diff options
| author | Adam Izraelevitz | 2019-10-18 19:01:19 -0700 |
|---|---|---|
| committer | GitHub | 2019-10-18 19:01:19 -0700 |
| commit | fd981848c7d2a800a15f9acfbf33b57dd1c6225b (patch) | |
| tree | 3609a301cb0ec867deefea4a0d08425810b00418 /src/main/scala/firrtl/passes | |
| parent | 973ecf516c0ef2b222f2eb68dc8b514767db59af (diff) | |
Upstream intervals (#870)
Major features:
- Added Interval type, as well as PrimOps asInterval, clip, wrap, and sqz.
- Changed PrimOp names: bpset -> setp, bpshl -> incp, bpshr -> decp
- Refactored width/bound inferencer into a separate constraint solver
- Added transforms to infer, trim, and remove interval bounds
- Tests for said features
Plan to be released with 1.3
Diffstat (limited to 'src/main/scala/firrtl/passes')
| -rw-r--r-- | src/main/scala/firrtl/passes/CheckWidths.scala | 63 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Checks.scala | 97 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/ConvertFixedToSInt.scala | 6 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/InferBinaryPoints.scala | 101 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/InferTypes.scala | 14 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/InferWidths.scala | 486 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveCHIRRTL.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveIntervals.scala | 173 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/TrimIntervals.scala | 97 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/memlib/MemUtils.scala | 2 |
10 files changed, 641 insertions, 400 deletions
diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala index 07784e19..5ae5dad4 100644 --- a/src/main/scala/firrtl/passes/CheckWidths.scala +++ b/src/main/scala/firrtl/passes/CheckWidths.scala @@ -7,14 +7,21 @@ import firrtl.ir._ import firrtl.PrimOps._ import firrtl.traversals.Foreachers._ import firrtl.Utils._ +import firrtl.constraint.IsKnown import firrtl.annotations.{CircuitTarget, ModuleTarget, Target, TargetToken} object CheckWidths extends Pass { /** The maximum allowed width for any circuit element */ val MaxWidth = 1000000 - val DshlMaxWidth = ceilLog2(MaxWidth + 1) + val DshlMaxWidth = getUIntWidth(MaxWidth) class UninferredWidth (info: Info, target: String) extends PassException( - s"""|$info : Uninferred width for target below. (Did you forget to assign to it?) + s"""|$info : Uninferred width for target below.serialize}. (Did you forget to assign to it?) + |$target""".stripMargin) + class UninferredBound (info: Info, target: String, bound: String) extends PassException( + s"""|$info : Uninferred $bound bound for target. (Did you forget to assign to it?) + |$target""".stripMargin) + class InvalidRange (info: Info, target: String, i: IntervalType) extends PassException( + s"""|$info : Invalid range ${i.serialize} for target below. (Are the bounds valid?) |$target""".stripMargin) class WidthTooSmall(info: Info, mname: String, b: BigInt) extends PassException( s"$info : [target $mname] Width too small for constant $b.") @@ -32,17 +39,25 @@ object CheckWidths extends Pass { s"$info: [target $mname] Parameter $n in tail operator is larger than input width $width.") class AttachWidthsNotEqual(info: Info, mname: String, eName: String, source: String) extends PassException( s"$info: [target $mname] Attach source $source and expression $eName must have identical widths.") + class DisjointSqueeze(info: Info, mname: String, squeeze: DoPrim) + extends PassException({ + val toSqz = squeeze.args.head.serialize + val toSqzTpe = squeeze.args.head.tpe.serialize + val sqzTo = squeeze.args(1).serialize + val sqzToTpe = squeeze.args(1).tpe.serialize + s"$info: [module $mname] Disjoint squz currently unsupported: $toSqz:$toSqzTpe cannot be squeezed with $sqzTo's type $sqzToTpe" + }) def run(c: Circuit): Circuit = { val errors = new Errors() - def check_width_w(info: Info, target: Target)(w: Width): Unit = { - w match { - case IntWidth(width) if width >= MaxWidth => + def check_width_w(info: Info, target: Target, t: Type)(w: Width): Unit = { + (w, t) match { + case (IntWidth(width), _) if width >= MaxWidth => errors.append(new WidthTooBig(info, target.serialize, width)) - case w: IntWidth if w.width >= 0 => - case _: IntWidth => + case (w: IntWidth, f: FixedType) if (w.width < 0 && w.width == f.width) => errors append new NegWidthException(info, target.serialize) + case (_: IntWidth, _) => case _ => errors append new UninferredWidth(info, target.prettyPrint(" ")) } @@ -57,9 +72,28 @@ object CheckWidths extends Pass { def check_width_t(info: Info, target: Target)(t: Type): Unit = { t match { case tt: BundleType => tt.fields.foreach(check_width_f(info, target)) + //Supports when l = u (if closed) + case i@IntervalType(Closed(l), Closed(u), IntWidth(_)) if l <= u => i + case i:IntervalType if i.range == Some(Nil) => + errors append new InvalidRange(info, target.prettyPrint(" "), i) + i + case i@IntervalType(KnownBound(l), KnownBound(u), IntWidth(p)) if l >= u => + errors append new InvalidRange(info, target.prettyPrint(" "), i) + i + case i@IntervalType(KnownBound(_), KnownBound(_), IntWidth(_)) => i + case i@IntervalType(_: IsKnown, _, _) => + errors append new UninferredBound(info, target.prettyPrint(" "), "upper") + i + case i@IntervalType(_, _: IsKnown, _) => + errors append new UninferredBound(info, target.prettyPrint(" "), "lower") + i + case i@IntervalType(_, _, _) => + errors append new UninferredBound(info, target.prettyPrint(" "), "lower") + errors append new UninferredBound(info, target.prettyPrint(" "), "upper") + i case tt => tt foreach check_width_t(info, target) } - t foreach check_width_w(info, target) + t foreach check_width_w(info, target, t) } def check_width_f(info: Info, target: Target)(f: Field): Unit = @@ -77,6 +111,12 @@ object CheckWidths extends Pass { errors append new WidthTooSmall(info, target.serialize, e.value) case _ => } + case sqz@DoPrim(Squeeze, Seq(a, b), _, IntervalType(Closed(min), Closed(max), _)) => + (a.tpe, b.tpe) match { + case (IntervalType(Closed(la), Closed(ua), _), IntervalType(Closed(lb), Closed(ub), _)) if (ua < lb) || (ub < la) => + errors append new DisjointSqueeze(info, target.serialize, sqz) + case other => + } case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) <= hi) => errors append new BitsWidthException(info, target.serialize, hi, bitWidth(a.tpe), e.serialize) case DoPrim(Head, Seq(a), Seq(n), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) < n) => @@ -87,7 +127,6 @@ object CheckWidths extends Pass { errors append new DshlTooBig(info, target.serialize) case _ => } - //e map check_width_t(info, mname) map check_width_e(info, mname) e foreach check_width_e(info, target) } @@ -111,11 +150,15 @@ object CheckWidths extends Pass { case ResetType => case _ => errors.append(new CheckTypes.IllegalResetType(info, target.serialize, sx.name)) } + if(!CheckTypes.validConnect(sx.tpe, sx.init.tpe)) { + val conMsg = sx.copy(info = NoInfo).serialize + errors.append(new CheckTypes.InvalidConnect(info, target.module, conMsg, WRef(sx), sx.init)) + } case _ => } } - def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Unit = check_width_t(p.info, target)(p.tpe) + def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Unit = check_width_t(p.info, target.ref(p.name))(p.tpe) def check_width_m(circuit: CircuitTarget)(m: DefModule): Unit = { m foreach check_width_p(m.info, circuit.module(m.name)) diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index 85ed7de0..4239247c 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -8,6 +8,7 @@ import firrtl.PrimOps._ import firrtl.Utils._ import firrtl.traversals.Foreachers._ import firrtl.WrappedType._ +import firrtl.constraint.{Constraint, IsKnown} trait CheckHighFormLike { type NameSet = collection.mutable.HashSet[String] @@ -84,11 +85,11 @@ trait CheckHighFormLike { e.op match { case Add | Sub | Mul | Div | Rem | Lt | Leq | Gt | Geq | - Eq | Neq | Dshl | Dshr | And | Or | Xor | Cat | Dshlw => + Eq | Neq | Dshl | Dshr | And | Or | Xor | Cat | Dshlw | Clip | Wrap | Squeeze => correctNum(Option(2), 0) case AsUInt | AsSInt | AsClock | AsAsyncReset | Cvt | Neq | Not => correctNum(Option(1), 0) - case AsFixedPoint | Pad | Head | Tail | BPShl | BPShr | BPSet => + case AsFixedPoint | Pad | Head | Tail | IncP | DecP | SetP => correctNum(Option(1), 1) case Shl | Shr => correctNum(Option(1), 1) @@ -101,6 +102,8 @@ trait CheckHighFormLike { if (lsb > msb) { errors.append(new LsbLargerThanMsbException(info, mname, e.op.toString, lsb, msb)) } + case AsInterval => + correctNum(Option(1), 3) case Andr | Orr | Xorr | Neg => correctNum(None,0) } @@ -137,7 +140,9 @@ trait CheckHighFormLike { def checkHighFormT(info: Info, mname: String)(t: Type): Unit = { t foreach checkHighFormT(info, mname) t match { - case tx: VectorType if tx.size < 0 => errors.append(new NegVecSizeException(info, mname)) + case tx: VectorType if tx.size < 0 => + errors.append(new NegVecSizeException(info, mname)) + case i: IntervalType => i case _ => t foreach checkHighFormW(info, mname) } } @@ -146,6 +151,7 @@ trait CheckHighFormLike { e match { case _: Reference | _: SubField | _: SubIndex | _: SubAccess => // No error case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess | _: Mux | _: ValidIf => // No error + case _: Reference | _: SubField | _: SubIndex | _: SubAccess => // No error case _ => errors.append(new InvalidAccessException(info, mname)) } } @@ -164,8 +170,8 @@ trait CheckHighFormLike { case ex: WSubAccess => validSubexp(info, mname)(ex.expr) case ex => ex foreach validSubexp(info, mname) } - e foreach checkHighFormW(info, mname) - e foreach checkHighFormT(info, mname) + e foreach checkHighFormW(info, mname + "/" + e.serialize) + e foreach checkHighFormT(info, mname + "/" + e.serialize) e foreach checkHighFormE(info, mname, names) } @@ -215,8 +221,7 @@ trait CheckHighFormLike { if (names(p.name)) errors.append(new NotUniqueException(NoInfo, mname, p.name)) names += p.name - p.tpe foreach checkHighFormT(p.info, mname) - p.tpe foreach checkHighFormW(p.info, mname) + checkHighFormT(p.info, mname)(p.tpe) } // Search for ResetType Ports of direction @@ -339,6 +344,11 @@ object CheckTypes extends Pass { s"$info: [module $mname] Uninferred type: $exp." ) + def fits(bigger: Constraint, smaller: Constraint): Boolean = (bigger, smaller) match { + case (IsKnown(v1), IsKnown(v2)) if v1 < v2 => false + case _ => true + } + def legalResetType(tpe: Type): Boolean = tpe match { case UIntType(IntWidth(w)) if w == 1 => true case AsyncResetType => true @@ -355,6 +365,9 @@ object CheckTypes extends Pass { case (_: UIntType, _: UIntType) => flip1 == flip2 case (_: SIntType, _: SIntType) => flip1 == flip2 case (_: FixedType, _: FixedType) => flip1 == flip2 + case (i1: IntervalType, i2: IntervalType) => + import Implicits.width2constraint + fits(i2.lower, i1.lower) && fits(i1.upper, i2.upper) && fits(i1.point, i2.point) case (_: AnalogType, _: AnalogType) => true case (AsyncResetType, AsyncResetType) => flip1 == flip2 case (ResetType, tpe) => legalResetType(tpe) && flip1 == flip2 @@ -375,7 +388,17 @@ object CheckTypes extends Pass { } } - def validConnect(con: Connect): Boolean = wt(con.loc.tpe).superTypeOf(wt(con.expr.tpe)) + def validConnect(locTpe: Type, expTpe: Type): Boolean = { + val itFits = (locTpe, expTpe) match { + case (i1: IntervalType, i2: IntervalType) => + import Implicits.width2constraint + fits(i2.lower, i1.lower) && fits(i1.upper, i2.upper) && fits(i1.point, i2.point) + case _ => true + } + wt(locTpe).superTypeOf(wt(expTpe)) && itFits + } + + def validConnect(con: Connect): Boolean = validConnect(con.loc.tpe, con.expr.tpe) def validPartialConnect(con: PartialConnect): Boolean = bulk_equals(con.loc.tpe, con.expr.tpe, Default, Default) @@ -393,44 +416,51 @@ object CheckTypes extends Pass { case tx: BundleType => tx.fields forall (x => x.flip == Default && passive(x.tpe)) case tx => true } + def check_types_primop(info: Info, mname: String, e: DoPrim): Unit = { - def checkAllTypes(exprs: Seq[Expression], okUInt: Boolean, okSInt: Boolean, okClock: Boolean, okFix: Boolean, okAsync: Boolean): Unit = { - exprs.foldLeft((false, false, false, false, false)) { - case ((isUInt, isSInt, isClock, isFix, isAsync), expr) => expr.tpe match { - case u: UIntType => (true, isSInt, isClock, isFix, isAsync) - case s: SIntType => (isUInt, true, isClock, isFix, isAsync) - case ClockType => (isUInt, isSInt, true, isFix, isAsync) - case f: FixedType => (isUInt, isSInt, isClock, true, isAsync) - case AsyncResetType => (isUInt, isSInt, isClock, isFix, true) + def checkAllTypes(exprs: Seq[Expression], okUInt: Boolean, okSInt: Boolean, okClock: Boolean, okFix: Boolean, okAsync: Boolean, okInterval: Boolean): Unit = { + exprs.foldLeft((false, false, false, false, false, false)) { + case ((isUInt, isSInt, isClock, isFix, isAsync, isInterval), expr) => expr.tpe match { + case u: UIntType => (true, isSInt, isClock, isFix, isAsync, isInterval) + case s: SIntType => (isUInt, true, isClock, isFix, isAsync, isInterval) + case ClockType => (isUInt, isSInt, true, isFix, isAsync, isInterval) + case f: FixedType => (isUInt, isSInt, isClock, true, isAsync, isInterval) + case AsyncResetType => (isUInt, isSInt, isClock, isFix, true, isInterval) + case i:IntervalType => (isUInt, isSInt, isClock, isFix, isAsync, true) case UnknownType => errors.append(new IllegalUnknownType(info, mname, e.serialize)) - (isUInt, isSInt, isClock, isFix, isAsync) + (isUInt, isSInt, isClock, isFix, isAsync, isInterval) case other => throwInternalError(s"Illegal Type: ${other.serialize}") } } match { - // (UInt, SInt, Clock, Fixed) - case (isAll, false, false, false, false) if isAll == okUInt => - case (false, isAll, false, false, false) if isAll == okSInt => - case (false, false, isAll, false, false) if isAll == okClock => - case (false, false, false, isAll, false) if isAll == okFix => - case (false, false, false, false, isAll) if isAll == okAsync => + // (UInt, SInt, Clock, Fixed, Async, Interval) + case (isAll, false, false, false, false, false) if isAll == okUInt => + case (false, isAll, false, false, false, false) if isAll == okSInt => + case (false, false, isAll, false, false, false) if isAll == okClock => + case (false, false, false, isAll, false, false) if isAll == okFix => + case (false, false, false, false, isAll, false) if isAll == okAsync => + case (false, false, false, false, false, isAll) if isAll == okInterval => case x => errors.append(new OpNotCorrectType(info, mname, e.op.serialize, exprs.map(_.tpe.serialize))) } } e.op match { - case AsUInt | AsSInt | AsClock | AsFixedPoint | AsAsyncReset => + case AsUInt | AsSInt | AsClock | AsFixedPoint | AsAsyncReset | AsInterval => // All types are ok case Dshl | Dshr => - checkAllTypes(Seq(e.args.head), okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false) - checkAllTypes(Seq(e.args(1)), okUInt=true, okSInt=false, okClock=false, okFix=false, okAsync=false) + checkAllTypes(Seq(e.args.head), okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true) + checkAllTypes(Seq(e.args(1)), okUInt=true, okSInt=false, okClock=false, okFix=false, okAsync=false, okInterval=false) case Add | Sub | Mul | Lt | Leq | Gt | Geq | Eq | Neq => - checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false) - case Pad | Shl | Shr | Cat | Bits | Head | Tail => - checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false) - case BPShl | BPShr | BPSet => - checkAllTypes(e.args, okUInt=false, okSInt=false, okClock=false, okFix=true, okAsync=false) + checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true) + case Pad | Bits | Head | Tail => + checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=false) + case Shl | Shr | Cat => + checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true) + case IncP | DecP | SetP => + checkAllTypes(e.args, okUInt=false, okSInt=false, okClock=false, okFix=true, okAsync=false, okInterval=true) + case Wrap | Clip | Squeeze => + checkAllTypes(e.args, okUInt = false, okSInt = false, okClock = false, okFix = false, okAsync=false, okInterval = true) case _ => - checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=false, okAsync=false) + checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=false, okAsync=false, okInterval=false) } } @@ -494,6 +524,9 @@ object CheckTypes extends Pass { sx.tpe match { case AnalogType(_) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name)) case t if wt(sx.tpe) != wt(sx.init.tpe) => errors.append(new InvalidRegInit(info, mname)) + case t if !validConnect(sx.tpe, sx.init.tpe) => + val conMsg = sx.copy(info = NoInfo).serialize + errors.append(new CheckTypes.InvalidConnect(info, mname, conMsg, WRef(sx), sx.init)) case t => } if (!legalResetType(sx.reset.tpe)) { diff --git a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala index 67fdfea0..05a000c5 100644 --- a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala +++ b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala @@ -39,9 +39,9 @@ object ConvertFixedToSInt extends Pass { def updateExpType(e:Expression): Expression = e match { case DoPrim(Mul, args, consts, tpe) => e map updateExpType case DoPrim(AsFixedPoint, args, consts, tpe) => DoPrim(AsSInt, args, Seq.empty, tpe) map updateExpType - case DoPrim(BPShl, args, consts, tpe) => DoPrim(Shl, args, consts, tpe) map updateExpType - case DoPrim(BPShr, args, consts, tpe) => DoPrim(Shr, args, consts, tpe) map updateExpType - case DoPrim(BPSet, args, consts, FixedType(w, IntWidth(p))) => alignArg(args.head, p) map updateExpType + case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe) map updateExpType + case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe) map updateExpType + case DoPrim(SetP, args, consts, FixedType(w, IntWidth(p))) => alignArg(args.head, p) map updateExpType case DoPrim(op, args, consts, tpe) => val point = calcPoint(args) val newExp = DoPrim(op, args.map(x => alignArg(x, point)), consts, UnknownType) diff --git a/src/main/scala/firrtl/passes/InferBinaryPoints.scala b/src/main/scala/firrtl/passes/InferBinaryPoints.scala new file mode 100644 index 00000000..258c9697 --- /dev/null +++ b/src/main/scala/firrtl/passes/InferBinaryPoints.scala @@ -0,0 +1,101 @@ +// See LICENSE for license details. + +package firrtl.passes + +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.Mappers._ +import firrtl.annotations.{CircuitTarget, ModuleTarget, ReferenceTarget, Target} +import firrtl.constraint.ConstraintSolver + +class InferBinaryPoints extends Pass { + private val constraintSolver = new ConstraintSolver() + + private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1,t2) match { + case (UIntType(w1), UIntType(w2)) => + case (SIntType(w1), SIntType(w2)) => + case (ClockType, ClockType) => + case (ResetType, _) => + case (_, ResetType) => + case (AsyncResetType, AsyncResetType) => + case (FixedType(w1, p1), FixedType(w2, p2)) => + constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint("")) + case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) => + constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint("")) + case (AnalogType(w1), AnalogType(w2)) => + case (t1: BundleType, t2: BundleType) => + (t1.fields zip t2.fields) foreach { case (f1, f2) => + (f1.flip, f2.flip) match { + case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe) + case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe) + case _ => sys.error("Shouldn't be here") + } + } + case (t1: VectorType, t2: VectorType) => addTypeConstraints(r1.index(0), r2.index(0))(t1.tpe, t2.tpe) + case other => throwInternalError(s"Illegal compiler state: cannot constraint different types - $other") + } + private def addDecConstraints(t: Type): Type = t map addDecConstraints + private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s map addDecConstraints match { + case c: Connect => + val n = get_size(c.loc.tpe) + val locs = create_exps(c.loc) + val exps = create_exps(c.expr) + (locs zip exps) foreach { case (loc, exp) => + to_flip(flow(loc)) match { + case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) + case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) + } + } + c + case pc: PartialConnect => + val ls = get_valid_points(pc.loc.tpe, pc.expr.tpe, Default, Default) + val locs = create_exps(pc.loc) + val exps = create_exps(pc.expr) + ls foreach { case (x, y) => + val loc = locs(x) + val exp = exps(y) + to_flip(flow(loc)) match { + case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) + case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) + } + } + pc + case r: DefRegister => + addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe) + r + case x => x map addStmtConstraints(mt) + } + private def fixWidth(w: Width): Width = constraintSolver.get(w) match { + case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) + case None => w + case _ => sys.error("Shouldn't be here") + } + private def fixType(t: Type): Type = t map fixType map fixWidth match { + case IntervalType(l, u, p) => + val px = constraintSolver.get(p) match { + case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) + case None => p + case _ => sys.error("Shouldn't be here") + } + IntervalType(l, u, px) + case FixedType(w, p) => + val px = constraintSolver.get(p) match { + case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) + case None => p + case _ => sys.error("Shouldn't be here") + } + FixedType(w, px) + case x => x + } + private def fixStmt(s: Statement): Statement = s map fixStmt map fixType + private def fixPort(p: Port): Port = Port(p.info, p.name, p.direction, fixType(p.tpe)) + def run (c: Circuit): Circuit = { + val ct = CircuitTarget(c.main) + c.modules foreach (m => m map addStmtConstraints(ct.module(m.name))) + c.modules foreach (_.ports foreach {p => addDecConstraints(p.tpe)}) + constraintSolver.solve() + InferTypes.run(c.copy(modules = c.modules map (_ + map fixPort + map fixStmt))) + } +} diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala index 288b62ba..3c5cf7fb 100644 --- a/src/main/scala/firrtl/passes/InferTypes.scala +++ b/src/main/scala/firrtl/passes/InferTypes.scala @@ -14,13 +14,23 @@ object InferTypes extends Pass { val namespace = Namespace() val mtypes = (c.modules map (m => m.name -> module_type(m))).toMap + def remove_unknowns_b(b: Bound): Bound = b match { + case UnknownBound => VarBound(namespace.newName("b")) + case k => k + } + def remove_unknowns_w(w: Width): Width = w match { case UnknownWidth => VarWidth(namespace.newName("w")) case wx => wx } - def remove_unknowns(t: Type): Type = - t map remove_unknowns map remove_unknowns_w + def remove_unknowns(t: Type): Type = { + t map remove_unknowns map remove_unknowns_w match { + case IntervalType(l, u, p) => + IntervalType(remove_unknowns_b(l), remove_unknowns_b(u), p) + case x => x + } + } def infer_types_e(types: TypeMap)(e: Expression): Expression = e map infer_types_e(types) match { diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index 9c58da2c..2211d238 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -3,15 +3,20 @@ package firrtl.passes // Datastructures -import scala.collection.mutable.ArrayBuffer -import scala.collection.immutable.ListMap - import firrtl._ import firrtl.annotations.{Annotation, ReferenceTarget} import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ -import firrtl.traversals.Foreachers._ +import firrtl.Implicits.width2constraint +import firrtl.annotations.{CircuitTarget, ModuleTarget, ReferenceTarget, Target} +import firrtl.constraint.{ConstraintSolver, IsMax} + +object InferWidths { + def apply(): InferWidths = new InferWidths() + def run(c: Circuit): Circuit = new InferWidths().run(c) + def execute(state: CircuitState): CircuitState = new InferWidths().execute(state) +} case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarget) extends Annotation { def update(renameMap: RenameMap): Seq[WidthGeqConstraintAnnotation] = { @@ -33,369 +38,146 @@ case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarg } } +/** Infers the widths of all signals with unknown widths + * + * Is a global width inference algorithm + * - Instances of the same module with unknown input port widths will be assigned the + * largest width of all assignments to each of its instance ports + * - If you don't want the global inference behavior, then be sure to define all your input widths + * + * Infers the smallest width is larger than all assigned widths to a signal + * - Note that this means that dummy assignments that are overwritten by last-connect-semantics + * can still influence width inference + * - E.g. + * wire x: UInt + * x <= UInt<5>(15) + * x <= UInt<1>(1) + * + * Since width inference occurs before lowering, it infers x's width to be 5 but with an assignment of UInt(1): + * + * wire x: UInt<5> + * x <= UInt<1>(1) + * + * Uses firrtl.constraint package to infer widths + */ class InferWidths extends Transform with ResolvedAnnotationPaths { def inputForm: CircuitForm = UnknownForm def outputForm: CircuitForm = UnknownForm - val annotationClasses = Seq(classOf[WidthGeqConstraintAnnotation]) + private val constraintSolver = new ConstraintSolver() - type ConstraintMap = collection.mutable.LinkedHashMap[String, Width] - - def solve_constraints(l: Seq[WGeq]): ConstraintMap = { - def unique(ls: Seq[Width]) : Seq[Width] = - (ls map (new WrappedWidth(_))).distinct map (_.w) - // Combines constraints on the same VarWidth into the same constraint - def make_unique(ls: Seq[WGeq]): ListMap[String,Width] = { - ls.foldLeft(ListMap.empty[String, Width])((acc, wgeq) => wgeq.loc match { - case VarWidth(name) => acc.get(name) match { - case None => acc + (name -> wgeq.exp) - // Avoid constructing massive MaxWidth chains - case Some(MaxWidth(args)) => acc + (name -> MaxWidth(wgeq.exp +: args)) - case Some(width) => acc + (name -> MaxWidth(Seq(wgeq.exp, width))) - } - case _ => acc - }) - } - def pullMinMax(w: Width): Width = w map pullMinMax match { - case PlusWidth(MaxWidth(maxs), IntWidth(i)) => MaxWidth(maxs.map(m => PlusWidth(m, IntWidth(i)))) - case PlusWidth(IntWidth(i), MaxWidth(maxs)) => MaxWidth(maxs.map(m => PlusWidth(m, IntWidth(i)))) - case MinusWidth(MaxWidth(maxs), IntWidth(i)) => MaxWidth(maxs.map(m => MinusWidth(m, IntWidth(i)))) - case MinusWidth(IntWidth(i), MaxWidth(maxs)) => MaxWidth(maxs.map(m => MinusWidth(IntWidth(i), m))) - case PlusWidth(MinWidth(mins), IntWidth(i)) => MinWidth(mins.map(m => PlusWidth(m, IntWidth(i)))) - case PlusWidth(IntWidth(i), MinWidth(mins)) => MinWidth(mins.map(m => PlusWidth(m, IntWidth(i)))) - case MinusWidth(MinWidth(mins), IntWidth(i)) => MinWidth(mins.map(m => MinusWidth(m, IntWidth(i)))) - case MinusWidth(IntWidth(i), MinWidth(mins)) => MinWidth(mins.map(m => MinusWidth(IntWidth(i), m))) - case wx => wx - } - def collectMinMax(w: Width): Width = w map collectMinMax match { - case MinWidth(args) => MinWidth(unique(args.foldLeft(List[Width]()) { - case (res, wxx: MinWidth) => wxx.args ++: res - case (res, wxx) => wxx +: res - })) - case MaxWidth(args) => MaxWidth(unique(args.foldLeft(List[Width]()) { - case (res, wxx: MaxWidth) => wxx.args ++: res - case (res, wxx) => wxx +: res - })) - case wx => wx - } - def mergePlusMinus(w: Width): Width = w map mergePlusMinus match { - case wx: PlusWidth => (wx.arg1, wx.arg2) match { - case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width + w2.width) - case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1) - case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1) - case (IntWidth(y), PlusWidth(w1, IntWidth(x))) => PlusWidth(IntWidth(x + y), w1) - case (IntWidth(y), PlusWidth(IntWidth(x), w1)) => PlusWidth(IntWidth(x + y), w1) - case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(y - x), w1) - case (IntWidth(y), MinusWidth(w1, IntWidth(x))) => PlusWidth(IntWidth(y - x), w1) - case _ => wx - } - case wx: MinusWidth => (wx.arg1, wx.arg2) match { - case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width) - case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1) - case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1) - case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1) - case _ => wx - } - case wx: ExpWidth => wx.arg1 match { - case w1: IntWidth => IntWidth(BigInt((math.pow(2, w1.width.toDouble) - 1).toLong)) - case _ => wx - } - case wx => wx - } - def removeZeros(w: Width): Width = w map removeZeros match { - case wx: PlusWidth => (wx.arg1, wx.arg2) match { - case (w1, IntWidth(x)) if x == 0 => w1 - case (IntWidth(x), w1) if x == 0 => w1 - case _ => wx - } - case wx: MinusWidth => (wx.arg1, wx.arg2) match { - case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width) - case (w1, IntWidth(x)) if x == 0 => w1 - case _ => wx - } - case wx => wx - } - def simplify(w: Width): Width = { - val opts = Seq( - pullMinMax _, - collectMinMax _, - mergePlusMinus _, - removeZeros _ - ) - opts.foldLeft(w) { (width, opt) => opt(width) } - } - - def substitute(h: ConstraintMap)(w: Width): Width = { - //;println-all-debug(["Substituting for [" w "]"]) - val wx = simplify(w) - //;println-all-debug(["After Simplify: [" wx "]"]) - wx map substitute(h) match { - //;("matched println-debugvarwidth!") - case wxx: VarWidth => h get wxx.name match { - case None => wxx - case Some(p) => - //;println-debug("Contained!") - //;println-all-debug(["Width: " wxx]) - //;println-all-debug(["Accessed: " h[name(wxx)]]) - val t = simplify(substitute(h)(p)) - h(wxx.name) = t - t - } - case wxx => wxx - //;println-all-debug(["not varwidth!" w]) - } - } - - def b_sub(h: ConstraintMap)(w: Width): Width = { - w map b_sub(h) match { - case wx: VarWidth => h getOrElse (wx.name, wx) - case wx => wx - } - } - - def remove_cycle(n: String)(w: Width): Width = { - //;println-all-debug(["Removing cycle for " n " inside " w]) - w match { - case wx: MaxWidth => MaxWidth(wx.args filter { - case wxx: VarWidth => !(n equals wxx.name) - case MinusWidth(VarWidth(name), IntWidth(i)) if ((i >= 0) && (n == name)) => false - case _ => true - }) - case wx: MinusWidth => wx.arg1 match { - case v: VarWidth if n == v.name => v - case v => wx - } - case wx => wx - } - //;println-all-debug(["After removing cycle for " n ", returning " wx]) - } + val annotationClasses = Seq(classOf[WidthGeqConstraintAnnotation]) - def hasVarWidth(n: String)(w: Width): Boolean = { - var has = false - def rec(w: Width): Width = { - w match { - case wx: VarWidth if wx.name == n => has = true - case _ => + private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1,t2) match { + case (UIntType(w1), UIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint("")) + case (SIntType(w1), SIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint("")) + case (ClockType, ClockType) => + case (FixedType(w1, p1), FixedType(w2, p2)) => + constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint("")) + constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint("")) + case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) => + constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint("")) + constraintSolver.addLeq(l1, l2, r1.prettyPrint(""), r2.prettyPrint("")) + constraintSolver.addGeq(u1, u2, r1.prettyPrint(""), r2.prettyPrint("")) + case (AnalogType(w1), AnalogType(w2)) => + constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint("")) + constraintSolver.addGeq(w2, w1, r1.prettyPrint(""), r2.prettyPrint("")) + case (t1: BundleType, t2: BundleType) => + (t1.fields zip t2.fields) foreach { case (f1, f2) => + (f1.flip, f2.flip) match { + case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe) + case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe) + case _ => sys.error("Shouldn't be here") } - w map rec - } - rec(w) - has - } - - //; Forward solve - //; Returns a solved list where each constraint undergoes: - //; 1) Continuous Solving (using triangular solving) - //; 2) Remove Cycles - //; 3) Move to solved if not self-recursive - val u = make_unique(l) - - //println("======== UNIQUE CONSTRAINTS ========") - //for (x <- u) { println(x) } - //println("====================================") - - val f = new ConstraintMap - val o = ArrayBuffer[String]() - for ((n, e) <- u) { - //println("==== SOLUTIONS TABLE ====") - //for (x <- f) println(x) - //println("=========================") - - val e_sub = simplify(substitute(f)(e)) - - //println("Solving " + n + " => " + e) - //println("After Substitute: " + n + " => " + e_sub) - //println("==== SOLUTIONS TABLE (Post Substitute) ====") - //for (x <- f) println(x) - //println("=========================") - - val ex = remove_cycle(n)(e_sub) - - //println("After Remove Cycle: " + n + " => " + ex) - if (!hasVarWidth(n)(ex)) { - //println("Not rec!: " + n + " => " + ex) - //println("Adding [" + n + "=>" + ex + "] to Solutions Table") - f(n) = ex - o += n } - } - - //println("Forward Solved Constraints") - //for (x <- f) println(x) - - //; Backwards Solve - val b = new ConstraintMap - for (i <- (o.size - 1) to 0 by -1) { - val n = o(i) // Should visit `o` backward - /* - println("SOLVE BACK: [" + n + " => " + f(n) + "]") - println("==== SOLUTIONS TABLE ====") - for (x <- b) println(x) - println("=========================") - */ - val ex = simplify(b_sub(b)(f(n))) - /* - println("BACK RETURN: [" + n + " => " + ex + "]") - */ - b(n) = ex - /* - println("==== SOLUTIONS TABLE (Post backsolve) ====") - for (x <- b) println(x) - println("=========================") - */ - } - b - } - - def get_constraints_t(t1: Type, t2: Type): Seq[WGeq] = (t1,t2) match { - case (t1: UIntType, t2: UIntType) => Seq(WGeq(t1.width, t2.width)) - case (t1: SIntType, t2: SIntType) => Seq(WGeq(t1.width, t2.width)) - case (ClockType, ClockType) => Nil + case (t1: VectorType, t2: VectorType) => addTypeConstraints(r1.index(0), r2.index(0))(t1.tpe, t2.tpe) case (AsyncResetType, AsyncResetType) => Nil - case (FixedType(w1, p1), FixedType(w2, p2)) => Seq(WGeq(w1,w2), WGeq(p1,p2)) - case (AnalogType(w1), AnalogType(w2)) => Seq(WGeq(w1,w2), WGeq(w2,w1)) - case (t1: BundleType, t2: BundleType) => - (t1.fields zip t2.fields foldLeft Seq[WGeq]()){case (res, (f1, f2)) => - res ++ (f1.flip match { - case Default => get_constraints_t(f1.tpe, f2.tpe) - case Flip => get_constraints_t(f2.tpe, f1.tpe) - }) - } - case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe) case (ResetType, _) => Nil case (_, ResetType) => Nil } - def run(c: Circuit, extra: Seq[WGeq]): Circuit = { - val v = ArrayBuffer[WGeq]() ++ extra + private def addExpConstraints(e: Expression): Expression = e map addExpConstraints match { + case m@Mux(p, tVal, fVal, t) => + constraintSolver.addGeq(getWidth(p), Closed(1), "mux predicate", "1.W") + m + case other => other + } - def get_constraints_e(e: Expression): Unit = { - e match { - case (e: Mux) => v ++= Seq( - WGeq(getWidth(e.cond), IntWidth(1)), - WGeq(IntWidth(1), getWidth(e.cond)) - ) - case _ => + private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s map addExpConstraints match { + case c: Connect => + val n = get_size(c.loc.tpe) + val locs = create_exps(c.loc) + val exps = create_exps(c.expr) + (locs zip exps).foreach { case (loc, exp) => + to_flip(flow(loc)) match { + case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) + case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) + } } - e.foreach(get_constraints_e) - } - - def get_constraints_declared_type (t: Type): Type = t match { - case FixedType(_, p) => - v += WGeq(p,IntWidth(0)) - t - case _ => t map get_constraints_declared_type - } - - def get_constraints_s(s: Statement): Unit = { - s map get_constraints_declared_type match { - case (s: Connect) => - val locs = create_exps(s.loc) - val exps = create_exps(s.expr) - v ++= locs.zip(exps).flatMap { case (locx, expx) => - to_flip(flow(locx)) match { - case Default => get_constraints_t(locx.tpe, expx.tpe)//WGeq(getWidth(locx), getWidth(expx)) - case Flip => get_constraints_t(expx.tpe, locx.tpe)//WGeq(getWidth(expx), getWidth(locx)) - } - } - case (s: PartialConnect) => - val ls = get_valid_points(s.loc.tpe, s.expr.tpe, Default, Default) - val locs = create_exps(s.loc) - val exps = create_exps(s.expr) - v ++= (ls flatMap {case (x, y) => - val locx = locs(x) - val expx = exps(y) - to_flip(flow(locx)) match { - case Default => get_constraints_t(locx.tpe, expx.tpe)//WGeq(getWidth(locx), getWidth(expx)) - case Flip => get_constraints_t(expx.tpe, locx.tpe)//WGeq(getWidth(expx), getWidth(locx)) - } - }) - case (s: DefRegister) => - if (s.reset.tpe != AsyncResetType ) { - v ++= ( - get_constraints_t(s.reset.tpe, UIntType(IntWidth(1))) ++ - get_constraints_t(UIntType(IntWidth(1)), s.reset.tpe)) - } - v ++= get_constraints_t(s.tpe, s.init.tpe) - case (s:Conditionally) => v ++= - get_constraints_t(s.pred.tpe, UIntType(IntWidth(1))) ++ - get_constraints_t(UIntType(IntWidth(1)), s.pred.tpe) - case Attach(_, exprs) => - // All widths must be equal - val widths = exprs map (e => getWidth(e.tpe)) - v ++= widths.tail map (WGeq(widths.head, _)) - case _ => + c + case pc: PartialConnect => + val ls = get_valid_points(pc.loc.tpe, pc.expr.tpe, Default, Default) + val locs = create_exps(pc.loc) + val exps = create_exps(pc.expr) + ls foreach { case (x, y) => + val loc = locs(x) + val exp = exps(y) + to_flip(flow(loc)) match { + case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) + case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) + } } - s.foreach(get_constraints_e) - s.foreach(get_constraints_s) - } - - c.modules.foreach(_.foreach(get_constraints_s)) - c.modules.foreach(_.ports.foreach({p => get_constraints_declared_type(p.tpe)})) - - //println("======== ALL CONSTRAINTS ========") - //for(x <- v) println(x) - //println("=================================") - val h = solve_constraints(v) - //println("======== SOLVED CONSTRAINTS ========") - //for(x <- h) println(x) - //println("====================================") - - def evaluate(w: Width): Width = { - def map2(a: Option[BigInt], b: Option[BigInt], f: (BigInt,BigInt) => BigInt): Option[BigInt] = - for (a_num <- a; b_num <- b) yield f(a_num, b_num) - def reduceOptions(l: Seq[Option[BigInt]], f: (BigInt,BigInt) => BigInt): Option[BigInt] = - l.reduce(map2(_, _, f)) + pc + case r: DefRegister => + if (r.reset.tpe != AsyncResetType ) { + addTypeConstraints(Target.asTarget(mt)(r.reset), mt.ref("1"))(r.reset.tpe, UIntType(IntWidth(1))) + } + addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe) + r + case a@Attach(_, exprs) => + val widths = exprs map (e => (e, getWidth(e.tpe))) + val maxWidth = IsMax(widths.map(x => width2constraint(x._2))) + widths.foreach { case (e, w) => + constraintSolver.addGeq(w, CalcWidth(maxWidth), Target.asTarget(mt)(e).prettyPrint(""), mt.ref(a.serialize).prettyPrint("")) + } + a + case c: Conditionally => + addTypeConstraints(Target.asTarget(mt)(c.pred), mt.ref("1.W"))(c.pred.tpe, UIntType(IntWidth(1))) + c map addStmtConstraints(mt) + case x => x map addStmtConstraints(mt) + } + private def fixWidth(w: Width): Width = constraintSolver.get(w) match { + case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) + case None => w + case _ => sys.error("Shouldn't be here") + } + private def fixType(t: Type): Type = t map fixType map fixWidth match { + case IntervalType(l, u, p) => + val (lx, ux) = (constraintSolver.get(l), constraintSolver.get(u)) match { + case (Some(x: Bound), Some(y: Bound)) => (x, y) + case (None, None) => (l, u) + case x => sys.error(s"Shouldn't be here: $x") - // This function shouldn't be necessary - // Added as protection in case a constraint accidentally uses MinWidth/MaxWidth - // without any actual Widths. This should be elevated to an earlier error - def forceNonEmpty(in: Seq[Option[BigInt]], default: Option[BigInt]): Seq[Option[BigInt]] = - if (in.isEmpty) Seq(default) - else in - def solve(w: Width): Option[BigInt] = w match { - case wx: VarWidth => - for{ - v <- h.get(wx.name) if !v.isInstanceOf[VarWidth] - result <- solve(v) - } yield result - case wx: MaxWidth => reduceOptions(forceNonEmpty(wx.args.map(solve), Some(BigInt(0))), max) - case wx: MinWidth => reduceOptions(forceNonEmpty(wx.args.map(solve), None), min) - case wx: PlusWidth => map2(solve(wx.arg1), solve(wx.arg2), {_ + _}) - case wx: MinusWidth => map2(solve(wx.arg1), solve(wx.arg2), {_ - _}) - case wx: ExpWidth => map2(Some(BigInt(2)), solve(wx.arg1), pow_minus_one) - case wx: IntWidth => Some(wx.width) - case wx => throwInternalError(s"solve: shouldn't be here - %$wx") } + IntervalType(lx, ux, fixWidth(p)) + case FixedType(w, p) => FixedType(w, fixWidth(p)) + case x => x + } + private def fixStmt(s: Statement): Statement = s map fixStmt map fixType + private def fixPort(p: Port): Port = { + Port(p.info, p.name, p.direction, fixType(p.tpe)) + } - solve(w) match { - case None => w - case Some(s) => IntWidth(s) - } - } - - def reduce_var_widths_w(w: Width): Width = { - //println-all-debug(["REPLACE: " w]) - evaluate(w) - //println-all-debug(["WITH: " wx]) - } - - def reduce_var_widths_t(t: Type): Type = { - t map reduce_var_widths_t map reduce_var_widths_w - } - - def reduce_var_widths_s(s: Statement): Statement = { - s map reduce_var_widths_s map reduce_var_widths_t - } - - def reduce_var_widths_p(p: Port): Port = { - Port(p.info, p.name, p.direction, reduce_var_widths_t(p.tpe)) - } - - InferTypes.run(c.copy(modules = c.modules map (_ - map reduce_var_widths_p - map reduce_var_widths_s))) + def run (c: Circuit): Circuit = { + val ct = CircuitTarget(c.main) + c.modules foreach ( m => m map addStmtConstraints(ct.module(m.name))) + constraintSolver.solve() + val ret = InferTypes.run(c.copy(modules = c.modules map (_ + map fixPort + map fixStmt))) + constraintSolver.clear() + ret } def execute(state: CircuitState): CircuitState = { @@ -426,7 +208,7 @@ class InferWidths extends Transform with ResolvedAnnotationPaths { } } - val extraConstraints = state.annotations.flatMap { + state.annotations.foreach { case anno: WidthGeqConstraintAnnotation if anno.loc.isLocal && anno.exp.isLocal => val locType :: expType :: Nil = Seq(anno.loc, anno.exp) map { target => val baseType = typeMap.getOrElse(target.copy(component = Seq.empty), @@ -440,10 +222,12 @@ class InferWidths extends Transform with ResolvedAnnotationPaths { leafType } - get_constraints_t(locType, expType) - case other => Seq.empty + //get_constraints_t(locType, expType) + addTypeConstraints(anno.loc, anno.exp)(locType, expType) + case other => } - state.copy(circuit = run(state.circuit, extraConstraints)) + state.copy(circuit = run(state.circuit)) } + } diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala index 48b5f041..921ec3c7 100644 --- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala +++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala @@ -72,7 +72,7 @@ object RemoveCHIRRTL extends Transform { refs: DataRefMap, raddrs: AddrMap, renames: RenameMap)(s: Statement): Statement = s match { case sx: CDefMemory => types(sx.name) = sx.tpe - val taddr = UIntType(IntWidth(1 max ceilLog2(sx.size))) + val taddr = UIntType(IntWidth(1 max getUIntWidth(sx.size - 1))) val tdata = sx.tpe def set_poison(vec: Seq[MPort]) = vec flatMap (r => Seq( IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "addr", taddr)), diff --git a/src/main/scala/firrtl/passes/RemoveIntervals.scala b/src/main/scala/firrtl/passes/RemoveIntervals.scala new file mode 100644 index 00000000..73f59b59 --- /dev/null +++ b/src/main/scala/firrtl/passes/RemoveIntervals.scala @@ -0,0 +1,173 @@ +// See LICENSE for license details. + +package firrtl.passes + +import firrtl.PrimOps._ +import firrtl.ir._ +import firrtl._ +import firrtl.Mappers._ +import Implicits.{bigint2WInt} +import firrtl.constraint.IsKnown + +import scala.math.BigDecimal.RoundingMode._ + +class WrapWithRemainder(info: Info, mname: String, wrap: DoPrim) + extends PassException({ + val toWrap = wrap.args.head.serialize + val toWrapTpe = wrap.args.head.tpe.serialize + val wrapTo = wrap.args(1).serialize + val wrapToTpe = wrap.args(1).tpe.serialize + s"$info: [module $mname] Wraps with remainder currently unsupported: $toWrap:$toWrapTpe cannot be wrapped to $wrapTo's type $wrapToTpe" + }) + + +/** Replaces IntervalType with SIntType, three AST walks: + * 1) Align binary points + * - adds shift operators to primop args and connections + * - does not affect declaration- or inferred-types + * 2) Replace Interval [[DefNode]] with [[DefWire]] + [[Connect]] + * - You have to do this to capture the smaller bitwidths of nodes that intervals give you. Otherwise, any future + * InferTypes would reinfer the larger widths on these nodes from SInt width inference rules + * 3) Replace declaration IntervalType's with SIntType's + * - for each declaration: + * a. remove non-zero binary points + * b. remove open bounds + * c. replace with SIntType + * 3) Run InferTypes + */ +class RemoveIntervals extends Pass { + + def run(c: Circuit): Circuit = { + val alignedCircuit = c + val errors = new Errors() + val wiredCircuit = alignedCircuit map makeWireModule + val replacedCircuit = wiredCircuit map replaceModuleInterval(errors) + errors.trigger() + InferTypes.run(replacedCircuit) + } + + /* Replace interval types */ + private def replaceModuleInterval(errors: Errors)(m: DefModule): DefModule = + m map replaceStmtInterval(errors, m.name) map replacePortInterval + + private def replaceStmtInterval(errors: Errors, mname: String)(s: Statement): Statement = { + val info = s match { + case h: HasInfo => h.info + case _ => NoInfo + } + s map replaceTypeInterval map replaceStmtInterval(errors, mname) map replaceExprInterval(errors, info, mname) + + } + + private def replaceExprInterval(errors: Errors, info: Info, mname: String)(e: Expression): Expression = e match { + case _: WRef | _: WSubIndex | _: WSubField => e + case o => + o map replaceExprInterval(errors, info, mname) match { + case DoPrim(AsInterval, Seq(a1), _, tpe) => DoPrim(AsSInt, Seq(a1), Seq.empty, tpe) + case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe) + case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe) + case DoPrim(Clip, Seq(a1, _), Nil, tpe: IntervalType) => + // Output interval (pre-calculated) + val clipLo = tpe.minAdjusted.get + val clipHi = tpe.maxAdjusted.get + // Input interval + val (inLow, inHigh) = a1.tpe match { + case t2: IntervalType => (t2.minAdjusted.get, t2.maxAdjusted.get) + case _ => sys.error("Shouldn't be here") + } + val gtOpt = clipHi >= inHigh + val ltOpt = clipLo <= inLow + (gtOpt, ltOpt) match { + // input range within output range -> no optimization + case (true, true) => a1 + case (true, false) => Mux(Lt(a1, clipLo.S), clipLo.S, a1) + case (false, true) => Mux(Gt(a1, clipHi.S), clipHi.S, a1) + case _ => Mux(Gt(a1, clipHi.S), clipHi.S, Mux(Lt(a1, clipLo.S), clipLo.S, a1)) + } + + case sqz@DoPrim(Squeeze, Seq(a1, a2), Nil, tpe: IntervalType) => + // Using (conditional) reassign interval w/o adding mux + val a1tpe = a1.tpe.asInstanceOf[IntervalType] + val a2tpe = a2.tpe.asInstanceOf[IntervalType] + val min2 = a2tpe.min.get * BigDecimal(BigInt(1) << a1tpe.point.asInstanceOf[IntWidth].width.toInt) + val max2 = a2tpe.max.get * BigDecimal(BigInt(1) << a1tpe.point.asInstanceOf[IntWidth].width.toInt) + val w1 = Seq(a1tpe.minAdjusted.get.bitLength, a1tpe.maxAdjusted.get.bitLength).max + 1 + // Conservative + val minOpt2 = min2.setScale(0, FLOOR).toBigInt + val maxOpt2 = max2.setScale(0, CEILING).toBigInt + val w2 = Seq(minOpt2.bitLength, maxOpt2.bitLength).max + 1 + if (w1 < w2) { + a1 + } else { + val bits = DoPrim(Bits, Seq(a1), Seq(w2 - 1, 0), UIntType(IntWidth(w2))) + DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(w2))) + } + case w@DoPrim(Wrap, Seq(a1, a2), Nil, tpe: IntervalType) => a2.tpe match { + // If a2 type is Interval wrap around range. If UInt, wrap around width + case t: IntervalType => + // Need to match binary points before getting *adjusted! + val (wrapLo, wrapHi) = t.copy(point = tpe.point) match { + case t: IntervalType => (t.minAdjusted.get, t.maxAdjusted.get) + case _ => Utils.throwInternalError(s"Illegal AST state: cannot have $e not have an IntervalType") + } + val (inLo, inHi) = a1.tpe match { + case t2: IntervalType => (t2.minAdjusted.get, t2.maxAdjusted.get) + case _ => sys.error("Shouldn't be here") + } + // If (max input) - (max wrap) + (min wrap) is less then (maxwrap), we can optimize when (max input > max wrap) + val range = wrapHi - wrapLo + val ltOpt = Add(a1, (range + 1).S) + val gtOpt = Sub(a1, (range + 1).S) + // [Angie]: This is dangerous. Would rather throw compilation error right now than allow "Rem" without the user explicitly including it. + // If x < wl + // output: wh - (wl - x) + 1 AKA x + r + 1 + // worst case: wh - (wl - xl) + 1 = wl + // -> xl + wr + 1 = wl + // If x > wh + // output: wl + (x - wh) - 1 AKA x - r - 1 + // worst case: wl + (xh - wh) - 1 = wh + // -> xh - wr - 1 = wh + val default = Add(Rem(Sub(a1, wrapLo.S), Sub(wrapHi.S, wrapLo.S)), wrapLo.S) + (wrapHi >= inHi, wrapLo <= inLo, (inHi - range - 1) <= wrapHi, (inLo + range + 1) >= wrapLo) match { + case (true, true, _, _) => a1 + case (true, _, _, true) => Mux(Lt(a1, wrapLo.S), ltOpt, a1) + case (_, true, true, _) => Mux(Gt(a1, wrapHi.S), gtOpt, a1) + // Note: inHi - range - 1 = wrapHi can't be true when inLo + range + 1 = wrapLo (i.e. simultaneous extreme cases don't work) + case (_, _, true, true) => Mux(Gt(a1, wrapHi.S), gtOpt, Mux(Lt(a1, wrapLo.S), ltOpt, a1)) + case _ => + errors.append(new WrapWithRemainder(info, mname, w)) + default + } + case _ => sys.error("Shouldn't be here") + } + case other => other + } + } + + private def replacePortInterval(p: Port): Port = p map replaceTypeInterval + + private def replaceTypeInterval(t: Type): Type = t match { + case i@IntervalType(l: IsKnown, u: IsKnown, p: IntWidth) => SIntType(i.width) + case i: IntervalType => sys.error(s"Shouldn't be here: $i") + case v => v map replaceTypeInterval + } + + /** Replace Interval Nodes with Interval Wires + * + * You have to do this to capture the smaller bitwidths of nodes that intervals give you. Otherwise, + * any future InferTypes would reinfer the larger widths on these nodes from SInt width inference rules + * @param m module to replace nodes with wire + connection + * @return + */ + private def makeWireModule(m: DefModule): DefModule = m map makeWireStmt + + private def makeWireStmt(s: Statement): Statement = s match { + case DefNode(info, name, value) => value.tpe match { + case IntervalType(l, u, p) => + val newType = IntervalType(l, u, p) + Block(Seq(DefWire(info, name, newType), Connect(info, WRef(name, newType, WireKind, FEMALE), value))) + case other => s + } + case other => other map makeWireStmt + } +} diff --git a/src/main/scala/firrtl/passes/TrimIntervals.scala b/src/main/scala/firrtl/passes/TrimIntervals.scala new file mode 100644 index 00000000..dec64ee7 --- /dev/null +++ b/src/main/scala/firrtl/passes/TrimIntervals.scala @@ -0,0 +1,97 @@ +// See LICENSE for license details. + +package firrtl.passes + +import scala.collection.mutable +import firrtl.PrimOps._ +import firrtl.ir._ +import firrtl._ +import firrtl.Mappers._ +import firrtl.Utils.{error, field_type, getUIntWidth, max, module_type, sub_type} +import Implicits.{bigint2WInt, int2WInt} +import firrtl.constraint.{IsFloor, IsKnown, IsMul} + +/** Replaces IntervalType with SIntType, three AST walks: + * 1) Align binary points + * - adds shift operators to primop args and connections + * - does not affect declaration- or inferred-types + * 2) Replace declaration IntervalType's with SIntType's + * - for each declaration: + * a. remove non-zero binary points + * b. remove open bounds + * c. replace with SIntType + * 3) Run InferTypes + */ +class TrimIntervals extends Pass { + def run(c: Circuit): Circuit = { + // Open -> closed + val firstPass = InferTypes.run(c map replaceModuleInterval) + // Align binary points and adjust range accordingly (loss of precision changes range) + firstPass map alignModuleBP + } + + /* Replace interval types */ + private def replaceModuleInterval(m: DefModule): DefModule = m map replaceStmtInterval map replacePortInterval + + private def replaceStmtInterval(s: Statement): Statement = s map replaceTypeInterval map replaceStmtInterval + + private def replacePortInterval(p: Port): Port = p map replaceTypeInterval + + private def replaceTypeInterval(t: Type): Type = t match { + case i@IntervalType(l: IsKnown, u: IsKnown, IntWidth(p)) => + IntervalType(Closed(i.min.get), Closed(i.max.get), IntWidth(p)) + case i: IntervalType => i + case v => v map replaceTypeInterval + } + + /* Align interval binary points -- BINARY POINT ALIGNMENT AFFECTS RANGE INFERENCE! */ + private def alignModuleBP(m: DefModule): DefModule = m map alignStmtBP + + private def alignStmtBP(s: Statement): Statement = s map alignExpBP match { + case c@Connect(info, loc, expr) => loc.tpe match { + case IntervalType(_, _, p) => Connect(info, loc, fixBP(p)(expr)) + case _ => c + } + case c@PartialConnect(info, loc, expr) => loc.tpe match { + case IntervalType(_, _, p) => PartialConnect(info, loc, fixBP(p)(expr)) + case _ => c + } + case other => other map alignStmtBP + } + + // Note - wrap/clip/squeeze ignore the binary point of the second argument, thus not needed to be aligned + // Note - Mul does not need its binary points aligned, because multiplication is cool like that + private val opsToFix = Seq(Add, Sub, Lt, Leq, Gt, Geq, Eq, Neq/*, Wrap, Clip, Squeeze*/) + + private def alignExpBP(e: Expression): Expression = e map alignExpBP match { + case DoPrim(SetP, Seq(arg), Seq(const), tpe: IntervalType) => fixBP(IntWidth(const))(arg) + case DoPrim(o, args, consts, t) if opsToFix.contains(o) && + (args.map(_.tpe).collect { case x: IntervalType => x }).size == args.size => + val maxBP = args.map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ max _) + DoPrim(o, args.map { a => fixBP(maxBP)(a) }, consts, t) + case Mux(cond, tval, fval, t: IntervalType) => + val maxBP = Seq(tval, fval).map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ max _) + Mux(cond, fixBP(maxBP)(tval), fixBP(maxBP)(fval), t) + case other => other + } + private def fixBP(p: Width)(e: Expression): Expression = (p, e.tpe) match { + case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired == current => e + case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired > current => + DoPrim(IncP, Seq(e), Seq(desired - current), IntervalType(l, u, IntWidth(desired))) + case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired < current => + val shiftAmt = current - desired + val shiftGain = BigDecimal(BigInt(1) << shiftAmt.toInt) + val shiftMul = Closed(BigDecimal(1) / shiftGain) + val bpGain = BigDecimal(BigInt(1) << current.toInt) + // BP is inferred at this point + // y = floor(x * 2^(-amt + bp)) gets rid of precision --> y * 2^(-bp + amt) + val newBPRes = Closed(shiftGain / bpGain) + val bpResInv = Closed(bpGain) + val newL = IsMul(IsFloor(IsMul(IsMul(l, shiftMul), bpResInv)), newBPRes) + val newU = IsMul(IsFloor(IsMul(IsMul(u, shiftMul), bpResInv)), newBPRes) + DoPrim(DecP, Seq(e), Seq(current - desired), IntervalType(CalcBound(newL), CalcBound(newU), IntWidth(desired))) + case x => sys.error(s"Shouldn't be here: $x") + } +} + +// vim: set ts=4 sw=4 et: diff --git a/src/main/scala/firrtl/passes/memlib/MemUtils.scala b/src/main/scala/firrtl/passes/memlib/MemUtils.scala index bb441ebb..69c6b284 100644 --- a/src/main/scala/firrtl/passes/memlib/MemUtils.scala +++ b/src/main/scala/firrtl/passes/memlib/MemUtils.scala @@ -56,7 +56,7 @@ object MemPortUtils { type Modules = collection.mutable.ArrayBuffer[DefModule] def defaultPortSeq(mem: DefMemory): Seq[Field] = Seq( - Field("addr", Default, UIntType(IntWidth(ceilLog2(mem.depth) max 1))), + Field("addr", Default, UIntType(IntWidth(getUIntWidth(mem.depth - 1) max 1))), Field("en", Default, BoolType), Field("clk", Default, ClockType) ) |
