aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAdam Izraelevitz2019-10-18 19:01:19 -0700
committerGitHub2019-10-18 19:01:19 -0700
commitfd981848c7d2a800a15f9acfbf33b57dd1c6225b (patch)
tree3609a301cb0ec867deefea4a0d08425810b00418 /src
parent973ecf516c0ef2b222f2eb68dc8b514767db59af (diff)
Upstream intervals (#870)
Major features: - Added Interval type, as well as PrimOps asInterval, clip, wrap, and sqz. - Changed PrimOp names: bpset -> setp, bpshl -> incp, bpshr -> decp - Refactored width/bound inferencer into a separate constraint solver - Added transforms to infer, trim, and remove interval bounds - Tests for said features Plan to be released with 1.3
Diffstat (limited to 'src')
-rw-r--r--src/main/antlr4/FIRRTL.g430
-rw-r--r--src/main/proto/firrtl.proto10
-rw-r--r--src/main/scala/firrtl/Implicits.scala30
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala3
-rw-r--r--src/main/scala/firrtl/Mappers.scala13
-rw-r--r--src/main/scala/firrtl/PrimOps.scala664
-rw-r--r--src/main/scala/firrtl/Utils.scala16
-rw-r--r--src/main/scala/firrtl/Visitor.scala35
-rw-r--r--src/main/scala/firrtl/WIR.scala97
-rw-r--r--src/main/scala/firrtl/annotations/Target.scala19
-rw-r--r--src/main/scala/firrtl/constraint/Constraint.scala22
-rw-r--r--src/main/scala/firrtl/constraint/ConstraintSolver.scala357
-rw-r--r--src/main/scala/firrtl/constraint/Inequality.scala24
-rw-r--r--src/main/scala/firrtl/constraint/IsAdd.scala52
-rw-r--r--src/main/scala/firrtl/constraint/IsFloor.scala32
-rw-r--r--src/main/scala/firrtl/constraint/IsKnown.scala44
-rw-r--r--src/main/scala/firrtl/constraint/IsMax.scala59
-rw-r--r--src/main/scala/firrtl/constraint/IsMin.scala57
-rw-r--r--src/main/scala/firrtl/constraint/IsMul.scala52
-rw-r--r--src/main/scala/firrtl/constraint/IsNeg.scala32
-rw-r--r--src/main/scala/firrtl/constraint/IsPow.scala33
-rw-r--r--src/main/scala/firrtl/constraint/IsVar.scala27
-rw-r--r--src/main/scala/firrtl/ir/IR.scala177
-rw-r--r--src/main/scala/firrtl/passes/CheckWidths.scala63
-rw-r--r--src/main/scala/firrtl/passes/Checks.scala97
-rw-r--r--src/main/scala/firrtl/passes/ConvertFixedToSInt.scala6
-rw-r--r--src/main/scala/firrtl/passes/InferBinaryPoints.scala101
-rw-r--r--src/main/scala/firrtl/passes/InferTypes.scala14
-rw-r--r--src/main/scala/firrtl/passes/InferWidths.scala486
-rw-r--r--src/main/scala/firrtl/passes/RemoveCHIRRTL.scala2
-rw-r--r--src/main/scala/firrtl/passes/RemoveIntervals.scala173
-rw-r--r--src/main/scala/firrtl/passes/TrimIntervals.scala97
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemUtils.scala2
-rw-r--r--src/main/scala/firrtl/proto/ToProto.scala10
-rw-r--r--src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala4
-rw-r--r--src/test/scala/firrtlTests/AsyncResetSpec.scala29
-rw-r--r--src/test/scala/firrtlTests/ChirrtlSpec.scala4
-rw-r--r--src/test/scala/firrtlTests/InfoSpec.scala4
-rw-r--r--src/test/scala/firrtlTests/LowerTypesSpec.scala2
-rw-r--r--src/test/scala/firrtlTests/UniquifySpec.scala2
-rw-r--r--src/test/scala/firrtlTests/ZeroWidthTests.scala2
-rw-r--r--src/test/scala/firrtlTests/constraint/InequalitySpec.scala197
-rw-r--r--src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala44
-rw-r--r--src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala13
-rw-r--r--src/test/scala/firrtlTests/interval/IntervalMathSpec.scala183
-rw-r--r--src/test/scala/firrtlTests/interval/IntervalSpec.scala530
-rw-r--r--src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala2
47 files changed, 3229 insertions, 723 deletions
diff --git a/src/main/antlr4/FIRRTL.g4 b/src/main/antlr4/FIRRTL.g4
index be15ab7c..518cb698 100644
--- a/src/main/antlr4/FIRRTL.g4
+++ b/src/main/antlr4/FIRRTL.g4
@@ -50,6 +50,7 @@ type
: 'UInt' ('<' intLit '>')?
| 'SInt' ('<' intLit '>')?
| 'Fixed' ('<' intLit '>')? ('<' '<' intLit '>' '>')?
+ | 'Interval' (lowerBound boundValue boundValue upperBound)? ('.' intLit)?
| 'Clock'
| 'AsyncReset'
| 'Reset'
@@ -187,6 +188,23 @@ intLit
| HexLit
;
+lowerBound
+ : '['
+ | '('
+ ;
+
+upperBound
+ : ']'
+ | ')'
+ ;
+
+boundValue
+ : '?'
+ | DoubleLit
+ | UnsignedInt
+ | SignedInt
+ ;
+
// Keywords that are also legal ids
keywordAsId
: 'circuit'
@@ -253,6 +271,8 @@ primop
| 'asAsyncReset('
| 'asSInt('
| 'asClock('
+ | 'asFixedPoint('
+ | 'asInterval('
| 'shl('
| 'shr('
| 'dshl('
@@ -270,10 +290,12 @@ primop
| 'bits('
| 'head('
| 'tail('
- | 'asFixedPoint('
- | 'bpshl('
- | 'bpshr('
- | 'bpset('
+ | 'incp('
+ | 'decp('
+ | 'setp('
+ | 'wrap('
+ | 'clip('
+ | 'squz('
;
/*------------------------------------------------------------------
diff --git a/src/main/proto/firrtl.proto b/src/main/proto/firrtl.proto
index 0cf14f41..3d2c89f1 100644
--- a/src/main/proto/firrtl.proto
+++ b/src/main/proto/firrtl.proto
@@ -451,10 +451,14 @@ message Firrtl {
OP_AS_FIXED_POINT = 32;
OP_AND_REDUCE = 33;
OP_OR_REDUCE = 34;
- OP_SHIFT_BINARY_POINT_LEFT = 35;
- OP_SHIFT_BINARY_POINT_RIGHT = 36;
- OP_SET_BINARY_POINT = 37;
+ OP_INCREASE_PRECISION = 35;
+ OP_DECREASE_PRECISION = 36;
+ OP_SET_PRECISION = 37;
OP_AS_ASYNC_RESET = 38;
+ OP_WRAP = 39;
+ OP_CLIP = 40;
+ OP_SQUEEZE = 41;
+ OP_AS_INTERVAL = 42;
}
// Required.
diff --git a/src/main/scala/firrtl/Implicits.scala b/src/main/scala/firrtl/Implicits.scala
new file mode 100644
index 00000000..ec1cf3d6
--- /dev/null
+++ b/src/main/scala/firrtl/Implicits.scala
@@ -0,0 +1,30 @@
+// See LICENSE for license details.
+
+package firrtl
+
+import firrtl.ir._
+import Utils.trim
+import firrtl.constraint.Constraint
+
+object Implicits {
+ implicit def int2WInt(i: Int): WrappedInt = WrappedInt(BigInt(i))
+ implicit def bigint2WInt(i: BigInt): WrappedInt = WrappedInt(i)
+ implicit def constraint2bound(c: Constraint): Bound = c match {
+ case x: Bound => x
+ case x => CalcBound(x)
+ }
+ implicit def constraint2width(c: Constraint): Width = c match {
+ case Closed(x) if trim(x).isWhole => IntWidth(x.toBigInt)
+ case x => CalcWidth(x)
+ }
+ implicit def width2constraint(w: Width): Constraint = w match {
+ case CalcWidth(x: Constraint) => x
+ case IntWidth(x) => Closed(BigDecimal(x))
+ case UnknownWidth => UnknownBound
+ case v: Constraint => v
+ }
+}
+case class WrappedInt(value: BigInt) {
+ def U: Expression = UIntLiteral(value, IntWidth(Utils.getUIntWidth(value)))
+ def S: Expression = SIntLiteral(value, IntWidth(Utils.getSIntWidth(value)))
+}
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index 05cdbe96..75645319 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -45,6 +45,8 @@ class ResolveAndCheck extends CoreTransform {
passes.InferTypes,
passes.ResolveFlows,
passes.CheckFlows,
+ new passes.InferBinaryPoints(),
+ new passes.TrimIntervals(),
new passes.InferWidths,
passes.CheckWidths,
new firrtl.transforms.InferResets)
@@ -73,6 +75,7 @@ class HighFirrtlToMiddleFirrtl extends CoreTransform {
passes.ResolveFlows,
new passes.InferWidths,
passes.CheckWidths,
+ new passes.RemoveIntervals(),
passes.ConvertFixedToSInt,
passes.ZeroWidth,
passes.InferTypes)
diff --git a/src/main/scala/firrtl/Mappers.scala b/src/main/scala/firrtl/Mappers.scala
index e8283d93..b30b4518 100644
--- a/src/main/scala/firrtl/Mappers.scala
+++ b/src/main/scala/firrtl/Mappers.scala
@@ -6,6 +6,19 @@ import firrtl.ir._
// TODO: Implement remaining mappers and recursive mappers
object Mappers {
+ // ********** Port Mappers **********
+ private trait PortMagnet {
+ def map(p: Port): Port
+ }
+ private object PortMagnet {
+ implicit def forType(f: Type => Type): PortMagnet = new PortMagnet {
+ override def map(port: Port): Port = port mapType f
+ }
+ }
+ implicit class PortMap(val _port: Port) extends AnyVal {
+ def map[T](f: T => T)(implicit magnet: (T => T) => PortMagnet): Port = magnet(f).map(_port)
+ }
+
// ********** Stmt Mappers **********
private trait StmtMagnet {
diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala
index af8328da..02404f70 100644
--- a/src/main/scala/firrtl/PrimOps.scala
+++ b/src/main/scala/firrtl/PrimOps.scala
@@ -3,92 +3,487 @@
package firrtl
import firrtl.ir._
-import firrtl.Utils.{min, max, pow_minus_one}
-
import com.typesafe.scalalogging.LazyLogging
+import Implicits.{constraint2bound, constraint2width, width2constraint}
+import firrtl.constraint._
/** Definitions and Utility functions for [[ir.PrimOp]]s */
object PrimOps extends LazyLogging {
+ def t1(e: DoPrim): Type = e.args.head.tpe
+ def t2(e: DoPrim): Type = e.args(1).tpe
+ def t3(e: DoPrim): Type = e.args(2).tpe
+ def w1(e: DoPrim): Width = getWidth(t1(e))
+ def w2(e: DoPrim): Width = getWidth(t2(e))
+ def p1(e: DoPrim): Width = t1(e) match {
+ case FixedType(w, p) => p
+ case IntervalType(min, max, p) => p
+ case _ => sys.error(s"Cannot get binary point from ${t1(e)}")
+ }
+ def p2(e: DoPrim): Width = t2(e) match {
+ case FixedType(w, p) => p
+ case IntervalType(min, max, p) => p
+ case _ => sys.error(s"Cannot get binary point from ${t1(e)}")
+ }
+ def c1(e: DoPrim) = IntWidth(e.consts.head)
+ def c2(e: DoPrim) = IntWidth(e.consts(1))
+ def o1(e: DoPrim) = e.consts(0)
+ def o2(e: DoPrim) = e.consts(1)
+ def o3(e: DoPrim) = e.consts(2)
+
/** Addition */
- case object Add extends PrimOp { override def toString = "add" }
+ case object Add extends PrimOp {
+ override def toString = "add"
+ override def propagateType(e: DoPrim): Type = {
+ (t1(e), t2(e)) match {
+ case (_: UIntType, _: UIntType) => UIntType(IsAdd(IsMax(w1(e), w2(e)), IntWidth(1)))
+ case (_: SIntType, _: SIntType) => SIntType(IsAdd(IsMax(w1(e), w2(e)), IntWidth(1)))
+ case (_: FixedType, _: FixedType) => FixedType(IsAdd(IsAdd(IsMax(p1(e), p2(e)), IsMax(IsAdd(w1(e), IsNeg(p1(e))), IsAdd(w2(e), IsNeg(p2(e))))), IntWidth(1)), IsMax(p1(e), p2(e)))
+ case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) => IntervalType(IsAdd(l1, l2), IsAdd(u1, u2), IsMax(p1, p2))
+ case _ => UnknownType
+ }
+ }
+ }
+
/** Subtraction */
- case object Sub extends PrimOp { override def toString = "sub" }
+ case object Sub extends PrimOp {
+ override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
+ case (_: UIntType, _: UIntType) => UIntType(IsAdd(IsMax(w1(e), w2(e)), IntWidth(1)))
+ case (_: SIntType, _: SIntType) => SIntType(IsAdd(IsMax(w1(e), w2(e)), IntWidth(1)))
+ case (_: FixedType, _: FixedType) => FixedType(IsAdd(IsAdd(IsMax(p1(e), p2(e)),IsMax(IsAdd(w1(e), IsNeg(p1(e))), IsAdd(w2(e), IsNeg(p2(e))))),IntWidth(1)), IsMax(p1(e), p2(e)))
+ case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) => IntervalType(IsAdd(l1, IsNeg(u2)), IsAdd(u1, IsNeg(l2)), IsMax(p1, p2))
+ case _ => UnknownType
+ }
+ override def toString = "sub"
+ }
+
/** Multiplication */
- case object Mul extends PrimOp { override def toString = "mul" }
+ case object Mul extends PrimOp {
+ override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
+ case (_: UIntType, _: UIntType) => UIntType(IsAdd(w1(e), w2(e)))
+ case (_: SIntType, _: SIntType) => SIntType(IsAdd(w1(e), w2(e)))
+ case (_: FixedType, _: FixedType) => FixedType(IsAdd(w1(e), w2(e)), IsAdd(p1(e), p2(e)))
+ case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) =>
+ IntervalType(
+ IsMin(Seq(IsMul(l1, l2), IsMul(l1, u2), IsMul(u1, l2), IsMul(u1, u2))),
+ IsMax(Seq(IsMul(l1, l2), IsMul(l1, u2), IsMul(u1, l2), IsMul(u1, u2))),
+ IsAdd(p1, p2)
+ )
+ case _ => UnknownType
+ }
+ override def toString = "mul" }
+
/** Division */
- case object Div extends PrimOp { override def toString = "div" }
+ case object Div extends PrimOp {
+ override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
+ case (_: UIntType, _: UIntType) => UIntType(w1(e))
+ case (_: SIntType, _: SIntType) => SIntType(IsAdd(w1(e), IntWidth(1)))
+ case _ => UnknownType
+ }
+ override def toString = "div" }
+
/** Remainder */
- case object Rem extends PrimOp { override def toString = "rem" }
+ case object Rem extends PrimOp {
+ override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
+ case (_: UIntType, _: UIntType) => UIntType(MIN(w1(e), w2(e)))
+ case (_: SIntType, _: SIntType) => SIntType(MIN(w1(e), w2(e)))
+ case _ => UnknownType
+ }
+ override def toString = "rem" }
/** Less Than */
- case object Lt extends PrimOp { override def toString = "lt" }
+ case object Lt extends PrimOp {
+ override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
+ case (_: UIntType, _: UIntType) => Utils.BoolType
+ case (_: SIntType, _: SIntType) => Utils.BoolType
+ case (_: FixedType, _: FixedType) => Utils.BoolType
+ case (_: IntervalType, _: IntervalType) => Utils.BoolType
+ case _ => UnknownType
+ }
+ override def toString = "lt" }
/** Less Than Or Equal To */
- case object Leq extends PrimOp { override def toString = "leq" }
+ case object Leq extends PrimOp {
+ override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
+ case (_: UIntType, _: UIntType) => Utils.BoolType
+ case (_: SIntType, _: SIntType) => Utils.BoolType
+ case (_: FixedType, _: FixedType) => Utils.BoolType
+ case (_: IntervalType, _: IntervalType) => Utils.BoolType
+ case _ => UnknownType
+ }
+ override def toString = "leq" }
/** Greater Than */
- case object Gt extends PrimOp { override def toString = "gt" }
+ case object Gt extends PrimOp {
+ override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
+ case (_: UIntType, _: UIntType) => Utils.BoolType
+ case (_: SIntType, _: SIntType) => Utils.BoolType
+ case (_: FixedType, _: FixedType) => Utils.BoolType
+ case (_: IntervalType, _: IntervalType) => Utils.BoolType
+ case _ => UnknownType
+ }
+ override def toString = "gt" }
/** Greater Than Or Equal To */
- case object Geq extends PrimOp { override def toString = "geq" }
+ case object Geq extends PrimOp {
+ override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
+ case (_: UIntType, _: UIntType) => Utils.BoolType
+ case (_: SIntType, _: SIntType) => Utils.BoolType
+ case (_: FixedType, _: FixedType) => Utils.BoolType
+ case (_: IntervalType, _: IntervalType) => Utils.BoolType
+ case _ => UnknownType
+ }
+ override def toString = "geq" }
/** Equal To */
- case object Eq extends PrimOp { override def toString = "eq" }
+ case object Eq extends PrimOp {
+ override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
+ case (_: UIntType, _: UIntType) => Utils.BoolType
+ case (_: SIntType, _: SIntType) => Utils.BoolType
+ case (_: FixedType, _: FixedType) => Utils.BoolType
+ case (_: IntervalType, _: IntervalType) => Utils.BoolType
+ case _ => UnknownType
+ }
+ override def toString = "eq" }
/** Not Equal To */
- case object Neq extends PrimOp { override def toString = "neq" }
+ case object Neq extends PrimOp {
+ override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
+ case (_: UIntType, _: UIntType) => Utils.BoolType
+ case (_: SIntType, _: SIntType) => Utils.BoolType
+ case (_: FixedType, _: FixedType) => Utils.BoolType
+ case (_: IntervalType, _: IntervalType) => Utils.BoolType
+ case _ => UnknownType
+ }
+ override def toString = "neq" }
/** Padding */
- case object Pad extends PrimOp { override def toString = "pad" }
- /** Interpret As UInt */
- case object AsUInt extends PrimOp { override def toString = "asUInt" }
- /** Interpret As SInt */
- case object AsSInt extends PrimOp { override def toString = "asSInt" }
- /** Interpret As Clock */
- case object AsClock extends PrimOp { override def toString = "asClock" }
- /** Interpret As AsyncReset */
- case object AsAsyncReset extends PrimOp { override def toString = "asAsyncReset" }
+ case object Pad extends PrimOp {
+ override def propagateType(e: DoPrim): Type = t1(e) match {
+ case _: UIntType => UIntType(IsMax(w1(e), c1(e)))
+ case _: SIntType => SIntType(IsMax(w1(e), c1(e)))
+ case _: FixedType => FixedType(IsMax(w1(e), c1(e)), p1(e))
+ case _ => UnknownType
+ }
+ override def toString = "pad" }
/** Static Shift Left */
- case object Shl extends PrimOp { override def toString = "shl" }
+ case object Shl extends PrimOp {
+ override def propagateType(e: DoPrim): Type = t1(e) match {
+ case _: UIntType => UIntType(IsAdd(w1(e), c1(e)))
+ case _: SIntType => SIntType(IsAdd(w1(e), c1(e)))
+ case _: FixedType => FixedType(IsAdd(w1(e),c1(e)), p1(e))
+ case IntervalType(l, u, p) => IntervalType(IsMul(l, Closed(BigDecimal(BigInt(1) << o1(e).toInt))), IsMul(u, Closed(BigDecimal(BigInt(1) << o1(e).toInt))), p)
+ case _ => UnknownType
+ }
+ override def toString = "shl" }
/** Static Shift Right */
- case object Shr extends PrimOp { override def toString = "shr" }
+ case object Shr extends PrimOp {
+ override def propagateType(e: DoPrim): Type = t1(e) match {
+ case _: UIntType => UIntType(IsMax(IsAdd(w1(e), IsNeg(c1(e))), IntWidth(1)))
+ case _: SIntType => SIntType(IsMax(IsAdd(w1(e), IsNeg(c1(e))), IntWidth(1)))
+ case _: FixedType => FixedType(IsMax(IsMax(IsAdd(w1(e), IsNeg(c1(e))), IntWidth(1)), p1(e)), p1(e))
+ case IntervalType(l, u, IntWidth(p)) =>
+ val shiftMul = Closed(BigDecimal(1) / BigDecimal(BigInt(1) << o1(e).toInt))
+ // BP is inferred at this point
+ val bpRes = Closed(BigDecimal(1) / BigDecimal(BigInt(1) << p.toInt))
+ val bpResInv = Closed(BigDecimal(BigInt(1) << p.toInt))
+ val newL = IsMul(IsFloor(IsMul(IsMul(l, shiftMul), bpResInv)), bpRes)
+ val newU = IsMul(IsFloor(IsMul(IsMul(u, shiftMul), bpResInv)), bpRes)
+ // BP doesn't grow
+ IntervalType(newL, newU, IntWidth(p))
+ case _ => UnknownType
+ }
+ override def toString = "shr"
+ }
/** Dynamic Shift Left */
- case object Dshl extends PrimOp { override def toString = "dshl" }
+ case object Dshl extends PrimOp {
+ override def propagateType(e: DoPrim): Type = t1(e) match {
+ case _: UIntType => UIntType(IsAdd(w1(e), IsAdd(IsPow(w2(e)), Closed(-1))))
+ case _: SIntType => SIntType(IsAdd(w1(e), IsAdd(IsPow(w2(e)), Closed(-1))))
+ case _: FixedType => FixedType(IsAdd(w1(e), IsAdd(IsPow(w2(e)), Closed(-1))), p1(e))
+ case IntervalType(l, u, p) =>
+ val maxShiftAmt = IsAdd(IsPow(w2(e)), Closed(-1))
+ val shiftMul = IsPow(maxShiftAmt)
+ // Magnitude matters! i.e. if l is negative, shifting by the largest amount makes the outcome more negative
+ // whereas if l is positive, shifting by the largest amount makes the outcome more positive (in this case, the lower bound is the previous l)
+ val newL = IsMin(l, IsMul(l, shiftMul))
+ val newU = IsMax(u, IsMul(u, shiftMul))
+ // BP doesn't grow
+ IntervalType(newL, newU, p)
+ case _ => UnknownType
+ }
+ override def toString = "dshl"
+ }
/** Dynamic Shift Right */
- case object Dshr extends PrimOp { override def toString = "dshr" }
+ case object Dshr extends PrimOp {
+ override def propagateType(e: DoPrim): Type = t1(e) match {
+ case _: UIntType => UIntType(w1(e))
+ case _: SIntType => SIntType(w1(e))
+ case _: FixedType => FixedType(w1(e), p1(e))
+ // Decreasing magnitude -- don't need more bits
+ case IntervalType(l, u, p) => IntervalType(l, u, p)
+ case _ => UnknownType
+ }
+ override def toString = "dshr"
+ }
/** Arithmetic Convert to Signed */
- case object Cvt extends PrimOp { override def toString = "cvt" }
+ case object Cvt extends PrimOp {
+ override def propagateType(e: DoPrim): Type = t1(e) match {
+ case _: UIntType => SIntType(IsAdd(w1(e), IntWidth(1)))
+ case _: SIntType => SIntType(w1(e))
+ case _ => UnknownType
+ }
+ override def toString = "cvt"
+ }
/** Negate */
- case object Neg extends PrimOp { override def toString = "neg" }
+ case object Neg extends PrimOp {
+ override def propagateType(e: DoPrim): Type = t1(e) match {
+ case _: UIntType => SIntType(IsAdd(w1(e), IntWidth(1)))
+ case _: SIntType => SIntType(IsAdd(w1(e), IntWidth(1)))
+ case _ => UnknownType
+ }
+ override def toString = "neg"
+ }
/** Bitwise Complement */
- case object Not extends PrimOp { override def toString = "not" }
+ case object Not extends PrimOp {
+ override def propagateType(e: DoPrim): Type = t1(e) match {
+ case _: UIntType => UIntType(w1(e))
+ case _: SIntType => UIntType(w1(e))
+ case _ => UnknownType
+ }
+ override def toString = "not"
+ }
/** Bitwise And */
- case object And extends PrimOp { override def toString = "and" }
+ case object And extends PrimOp {
+ override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
+ case (_: SIntType | _: UIntType, _: SIntType | _: UIntType) => UIntType(IsMax(w1(e), w2(e)))
+ case _ => UnknownType
+ }
+ override def toString = "and"
+ }
/** Bitwise Or */
- case object Or extends PrimOp { override def toString = "or" }
+ case object Or extends PrimOp {
+ override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
+ case (_: SIntType | _: UIntType, _: SIntType | _: UIntType) => UIntType(IsMax(w1(e), w2(e)))
+ case _ => UnknownType
+ }
+ override def toString = "or"
+ }
/** Bitwise Exclusive Or */
- case object Xor extends PrimOp { override def toString = "xor" }
+ case object Xor extends PrimOp {
+ override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
+ case (_: SIntType | _: UIntType, _: SIntType | _: UIntType) => UIntType(IsMax(w1(e), w2(e)))
+ case _ => UnknownType
+ }
+ override def toString = "xor"
+ }
/** Bitwise And Reduce */
- case object Andr extends PrimOp { override def toString = "andr" }
+ case object Andr extends PrimOp {
+ override def propagateType(e: DoPrim): Type = t1(e) match {
+ case (_: UIntType | _: SIntType) => Utils.BoolType
+ case _ => UnknownType
+ }
+ override def toString = "andr"
+ }
/** Bitwise Or Reduce */
- case object Orr extends PrimOp { override def toString = "orr" }
+ case object Orr extends PrimOp {
+ override def propagateType(e: DoPrim): Type = t1(e) match {
+ case (_: UIntType | _: SIntType) => Utils.BoolType
+ case _ => UnknownType
+ }
+ override def toString = "orr"
+ }
/** Bitwise Exclusive Or Reduce */
- case object Xorr extends PrimOp { override def toString = "xorr" }
+ case object Xorr extends PrimOp {
+ override def propagateType(e: DoPrim): Type = t1(e) match {
+ case (_: UIntType | _: SIntType) => Utils.BoolType
+ case _ => UnknownType
+ }
+ override def toString = "xorr"
+ }
/** Concatenate */
- case object Cat extends PrimOp { override def toString = "cat" }
+ case object Cat extends PrimOp {
+ override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
+ case (_: UIntType | _: SIntType | _: FixedType | _: IntervalType, _: UIntType | _: SIntType | _: FixedType | _: IntervalType) => UIntType(IsAdd(w1(e), w2(e)))
+ case (t1, t2) => UnknownType
+ }
+ override def toString = "cat"
+ }
/** Bit Extraction */
- case object Bits extends PrimOp { override def toString = "bits" }
+ case object Bits extends PrimOp {
+ override def propagateType(e: DoPrim): Type = t1(e) match {
+ case (_: UIntType | _: SIntType | _: FixedType | _: IntervalType) => UIntType(IsAdd(IsAdd(c1(e), IsNeg(c2(e))), IntWidth(1)))
+ case _ => UnknownType
+ }
+ override def toString = "bits"
+ }
/** Head */
- case object Head extends PrimOp { override def toString = "head" }
+ case object Head extends PrimOp {
+ override def propagateType(e: DoPrim): Type = t1(e) match {
+ case (_: UIntType | _: SIntType | _: FixedType | _: IntervalType) => UIntType(c1(e))
+ case _ => UnknownType
+ }
+ override def toString = "head"
+ }
/** Tail */
- case object Tail extends PrimOp { override def toString = "tail" }
+ case object Tail extends PrimOp {
+ override def propagateType(e: DoPrim): Type = t1(e) match {
+ case (_: UIntType | _: SIntType | _: FixedType | _: IntervalType) => UIntType(IsAdd(w1(e), IsNeg(c1(e))))
+ case _ => UnknownType
+ }
+ override def toString = "tail"
+ }
+ /** Increase Precision **/
+ case object IncP extends PrimOp {
+ override def propagateType(e: DoPrim): Type = t1(e) match {
+ case _: FixedType => FixedType(IsAdd(w1(e),c1(e)), IsAdd(p1(e), c1(e)))
+ // Keeps the same exact value, but adds more precision for the future i.e. aaa.bbb -> aaa.bbb00
+ case IntervalType(l, u, p) => IntervalType(l, u, IsAdd(p, c1(e)))
+ case _ => UnknownType
+ }
+ override def toString = "incp"
+ }
+ /** Decrease Precision **/
+ case object DecP extends PrimOp {
+ override def propagateType(e: DoPrim): Type = t1(e) match {
+ case _: FixedType => FixedType(IsAdd(w1(e),IsNeg(c1(e))), IsAdd(p1(e), IsNeg(c1(e))))
+ case IntervalType(l, u, IntWidth(p)) =>
+ val shiftMul = Closed(BigDecimal(1) / BigDecimal(BigInt(1) << o1(e).toInt))
+ // BP is inferred at this point
+ // newBPRes is the only difference in calculating bpshr from shr
+ // y = floor(x * 2^(-amt + bp)) gets rid of precision --> y * 2^(-bp + amt)
+ // without amt, same op as shr
+ val newBPRes = Closed(BigDecimal(BigInt(1) << o1(e).toInt) / BigDecimal(BigInt(1) << p.toInt))
+ val bpResInv = Closed(BigDecimal(BigInt(1) << p.toInt))
+ val newL = IsMul(IsFloor(IsMul(IsMul(l, shiftMul), bpResInv)), newBPRes)
+ val newU = IsMul(IsFloor(IsMul(IsMul(u, shiftMul), bpResInv)), newBPRes)
+ // BP doesn't grow
+ IntervalType(newL, newU, IsAdd(IntWidth(p), IsNeg(c1(e))))
+ case _ => UnknownType
+ }
+ override def toString = "decp"
+ }
+ /** Set Precision **/
+ case object SetP extends PrimOp {
+ override def propagateType(e: DoPrim): Type = t1(e) match {
+ case _: FixedType => FixedType(IsAdd(c1(e), IsAdd(w1(e), IsNeg(p1(e)))), c1(e))
+ case IntervalType(l, u, p) =>
+ val newBPResInv = Closed(BigDecimal(BigInt(1) << o1(e).toInt))
+ val newBPRes = Closed(BigDecimal(1) / BigDecimal(BigInt(1) << o1(e).toInt))
+ val newL = IsMul(IsFloor(IsMul(l, newBPResInv)), newBPRes)
+ val newU = IsMul(IsFloor(IsMul(u, newBPResInv)), newBPRes)
+ IntervalType(newL, newU, c1(e))
+ case _ => UnknownType
+ }
+ override def toString = "setp"
+ }
+ /** Interpret As UInt */
+ case object AsUInt extends PrimOp {
+ override def propagateType(e: DoPrim): Type = t1(e) match {
+ case _: UIntType => UIntType(w1(e))
+ case _: SIntType => UIntType(w1(e))
+ case _: FixedType => UIntType(w1(e))
+ case ClockType => UIntType(IntWidth(1))
+ case AsyncResetType => UIntType(IntWidth(1))
+ case ResetType => UIntType(IntWidth(1))
+ case AnalogType(w) => UIntType(w1(e))
+ case _: IntervalType => UIntType(w1(e))
+ case _ => UnknownType
+ }
+ override def toString = "asUInt"
+ }
+ /** Interpret As SInt */
+ case object AsSInt extends PrimOp {
+ override def propagateType(e: DoPrim): Type = t1(e) match {
+ case _: UIntType => SIntType(w1(e))
+ case _: SIntType => SIntType(w1(e))
+ case _: FixedType => SIntType(w1(e))
+ case ClockType => SIntType(IntWidth(1))
+ case AsyncResetType => SIntType(IntWidth(1))
+ case ResetType => SIntType(IntWidth(1))
+ case _: AnalogType => SIntType(w1(e))
+ case _: IntervalType => SIntType(w1(e))
+ case _ => UnknownType
+ }
+ override def toString = "asSInt"
+ }
+ /** Interpret As Clock */
+ case object AsClock extends PrimOp {
+ override def propagateType(e: DoPrim): Type = t1(e) match {
+ case _: UIntType => ClockType
+ case _: SIntType => ClockType
+ case ClockType => ClockType
+ case AsyncResetType => ClockType
+ case ResetType => ClockType
+ case _: AnalogType => ClockType
+ case _: IntervalType => ClockType
+ case _ => UnknownType
+ }
+ override def toString = "asClock"
+ }
+ /** Interpret As AsyncReset */
+ case object AsAsyncReset extends PrimOp {
+ override def propagateType(e: DoPrim): Type = t1(e) match {
+ case _: UIntType | _: SIntType | _: AnalogType | ClockType | AsyncResetType | ResetType | _: IntervalType | _: FixedType => AsyncResetType
+ case _ => UnknownType
+ }
+ override def toString = "asAsyncReset"
+ }
/** Interpret as Fixed Point **/
- case object AsFixedPoint extends PrimOp { override def toString = "asFixedPoint" }
- /** Shift Binary Point Left **/
- case object BPShl extends PrimOp { override def toString = "bpshl" }
- /** Shift Binary Point Right **/
- case object BPShr extends PrimOp { override def toString = "bpshr" }
- /** Set Binary Point **/
- case object BPSet extends PrimOp { override def toString = "bpset" }
+ case object AsFixedPoint extends PrimOp {
+ override def propagateType(e: DoPrim): Type = t1(e) match {
+ case _: UIntType => FixedType(w1(e), c1(e))
+ case _: SIntType => FixedType(w1(e), c1(e))
+ case _: FixedType => FixedType(w1(e), c1(e))
+ case ClockType => FixedType(IntWidth(1), c1(e))
+ case _: AnalogType => FixedType(w1(e), c1(e))
+ case AsyncResetType => FixedType(IntWidth(1), c1(e))
+ case ResetType => FixedType(IntWidth(1), c1(e))
+ case _: IntervalType => FixedType(w1(e), c1(e))
+ case _ => UnknownType
+ }
+ override def toString = "asFixedPoint"
+ }
+ /** Interpret as Interval (closed lower bound, closed upper bound, binary point) **/
+ case object AsInterval extends PrimOp {
+ override def propagateType(e: DoPrim): Type = t1(e) match {
+ // Chisel shifts up and rounds first.
+ case _: UIntType | _: SIntType | _: FixedType | ClockType | AsyncResetType | ResetType | _: AnalogType | _: IntervalType =>
+ IntervalType(Closed(BigDecimal(o1(e))/BigDecimal(BigInt(1) << o3(e).toInt)), Closed(BigDecimal(o2(e))/BigDecimal(BigInt(1) << o3(e).toInt)), IntWidth(o3(e)))
+ case _ => UnknownType
+ }
+ override def toString = "asInterval"
+ }
+ /** Try to fit the first argument into the type of the smaller argument **/
+ case object Squeeze extends PrimOp {
+ override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
+ case (IntervalType(l1, u1, p1), IntervalType(l2, u2, _)) =>
+ val low = IsMax(l1, l2)
+ val high = IsMin(u1, u2)
+ IntervalType(IsMin(low, u2), IsMax(l2, high), p1)
+ case _ => UnknownType
+ }
+ override def toString = "squz"
+ }
+ /** Wrap First Operand Around Range/Width of Second Operand **/
+ case object Wrap extends PrimOp {
+ override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
+ case (IntervalType(l1, u1, p1), IntervalType(l2, u2, _)) => IntervalType(l2, u2, p1)
+ case _ => UnknownType
+ }
+ override def toString = "wrap"
+ }
+ /** Clip First Operand At Range/Width of Second Operand **/
+ case object Clip extends PrimOp {
+ override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
+ case (IntervalType(l1, u1, p1), IntervalType(l2, u2, _)) =>
+ val low = IsMax(l1, l2)
+ val high = IsMin(u1, u2)
+ IntervalType(IsMin(low, u2), IsMax(l2, high), p1)
+ case _ => UnknownType
+ }
+ override def toString = "clip"
+ }
private lazy val builtinPrimOps: Seq[PrimOp] =
- Seq(Add, Sub, Mul, Div, Rem, Lt, Leq, Gt, Geq, Eq, Neq, Pad, AsUInt, AsSInt, AsClock,
- AsAsyncReset, Shl, Shr, Dshl, Dshr, Neg, Cvt, Not, And, Or, Xor, Andr, Orr, Xorr, Cat, Bits,
- Head, Tail, AsFixedPoint, BPShl, BPShr, BPSet)
- private lazy val strToPrimOp: Map[String, PrimOp] = builtinPrimOps.map { case op : PrimOp=> op.toString -> op }.toMap
+ Seq(Add, Sub, Mul, Div, Rem, Lt, Leq, Gt, Geq, Eq, Neq, Pad, AsUInt, AsSInt, AsInterval, AsClock, AsAsyncReset, Shl, Shr,
+ Dshl, Dshr, Neg, Cvt, Not, And, Or, Xor, Andr, Orr, Xorr, Cat, Bits, Head, Tail, AsFixedPoint, IncP, DecP,
+ SetP, Wrap, Clip, Squeeze)
+ private lazy val strToPrimOp: Map[String, PrimOp] = {
+ builtinPrimOps.map { case op : PrimOp=> op.toString -> op }.toMap
+ }
/** Seq of String representations of [[ir.PrimOp]]s */
lazy val listing: Seq[String] = builtinPrimOps map (_.toString)
@@ -96,169 +491,10 @@ object PrimOps extends LazyLogging {
def fromString(op: String): PrimOp = strToPrimOp(op)
// Width Constraint Functions
- def PLUS (w1:Width, w2:Width) : Width = (w1, w2) match {
- case (IntWidth(i), IntWidth(j)) => IntWidth(i + j)
- case _ => PlusWidth(w1, w2)
- }
- def MAX (w1:Width, w2:Width) : Width = (w1, w2) match {
- case (IntWidth(i), IntWidth(j)) => IntWidth(max(i,j))
- case _ => MaxWidth(Seq(w1, w2))
- }
- def MINUS (w1:Width, w2:Width) : Width = (w1, w2) match {
- case (IntWidth(i), IntWidth(j)) => IntWidth(i - j)
- case _ => MinusWidth(w1, w2)
- }
- def POW (w1:Width) : Width = w1 match {
- case IntWidth(i) => IntWidth(pow_minus_one(BigInt(2), i))
- case _ => ExpWidth(w1)
- }
- def MIN (w1:Width, w2:Width) : Width = (w1, w2) match {
- case (IntWidth(i), IntWidth(j)) => IntWidth(min(i,j))
- case _ => MinWidth(Seq(w1, w2))
- }
+ def PLUS(w1: Width, w2: Width): Constraint = IsAdd(w1, w2)
+ def MAX(w1: Width, w2: Width): Constraint = IsMax(w1, w2)
+ def MINUS(w1: Width, w2: Width): Constraint = IsAdd(w1, IsNeg(w2))
+ def MIN(w1: Width, w2: Width): Constraint = IsMin(w1, w2)
- // Borrowed from Stanza implementation
- def set_primop_type (e:DoPrim) : DoPrim = {
- //println-all(["Inferencing primop type: " e])
- def t1 = e.args.head.tpe
- def t2 = e.args(1).tpe
- def w1 = getWidth(e.args.head.tpe)
- def w2 = getWidth(e.args(1).tpe)
- def p1 = t1 match { case FixedType(w, p) => p } //Intentional
- def p2 = t2 match { case FixedType(w, p) => p } //Intentional
- def c1 = IntWidth(e.consts.head)
- def c2 = IntWidth(e.consts(1))
- e copy (tpe = e.op match {
- case Add | Sub => (t1, t2) match {
- case (_: UIntType, _: UIntType) => UIntType(PLUS(MAX(w1, w2), IntWidth(1)))
- case (_: SIntType, _: SIntType) => SIntType(PLUS(MAX(w1, w2), IntWidth(1)))
- case (_: FixedType, _: FixedType) => FixedType(PLUS(PLUS(MAX(p1, p2), MAX(MINUS(w1, p1), MINUS(w2, p2))), IntWidth(1)), MAX(p1, p2))
- case _ => UnknownType
- }
- case Mul => (t1, t2) match {
- case (_: UIntType, _: UIntType) => UIntType(PLUS(w1, w2))
- case (_: SIntType, _: SIntType) => SIntType(PLUS(w1, w2))
- case (_: FixedType, _: FixedType) => FixedType(PLUS(w1, w2), PLUS(p1, p2))
- case _ => UnknownType
- }
- case Div => (t1, t2) match {
- case (_: UIntType, _: UIntType) => UIntType(w1)
- case (_: SIntType, _: SIntType) => SIntType(PLUS(w1, IntWidth(1)))
- case _ => UnknownType
- }
- case Rem => (t1, t2) match {
- case (_: UIntType, _: UIntType) => UIntType(MIN(w1, w2))
- case (_: SIntType, _: SIntType) => SIntType(MIN(w1, w2))
- case _ => UnknownType
- }
- case Lt | Leq | Gt | Geq | Eq | Neq => (t1, t2) match {
- case (_: UIntType, _: UIntType) => Utils.BoolType
- case (_: SIntType, _: SIntType) => Utils.BoolType
- case (_: FixedType, _: FixedType) => Utils.BoolType
- case _ => UnknownType
- }
- case Pad => t1 match {
- case _: UIntType => UIntType(MAX(w1, c1))
- case _: SIntType => SIntType(MAX(w1, c1))
- case _: FixedType => FixedType(MAX(w1, c1), p1)
- case _ => UnknownType
- }
- case AsUInt => t1 match {
- case (_: UIntType | _: SIntType | _: FixedType | _: AnalogType) => UIntType(w1)
- case ClockType | AsyncResetType | ResetType => UIntType(IntWidth(1))
- case _ => UnknownType
- }
- case AsSInt => t1 match {
- case (_: UIntType | _: SIntType | _: FixedType | _: AnalogType) => SIntType(w1)
- case ClockType | AsyncResetType | ResetType => SIntType(IntWidth(1))
- case _ => UnknownType
- }
- case AsFixedPoint => t1 match {
- case (_: UIntType | _: SIntType | _: FixedType | _: AnalogType) => FixedType(w1, c1)
- case ClockType | AsyncResetType | ResetType => FixedType(IntWidth(1), c1)
- case _ => UnknownType
- }
- case AsClock => t1 match {
- case (_: UIntType | _: SIntType | _: AnalogType | ClockType | AsyncResetType | ResetType) => ClockType
- case _ => UnknownType
- }
- case AsAsyncReset => t1 match {
- case (_: UIntType | _: SIntType | _: AnalogType | ClockType | AsyncResetType | ResetType) => AsyncResetType
- case _ => UnknownType
- }
- case Shl => t1 match {
- case _: UIntType => UIntType(PLUS(w1, c1))
- case _: SIntType => SIntType(PLUS(w1, c1))
- case _: FixedType => FixedType(PLUS(w1,c1), p1)
- case _ => UnknownType
- }
- case Shr => t1 match {
- case _: UIntType => UIntType(MAX(MINUS(w1, c1), IntWidth(1)))
- case _: SIntType => SIntType(MAX(MINUS(w1, c1), IntWidth(1)))
- case _: FixedType => FixedType(MAX(MAX(MINUS(w1,c1), IntWidth(1)), p1), p1)
- case _ => UnknownType
- }
- case Dshl => t1 match {
- case _: UIntType => UIntType(PLUS(w1, POW(w2)))
- case _: SIntType => SIntType(PLUS(w1, POW(w2)))
- case _: FixedType => FixedType(PLUS(w1, POW(w2)), p1)
- case _ => UnknownType
- }
- case Dshr => t1 match {
- case _: UIntType => UIntType(w1)
- case _: SIntType => SIntType(w1)
- case _: FixedType => FixedType(w1, p1)
- case _ => UnknownType
- }
- case Cvt => t1 match {
- case _: UIntType => SIntType(PLUS(w1, IntWidth(1)))
- case _: SIntType => SIntType(w1)
- case _ => UnknownType
- }
- case Neg => t1 match {
- case (_: UIntType | _: SIntType) => SIntType(PLUS(w1, IntWidth(1)))
- case _ => UnknownType
- }
- case Not => t1 match {
- case (_: UIntType | _: SIntType) => UIntType(w1)
- case _ => UnknownType
- }
- case And | Or | Xor => (t1, t2) match {
- case (_: SIntType | _: UIntType, _: SIntType | _: UIntType) => UIntType(MAX(w1, w2))
- case _ => UnknownType
- }
- case Andr | Orr | Xorr => t1 match {
- case (_: UIntType | _: SIntType) => Utils.BoolType
- case _ => UnknownType
- }
- case Cat => (t1, t2) match {
- case (_: UIntType | _: SIntType | _: FixedType, _: UIntType | _: SIntType | _: FixedType) => UIntType(PLUS(w1, w2))
- case (t1, t2) => UnknownType
- }
- case Bits => t1 match {
- case (_: UIntType | _: SIntType | _: FixedType) => UIntType(PLUS(MINUS(c1, c2), IntWidth(1)))
- case _ => UnknownType
- }
- case Head => t1 match {
- case (_: UIntType | _: SIntType | _: FixedType) => UIntType(c1)
- case _ => UnknownType
- }
- case Tail => t1 match {
- case (_: UIntType | _: SIntType | _: FixedType) => UIntType(MINUS(w1, c1))
- case _ => UnknownType
- }
- case BPShl => t1 match {
- case _: FixedType => FixedType(PLUS(w1,c1), PLUS(p1, c1))
- case _ => UnknownType
- }
- case BPShr => t1 match {
- case _: FixedType => FixedType(MINUS(w1,c1), MINUS(p1, c1))
- case _ => UnknownType
- }
- case BPSet => t1 match {
- case _: FixedType => FixedType(PLUS(c1, MINUS(w1, p1)), c1)
- case _ => UnknownType
- }
- })
- }
+ def set_primop_type(e: DoPrim): DoPrim = DoPrim(e.op, e.args, e.consts, e.op.propagateType(e))
}
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index 4bf2d14d..8a76aca6 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -10,6 +10,8 @@ import firrtl.WrappedExpression._
import scala.collection.mutable
import scala.util.matching.Regex
+import Implicits.{constraint2bound, constraint2width, width2constraint}
+import firrtl.constraint.{IsMax, IsMin}
import firrtl.annotations.{ReferenceTarget, TargetToken}
import _root_.logger.LazyLogging
@@ -219,7 +221,10 @@ object Utils extends LazyLogging {
def indent(str: String) = str replaceAllLiterally ("\n", "\n ")
implicit def toWrappedExpression (x:Expression): WrappedExpression = new WrappedExpression(x)
- def ceilLog2(x: BigInt): Int = (x-1).bitLength
+ def getSIntWidth(s: BigInt): Int = s.bitLength + 1
+ def getUIntWidth(u: BigInt): Int = u.bitLength
+ def dec2string(v: BigDecimal): String = v.underlying().stripTrailingZeros().toPlainString
+ def trim(v: BigDecimal): BigDecimal = BigDecimal(dec2string(v))
def max(a: BigInt, b: BigInt): BigInt = if (a >= b) a else b
def min(a: BigInt, b: BigInt): BigInt = if (a >= b) b else a
def pow_minus_one(a: BigInt, b: BigInt): BigInt = a.pow(b.toInt) - 1
@@ -375,6 +380,7 @@ object Utils extends LazyLogging {
case (t1: UIntType, t2: UIntType) => UIntType(UnknownWidth)
case (t1: SIntType, t2: SIntType) => SIntType(UnknownWidth)
case (t1: FixedType, t2: FixedType) => FixedType(UnknownWidth, UnknownWidth)
+ case (t1: IntervalType, t2: IntervalType) => IntervalType(UnknownBound, UnknownBound, UnknownWidth)
case (t1: VectorType, t2: VectorType) => VectorType(mux_type(t1.tpe, t2.tpe), t1.size)
case (t1: BundleType, t2: BundleType) => BundleType(t1.fields zip t2.fields map {
case (f1, f2) => Field(f1.name, f1.flip, mux_type(f1.tpe, f2.tpe))
@@ -386,15 +392,17 @@ object Utils extends LazyLogging {
def mux_type_and_widths(t1: Type, t2: Type): Type = {
def wmax(w1: Width, w2: Width): Width = (w1, w2) match {
case (w1x: IntWidth, w2x: IntWidth) => IntWidth(w1x.width max w2x.width)
- case (w1x, w2x) => MaxWidth(Seq(w1x, w2x))
+ case (w1x, w2x) => IsMax(w1x, w2x)
}
(t1, t2) match {
case (ClockType, ClockType) => ClockType
case (AsyncResetType, AsyncResetType) => AsyncResetType
- case (t1x: UIntType, t2x: UIntType) => UIntType(wmax(t1x.width, t2x.width))
- case (t1x: SIntType, t2x: SIntType) => SIntType(wmax(t1x.width, t2x.width))
+ case (t1x: UIntType, t2x: UIntType) => UIntType(IsMax(t1x.width, t2x.width))
+ case (t1x: SIntType, t2x: SIntType) => SIntType(IsMax(t1x.width, t2x.width))
case (FixedType(w1, p1), FixedType(w2, p2)) =>
FixedType(PLUS(MAX(p1, p2),MAX(MINUS(w1, p1), MINUS(w2, p2))), MAX(p1, p2))
+ case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) =>
+ IntervalType(IsMin(l1, l2), constraint.IsMax(u1, u2), MAX(p1, p2))
case (t1x: VectorType, t2x: VectorType) => VectorType(
mux_type_and_widths(t1x.tpe, t2x.tpe), t1x.size)
case (t1x: BundleType, t2x: BundleType) => BundleType(t1x.fields zip t2x.fields map {
diff --git a/src/main/scala/firrtl/Visitor.scala b/src/main/scala/firrtl/Visitor.scala
index 01de8f15..112343d1 100644
--- a/src/main/scala/firrtl/Visitor.scala
+++ b/src/main/scala/firrtl/Visitor.scala
@@ -30,6 +30,7 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
private val HexPattern = """\"*h([+\-]?[a-zA-Z0-9]+)\"*""".r
private val DecPattern = """([+\-]?[1-9]\d*)""".r
private val ZeroPattern = "0".r
+ private val DecimalPattern = """([+\-]?[0-9]\d*\.[0-9]\d*)""".r
private def string2BigInt(s: String): BigInt = {
// private define legal patterns
@@ -41,6 +42,16 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
}
}
+ private def string2BigDecimal(s: String): BigDecimal = {
+ // private define legal patterns
+ s match {
+ case ZeroPattern(_*) => BigDecimal(0)
+ case DecPattern(num) => BigDecimal(num)
+ case DecimalPattern(num) => BigDecimal(num)
+ case _ => throw new Exception("Invalid String for conversion to BigDecimal " + s)
+ }
+ }
+
private def string2Int(s: String): Int = string2BigInt(s).toInt
private def visitInfo(ctx: Option[InfoContext], parentCtx: ParserRuleContext): Info = {
@@ -129,6 +140,30 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
}
case 2 => FixedType(getWidth(ctx.intLit(0)), getWidth(ctx.intLit(1)))
}
+ case "Interval" => ctx.boundValue.size match {
+ case 0 =>
+ val point = ctx.intLit.size match {
+ case 0 => UnknownWidth
+ case 1 => IntWidth(string2BigInt(ctx.intLit(0).getText))
+ }
+ IntervalType(UnknownBound, UnknownBound, point)
+ case 2 =>
+ val lower = (ctx.lowerBound.getText, ctx.boundValue(0).getText) match {
+ case (_, "?") => UnknownBound
+ case ("(", v) => Open(string2BigDecimal(v))
+ case ("[", v) => Closed(string2BigDecimal(v))
+ }
+ val upper = (ctx.upperBound.getText, ctx.boundValue(1).getText) match {
+ case (_, "?") => UnknownBound
+ case (")", v) => Open(string2BigDecimal(v))
+ case ("]", v) => Closed(string2BigDecimal(v))
+ }
+ val point = ctx.intLit.size match {
+ case 0 => UnknownWidth
+ case 1 => IntWidth(string2BigInt(ctx.intLit(0).getText))
+ }
+ IntervalType(lower, upper, point)
+ }
case "Clock" => ClockType
case "AsyncReset" => AsyncResetType
case "Reset" => ResetType
diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala
index 475f5e9c..eb4a665f 100644
--- a/src/main/scala/firrtl/WIR.scala
+++ b/src/main/scala/firrtl/WIR.scala
@@ -162,11 +162,49 @@ case class WDefInstanceConnector(
}
// Resultant width is the same as the maximum input width
-case object Addw extends PrimOp { override def toString = "addw" }
+case object Addw extends PrimOp {
+ override def toString = "addw"
+ import constraint._
+ import PrimOps._
+ import Implicits.{constraint2width, width2constraint}
+
+ override def propagateType(e: DoPrim): Type = {
+ (e.args(0).tpe, e.args(1).tpe) match {
+ case (_: UIntType, _: UIntType) => UIntType(IsMax(w1(e), w2(e)))
+ case (_: SIntType, _: SIntType) => SIntType(IsMax(w1(e), w2(e)))
+ case _ => UnknownType
+ }
+ }
+}
// Resultant width is the same as the maximum input width
-case object Subw extends PrimOp { override def toString = "subw" }
+case object Subw extends PrimOp {
+ override def toString = "subw"
+ import constraint._
+ import PrimOps._
+ import Implicits.{constraint2width, width2constraint}
+
+ override def propagateType(e: DoPrim): Type = {
+ (e.args(0).tpe, e.args(1).tpe) match {
+ case (_: UIntType, _: UIntType) => UIntType(IsMax(w1(e), w2(e)))
+ case (_: SIntType, _: SIntType) => SIntType(IsMax(w1(e), w2(e)))
+ case _ => UnknownType
+ }
+ }
+}
// Resultant width is the same as input argument width
-case object Dshlw extends PrimOp { override def toString = "dshlw" }
+case object Dshlw extends PrimOp {
+ override def toString = "dshlw"
+
+ import PrimOps._
+
+ override def propagateType(e: DoPrim): Type = {
+ e.args(0).tpe match {
+ case _: UIntType => UIntType(w1(e))
+ case _: SIntType => SIntType(w1(e))
+ case _ => UnknownType
+ }
+ }
+}
object WrappedExpression {
def apply(e: Expression) = new WrappedExpression(e)
@@ -200,30 +238,6 @@ class WrappedExpression(val e1: Expression) {
private[firrtl] sealed trait HasMapWidth {
def mapWidth(f: Width => Width): Width
}
-case class VarWidth(name: String) extends Width with HasMapWidth {
- def serialize: String = name
- def mapWidth(f: Width => Width): Width = this
-}
-case class PlusWidth(arg1: Width, arg2: Width) extends Width with HasMapWidth {
- def serialize: String = "(" + arg1.serialize + " + " + arg2.serialize + ")"
- def mapWidth(f: Width => Width): Width = PlusWidth(f(arg1), f(arg2))
-}
-case class MinusWidth(arg1: Width, arg2: Width) extends Width with HasMapWidth {
- def serialize: String = "(" + arg1.serialize + " - " + arg2.serialize + ")"
- def mapWidth(f: Width => Width): Width = MinusWidth(f(arg1), f(arg2))
-}
-case class MaxWidth(args: Seq[Width]) extends Width with HasMapWidth {
- def serialize: String = args map (_.serialize) mkString ("max(", ", ", ")")
- def mapWidth(f: Width => Width): Width = MaxWidth(args map f)
-}
-case class MinWidth(args: Seq[Width]) extends Width with HasMapWidth {
- def serialize: String = args map (_.serialize) mkString ("min(", ", ", ")")
- def mapWidth(f: Width => Width): Width = MinWidth(args map f)
-}
-case class ExpWidth(arg1: Width) extends Width with HasMapWidth {
- def serialize: String = "exp(" + arg1.serialize + " )"
- def mapWidth(f: Width => Width): Width = ExpWidth(f(arg1))
-}
object WrappedType {
def apply(t: Type) = new WrappedType(t)
@@ -240,6 +254,7 @@ object WrappedType {
case (ResetType, tpe) => legalResetType(tpe)
case (tpe, ResetType) => legalResetType(tpe)
case (_: FixedType, _: FixedType) => true
+ case (_: IntervalType, _: IntervalType) => true
// Analog totally skips out of the Firrtl type system.
// The only way Analog can play with another Analog component is through Attach.
// Ohterwise, we'd need to special case it during ExpandWhens, Lowering,
@@ -280,29 +295,13 @@ class WrappedWidth (val w: Width) {
def ww(w: Width): WrappedWidth = new WrappedWidth(w)
override def toString = w match {
case (w: VarWidth) => w.name
- case (w: MaxWidth) => s"max(${w.args.mkString})"
- case (w: MinWidth) => s"min(${w.args.mkString})"
- case (w: PlusWidth) => s"(${w.arg1} + ${w.arg2})"
- case (w: MinusWidth) => s"(${w.arg1} -${w.arg2})"
- case (w: ExpWidth) => s"exp(${w.arg1})"
case (w: IntWidth) => w.width.toString
case UnknownWidth => "?"
}
override def equals(o: Any): Boolean = o match {
case (w2: WrappedWidth) => (w, w2.w) match {
case (w1: VarWidth, w2: VarWidth) => w1.name.equals(w2.name)
- case (w1: MaxWidth, w2: MaxWidth) => w1.args.size == w2.args.size &&
- (w1.args forall (a1 => w2.args exists (a2 => eqw(a1, a2))))
- case (w1: MinWidth, w2: MinWidth) => w1.args.size == w2.args.size &&
- (w1.args forall (a1 => w2.args exists (a2 => eqw(a1, a2))))
case (w1: IntWidth, w2: IntWidth) => w1.width == w2.width
- case (w1: PlusWidth, w2: PlusWidth) =>
- (ww(w1.arg1) == ww(w2.arg1) && ww(w1.arg2) == ww(w2.arg2)) ||
- (ww(w1.arg1) == ww(w2.arg2) && ww(w1.arg2) == ww(w2.arg1))
- case (w1: MinusWidth,w2: MinusWidth) =>
- (ww(w1.arg1) == ww(w2.arg1) && ww(w1.arg2) == ww(w2.arg2)) ||
- (ww(w1.arg1) == ww(w2.arg2) && ww(w1.arg2) == ww(w2.arg1))
- case (w1: ExpWidth, w2: ExpWidth) => ww(w1.arg1) == ww(w2.arg1)
case (UnknownWidth, UnknownWidth) => true
case _ => false
}
@@ -310,18 +309,6 @@ class WrappedWidth (val w: Width) {
}
}
-trait Constraint
-class WGeq(val loc: Width, val exp: Width) extends Constraint {
- override def toString = {
- val wloc = new WrappedWidth(loc)
- val wexp = new WrappedWidth(exp)
- wloc.toString + " >= " + wexp.toString
- }
-}
-object WGeq {
- def apply(loc: Width, exp: Width) = new WGeq(loc, exp)
-}
-
abstract class MPortDir extends FirrtlNode
case object MInfer extends MPortDir {
def serialize: String = "infer"
diff --git a/src/main/scala/firrtl/annotations/Target.scala b/src/main/scala/firrtl/annotations/Target.scala
index 313f1dc2..1571f98e 100644
--- a/src/main/scala/firrtl/annotations/Target.scala
+++ b/src/main/scala/firrtl/annotations/Target.scala
@@ -3,7 +3,7 @@
package firrtl
package annotations
-import firrtl.ir.{Expression, Type}
+import firrtl.ir.{Field => _, _}
import firrtl.Utils.{sub_type, field_type}
import AnnotationUtils.{toExp, validComponentName, validModuleName}
import TargetToken._
@@ -41,7 +41,7 @@ sealed trait Target extends Named {
case Ref(r) => s">$r"
case Instance(i) => s"/$i"
case OfModule(o) => s":$o"
- case Field(f) => s".$f"
+ case TargetToken.Field(f) => s".$f"
case Index(v) => s"[$v]"
case Clock => s"@clock"
case Reset => s"@reset"
@@ -103,6 +103,21 @@ sealed trait Target extends Named {
}
object Target {
+ def asTarget(m: ModuleTarget)(e: Expression): ReferenceTarget = e match {
+ case w: WRef => m.ref(w.name)
+ case r: ir.Reference => m.ref(r.name)
+ case w: WSubIndex => asTarget(m)(w.expr).index(w.value)
+ case s: ir.SubIndex => asTarget(m)(s.expr).index(s.value)
+ case w: WSubField => asTarget(m)(w.expr).field(w.name)
+ case s: ir.SubField => asTarget(m)(s.expr).field(s.name)
+ case w: WSubAccess => asTarget(m)(w.expr).field("@" + w.index.serialize)
+ case s: ir.SubAccess => asTarget(m)(s.expr).field("@" + s.index.serialize)
+ case d: DoPrim => m.ref("@" + d.serialize)
+ case d: Mux => m.ref("@" + d.serialize)
+ case d: ValidIf => m.ref("@" + d.serialize)
+ case d: Literal => m.ref("@" + d.serialize)
+ case other => sys.error(s"Unsupported: $other")
+ }
def apply(circuitOpt: Option[String], moduleOpt: Option[String], reference: Seq[TargetToken]): GenericTarget =
GenericTarget(circuitOpt, moduleOpt, reference.toVector)
diff --git a/src/main/scala/firrtl/constraint/Constraint.scala b/src/main/scala/firrtl/constraint/Constraint.scala
new file mode 100644
index 00000000..247593ee
--- /dev/null
+++ b/src/main/scala/firrtl/constraint/Constraint.scala
@@ -0,0 +1,22 @@
+// See LICENSE for license details.
+
+package firrtl.constraint
+
+/** Trait for all Constraint Solver expressions */
+trait Constraint {
+ def serialize: String
+ def map(f: Constraint => Constraint): Constraint
+ val children: Vector[Constraint]
+ def reduce(): Constraint
+}
+
+/** Trait for constraints with more than one argument */
+trait MultiAry extends Constraint {
+ def op(a: IsKnown, b: IsKnown): IsKnown
+ def merge(b1: Option[IsKnown], b2: Option[IsKnown]): Option[IsKnown] = (b1, b2) match {
+ case (Some(x), Some(y)) => Some(op(x, y))
+ case (_, y: Some[_]) => y
+ case (x: Some[_], _) => x
+ case _ => None
+ }
+}
diff --git a/src/main/scala/firrtl/constraint/ConstraintSolver.scala b/src/main/scala/firrtl/constraint/ConstraintSolver.scala
new file mode 100644
index 00000000..52440b15
--- /dev/null
+++ b/src/main/scala/firrtl/constraint/ConstraintSolver.scala
@@ -0,0 +1,357 @@
+// See LICENSE for license details.
+
+package firrtl.constraint
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Utils.throwInternalError
+import firrtl.annotations.ReferenceTarget
+
+import scala.collection.mutable
+
+/** Forwards-Backwards Constraint Solver
+ *
+ * Used for computing [[Width]] and [[Bound]] constraints
+ *
+ * Note - this is an O(N) algorithm, but requires exponential memory. We rely on aggressive early optimization
+ * of constraint expressions to (usually) get around this.
+ */
+class ConstraintSolver {
+
+ /** Initial, mutable constraint list, with function to add the constraint */
+ private val constraints = mutable.ArrayBuffer[Inequality]()
+
+ /** Solved constraints */
+ type ConstraintMap = mutable.HashMap[String, (Constraint, Boolean)]
+ private val solvedConstraintMap = new ConstraintMap()
+
+
+ /** Clear all previously recorded/solved constraints */
+ def clear(): Unit = {
+ constraints.clear()
+ solvedConstraintMap.clear()
+ }
+
+ /** Updates internal list of inequalities with a new [[GreaterOrEqual]]
+ * @param big The larger constraint, must be either known or a variable
+ * @param small The smaller constraint
+ */
+ def addGeq(big: Constraint, small: Constraint, r1: String, r2: String): Unit = (big, small) match {
+ case (IsVar(name), other: Constraint) => add(GreaterOrEqual(name, other))
+ case _ => // Constraints on widths should never error, e.g. attach adds lots of unnecessary constraints
+ }
+
+ /** Updates internal list of inequalities with a new [[GreaterOrEqual]]
+ * @param big The larger constraint, must be either known or a variable
+ * @param small The smaller constraint
+ */
+ def addGeq(big: Width, small: Width, r1: String, r2: String): Unit = (big, small) match {
+ case (IsVar(name), other: CalcWidth) => add(GreaterOrEqual(name, other.arg))
+ case (IsVar(name), other: IsVar) => add(GreaterOrEqual(name, other))
+ case (IsVar(name), other: IntWidth) => add(GreaterOrEqual(name, Implicits.width2constraint(other)))
+ case _ => // Constraints on widths should never error, e.g. attach adds lots of unnecessary constraints
+ }
+
+ /** Updates internal list of inequalities with a new [[LesserOrEqual]]
+ * @param small The smaller constraint, must be either known or a variable
+ * @param big The larger constraint
+ */
+ def addLeq(small: Constraint, big: Constraint, r1: String, r2: String): Unit = (small, big) match {
+ case (IsVar(name), other: Constraint) => add(LesserOrEqual(name, other))
+ case _ => // Constraints on widths should never error, e.g. attach adds lots of unnecessary constraints
+ }
+
+ /** Updates internal list of inequalities with a new [[LesserOrEqual]]
+ * @param small The smaller constraint, must be either known or a variable
+ * @param big The larger constraint
+ */
+ def addLeq(small: Width, big: Width, r1: String, r2: String): Unit = (small, big) match {
+ case (IsVar(name), other: CalcWidth) => add(LesserOrEqual(name, other.arg))
+ case (IsVar(name), other: IsVar) => add(LesserOrEqual(name, other))
+ case (IsVar(name), other: IntWidth) => add(LesserOrEqual(name, Implicits.width2constraint(other)))
+ case _ => // Constraints on widths should never error, e.g. attach adds lots of unnecessary constraints
+ }
+
+ /** Returns a solved constraint, if it exists and is solved
+ * @param b
+ * @return
+ */
+ def get(b: Constraint): Option[IsKnown] = {
+ val name = b match {
+ case IsVar(name) => name
+ case x => ""
+ }
+ solvedConstraintMap.get(name) match {
+ case None => None
+ case Some((k: IsKnown, _)) => Some(k)
+ case Some(_) => None
+ }
+ }
+
+ /** Returns a solved width, if it exists and is solved
+ * @param b
+ * @return
+ */
+ def get(b: Width): Option[IsKnown] = {
+ val name = b match {
+ case IsVar(name) => name
+ case x => ""
+ }
+ solvedConstraintMap.get(name) match {
+ case None => None
+ case Some((k: IsKnown, _)) => Some(k)
+ case Some(_) => None
+ }
+ }
+
+
+ private def add(c: Inequality) = constraints += c
+
+
+ /** Creates an Inequality given a variable name, constraint, and whether its >= or <=
+ * @param left
+ * @param right
+ * @param geq
+ * @return
+ */
+ private def genConst(left: String, right: Constraint, geq: Boolean): Inequality = geq match {
+ case true => GreaterOrEqual(left, right)
+ case false => LesserOrEqual(left, right)
+ }
+
+ /** For debugging, can serialize the initial constraints */
+ def serializeConstraints: String = constraints.mkString("\n")
+
+ /** For debugging, can serialize the solved constraints */
+ def serializeSolutions: String = solvedConstraintMap.map{
+ case (k, (v, true)) => s"$k >= ${v.serialize}"
+ case (k, (v, false)) => s"$k <= ${v.serialize}"
+ }.mkString("\n")
+
+
+
+ /************* Constraint Solver Engine ****************/
+
+ /** Merges constraints on the same variable
+ *
+ * Returns a new list of Inequalities with a single Inequality per variable
+ *
+ * For example, given:
+ * a >= 1 + b
+ * a >= 3
+ *
+ * Will return:
+ * a >= max(3, 1 + b)
+ *
+ * @param constraints
+ * @return
+ */
+ private def mergeConstraints(constraints: Seq[Inequality]): Seq[Inequality] = {
+ val mergedMap = mutable.HashMap[String, Inequality]()
+ constraints.foreach {
+ case c if c.geq && mergedMap.contains(c.left) =>
+ mergedMap(c.left) = genConst(c.left, IsMax(mergedMap(c.left).right, c.right), true)
+ case c if !c.geq && mergedMap.contains(c.left) =>
+ mergedMap(c.left) = genConst(c.left, IsMin(mergedMap(c.left).right, c.right), false)
+ case c =>
+ mergedMap(c.left) = c
+ }
+ mergedMap.values.toList
+ }
+
+
+ /** Attempts to substitute variables with their corresponding forward-solved constraints
+ * If no corresponding constraint has been visited yet, keep variable as is
+ *
+ * @param forwardSolved ConstraintMap containing earlier forward-solved constraints
+ * @param constraint Constraint to forward solve
+ * @return Forward solved constraint
+ */
+ private def forwardSubstitution(forwardSolved: ConstraintMap)(constraint: Constraint): Constraint = {
+ val x = constraint map forwardSubstitution(forwardSolved)
+ x match {
+ case isVar: IsVar => forwardSolved get isVar.name match {
+ case None => isVar.asInstanceOf[Constraint]
+ case Some((p, geq)) =>
+ val newT = forwardSubstitution(forwardSolved)(p)
+ forwardSolved(isVar.name) = (newT, geq)
+ newT
+ }
+ case other => other
+ }
+ }
+
+ /** Attempts to substitute variables with their corresponding backwards-solved constraints
+ * If no corresponding constraint is solved, keep variable as is (as an unsolved constraint,
+ * which will be reported later)
+ *
+ * @param backwardSolved ConstraintMap containing earlier backward-solved constraints
+ * @param constraint Constraint to backward solve
+ * @return Backward solved constraint
+ */
+ private def backwardSubstitution(backwardSolved: ConstraintMap)(constraint: Constraint): Constraint = {
+ constraint match {
+ case isVar: IsVar => backwardSolved.get(isVar.name) match {
+ case Some((p, geq)) => p
+ case _ => isVar
+ }
+ case other => other map backwardSubstitution(backwardSolved)
+ }
+ }
+
+ /** Remove solvable cycles in an inequality
+ *
+ * For example:
+ * a >= max(1, a)
+ *
+ * Can be simplified to:
+ * a >= 1
+ * @param name Name of the variable on left side of inequality
+ * @param geq Whether inequality is >= or <=
+ * @param constraint Constraint expression
+ * @return
+ */
+ private def removeCycle(name: String, geq: Boolean)(constraint: Constraint): Constraint =
+ if(geq) removeGeqCycle(name)(constraint) else removeLeqCycle(name)(constraint)
+
+ /** Removes solvable cycles of <= inequalities
+ * @param name Name of the variable on left side of inequality
+ * @param constraint Constraint expression
+ * @return
+ */
+ private def removeLeqCycle(name: String)(constraint: Constraint): Constraint = constraint match {
+ case x if greaterEqThan(name)(x) => VarCon(name)
+ case isMin: IsMin => IsMin(isMin.children.filter{ c => !greaterEqThan(name)(c)})
+ case x => x
+ }
+
+ /** Removes solvable cycles of >= inequalities
+ * @param name Name of the variable on left side of inequality
+ * @param constraint Constraint expression
+ * @return
+ */
+ private def removeGeqCycle(name: String)(constraint: Constraint): Constraint = constraint match {
+ case x if lessEqThan(name)(x) => VarCon(name)
+ case isMax: IsMax => IsMax(isMax.children.filter{c => !lessEqThan(name)(c)})
+ case x => x
+ }
+
+ private def greaterEqThan(name: String)(constraint: Constraint): Boolean = constraint match {
+ case isMin: IsMin => isMin.children.map(greaterEqThan(name)).reduce(_ && _)
+ case isAdd: IsAdd => isAdd.children match {
+ case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value >= 0) => true
+ case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value >= 0) => true
+ case _ => false
+ }
+ case isMul: IsMul => isMul.children match {
+ case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value >= 0) => true
+ case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value >= 0) => true
+ case _ => false
+ }
+ case isVar: IsVar if isVar.name == name => true
+ case _ => false
+ }
+
+ private def lessEqThan(name: String)(constraint: Constraint): Boolean = constraint match {
+ case isMax: IsMax => isMax.children.map(lessEqThan(name)).reduce(_ && _)
+ case isAdd: IsAdd => isAdd.children match {
+ case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value <= 0) => true
+ case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value <= 0) => true
+ case _ => false
+ }
+ case isMul: IsMul => isMul.children match {
+ case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value <= 0) => true
+ case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value <= 0) => true
+ case _ => false
+ }
+ case isVar: IsVar if isVar.name == name => true
+ case isNeg: IsNeg => isNeg.child match {
+ case isVar: IsVar if isVar.name == name => true
+ case _ => false
+ }
+ case _ => false
+ }
+
+ /** Whether a constraint contains the named variable
+ * @param name Name of variable
+ * @param constraint Constraint to check
+ * @return
+ */
+ private def hasVar(name: String)(constraint: Constraint): Boolean = {
+ var has = false
+ def rec(constraint: Constraint): Constraint = {
+ constraint match {
+ case isVar: IsVar if isVar.name == name => has = true
+ case _ =>
+ }
+ constraint map rec
+ }
+ rec(constraint)
+ has
+ }
+
+ /** Returns illegal constraints, where both a >= and <= inequality are used on the same variable
+ * @return
+ */
+ def check(): Seq[Inequality] = {
+ val checkMap = new mutable.HashMap[String, Inequality]()
+ constraints.foldLeft(Seq[Inequality]()) { (seq, c) =>
+ checkMap.get(c.left) match {
+ case None =>
+ checkMap(c.left) = c
+ seq ++ Nil
+ case Some(x) if x.geq != c.geq => seq ++ Seq(x, c)
+ case Some(x) => seq ++ Nil
+ }
+ }
+ }
+
+ /** Solves constraints present in collected inequalities
+ *
+ * Constraint solving steps:
+ * 1) Assert no variable has both >= and <= inequalities (it can have multiple of the same kind of inequality)
+ * 2) Merge constraints of variables having multiple inequalities
+ * 3) Forward solve inequalities
+ * a. Iterate through inequalities top-to-bottom, replacing previously seen variables with corresponding
+ * constraint
+ * b. For each forward-solved inequality, attempt to remove circular constraints
+ * c. Forward-solved inequalities without circular constraints are recorded
+ * 4) Backwards solve inequalities
+ * a. Iterate through successful forward-solved inequalities bottom-to-top, replacing previously seen variables
+ * with corresponding constraint
+ * b. Record solved constraints
+ */
+ def solve(): Unit = {
+ // 1) Check if any variable has both >= and <= inequalities (which is illegal)
+ val illegals = check()
+ if (illegals != Nil) throwInternalError(s"Constraints cannot have both >= and <= inequalities: $illegals")
+
+ // 2) Merge constraints
+ val uniqueConstraints = mergeConstraints(constraints.toSeq)
+
+ // 3) Forward Solve
+ val forwardConstraintMap = new ConstraintMap
+ val orderedVars = mutable.HashMap[Int, String]()
+
+ var index = 0
+ for (constraint <- uniqueConstraints) {
+ //TODO: Risky if used improperly... need to check whether substitution from a leq to a geq is negated (always).
+ val subbedRight = forwardSubstitution(forwardConstraintMap)(constraint.right)
+ val name = constraint.left
+ val finishedRight = removeCycle(name, constraint.geq)(subbedRight)
+ if (!hasVar(name)(finishedRight)) {
+ forwardConstraintMap(name) = (finishedRight, constraint.geq)
+ orderedVars(index) = name
+ index += 1
+ }
+ }
+
+ // 4) Backwards Solve
+ for (i <- (orderedVars.size - 1) to 0 by -1) {
+ val name = orderedVars(i) // Should visit `orderedVars` backward
+ val (forwardRight, forwardGeq) = forwardConstraintMap(name)
+ val solvedRight = backwardSubstitution(solvedConstraintMap)(forwardRight)
+ solvedConstraintMap(name) = (solvedRight, forwardGeq)
+ }
+ }
+}
diff --git a/src/main/scala/firrtl/constraint/Inequality.scala b/src/main/scala/firrtl/constraint/Inequality.scala
new file mode 100644
index 00000000..0fa1d2eb
--- /dev/null
+++ b/src/main/scala/firrtl/constraint/Inequality.scala
@@ -0,0 +1,24 @@
+// See LICENSE for license details.
+
+package firrtl.constraint
+
+/** Represents either greater or equal to or less than or equal to
+ * Is passed to the constraint solver to resolve
+ */
+trait Inequality {
+ def left: String
+ def right: Constraint
+ def geq: Boolean
+}
+
+case class GreaterOrEqual(left: String, right: Constraint) extends Inequality {
+ val geq = true
+ override def toString: String = s"$left >= ${right.serialize}"
+}
+
+case class LesserOrEqual(left: String, right: Constraint) extends Inequality {
+ val geq = false
+ override def toString: String = s"$left <= ${right.serialize}"
+}
+
+
diff --git a/src/main/scala/firrtl/constraint/IsAdd.scala b/src/main/scala/firrtl/constraint/IsAdd.scala
new file mode 100644
index 00000000..e177a8b9
--- /dev/null
+++ b/src/main/scala/firrtl/constraint/IsAdd.scala
@@ -0,0 +1,52 @@
+// See LICENSE for license details.
+
+
+package firrtl.constraint
+
+// Is case class because writing tests is easier due to equality is not object equality
+case class IsAdd private (known: Option[IsKnown],
+ maxs: Vector[IsMax],
+ mins: Vector[IsMin],
+ others: Vector[Constraint]) extends Constraint with MultiAry {
+
+ def op(b1: IsKnown, b2: IsKnown): IsKnown = b1 + b2
+
+ lazy val children: Vector[Constraint] = {
+ if(known.nonEmpty) known.get +: (maxs ++ mins ++ others) else maxs ++ mins ++ others
+ }
+
+ def addChild(x: Constraint): IsAdd = x match {
+ case k: IsKnown => new IsAdd(merge(Some(k), known), maxs, mins, others)
+ case add: IsAdd => new IsAdd(merge(known, add.known), maxs ++ add.maxs, mins ++ add.mins, others ++ add.others)
+ case max: IsMax => new IsAdd(known, maxs :+ max, mins, others)
+ case min: IsMin => new IsAdd(known, maxs, mins :+ min, others)
+ case other => new IsAdd(known, maxs, mins, others :+ other)
+ }
+
+ override def serialize: String = "(" + children.map(_.serialize).mkString(" + ") + ")"
+
+ override def map(f: Constraint=>Constraint): Constraint = IsAdd(children.map(f))
+
+ def reduce(): Constraint = {
+ if(children.size == 1) children.head else {
+ (known, maxs, mins, others) match {
+ case (Some(k), _, _, _) if k.value == 0 => new IsAdd(None, maxs, mins, others).reduce()
+ case (Some(k), Vector(max), Vector(), Vector()) => max.map { o => IsAdd(k, o) }.reduce()
+ case (Some(k), Vector(), Vector(min), Vector()) => min.map { o => IsAdd(k, o) }.reduce()
+ case _ => this
+ }
+ }
+ }
+}
+
+object IsAdd {
+ def apply(left: Constraint, right: Constraint): Constraint = (left, right) match {
+ case (l: IsKnown, r: IsKnown) => l + r
+ case _ => apply(Seq(left, right))
+ }
+ def apply(children: Seq[Constraint]): Constraint = {
+ children.foldLeft(new IsAdd(None, Vector(), Vector(), Vector())) { (add, c) =>
+ add.addChild(c)
+ }.reduce()
+ }
+} \ No newline at end of file
diff --git a/src/main/scala/firrtl/constraint/IsFloor.scala b/src/main/scala/firrtl/constraint/IsFloor.scala
new file mode 100644
index 00000000..5de4697e
--- /dev/null
+++ b/src/main/scala/firrtl/constraint/IsFloor.scala
@@ -0,0 +1,32 @@
+// See LICENSE for license details.
+
+package firrtl.constraint
+
+object IsFloor {
+ def apply(child: Constraint): Constraint = new IsFloor(child, 0).reduce()
+}
+
+case class IsFloor private (child: Constraint, dummyArg: Int) extends Constraint {
+
+ override def reduce(): Constraint = child match {
+ case k: IsKnown => k.floor
+ case x: IsAdd => this
+ case x: IsMul => this
+ case x: IsNeg => this
+ case x: IsPow => this
+ // floor(max(a, b)) -> max(floor(a), floor(b))
+ case x: IsMax => IsMax(x.children.map {b => IsFloor(b)})
+ case x: IsMin => IsMin(x.children.map {b => IsFloor(b)})
+ case x: IsVar => this
+ // floor(floor(x)) -> floor(x)
+ case x: IsFloor => x
+ case _ => this
+ }
+ val children = Vector(child)
+
+ override def map(f: Constraint=>Constraint): Constraint = IsFloor(f(child))
+
+ override def serialize: String = "floor(" + child.serialize + ")"
+}
+
+
diff --git a/src/main/scala/firrtl/constraint/IsKnown.scala b/src/main/scala/firrtl/constraint/IsKnown.scala
new file mode 100644
index 00000000..5bd25f92
--- /dev/null
+++ b/src/main/scala/firrtl/constraint/IsKnown.scala
@@ -0,0 +1,44 @@
+// See LICENSE for license details.
+
+package firrtl.constraint
+
+object IsKnown {
+ def unapply(b: Constraint): Option[BigDecimal] = b match {
+ case k: IsKnown => Some(k.value)
+ case _ => None
+ }
+}
+
+/** Constant values must extend this trait see [[firrtl.ir.Closed and firrtl.ir.Open]] */
+trait IsKnown extends Constraint {
+ val value: BigDecimal
+
+ /** Addition */
+ def +(that: IsKnown): IsKnown
+
+ /** Multiplication */
+ def *(that: IsKnown): IsKnown
+
+ /** Max */
+ def max(that: IsKnown): IsKnown
+
+ /** Min */
+ def min(that: IsKnown): IsKnown
+
+ /** Negate */
+ def neg: IsKnown
+
+ /** 2 << value */
+ def pow: IsKnown
+
+ /** Floor */
+ def floor: IsKnown
+
+ override def map(f: Constraint=>Constraint): Constraint = this
+
+ val children: Vector[Constraint] = Vector.empty[Constraint]
+
+ def reduce(): IsKnown = this
+}
+
+
diff --git a/src/main/scala/firrtl/constraint/IsMax.scala b/src/main/scala/firrtl/constraint/IsMax.scala
new file mode 100644
index 00000000..3f24b7c0
--- /dev/null
+++ b/src/main/scala/firrtl/constraint/IsMax.scala
@@ -0,0 +1,59 @@
+// See LICENSE for license details.
+
+package firrtl.constraint
+
+object IsMax {
+ def apply(left: Constraint, right: Constraint): Constraint = (left, right) match {
+ case (l: IsKnown, r: IsKnown) => l max r
+ case _ => apply(Seq(left, right))
+ }
+ def apply(children: Seq[Constraint]): Constraint = {
+ val x = children.foldLeft(new IsMax(None, Vector(), Vector())) { (add, c) =>
+ add.addChild(c)
+ }
+ x.reduce()
+ }
+}
+
+case class IsMax private[constraint](known: Option[IsKnown],
+ mins: Vector[IsMin],
+ others: Vector[Constraint]
+ ) extends MultiAry {
+
+ def op(b1: IsKnown, b2: IsKnown): IsKnown = b1 max b2
+
+ override def serialize: String = "max(" + children.map(_.serialize).mkString(", ") + ")"
+
+ override def map(f: Constraint=>Constraint): Constraint = IsMax(children.map(f))
+
+ lazy val children: Vector[Constraint] = {
+ if(known.nonEmpty) known.get +: (mins ++ others) else mins ++ others
+ }
+
+ def reduce(): Constraint = {
+ if(children.size == 1) children.head else {
+ (known, mins, others) match {
+ case (Some(IsKnown(a)), _, _) =>
+ // Eliminate minimums who have a known minimum value which is smaller than known maximum value
+ val filteredMins = mins.filter {
+ case IsMin(Some(IsKnown(i)), _, _) if i <= a => false
+ case other => true
+ }
+ // If a successful filter, rerun reduce
+ val newMax = new IsMax(known, filteredMins, others)
+ if(filteredMins.size != mins.size) {
+ newMax.reduce()
+ } else newMax
+ case _ => this
+ }
+ }
+ }
+
+ def addChild(x: Constraint): IsMax = x match {
+ case k: IsKnown => new IsMax(known = merge(Some(k), known), mins, others)
+ case max: IsMax => new IsMax(known = merge(known, max.known), max.mins ++ mins, others ++ max.others)
+ case min: IsMin => new IsMax(known, mins :+ min, others)
+ case other => new IsMax(known, mins, others :+ other)
+ }
+}
+
diff --git a/src/main/scala/firrtl/constraint/IsMin.scala b/src/main/scala/firrtl/constraint/IsMin.scala
new file mode 100644
index 00000000..ee97e298
--- /dev/null
+++ b/src/main/scala/firrtl/constraint/IsMin.scala
@@ -0,0 +1,57 @@
+// See LICENSE for license details.
+
+package firrtl.constraint
+
+object IsMin {
+ def apply(left: Constraint, right: Constraint): Constraint = (left, right) match {
+ case (l: IsKnown, r: IsKnown) => l min r
+ case _ => apply(Seq(left, right))
+ }
+ def apply(children: Seq[Constraint]): Constraint = {
+ children.foldLeft(new IsMin(None, Vector(), Vector())) { (add, c) =>
+ add.addChild(c)
+ }.reduce()
+ }
+}
+
+case class IsMin private[constraint](known: Option[IsKnown],
+ maxs: Vector[IsMax],
+ others: Vector[Constraint]
+ ) extends MultiAry {
+
+ def op(b1: IsKnown, b2: IsKnown): IsKnown = b1 min b2
+
+ override def serialize: String = "min(" + children.map(_.serialize).mkString(", ") + ")"
+
+ override def map(f: Constraint=>Constraint): Constraint = IsMin(children.map(f))
+
+ lazy val children: Vector[Constraint] = {
+ if(known.nonEmpty) known.get +: (maxs ++ others) else maxs ++ others
+ }
+
+ def reduce(): Constraint = {
+ if(children.size == 1) children.head else {
+ (known, maxs, others) match {
+ case (Some(IsKnown(i)), _, _) =>
+ // Eliminate maximums who have a known maximum value which is larger than known minimum value
+ val filteredMaxs = maxs.filter {
+ case IsMax(Some(IsKnown(a)), _, _) if a >= i => false
+ case other => true
+ }
+ // If a successful filter, rerun reduce
+ val newMin = new IsMin(known, filteredMaxs, others)
+ if(filteredMaxs.size != maxs.size) {
+ newMin.reduce()
+ } else newMin
+ case _ => this
+ }
+ }
+ }
+
+ def addChild(x: Constraint): IsMin = x match {
+ case k: IsKnown => new IsMin(merge(Some(k), known), maxs, others)
+ case max: IsMax => new IsMin(known, maxs :+ max, others)
+ case min: IsMin => new IsMin(merge(min.known, known), maxs ++ min.maxs, others ++ min.others)
+ case other => new IsMin(known, maxs, others :+ other)
+ }
+}
diff --git a/src/main/scala/firrtl/constraint/IsMul.scala b/src/main/scala/firrtl/constraint/IsMul.scala
new file mode 100644
index 00000000..3f637d75
--- /dev/null
+++ b/src/main/scala/firrtl/constraint/IsMul.scala
@@ -0,0 +1,52 @@
+// See LICENSE for license details.
+
+package firrtl.constraint
+
+import firrtl.ir.Closed
+
+object IsMul {
+ def apply(left: Constraint, right: Constraint): Constraint = (left, right) match {
+ case (l: IsKnown, r: IsKnown) => l * r
+ case _ => apply(Seq(left, right))
+ }
+ def apply(children: Seq[Constraint]): Constraint = {
+ children.foldLeft(new IsMul(None, Vector())) { (add, c) =>
+ add.addChild(c)
+ }.reduce()
+ }
+}
+
+case class IsMul private (known: Option[IsKnown], others: Vector[Constraint]) extends MultiAry {
+
+ def op(b1: IsKnown, b2: IsKnown): IsKnown = b1 * b2
+
+ lazy val children: Vector[Constraint] = if(known.nonEmpty) known.get +: others else others
+
+ def addChild(x: Constraint): IsMul = x match {
+ case k: IsKnown => new IsMul(known = merge(Some(k), known), others)
+ case mul: IsMul => new IsMul(merge(known, mul.known), others ++ mul.others)
+ case other => new IsMul(known, others :+ other)
+ }
+
+ override def reduce(): Constraint = {
+ if(children.size == 1) children.head else {
+ (known, others) match {
+ case (Some(Closed(x)), _) if x == BigDecimal(1) => new IsMul(None, others).reduce()
+ case (Some(Closed(x)), _) if x == BigDecimal(0) => Closed(0)
+ case (Some(Closed(x)), Vector(m: IsMax)) if x > 0 =>
+ IsMax(m.children.map { c => IsMul(Closed(x), c) })
+ case (Some(Closed(x)), Vector(m: IsMax)) if x < 0 =>
+ IsMin(m.children.map { c => IsMul(Closed(x), c) })
+ case (Some(Closed(x)), Vector(m: IsMin)) if x > 0 =>
+ IsMin(m.children.map { c => IsMul(Closed(x), c) })
+ case (Some(Closed(x)), Vector(m: IsMin)) if x < 0 =>
+ IsMax(m.children.map { c => IsMul(Closed(x), c) })
+ case _ => this
+ }
+ }
+ }
+
+ override def map(f: Constraint=>Constraint): Constraint = IsMul(children.map(f))
+
+ override def serialize: String = "(" + children.map(_.serialize).mkString(" * ") + ")"
+}
diff --git a/src/main/scala/firrtl/constraint/IsNeg.scala b/src/main/scala/firrtl/constraint/IsNeg.scala
new file mode 100644
index 00000000..46f739c6
--- /dev/null
+++ b/src/main/scala/firrtl/constraint/IsNeg.scala
@@ -0,0 +1,32 @@
+// See LICENSE for license details.
+
+package firrtl.constraint
+
+object IsNeg {
+ def apply(child: Constraint): Constraint = new IsNeg(child, 0).reduce()
+}
+
+// Dummy arg is to get around weird Scala issue that can't differentiate between a
+// private constructor and public apply that share the same arguments
+case class IsNeg private (child: Constraint, dummyArg: Int) extends Constraint {
+ override def reduce(): Constraint = child match {
+ case k: IsKnown => k.neg
+ case x: IsAdd => IsAdd(x.children.map { b => IsNeg(b) })
+ case x: IsMul => IsMul(Seq(IsNeg(x.children.head)) ++ x.children.tail)
+ case x: IsNeg => x.child
+ case x: IsPow => this
+ // -[max(a, b)] -> min[-a, -b]
+ case x: IsMax => IsMin(x.children.map { b => IsNeg(b) })
+ case x: IsMin => IsMax(x.children.map { b => IsNeg(b) })
+ case x: IsVar => this
+ case _ => this
+ }
+
+ lazy val children = Vector(child)
+
+ override def map(f: Constraint=>Constraint): Constraint = IsNeg(f(child))
+
+ override def serialize: String = "(-" + child.serialize + ")"
+}
+
+
diff --git a/src/main/scala/firrtl/constraint/IsPow.scala b/src/main/scala/firrtl/constraint/IsPow.scala
new file mode 100644
index 00000000..54a06bf8
--- /dev/null
+++ b/src/main/scala/firrtl/constraint/IsPow.scala
@@ -0,0 +1,33 @@
+// See LICENSE for license details.
+
+package firrtl.constraint
+
+object IsPow {
+ def apply(child: Constraint): Constraint = new IsPow(child, 0).reduce()
+}
+
+// Dummy arg is to get around weird Scala issue that can't differentiate between a
+// private constructor and public apply that share the same arguments
+case class IsPow private (child: Constraint, dummyArg: Int) extends Constraint {
+ override def reduce(): Constraint = child match {
+ case k: IsKnown => k.pow
+ // 2^(a + b) -> 2^a * 2^b
+ case x: IsAdd => IsMul(x.children.map { b => IsPow(b)})
+ case x: IsMul => this
+ case x: IsNeg => this
+ case x: IsPow => this
+ // 2^(max(a, b)) -> max(2^a, 2^b) since two is always positive, so a, b control magnitude
+ case x: IsMax => IsMax(x.children.map {b => IsPow(b)})
+ case x: IsMin => IsMin(x.children.map {b => IsPow(b)})
+ case x: IsVar => this
+ case _ => this
+ }
+
+ val children = Vector(child)
+
+ override def map(f: Constraint=>Constraint): Constraint = IsPow(f(child))
+
+ override def serialize: String = "(2^" + child.serialize + ")"
+}
+
+
diff --git a/src/main/scala/firrtl/constraint/IsVar.scala b/src/main/scala/firrtl/constraint/IsVar.scala
new file mode 100644
index 00000000..98396fa0
--- /dev/null
+++ b/src/main/scala/firrtl/constraint/IsVar.scala
@@ -0,0 +1,27 @@
+// See LICENSE for license details.
+
+package firrtl.constraint
+
+object IsVar {
+ def unapply(i: Constraint): Option[String] = i match {
+ case i: IsVar => Some(i.name)
+ case _ => None
+ }
+}
+
+/** Extend to be a constraint variable */
+trait IsVar extends Constraint {
+
+ def name: String
+
+ override def serialize: String = name
+
+ override def map(f: Constraint=>Constraint): Constraint = this
+
+ override def reduce() = this
+
+ val children = Vector()
+}
+
+case class VarCon(name: String) extends IsVar
+
diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala
index e721363c..63c620d1 100644
--- a/src/main/scala/firrtl/ir/IR.scala
+++ b/src/main/scala/firrtl/ir/IR.scala
@@ -3,7 +3,11 @@
package firrtl
package ir
-import Utils.indent
+import Utils.{dec2string, indent, trim}
+import firrtl.constraint.{Constraint, IsKnown, IsVar}
+
+import scala.math.BigDecimal.RoundingMode._
+import scala.collection.mutable
/** Intermediate Representation */
abstract class FirrtlNode {
@@ -103,6 +107,29 @@ object StringLit {
*/
abstract class PrimOp extends FirrtlNode {
def serialize: String = this.toString
+ def propagateType(e: DoPrim): Type = UnknownType
+ def apply(args: Any*): DoPrim = {
+ val groups = args.groupBy {
+ case x: Expression => "exp"
+ case x: BigInt => "int"
+ case x: Int => "int"
+ case other => "other"
+ }
+ val exprs = groups.getOrElse("exp", Nil).collect {
+ case e: Expression => e
+ }
+ val consts = groups.getOrElse("int", Nil).map {
+ _ match {
+ case i: BigInt => i
+ case i: Int => BigInt(i)
+ }
+ }
+ groups.get("other") match {
+ case None =>
+ case Some(x) => sys.error(s"Shouldn't be here: $x")
+ }
+ DoPrim(this, exprs, consts, UnknownType)
+ }
}
abstract class Expression extends FirrtlNode {
@@ -151,7 +178,7 @@ case class SubAccess(expr: Expression, index: Expression, tpe: Type) extends Exp
def foreachType(f: Type => Unit): Unit = f(tpe)
def foreachWidth(f: Width => Unit): Unit = Unit
}
-case class Mux(cond: Expression, tval: Expression, fval: Expression, tpe: Type) extends Expression {
+case class Mux(cond: Expression, tval: Expression, fval: Expression, tpe: Type = UnknownType) extends Expression {
def serialize: String = s"mux(${cond.serialize}, ${tval.serialize}, ${fval.serialize})"
def mapExpr(f: Expression => Expression): Expression = Mux(f(cond), f(tval), f(fval), tpe)
def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
@@ -535,6 +562,12 @@ class IntWidth(val width: BigInt) extends Width with Product {
case object UnknownWidth extends Width {
def serialize: String = ""
}
+case class CalcWidth(arg: Constraint) extends Width {
+ def serialize: String = s"calcw(${arg.serialize})"
+}
+case class VarWidth(name: String) extends Width with IsVar {
+ override def serialize: String = s"<$name>"
+}
/** Orientation of [[Field]] */
abstract class Orientation extends FirrtlNode
@@ -550,6 +583,67 @@ case class Field(name: String, flip: Orientation, tpe: Type) extends FirrtlNode
def serialize: String = flip.serialize + name + " : " + tpe.serialize
}
+
+/** Bounds of [[IntervalType]] */
+
+trait Bound extends Constraint {
+ def serialize: String
+}
+case object UnknownBound extends Bound {
+ def serialize: String = "?"
+ def map(f: Constraint=>Constraint): Constraint = this
+ override def reduce(): Constraint = this
+ val children = Vector()
+}
+case class CalcBound(arg: Constraint) extends Bound {
+ def serialize: String = s"calcb(${arg.serialize})"
+ def map(f: Constraint=>Constraint): Constraint = f(arg)
+ override def reduce(): Constraint = arg
+ val children = Vector(arg)
+}
+case class VarBound(name: String) extends IsVar with Bound
+object KnownBound {
+ def unapply(b: Constraint): Option[BigDecimal] = b match {
+ case k: IsKnown => Some(k.value)
+ case _ => None
+ }
+ def unapply(b: Bound): Option[BigDecimal] = b match {
+ case k: IsKnown => Some(k.value)
+ case _ => None
+ }
+}
+case class Open(value: BigDecimal) extends IsKnown with Bound {
+ def serialize = s"o($value)"
+ def +(that: IsKnown): IsKnown = Open(value + that.value)
+ def *(that: IsKnown): IsKnown = that match {
+ case Closed(x) if x == 0 => Closed(x)
+ case _ => Open(value * that.value)
+ }
+ def min(that: IsKnown): IsKnown = if(value < that.value) this else that
+ def max(that: IsKnown): IsKnown = if(value > that.value) this else that
+ def neg: IsKnown = Open(-value)
+ def floor: IsKnown = Open(value.setScale(0, BigDecimal.RoundingMode.FLOOR))
+ def pow: IsKnown = if(value.isBinaryDouble) Open(BigDecimal(BigInt(1) << value.toInt)) else sys.error("Shouldn't be here")
+}
+case class Closed(value: BigDecimal) extends IsKnown with Bound {
+ def serialize = s"c($value)"
+ def +(that: IsKnown): IsKnown = that match {
+ case Open(x) => Open(value + x)
+ case Closed(x) => Closed(value + x)
+ }
+ def *(that: IsKnown): IsKnown = that match {
+ case IsKnown(x) if value == BigInt(0) => Closed(0)
+ case Open(x) => Open(value * x)
+ case Closed(x) => Closed(value * x)
+ }
+ def min(that: IsKnown): IsKnown = if(value <= that.value) this else that
+ def max(that: IsKnown): IsKnown = if(value >= that.value) this else that
+ def neg: IsKnown = Closed(-value)
+ def floor: IsKnown = Closed(value.setScale(0, BigDecimal.RoundingMode.FLOOR))
+ def pow: IsKnown = if(value.isBinaryDouble) Closed(BigDecimal(BigInt(1) << value.toInt)) else sys.error("Shouldn't be here")
+}
+
+/** Types of [[FirrtlNode]] */
abstract class Type extends FirrtlNode {
def mapType(f: Type => Type): Type
def mapWidth(f: Width => Width): Type
@@ -586,6 +680,84 @@ case class FixedType(width: Width, point: Width) extends GroundType {
def mapWidth(f: Width => Width): Type = FixedType(f(width), f(point))
def foreachWidth(f: Width => Unit): Unit = { f(width); f(point) }
}
+case class IntervalType(lower: Bound, upper: Bound, point: Width) extends GroundType {
+ override def serialize: String = {
+ val lowerString = lower match {
+ case Open(l) => s"(${dec2string(l)}, "
+ case Closed(l) => s"[${dec2string(l)}, "
+ case UnknownBound => s"[?, "
+ case _ => s"[?, "
+ }
+ val upperString = upper match {
+ case Open(u) => s"${dec2string(u)})"
+ case Closed(u) => s"${dec2string(u)}]"
+ case UnknownBound => s"?]"
+ case _ => s"?]"
+ }
+ val bounds = (lower, upper) match {
+ case (k1: IsKnown, k2: IsKnown) => lowerString + upperString
+ case _ => ""
+ }
+ val pointString = point match {
+ case IntWidth(i) => "." + i.toString
+ case _ => ""
+ }
+ "Interval" + bounds + pointString
+ }
+
+ private lazy val bp = point.asInstanceOf[IntWidth].width.toInt
+ private def precision: Option[BigDecimal] = point match {
+ case IntWidth(width) =>
+ val bp = width.toInt
+ if(bp >= 0) Some(BigDecimal(1) / BigDecimal(BigInt(1) << bp)) else Some(BigDecimal(BigInt(1) << -bp))
+ case other => None
+ }
+
+ def min: Option[BigDecimal] = (lower, precision) match {
+ case (Open(a), Some(prec)) => a / prec match {
+ case x if trim(x).isWhole => Some(a + prec) // add precision for open lower bound i.e. (-4 -> [3 for bp = 0
+ case x => Some(x.setScale(0, CEILING) * prec) // Deal with unrepresentable bound representations (finite BP) -- new closed form l > original l
+ }
+ case (Closed(a), Some(prec)) => Some((a / prec).setScale(0, CEILING) * prec)
+ case other => None
+ }
+
+ def max: Option[BigDecimal] = (upper, precision) match {
+ case (Open(a), Some(prec)) => a / prec match {
+ case x if trim(x).isWhole => Some(a - prec) // subtract precision for open upper bound
+ case x => Some(x.setScale(0, FLOOR) * prec)
+ }
+ case (Closed(a), Some(prec)) => Some((a / prec).setScale(0, FLOOR) * prec)
+ }
+
+ def minAdjusted: Option[BigInt] = min.map(_ * BigDecimal(BigInt(1) << bp) match {
+ case x if trim(x).isWhole | x.doubleValue == 0.0 => x.toBigInt
+ case x => sys.error(s"MinAdjusted should be a whole number: $x. Min is $min. BP is $bp. Precision is $precision. Lower is ${lower}.")
+ })
+
+ def maxAdjusted: Option[BigInt] = max.map(_ * BigDecimal(BigInt(1) << bp) match {
+ case x if trim(x).isWhole => x.toBigInt
+ case x => sys.error(s"MaxAdjusted should be a whole number: $x")
+ })
+
+ /** If bounds are known, calculates the width, otherwise returns UnknownWidth */
+ lazy val width: Width = (point, lower, upper) match {
+ case (IntWidth(i), l: IsKnown, u: IsKnown) =>
+ IntWidth(Math.max(Utils.getSIntWidth(minAdjusted.get), Utils.getSIntWidth(maxAdjusted.get)))
+ case _ => UnknownWidth
+ }
+
+ /** If bounds are known, returns a sequence of all possible values inside this interval */
+ lazy val range: Option[Seq[BigDecimal]] = (lower, upper, point) match {
+ case (l: IsKnown, u: IsKnown, p: IntWidth) =>
+ if(min.get > max.get) Some(Nil) else Some(Range.BigDecimal(min.get, max.get, precision.get))
+ case _ => None
+ }
+
+ override def mapWidth(f: Width => Width): Type = this.copy(point = f(point))
+ override def foreachWidth(f: Width => Unit): Unit = f(point)
+}
+
case class BundleType(fields: Seq[Field]) extends AggregateType {
def serialize: String = "{ " + (fields map (_.serialize) mkString ", ") + "}"
def mapType(f: Type => Type): Type =
@@ -645,6 +817,7 @@ case class Port(
direction: Direction,
tpe: Type) extends FirrtlNode with IsDeclaration {
def serialize: String = s"${direction.serialize} $name : ${tpe.serialize}" + info.serialize
+ def mapType(f: Type => Type): Port = Port(info, name, direction, f(tpe))
}
/** Parameters for external modules */
diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala
index 07784e19..5ae5dad4 100644
--- a/src/main/scala/firrtl/passes/CheckWidths.scala
+++ b/src/main/scala/firrtl/passes/CheckWidths.scala
@@ -7,14 +7,21 @@ import firrtl.ir._
import firrtl.PrimOps._
import firrtl.traversals.Foreachers._
import firrtl.Utils._
+import firrtl.constraint.IsKnown
import firrtl.annotations.{CircuitTarget, ModuleTarget, Target, TargetToken}
object CheckWidths extends Pass {
/** The maximum allowed width for any circuit element */
val MaxWidth = 1000000
- val DshlMaxWidth = ceilLog2(MaxWidth + 1)
+ val DshlMaxWidth = getUIntWidth(MaxWidth)
class UninferredWidth (info: Info, target: String) extends PassException(
- s"""|$info : Uninferred width for target below. (Did you forget to assign to it?)
+ s"""|$info : Uninferred width for target below.serialize}. (Did you forget to assign to it?)
+ |$target""".stripMargin)
+ class UninferredBound (info: Info, target: String, bound: String) extends PassException(
+ s"""|$info : Uninferred $bound bound for target. (Did you forget to assign to it?)
+ |$target""".stripMargin)
+ class InvalidRange (info: Info, target: String, i: IntervalType) extends PassException(
+ s"""|$info : Invalid range ${i.serialize} for target below. (Are the bounds valid?)
|$target""".stripMargin)
class WidthTooSmall(info: Info, mname: String, b: BigInt) extends PassException(
s"$info : [target $mname] Width too small for constant $b.")
@@ -32,17 +39,25 @@ object CheckWidths extends Pass {
s"$info: [target $mname] Parameter $n in tail operator is larger than input width $width.")
class AttachWidthsNotEqual(info: Info, mname: String, eName: String, source: String) extends PassException(
s"$info: [target $mname] Attach source $source and expression $eName must have identical widths.")
+ class DisjointSqueeze(info: Info, mname: String, squeeze: DoPrim)
+ extends PassException({
+ val toSqz = squeeze.args.head.serialize
+ val toSqzTpe = squeeze.args.head.tpe.serialize
+ val sqzTo = squeeze.args(1).serialize
+ val sqzToTpe = squeeze.args(1).tpe.serialize
+ s"$info: [module $mname] Disjoint squz currently unsupported: $toSqz:$toSqzTpe cannot be squeezed with $sqzTo's type $sqzToTpe"
+ })
def run(c: Circuit): Circuit = {
val errors = new Errors()
- def check_width_w(info: Info, target: Target)(w: Width): Unit = {
- w match {
- case IntWidth(width) if width >= MaxWidth =>
+ def check_width_w(info: Info, target: Target, t: Type)(w: Width): Unit = {
+ (w, t) match {
+ case (IntWidth(width), _) if width >= MaxWidth =>
errors.append(new WidthTooBig(info, target.serialize, width))
- case w: IntWidth if w.width >= 0 =>
- case _: IntWidth =>
+ case (w: IntWidth, f: FixedType) if (w.width < 0 && w.width == f.width) =>
errors append new NegWidthException(info, target.serialize)
+ case (_: IntWidth, _) =>
case _ =>
errors append new UninferredWidth(info, target.prettyPrint(" "))
}
@@ -57,9 +72,28 @@ object CheckWidths extends Pass {
def check_width_t(info: Info, target: Target)(t: Type): Unit = {
t match {
case tt: BundleType => tt.fields.foreach(check_width_f(info, target))
+ //Supports when l = u (if closed)
+ case i@IntervalType(Closed(l), Closed(u), IntWidth(_)) if l <= u => i
+ case i:IntervalType if i.range == Some(Nil) =>
+ errors append new InvalidRange(info, target.prettyPrint(" "), i)
+ i
+ case i@IntervalType(KnownBound(l), KnownBound(u), IntWidth(p)) if l >= u =>
+ errors append new InvalidRange(info, target.prettyPrint(" "), i)
+ i
+ case i@IntervalType(KnownBound(_), KnownBound(_), IntWidth(_)) => i
+ case i@IntervalType(_: IsKnown, _, _) =>
+ errors append new UninferredBound(info, target.prettyPrint(" "), "upper")
+ i
+ case i@IntervalType(_, _: IsKnown, _) =>
+ errors append new UninferredBound(info, target.prettyPrint(" "), "lower")
+ i
+ case i@IntervalType(_, _, _) =>
+ errors append new UninferredBound(info, target.prettyPrint(" "), "lower")
+ errors append new UninferredBound(info, target.prettyPrint(" "), "upper")
+ i
case tt => tt foreach check_width_t(info, target)
}
- t foreach check_width_w(info, target)
+ t foreach check_width_w(info, target, t)
}
def check_width_f(info: Info, target: Target)(f: Field): Unit =
@@ -77,6 +111,12 @@ object CheckWidths extends Pass {
errors append new WidthTooSmall(info, target.serialize, e.value)
case _ =>
}
+ case sqz@DoPrim(Squeeze, Seq(a, b), _, IntervalType(Closed(min), Closed(max), _)) =>
+ (a.tpe, b.tpe) match {
+ case (IntervalType(Closed(la), Closed(ua), _), IntervalType(Closed(lb), Closed(ub), _)) if (ua < lb) || (ub < la) =>
+ errors append new DisjointSqueeze(info, target.serialize, sqz)
+ case other =>
+ }
case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) <= hi) =>
errors append new BitsWidthException(info, target.serialize, hi, bitWidth(a.tpe), e.serialize)
case DoPrim(Head, Seq(a), Seq(n), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) < n) =>
@@ -87,7 +127,6 @@ object CheckWidths extends Pass {
errors append new DshlTooBig(info, target.serialize)
case _ =>
}
- //e map check_width_t(info, mname) map check_width_e(info, mname)
e foreach check_width_e(info, target)
}
@@ -111,11 +150,15 @@ object CheckWidths extends Pass {
case ResetType =>
case _ => errors.append(new CheckTypes.IllegalResetType(info, target.serialize, sx.name))
}
+ if(!CheckTypes.validConnect(sx.tpe, sx.init.tpe)) {
+ val conMsg = sx.copy(info = NoInfo).serialize
+ errors.append(new CheckTypes.InvalidConnect(info, target.module, conMsg, WRef(sx), sx.init))
+ }
case _ =>
}
}
- def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Unit = check_width_t(p.info, target)(p.tpe)
+ def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Unit = check_width_t(p.info, target.ref(p.name))(p.tpe)
def check_width_m(circuit: CircuitTarget)(m: DefModule): Unit = {
m foreach check_width_p(m.info, circuit.module(m.name))
diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala
index 85ed7de0..4239247c 100644
--- a/src/main/scala/firrtl/passes/Checks.scala
+++ b/src/main/scala/firrtl/passes/Checks.scala
@@ -8,6 +8,7 @@ import firrtl.PrimOps._
import firrtl.Utils._
import firrtl.traversals.Foreachers._
import firrtl.WrappedType._
+import firrtl.constraint.{Constraint, IsKnown}
trait CheckHighFormLike {
type NameSet = collection.mutable.HashSet[String]
@@ -84,11 +85,11 @@ trait CheckHighFormLike {
e.op match {
case Add | Sub | Mul | Div | Rem | Lt | Leq | Gt | Geq |
- Eq | Neq | Dshl | Dshr | And | Or | Xor | Cat | Dshlw =>
+ Eq | Neq | Dshl | Dshr | And | Or | Xor | Cat | Dshlw | Clip | Wrap | Squeeze =>
correctNum(Option(2), 0)
case AsUInt | AsSInt | AsClock | AsAsyncReset | Cvt | Neq | Not =>
correctNum(Option(1), 0)
- case AsFixedPoint | Pad | Head | Tail | BPShl | BPShr | BPSet =>
+ case AsFixedPoint | Pad | Head | Tail | IncP | DecP | SetP =>
correctNum(Option(1), 1)
case Shl | Shr =>
correctNum(Option(1), 1)
@@ -101,6 +102,8 @@ trait CheckHighFormLike {
if (lsb > msb) {
errors.append(new LsbLargerThanMsbException(info, mname, e.op.toString, lsb, msb))
}
+ case AsInterval =>
+ correctNum(Option(1), 3)
case Andr | Orr | Xorr | Neg =>
correctNum(None,0)
}
@@ -137,7 +140,9 @@ trait CheckHighFormLike {
def checkHighFormT(info: Info, mname: String)(t: Type): Unit = {
t foreach checkHighFormT(info, mname)
t match {
- case tx: VectorType if tx.size < 0 => errors.append(new NegVecSizeException(info, mname))
+ case tx: VectorType if tx.size < 0 =>
+ errors.append(new NegVecSizeException(info, mname))
+ case i: IntervalType => i
case _ => t foreach checkHighFormW(info, mname)
}
}
@@ -146,6 +151,7 @@ trait CheckHighFormLike {
e match {
case _: Reference | _: SubField | _: SubIndex | _: SubAccess => // No error
case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess | _: Mux | _: ValidIf => // No error
+ case _: Reference | _: SubField | _: SubIndex | _: SubAccess => // No error
case _ => errors.append(new InvalidAccessException(info, mname))
}
}
@@ -164,8 +170,8 @@ trait CheckHighFormLike {
case ex: WSubAccess => validSubexp(info, mname)(ex.expr)
case ex => ex foreach validSubexp(info, mname)
}
- e foreach checkHighFormW(info, mname)
- e foreach checkHighFormT(info, mname)
+ e foreach checkHighFormW(info, mname + "/" + e.serialize)
+ e foreach checkHighFormT(info, mname + "/" + e.serialize)
e foreach checkHighFormE(info, mname, names)
}
@@ -215,8 +221,7 @@ trait CheckHighFormLike {
if (names(p.name))
errors.append(new NotUniqueException(NoInfo, mname, p.name))
names += p.name
- p.tpe foreach checkHighFormT(p.info, mname)
- p.tpe foreach checkHighFormW(p.info, mname)
+ checkHighFormT(p.info, mname)(p.tpe)
}
// Search for ResetType Ports of direction
@@ -339,6 +344,11 @@ object CheckTypes extends Pass {
s"$info: [module $mname] Uninferred type: $exp."
)
+ def fits(bigger: Constraint, smaller: Constraint): Boolean = (bigger, smaller) match {
+ case (IsKnown(v1), IsKnown(v2)) if v1 < v2 => false
+ case _ => true
+ }
+
def legalResetType(tpe: Type): Boolean = tpe match {
case UIntType(IntWidth(w)) if w == 1 => true
case AsyncResetType => true
@@ -355,6 +365,9 @@ object CheckTypes extends Pass {
case (_: UIntType, _: UIntType) => flip1 == flip2
case (_: SIntType, _: SIntType) => flip1 == flip2
case (_: FixedType, _: FixedType) => flip1 == flip2
+ case (i1: IntervalType, i2: IntervalType) =>
+ import Implicits.width2constraint
+ fits(i2.lower, i1.lower) && fits(i1.upper, i2.upper) && fits(i1.point, i2.point)
case (_: AnalogType, _: AnalogType) => true
case (AsyncResetType, AsyncResetType) => flip1 == flip2
case (ResetType, tpe) => legalResetType(tpe) && flip1 == flip2
@@ -375,7 +388,17 @@ object CheckTypes extends Pass {
}
}
- def validConnect(con: Connect): Boolean = wt(con.loc.tpe).superTypeOf(wt(con.expr.tpe))
+ def validConnect(locTpe: Type, expTpe: Type): Boolean = {
+ val itFits = (locTpe, expTpe) match {
+ case (i1: IntervalType, i2: IntervalType) =>
+ import Implicits.width2constraint
+ fits(i2.lower, i1.lower) && fits(i1.upper, i2.upper) && fits(i1.point, i2.point)
+ case _ => true
+ }
+ wt(locTpe).superTypeOf(wt(expTpe)) && itFits
+ }
+
+ def validConnect(con: Connect): Boolean = validConnect(con.loc.tpe, con.expr.tpe)
def validPartialConnect(con: PartialConnect): Boolean =
bulk_equals(con.loc.tpe, con.expr.tpe, Default, Default)
@@ -393,44 +416,51 @@ object CheckTypes extends Pass {
case tx: BundleType => tx.fields forall (x => x.flip == Default && passive(x.tpe))
case tx => true
}
+
def check_types_primop(info: Info, mname: String, e: DoPrim): Unit = {
- def checkAllTypes(exprs: Seq[Expression], okUInt: Boolean, okSInt: Boolean, okClock: Boolean, okFix: Boolean, okAsync: Boolean): Unit = {
- exprs.foldLeft((false, false, false, false, false)) {
- case ((isUInt, isSInt, isClock, isFix, isAsync), expr) => expr.tpe match {
- case u: UIntType => (true, isSInt, isClock, isFix, isAsync)
- case s: SIntType => (isUInt, true, isClock, isFix, isAsync)
- case ClockType => (isUInt, isSInt, true, isFix, isAsync)
- case f: FixedType => (isUInt, isSInt, isClock, true, isAsync)
- case AsyncResetType => (isUInt, isSInt, isClock, isFix, true)
+ def checkAllTypes(exprs: Seq[Expression], okUInt: Boolean, okSInt: Boolean, okClock: Boolean, okFix: Boolean, okAsync: Boolean, okInterval: Boolean): Unit = {
+ exprs.foldLeft((false, false, false, false, false, false)) {
+ case ((isUInt, isSInt, isClock, isFix, isAsync, isInterval), expr) => expr.tpe match {
+ case u: UIntType => (true, isSInt, isClock, isFix, isAsync, isInterval)
+ case s: SIntType => (isUInt, true, isClock, isFix, isAsync, isInterval)
+ case ClockType => (isUInt, isSInt, true, isFix, isAsync, isInterval)
+ case f: FixedType => (isUInt, isSInt, isClock, true, isAsync, isInterval)
+ case AsyncResetType => (isUInt, isSInt, isClock, isFix, true, isInterval)
+ case i:IntervalType => (isUInt, isSInt, isClock, isFix, isAsync, true)
case UnknownType =>
errors.append(new IllegalUnknownType(info, mname, e.serialize))
- (isUInt, isSInt, isClock, isFix, isAsync)
+ (isUInt, isSInt, isClock, isFix, isAsync, isInterval)
case other => throwInternalError(s"Illegal Type: ${other.serialize}")
}
} match {
- // (UInt, SInt, Clock, Fixed)
- case (isAll, false, false, false, false) if isAll == okUInt =>
- case (false, isAll, false, false, false) if isAll == okSInt =>
- case (false, false, isAll, false, false) if isAll == okClock =>
- case (false, false, false, isAll, false) if isAll == okFix =>
- case (false, false, false, false, isAll) if isAll == okAsync =>
+ // (UInt, SInt, Clock, Fixed, Async, Interval)
+ case (isAll, false, false, false, false, false) if isAll == okUInt =>
+ case (false, isAll, false, false, false, false) if isAll == okSInt =>
+ case (false, false, isAll, false, false, false) if isAll == okClock =>
+ case (false, false, false, isAll, false, false) if isAll == okFix =>
+ case (false, false, false, false, isAll, false) if isAll == okAsync =>
+ case (false, false, false, false, false, isAll) if isAll == okInterval =>
case x => errors.append(new OpNotCorrectType(info, mname, e.op.serialize, exprs.map(_.tpe.serialize)))
}
}
e.op match {
- case AsUInt | AsSInt | AsClock | AsFixedPoint | AsAsyncReset =>
+ case AsUInt | AsSInt | AsClock | AsFixedPoint | AsAsyncReset | AsInterval =>
// All types are ok
case Dshl | Dshr =>
- checkAllTypes(Seq(e.args.head), okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false)
- checkAllTypes(Seq(e.args(1)), okUInt=true, okSInt=false, okClock=false, okFix=false, okAsync=false)
+ checkAllTypes(Seq(e.args.head), okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true)
+ checkAllTypes(Seq(e.args(1)), okUInt=true, okSInt=false, okClock=false, okFix=false, okAsync=false, okInterval=false)
case Add | Sub | Mul | Lt | Leq | Gt | Geq | Eq | Neq =>
- checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false)
- case Pad | Shl | Shr | Cat | Bits | Head | Tail =>
- checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false)
- case BPShl | BPShr | BPSet =>
- checkAllTypes(e.args, okUInt=false, okSInt=false, okClock=false, okFix=true, okAsync=false)
+ checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true)
+ case Pad | Bits | Head | Tail =>
+ checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=false)
+ case Shl | Shr | Cat =>
+ checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true)
+ case IncP | DecP | SetP =>
+ checkAllTypes(e.args, okUInt=false, okSInt=false, okClock=false, okFix=true, okAsync=false, okInterval=true)
+ case Wrap | Clip | Squeeze =>
+ checkAllTypes(e.args, okUInt = false, okSInt = false, okClock = false, okFix = false, okAsync=false, okInterval = true)
case _ =>
- checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=false, okAsync=false)
+ checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=false, okAsync=false, okInterval=false)
}
}
@@ -494,6 +524,9 @@ object CheckTypes extends Pass {
sx.tpe match {
case AnalogType(_) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name))
case t if wt(sx.tpe) != wt(sx.init.tpe) => errors.append(new InvalidRegInit(info, mname))
+ case t if !validConnect(sx.tpe, sx.init.tpe) =>
+ val conMsg = sx.copy(info = NoInfo).serialize
+ errors.append(new CheckTypes.InvalidConnect(info, mname, conMsg, WRef(sx), sx.init))
case t =>
}
if (!legalResetType(sx.reset.tpe)) {
diff --git a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala
index 67fdfea0..05a000c5 100644
--- a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala
+++ b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala
@@ -39,9 +39,9 @@ object ConvertFixedToSInt extends Pass {
def updateExpType(e:Expression): Expression = e match {
case DoPrim(Mul, args, consts, tpe) => e map updateExpType
case DoPrim(AsFixedPoint, args, consts, tpe) => DoPrim(AsSInt, args, Seq.empty, tpe) map updateExpType
- case DoPrim(BPShl, args, consts, tpe) => DoPrim(Shl, args, consts, tpe) map updateExpType
- case DoPrim(BPShr, args, consts, tpe) => DoPrim(Shr, args, consts, tpe) map updateExpType
- case DoPrim(BPSet, args, consts, FixedType(w, IntWidth(p))) => alignArg(args.head, p) map updateExpType
+ case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe) map updateExpType
+ case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe) map updateExpType
+ case DoPrim(SetP, args, consts, FixedType(w, IntWidth(p))) => alignArg(args.head, p) map updateExpType
case DoPrim(op, args, consts, tpe) =>
val point = calcPoint(args)
val newExp = DoPrim(op, args.map(x => alignArg(x, point)), consts, UnknownType)
diff --git a/src/main/scala/firrtl/passes/InferBinaryPoints.scala b/src/main/scala/firrtl/passes/InferBinaryPoints.scala
new file mode 100644
index 00000000..258c9697
--- /dev/null
+++ b/src/main/scala/firrtl/passes/InferBinaryPoints.scala
@@ -0,0 +1,101 @@
+// See LICENSE for license details.
+
+package firrtl.passes
+
+import firrtl.ir._
+import firrtl.Utils._
+import firrtl.Mappers._
+import firrtl.annotations.{CircuitTarget, ModuleTarget, ReferenceTarget, Target}
+import firrtl.constraint.ConstraintSolver
+
+class InferBinaryPoints extends Pass {
+ private val constraintSolver = new ConstraintSolver()
+
+ private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1,t2) match {
+ case (UIntType(w1), UIntType(w2)) =>
+ case (SIntType(w1), SIntType(w2)) =>
+ case (ClockType, ClockType) =>
+ case (ResetType, _) =>
+ case (_, ResetType) =>
+ case (AsyncResetType, AsyncResetType) =>
+ case (FixedType(w1, p1), FixedType(w2, p2)) =>
+ constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint(""))
+ case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) =>
+ constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint(""))
+ case (AnalogType(w1), AnalogType(w2)) =>
+ case (t1: BundleType, t2: BundleType) =>
+ (t1.fields zip t2.fields) foreach { case (f1, f2) =>
+ (f1.flip, f2.flip) match {
+ case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe)
+ case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe)
+ case _ => sys.error("Shouldn't be here")
+ }
+ }
+ case (t1: VectorType, t2: VectorType) => addTypeConstraints(r1.index(0), r2.index(0))(t1.tpe, t2.tpe)
+ case other => throwInternalError(s"Illegal compiler state: cannot constraint different types - $other")
+ }
+ private def addDecConstraints(t: Type): Type = t map addDecConstraints
+ private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s map addDecConstraints match {
+ case c: Connect =>
+ val n = get_size(c.loc.tpe)
+ val locs = create_exps(c.loc)
+ val exps = create_exps(c.expr)
+ (locs zip exps) foreach { case (loc, exp) =>
+ to_flip(flow(loc)) match {
+ case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
+ case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
+ }
+ }
+ c
+ case pc: PartialConnect =>
+ val ls = get_valid_points(pc.loc.tpe, pc.expr.tpe, Default, Default)
+ val locs = create_exps(pc.loc)
+ val exps = create_exps(pc.expr)
+ ls foreach { case (x, y) =>
+ val loc = locs(x)
+ val exp = exps(y)
+ to_flip(flow(loc)) match {
+ case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
+ case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
+ }
+ }
+ pc
+ case r: DefRegister =>
+ addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe)
+ r
+ case x => x map addStmtConstraints(mt)
+ }
+ private def fixWidth(w: Width): Width = constraintSolver.get(w) match {
+ case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt)
+ case None => w
+ case _ => sys.error("Shouldn't be here")
+ }
+ private def fixType(t: Type): Type = t map fixType map fixWidth match {
+ case IntervalType(l, u, p) =>
+ val px = constraintSolver.get(p) match {
+ case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt)
+ case None => p
+ case _ => sys.error("Shouldn't be here")
+ }
+ IntervalType(l, u, px)
+ case FixedType(w, p) =>
+ val px = constraintSolver.get(p) match {
+ case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt)
+ case None => p
+ case _ => sys.error("Shouldn't be here")
+ }
+ FixedType(w, px)
+ case x => x
+ }
+ private def fixStmt(s: Statement): Statement = s map fixStmt map fixType
+ private def fixPort(p: Port): Port = Port(p.info, p.name, p.direction, fixType(p.tpe))
+ def run (c: Circuit): Circuit = {
+ val ct = CircuitTarget(c.main)
+ c.modules foreach (m => m map addStmtConstraints(ct.module(m.name)))
+ c.modules foreach (_.ports foreach {p => addDecConstraints(p.tpe)})
+ constraintSolver.solve()
+ InferTypes.run(c.copy(modules = c.modules map (_
+ map fixPort
+ map fixStmt)))
+ }
+}
diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala
index 288b62ba..3c5cf7fb 100644
--- a/src/main/scala/firrtl/passes/InferTypes.scala
+++ b/src/main/scala/firrtl/passes/InferTypes.scala
@@ -14,13 +14,23 @@ object InferTypes extends Pass {
val namespace = Namespace()
val mtypes = (c.modules map (m => m.name -> module_type(m))).toMap
+ def remove_unknowns_b(b: Bound): Bound = b match {
+ case UnknownBound => VarBound(namespace.newName("b"))
+ case k => k
+ }
+
def remove_unknowns_w(w: Width): Width = w match {
case UnknownWidth => VarWidth(namespace.newName("w"))
case wx => wx
}
- def remove_unknowns(t: Type): Type =
- t map remove_unknowns map remove_unknowns_w
+ def remove_unknowns(t: Type): Type = {
+ t map remove_unknowns map remove_unknowns_w match {
+ case IntervalType(l, u, p) =>
+ IntervalType(remove_unknowns_b(l), remove_unknowns_b(u), p)
+ case x => x
+ }
+ }
def infer_types_e(types: TypeMap)(e: Expression): Expression =
e map infer_types_e(types) match {
diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala
index 9c58da2c..2211d238 100644
--- a/src/main/scala/firrtl/passes/InferWidths.scala
+++ b/src/main/scala/firrtl/passes/InferWidths.scala
@@ -3,15 +3,20 @@
package firrtl.passes
// Datastructures
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.immutable.ListMap
-
import firrtl._
import firrtl.annotations.{Annotation, ReferenceTarget}
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
-import firrtl.traversals.Foreachers._
+import firrtl.Implicits.width2constraint
+import firrtl.annotations.{CircuitTarget, ModuleTarget, ReferenceTarget, Target}
+import firrtl.constraint.{ConstraintSolver, IsMax}
+
+object InferWidths {
+ def apply(): InferWidths = new InferWidths()
+ def run(c: Circuit): Circuit = new InferWidths().run(c)
+ def execute(state: CircuitState): CircuitState = new InferWidths().execute(state)
+}
case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarget) extends Annotation {
def update(renameMap: RenameMap): Seq[WidthGeqConstraintAnnotation] = {
@@ -33,369 +38,146 @@ case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarg
}
}
+/** Infers the widths of all signals with unknown widths
+ *
+ * Is a global width inference algorithm
+ * - Instances of the same module with unknown input port widths will be assigned the
+ * largest width of all assignments to each of its instance ports
+ * - If you don't want the global inference behavior, then be sure to define all your input widths
+ *
+ * Infers the smallest width is larger than all assigned widths to a signal
+ * - Note that this means that dummy assignments that are overwritten by last-connect-semantics
+ * can still influence width inference
+ * - E.g.
+ * wire x: UInt
+ * x <= UInt<5>(15)
+ * x <= UInt<1>(1)
+ *
+ * Since width inference occurs before lowering, it infers x's width to be 5 but with an assignment of UInt(1):
+ *
+ * wire x: UInt<5>
+ * x <= UInt<1>(1)
+ *
+ * Uses firrtl.constraint package to infer widths
+ */
class InferWidths extends Transform with ResolvedAnnotationPaths {
def inputForm: CircuitForm = UnknownForm
def outputForm: CircuitForm = UnknownForm
- val annotationClasses = Seq(classOf[WidthGeqConstraintAnnotation])
+ private val constraintSolver = new ConstraintSolver()
- type ConstraintMap = collection.mutable.LinkedHashMap[String, Width]
-
- def solve_constraints(l: Seq[WGeq]): ConstraintMap = {
- def unique(ls: Seq[Width]) : Seq[Width] =
- (ls map (new WrappedWidth(_))).distinct map (_.w)
- // Combines constraints on the same VarWidth into the same constraint
- def make_unique(ls: Seq[WGeq]): ListMap[String,Width] = {
- ls.foldLeft(ListMap.empty[String, Width])((acc, wgeq) => wgeq.loc match {
- case VarWidth(name) => acc.get(name) match {
- case None => acc + (name -> wgeq.exp)
- // Avoid constructing massive MaxWidth chains
- case Some(MaxWidth(args)) => acc + (name -> MaxWidth(wgeq.exp +: args))
- case Some(width) => acc + (name -> MaxWidth(Seq(wgeq.exp, width)))
- }
- case _ => acc
- })
- }
- def pullMinMax(w: Width): Width = w map pullMinMax match {
- case PlusWidth(MaxWidth(maxs), IntWidth(i)) => MaxWidth(maxs.map(m => PlusWidth(m, IntWidth(i))))
- case PlusWidth(IntWidth(i), MaxWidth(maxs)) => MaxWidth(maxs.map(m => PlusWidth(m, IntWidth(i))))
- case MinusWidth(MaxWidth(maxs), IntWidth(i)) => MaxWidth(maxs.map(m => MinusWidth(m, IntWidth(i))))
- case MinusWidth(IntWidth(i), MaxWidth(maxs)) => MaxWidth(maxs.map(m => MinusWidth(IntWidth(i), m)))
- case PlusWidth(MinWidth(mins), IntWidth(i)) => MinWidth(mins.map(m => PlusWidth(m, IntWidth(i))))
- case PlusWidth(IntWidth(i), MinWidth(mins)) => MinWidth(mins.map(m => PlusWidth(m, IntWidth(i))))
- case MinusWidth(MinWidth(mins), IntWidth(i)) => MinWidth(mins.map(m => MinusWidth(m, IntWidth(i))))
- case MinusWidth(IntWidth(i), MinWidth(mins)) => MinWidth(mins.map(m => MinusWidth(IntWidth(i), m)))
- case wx => wx
- }
- def collectMinMax(w: Width): Width = w map collectMinMax match {
- case MinWidth(args) => MinWidth(unique(args.foldLeft(List[Width]()) {
- case (res, wxx: MinWidth) => wxx.args ++: res
- case (res, wxx) => wxx +: res
- }))
- case MaxWidth(args) => MaxWidth(unique(args.foldLeft(List[Width]()) {
- case (res, wxx: MaxWidth) => wxx.args ++: res
- case (res, wxx) => wxx +: res
- }))
- case wx => wx
- }
- def mergePlusMinus(w: Width): Width = w map mergePlusMinus match {
- case wx: PlusWidth => (wx.arg1, wx.arg2) match {
- case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width + w2.width)
- case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1)
- case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1)
- case (IntWidth(y), PlusWidth(w1, IntWidth(x))) => PlusWidth(IntWidth(x + y), w1)
- case (IntWidth(y), PlusWidth(IntWidth(x), w1)) => PlusWidth(IntWidth(x + y), w1)
- case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(y - x), w1)
- case (IntWidth(y), MinusWidth(w1, IntWidth(x))) => PlusWidth(IntWidth(y - x), w1)
- case _ => wx
- }
- case wx: MinusWidth => (wx.arg1, wx.arg2) match {
- case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width)
- case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1)
- case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1)
- case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1)
- case _ => wx
- }
- case wx: ExpWidth => wx.arg1 match {
- case w1: IntWidth => IntWidth(BigInt((math.pow(2, w1.width.toDouble) - 1).toLong))
- case _ => wx
- }
- case wx => wx
- }
- def removeZeros(w: Width): Width = w map removeZeros match {
- case wx: PlusWidth => (wx.arg1, wx.arg2) match {
- case (w1, IntWidth(x)) if x == 0 => w1
- case (IntWidth(x), w1) if x == 0 => w1
- case _ => wx
- }
- case wx: MinusWidth => (wx.arg1, wx.arg2) match {
- case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width)
- case (w1, IntWidth(x)) if x == 0 => w1
- case _ => wx
- }
- case wx => wx
- }
- def simplify(w: Width): Width = {
- val opts = Seq(
- pullMinMax _,
- collectMinMax _,
- mergePlusMinus _,
- removeZeros _
- )
- opts.foldLeft(w) { (width, opt) => opt(width) }
- }
-
- def substitute(h: ConstraintMap)(w: Width): Width = {
- //;println-all-debug(["Substituting for [" w "]"])
- val wx = simplify(w)
- //;println-all-debug(["After Simplify: [" wx "]"])
- wx map substitute(h) match {
- //;("matched println-debugvarwidth!")
- case wxx: VarWidth => h get wxx.name match {
- case None => wxx
- case Some(p) =>
- //;println-debug("Contained!")
- //;println-all-debug(["Width: " wxx])
- //;println-all-debug(["Accessed: " h[name(wxx)]])
- val t = simplify(substitute(h)(p))
- h(wxx.name) = t
- t
- }
- case wxx => wxx
- //;println-all-debug(["not varwidth!" w])
- }
- }
-
- def b_sub(h: ConstraintMap)(w: Width): Width = {
- w map b_sub(h) match {
- case wx: VarWidth => h getOrElse (wx.name, wx)
- case wx => wx
- }
- }
-
- def remove_cycle(n: String)(w: Width): Width = {
- //;println-all-debug(["Removing cycle for " n " inside " w])
- w match {
- case wx: MaxWidth => MaxWidth(wx.args filter {
- case wxx: VarWidth => !(n equals wxx.name)
- case MinusWidth(VarWidth(name), IntWidth(i)) if ((i >= 0) && (n == name)) => false
- case _ => true
- })
- case wx: MinusWidth => wx.arg1 match {
- case v: VarWidth if n == v.name => v
- case v => wx
- }
- case wx => wx
- }
- //;println-all-debug(["After removing cycle for " n ", returning " wx])
- }
+ val annotationClasses = Seq(classOf[WidthGeqConstraintAnnotation])
- def hasVarWidth(n: String)(w: Width): Boolean = {
- var has = false
- def rec(w: Width): Width = {
- w match {
- case wx: VarWidth if wx.name == n => has = true
- case _ =>
+ private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1,t2) match {
+ case (UIntType(w1), UIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
+ case (SIntType(w1), SIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
+ case (ClockType, ClockType) =>
+ case (FixedType(w1, p1), FixedType(w2, p2)) =>
+ constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint(""))
+ constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
+ case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) =>
+ constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint(""))
+ constraintSolver.addLeq(l1, l2, r1.prettyPrint(""), r2.prettyPrint(""))
+ constraintSolver.addGeq(u1, u2, r1.prettyPrint(""), r2.prettyPrint(""))
+ case (AnalogType(w1), AnalogType(w2)) =>
+ constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
+ constraintSolver.addGeq(w2, w1, r1.prettyPrint(""), r2.prettyPrint(""))
+ case (t1: BundleType, t2: BundleType) =>
+ (t1.fields zip t2.fields) foreach { case (f1, f2) =>
+ (f1.flip, f2.flip) match {
+ case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe)
+ case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe)
+ case _ => sys.error("Shouldn't be here")
}
- w map rec
- }
- rec(w)
- has
- }
-
- //; Forward solve
- //; Returns a solved list where each constraint undergoes:
- //; 1) Continuous Solving (using triangular solving)
- //; 2) Remove Cycles
- //; 3) Move to solved if not self-recursive
- val u = make_unique(l)
-
- //println("======== UNIQUE CONSTRAINTS ========")
- //for (x <- u) { println(x) }
- //println("====================================")
-
- val f = new ConstraintMap
- val o = ArrayBuffer[String]()
- for ((n, e) <- u) {
- //println("==== SOLUTIONS TABLE ====")
- //for (x <- f) println(x)
- //println("=========================")
-
- val e_sub = simplify(substitute(f)(e))
-
- //println("Solving " + n + " => " + e)
- //println("After Substitute: " + n + " => " + e_sub)
- //println("==== SOLUTIONS TABLE (Post Substitute) ====")
- //for (x <- f) println(x)
- //println("=========================")
-
- val ex = remove_cycle(n)(e_sub)
-
- //println("After Remove Cycle: " + n + " => " + ex)
- if (!hasVarWidth(n)(ex)) {
- //println("Not rec!: " + n + " => " + ex)
- //println("Adding [" + n + "=>" + ex + "] to Solutions Table")
- f(n) = ex
- o += n
}
- }
-
- //println("Forward Solved Constraints")
- //for (x <- f) println(x)
-
- //; Backwards Solve
- val b = new ConstraintMap
- for (i <- (o.size - 1) to 0 by -1) {
- val n = o(i) // Should visit `o` backward
- /*
- println("SOLVE BACK: [" + n + " => " + f(n) + "]")
- println("==== SOLUTIONS TABLE ====")
- for (x <- b) println(x)
- println("=========================")
- */
- val ex = simplify(b_sub(b)(f(n)))
- /*
- println("BACK RETURN: [" + n + " => " + ex + "]")
- */
- b(n) = ex
- /*
- println("==== SOLUTIONS TABLE (Post backsolve) ====")
- for (x <- b) println(x)
- println("=========================")
- */
- }
- b
- }
-
- def get_constraints_t(t1: Type, t2: Type): Seq[WGeq] = (t1,t2) match {
- case (t1: UIntType, t2: UIntType) => Seq(WGeq(t1.width, t2.width))
- case (t1: SIntType, t2: SIntType) => Seq(WGeq(t1.width, t2.width))
- case (ClockType, ClockType) => Nil
+ case (t1: VectorType, t2: VectorType) => addTypeConstraints(r1.index(0), r2.index(0))(t1.tpe, t2.tpe)
case (AsyncResetType, AsyncResetType) => Nil
- case (FixedType(w1, p1), FixedType(w2, p2)) => Seq(WGeq(w1,w2), WGeq(p1,p2))
- case (AnalogType(w1), AnalogType(w2)) => Seq(WGeq(w1,w2), WGeq(w2,w1))
- case (t1: BundleType, t2: BundleType) =>
- (t1.fields zip t2.fields foldLeft Seq[WGeq]()){case (res, (f1, f2)) =>
- res ++ (f1.flip match {
- case Default => get_constraints_t(f1.tpe, f2.tpe)
- case Flip => get_constraints_t(f2.tpe, f1.tpe)
- })
- }
- case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe)
case (ResetType, _) => Nil
case (_, ResetType) => Nil
}
- def run(c: Circuit, extra: Seq[WGeq]): Circuit = {
- val v = ArrayBuffer[WGeq]() ++ extra
+ private def addExpConstraints(e: Expression): Expression = e map addExpConstraints match {
+ case m@Mux(p, tVal, fVal, t) =>
+ constraintSolver.addGeq(getWidth(p), Closed(1), "mux predicate", "1.W")
+ m
+ case other => other
+ }
- def get_constraints_e(e: Expression): Unit = {
- e match {
- case (e: Mux) => v ++= Seq(
- WGeq(getWidth(e.cond), IntWidth(1)),
- WGeq(IntWidth(1), getWidth(e.cond))
- )
- case _ =>
+ private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s map addExpConstraints match {
+ case c: Connect =>
+ val n = get_size(c.loc.tpe)
+ val locs = create_exps(c.loc)
+ val exps = create_exps(c.expr)
+ (locs zip exps).foreach { case (loc, exp) =>
+ to_flip(flow(loc)) match {
+ case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
+ case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
+ }
}
- e.foreach(get_constraints_e)
- }
-
- def get_constraints_declared_type (t: Type): Type = t match {
- case FixedType(_, p) =>
- v += WGeq(p,IntWidth(0))
- t
- case _ => t map get_constraints_declared_type
- }
-
- def get_constraints_s(s: Statement): Unit = {
- s map get_constraints_declared_type match {
- case (s: Connect) =>
- val locs = create_exps(s.loc)
- val exps = create_exps(s.expr)
- v ++= locs.zip(exps).flatMap { case (locx, expx) =>
- to_flip(flow(locx)) match {
- case Default => get_constraints_t(locx.tpe, expx.tpe)//WGeq(getWidth(locx), getWidth(expx))
- case Flip => get_constraints_t(expx.tpe, locx.tpe)//WGeq(getWidth(expx), getWidth(locx))
- }
- }
- case (s: PartialConnect) =>
- val ls = get_valid_points(s.loc.tpe, s.expr.tpe, Default, Default)
- val locs = create_exps(s.loc)
- val exps = create_exps(s.expr)
- v ++= (ls flatMap {case (x, y) =>
- val locx = locs(x)
- val expx = exps(y)
- to_flip(flow(locx)) match {
- case Default => get_constraints_t(locx.tpe, expx.tpe)//WGeq(getWidth(locx), getWidth(expx))
- case Flip => get_constraints_t(expx.tpe, locx.tpe)//WGeq(getWidth(expx), getWidth(locx))
- }
- })
- case (s: DefRegister) =>
- if (s.reset.tpe != AsyncResetType ) {
- v ++= (
- get_constraints_t(s.reset.tpe, UIntType(IntWidth(1))) ++
- get_constraints_t(UIntType(IntWidth(1)), s.reset.tpe))
- }
- v ++= get_constraints_t(s.tpe, s.init.tpe)
- case (s:Conditionally) => v ++=
- get_constraints_t(s.pred.tpe, UIntType(IntWidth(1))) ++
- get_constraints_t(UIntType(IntWidth(1)), s.pred.tpe)
- case Attach(_, exprs) =>
- // All widths must be equal
- val widths = exprs map (e => getWidth(e.tpe))
- v ++= widths.tail map (WGeq(widths.head, _))
- case _ =>
+ c
+ case pc: PartialConnect =>
+ val ls = get_valid_points(pc.loc.tpe, pc.expr.tpe, Default, Default)
+ val locs = create_exps(pc.loc)
+ val exps = create_exps(pc.expr)
+ ls foreach { case (x, y) =>
+ val loc = locs(x)
+ val exp = exps(y)
+ to_flip(flow(loc)) match {
+ case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
+ case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
+ }
}
- s.foreach(get_constraints_e)
- s.foreach(get_constraints_s)
- }
-
- c.modules.foreach(_.foreach(get_constraints_s))
- c.modules.foreach(_.ports.foreach({p => get_constraints_declared_type(p.tpe)}))
-
- //println("======== ALL CONSTRAINTS ========")
- //for(x <- v) println(x)
- //println("=================================")
- val h = solve_constraints(v)
- //println("======== SOLVED CONSTRAINTS ========")
- //for(x <- h) println(x)
- //println("====================================")
-
- def evaluate(w: Width): Width = {
- def map2(a: Option[BigInt], b: Option[BigInt], f: (BigInt,BigInt) => BigInt): Option[BigInt] =
- for (a_num <- a; b_num <- b) yield f(a_num, b_num)
- def reduceOptions(l: Seq[Option[BigInt]], f: (BigInt,BigInt) => BigInt): Option[BigInt] =
- l.reduce(map2(_, _, f))
+ pc
+ case r: DefRegister =>
+ if (r.reset.tpe != AsyncResetType ) {
+ addTypeConstraints(Target.asTarget(mt)(r.reset), mt.ref("1"))(r.reset.tpe, UIntType(IntWidth(1)))
+ }
+ addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe)
+ r
+ case a@Attach(_, exprs) =>
+ val widths = exprs map (e => (e, getWidth(e.tpe)))
+ val maxWidth = IsMax(widths.map(x => width2constraint(x._2)))
+ widths.foreach { case (e, w) =>
+ constraintSolver.addGeq(w, CalcWidth(maxWidth), Target.asTarget(mt)(e).prettyPrint(""), mt.ref(a.serialize).prettyPrint(""))
+ }
+ a
+ case c: Conditionally =>
+ addTypeConstraints(Target.asTarget(mt)(c.pred), mt.ref("1.W"))(c.pred.tpe, UIntType(IntWidth(1)))
+ c map addStmtConstraints(mt)
+ case x => x map addStmtConstraints(mt)
+ }
+ private def fixWidth(w: Width): Width = constraintSolver.get(w) match {
+ case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt)
+ case None => w
+ case _ => sys.error("Shouldn't be here")
+ }
+ private def fixType(t: Type): Type = t map fixType map fixWidth match {
+ case IntervalType(l, u, p) =>
+ val (lx, ux) = (constraintSolver.get(l), constraintSolver.get(u)) match {
+ case (Some(x: Bound), Some(y: Bound)) => (x, y)
+ case (None, None) => (l, u)
+ case x => sys.error(s"Shouldn't be here: $x")
- // This function shouldn't be necessary
- // Added as protection in case a constraint accidentally uses MinWidth/MaxWidth
- // without any actual Widths. This should be elevated to an earlier error
- def forceNonEmpty(in: Seq[Option[BigInt]], default: Option[BigInt]): Seq[Option[BigInt]] =
- if (in.isEmpty) Seq(default)
- else in
- def solve(w: Width): Option[BigInt] = w match {
- case wx: VarWidth =>
- for{
- v <- h.get(wx.name) if !v.isInstanceOf[VarWidth]
- result <- solve(v)
- } yield result
- case wx: MaxWidth => reduceOptions(forceNonEmpty(wx.args.map(solve), Some(BigInt(0))), max)
- case wx: MinWidth => reduceOptions(forceNonEmpty(wx.args.map(solve), None), min)
- case wx: PlusWidth => map2(solve(wx.arg1), solve(wx.arg2), {_ + _})
- case wx: MinusWidth => map2(solve(wx.arg1), solve(wx.arg2), {_ - _})
- case wx: ExpWidth => map2(Some(BigInt(2)), solve(wx.arg1), pow_minus_one)
- case wx: IntWidth => Some(wx.width)
- case wx => throwInternalError(s"solve: shouldn't be here - %$wx")
}
+ IntervalType(lx, ux, fixWidth(p))
+ case FixedType(w, p) => FixedType(w, fixWidth(p))
+ case x => x
+ }
+ private def fixStmt(s: Statement): Statement = s map fixStmt map fixType
+ private def fixPort(p: Port): Port = {
+ Port(p.info, p.name, p.direction, fixType(p.tpe))
+ }
- solve(w) match {
- case None => w
- case Some(s) => IntWidth(s)
- }
- }
-
- def reduce_var_widths_w(w: Width): Width = {
- //println-all-debug(["REPLACE: " w])
- evaluate(w)
- //println-all-debug(["WITH: " wx])
- }
-
- def reduce_var_widths_t(t: Type): Type = {
- t map reduce_var_widths_t map reduce_var_widths_w
- }
-
- def reduce_var_widths_s(s: Statement): Statement = {
- s map reduce_var_widths_s map reduce_var_widths_t
- }
-
- def reduce_var_widths_p(p: Port): Port = {
- Port(p.info, p.name, p.direction, reduce_var_widths_t(p.tpe))
- }
-
- InferTypes.run(c.copy(modules = c.modules map (_
- map reduce_var_widths_p
- map reduce_var_widths_s)))
+ def run (c: Circuit): Circuit = {
+ val ct = CircuitTarget(c.main)
+ c.modules foreach ( m => m map addStmtConstraints(ct.module(m.name)))
+ constraintSolver.solve()
+ val ret = InferTypes.run(c.copy(modules = c.modules map (_
+ map fixPort
+ map fixStmt)))
+ constraintSolver.clear()
+ ret
}
def execute(state: CircuitState): CircuitState = {
@@ -426,7 +208,7 @@ class InferWidths extends Transform with ResolvedAnnotationPaths {
}
}
- val extraConstraints = state.annotations.flatMap {
+ state.annotations.foreach {
case anno: WidthGeqConstraintAnnotation if anno.loc.isLocal && anno.exp.isLocal =>
val locType :: expType :: Nil = Seq(anno.loc, anno.exp) map { target =>
val baseType = typeMap.getOrElse(target.copy(component = Seq.empty),
@@ -440,10 +222,12 @@ class InferWidths extends Transform with ResolvedAnnotationPaths {
leafType
}
- get_constraints_t(locType, expType)
- case other => Seq.empty
+ //get_constraints_t(locType, expType)
+ addTypeConstraints(anno.loc, anno.exp)(locType, expType)
+ case other =>
}
- state.copy(circuit = run(state.circuit, extraConstraints))
+ state.copy(circuit = run(state.circuit))
}
+
}
diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
index 48b5f041..921ec3c7 100644
--- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
+++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
@@ -72,7 +72,7 @@ object RemoveCHIRRTL extends Transform {
refs: DataRefMap, raddrs: AddrMap, renames: RenameMap)(s: Statement): Statement = s match {
case sx: CDefMemory =>
types(sx.name) = sx.tpe
- val taddr = UIntType(IntWidth(1 max ceilLog2(sx.size)))
+ val taddr = UIntType(IntWidth(1 max getUIntWidth(sx.size - 1)))
val tdata = sx.tpe
def set_poison(vec: Seq[MPort]) = vec flatMap (r => Seq(
IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "addr", taddr)),
diff --git a/src/main/scala/firrtl/passes/RemoveIntervals.scala b/src/main/scala/firrtl/passes/RemoveIntervals.scala
new file mode 100644
index 00000000..73f59b59
--- /dev/null
+++ b/src/main/scala/firrtl/passes/RemoveIntervals.scala
@@ -0,0 +1,173 @@
+// See LICENSE for license details.
+
+package firrtl.passes
+
+import firrtl.PrimOps._
+import firrtl.ir._
+import firrtl._
+import firrtl.Mappers._
+import Implicits.{bigint2WInt}
+import firrtl.constraint.IsKnown
+
+import scala.math.BigDecimal.RoundingMode._
+
+class WrapWithRemainder(info: Info, mname: String, wrap: DoPrim)
+ extends PassException({
+ val toWrap = wrap.args.head.serialize
+ val toWrapTpe = wrap.args.head.tpe.serialize
+ val wrapTo = wrap.args(1).serialize
+ val wrapToTpe = wrap.args(1).tpe.serialize
+ s"$info: [module $mname] Wraps with remainder currently unsupported: $toWrap:$toWrapTpe cannot be wrapped to $wrapTo's type $wrapToTpe"
+ })
+
+
+/** Replaces IntervalType with SIntType, three AST walks:
+ * 1) Align binary points
+ * - adds shift operators to primop args and connections
+ * - does not affect declaration- or inferred-types
+ * 2) Replace Interval [[DefNode]] with [[DefWire]] + [[Connect]]
+ * - You have to do this to capture the smaller bitwidths of nodes that intervals give you. Otherwise, any future
+ * InferTypes would reinfer the larger widths on these nodes from SInt width inference rules
+ * 3) Replace declaration IntervalType's with SIntType's
+ * - for each declaration:
+ * a. remove non-zero binary points
+ * b. remove open bounds
+ * c. replace with SIntType
+ * 3) Run InferTypes
+ */
+class RemoveIntervals extends Pass {
+
+ def run(c: Circuit): Circuit = {
+ val alignedCircuit = c
+ val errors = new Errors()
+ val wiredCircuit = alignedCircuit map makeWireModule
+ val replacedCircuit = wiredCircuit map replaceModuleInterval(errors)
+ errors.trigger()
+ InferTypes.run(replacedCircuit)
+ }
+
+ /* Replace interval types */
+ private def replaceModuleInterval(errors: Errors)(m: DefModule): DefModule =
+ m map replaceStmtInterval(errors, m.name) map replacePortInterval
+
+ private def replaceStmtInterval(errors: Errors, mname: String)(s: Statement): Statement = {
+ val info = s match {
+ case h: HasInfo => h.info
+ case _ => NoInfo
+ }
+ s map replaceTypeInterval map replaceStmtInterval(errors, mname) map replaceExprInterval(errors, info, mname)
+
+ }
+
+ private def replaceExprInterval(errors: Errors, info: Info, mname: String)(e: Expression): Expression = e match {
+ case _: WRef | _: WSubIndex | _: WSubField => e
+ case o =>
+ o map replaceExprInterval(errors, info, mname) match {
+ case DoPrim(AsInterval, Seq(a1), _, tpe) => DoPrim(AsSInt, Seq(a1), Seq.empty, tpe)
+ case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe)
+ case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe)
+ case DoPrim(Clip, Seq(a1, _), Nil, tpe: IntervalType) =>
+ // Output interval (pre-calculated)
+ val clipLo = tpe.minAdjusted.get
+ val clipHi = tpe.maxAdjusted.get
+ // Input interval
+ val (inLow, inHigh) = a1.tpe match {
+ case t2: IntervalType => (t2.minAdjusted.get, t2.maxAdjusted.get)
+ case _ => sys.error("Shouldn't be here")
+ }
+ val gtOpt = clipHi >= inHigh
+ val ltOpt = clipLo <= inLow
+ (gtOpt, ltOpt) match {
+ // input range within output range -> no optimization
+ case (true, true) => a1
+ case (true, false) => Mux(Lt(a1, clipLo.S), clipLo.S, a1)
+ case (false, true) => Mux(Gt(a1, clipHi.S), clipHi.S, a1)
+ case _ => Mux(Gt(a1, clipHi.S), clipHi.S, Mux(Lt(a1, clipLo.S), clipLo.S, a1))
+ }
+
+ case sqz@DoPrim(Squeeze, Seq(a1, a2), Nil, tpe: IntervalType) =>
+ // Using (conditional) reassign interval w/o adding mux
+ val a1tpe = a1.tpe.asInstanceOf[IntervalType]
+ val a2tpe = a2.tpe.asInstanceOf[IntervalType]
+ val min2 = a2tpe.min.get * BigDecimal(BigInt(1) << a1tpe.point.asInstanceOf[IntWidth].width.toInt)
+ val max2 = a2tpe.max.get * BigDecimal(BigInt(1) << a1tpe.point.asInstanceOf[IntWidth].width.toInt)
+ val w1 = Seq(a1tpe.minAdjusted.get.bitLength, a1tpe.maxAdjusted.get.bitLength).max + 1
+ // Conservative
+ val minOpt2 = min2.setScale(0, FLOOR).toBigInt
+ val maxOpt2 = max2.setScale(0, CEILING).toBigInt
+ val w2 = Seq(minOpt2.bitLength, maxOpt2.bitLength).max + 1
+ if (w1 < w2) {
+ a1
+ } else {
+ val bits = DoPrim(Bits, Seq(a1), Seq(w2 - 1, 0), UIntType(IntWidth(w2)))
+ DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(w2)))
+ }
+ case w@DoPrim(Wrap, Seq(a1, a2), Nil, tpe: IntervalType) => a2.tpe match {
+ // If a2 type is Interval wrap around range. If UInt, wrap around width
+ case t: IntervalType =>
+ // Need to match binary points before getting *adjusted!
+ val (wrapLo, wrapHi) = t.copy(point = tpe.point) match {
+ case t: IntervalType => (t.minAdjusted.get, t.maxAdjusted.get)
+ case _ => Utils.throwInternalError(s"Illegal AST state: cannot have $e not have an IntervalType")
+ }
+ val (inLo, inHi) = a1.tpe match {
+ case t2: IntervalType => (t2.minAdjusted.get, t2.maxAdjusted.get)
+ case _ => sys.error("Shouldn't be here")
+ }
+ // If (max input) - (max wrap) + (min wrap) is less then (maxwrap), we can optimize when (max input > max wrap)
+ val range = wrapHi - wrapLo
+ val ltOpt = Add(a1, (range + 1).S)
+ val gtOpt = Sub(a1, (range + 1).S)
+ // [Angie]: This is dangerous. Would rather throw compilation error right now than allow "Rem" without the user explicitly including it.
+ // If x < wl
+ // output: wh - (wl - x) + 1 AKA x + r + 1
+ // worst case: wh - (wl - xl) + 1 = wl
+ // -> xl + wr + 1 = wl
+ // If x > wh
+ // output: wl + (x - wh) - 1 AKA x - r - 1
+ // worst case: wl + (xh - wh) - 1 = wh
+ // -> xh - wr - 1 = wh
+ val default = Add(Rem(Sub(a1, wrapLo.S), Sub(wrapHi.S, wrapLo.S)), wrapLo.S)
+ (wrapHi >= inHi, wrapLo <= inLo, (inHi - range - 1) <= wrapHi, (inLo + range + 1) >= wrapLo) match {
+ case (true, true, _, _) => a1
+ case (true, _, _, true) => Mux(Lt(a1, wrapLo.S), ltOpt, a1)
+ case (_, true, true, _) => Mux(Gt(a1, wrapHi.S), gtOpt, a1)
+ // Note: inHi - range - 1 = wrapHi can't be true when inLo + range + 1 = wrapLo (i.e. simultaneous extreme cases don't work)
+ case (_, _, true, true) => Mux(Gt(a1, wrapHi.S), gtOpt, Mux(Lt(a1, wrapLo.S), ltOpt, a1))
+ case _ =>
+ errors.append(new WrapWithRemainder(info, mname, w))
+ default
+ }
+ case _ => sys.error("Shouldn't be here")
+ }
+ case other => other
+ }
+ }
+
+ private def replacePortInterval(p: Port): Port = p map replaceTypeInterval
+
+ private def replaceTypeInterval(t: Type): Type = t match {
+ case i@IntervalType(l: IsKnown, u: IsKnown, p: IntWidth) => SIntType(i.width)
+ case i: IntervalType => sys.error(s"Shouldn't be here: $i")
+ case v => v map replaceTypeInterval
+ }
+
+ /** Replace Interval Nodes with Interval Wires
+ *
+ * You have to do this to capture the smaller bitwidths of nodes that intervals give you. Otherwise,
+ * any future InferTypes would reinfer the larger widths on these nodes from SInt width inference rules
+ * @param m module to replace nodes with wire + connection
+ * @return
+ */
+ private def makeWireModule(m: DefModule): DefModule = m map makeWireStmt
+
+ private def makeWireStmt(s: Statement): Statement = s match {
+ case DefNode(info, name, value) => value.tpe match {
+ case IntervalType(l, u, p) =>
+ val newType = IntervalType(l, u, p)
+ Block(Seq(DefWire(info, name, newType), Connect(info, WRef(name, newType, WireKind, FEMALE), value)))
+ case other => s
+ }
+ case other => other map makeWireStmt
+ }
+}
diff --git a/src/main/scala/firrtl/passes/TrimIntervals.scala b/src/main/scala/firrtl/passes/TrimIntervals.scala
new file mode 100644
index 00000000..dec64ee7
--- /dev/null
+++ b/src/main/scala/firrtl/passes/TrimIntervals.scala
@@ -0,0 +1,97 @@
+// See LICENSE for license details.
+
+package firrtl.passes
+
+import scala.collection.mutable
+import firrtl.PrimOps._
+import firrtl.ir._
+import firrtl._
+import firrtl.Mappers._
+import firrtl.Utils.{error, field_type, getUIntWidth, max, module_type, sub_type}
+import Implicits.{bigint2WInt, int2WInt}
+import firrtl.constraint.{IsFloor, IsKnown, IsMul}
+
+/** Replaces IntervalType with SIntType, three AST walks:
+ * 1) Align binary points
+ * - adds shift operators to primop args and connections
+ * - does not affect declaration- or inferred-types
+ * 2) Replace declaration IntervalType's with SIntType's
+ * - for each declaration:
+ * a. remove non-zero binary points
+ * b. remove open bounds
+ * c. replace with SIntType
+ * 3) Run InferTypes
+ */
+class TrimIntervals extends Pass {
+ def run(c: Circuit): Circuit = {
+ // Open -> closed
+ val firstPass = InferTypes.run(c map replaceModuleInterval)
+ // Align binary points and adjust range accordingly (loss of precision changes range)
+ firstPass map alignModuleBP
+ }
+
+ /* Replace interval types */
+ private def replaceModuleInterval(m: DefModule): DefModule = m map replaceStmtInterval map replacePortInterval
+
+ private def replaceStmtInterval(s: Statement): Statement = s map replaceTypeInterval map replaceStmtInterval
+
+ private def replacePortInterval(p: Port): Port = p map replaceTypeInterval
+
+ private def replaceTypeInterval(t: Type): Type = t match {
+ case i@IntervalType(l: IsKnown, u: IsKnown, IntWidth(p)) =>
+ IntervalType(Closed(i.min.get), Closed(i.max.get), IntWidth(p))
+ case i: IntervalType => i
+ case v => v map replaceTypeInterval
+ }
+
+ /* Align interval binary points -- BINARY POINT ALIGNMENT AFFECTS RANGE INFERENCE! */
+ private def alignModuleBP(m: DefModule): DefModule = m map alignStmtBP
+
+ private def alignStmtBP(s: Statement): Statement = s map alignExpBP match {
+ case c@Connect(info, loc, expr) => loc.tpe match {
+ case IntervalType(_, _, p) => Connect(info, loc, fixBP(p)(expr))
+ case _ => c
+ }
+ case c@PartialConnect(info, loc, expr) => loc.tpe match {
+ case IntervalType(_, _, p) => PartialConnect(info, loc, fixBP(p)(expr))
+ case _ => c
+ }
+ case other => other map alignStmtBP
+ }
+
+ // Note - wrap/clip/squeeze ignore the binary point of the second argument, thus not needed to be aligned
+ // Note - Mul does not need its binary points aligned, because multiplication is cool like that
+ private val opsToFix = Seq(Add, Sub, Lt, Leq, Gt, Geq, Eq, Neq/*, Wrap, Clip, Squeeze*/)
+
+ private def alignExpBP(e: Expression): Expression = e map alignExpBP match {
+ case DoPrim(SetP, Seq(arg), Seq(const), tpe: IntervalType) => fixBP(IntWidth(const))(arg)
+ case DoPrim(o, args, consts, t) if opsToFix.contains(o) &&
+ (args.map(_.tpe).collect { case x: IntervalType => x }).size == args.size =>
+ val maxBP = args.map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ max _)
+ DoPrim(o, args.map { a => fixBP(maxBP)(a) }, consts, t)
+ case Mux(cond, tval, fval, t: IntervalType) =>
+ val maxBP = Seq(tval, fval).map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ max _)
+ Mux(cond, fixBP(maxBP)(tval), fixBP(maxBP)(fval), t)
+ case other => other
+ }
+ private def fixBP(p: Width)(e: Expression): Expression = (p, e.tpe) match {
+ case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired == current => e
+ case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired > current =>
+ DoPrim(IncP, Seq(e), Seq(desired - current), IntervalType(l, u, IntWidth(desired)))
+ case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired < current =>
+ val shiftAmt = current - desired
+ val shiftGain = BigDecimal(BigInt(1) << shiftAmt.toInt)
+ val shiftMul = Closed(BigDecimal(1) / shiftGain)
+ val bpGain = BigDecimal(BigInt(1) << current.toInt)
+ // BP is inferred at this point
+ // y = floor(x * 2^(-amt + bp)) gets rid of precision --> y * 2^(-bp + amt)
+ val newBPRes = Closed(shiftGain / bpGain)
+ val bpResInv = Closed(bpGain)
+ val newL = IsMul(IsFloor(IsMul(IsMul(l, shiftMul), bpResInv)), newBPRes)
+ val newU = IsMul(IsFloor(IsMul(IsMul(u, shiftMul), bpResInv)), newBPRes)
+ DoPrim(DecP, Seq(e), Seq(current - desired), IntervalType(CalcBound(newL), CalcBound(newU), IntWidth(desired)))
+ case x => sys.error(s"Shouldn't be here: $x")
+ }
+}
+
+// vim: set ts=4 sw=4 et:
diff --git a/src/main/scala/firrtl/passes/memlib/MemUtils.scala b/src/main/scala/firrtl/passes/memlib/MemUtils.scala
index bb441ebb..69c6b284 100644
--- a/src/main/scala/firrtl/passes/memlib/MemUtils.scala
+++ b/src/main/scala/firrtl/passes/memlib/MemUtils.scala
@@ -56,7 +56,7 @@ object MemPortUtils {
type Modules = collection.mutable.ArrayBuffer[DefModule]
def defaultPortSeq(mem: DefMemory): Seq[Field] = Seq(
- Field("addr", Default, UIntType(IntWidth(ceilLog2(mem.depth) max 1))),
+ Field("addr", Default, UIntType(IntWidth(getUIntWidth(mem.depth - 1) max 1))),
Field("en", Default, BoolType),
Field("clk", Default, ClockType)
)
diff --git a/src/main/scala/firrtl/proto/ToProto.scala b/src/main/scala/firrtl/proto/ToProto.scala
index 70de3ccd..9fe01a07 100644
--- a/src/main/scala/firrtl/proto/ToProto.scala
+++ b/src/main/scala/firrtl/proto/ToProto.scala
@@ -99,9 +99,13 @@ object ToProto {
Bits -> Op.OP_EXTRACT_BITS,
Head -> Op.OP_HEAD,
Tail -> Op.OP_TAIL,
- BPShl -> Op.OP_SHIFT_BINARY_POINT_LEFT,
- BPShr -> Op.OP_SHIFT_BINARY_POINT_RIGHT,
- BPSet -> Op.OP_SET_BINARY_POINT
+ IncP -> Op.OP_INCREASE_PRECISION,
+ DecP -> Op.OP_DECREASE_PRECISION,
+ SetP -> Op.OP_SET_PRECISION,
+ AsInterval -> Op.OP_AS_INTERVAL,
+ Squeeze -> Op.OP_SQUEEZE,
+ Wrap -> Op.OP_WRAP,
+ Clip -> Op.OP_CLIP
)
def convert(ruw: ir.ReadUnderWrite.Value): ReadUnderWrite = ruw match {
diff --git a/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala b/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala
index d06344af..89f2ec07 100644
--- a/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala
+++ b/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala
@@ -3,14 +3,12 @@
package firrtl.stage.phases.tests
import org.scalatest.{FlatSpec, Matchers, PrivateMethodTester}
-
import java.io.File
import firrtl._
-import firrtl.stage._
import firrtl.stage.phases.DriverCompatibility._
-
import firrtl.options.{InputAnnotationFileAnnotation, Phase, TargetDirAnnotation}
+import firrtl.stage.{CompilerAnnotation, FirrtlCircuitAnnotation, FirrtlFileAnnotation, FirrtlSourceAnnotation, OutputFileAnnotation, RunFirrtlTransformAnnotation}
import firrtl.stage.phases.DriverCompatibility
class DriverCompatibilitySpec extends FlatSpec with Matchers with PrivateMethodTester {
diff --git a/src/test/scala/firrtlTests/AsyncResetSpec.scala b/src/test/scala/firrtlTests/AsyncResetSpec.scala
index 6fcb647a..8ad397b3 100644
--- a/src/test/scala/firrtlTests/AsyncResetSpec.scala
+++ b/src/test/scala/firrtlTests/AsyncResetSpec.scala
@@ -51,16 +51,19 @@ class AsyncResetSpec extends FirrtlFlatSpec {
it should "support casting to other types" in {
val result = compileBody(s"""
|input a : AsyncReset
+ |output u : Interval[0, 1].0
|output v : UInt<1>
|output w : SInt<1>
|output x : Clock
|output y : Fixed<1><<0>>
|output z : AsyncReset
+ |u <= asInterval(a, 0, 1, 0)
|v <= asUInt(a)
|w <= asSInt(a)
|x <= asClock(a)
|y <= asFixedPoint(a, 0)
- |z <= asAsyncReset(a)""".stripMargin
+ |z <= asAsyncReset(a)
+ |""".stripMargin
)
result should containLine ("assign v = $unsigned(a);")
result should containLine ("assign w = $signed(a);")
@@ -76,22 +79,26 @@ class AsyncResetSpec extends FirrtlFlatSpec {
|input c : Clock
|input d : Fixed<1><<0>>
|input e : AsyncReset
+ |input f : Interval[0, 1].0
+ |output u : AsyncReset
|output v : AsyncReset
|output w : AsyncReset
|output x : AsyncReset
|output y : AsyncReset
|output z : AsyncReset
- |v <= asAsyncReset(a)
- |w <= asAsyncReset(a)
- |x <= asAsyncReset(a)
- |y <= asAsyncReset(a)
- |z <= asAsyncReset(a)""".stripMargin
+ |u <= asAsyncReset(a)
+ |v <= asAsyncReset(b)
+ |w <= asAsyncReset(c)
+ |x <= asAsyncReset(d)
+ |y <= asAsyncReset(e)
+ |z <= asAsyncReset(f)""".stripMargin
)
- result should containLine ("assign v = a;")
- result should containLine ("assign w = a;")
- result should containLine ("assign x = a;")
- result should containLine ("assign y = a;")
- result should containLine ("assign z = a;")
+ result should containLine ("assign u = a;")
+ result should containLine ("assign v = b;")
+ result should containLine ("assign w = c;")
+ result should containLine ("assign x = d;")
+ result should containLine ("assign y = e;")
+ result should containLine ("assign z = f;")
}
"Non-literals" should "NOT be allowed as reset values for AsyncReset" in {
diff --git a/src/test/scala/firrtlTests/ChirrtlSpec.scala b/src/test/scala/firrtlTests/ChirrtlSpec.scala
index fba81ec7..b82637b6 100644
--- a/src/test/scala/firrtlTests/ChirrtlSpec.scala
+++ b/src/test/scala/firrtlTests/ChirrtlSpec.scala
@@ -70,8 +70,8 @@ class ChirrtlSpec extends FirrtlFlatSpec {
behavior of "Uniqueness"
for ((description, input) <- CheckSpec.nonUniqueExamples) {
it should s"be asserted for $description" in {
- assertThrows[CheckChirrtl.NotUniqueException] {
- Seq(ToWorkingIR, CheckChirrtl).foldLeft(Parser.parse(input)){ case (c, tx) => tx.run(c) }
+ assertThrows[CheckHighForm.NotUniqueException] {
+ Seq(ToWorkingIR, CheckHighForm).foldLeft(Parser.parse(input)){ case (c, tx) => tx.run(c) }
}
}
}
diff --git a/src/test/scala/firrtlTests/InfoSpec.scala b/src/test/scala/firrtlTests/InfoSpec.scala
index dbc997cd..9d6206af 100644
--- a/src/test/scala/firrtlTests/InfoSpec.scala
+++ b/src/test/scala/firrtlTests/InfoSpec.scala
@@ -66,7 +66,7 @@ class InfoSpec extends FirrtlFlatSpec {
result should containLine (s"assign n = w | x; //$Info3")
}
- they should "be propagated on memories" in {
+ it should "be propagated on memories" in {
val result = compileBody(s"""
|input clock : Clock
|input addr : UInt<5>
@@ -102,7 +102,7 @@ class InfoSpec extends FirrtlFlatSpec {
result should containLine (s"m[m_w_addr] <= m_w_data; //$Info1")
}
- they should "be propagated on instances" in {
+ it should "be propagated on instances" in {
val result = compile(s"""
|circuit Test :
| module Child :
diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala
index be9d738b..69379c51 100644
--- a/src/test/scala/firrtlTests/LowerTypesSpec.scala
+++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala
@@ -12,7 +12,7 @@ import firrtl.transforms._
import firrtl._
class LowerTypesSpec extends FirrtlFlatSpec {
- private val transforms = Seq(
+ private def transforms = Seq(
ToWorkingIR,
CheckHighForm,
ResolveKinds,
diff --git a/src/test/scala/firrtlTests/UniquifySpec.scala b/src/test/scala/firrtlTests/UniquifySpec.scala
index afb82384..e64e9105 100644
--- a/src/test/scala/firrtlTests/UniquifySpec.scala
+++ b/src/test/scala/firrtlTests/UniquifySpec.scala
@@ -16,7 +16,7 @@ import firrtl.util.TestOptions
class UniquifySpec extends FirrtlFlatSpec {
- private val transforms = Seq(
+ private def transforms = Seq(
ToWorkingIR,
CheckHighForm,
ResolveKinds,
diff --git a/src/test/scala/firrtlTests/ZeroWidthTests.scala b/src/test/scala/firrtlTests/ZeroWidthTests.scala
index eb3d1a96..f1dadcee 100644
--- a/src/test/scala/firrtlTests/ZeroWidthTests.scala
+++ b/src/test/scala/firrtlTests/ZeroWidthTests.scala
@@ -11,7 +11,7 @@ import firrtl.Parser
import firrtl.passes._
class ZeroWidthTests extends FirrtlFlatSpec {
- val transforms = Seq(
+ def transforms = Seq(
ToWorkingIR,
ResolveKinds,
InferTypes,
diff --git a/src/test/scala/firrtlTests/constraint/InequalitySpec.scala b/src/test/scala/firrtlTests/constraint/InequalitySpec.scala
new file mode 100644
index 00000000..02a853cb
--- /dev/null
+++ b/src/test/scala/firrtlTests/constraint/InequalitySpec.scala
@@ -0,0 +1,197 @@
+package firrtlTests.constraint
+
+import firrtl.constraint._
+import org.scalatest.{FlatSpec, Matchers}
+import firrtl.ir.Closed
+
+class InequalitySpec extends FlatSpec with Matchers {
+
+ behavior of "Constraints"
+
+ "IsConstraints" should "reduce properly" in {
+ IsMin(Closed(0), Closed(1)) should be (Closed(0))
+ IsMin(Closed(-1), Closed(1)) should be (Closed(-1))
+ IsMax(Closed(-1), Closed(1)) should be (Closed(1))
+ IsNeg(IsMul(Closed(-1), Closed(-2))) should be (Closed(-2))
+ val x = IsMin(IsMul(Closed(1), VarCon("a")), Closed(2))
+ x.children.toSet should be (IsMin(Closed(2), IsMul(Closed(1), VarCon("a"))).children.toSet)
+ }
+
+ "IsAdd" should "reduce properly" in {
+ // All constants
+ IsAdd(Closed(-1), Closed(1)) should be (Closed(0))
+
+ // Pull Out IsMax
+ IsAdd(Closed(1), IsMax(Closed(1), VarCon("a"))) should be (IsMax(Closed(2), IsAdd(VarCon("a"), Closed(1))))
+ IsAdd(Closed(1), IsMax(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be (
+ IsMax(Seq(Closed(2), IsAdd(VarCon("a"), Closed(1)), IsAdd(VarCon("b"), Closed(1))))
+ )
+
+ // Pull Out IsMin
+ IsAdd(Closed(1), IsMin(Closed(1), VarCon("a"))) should be (IsMin(Closed(2), IsAdd(VarCon("a"), Closed(1))))
+ IsAdd(Closed(1), IsMin(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be (
+ IsMin(Seq(Closed(2), IsAdd(VarCon("a"), Closed(1)), IsAdd(VarCon("b"), Closed(1))))
+ )
+
+ // Add Zero
+ IsAdd(Closed(0), VarCon("a")) should be (VarCon("a"))
+
+ // One argument
+ IsAdd(Seq(VarCon("a"))) should be (VarCon("a"))
+ }
+
+ "IsMax" should "reduce properly" in {
+ // All constants
+ IsMax(Closed(-1), Closed(1)) should be (Closed(1))
+
+ // Flatten nested IsMax
+ IsMax(Closed(1), IsMax(Closed(1), VarCon("a"))) should be (IsMax(Closed(1), VarCon("a")))
+ IsMax(Closed(1), IsMax(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be (
+ IsMax(Seq(Closed(1), VarCon("a"), VarCon("b")))
+ )
+
+ // Eliminate IsMins if possible
+ IsMax(Closed(2), IsMin(Closed(1), VarCon("a"))) should be (Closed(2))
+ IsMax(Seq(
+ Closed(2),
+ IsMin(Closed(1), VarCon("a")),
+ IsMin(Closed(3), VarCon("b"))
+ )) should be (
+ IsMax(Seq(
+ Closed(2),
+ IsMin(Closed(3), VarCon("b"))
+ ))
+ )
+
+ // One argument
+ IsMax(Seq(VarCon("a"))) should be (VarCon("a"))
+ IsMax(Seq(Closed(0))) should be (Closed(0))
+ IsMax(Seq(IsMin(VarCon("a"), Closed(0)))) should be (IsMin(VarCon("a"), Closed(0)))
+ }
+
+ "IsMin" should "reduce properly" in {
+ // All constants
+ IsMin(Closed(-1), Closed(1)) should be (Closed(-1))
+
+ // Flatten nested IsMin
+ IsMin(Closed(1), IsMin(Closed(1), VarCon("a"))) should be (IsMin(Closed(1), VarCon("a")))
+ IsMin(Closed(1), IsMin(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be (
+ IsMin(Seq(Closed(1), VarCon("a"), VarCon("b")))
+ )
+
+ // Eliminate IsMaxs if possible
+ IsMin(Closed(1), IsMax(Closed(2), VarCon("a"))) should be (Closed(1))
+ IsMin(Seq(
+ Closed(2),
+ IsMax(Closed(1), VarCon("a")),
+ IsMax(Closed(3), VarCon("b"))
+ )) should be (
+ IsMin(Seq(
+ Closed(2),
+ IsMax(Closed(1), VarCon("a"))
+ ))
+ )
+
+ // One argument
+ IsMin(Seq(VarCon("a"))) should be (VarCon("a"))
+ IsMin(Seq(Closed(0))) should be (Closed(0))
+ IsMin(Seq(IsMax(VarCon("a"), Closed(0)))) should be (IsMax(VarCon("a"), Closed(0)))
+ }
+
+ "IsMul" should "reduce properly" in {
+ // All constants
+ IsMul(Closed(2), Closed(3)) should be (Closed(6))
+
+ // Pull out max, if positive stays max
+ IsMul(Closed(2), IsMax(Closed(3), VarCon("a"))) should be(
+ IsMax(Closed(6), IsMul(Closed(2), VarCon("a")))
+ )
+
+ // Pull out max, if negative is min
+ IsMul(Closed(-2), IsMax(Closed(3), VarCon("a"))) should be(
+ IsMin(Closed(-6), IsMul(Closed(-2), VarCon("a")))
+ )
+
+ // Pull out min, if positive stays min
+ IsMul(Closed(2), IsMin(Closed(3), VarCon("a"))) should be(
+ IsMin(Closed(6), IsMul(Closed(2), VarCon("a")))
+ )
+
+ // Pull out min, if negative is max
+ IsMul(Closed(-2), IsMin(Closed(3), VarCon("a"))) should be(
+ IsMax(Closed(-6), IsMul(Closed(-2), VarCon("a")))
+ )
+
+ // Times zero
+ IsMul(Closed(0), VarCon("x")) should be (Closed(0))
+
+ // Times 1
+ IsMul(Closed(1), VarCon("x")) should be (VarCon("x"))
+
+ // One argument
+ IsMul(Seq(Closed(0))) should be (Closed(0))
+ IsMul(Seq(VarCon("a"))) should be (VarCon("a"))
+
+ // No optimizations
+ val isMax = IsMax(VarCon("x"), VarCon("y"))
+ val isMin = IsMin(VarCon("x"), VarCon("y"))
+ val a = VarCon("a")
+ IsMul(a, isMax).children should be (Vector(a, isMax)) //non-known multiply
+ IsMul(a, isMin).children should be (Vector(a, isMin)) //non-known multiply
+ IsMul(Seq(Closed(2), isMin, isMin)).children should be (Vector(Closed(2), isMin, isMin)) //>1 min
+ IsMul(Seq(Closed(2), isMax, isMax)).children should be (Vector(Closed(2), isMax, isMax)) //>1 max
+ IsMul(Seq(Closed(2), isMin, isMax)).children should be (Vector(Closed(2), isMin, isMax)) //mixed min/max
+ }
+
+ "IsNeg" should "reduce properly" in {
+ // All constants
+ IsNeg(Closed(1)) should be (Closed(-1))
+ // Pull out max
+ IsNeg(IsMax(Closed(1), VarCon("a"))) should be (IsMin(Closed(-1), IsNeg(VarCon("a"))))
+ // Pull out min
+ IsNeg(IsMin(Closed(1), VarCon("a"))) should be (IsMax(Closed(-1), IsNeg(VarCon("a"))))
+ // Pull out add
+ IsNeg(IsAdd(Closed(1), VarCon("a"))) should be (IsAdd(Closed(-1), IsNeg(VarCon("a"))))
+ // Pull out mul
+ IsNeg(IsMul(Closed(2), VarCon("a"))) should be (IsMul(Closed(-2), VarCon("a")))
+ // No optimizations
+ // (pow), (floor?)
+ IsNeg(IsPow(VarCon("x"))).children should be (Vector(IsPow(VarCon("x"))))
+ IsNeg(IsFloor(VarCon("x"))).children should be (Vector(IsFloor(VarCon("x"))))
+ }
+
+ "IsPow" should "reduce properly" in {
+ // All constants
+ IsPow(Closed(1)) should be (Closed(2))
+ // Pull out max
+ IsPow(IsMax(Closed(1), VarCon("a"))) should be (IsMax(Closed(2), IsPow(VarCon("a"))))
+ // Pull out min
+ IsPow(IsMin(Closed(1), VarCon("a"))) should be (IsMin(Closed(2), IsPow(VarCon("a"))))
+ // Pull out add
+ IsPow(IsAdd(Closed(1), VarCon("a"))) should be (IsMul(Closed(2), IsPow(VarCon("a"))))
+ // No optimizations
+ // (mul), (pow), (floor?)
+ IsPow(IsMul(Closed(2), VarCon("x"))).children should be (Vector(IsMul(Closed(2), VarCon("x"))))
+ IsPow(IsPow(VarCon("x"))).children should be (Vector(IsPow(VarCon("x"))))
+ IsPow(IsFloor(VarCon("x"))).children should be (Vector(IsFloor(VarCon("x"))))
+ }
+
+ "IsFloor" should "reduce properly" in {
+ // All constants
+ IsFloor(Closed(1.9)) should be (Closed(1))
+ IsFloor(Closed(-1.9)) should be (Closed(-2))
+ // Pull out max
+ IsFloor(IsMax(Closed(1.9), VarCon("a"))) should be (IsMax(Closed(1), IsFloor(VarCon("a"))))
+ // Pull out min
+ IsFloor(IsMin(Closed(1.9), VarCon("a"))) should be (IsMin(Closed(1), IsFloor(VarCon("a"))))
+ // Cancel with another floor
+ IsFloor(IsFloor(VarCon("a"))) should be (IsFloor(VarCon("a")))
+ // No optimizations
+ // (add), (mul), (pow)
+ IsFloor(IsMul(Closed(2), VarCon("x"))).children should be (Vector(IsMul(Closed(2), VarCon("x"))))
+ IsFloor(IsPow(VarCon("x"))).children should be (Vector(IsPow(VarCon("x"))))
+ IsFloor(IsAdd(Closed(1), VarCon("x"))).children should be (Vector(IsAdd(Closed(1), VarCon("x"))))
+ }
+
+}
+
diff --git a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala
index 6bf86479..a34145ac 100644
--- a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala
+++ b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala
@@ -21,6 +21,36 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec {
}
}
+ "Fixed types" should "infer add correctly if only precision unspecified" in {
+ val passes = Seq(
+ ToWorkingIR,
+ CheckHighForm,
+ ResolveKinds,
+ InferTypes,
+ CheckTypes,
+ ResolveFlows,
+ CheckFlows,
+ new InferWidths,
+ CheckWidths)
+ val input =
+ """circuit Unit :
+ | module Unit :
+ | input a : Fixed<10><<2>>
+ | input b : Fixed<10><<0>>
+ | input c : Fixed<4><<3>>
+ | output d : Fixed<13>
+ | d <= add(a, add(b, c))""".stripMargin
+ val check =
+ """circuit Unit :
+ | module Unit :
+ | input a : Fixed<10><<2>>
+ | input b : Fixed<10><<0>>
+ | input c : Fixed<4><<3>>
+ | output d : Fixed<13><<3>>
+ | d <= add(a, add(b, c))""".stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+
"Fixed types" should "infer add correctly" in {
val passes = Seq(
ToWorkingIR,
@@ -36,7 +66,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec {
"""circuit Unit :
| module Unit :
| input a : Fixed<10><<2>>
- | input b : Fixed<10>
+ | input b : Fixed<10><<0>>
| input c : Fixed<4><<3>>
| output d : Fixed
| d <= add(a, add(b, c))""".stripMargin
@@ -119,13 +149,13 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec {
| module Unit :
| input a : Fixed<10><<2>>
| output d : Fixed
- | d <= bpshl(a, 2)""".stripMargin
+ | d <= incp(a, 2)""".stripMargin
val check =
"""circuit Unit :
| module Unit :
| input a : Fixed<10><<2>>
| output d : Fixed<12><<4>>
- | d <= bpshl(a, 2)""".stripMargin
+ | d <= incp(a, 2)""".stripMargin
executeTest(input, check.split("\n") map normalized, passes)
}
@@ -145,13 +175,13 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec {
| module Unit :
| input a : Fixed<10><<2>>
| output d : Fixed
- | d <= bpshr(a, 2)""".stripMargin
+ | d <= decp(a, 2)""".stripMargin
val check =
"""circuit Unit :
| module Unit :
| input a : Fixed<10><<2>>
| output d : Fixed<8><<0>>
- | d <= bpshr(a, 2)""".stripMargin
+ | d <= decp(a, 2)""".stripMargin
executeTest(input, check.split("\n") map normalized, passes)
}
@@ -171,13 +201,13 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec {
| module Unit :
| input a : Fixed<10><<2>>
| output d : Fixed
- | d <= bpset(a, 3)""".stripMargin
+ | d <= setp(a, 3)""".stripMargin
val check =
"""circuit Unit :
| module Unit :
| input a : Fixed<10><<2>>
| output d : Fixed<11><<3>>
- | d <= bpset(a, 3)""".stripMargin
+ | d <= setp(a, 3)""".stripMargin
executeTest(input, check.split("\n") map normalized, passes)
}
diff --git a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala
index 8686bd0f..f5b16e45 100644
--- a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala
+++ b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala
@@ -14,7 +14,6 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec {
(c: CircuitState, p: Transform) => p.runTransform(c)
}.circuit
val lines = c.serialize.split("\n") map normalized
- println(c.serialize)
expected foreach { e =>
lines should contain(e)
@@ -37,7 +36,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec {
"""circuit Unit :
| module Unit :
| input a : Fixed<10><<2>>
- | input b : Fixed<10>
+ | input b : Fixed<10><<0>>
| input c : Fixed<4><<3>>
| output d : Fixed<<5>>
| d <= add(a, add(b, c))""".stripMargin
@@ -67,7 +66,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec {
"""circuit Unit :
| module Unit :
| input a : Fixed<10><<2>>
- | input b : Fixed<10>
+ | input b : Fixed<10><<0>>
| input c : Fixed<4><<3>>
| output d : Fixed<<5>>
| d <- add(a, add(b, c))""".stripMargin
@@ -99,7 +98,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec {
| module Unit :
| input a : Fixed<10><<2>>
| output d : Fixed<12><<4>>
- | d <= bpshl(a, 2)""".stripMargin
+ | d <= incp(a, 2)""".stripMargin
val check =
"""circuit Unit :
| module Unit :
@@ -126,7 +125,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec {
| module Unit :
| input a : Fixed<10><<2>>
| output d : Fixed<9><<1>>
- | d <= bpshr(a, 1)""".stripMargin
+ | d <= decp(a, 1)""".stripMargin
val check =
"""circuit Unit :
| module Unit :
@@ -153,7 +152,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec {
| module Unit :
| input a : Fixed<10><<2>>
| output d : Fixed
- | d <= bpset(a, 3)""".stripMargin
+ | d <= setp(a, 3)""".stripMargin
val check =
"""circuit Unit :
| module Unit :
@@ -181,7 +180,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec {
class CheckChirrtlTransform extends SeqTransform {
def inputForm = ChirrtlForm
def outputForm = ChirrtlForm
- val transforms = Seq(passes.CheckChirrtl)
+ def transforms = Seq(passes.CheckChirrtl)
}
val chirrtlTransform = new CheckChirrtlTransform
diff --git a/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala b/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala
new file mode 100644
index 00000000..20fdeee1
--- /dev/null
+++ b/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala
@@ -0,0 +1,183 @@
+// See LICENSE for license details.
+
+package firrtlTests.interval
+
+import firrtl.Implicits.constraint2bound
+import firrtl.{ChirrtlForm, CircuitState, LowFirrtlCompiler, Parser}
+import firrtl.ir._
+
+import scala.math.BigDecimal.RoundingMode._
+import firrtl.Parser.IgnoreInfo
+import firrtl.constraint._
+import firrtlTests.FirrtlFlatSpec
+
+class IntervalMathSpec extends FirrtlFlatSpec {
+ val SumPattern = """.*output sum.*<(\d+)>.*""".r
+ val ProductPattern = """.*output product.*<(\d+)>.*""".r
+ val DifferencePattern = """.*output difference.*<(\d+)>.*""".r
+ val ComparisonPattern = """.*output (\w+).*UInt<(\d+)>.*""".r
+ val ShiftLeftPattern = """.*output shl.*<(\d+)>.*""".r
+ val ShiftRightPattern = """.*output shr.*<(\d+)>.*""".r
+ val DShiftLeftPattern = """.*output dshl.*<(\d+)>.*""".r
+ val DShiftRightPattern = """.*output dshr.*<(\d+)>.*""".r
+ val ArithAssignPattern = """\s*(\w+) <= asSInt\(bits\((\w+)\((.*)\).*\)\)\s*""".r
+ def getBound(bound: String, value: Double): IsKnown = bound match {
+ case "[" => Closed(BigDecimal(value))
+ case "]" => Closed(BigDecimal(value))
+ case "(" => Open(BigDecimal(value))
+ case ")" => Open(BigDecimal(value))
+ }
+
+ val prec = 0.5
+
+ for {
+ lb1 <- Seq("[", "(")
+ lv1 <- Range.Double(-1.0, 1.0, prec)
+ uv1 <- if(lb1 == "[") Range.Double(lv1, 1.0, prec) else Range.Double(lv1 + prec, 1.0, prec)
+ ub1 <- if (lv1 == uv1) Seq("]") else Seq("]", ")")
+ bp1 <- 0 to 1
+ lb2 <- Seq("[", "(")
+ lv2 <- Range.Double(-1.0, 1.0, prec)
+ uv2 <- if(lb2 == "[") Range.Double(lv2, 1.0, prec) else Range.Double(lv2 + prec, 1.0, prec)
+ ub2 <- if (lv2 == uv2) Seq("]") else Seq("]", ")")
+ bp2 <- 0 to 1
+ } {
+ val it1 = IntervalType(getBound(lb1, lv1), getBound(ub1, uv1), IntWidth(bp1.toInt))
+ val it2 = IntervalType(getBound(lb2, lv2), getBound(ub2, uv2), IntWidth(bp2.toInt))
+ (it1.range, it2.range) match {
+ case (Some(Nil), _) =>
+ case (_, Some(Nil)) =>
+ case _ =>
+ def config = s"$lb1$lv1,$uv1$ub1.$bp1 and $lb2$lv2,$uv2$ub2.$bp2"
+
+ s"Configuration $config" should "pass" in {
+
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in1 : Interval$lb1$lv1, $uv1$ub1.$bp1
+ | input in2 : Interval$lb2$lv2, $uv2$ub2.$bp2
+ | input amt : UInt<3>
+ | output sum : Interval
+ | output difference : Interval
+ | output product : Interval
+ | output shl : Interval
+ | output shr : Interval
+ | output dshl : Interval
+ | output dshr : Interval
+ | output lt : UInt
+ | output leq : UInt
+ | output gt : UInt
+ | output geq : UInt
+ | output eq : UInt
+ | output neq : UInt
+ | output cat : UInt
+ | sum <= add(in1, in2)
+ | difference <= sub(in1, in2)
+ | product <= mul(in1, in2)
+ | shl <= shl(in1, 3)
+ | shr <= shr(in1, 3)
+ | dshl <= dshl(in1, amt)
+ | dshr <= dshr(in1, amt)
+ | lt <= lt(in1, in2)
+ | leq <= leq(in1, in2)
+ | gt <= gt(in1, in2)
+ | geq <= geq(in1, in2)
+ | eq <= eq(in1, in2)
+ | neq <= lt(in1, in2)
+ | cat <= cat(in1, in2)
+ | """.stripMargin
+
+ val lowerer = new LowFirrtlCompiler
+ val res = lowerer.compileAndEmit(CircuitState(parse(input), ChirrtlForm))
+ val output = res.getEmittedCircuit.value split "\n"
+ val min1 = Closed(it1.min.get)
+ val max1 = Closed(it1.max.get)
+ val min2 = Closed(it2.min.get)
+ val max2 = Closed(it2.max.get)
+ for (line <- output) {
+ line match {
+ case SumPattern(varWidth) =>
+ val bp = IntWidth(Math.max(bp1.toInt, bp2.toInt))
+ val it = IntervalType(IsAdd(min1, min2), IsAdd(max1, max2), bp)
+ assert(varWidth.toInt == it.width.asInstanceOf[IntWidth].width, s"$line,${it.range}")
+ case ProductPattern(varWidth) =>
+ val bp = IntWidth(bp1.toInt + bp2.toInt)
+ val lv = IsMin(Seq(IsMul(min1, min2), IsMul(min1, max2), IsMul(max1, min2), IsMul(max1, max2)))
+ val uv = IsMax(Seq(IsMul(min1, min2), IsMul(min1, max2), IsMul(max1, min2), IsMul(max1, max2)))
+ assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "product")
+ case DifferencePattern(varWidth) =>
+ val bp = IntWidth(Math.max(bp1.toInt, bp2.toInt))
+ val lv = min1 + max2.neg
+ val uv = max1 + min2.neg
+ assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "diff")
+ case ShiftLeftPattern(varWidth) =>
+ val bp = IntWidth(bp1.toInt)
+ val lv = min1 * Closed(8)
+ val uv = max1 * Closed(8)
+ val it = IntervalType(lv, uv, bp)
+ assert(varWidth.toInt == it.width.asInstanceOf[IntWidth].width, "shl")
+ case ShiftRightPattern(varWidth) =>
+ val bp = IntWidth(bp1.toInt)
+ val lv = min1 * Closed(1/3)
+ val uv = max1 * Closed(1/3)
+ assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "shr")
+ case DShiftLeftPattern(varWidth) =>
+ val bp = IntWidth(bp1.toInt)
+ val lv = min1 * Closed(128)
+ val uv = max1 * Closed(128)
+ assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "dshl")
+ case DShiftRightPattern(varWidth) =>
+ val bp = IntWidth(bp1.toInt)
+ val lv = min1
+ val uv = max1
+ assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "dshr")
+ case ComparisonPattern(varWidth) => assert(varWidth.toInt == 1, "==")
+ case ArithAssignPattern(varName, operation, args) =>
+ val arg1 = if(IntervalType(getBound(lb1, lv1), getBound(ub1, uv1), IntWidth(bp1)).width == IntWidth(0)) """SInt<1>("h0")""" else "in1"
+ val arg2 = if(IntervalType(getBound(lb2, lv2), getBound(ub2, uv2), IntWidth(bp2)).width == IntWidth(0)) """SInt<1>("h0")""" else "in2"
+ varName match {
+ case "sum" =>
+ assert(operation === "add", s"""var sum should be result of an add in ${output.mkString("\n")}""")
+ if (bp1 > bp2) {
+ if (arg1 != arg2) assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line")
+ assert(args.contains(s"shl($arg2, ${bp1 - bp2})"),
+ s"$config second arg incorrect in $line")
+ } else if (bp1 < bp2) {
+ assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"),
+ s"$config second arg incorrect in $line")
+ assert(!args.contains("shl($arg2"), s"$config second arg should be just $arg2 in $line")
+ } else {
+ assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line")
+ assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line")
+ }
+ case "product" =>
+ assert(operation === "mul", s"var sum should be result of an add in $line")
+ assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line")
+ assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line")
+ case "difference" =>
+ assert(operation === "sub", s"var difference should be result of an sub in $line")
+ if (bp1 > bp2) {
+ if (arg1 != arg2) assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line")
+ assert(args.contains(s"shl($arg2, ${bp1 - bp2})"),
+ s"$config second arg incorrect in $line")
+ } else if (bp1 < bp2) {
+ assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"),
+ s"$config second arg incorrect in $line")
+ if (arg1 != arg2) assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line")
+ } else {
+ assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line")
+ assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line")
+ }
+ case _ =>
+ }
+ case _ =>
+ }
+ }
+ }
+ }
+ }
+}
+
+
+// vim: set ts=4 sw=4 et:
diff --git a/src/test/scala/firrtlTests/interval/IntervalSpec.scala b/src/test/scala/firrtlTests/interval/IntervalSpec.scala
new file mode 100644
index 00000000..37d79c84
--- /dev/null
+++ b/src/test/scala/firrtlTests/interval/IntervalSpec.scala
@@ -0,0 +1,530 @@
+package firrtlTests
+package interval
+
+import java.io._
+
+import firrtl._
+import firrtl.ir.Circuit
+import firrtl.passes._
+import firrtl.Parser.IgnoreInfo
+import firrtl.passes.CheckTypes.InvalidConnect
+import firrtl.passes.CheckWidths.DisjointSqueeze
+
+class IntervalSpec extends FirrtlFlatSpec {
+ private def executeTest(input: String, expected: Seq[String], passes: Seq[Transform]) = {
+ val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
+ (c: Circuit, p: Transform) =>
+ p.runTransform(CircuitState(c, UnknownForm, AnnotationSeq(Nil), None)).circuit
+ }
+ val lines = c.serialize.split("\n") map normalized
+
+ expected foreach { e =>
+ lines should contain(e)
+ }
+ }
+
+ "Interval types" should "parse correctly" in {
+ val passes = Seq(ToWorkingIR)
+ val input =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : Interval(-0.32, 10.1).4
+ | input in1 : Interval[0, 10.1].4
+ | input in2 : Interval(-0.32, 10].4
+ | input in3 : Interval[-3, 10.1).4
+ | input in4 : Interval(-0.32, 10.1)
+ | input in5 : Interval.4
+ | input in6 : Interval
+ | output out0 : Interval.2
+ | output out1 : Interval
+ | out0 <= add(in0, add(in1, add(in2, add(in3, add(in4, add(in5, in6))))))
+ | out1 <= add(in0, add(in1, add(in2, add(in3, add(in4, add(in5, in6))))))""".stripMargin
+ executeTest(input, input.split("\n") map normalized, passes)
+ }
+
+ "Interval types" should "infer bp correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints())
+ val input =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : Interval(-0.32, 10.1).4
+ | input in1 : Interval[0, 10.1].3
+ | input in2 : Interval(-0.32, 10].2
+ | output out0 : Interval
+ | out0 <= add(in0, add(in1, in2))""".stripMargin
+ val check =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : Interval(-0.32, 10.1).4
+ | input in1 : Interval[0, 10.1].3
+ | input in2 : Interval(-0.32, 10].2
+ | output out0 : Interval.4
+ | out0 <= add(in0, add(in1, in2))""".stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+
+ "Interval types" should "trim known intervals correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals())
+ val input =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : Interval(-0.32, 10.1).4
+ | input in1 : Interval[0, 10.1].3
+ | input in2 : Interval(-0.32, 10].2
+ | output out0 : Interval
+ | out0 <= add(in0, add(in1, in2))""".stripMargin
+ val check =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : Interval[-0.3125, 10.0625].4
+ | input in1 : Interval[0, 10].3
+ | input in2 : Interval[-0.25, 10].2
+ | output out0 : Interval.4
+ | out0 <= add(in0, incp(add(in1, incp(in2, 1)), 1))""".stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+
+ "Interval types" should "infer intervals correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths())
+ val input =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : Interval(0, 10).4
+ | input in1 : Interval(0, 10].3
+ | input in2 : Interval(-1, 3].2
+ | output out0 : Interval
+ | output out1 : Interval
+ | output out2 : Interval
+ | out0 <= add(in0, add(in1, in2))
+ | out1 <= mul(in0, mul(in1, in2))
+ | out2 <= sub(in0, sub(in1, in2))""".stripMargin
+ val check =
+ """output out0 : Interval[-0.5625, 22.9375].4
+ |output out1 : Interval[-74.53125, 298.125].9
+ |output out2 : Interval[-10.6875, 12.8125].4""".stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+
+ "Interval types" should "be removed correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths(), new RemoveIntervals())
+ val input =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : Interval(0, 10).4
+ | input in1 : Interval(0, 10].3
+ | input in2 : Interval(-1, 3].2
+ | output out0 : Interval
+ | output out1 : Interval
+ | output out2 : Interval
+ | out0 <= add(in0, add(in1, in2))
+ | out1 <= mul(in0, mul(in1, in2))
+ | out2 <= sub(in0, sub(in1, in2))""".stripMargin
+ val check =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : SInt<9>
+ | input in1 : SInt<8>
+ | input in2 : SInt<5>
+ | output out0 : SInt<10>
+ | output out1 : SInt<19>
+ | output out2 : SInt<9>
+ | out0 <= add(in0, shl(add(in1, shl(in2, 1)), 1))
+ | out1 <= mul(in0, mul(in1, in2))
+ | out2 <= sub(in0, shl(sub(in1, shl(in2, 1)), 1))""".stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+
+"Interval types" should "infer multiplication by zero correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in1 : Interval[0, 0.5].1
+ | input in2 : Interval[0, 0].1
+ | output mul : Interval
+ | mul <= mul(in2, in1)
+ | """.stripMargin
+ val check = s"""output mul : Interval[0, 0].2 """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+}
+
+ "Interval types" should "infer muxes correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input p : UInt<1>
+ | input in1 : Interval[0, 0.5].1
+ | input in2 : Interval[0, 0].1
+ | output out : Interval
+ | out <= mux(p, in2, in1)
+ | """.stripMargin
+ val check = s"""output out : Interval[0, 0.5].1 """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "infer dshl correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveKinds, ResolveGenders, new InferBinaryPoints(), new TrimIntervals, new InferWidths())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input p : UInt<3>
+ | input in1 : Interval[-1, 1].0
+ | output out : Interval
+ | out <= dshl(in1, p)
+ | """.stripMargin
+ val check = s"""output out : Interval[-128, 128].0 """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "infer asInterval correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferWidths())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input p : UInt<3>
+ | output out : Interval
+ | out <= asInterval(p, 0, 4, 1)
+ | """.stripMargin
+ val check = s"""output out : Interval[0, 2].1 """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "do wrap/clip correctly" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input s: SInt<2>
+ | input u: UInt<3>
+ | input in1: Interval[-3, 5].0
+ | output wrap3: Interval
+ | output wrap4: Interval
+ | output wrap5: Interval
+ | output wrap6: Interval
+ | output wrap7: Interval
+ | output clip3: Interval
+ | output clip4: Interval
+ | output clip5: Interval
+ | output clip6: Interval
+ | output clip7: Interval
+ | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0))
+ | wrap4 <= wrap(in1, asInterval(s, -1, 1, 0))
+ | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0))
+ | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0))
+ | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0))
+ | clip3 <= clip(in1, asInterval(s, -2, 4, 0))
+ | clip4 <= clip(in1, asInterval(s, -1, 1, 0))
+ | clip5 <= clip(in1, asInterval(s, -4, 4, 0))
+ | clip6 <= clip(in1, asInterval(s, -1, 7, 0))
+ | clip7 <= clip(in1, asInterval(s, -4, 7, 0))
+ """.stripMargin
+ //| output wrap1: Interval
+ //| output wrap2: Interval
+ //| output clip1: Interval
+ //| output clip2: Interval
+ //| wrap1 <= wrap(in1, u, 0)
+ //| wrap2 <= wrap(in1, s, 0)
+ //| clip1 <= clip(in1, u)
+ //| clip2 <= clip(in1, s)
+ val check = s"""
+ | output wrap3 : Interval[-2, 4].0
+ | output wrap4 : Interval[-1, 1].0
+ | output wrap5 : Interval[-4, 4].0
+ | output wrap6 : Interval[-1, 7].0
+ | output wrap7 : Interval[-4, 7].0
+ | output clip3 : Interval[-2, 4].0
+ | output clip4 : Interval[-1, 1].0
+ | output clip5 : Interval[-3, 4].0
+ | output clip6 : Interval[-1, 5].0
+ | output clip7 : Interval[-3, 5].0 """.stripMargin
+ // TODO: this optimization
+ //| output wrap1 : Interval[0, 7].0
+ //| output wrap2 : Interval[-2, 1].0
+ //| output clip1 : Interval[0, 5].0
+ //| output clip2 : Interval[-2, 1].0
+ //| output wrap7 : Interval[-3, 5].0
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "remove wrap/clip correctly" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck(), new RemoveIntervals())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input s: SInt<2>
+ | input u: UInt<3>
+ | input in1: Interval[-3, 5].0
+ | output wrap3: Interval
+ | output wrap5: Interval
+ | output wrap6: Interval
+ | output wrap7: Interval
+ | output clip3: Interval
+ | output clip4: Interval
+ | output clip5: Interval
+ | output clip6: Interval
+ | output clip7: Interval
+ | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0))
+ | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0))
+ | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0))
+ | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0))
+ | clip3 <= clip(in1, asInterval(s, -2, 4, 0))
+ | clip4 <= clip(in1, asInterval(s, -1, 1, 0))
+ | clip5 <= clip(in1, asInterval(s, -4, 4, 0))
+ | clip6 <= clip(in1, asInterval(s, -1, 7, 0))
+ | clip7 <= clip(in1, asInterval(s, -4, 7, 0))
+ | """.stripMargin
+ val check = s"""
+ | wrap3 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<4>("h7")), mux(lt(in1, SInt<2>("h-2")), add(in1, SInt<4>("h7")), in1))
+ | wrap5 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), in1)
+ | wrap6 <= mux(lt(in1, SInt<1>("h-1")), add(in1, SInt<5>("h9")), in1)
+ | wrap7 <= in1
+ | clip3 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<2>("h-2")), SInt<2>("h-2"), in1))
+ | clip4 <= mux(gt(in1, SInt<2>("h1")), SInt<2>("h1"), mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1))
+ | clip5 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), in1)
+ | clip6 <= mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1)
+ | clip7 <= in1
+ """.stripMargin
+ //| output wrap4: Interval
+ //| wrap4 <= wrap(in1, asInterval(s, -1, 1, 0), 0)
+ //| wrap4 <= add(rem(sub(in1, SInt<1>("h-1")), sub(SInt<2>("h1"), SInt<1>("h-1"))), SInt<1>("h-1"))
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "shift wrap/clip correctly" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input s: SInt<2>
+ | input in1: Interval[-3, 5].1
+ | output wrap1: Interval
+ | output clip1: Interval
+ | wrap1 <= wrap(in1, asInterval(s, -2, 2, 0))
+ | clip1 <= clip(in1, asInterval(s, -2, 2, 0))
+ | """.stripMargin
+ val check = s"""
+ | wrap1 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), mux(lt(in1, SInt<3>("h-4")), add(in1, SInt<5>("h9")), in1))
+ | clip1 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<3>("h-4")), SInt<3>("h-4"), in1))
+ """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "infer negative binary points" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in1: Interval[-2, 4].-1
+ | input in2: Interval[-4, 8].-2
+ | output out: Interval
+ | out <= add(in1, in2)
+ | """.stripMargin
+ val check = s"""
+ | output out : Interval[-6, 12].-1
+ """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "remove negative binary points" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths(), new RemoveIntervals())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in1: Interval[-2, 4].-1
+ | input in2: Interval[-4, 8].-2
+ | output out: Interval.0
+ | out <= add(in1, in2)
+ | """.stripMargin
+ val check = s"""
+ | output out : SInt<5>
+ | out <= shl(add(in1, shl(in2, 1)), 1)
+ """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "implement squz properly" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck)
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input min: Interval[-1, 4].1
+ | input max: Interval[-3, 5].1
+ | input left: Interval[-3, 3].1
+ | input right: Interval[0, 5].1
+ | input off: Interval[-1, 4].2
+ | output minMax: Interval
+ | output maxMin: Interval
+ | output minLeft: Interval
+ | output leftMin: Interval
+ | output minRight: Interval
+ | output rightMin: Interval
+ | output minOff: Interval
+ | output offMin: Interval
+ |
+ | minMax <= squz(min, max)
+ | maxMin <= squz(max, min)
+ | minLeft <= squz(min, left)
+ | leftMin <= squz(left, min)
+ | minRight <= squz(min, right)
+ | rightMin <= squz(right, min)
+ | minOff <= squz(min, off)
+ | offMin <= squz(off, min)
+ | """.stripMargin
+ val check =
+ s"""
+ | output minMax : Interval[-1, 4].1
+ | output maxMin : Interval[-1, 4].1
+ | output minLeft : Interval[-1, 3].1
+ | output leftMin : Interval[-1, 3].1
+ | output minRight : Interval[0, 4].1
+ | output rightMin : Interval[0, 4].1
+ | output minOff : Interval[-1, 4].1
+ | output offMin : Interval[-1, 4].2
+ """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "lower squz properly" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals)
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input min: Interval[-1, 4].1
+ | input max: Interval[-3, 5].1
+ | input left: Interval[-3, 3].1
+ | input right: Interval[0, 5].1
+ | input off: Interval[-1, 4].2
+ | output minMax: Interval
+ | output maxMin: Interval
+ | output minLeft: Interval
+ | output leftMin: Interval
+ | output minRight: Interval
+ | output rightMin: Interval
+ | output minOff: Interval
+ | output offMin: Interval
+ |
+ | minMax <= squz(min, max)
+ | maxMin <= squz(max, min)
+ | minLeft <= squz(min, left)
+ | leftMin <= squz(left, min)
+ | minRight <= squz(min, right)
+ | rightMin <= squz(right, min)
+ | minOff <= squz(min, off)
+ | offMin <= squz(off, min)
+ | """.stripMargin
+ val check =
+ s"""
+ | minMax <= asSInt(bits(min, 4, 0))
+ | maxMin <= asSInt(bits(max, 4, 0))
+ | minLeft <= asSInt(bits(min, 3, 0))
+ | leftMin <= left
+ | minRight <= asSInt(bits(min, 4, 0))
+ | rightMin <= asSInt(bits(right, 4, 0))
+ | minOff <= asSInt(bits(min, 4, 0))
+ | offMin <= asSInt(bits(off, 5, 0))
+ """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Assigning a larger interval to a smaller interval" should "error!" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals)
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in: Interval[1, 4].1
+ | output out: Interval[2, 3].1
+ | out <= in
+ | """.stripMargin
+ intercept[InvalidConnect]{
+ executeTest(input, Nil, passes)
+ }
+ }
+ "Assigning a more precise interval to a less precise interval" should "error!" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals)
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in: Interval[2, 3].3
+ | output out: Interval[2, 3].1
+ | out <= in
+ | """.stripMargin
+ intercept[InvalidConnect]{
+ executeTest(input, Nil, passes)
+ }
+ }
+ "Chick's example" should "work" in {
+ val input =
+ s"""circuit IntervalChainedSubTester :
+ | module IntervalChainedSubTester :
+ | input clock : Clock
+ | input reset : UInt<1>
+ | node _GEN_0 = sub(SInt<6>("h11"), SInt<6>("h2")) @[IntervalSpec.scala 337:26 IntervalSpec.scala 337:26]
+ | node _GEN_1 = bits(_GEN_0, 4, 0) @[IntervalSpec.scala 337:26 IntervalSpec.scala 337:26]
+ | node intervalResult = asSInt(_GEN_1) @[IntervalSpec.scala 337:26 IntervalSpec.scala 337:26]
+ | skip
+ | node _T_1 = asUInt(intervalResult) @[IntervalSpec.scala 338:50]
+ | skip
+ | node _T_3 = eq(reset, UInt<1>("h0")) @[IntervalSpec.scala 338:9]
+ | node _T_4 = eq(intervalResult, SInt<5>("hf")) @[IntervalSpec.scala 339:25]
+ | skip
+ | node _T_6 = or(_T_4, reset) @[IntervalSpec.scala 339:9]
+ | node _T_7 = eq(_T_6, UInt<1>("h0")) @[IntervalSpec.scala 339:9]
+ | skip
+ | skip
+ | printf(clock, _T_3, "Interval result: %d", _T_1) @[IntervalSpec.scala 338:9]
+ | printf(clock, _T_7, "Assertion failed at IntervalSpec.scala:339 assert(intervalResult === 15.I)") @[IntervalSpec.scala 339:9]
+ | stop(clock, _T_7, 1) @[IntervalSpec.scala 339:9]
+ | stop(clock, _T_3, 0) @[IntervalSpec.scala 340:7]
+ |
+ """.stripMargin
+ compileToVerilog(input)
+ }
+
+ "Squeeze with disjoint intervals" should "error" in {
+ intercept[DisjointSqueeze] {
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in1: Interval[2, 3).3
+ | input in2: Interval[3, 6].3
+ | node out = squz(in1, in2)
+ """.stripMargin
+ compileToVerilog(input)
+ }
+ intercept[DisjointSqueeze] {
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in1: Interval[2, 3).3
+ | input in2: Interval[3, 6].3
+ | node out = squz(in2, in1)
+ """.stripMargin
+ compileToVerilog(input)
+ }
+ }
+
+ "Clip with disjoint intervals" should "work" in {
+ compileToVerilog(
+ s"""circuit Unit :
+ | module Unit :
+ | input in1: Interval[2, 3).3
+ | input in2: Interval[3, 6].3
+ | output out: Interval
+ | out <= clip(in1, in2)
+ """.stripMargin
+ )
+ compileToVerilog(
+ s"""circuit Unit :
+ | module Unit :
+ | input in1: Interval[2, 3).3
+ | input in2: Interval[4, 6].3
+ | node out = clip(in1, in2)
+ """.stripMargin
+ )
+ }
+
+
+ "Wrap with remainder" should "error" in {
+ intercept[WrapWithRemainder] {
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in1: Interval[0, 300).3
+ | input in2: Interval[3, 6].3
+ | node out = wrap(in1, in2)
+ """.stripMargin
+ compileToVerilog(input)
+ }
+ }
+}
diff --git a/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala b/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala
index 46fb310a..88095830 100644
--- a/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala
+++ b/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala
@@ -152,7 +152,7 @@ class InferWidthsWithAnnosSpec extends FirrtlFlatSpec {
}
"InferWidthsWithAnnos" should "work with WiringTransform" in {
- def transforms = Seq(
+ def transforms() = Seq(
ToWorkingIR,
ResolveKinds,
InferTypes,