aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes
diff options
context:
space:
mode:
authorAdam Izraelevitz2019-10-18 19:01:19 -0700
committerGitHub2019-10-18 19:01:19 -0700
commitfd981848c7d2a800a15f9acfbf33b57dd1c6225b (patch)
tree3609a301cb0ec867deefea4a0d08425810b00418 /src/main/scala/firrtl/passes
parent973ecf516c0ef2b222f2eb68dc8b514767db59af (diff)
Upstream intervals (#870)
Major features: - Added Interval type, as well as PrimOps asInterval, clip, wrap, and sqz. - Changed PrimOp names: bpset -> setp, bpshl -> incp, bpshr -> decp - Refactored width/bound inferencer into a separate constraint solver - Added transforms to infer, trim, and remove interval bounds - Tests for said features Plan to be released with 1.3
Diffstat (limited to 'src/main/scala/firrtl/passes')
-rw-r--r--src/main/scala/firrtl/passes/CheckWidths.scala63
-rw-r--r--src/main/scala/firrtl/passes/Checks.scala97
-rw-r--r--src/main/scala/firrtl/passes/ConvertFixedToSInt.scala6
-rw-r--r--src/main/scala/firrtl/passes/InferBinaryPoints.scala101
-rw-r--r--src/main/scala/firrtl/passes/InferTypes.scala14
-rw-r--r--src/main/scala/firrtl/passes/InferWidths.scala486
-rw-r--r--src/main/scala/firrtl/passes/RemoveCHIRRTL.scala2
-rw-r--r--src/main/scala/firrtl/passes/RemoveIntervals.scala173
-rw-r--r--src/main/scala/firrtl/passes/TrimIntervals.scala97
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemUtils.scala2
10 files changed, 641 insertions, 400 deletions
diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala
index 07784e19..5ae5dad4 100644
--- a/src/main/scala/firrtl/passes/CheckWidths.scala
+++ b/src/main/scala/firrtl/passes/CheckWidths.scala
@@ -7,14 +7,21 @@ import firrtl.ir._
import firrtl.PrimOps._
import firrtl.traversals.Foreachers._
import firrtl.Utils._
+import firrtl.constraint.IsKnown
import firrtl.annotations.{CircuitTarget, ModuleTarget, Target, TargetToken}
object CheckWidths extends Pass {
/** The maximum allowed width for any circuit element */
val MaxWidth = 1000000
- val DshlMaxWidth = ceilLog2(MaxWidth + 1)
+ val DshlMaxWidth = getUIntWidth(MaxWidth)
class UninferredWidth (info: Info, target: String) extends PassException(
- s"""|$info : Uninferred width for target below. (Did you forget to assign to it?)
+ s"""|$info : Uninferred width for target below.serialize}. (Did you forget to assign to it?)
+ |$target""".stripMargin)
+ class UninferredBound (info: Info, target: String, bound: String) extends PassException(
+ s"""|$info : Uninferred $bound bound for target. (Did you forget to assign to it?)
+ |$target""".stripMargin)
+ class InvalidRange (info: Info, target: String, i: IntervalType) extends PassException(
+ s"""|$info : Invalid range ${i.serialize} for target below. (Are the bounds valid?)
|$target""".stripMargin)
class WidthTooSmall(info: Info, mname: String, b: BigInt) extends PassException(
s"$info : [target $mname] Width too small for constant $b.")
@@ -32,17 +39,25 @@ object CheckWidths extends Pass {
s"$info: [target $mname] Parameter $n in tail operator is larger than input width $width.")
class AttachWidthsNotEqual(info: Info, mname: String, eName: String, source: String) extends PassException(
s"$info: [target $mname] Attach source $source and expression $eName must have identical widths.")
+ class DisjointSqueeze(info: Info, mname: String, squeeze: DoPrim)
+ extends PassException({
+ val toSqz = squeeze.args.head.serialize
+ val toSqzTpe = squeeze.args.head.tpe.serialize
+ val sqzTo = squeeze.args(1).serialize
+ val sqzToTpe = squeeze.args(1).tpe.serialize
+ s"$info: [module $mname] Disjoint squz currently unsupported: $toSqz:$toSqzTpe cannot be squeezed with $sqzTo's type $sqzToTpe"
+ })
def run(c: Circuit): Circuit = {
val errors = new Errors()
- def check_width_w(info: Info, target: Target)(w: Width): Unit = {
- w match {
- case IntWidth(width) if width >= MaxWidth =>
+ def check_width_w(info: Info, target: Target, t: Type)(w: Width): Unit = {
+ (w, t) match {
+ case (IntWidth(width), _) if width >= MaxWidth =>
errors.append(new WidthTooBig(info, target.serialize, width))
- case w: IntWidth if w.width >= 0 =>
- case _: IntWidth =>
+ case (w: IntWidth, f: FixedType) if (w.width < 0 && w.width == f.width) =>
errors append new NegWidthException(info, target.serialize)
+ case (_: IntWidth, _) =>
case _ =>
errors append new UninferredWidth(info, target.prettyPrint(" "))
}
@@ -57,9 +72,28 @@ object CheckWidths extends Pass {
def check_width_t(info: Info, target: Target)(t: Type): Unit = {
t match {
case tt: BundleType => tt.fields.foreach(check_width_f(info, target))
+ //Supports when l = u (if closed)
+ case i@IntervalType(Closed(l), Closed(u), IntWidth(_)) if l <= u => i
+ case i:IntervalType if i.range == Some(Nil) =>
+ errors append new InvalidRange(info, target.prettyPrint(" "), i)
+ i
+ case i@IntervalType(KnownBound(l), KnownBound(u), IntWidth(p)) if l >= u =>
+ errors append new InvalidRange(info, target.prettyPrint(" "), i)
+ i
+ case i@IntervalType(KnownBound(_), KnownBound(_), IntWidth(_)) => i
+ case i@IntervalType(_: IsKnown, _, _) =>
+ errors append new UninferredBound(info, target.prettyPrint(" "), "upper")
+ i
+ case i@IntervalType(_, _: IsKnown, _) =>
+ errors append new UninferredBound(info, target.prettyPrint(" "), "lower")
+ i
+ case i@IntervalType(_, _, _) =>
+ errors append new UninferredBound(info, target.prettyPrint(" "), "lower")
+ errors append new UninferredBound(info, target.prettyPrint(" "), "upper")
+ i
case tt => tt foreach check_width_t(info, target)
}
- t foreach check_width_w(info, target)
+ t foreach check_width_w(info, target, t)
}
def check_width_f(info: Info, target: Target)(f: Field): Unit =
@@ -77,6 +111,12 @@ object CheckWidths extends Pass {
errors append new WidthTooSmall(info, target.serialize, e.value)
case _ =>
}
+ case sqz@DoPrim(Squeeze, Seq(a, b), _, IntervalType(Closed(min), Closed(max), _)) =>
+ (a.tpe, b.tpe) match {
+ case (IntervalType(Closed(la), Closed(ua), _), IntervalType(Closed(lb), Closed(ub), _)) if (ua < lb) || (ub < la) =>
+ errors append new DisjointSqueeze(info, target.serialize, sqz)
+ case other =>
+ }
case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) <= hi) =>
errors append new BitsWidthException(info, target.serialize, hi, bitWidth(a.tpe), e.serialize)
case DoPrim(Head, Seq(a), Seq(n), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) < n) =>
@@ -87,7 +127,6 @@ object CheckWidths extends Pass {
errors append new DshlTooBig(info, target.serialize)
case _ =>
}
- //e map check_width_t(info, mname) map check_width_e(info, mname)
e foreach check_width_e(info, target)
}
@@ -111,11 +150,15 @@ object CheckWidths extends Pass {
case ResetType =>
case _ => errors.append(new CheckTypes.IllegalResetType(info, target.serialize, sx.name))
}
+ if(!CheckTypes.validConnect(sx.tpe, sx.init.tpe)) {
+ val conMsg = sx.copy(info = NoInfo).serialize
+ errors.append(new CheckTypes.InvalidConnect(info, target.module, conMsg, WRef(sx), sx.init))
+ }
case _ =>
}
}
- def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Unit = check_width_t(p.info, target)(p.tpe)
+ def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Unit = check_width_t(p.info, target.ref(p.name))(p.tpe)
def check_width_m(circuit: CircuitTarget)(m: DefModule): Unit = {
m foreach check_width_p(m.info, circuit.module(m.name))
diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala
index 85ed7de0..4239247c 100644
--- a/src/main/scala/firrtl/passes/Checks.scala
+++ b/src/main/scala/firrtl/passes/Checks.scala
@@ -8,6 +8,7 @@ import firrtl.PrimOps._
import firrtl.Utils._
import firrtl.traversals.Foreachers._
import firrtl.WrappedType._
+import firrtl.constraint.{Constraint, IsKnown}
trait CheckHighFormLike {
type NameSet = collection.mutable.HashSet[String]
@@ -84,11 +85,11 @@ trait CheckHighFormLike {
e.op match {
case Add | Sub | Mul | Div | Rem | Lt | Leq | Gt | Geq |
- Eq | Neq | Dshl | Dshr | And | Or | Xor | Cat | Dshlw =>
+ Eq | Neq | Dshl | Dshr | And | Or | Xor | Cat | Dshlw | Clip | Wrap | Squeeze =>
correctNum(Option(2), 0)
case AsUInt | AsSInt | AsClock | AsAsyncReset | Cvt | Neq | Not =>
correctNum(Option(1), 0)
- case AsFixedPoint | Pad | Head | Tail | BPShl | BPShr | BPSet =>
+ case AsFixedPoint | Pad | Head | Tail | IncP | DecP | SetP =>
correctNum(Option(1), 1)
case Shl | Shr =>
correctNum(Option(1), 1)
@@ -101,6 +102,8 @@ trait CheckHighFormLike {
if (lsb > msb) {
errors.append(new LsbLargerThanMsbException(info, mname, e.op.toString, lsb, msb))
}
+ case AsInterval =>
+ correctNum(Option(1), 3)
case Andr | Orr | Xorr | Neg =>
correctNum(None,0)
}
@@ -137,7 +140,9 @@ trait CheckHighFormLike {
def checkHighFormT(info: Info, mname: String)(t: Type): Unit = {
t foreach checkHighFormT(info, mname)
t match {
- case tx: VectorType if tx.size < 0 => errors.append(new NegVecSizeException(info, mname))
+ case tx: VectorType if tx.size < 0 =>
+ errors.append(new NegVecSizeException(info, mname))
+ case i: IntervalType => i
case _ => t foreach checkHighFormW(info, mname)
}
}
@@ -146,6 +151,7 @@ trait CheckHighFormLike {
e match {
case _: Reference | _: SubField | _: SubIndex | _: SubAccess => // No error
case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess | _: Mux | _: ValidIf => // No error
+ case _: Reference | _: SubField | _: SubIndex | _: SubAccess => // No error
case _ => errors.append(new InvalidAccessException(info, mname))
}
}
@@ -164,8 +170,8 @@ trait CheckHighFormLike {
case ex: WSubAccess => validSubexp(info, mname)(ex.expr)
case ex => ex foreach validSubexp(info, mname)
}
- e foreach checkHighFormW(info, mname)
- e foreach checkHighFormT(info, mname)
+ e foreach checkHighFormW(info, mname + "/" + e.serialize)
+ e foreach checkHighFormT(info, mname + "/" + e.serialize)
e foreach checkHighFormE(info, mname, names)
}
@@ -215,8 +221,7 @@ trait CheckHighFormLike {
if (names(p.name))
errors.append(new NotUniqueException(NoInfo, mname, p.name))
names += p.name
- p.tpe foreach checkHighFormT(p.info, mname)
- p.tpe foreach checkHighFormW(p.info, mname)
+ checkHighFormT(p.info, mname)(p.tpe)
}
// Search for ResetType Ports of direction
@@ -339,6 +344,11 @@ object CheckTypes extends Pass {
s"$info: [module $mname] Uninferred type: $exp."
)
+ def fits(bigger: Constraint, smaller: Constraint): Boolean = (bigger, smaller) match {
+ case (IsKnown(v1), IsKnown(v2)) if v1 < v2 => false
+ case _ => true
+ }
+
def legalResetType(tpe: Type): Boolean = tpe match {
case UIntType(IntWidth(w)) if w == 1 => true
case AsyncResetType => true
@@ -355,6 +365,9 @@ object CheckTypes extends Pass {
case (_: UIntType, _: UIntType) => flip1 == flip2
case (_: SIntType, _: SIntType) => flip1 == flip2
case (_: FixedType, _: FixedType) => flip1 == flip2
+ case (i1: IntervalType, i2: IntervalType) =>
+ import Implicits.width2constraint
+ fits(i2.lower, i1.lower) && fits(i1.upper, i2.upper) && fits(i1.point, i2.point)
case (_: AnalogType, _: AnalogType) => true
case (AsyncResetType, AsyncResetType) => flip1 == flip2
case (ResetType, tpe) => legalResetType(tpe) && flip1 == flip2
@@ -375,7 +388,17 @@ object CheckTypes extends Pass {
}
}
- def validConnect(con: Connect): Boolean = wt(con.loc.tpe).superTypeOf(wt(con.expr.tpe))
+ def validConnect(locTpe: Type, expTpe: Type): Boolean = {
+ val itFits = (locTpe, expTpe) match {
+ case (i1: IntervalType, i2: IntervalType) =>
+ import Implicits.width2constraint
+ fits(i2.lower, i1.lower) && fits(i1.upper, i2.upper) && fits(i1.point, i2.point)
+ case _ => true
+ }
+ wt(locTpe).superTypeOf(wt(expTpe)) && itFits
+ }
+
+ def validConnect(con: Connect): Boolean = validConnect(con.loc.tpe, con.expr.tpe)
def validPartialConnect(con: PartialConnect): Boolean =
bulk_equals(con.loc.tpe, con.expr.tpe, Default, Default)
@@ -393,44 +416,51 @@ object CheckTypes extends Pass {
case tx: BundleType => tx.fields forall (x => x.flip == Default && passive(x.tpe))
case tx => true
}
+
def check_types_primop(info: Info, mname: String, e: DoPrim): Unit = {
- def checkAllTypes(exprs: Seq[Expression], okUInt: Boolean, okSInt: Boolean, okClock: Boolean, okFix: Boolean, okAsync: Boolean): Unit = {
- exprs.foldLeft((false, false, false, false, false)) {
- case ((isUInt, isSInt, isClock, isFix, isAsync), expr) => expr.tpe match {
- case u: UIntType => (true, isSInt, isClock, isFix, isAsync)
- case s: SIntType => (isUInt, true, isClock, isFix, isAsync)
- case ClockType => (isUInt, isSInt, true, isFix, isAsync)
- case f: FixedType => (isUInt, isSInt, isClock, true, isAsync)
- case AsyncResetType => (isUInt, isSInt, isClock, isFix, true)
+ def checkAllTypes(exprs: Seq[Expression], okUInt: Boolean, okSInt: Boolean, okClock: Boolean, okFix: Boolean, okAsync: Boolean, okInterval: Boolean): Unit = {
+ exprs.foldLeft((false, false, false, false, false, false)) {
+ case ((isUInt, isSInt, isClock, isFix, isAsync, isInterval), expr) => expr.tpe match {
+ case u: UIntType => (true, isSInt, isClock, isFix, isAsync, isInterval)
+ case s: SIntType => (isUInt, true, isClock, isFix, isAsync, isInterval)
+ case ClockType => (isUInt, isSInt, true, isFix, isAsync, isInterval)
+ case f: FixedType => (isUInt, isSInt, isClock, true, isAsync, isInterval)
+ case AsyncResetType => (isUInt, isSInt, isClock, isFix, true, isInterval)
+ case i:IntervalType => (isUInt, isSInt, isClock, isFix, isAsync, true)
case UnknownType =>
errors.append(new IllegalUnknownType(info, mname, e.serialize))
- (isUInt, isSInt, isClock, isFix, isAsync)
+ (isUInt, isSInt, isClock, isFix, isAsync, isInterval)
case other => throwInternalError(s"Illegal Type: ${other.serialize}")
}
} match {
- // (UInt, SInt, Clock, Fixed)
- case (isAll, false, false, false, false) if isAll == okUInt =>
- case (false, isAll, false, false, false) if isAll == okSInt =>
- case (false, false, isAll, false, false) if isAll == okClock =>
- case (false, false, false, isAll, false) if isAll == okFix =>
- case (false, false, false, false, isAll) if isAll == okAsync =>
+ // (UInt, SInt, Clock, Fixed, Async, Interval)
+ case (isAll, false, false, false, false, false) if isAll == okUInt =>
+ case (false, isAll, false, false, false, false) if isAll == okSInt =>
+ case (false, false, isAll, false, false, false) if isAll == okClock =>
+ case (false, false, false, isAll, false, false) if isAll == okFix =>
+ case (false, false, false, false, isAll, false) if isAll == okAsync =>
+ case (false, false, false, false, false, isAll) if isAll == okInterval =>
case x => errors.append(new OpNotCorrectType(info, mname, e.op.serialize, exprs.map(_.tpe.serialize)))
}
}
e.op match {
- case AsUInt | AsSInt | AsClock | AsFixedPoint | AsAsyncReset =>
+ case AsUInt | AsSInt | AsClock | AsFixedPoint | AsAsyncReset | AsInterval =>
// All types are ok
case Dshl | Dshr =>
- checkAllTypes(Seq(e.args.head), okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false)
- checkAllTypes(Seq(e.args(1)), okUInt=true, okSInt=false, okClock=false, okFix=false, okAsync=false)
+ checkAllTypes(Seq(e.args.head), okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true)
+ checkAllTypes(Seq(e.args(1)), okUInt=true, okSInt=false, okClock=false, okFix=false, okAsync=false, okInterval=false)
case Add | Sub | Mul | Lt | Leq | Gt | Geq | Eq | Neq =>
- checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false)
- case Pad | Shl | Shr | Cat | Bits | Head | Tail =>
- checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false)
- case BPShl | BPShr | BPSet =>
- checkAllTypes(e.args, okUInt=false, okSInt=false, okClock=false, okFix=true, okAsync=false)
+ checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true)
+ case Pad | Bits | Head | Tail =>
+ checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=false)
+ case Shl | Shr | Cat =>
+ checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true)
+ case IncP | DecP | SetP =>
+ checkAllTypes(e.args, okUInt=false, okSInt=false, okClock=false, okFix=true, okAsync=false, okInterval=true)
+ case Wrap | Clip | Squeeze =>
+ checkAllTypes(e.args, okUInt = false, okSInt = false, okClock = false, okFix = false, okAsync=false, okInterval = true)
case _ =>
- checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=false, okAsync=false)
+ checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=false, okAsync=false, okInterval=false)
}
}
@@ -494,6 +524,9 @@ object CheckTypes extends Pass {
sx.tpe match {
case AnalogType(_) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name))
case t if wt(sx.tpe) != wt(sx.init.tpe) => errors.append(new InvalidRegInit(info, mname))
+ case t if !validConnect(sx.tpe, sx.init.tpe) =>
+ val conMsg = sx.copy(info = NoInfo).serialize
+ errors.append(new CheckTypes.InvalidConnect(info, mname, conMsg, WRef(sx), sx.init))
case t =>
}
if (!legalResetType(sx.reset.tpe)) {
diff --git a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala
index 67fdfea0..05a000c5 100644
--- a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala
+++ b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala
@@ -39,9 +39,9 @@ object ConvertFixedToSInt extends Pass {
def updateExpType(e:Expression): Expression = e match {
case DoPrim(Mul, args, consts, tpe) => e map updateExpType
case DoPrim(AsFixedPoint, args, consts, tpe) => DoPrim(AsSInt, args, Seq.empty, tpe) map updateExpType
- case DoPrim(BPShl, args, consts, tpe) => DoPrim(Shl, args, consts, tpe) map updateExpType
- case DoPrim(BPShr, args, consts, tpe) => DoPrim(Shr, args, consts, tpe) map updateExpType
- case DoPrim(BPSet, args, consts, FixedType(w, IntWidth(p))) => alignArg(args.head, p) map updateExpType
+ case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe) map updateExpType
+ case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe) map updateExpType
+ case DoPrim(SetP, args, consts, FixedType(w, IntWidth(p))) => alignArg(args.head, p) map updateExpType
case DoPrim(op, args, consts, tpe) =>
val point = calcPoint(args)
val newExp = DoPrim(op, args.map(x => alignArg(x, point)), consts, UnknownType)
diff --git a/src/main/scala/firrtl/passes/InferBinaryPoints.scala b/src/main/scala/firrtl/passes/InferBinaryPoints.scala
new file mode 100644
index 00000000..258c9697
--- /dev/null
+++ b/src/main/scala/firrtl/passes/InferBinaryPoints.scala
@@ -0,0 +1,101 @@
+// See LICENSE for license details.
+
+package firrtl.passes
+
+import firrtl.ir._
+import firrtl.Utils._
+import firrtl.Mappers._
+import firrtl.annotations.{CircuitTarget, ModuleTarget, ReferenceTarget, Target}
+import firrtl.constraint.ConstraintSolver
+
+class InferBinaryPoints extends Pass {
+ private val constraintSolver = new ConstraintSolver()
+
+ private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1,t2) match {
+ case (UIntType(w1), UIntType(w2)) =>
+ case (SIntType(w1), SIntType(w2)) =>
+ case (ClockType, ClockType) =>
+ case (ResetType, _) =>
+ case (_, ResetType) =>
+ case (AsyncResetType, AsyncResetType) =>
+ case (FixedType(w1, p1), FixedType(w2, p2)) =>
+ constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint(""))
+ case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) =>
+ constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint(""))
+ case (AnalogType(w1), AnalogType(w2)) =>
+ case (t1: BundleType, t2: BundleType) =>
+ (t1.fields zip t2.fields) foreach { case (f1, f2) =>
+ (f1.flip, f2.flip) match {
+ case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe)
+ case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe)
+ case _ => sys.error("Shouldn't be here")
+ }
+ }
+ case (t1: VectorType, t2: VectorType) => addTypeConstraints(r1.index(0), r2.index(0))(t1.tpe, t2.tpe)
+ case other => throwInternalError(s"Illegal compiler state: cannot constraint different types - $other")
+ }
+ private def addDecConstraints(t: Type): Type = t map addDecConstraints
+ private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s map addDecConstraints match {
+ case c: Connect =>
+ val n = get_size(c.loc.tpe)
+ val locs = create_exps(c.loc)
+ val exps = create_exps(c.expr)
+ (locs zip exps) foreach { case (loc, exp) =>
+ to_flip(flow(loc)) match {
+ case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
+ case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
+ }
+ }
+ c
+ case pc: PartialConnect =>
+ val ls = get_valid_points(pc.loc.tpe, pc.expr.tpe, Default, Default)
+ val locs = create_exps(pc.loc)
+ val exps = create_exps(pc.expr)
+ ls foreach { case (x, y) =>
+ val loc = locs(x)
+ val exp = exps(y)
+ to_flip(flow(loc)) match {
+ case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
+ case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
+ }
+ }
+ pc
+ case r: DefRegister =>
+ addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe)
+ r
+ case x => x map addStmtConstraints(mt)
+ }
+ private def fixWidth(w: Width): Width = constraintSolver.get(w) match {
+ case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt)
+ case None => w
+ case _ => sys.error("Shouldn't be here")
+ }
+ private def fixType(t: Type): Type = t map fixType map fixWidth match {
+ case IntervalType(l, u, p) =>
+ val px = constraintSolver.get(p) match {
+ case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt)
+ case None => p
+ case _ => sys.error("Shouldn't be here")
+ }
+ IntervalType(l, u, px)
+ case FixedType(w, p) =>
+ val px = constraintSolver.get(p) match {
+ case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt)
+ case None => p
+ case _ => sys.error("Shouldn't be here")
+ }
+ FixedType(w, px)
+ case x => x
+ }
+ private def fixStmt(s: Statement): Statement = s map fixStmt map fixType
+ private def fixPort(p: Port): Port = Port(p.info, p.name, p.direction, fixType(p.tpe))
+ def run (c: Circuit): Circuit = {
+ val ct = CircuitTarget(c.main)
+ c.modules foreach (m => m map addStmtConstraints(ct.module(m.name)))
+ c.modules foreach (_.ports foreach {p => addDecConstraints(p.tpe)})
+ constraintSolver.solve()
+ InferTypes.run(c.copy(modules = c.modules map (_
+ map fixPort
+ map fixStmt)))
+ }
+}
diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala
index 288b62ba..3c5cf7fb 100644
--- a/src/main/scala/firrtl/passes/InferTypes.scala
+++ b/src/main/scala/firrtl/passes/InferTypes.scala
@@ -14,13 +14,23 @@ object InferTypes extends Pass {
val namespace = Namespace()
val mtypes = (c.modules map (m => m.name -> module_type(m))).toMap
+ def remove_unknowns_b(b: Bound): Bound = b match {
+ case UnknownBound => VarBound(namespace.newName("b"))
+ case k => k
+ }
+
def remove_unknowns_w(w: Width): Width = w match {
case UnknownWidth => VarWidth(namespace.newName("w"))
case wx => wx
}
- def remove_unknowns(t: Type): Type =
- t map remove_unknowns map remove_unknowns_w
+ def remove_unknowns(t: Type): Type = {
+ t map remove_unknowns map remove_unknowns_w match {
+ case IntervalType(l, u, p) =>
+ IntervalType(remove_unknowns_b(l), remove_unknowns_b(u), p)
+ case x => x
+ }
+ }
def infer_types_e(types: TypeMap)(e: Expression): Expression =
e map infer_types_e(types) match {
diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala
index 9c58da2c..2211d238 100644
--- a/src/main/scala/firrtl/passes/InferWidths.scala
+++ b/src/main/scala/firrtl/passes/InferWidths.scala
@@ -3,15 +3,20 @@
package firrtl.passes
// Datastructures
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.immutable.ListMap
-
import firrtl._
import firrtl.annotations.{Annotation, ReferenceTarget}
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
-import firrtl.traversals.Foreachers._
+import firrtl.Implicits.width2constraint
+import firrtl.annotations.{CircuitTarget, ModuleTarget, ReferenceTarget, Target}
+import firrtl.constraint.{ConstraintSolver, IsMax}
+
+object InferWidths {
+ def apply(): InferWidths = new InferWidths()
+ def run(c: Circuit): Circuit = new InferWidths().run(c)
+ def execute(state: CircuitState): CircuitState = new InferWidths().execute(state)
+}
case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarget) extends Annotation {
def update(renameMap: RenameMap): Seq[WidthGeqConstraintAnnotation] = {
@@ -33,369 +38,146 @@ case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarg
}
}
+/** Infers the widths of all signals with unknown widths
+ *
+ * Is a global width inference algorithm
+ * - Instances of the same module with unknown input port widths will be assigned the
+ * largest width of all assignments to each of its instance ports
+ * - If you don't want the global inference behavior, then be sure to define all your input widths
+ *
+ * Infers the smallest width is larger than all assigned widths to a signal
+ * - Note that this means that dummy assignments that are overwritten by last-connect-semantics
+ * can still influence width inference
+ * - E.g.
+ * wire x: UInt
+ * x <= UInt<5>(15)
+ * x <= UInt<1>(1)
+ *
+ * Since width inference occurs before lowering, it infers x's width to be 5 but with an assignment of UInt(1):
+ *
+ * wire x: UInt<5>
+ * x <= UInt<1>(1)
+ *
+ * Uses firrtl.constraint package to infer widths
+ */
class InferWidths extends Transform with ResolvedAnnotationPaths {
def inputForm: CircuitForm = UnknownForm
def outputForm: CircuitForm = UnknownForm
- val annotationClasses = Seq(classOf[WidthGeqConstraintAnnotation])
+ private val constraintSolver = new ConstraintSolver()
- type ConstraintMap = collection.mutable.LinkedHashMap[String, Width]
-
- def solve_constraints(l: Seq[WGeq]): ConstraintMap = {
- def unique(ls: Seq[Width]) : Seq[Width] =
- (ls map (new WrappedWidth(_))).distinct map (_.w)
- // Combines constraints on the same VarWidth into the same constraint
- def make_unique(ls: Seq[WGeq]): ListMap[String,Width] = {
- ls.foldLeft(ListMap.empty[String, Width])((acc, wgeq) => wgeq.loc match {
- case VarWidth(name) => acc.get(name) match {
- case None => acc + (name -> wgeq.exp)
- // Avoid constructing massive MaxWidth chains
- case Some(MaxWidth(args)) => acc + (name -> MaxWidth(wgeq.exp +: args))
- case Some(width) => acc + (name -> MaxWidth(Seq(wgeq.exp, width)))
- }
- case _ => acc
- })
- }
- def pullMinMax(w: Width): Width = w map pullMinMax match {
- case PlusWidth(MaxWidth(maxs), IntWidth(i)) => MaxWidth(maxs.map(m => PlusWidth(m, IntWidth(i))))
- case PlusWidth(IntWidth(i), MaxWidth(maxs)) => MaxWidth(maxs.map(m => PlusWidth(m, IntWidth(i))))
- case MinusWidth(MaxWidth(maxs), IntWidth(i)) => MaxWidth(maxs.map(m => MinusWidth(m, IntWidth(i))))
- case MinusWidth(IntWidth(i), MaxWidth(maxs)) => MaxWidth(maxs.map(m => MinusWidth(IntWidth(i), m)))
- case PlusWidth(MinWidth(mins), IntWidth(i)) => MinWidth(mins.map(m => PlusWidth(m, IntWidth(i))))
- case PlusWidth(IntWidth(i), MinWidth(mins)) => MinWidth(mins.map(m => PlusWidth(m, IntWidth(i))))
- case MinusWidth(MinWidth(mins), IntWidth(i)) => MinWidth(mins.map(m => MinusWidth(m, IntWidth(i))))
- case MinusWidth(IntWidth(i), MinWidth(mins)) => MinWidth(mins.map(m => MinusWidth(IntWidth(i), m)))
- case wx => wx
- }
- def collectMinMax(w: Width): Width = w map collectMinMax match {
- case MinWidth(args) => MinWidth(unique(args.foldLeft(List[Width]()) {
- case (res, wxx: MinWidth) => wxx.args ++: res
- case (res, wxx) => wxx +: res
- }))
- case MaxWidth(args) => MaxWidth(unique(args.foldLeft(List[Width]()) {
- case (res, wxx: MaxWidth) => wxx.args ++: res
- case (res, wxx) => wxx +: res
- }))
- case wx => wx
- }
- def mergePlusMinus(w: Width): Width = w map mergePlusMinus match {
- case wx: PlusWidth => (wx.arg1, wx.arg2) match {
- case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width + w2.width)
- case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1)
- case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1)
- case (IntWidth(y), PlusWidth(w1, IntWidth(x))) => PlusWidth(IntWidth(x + y), w1)
- case (IntWidth(y), PlusWidth(IntWidth(x), w1)) => PlusWidth(IntWidth(x + y), w1)
- case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(y - x), w1)
- case (IntWidth(y), MinusWidth(w1, IntWidth(x))) => PlusWidth(IntWidth(y - x), w1)
- case _ => wx
- }
- case wx: MinusWidth => (wx.arg1, wx.arg2) match {
- case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width)
- case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1)
- case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1)
- case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1)
- case _ => wx
- }
- case wx: ExpWidth => wx.arg1 match {
- case w1: IntWidth => IntWidth(BigInt((math.pow(2, w1.width.toDouble) - 1).toLong))
- case _ => wx
- }
- case wx => wx
- }
- def removeZeros(w: Width): Width = w map removeZeros match {
- case wx: PlusWidth => (wx.arg1, wx.arg2) match {
- case (w1, IntWidth(x)) if x == 0 => w1
- case (IntWidth(x), w1) if x == 0 => w1
- case _ => wx
- }
- case wx: MinusWidth => (wx.arg1, wx.arg2) match {
- case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width)
- case (w1, IntWidth(x)) if x == 0 => w1
- case _ => wx
- }
- case wx => wx
- }
- def simplify(w: Width): Width = {
- val opts = Seq(
- pullMinMax _,
- collectMinMax _,
- mergePlusMinus _,
- removeZeros _
- )
- opts.foldLeft(w) { (width, opt) => opt(width) }
- }
-
- def substitute(h: ConstraintMap)(w: Width): Width = {
- //;println-all-debug(["Substituting for [" w "]"])
- val wx = simplify(w)
- //;println-all-debug(["After Simplify: [" wx "]"])
- wx map substitute(h) match {
- //;("matched println-debugvarwidth!")
- case wxx: VarWidth => h get wxx.name match {
- case None => wxx
- case Some(p) =>
- //;println-debug("Contained!")
- //;println-all-debug(["Width: " wxx])
- //;println-all-debug(["Accessed: " h[name(wxx)]])
- val t = simplify(substitute(h)(p))
- h(wxx.name) = t
- t
- }
- case wxx => wxx
- //;println-all-debug(["not varwidth!" w])
- }
- }
-
- def b_sub(h: ConstraintMap)(w: Width): Width = {
- w map b_sub(h) match {
- case wx: VarWidth => h getOrElse (wx.name, wx)
- case wx => wx
- }
- }
-
- def remove_cycle(n: String)(w: Width): Width = {
- //;println-all-debug(["Removing cycle for " n " inside " w])
- w match {
- case wx: MaxWidth => MaxWidth(wx.args filter {
- case wxx: VarWidth => !(n equals wxx.name)
- case MinusWidth(VarWidth(name), IntWidth(i)) if ((i >= 0) && (n == name)) => false
- case _ => true
- })
- case wx: MinusWidth => wx.arg1 match {
- case v: VarWidth if n == v.name => v
- case v => wx
- }
- case wx => wx
- }
- //;println-all-debug(["After removing cycle for " n ", returning " wx])
- }
+ val annotationClasses = Seq(classOf[WidthGeqConstraintAnnotation])
- def hasVarWidth(n: String)(w: Width): Boolean = {
- var has = false
- def rec(w: Width): Width = {
- w match {
- case wx: VarWidth if wx.name == n => has = true
- case _ =>
+ private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1,t2) match {
+ case (UIntType(w1), UIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
+ case (SIntType(w1), SIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
+ case (ClockType, ClockType) =>
+ case (FixedType(w1, p1), FixedType(w2, p2)) =>
+ constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint(""))
+ constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
+ case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) =>
+ constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint(""))
+ constraintSolver.addLeq(l1, l2, r1.prettyPrint(""), r2.prettyPrint(""))
+ constraintSolver.addGeq(u1, u2, r1.prettyPrint(""), r2.prettyPrint(""))
+ case (AnalogType(w1), AnalogType(w2)) =>
+ constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
+ constraintSolver.addGeq(w2, w1, r1.prettyPrint(""), r2.prettyPrint(""))
+ case (t1: BundleType, t2: BundleType) =>
+ (t1.fields zip t2.fields) foreach { case (f1, f2) =>
+ (f1.flip, f2.flip) match {
+ case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe)
+ case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe)
+ case _ => sys.error("Shouldn't be here")
}
- w map rec
- }
- rec(w)
- has
- }
-
- //; Forward solve
- //; Returns a solved list where each constraint undergoes:
- //; 1) Continuous Solving (using triangular solving)
- //; 2) Remove Cycles
- //; 3) Move to solved if not self-recursive
- val u = make_unique(l)
-
- //println("======== UNIQUE CONSTRAINTS ========")
- //for (x <- u) { println(x) }
- //println("====================================")
-
- val f = new ConstraintMap
- val o = ArrayBuffer[String]()
- for ((n, e) <- u) {
- //println("==== SOLUTIONS TABLE ====")
- //for (x <- f) println(x)
- //println("=========================")
-
- val e_sub = simplify(substitute(f)(e))
-
- //println("Solving " + n + " => " + e)
- //println("After Substitute: " + n + " => " + e_sub)
- //println("==== SOLUTIONS TABLE (Post Substitute) ====")
- //for (x <- f) println(x)
- //println("=========================")
-
- val ex = remove_cycle(n)(e_sub)
-
- //println("After Remove Cycle: " + n + " => " + ex)
- if (!hasVarWidth(n)(ex)) {
- //println("Not rec!: " + n + " => " + ex)
- //println("Adding [" + n + "=>" + ex + "] to Solutions Table")
- f(n) = ex
- o += n
}
- }
-
- //println("Forward Solved Constraints")
- //for (x <- f) println(x)
-
- //; Backwards Solve
- val b = new ConstraintMap
- for (i <- (o.size - 1) to 0 by -1) {
- val n = o(i) // Should visit `o` backward
- /*
- println("SOLVE BACK: [" + n + " => " + f(n) + "]")
- println("==== SOLUTIONS TABLE ====")
- for (x <- b) println(x)
- println("=========================")
- */
- val ex = simplify(b_sub(b)(f(n)))
- /*
- println("BACK RETURN: [" + n + " => " + ex + "]")
- */
- b(n) = ex
- /*
- println("==== SOLUTIONS TABLE (Post backsolve) ====")
- for (x <- b) println(x)
- println("=========================")
- */
- }
- b
- }
-
- def get_constraints_t(t1: Type, t2: Type): Seq[WGeq] = (t1,t2) match {
- case (t1: UIntType, t2: UIntType) => Seq(WGeq(t1.width, t2.width))
- case (t1: SIntType, t2: SIntType) => Seq(WGeq(t1.width, t2.width))
- case (ClockType, ClockType) => Nil
+ case (t1: VectorType, t2: VectorType) => addTypeConstraints(r1.index(0), r2.index(0))(t1.tpe, t2.tpe)
case (AsyncResetType, AsyncResetType) => Nil
- case (FixedType(w1, p1), FixedType(w2, p2)) => Seq(WGeq(w1,w2), WGeq(p1,p2))
- case (AnalogType(w1), AnalogType(w2)) => Seq(WGeq(w1,w2), WGeq(w2,w1))
- case (t1: BundleType, t2: BundleType) =>
- (t1.fields zip t2.fields foldLeft Seq[WGeq]()){case (res, (f1, f2)) =>
- res ++ (f1.flip match {
- case Default => get_constraints_t(f1.tpe, f2.tpe)
- case Flip => get_constraints_t(f2.tpe, f1.tpe)
- })
- }
- case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe)
case (ResetType, _) => Nil
case (_, ResetType) => Nil
}
- def run(c: Circuit, extra: Seq[WGeq]): Circuit = {
- val v = ArrayBuffer[WGeq]() ++ extra
+ private def addExpConstraints(e: Expression): Expression = e map addExpConstraints match {
+ case m@Mux(p, tVal, fVal, t) =>
+ constraintSolver.addGeq(getWidth(p), Closed(1), "mux predicate", "1.W")
+ m
+ case other => other
+ }
- def get_constraints_e(e: Expression): Unit = {
- e match {
- case (e: Mux) => v ++= Seq(
- WGeq(getWidth(e.cond), IntWidth(1)),
- WGeq(IntWidth(1), getWidth(e.cond))
- )
- case _ =>
+ private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s map addExpConstraints match {
+ case c: Connect =>
+ val n = get_size(c.loc.tpe)
+ val locs = create_exps(c.loc)
+ val exps = create_exps(c.expr)
+ (locs zip exps).foreach { case (loc, exp) =>
+ to_flip(flow(loc)) match {
+ case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
+ case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
+ }
}
- e.foreach(get_constraints_e)
- }
-
- def get_constraints_declared_type (t: Type): Type = t match {
- case FixedType(_, p) =>
- v += WGeq(p,IntWidth(0))
- t
- case _ => t map get_constraints_declared_type
- }
-
- def get_constraints_s(s: Statement): Unit = {
- s map get_constraints_declared_type match {
- case (s: Connect) =>
- val locs = create_exps(s.loc)
- val exps = create_exps(s.expr)
- v ++= locs.zip(exps).flatMap { case (locx, expx) =>
- to_flip(flow(locx)) match {
- case Default => get_constraints_t(locx.tpe, expx.tpe)//WGeq(getWidth(locx), getWidth(expx))
- case Flip => get_constraints_t(expx.tpe, locx.tpe)//WGeq(getWidth(expx), getWidth(locx))
- }
- }
- case (s: PartialConnect) =>
- val ls = get_valid_points(s.loc.tpe, s.expr.tpe, Default, Default)
- val locs = create_exps(s.loc)
- val exps = create_exps(s.expr)
- v ++= (ls flatMap {case (x, y) =>
- val locx = locs(x)
- val expx = exps(y)
- to_flip(flow(locx)) match {
- case Default => get_constraints_t(locx.tpe, expx.tpe)//WGeq(getWidth(locx), getWidth(expx))
- case Flip => get_constraints_t(expx.tpe, locx.tpe)//WGeq(getWidth(expx), getWidth(locx))
- }
- })
- case (s: DefRegister) =>
- if (s.reset.tpe != AsyncResetType ) {
- v ++= (
- get_constraints_t(s.reset.tpe, UIntType(IntWidth(1))) ++
- get_constraints_t(UIntType(IntWidth(1)), s.reset.tpe))
- }
- v ++= get_constraints_t(s.tpe, s.init.tpe)
- case (s:Conditionally) => v ++=
- get_constraints_t(s.pred.tpe, UIntType(IntWidth(1))) ++
- get_constraints_t(UIntType(IntWidth(1)), s.pred.tpe)
- case Attach(_, exprs) =>
- // All widths must be equal
- val widths = exprs map (e => getWidth(e.tpe))
- v ++= widths.tail map (WGeq(widths.head, _))
- case _ =>
+ c
+ case pc: PartialConnect =>
+ val ls = get_valid_points(pc.loc.tpe, pc.expr.tpe, Default, Default)
+ val locs = create_exps(pc.loc)
+ val exps = create_exps(pc.expr)
+ ls foreach { case (x, y) =>
+ val loc = locs(x)
+ val exp = exps(y)
+ to_flip(flow(loc)) match {
+ case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
+ case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
+ }
}
- s.foreach(get_constraints_e)
- s.foreach(get_constraints_s)
- }
-
- c.modules.foreach(_.foreach(get_constraints_s))
- c.modules.foreach(_.ports.foreach({p => get_constraints_declared_type(p.tpe)}))
-
- //println("======== ALL CONSTRAINTS ========")
- //for(x <- v) println(x)
- //println("=================================")
- val h = solve_constraints(v)
- //println("======== SOLVED CONSTRAINTS ========")
- //for(x <- h) println(x)
- //println("====================================")
-
- def evaluate(w: Width): Width = {
- def map2(a: Option[BigInt], b: Option[BigInt], f: (BigInt,BigInt) => BigInt): Option[BigInt] =
- for (a_num <- a; b_num <- b) yield f(a_num, b_num)
- def reduceOptions(l: Seq[Option[BigInt]], f: (BigInt,BigInt) => BigInt): Option[BigInt] =
- l.reduce(map2(_, _, f))
+ pc
+ case r: DefRegister =>
+ if (r.reset.tpe != AsyncResetType ) {
+ addTypeConstraints(Target.asTarget(mt)(r.reset), mt.ref("1"))(r.reset.tpe, UIntType(IntWidth(1)))
+ }
+ addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe)
+ r
+ case a@Attach(_, exprs) =>
+ val widths = exprs map (e => (e, getWidth(e.tpe)))
+ val maxWidth = IsMax(widths.map(x => width2constraint(x._2)))
+ widths.foreach { case (e, w) =>
+ constraintSolver.addGeq(w, CalcWidth(maxWidth), Target.asTarget(mt)(e).prettyPrint(""), mt.ref(a.serialize).prettyPrint(""))
+ }
+ a
+ case c: Conditionally =>
+ addTypeConstraints(Target.asTarget(mt)(c.pred), mt.ref("1.W"))(c.pred.tpe, UIntType(IntWidth(1)))
+ c map addStmtConstraints(mt)
+ case x => x map addStmtConstraints(mt)
+ }
+ private def fixWidth(w: Width): Width = constraintSolver.get(w) match {
+ case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt)
+ case None => w
+ case _ => sys.error("Shouldn't be here")
+ }
+ private def fixType(t: Type): Type = t map fixType map fixWidth match {
+ case IntervalType(l, u, p) =>
+ val (lx, ux) = (constraintSolver.get(l), constraintSolver.get(u)) match {
+ case (Some(x: Bound), Some(y: Bound)) => (x, y)
+ case (None, None) => (l, u)
+ case x => sys.error(s"Shouldn't be here: $x")
- // This function shouldn't be necessary
- // Added as protection in case a constraint accidentally uses MinWidth/MaxWidth
- // without any actual Widths. This should be elevated to an earlier error
- def forceNonEmpty(in: Seq[Option[BigInt]], default: Option[BigInt]): Seq[Option[BigInt]] =
- if (in.isEmpty) Seq(default)
- else in
- def solve(w: Width): Option[BigInt] = w match {
- case wx: VarWidth =>
- for{
- v <- h.get(wx.name) if !v.isInstanceOf[VarWidth]
- result <- solve(v)
- } yield result
- case wx: MaxWidth => reduceOptions(forceNonEmpty(wx.args.map(solve), Some(BigInt(0))), max)
- case wx: MinWidth => reduceOptions(forceNonEmpty(wx.args.map(solve), None), min)
- case wx: PlusWidth => map2(solve(wx.arg1), solve(wx.arg2), {_ + _})
- case wx: MinusWidth => map2(solve(wx.arg1), solve(wx.arg2), {_ - _})
- case wx: ExpWidth => map2(Some(BigInt(2)), solve(wx.arg1), pow_minus_one)
- case wx: IntWidth => Some(wx.width)
- case wx => throwInternalError(s"solve: shouldn't be here - %$wx")
}
+ IntervalType(lx, ux, fixWidth(p))
+ case FixedType(w, p) => FixedType(w, fixWidth(p))
+ case x => x
+ }
+ private def fixStmt(s: Statement): Statement = s map fixStmt map fixType
+ private def fixPort(p: Port): Port = {
+ Port(p.info, p.name, p.direction, fixType(p.tpe))
+ }
- solve(w) match {
- case None => w
- case Some(s) => IntWidth(s)
- }
- }
-
- def reduce_var_widths_w(w: Width): Width = {
- //println-all-debug(["REPLACE: " w])
- evaluate(w)
- //println-all-debug(["WITH: " wx])
- }
-
- def reduce_var_widths_t(t: Type): Type = {
- t map reduce_var_widths_t map reduce_var_widths_w
- }
-
- def reduce_var_widths_s(s: Statement): Statement = {
- s map reduce_var_widths_s map reduce_var_widths_t
- }
-
- def reduce_var_widths_p(p: Port): Port = {
- Port(p.info, p.name, p.direction, reduce_var_widths_t(p.tpe))
- }
-
- InferTypes.run(c.copy(modules = c.modules map (_
- map reduce_var_widths_p
- map reduce_var_widths_s)))
+ def run (c: Circuit): Circuit = {
+ val ct = CircuitTarget(c.main)
+ c.modules foreach ( m => m map addStmtConstraints(ct.module(m.name)))
+ constraintSolver.solve()
+ val ret = InferTypes.run(c.copy(modules = c.modules map (_
+ map fixPort
+ map fixStmt)))
+ constraintSolver.clear()
+ ret
}
def execute(state: CircuitState): CircuitState = {
@@ -426,7 +208,7 @@ class InferWidths extends Transform with ResolvedAnnotationPaths {
}
}
- val extraConstraints = state.annotations.flatMap {
+ state.annotations.foreach {
case anno: WidthGeqConstraintAnnotation if anno.loc.isLocal && anno.exp.isLocal =>
val locType :: expType :: Nil = Seq(anno.loc, anno.exp) map { target =>
val baseType = typeMap.getOrElse(target.copy(component = Seq.empty),
@@ -440,10 +222,12 @@ class InferWidths extends Transform with ResolvedAnnotationPaths {
leafType
}
- get_constraints_t(locType, expType)
- case other => Seq.empty
+ //get_constraints_t(locType, expType)
+ addTypeConstraints(anno.loc, anno.exp)(locType, expType)
+ case other =>
}
- state.copy(circuit = run(state.circuit, extraConstraints))
+ state.copy(circuit = run(state.circuit))
}
+
}
diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
index 48b5f041..921ec3c7 100644
--- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
+++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
@@ -72,7 +72,7 @@ object RemoveCHIRRTL extends Transform {
refs: DataRefMap, raddrs: AddrMap, renames: RenameMap)(s: Statement): Statement = s match {
case sx: CDefMemory =>
types(sx.name) = sx.tpe
- val taddr = UIntType(IntWidth(1 max ceilLog2(sx.size)))
+ val taddr = UIntType(IntWidth(1 max getUIntWidth(sx.size - 1)))
val tdata = sx.tpe
def set_poison(vec: Seq[MPort]) = vec flatMap (r => Seq(
IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "addr", taddr)),
diff --git a/src/main/scala/firrtl/passes/RemoveIntervals.scala b/src/main/scala/firrtl/passes/RemoveIntervals.scala
new file mode 100644
index 00000000..73f59b59
--- /dev/null
+++ b/src/main/scala/firrtl/passes/RemoveIntervals.scala
@@ -0,0 +1,173 @@
+// See LICENSE for license details.
+
+package firrtl.passes
+
+import firrtl.PrimOps._
+import firrtl.ir._
+import firrtl._
+import firrtl.Mappers._
+import Implicits.{bigint2WInt}
+import firrtl.constraint.IsKnown
+
+import scala.math.BigDecimal.RoundingMode._
+
+class WrapWithRemainder(info: Info, mname: String, wrap: DoPrim)
+ extends PassException({
+ val toWrap = wrap.args.head.serialize
+ val toWrapTpe = wrap.args.head.tpe.serialize
+ val wrapTo = wrap.args(1).serialize
+ val wrapToTpe = wrap.args(1).tpe.serialize
+ s"$info: [module $mname] Wraps with remainder currently unsupported: $toWrap:$toWrapTpe cannot be wrapped to $wrapTo's type $wrapToTpe"
+ })
+
+
+/** Replaces IntervalType with SIntType, three AST walks:
+ * 1) Align binary points
+ * - adds shift operators to primop args and connections
+ * - does not affect declaration- or inferred-types
+ * 2) Replace Interval [[DefNode]] with [[DefWire]] + [[Connect]]
+ * - You have to do this to capture the smaller bitwidths of nodes that intervals give you. Otherwise, any future
+ * InferTypes would reinfer the larger widths on these nodes from SInt width inference rules
+ * 3) Replace declaration IntervalType's with SIntType's
+ * - for each declaration:
+ * a. remove non-zero binary points
+ * b. remove open bounds
+ * c. replace with SIntType
+ * 3) Run InferTypes
+ */
+class RemoveIntervals extends Pass {
+
+ def run(c: Circuit): Circuit = {
+ val alignedCircuit = c
+ val errors = new Errors()
+ val wiredCircuit = alignedCircuit map makeWireModule
+ val replacedCircuit = wiredCircuit map replaceModuleInterval(errors)
+ errors.trigger()
+ InferTypes.run(replacedCircuit)
+ }
+
+ /* Replace interval types */
+ private def replaceModuleInterval(errors: Errors)(m: DefModule): DefModule =
+ m map replaceStmtInterval(errors, m.name) map replacePortInterval
+
+ private def replaceStmtInterval(errors: Errors, mname: String)(s: Statement): Statement = {
+ val info = s match {
+ case h: HasInfo => h.info
+ case _ => NoInfo
+ }
+ s map replaceTypeInterval map replaceStmtInterval(errors, mname) map replaceExprInterval(errors, info, mname)
+
+ }
+
+ private def replaceExprInterval(errors: Errors, info: Info, mname: String)(e: Expression): Expression = e match {
+ case _: WRef | _: WSubIndex | _: WSubField => e
+ case o =>
+ o map replaceExprInterval(errors, info, mname) match {
+ case DoPrim(AsInterval, Seq(a1), _, tpe) => DoPrim(AsSInt, Seq(a1), Seq.empty, tpe)
+ case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe)
+ case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe)
+ case DoPrim(Clip, Seq(a1, _), Nil, tpe: IntervalType) =>
+ // Output interval (pre-calculated)
+ val clipLo = tpe.minAdjusted.get
+ val clipHi = tpe.maxAdjusted.get
+ // Input interval
+ val (inLow, inHigh) = a1.tpe match {
+ case t2: IntervalType => (t2.minAdjusted.get, t2.maxAdjusted.get)
+ case _ => sys.error("Shouldn't be here")
+ }
+ val gtOpt = clipHi >= inHigh
+ val ltOpt = clipLo <= inLow
+ (gtOpt, ltOpt) match {
+ // input range within output range -> no optimization
+ case (true, true) => a1
+ case (true, false) => Mux(Lt(a1, clipLo.S), clipLo.S, a1)
+ case (false, true) => Mux(Gt(a1, clipHi.S), clipHi.S, a1)
+ case _ => Mux(Gt(a1, clipHi.S), clipHi.S, Mux(Lt(a1, clipLo.S), clipLo.S, a1))
+ }
+
+ case sqz@DoPrim(Squeeze, Seq(a1, a2), Nil, tpe: IntervalType) =>
+ // Using (conditional) reassign interval w/o adding mux
+ val a1tpe = a1.tpe.asInstanceOf[IntervalType]
+ val a2tpe = a2.tpe.asInstanceOf[IntervalType]
+ val min2 = a2tpe.min.get * BigDecimal(BigInt(1) << a1tpe.point.asInstanceOf[IntWidth].width.toInt)
+ val max2 = a2tpe.max.get * BigDecimal(BigInt(1) << a1tpe.point.asInstanceOf[IntWidth].width.toInt)
+ val w1 = Seq(a1tpe.minAdjusted.get.bitLength, a1tpe.maxAdjusted.get.bitLength).max + 1
+ // Conservative
+ val minOpt2 = min2.setScale(0, FLOOR).toBigInt
+ val maxOpt2 = max2.setScale(0, CEILING).toBigInt
+ val w2 = Seq(minOpt2.bitLength, maxOpt2.bitLength).max + 1
+ if (w1 < w2) {
+ a1
+ } else {
+ val bits = DoPrim(Bits, Seq(a1), Seq(w2 - 1, 0), UIntType(IntWidth(w2)))
+ DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(w2)))
+ }
+ case w@DoPrim(Wrap, Seq(a1, a2), Nil, tpe: IntervalType) => a2.tpe match {
+ // If a2 type is Interval wrap around range. If UInt, wrap around width
+ case t: IntervalType =>
+ // Need to match binary points before getting *adjusted!
+ val (wrapLo, wrapHi) = t.copy(point = tpe.point) match {
+ case t: IntervalType => (t.minAdjusted.get, t.maxAdjusted.get)
+ case _ => Utils.throwInternalError(s"Illegal AST state: cannot have $e not have an IntervalType")
+ }
+ val (inLo, inHi) = a1.tpe match {
+ case t2: IntervalType => (t2.minAdjusted.get, t2.maxAdjusted.get)
+ case _ => sys.error("Shouldn't be here")
+ }
+ // If (max input) - (max wrap) + (min wrap) is less then (maxwrap), we can optimize when (max input > max wrap)
+ val range = wrapHi - wrapLo
+ val ltOpt = Add(a1, (range + 1).S)
+ val gtOpt = Sub(a1, (range + 1).S)
+ // [Angie]: This is dangerous. Would rather throw compilation error right now than allow "Rem" without the user explicitly including it.
+ // If x < wl
+ // output: wh - (wl - x) + 1 AKA x + r + 1
+ // worst case: wh - (wl - xl) + 1 = wl
+ // -> xl + wr + 1 = wl
+ // If x > wh
+ // output: wl + (x - wh) - 1 AKA x - r - 1
+ // worst case: wl + (xh - wh) - 1 = wh
+ // -> xh - wr - 1 = wh
+ val default = Add(Rem(Sub(a1, wrapLo.S), Sub(wrapHi.S, wrapLo.S)), wrapLo.S)
+ (wrapHi >= inHi, wrapLo <= inLo, (inHi - range - 1) <= wrapHi, (inLo + range + 1) >= wrapLo) match {
+ case (true, true, _, _) => a1
+ case (true, _, _, true) => Mux(Lt(a1, wrapLo.S), ltOpt, a1)
+ case (_, true, true, _) => Mux(Gt(a1, wrapHi.S), gtOpt, a1)
+ // Note: inHi - range - 1 = wrapHi can't be true when inLo + range + 1 = wrapLo (i.e. simultaneous extreme cases don't work)
+ case (_, _, true, true) => Mux(Gt(a1, wrapHi.S), gtOpt, Mux(Lt(a1, wrapLo.S), ltOpt, a1))
+ case _ =>
+ errors.append(new WrapWithRemainder(info, mname, w))
+ default
+ }
+ case _ => sys.error("Shouldn't be here")
+ }
+ case other => other
+ }
+ }
+
+ private def replacePortInterval(p: Port): Port = p map replaceTypeInterval
+
+ private def replaceTypeInterval(t: Type): Type = t match {
+ case i@IntervalType(l: IsKnown, u: IsKnown, p: IntWidth) => SIntType(i.width)
+ case i: IntervalType => sys.error(s"Shouldn't be here: $i")
+ case v => v map replaceTypeInterval
+ }
+
+ /** Replace Interval Nodes with Interval Wires
+ *
+ * You have to do this to capture the smaller bitwidths of nodes that intervals give you. Otherwise,
+ * any future InferTypes would reinfer the larger widths on these nodes from SInt width inference rules
+ * @param m module to replace nodes with wire + connection
+ * @return
+ */
+ private def makeWireModule(m: DefModule): DefModule = m map makeWireStmt
+
+ private def makeWireStmt(s: Statement): Statement = s match {
+ case DefNode(info, name, value) => value.tpe match {
+ case IntervalType(l, u, p) =>
+ val newType = IntervalType(l, u, p)
+ Block(Seq(DefWire(info, name, newType), Connect(info, WRef(name, newType, WireKind, FEMALE), value)))
+ case other => s
+ }
+ case other => other map makeWireStmt
+ }
+}
diff --git a/src/main/scala/firrtl/passes/TrimIntervals.scala b/src/main/scala/firrtl/passes/TrimIntervals.scala
new file mode 100644
index 00000000..dec64ee7
--- /dev/null
+++ b/src/main/scala/firrtl/passes/TrimIntervals.scala
@@ -0,0 +1,97 @@
+// See LICENSE for license details.
+
+package firrtl.passes
+
+import scala.collection.mutable
+import firrtl.PrimOps._
+import firrtl.ir._
+import firrtl._
+import firrtl.Mappers._
+import firrtl.Utils.{error, field_type, getUIntWidth, max, module_type, sub_type}
+import Implicits.{bigint2WInt, int2WInt}
+import firrtl.constraint.{IsFloor, IsKnown, IsMul}
+
+/** Replaces IntervalType with SIntType, three AST walks:
+ * 1) Align binary points
+ * - adds shift operators to primop args and connections
+ * - does not affect declaration- or inferred-types
+ * 2) Replace declaration IntervalType's with SIntType's
+ * - for each declaration:
+ * a. remove non-zero binary points
+ * b. remove open bounds
+ * c. replace with SIntType
+ * 3) Run InferTypes
+ */
+class TrimIntervals extends Pass {
+ def run(c: Circuit): Circuit = {
+ // Open -> closed
+ val firstPass = InferTypes.run(c map replaceModuleInterval)
+ // Align binary points and adjust range accordingly (loss of precision changes range)
+ firstPass map alignModuleBP
+ }
+
+ /* Replace interval types */
+ private def replaceModuleInterval(m: DefModule): DefModule = m map replaceStmtInterval map replacePortInterval
+
+ private def replaceStmtInterval(s: Statement): Statement = s map replaceTypeInterval map replaceStmtInterval
+
+ private def replacePortInterval(p: Port): Port = p map replaceTypeInterval
+
+ private def replaceTypeInterval(t: Type): Type = t match {
+ case i@IntervalType(l: IsKnown, u: IsKnown, IntWidth(p)) =>
+ IntervalType(Closed(i.min.get), Closed(i.max.get), IntWidth(p))
+ case i: IntervalType => i
+ case v => v map replaceTypeInterval
+ }
+
+ /* Align interval binary points -- BINARY POINT ALIGNMENT AFFECTS RANGE INFERENCE! */
+ private def alignModuleBP(m: DefModule): DefModule = m map alignStmtBP
+
+ private def alignStmtBP(s: Statement): Statement = s map alignExpBP match {
+ case c@Connect(info, loc, expr) => loc.tpe match {
+ case IntervalType(_, _, p) => Connect(info, loc, fixBP(p)(expr))
+ case _ => c
+ }
+ case c@PartialConnect(info, loc, expr) => loc.tpe match {
+ case IntervalType(_, _, p) => PartialConnect(info, loc, fixBP(p)(expr))
+ case _ => c
+ }
+ case other => other map alignStmtBP
+ }
+
+ // Note - wrap/clip/squeeze ignore the binary point of the second argument, thus not needed to be aligned
+ // Note - Mul does not need its binary points aligned, because multiplication is cool like that
+ private val opsToFix = Seq(Add, Sub, Lt, Leq, Gt, Geq, Eq, Neq/*, Wrap, Clip, Squeeze*/)
+
+ private def alignExpBP(e: Expression): Expression = e map alignExpBP match {
+ case DoPrim(SetP, Seq(arg), Seq(const), tpe: IntervalType) => fixBP(IntWidth(const))(arg)
+ case DoPrim(o, args, consts, t) if opsToFix.contains(o) &&
+ (args.map(_.tpe).collect { case x: IntervalType => x }).size == args.size =>
+ val maxBP = args.map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ max _)
+ DoPrim(o, args.map { a => fixBP(maxBP)(a) }, consts, t)
+ case Mux(cond, tval, fval, t: IntervalType) =>
+ val maxBP = Seq(tval, fval).map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ max _)
+ Mux(cond, fixBP(maxBP)(tval), fixBP(maxBP)(fval), t)
+ case other => other
+ }
+ private def fixBP(p: Width)(e: Expression): Expression = (p, e.tpe) match {
+ case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired == current => e
+ case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired > current =>
+ DoPrim(IncP, Seq(e), Seq(desired - current), IntervalType(l, u, IntWidth(desired)))
+ case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired < current =>
+ val shiftAmt = current - desired
+ val shiftGain = BigDecimal(BigInt(1) << shiftAmt.toInt)
+ val shiftMul = Closed(BigDecimal(1) / shiftGain)
+ val bpGain = BigDecimal(BigInt(1) << current.toInt)
+ // BP is inferred at this point
+ // y = floor(x * 2^(-amt + bp)) gets rid of precision --> y * 2^(-bp + amt)
+ val newBPRes = Closed(shiftGain / bpGain)
+ val bpResInv = Closed(bpGain)
+ val newL = IsMul(IsFloor(IsMul(IsMul(l, shiftMul), bpResInv)), newBPRes)
+ val newU = IsMul(IsFloor(IsMul(IsMul(u, shiftMul), bpResInv)), newBPRes)
+ DoPrim(DecP, Seq(e), Seq(current - desired), IntervalType(CalcBound(newL), CalcBound(newU), IntWidth(desired)))
+ case x => sys.error(s"Shouldn't be here: $x")
+ }
+}
+
+// vim: set ts=4 sw=4 et:
diff --git a/src/main/scala/firrtl/passes/memlib/MemUtils.scala b/src/main/scala/firrtl/passes/memlib/MemUtils.scala
index bb441ebb..69c6b284 100644
--- a/src/main/scala/firrtl/passes/memlib/MemUtils.scala
+++ b/src/main/scala/firrtl/passes/memlib/MemUtils.scala
@@ -56,7 +56,7 @@ object MemPortUtils {
type Modules = collection.mutable.ArrayBuffer[DefModule]
def defaultPortSeq(mem: DefMemory): Seq[Field] = Seq(
- Field("addr", Default, UIntType(IntWidth(ceilLog2(mem.depth) max 1))),
+ Field("addr", Default, UIntType(IntWidth(getUIntWidth(mem.depth - 1) max 1))),
Field("en", Default, BoolType),
Field("clk", Default, ClockType)
)