diff options
| author | Adam Izraelevitz | 2019-10-18 19:01:19 -0700 |
|---|---|---|
| committer | GitHub | 2019-10-18 19:01:19 -0700 |
| commit | fd981848c7d2a800a15f9acfbf33b57dd1c6225b (patch) | |
| tree | 3609a301cb0ec867deefea4a0d08425810b00418 /src | |
| parent | 973ecf516c0ef2b222f2eb68dc8b514767db59af (diff) | |
Upstream intervals (#870)
Major features:
- Added Interval type, as well as PrimOps asInterval, clip, wrap, and sqz.
- Changed PrimOp names: bpset -> setp, bpshl -> incp, bpshr -> decp
- Refactored width/bound inferencer into a separate constraint solver
- Added transforms to infer, trim, and remove interval bounds
- Tests for said features
Plan to be released with 1.3
Diffstat (limited to 'src')
47 files changed, 3229 insertions, 723 deletions
diff --git a/src/main/antlr4/FIRRTL.g4 b/src/main/antlr4/FIRRTL.g4 index be15ab7c..518cb698 100644 --- a/src/main/antlr4/FIRRTL.g4 +++ b/src/main/antlr4/FIRRTL.g4 @@ -50,6 +50,7 @@ type : 'UInt' ('<' intLit '>')? | 'SInt' ('<' intLit '>')? | 'Fixed' ('<' intLit '>')? ('<' '<' intLit '>' '>')? + | 'Interval' (lowerBound boundValue boundValue upperBound)? ('.' intLit)? | 'Clock' | 'AsyncReset' | 'Reset' @@ -187,6 +188,23 @@ intLit | HexLit ; +lowerBound + : '[' + | '(' + ; + +upperBound + : ']' + | ')' + ; + +boundValue + : '?' + | DoubleLit + | UnsignedInt + | SignedInt + ; + // Keywords that are also legal ids keywordAsId : 'circuit' @@ -253,6 +271,8 @@ primop | 'asAsyncReset(' | 'asSInt(' | 'asClock(' + | 'asFixedPoint(' + | 'asInterval(' | 'shl(' | 'shr(' | 'dshl(' @@ -270,10 +290,12 @@ primop | 'bits(' | 'head(' | 'tail(' - | 'asFixedPoint(' - | 'bpshl(' - | 'bpshr(' - | 'bpset(' + | 'incp(' + | 'decp(' + | 'setp(' + | 'wrap(' + | 'clip(' + | 'squz(' ; /*------------------------------------------------------------------ diff --git a/src/main/proto/firrtl.proto b/src/main/proto/firrtl.proto index 0cf14f41..3d2c89f1 100644 --- a/src/main/proto/firrtl.proto +++ b/src/main/proto/firrtl.proto @@ -451,10 +451,14 @@ message Firrtl { OP_AS_FIXED_POINT = 32; OP_AND_REDUCE = 33; OP_OR_REDUCE = 34; - OP_SHIFT_BINARY_POINT_LEFT = 35; - OP_SHIFT_BINARY_POINT_RIGHT = 36; - OP_SET_BINARY_POINT = 37; + OP_INCREASE_PRECISION = 35; + OP_DECREASE_PRECISION = 36; + OP_SET_PRECISION = 37; OP_AS_ASYNC_RESET = 38; + OP_WRAP = 39; + OP_CLIP = 40; + OP_SQUEEZE = 41; + OP_AS_INTERVAL = 42; } // Required. diff --git a/src/main/scala/firrtl/Implicits.scala b/src/main/scala/firrtl/Implicits.scala new file mode 100644 index 00000000..ec1cf3d6 --- /dev/null +++ b/src/main/scala/firrtl/Implicits.scala @@ -0,0 +1,30 @@ +// See LICENSE for license details. + +package firrtl + +import firrtl.ir._ +import Utils.trim +import firrtl.constraint.Constraint + +object Implicits { + implicit def int2WInt(i: Int): WrappedInt = WrappedInt(BigInt(i)) + implicit def bigint2WInt(i: BigInt): WrappedInt = WrappedInt(i) + implicit def constraint2bound(c: Constraint): Bound = c match { + case x: Bound => x + case x => CalcBound(x) + } + implicit def constraint2width(c: Constraint): Width = c match { + case Closed(x) if trim(x).isWhole => IntWidth(x.toBigInt) + case x => CalcWidth(x) + } + implicit def width2constraint(w: Width): Constraint = w match { + case CalcWidth(x: Constraint) => x + case IntWidth(x) => Closed(BigDecimal(x)) + case UnknownWidth => UnknownBound + case v: Constraint => v + } +} +case class WrappedInt(value: BigInt) { + def U: Expression = UIntLiteral(value, IntWidth(Utils.getUIntWidth(value))) + def S: Expression = SIntLiteral(value, IntWidth(Utils.getSIntWidth(value))) +} diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index 05cdbe96..75645319 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -45,6 +45,8 @@ class ResolveAndCheck extends CoreTransform { passes.InferTypes, passes.ResolveFlows, passes.CheckFlows, + new passes.InferBinaryPoints(), + new passes.TrimIntervals(), new passes.InferWidths, passes.CheckWidths, new firrtl.transforms.InferResets) @@ -73,6 +75,7 @@ class HighFirrtlToMiddleFirrtl extends CoreTransform { passes.ResolveFlows, new passes.InferWidths, passes.CheckWidths, + new passes.RemoveIntervals(), passes.ConvertFixedToSInt, passes.ZeroWidth, passes.InferTypes) diff --git a/src/main/scala/firrtl/Mappers.scala b/src/main/scala/firrtl/Mappers.scala index e8283d93..b30b4518 100644 --- a/src/main/scala/firrtl/Mappers.scala +++ b/src/main/scala/firrtl/Mappers.scala @@ -6,6 +6,19 @@ import firrtl.ir._ // TODO: Implement remaining mappers and recursive mappers object Mappers { + // ********** Port Mappers ********** + private trait PortMagnet { + def map(p: Port): Port + } + private object PortMagnet { + implicit def forType(f: Type => Type): PortMagnet = new PortMagnet { + override def map(port: Port): Port = port mapType f + } + } + implicit class PortMap(val _port: Port) extends AnyVal { + def map[T](f: T => T)(implicit magnet: (T => T) => PortMagnet): Port = magnet(f).map(_port) + } + // ********** Stmt Mappers ********** private trait StmtMagnet { diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala index af8328da..02404f70 100644 --- a/src/main/scala/firrtl/PrimOps.scala +++ b/src/main/scala/firrtl/PrimOps.scala @@ -3,92 +3,487 @@ package firrtl import firrtl.ir._ -import firrtl.Utils.{min, max, pow_minus_one} - import com.typesafe.scalalogging.LazyLogging +import Implicits.{constraint2bound, constraint2width, width2constraint} +import firrtl.constraint._ /** Definitions and Utility functions for [[ir.PrimOp]]s */ object PrimOps extends LazyLogging { + def t1(e: DoPrim): Type = e.args.head.tpe + def t2(e: DoPrim): Type = e.args(1).tpe + def t3(e: DoPrim): Type = e.args(2).tpe + def w1(e: DoPrim): Width = getWidth(t1(e)) + def w2(e: DoPrim): Width = getWidth(t2(e)) + def p1(e: DoPrim): Width = t1(e) match { + case FixedType(w, p) => p + case IntervalType(min, max, p) => p + case _ => sys.error(s"Cannot get binary point from ${t1(e)}") + } + def p2(e: DoPrim): Width = t2(e) match { + case FixedType(w, p) => p + case IntervalType(min, max, p) => p + case _ => sys.error(s"Cannot get binary point from ${t1(e)}") + } + def c1(e: DoPrim) = IntWidth(e.consts.head) + def c2(e: DoPrim) = IntWidth(e.consts(1)) + def o1(e: DoPrim) = e.consts(0) + def o2(e: DoPrim) = e.consts(1) + def o3(e: DoPrim) = e.consts(2) + /** Addition */ - case object Add extends PrimOp { override def toString = "add" } + case object Add extends PrimOp { + override def toString = "add" + override def propagateType(e: DoPrim): Type = { + (t1(e), t2(e)) match { + case (_: UIntType, _: UIntType) => UIntType(IsAdd(IsMax(w1(e), w2(e)), IntWidth(1))) + case (_: SIntType, _: SIntType) => SIntType(IsAdd(IsMax(w1(e), w2(e)), IntWidth(1))) + case (_: FixedType, _: FixedType) => FixedType(IsAdd(IsAdd(IsMax(p1(e), p2(e)), IsMax(IsAdd(w1(e), IsNeg(p1(e))), IsAdd(w2(e), IsNeg(p2(e))))), IntWidth(1)), IsMax(p1(e), p2(e))) + case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) => IntervalType(IsAdd(l1, l2), IsAdd(u1, u2), IsMax(p1, p2)) + case _ => UnknownType + } + } + } + /** Subtraction */ - case object Sub extends PrimOp { override def toString = "sub" } + case object Sub extends PrimOp { + override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { + case (_: UIntType, _: UIntType) => UIntType(IsAdd(IsMax(w1(e), w2(e)), IntWidth(1))) + case (_: SIntType, _: SIntType) => SIntType(IsAdd(IsMax(w1(e), w2(e)), IntWidth(1))) + case (_: FixedType, _: FixedType) => FixedType(IsAdd(IsAdd(IsMax(p1(e), p2(e)),IsMax(IsAdd(w1(e), IsNeg(p1(e))), IsAdd(w2(e), IsNeg(p2(e))))),IntWidth(1)), IsMax(p1(e), p2(e))) + case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) => IntervalType(IsAdd(l1, IsNeg(u2)), IsAdd(u1, IsNeg(l2)), IsMax(p1, p2)) + case _ => UnknownType + } + override def toString = "sub" + } + /** Multiplication */ - case object Mul extends PrimOp { override def toString = "mul" } + case object Mul extends PrimOp { + override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { + case (_: UIntType, _: UIntType) => UIntType(IsAdd(w1(e), w2(e))) + case (_: SIntType, _: SIntType) => SIntType(IsAdd(w1(e), w2(e))) + case (_: FixedType, _: FixedType) => FixedType(IsAdd(w1(e), w2(e)), IsAdd(p1(e), p2(e))) + case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) => + IntervalType( + IsMin(Seq(IsMul(l1, l2), IsMul(l1, u2), IsMul(u1, l2), IsMul(u1, u2))), + IsMax(Seq(IsMul(l1, l2), IsMul(l1, u2), IsMul(u1, l2), IsMul(u1, u2))), + IsAdd(p1, p2) + ) + case _ => UnknownType + } + override def toString = "mul" } + /** Division */ - case object Div extends PrimOp { override def toString = "div" } + case object Div extends PrimOp { + override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { + case (_: UIntType, _: UIntType) => UIntType(w1(e)) + case (_: SIntType, _: SIntType) => SIntType(IsAdd(w1(e), IntWidth(1))) + case _ => UnknownType + } + override def toString = "div" } + /** Remainder */ - case object Rem extends PrimOp { override def toString = "rem" } + case object Rem extends PrimOp { + override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { + case (_: UIntType, _: UIntType) => UIntType(MIN(w1(e), w2(e))) + case (_: SIntType, _: SIntType) => SIntType(MIN(w1(e), w2(e))) + case _ => UnknownType + } + override def toString = "rem" } /** Less Than */ - case object Lt extends PrimOp { override def toString = "lt" } + case object Lt extends PrimOp { + override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { + case (_: UIntType, _: UIntType) => Utils.BoolType + case (_: SIntType, _: SIntType) => Utils.BoolType + case (_: FixedType, _: FixedType) => Utils.BoolType + case (_: IntervalType, _: IntervalType) => Utils.BoolType + case _ => UnknownType + } + override def toString = "lt" } /** Less Than Or Equal To */ - case object Leq extends PrimOp { override def toString = "leq" } + case object Leq extends PrimOp { + override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { + case (_: UIntType, _: UIntType) => Utils.BoolType + case (_: SIntType, _: SIntType) => Utils.BoolType + case (_: FixedType, _: FixedType) => Utils.BoolType + case (_: IntervalType, _: IntervalType) => Utils.BoolType + case _ => UnknownType + } + override def toString = "leq" } /** Greater Than */ - case object Gt extends PrimOp { override def toString = "gt" } + case object Gt extends PrimOp { + override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { + case (_: UIntType, _: UIntType) => Utils.BoolType + case (_: SIntType, _: SIntType) => Utils.BoolType + case (_: FixedType, _: FixedType) => Utils.BoolType + case (_: IntervalType, _: IntervalType) => Utils.BoolType + case _ => UnknownType + } + override def toString = "gt" } /** Greater Than Or Equal To */ - case object Geq extends PrimOp { override def toString = "geq" } + case object Geq extends PrimOp { + override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { + case (_: UIntType, _: UIntType) => Utils.BoolType + case (_: SIntType, _: SIntType) => Utils.BoolType + case (_: FixedType, _: FixedType) => Utils.BoolType + case (_: IntervalType, _: IntervalType) => Utils.BoolType + case _ => UnknownType + } + override def toString = "geq" } /** Equal To */ - case object Eq extends PrimOp { override def toString = "eq" } + case object Eq extends PrimOp { + override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { + case (_: UIntType, _: UIntType) => Utils.BoolType + case (_: SIntType, _: SIntType) => Utils.BoolType + case (_: FixedType, _: FixedType) => Utils.BoolType + case (_: IntervalType, _: IntervalType) => Utils.BoolType + case _ => UnknownType + } + override def toString = "eq" } /** Not Equal To */ - case object Neq extends PrimOp { override def toString = "neq" } + case object Neq extends PrimOp { + override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { + case (_: UIntType, _: UIntType) => Utils.BoolType + case (_: SIntType, _: SIntType) => Utils.BoolType + case (_: FixedType, _: FixedType) => Utils.BoolType + case (_: IntervalType, _: IntervalType) => Utils.BoolType + case _ => UnknownType + } + override def toString = "neq" } /** Padding */ - case object Pad extends PrimOp { override def toString = "pad" } - /** Interpret As UInt */ - case object AsUInt extends PrimOp { override def toString = "asUInt" } - /** Interpret As SInt */ - case object AsSInt extends PrimOp { override def toString = "asSInt" } - /** Interpret As Clock */ - case object AsClock extends PrimOp { override def toString = "asClock" } - /** Interpret As AsyncReset */ - case object AsAsyncReset extends PrimOp { override def toString = "asAsyncReset" } + case object Pad extends PrimOp { + override def propagateType(e: DoPrim): Type = t1(e) match { + case _: UIntType => UIntType(IsMax(w1(e), c1(e))) + case _: SIntType => SIntType(IsMax(w1(e), c1(e))) + case _: FixedType => FixedType(IsMax(w1(e), c1(e)), p1(e)) + case _ => UnknownType + } + override def toString = "pad" } /** Static Shift Left */ - case object Shl extends PrimOp { override def toString = "shl" } + case object Shl extends PrimOp { + override def propagateType(e: DoPrim): Type = t1(e) match { + case _: UIntType => UIntType(IsAdd(w1(e), c1(e))) + case _: SIntType => SIntType(IsAdd(w1(e), c1(e))) + case _: FixedType => FixedType(IsAdd(w1(e),c1(e)), p1(e)) + case IntervalType(l, u, p) => IntervalType(IsMul(l, Closed(BigDecimal(BigInt(1) << o1(e).toInt))), IsMul(u, Closed(BigDecimal(BigInt(1) << o1(e).toInt))), p) + case _ => UnknownType + } + override def toString = "shl" } /** Static Shift Right */ - case object Shr extends PrimOp { override def toString = "shr" } + case object Shr extends PrimOp { + override def propagateType(e: DoPrim): Type = t1(e) match { + case _: UIntType => UIntType(IsMax(IsAdd(w1(e), IsNeg(c1(e))), IntWidth(1))) + case _: SIntType => SIntType(IsMax(IsAdd(w1(e), IsNeg(c1(e))), IntWidth(1))) + case _: FixedType => FixedType(IsMax(IsMax(IsAdd(w1(e), IsNeg(c1(e))), IntWidth(1)), p1(e)), p1(e)) + case IntervalType(l, u, IntWidth(p)) => + val shiftMul = Closed(BigDecimal(1) / BigDecimal(BigInt(1) << o1(e).toInt)) + // BP is inferred at this point + val bpRes = Closed(BigDecimal(1) / BigDecimal(BigInt(1) << p.toInt)) + val bpResInv = Closed(BigDecimal(BigInt(1) << p.toInt)) + val newL = IsMul(IsFloor(IsMul(IsMul(l, shiftMul), bpResInv)), bpRes) + val newU = IsMul(IsFloor(IsMul(IsMul(u, shiftMul), bpResInv)), bpRes) + // BP doesn't grow + IntervalType(newL, newU, IntWidth(p)) + case _ => UnknownType + } + override def toString = "shr" + } /** Dynamic Shift Left */ - case object Dshl extends PrimOp { override def toString = "dshl" } + case object Dshl extends PrimOp { + override def propagateType(e: DoPrim): Type = t1(e) match { + case _: UIntType => UIntType(IsAdd(w1(e), IsAdd(IsPow(w2(e)), Closed(-1)))) + case _: SIntType => SIntType(IsAdd(w1(e), IsAdd(IsPow(w2(e)), Closed(-1)))) + case _: FixedType => FixedType(IsAdd(w1(e), IsAdd(IsPow(w2(e)), Closed(-1))), p1(e)) + case IntervalType(l, u, p) => + val maxShiftAmt = IsAdd(IsPow(w2(e)), Closed(-1)) + val shiftMul = IsPow(maxShiftAmt) + // Magnitude matters! i.e. if l is negative, shifting by the largest amount makes the outcome more negative + // whereas if l is positive, shifting by the largest amount makes the outcome more positive (in this case, the lower bound is the previous l) + val newL = IsMin(l, IsMul(l, shiftMul)) + val newU = IsMax(u, IsMul(u, shiftMul)) + // BP doesn't grow + IntervalType(newL, newU, p) + case _ => UnknownType + } + override def toString = "dshl" + } /** Dynamic Shift Right */ - case object Dshr extends PrimOp { override def toString = "dshr" } + case object Dshr extends PrimOp { + override def propagateType(e: DoPrim): Type = t1(e) match { + case _: UIntType => UIntType(w1(e)) + case _: SIntType => SIntType(w1(e)) + case _: FixedType => FixedType(w1(e), p1(e)) + // Decreasing magnitude -- don't need more bits + case IntervalType(l, u, p) => IntervalType(l, u, p) + case _ => UnknownType + } + override def toString = "dshr" + } /** Arithmetic Convert to Signed */ - case object Cvt extends PrimOp { override def toString = "cvt" } + case object Cvt extends PrimOp { + override def propagateType(e: DoPrim): Type = t1(e) match { + case _: UIntType => SIntType(IsAdd(w1(e), IntWidth(1))) + case _: SIntType => SIntType(w1(e)) + case _ => UnknownType + } + override def toString = "cvt" + } /** Negate */ - case object Neg extends PrimOp { override def toString = "neg" } + case object Neg extends PrimOp { + override def propagateType(e: DoPrim): Type = t1(e) match { + case _: UIntType => SIntType(IsAdd(w1(e), IntWidth(1))) + case _: SIntType => SIntType(IsAdd(w1(e), IntWidth(1))) + case _ => UnknownType + } + override def toString = "neg" + } /** Bitwise Complement */ - case object Not extends PrimOp { override def toString = "not" } + case object Not extends PrimOp { + override def propagateType(e: DoPrim): Type = t1(e) match { + case _: UIntType => UIntType(w1(e)) + case _: SIntType => UIntType(w1(e)) + case _ => UnknownType + } + override def toString = "not" + } /** Bitwise And */ - case object And extends PrimOp { override def toString = "and" } + case object And extends PrimOp { + override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { + case (_: SIntType | _: UIntType, _: SIntType | _: UIntType) => UIntType(IsMax(w1(e), w2(e))) + case _ => UnknownType + } + override def toString = "and" + } /** Bitwise Or */ - case object Or extends PrimOp { override def toString = "or" } + case object Or extends PrimOp { + override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { + case (_: SIntType | _: UIntType, _: SIntType | _: UIntType) => UIntType(IsMax(w1(e), w2(e))) + case _ => UnknownType + } + override def toString = "or" + } /** Bitwise Exclusive Or */ - case object Xor extends PrimOp { override def toString = "xor" } + case object Xor extends PrimOp { + override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { + case (_: SIntType | _: UIntType, _: SIntType | _: UIntType) => UIntType(IsMax(w1(e), w2(e))) + case _ => UnknownType + } + override def toString = "xor" + } /** Bitwise And Reduce */ - case object Andr extends PrimOp { override def toString = "andr" } + case object Andr extends PrimOp { + override def propagateType(e: DoPrim): Type = t1(e) match { + case (_: UIntType | _: SIntType) => Utils.BoolType + case _ => UnknownType + } + override def toString = "andr" + } /** Bitwise Or Reduce */ - case object Orr extends PrimOp { override def toString = "orr" } + case object Orr extends PrimOp { + override def propagateType(e: DoPrim): Type = t1(e) match { + case (_: UIntType | _: SIntType) => Utils.BoolType + case _ => UnknownType + } + override def toString = "orr" + } /** Bitwise Exclusive Or Reduce */ - case object Xorr extends PrimOp { override def toString = "xorr" } + case object Xorr extends PrimOp { + override def propagateType(e: DoPrim): Type = t1(e) match { + case (_: UIntType | _: SIntType) => Utils.BoolType + case _ => UnknownType + } + override def toString = "xorr" + } /** Concatenate */ - case object Cat extends PrimOp { override def toString = "cat" } + case object Cat extends PrimOp { + override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { + case (_: UIntType | _: SIntType | _: FixedType | _: IntervalType, _: UIntType | _: SIntType | _: FixedType | _: IntervalType) => UIntType(IsAdd(w1(e), w2(e))) + case (t1, t2) => UnknownType + } + override def toString = "cat" + } /** Bit Extraction */ - case object Bits extends PrimOp { override def toString = "bits" } + case object Bits extends PrimOp { + override def propagateType(e: DoPrim): Type = t1(e) match { + case (_: UIntType | _: SIntType | _: FixedType | _: IntervalType) => UIntType(IsAdd(IsAdd(c1(e), IsNeg(c2(e))), IntWidth(1))) + case _ => UnknownType + } + override def toString = "bits" + } /** Head */ - case object Head extends PrimOp { override def toString = "head" } + case object Head extends PrimOp { + override def propagateType(e: DoPrim): Type = t1(e) match { + case (_: UIntType | _: SIntType | _: FixedType | _: IntervalType) => UIntType(c1(e)) + case _ => UnknownType + } + override def toString = "head" + } /** Tail */ - case object Tail extends PrimOp { override def toString = "tail" } + case object Tail extends PrimOp { + override def propagateType(e: DoPrim): Type = t1(e) match { + case (_: UIntType | _: SIntType | _: FixedType | _: IntervalType) => UIntType(IsAdd(w1(e), IsNeg(c1(e)))) + case _ => UnknownType + } + override def toString = "tail" + } + /** Increase Precision **/ + case object IncP extends PrimOp { + override def propagateType(e: DoPrim): Type = t1(e) match { + case _: FixedType => FixedType(IsAdd(w1(e),c1(e)), IsAdd(p1(e), c1(e))) + // Keeps the same exact value, but adds more precision for the future i.e. aaa.bbb -> aaa.bbb00 + case IntervalType(l, u, p) => IntervalType(l, u, IsAdd(p, c1(e))) + case _ => UnknownType + } + override def toString = "incp" + } + /** Decrease Precision **/ + case object DecP extends PrimOp { + override def propagateType(e: DoPrim): Type = t1(e) match { + case _: FixedType => FixedType(IsAdd(w1(e),IsNeg(c1(e))), IsAdd(p1(e), IsNeg(c1(e)))) + case IntervalType(l, u, IntWidth(p)) => + val shiftMul = Closed(BigDecimal(1) / BigDecimal(BigInt(1) << o1(e).toInt)) + // BP is inferred at this point + // newBPRes is the only difference in calculating bpshr from shr + // y = floor(x * 2^(-amt + bp)) gets rid of precision --> y * 2^(-bp + amt) + // without amt, same op as shr + val newBPRes = Closed(BigDecimal(BigInt(1) << o1(e).toInt) / BigDecimal(BigInt(1) << p.toInt)) + val bpResInv = Closed(BigDecimal(BigInt(1) << p.toInt)) + val newL = IsMul(IsFloor(IsMul(IsMul(l, shiftMul), bpResInv)), newBPRes) + val newU = IsMul(IsFloor(IsMul(IsMul(u, shiftMul), bpResInv)), newBPRes) + // BP doesn't grow + IntervalType(newL, newU, IsAdd(IntWidth(p), IsNeg(c1(e)))) + case _ => UnknownType + } + override def toString = "decp" + } + /** Set Precision **/ + case object SetP extends PrimOp { + override def propagateType(e: DoPrim): Type = t1(e) match { + case _: FixedType => FixedType(IsAdd(c1(e), IsAdd(w1(e), IsNeg(p1(e)))), c1(e)) + case IntervalType(l, u, p) => + val newBPResInv = Closed(BigDecimal(BigInt(1) << o1(e).toInt)) + val newBPRes = Closed(BigDecimal(1) / BigDecimal(BigInt(1) << o1(e).toInt)) + val newL = IsMul(IsFloor(IsMul(l, newBPResInv)), newBPRes) + val newU = IsMul(IsFloor(IsMul(u, newBPResInv)), newBPRes) + IntervalType(newL, newU, c1(e)) + case _ => UnknownType + } + override def toString = "setp" + } + /** Interpret As UInt */ + case object AsUInt extends PrimOp { + override def propagateType(e: DoPrim): Type = t1(e) match { + case _: UIntType => UIntType(w1(e)) + case _: SIntType => UIntType(w1(e)) + case _: FixedType => UIntType(w1(e)) + case ClockType => UIntType(IntWidth(1)) + case AsyncResetType => UIntType(IntWidth(1)) + case ResetType => UIntType(IntWidth(1)) + case AnalogType(w) => UIntType(w1(e)) + case _: IntervalType => UIntType(w1(e)) + case _ => UnknownType + } + override def toString = "asUInt" + } + /** Interpret As SInt */ + case object AsSInt extends PrimOp { + override def propagateType(e: DoPrim): Type = t1(e) match { + case _: UIntType => SIntType(w1(e)) + case _: SIntType => SIntType(w1(e)) + case _: FixedType => SIntType(w1(e)) + case ClockType => SIntType(IntWidth(1)) + case AsyncResetType => SIntType(IntWidth(1)) + case ResetType => SIntType(IntWidth(1)) + case _: AnalogType => SIntType(w1(e)) + case _: IntervalType => SIntType(w1(e)) + case _ => UnknownType + } + override def toString = "asSInt" + } + /** Interpret As Clock */ + case object AsClock extends PrimOp { + override def propagateType(e: DoPrim): Type = t1(e) match { + case _: UIntType => ClockType + case _: SIntType => ClockType + case ClockType => ClockType + case AsyncResetType => ClockType + case ResetType => ClockType + case _: AnalogType => ClockType + case _: IntervalType => ClockType + case _ => UnknownType + } + override def toString = "asClock" + } + /** Interpret As AsyncReset */ + case object AsAsyncReset extends PrimOp { + override def propagateType(e: DoPrim): Type = t1(e) match { + case _: UIntType | _: SIntType | _: AnalogType | ClockType | AsyncResetType | ResetType | _: IntervalType | _: FixedType => AsyncResetType + case _ => UnknownType + } + override def toString = "asAsyncReset" + } /** Interpret as Fixed Point **/ - case object AsFixedPoint extends PrimOp { override def toString = "asFixedPoint" } - /** Shift Binary Point Left **/ - case object BPShl extends PrimOp { override def toString = "bpshl" } - /** Shift Binary Point Right **/ - case object BPShr extends PrimOp { override def toString = "bpshr" } - /** Set Binary Point **/ - case object BPSet extends PrimOp { override def toString = "bpset" } + case object AsFixedPoint extends PrimOp { + override def propagateType(e: DoPrim): Type = t1(e) match { + case _: UIntType => FixedType(w1(e), c1(e)) + case _: SIntType => FixedType(w1(e), c1(e)) + case _: FixedType => FixedType(w1(e), c1(e)) + case ClockType => FixedType(IntWidth(1), c1(e)) + case _: AnalogType => FixedType(w1(e), c1(e)) + case AsyncResetType => FixedType(IntWidth(1), c1(e)) + case ResetType => FixedType(IntWidth(1), c1(e)) + case _: IntervalType => FixedType(w1(e), c1(e)) + case _ => UnknownType + } + override def toString = "asFixedPoint" + } + /** Interpret as Interval (closed lower bound, closed upper bound, binary point) **/ + case object AsInterval extends PrimOp { + override def propagateType(e: DoPrim): Type = t1(e) match { + // Chisel shifts up and rounds first. + case _: UIntType | _: SIntType | _: FixedType | ClockType | AsyncResetType | ResetType | _: AnalogType | _: IntervalType => + IntervalType(Closed(BigDecimal(o1(e))/BigDecimal(BigInt(1) << o3(e).toInt)), Closed(BigDecimal(o2(e))/BigDecimal(BigInt(1) << o3(e).toInt)), IntWidth(o3(e))) + case _ => UnknownType + } + override def toString = "asInterval" + } + /** Try to fit the first argument into the type of the smaller argument **/ + case object Squeeze extends PrimOp { + override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { + case (IntervalType(l1, u1, p1), IntervalType(l2, u2, _)) => + val low = IsMax(l1, l2) + val high = IsMin(u1, u2) + IntervalType(IsMin(low, u2), IsMax(l2, high), p1) + case _ => UnknownType + } + override def toString = "squz" + } + /** Wrap First Operand Around Range/Width of Second Operand **/ + case object Wrap extends PrimOp { + override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { + case (IntervalType(l1, u1, p1), IntervalType(l2, u2, _)) => IntervalType(l2, u2, p1) + case _ => UnknownType + } + override def toString = "wrap" + } + /** Clip First Operand At Range/Width of Second Operand **/ + case object Clip extends PrimOp { + override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { + case (IntervalType(l1, u1, p1), IntervalType(l2, u2, _)) => + val low = IsMax(l1, l2) + val high = IsMin(u1, u2) + IntervalType(IsMin(low, u2), IsMax(l2, high), p1) + case _ => UnknownType + } + override def toString = "clip" + } private lazy val builtinPrimOps: Seq[PrimOp] = - Seq(Add, Sub, Mul, Div, Rem, Lt, Leq, Gt, Geq, Eq, Neq, Pad, AsUInt, AsSInt, AsClock, - AsAsyncReset, Shl, Shr, Dshl, Dshr, Neg, Cvt, Not, And, Or, Xor, Andr, Orr, Xorr, Cat, Bits, - Head, Tail, AsFixedPoint, BPShl, BPShr, BPSet) - private lazy val strToPrimOp: Map[String, PrimOp] = builtinPrimOps.map { case op : PrimOp=> op.toString -> op }.toMap + Seq(Add, Sub, Mul, Div, Rem, Lt, Leq, Gt, Geq, Eq, Neq, Pad, AsUInt, AsSInt, AsInterval, AsClock, AsAsyncReset, Shl, Shr, + Dshl, Dshr, Neg, Cvt, Not, And, Or, Xor, Andr, Orr, Xorr, Cat, Bits, Head, Tail, AsFixedPoint, IncP, DecP, + SetP, Wrap, Clip, Squeeze) + private lazy val strToPrimOp: Map[String, PrimOp] = { + builtinPrimOps.map { case op : PrimOp=> op.toString -> op }.toMap + } /** Seq of String representations of [[ir.PrimOp]]s */ lazy val listing: Seq[String] = builtinPrimOps map (_.toString) @@ -96,169 +491,10 @@ object PrimOps extends LazyLogging { def fromString(op: String): PrimOp = strToPrimOp(op) // Width Constraint Functions - def PLUS (w1:Width, w2:Width) : Width = (w1, w2) match { - case (IntWidth(i), IntWidth(j)) => IntWidth(i + j) - case _ => PlusWidth(w1, w2) - } - def MAX (w1:Width, w2:Width) : Width = (w1, w2) match { - case (IntWidth(i), IntWidth(j)) => IntWidth(max(i,j)) - case _ => MaxWidth(Seq(w1, w2)) - } - def MINUS (w1:Width, w2:Width) : Width = (w1, w2) match { - case (IntWidth(i), IntWidth(j)) => IntWidth(i - j) - case _ => MinusWidth(w1, w2) - } - def POW (w1:Width) : Width = w1 match { - case IntWidth(i) => IntWidth(pow_minus_one(BigInt(2), i)) - case _ => ExpWidth(w1) - } - def MIN (w1:Width, w2:Width) : Width = (w1, w2) match { - case (IntWidth(i), IntWidth(j)) => IntWidth(min(i,j)) - case _ => MinWidth(Seq(w1, w2)) - } + def PLUS(w1: Width, w2: Width): Constraint = IsAdd(w1, w2) + def MAX(w1: Width, w2: Width): Constraint = IsMax(w1, w2) + def MINUS(w1: Width, w2: Width): Constraint = IsAdd(w1, IsNeg(w2)) + def MIN(w1: Width, w2: Width): Constraint = IsMin(w1, w2) - // Borrowed from Stanza implementation - def set_primop_type (e:DoPrim) : DoPrim = { - //println-all(["Inferencing primop type: " e]) - def t1 = e.args.head.tpe - def t2 = e.args(1).tpe - def w1 = getWidth(e.args.head.tpe) - def w2 = getWidth(e.args(1).tpe) - def p1 = t1 match { case FixedType(w, p) => p } //Intentional - def p2 = t2 match { case FixedType(w, p) => p } //Intentional - def c1 = IntWidth(e.consts.head) - def c2 = IntWidth(e.consts(1)) - e copy (tpe = e.op match { - case Add | Sub => (t1, t2) match { - case (_: UIntType, _: UIntType) => UIntType(PLUS(MAX(w1, w2), IntWidth(1))) - case (_: SIntType, _: SIntType) => SIntType(PLUS(MAX(w1, w2), IntWidth(1))) - case (_: FixedType, _: FixedType) => FixedType(PLUS(PLUS(MAX(p1, p2), MAX(MINUS(w1, p1), MINUS(w2, p2))), IntWidth(1)), MAX(p1, p2)) - case _ => UnknownType - } - case Mul => (t1, t2) match { - case (_: UIntType, _: UIntType) => UIntType(PLUS(w1, w2)) - case (_: SIntType, _: SIntType) => SIntType(PLUS(w1, w2)) - case (_: FixedType, _: FixedType) => FixedType(PLUS(w1, w2), PLUS(p1, p2)) - case _ => UnknownType - } - case Div => (t1, t2) match { - case (_: UIntType, _: UIntType) => UIntType(w1) - case (_: SIntType, _: SIntType) => SIntType(PLUS(w1, IntWidth(1))) - case _ => UnknownType - } - case Rem => (t1, t2) match { - case (_: UIntType, _: UIntType) => UIntType(MIN(w1, w2)) - case (_: SIntType, _: SIntType) => SIntType(MIN(w1, w2)) - case _ => UnknownType - } - case Lt | Leq | Gt | Geq | Eq | Neq => (t1, t2) match { - case (_: UIntType, _: UIntType) => Utils.BoolType - case (_: SIntType, _: SIntType) => Utils.BoolType - case (_: FixedType, _: FixedType) => Utils.BoolType - case _ => UnknownType - } - case Pad => t1 match { - case _: UIntType => UIntType(MAX(w1, c1)) - case _: SIntType => SIntType(MAX(w1, c1)) - case _: FixedType => FixedType(MAX(w1, c1), p1) - case _ => UnknownType - } - case AsUInt => t1 match { - case (_: UIntType | _: SIntType | _: FixedType | _: AnalogType) => UIntType(w1) - case ClockType | AsyncResetType | ResetType => UIntType(IntWidth(1)) - case _ => UnknownType - } - case AsSInt => t1 match { - case (_: UIntType | _: SIntType | _: FixedType | _: AnalogType) => SIntType(w1) - case ClockType | AsyncResetType | ResetType => SIntType(IntWidth(1)) - case _ => UnknownType - } - case AsFixedPoint => t1 match { - case (_: UIntType | _: SIntType | _: FixedType | _: AnalogType) => FixedType(w1, c1) - case ClockType | AsyncResetType | ResetType => FixedType(IntWidth(1), c1) - case _ => UnknownType - } - case AsClock => t1 match { - case (_: UIntType | _: SIntType | _: AnalogType | ClockType | AsyncResetType | ResetType) => ClockType - case _ => UnknownType - } - case AsAsyncReset => t1 match { - case (_: UIntType | _: SIntType | _: AnalogType | ClockType | AsyncResetType | ResetType) => AsyncResetType - case _ => UnknownType - } - case Shl => t1 match { - case _: UIntType => UIntType(PLUS(w1, c1)) - case _: SIntType => SIntType(PLUS(w1, c1)) - case _: FixedType => FixedType(PLUS(w1,c1), p1) - case _ => UnknownType - } - case Shr => t1 match { - case _: UIntType => UIntType(MAX(MINUS(w1, c1), IntWidth(1))) - case _: SIntType => SIntType(MAX(MINUS(w1, c1), IntWidth(1))) - case _: FixedType => FixedType(MAX(MAX(MINUS(w1,c1), IntWidth(1)), p1), p1) - case _ => UnknownType - } - case Dshl => t1 match { - case _: UIntType => UIntType(PLUS(w1, POW(w2))) - case _: SIntType => SIntType(PLUS(w1, POW(w2))) - case _: FixedType => FixedType(PLUS(w1, POW(w2)), p1) - case _ => UnknownType - } - case Dshr => t1 match { - case _: UIntType => UIntType(w1) - case _: SIntType => SIntType(w1) - case _: FixedType => FixedType(w1, p1) - case _ => UnknownType - } - case Cvt => t1 match { - case _: UIntType => SIntType(PLUS(w1, IntWidth(1))) - case _: SIntType => SIntType(w1) - case _ => UnknownType - } - case Neg => t1 match { - case (_: UIntType | _: SIntType) => SIntType(PLUS(w1, IntWidth(1))) - case _ => UnknownType - } - case Not => t1 match { - case (_: UIntType | _: SIntType) => UIntType(w1) - case _ => UnknownType - } - case And | Or | Xor => (t1, t2) match { - case (_: SIntType | _: UIntType, _: SIntType | _: UIntType) => UIntType(MAX(w1, w2)) - case _ => UnknownType - } - case Andr | Orr | Xorr => t1 match { - case (_: UIntType | _: SIntType) => Utils.BoolType - case _ => UnknownType - } - case Cat => (t1, t2) match { - case (_: UIntType | _: SIntType | _: FixedType, _: UIntType | _: SIntType | _: FixedType) => UIntType(PLUS(w1, w2)) - case (t1, t2) => UnknownType - } - case Bits => t1 match { - case (_: UIntType | _: SIntType | _: FixedType) => UIntType(PLUS(MINUS(c1, c2), IntWidth(1))) - case _ => UnknownType - } - case Head => t1 match { - case (_: UIntType | _: SIntType | _: FixedType) => UIntType(c1) - case _ => UnknownType - } - case Tail => t1 match { - case (_: UIntType | _: SIntType | _: FixedType) => UIntType(MINUS(w1, c1)) - case _ => UnknownType - } - case BPShl => t1 match { - case _: FixedType => FixedType(PLUS(w1,c1), PLUS(p1, c1)) - case _ => UnknownType - } - case BPShr => t1 match { - case _: FixedType => FixedType(MINUS(w1,c1), MINUS(p1, c1)) - case _ => UnknownType - } - case BPSet => t1 match { - case _: FixedType => FixedType(PLUS(c1, MINUS(w1, p1)), c1) - case _ => UnknownType - } - }) - } + def set_primop_type(e: DoPrim): DoPrim = DoPrim(e.op, e.args, e.consts, e.op.propagateType(e)) } diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 4bf2d14d..8a76aca6 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -10,6 +10,8 @@ import firrtl.WrappedExpression._ import scala.collection.mutable import scala.util.matching.Regex +import Implicits.{constraint2bound, constraint2width, width2constraint} +import firrtl.constraint.{IsMax, IsMin} import firrtl.annotations.{ReferenceTarget, TargetToken} import _root_.logger.LazyLogging @@ -219,7 +221,10 @@ object Utils extends LazyLogging { def indent(str: String) = str replaceAllLiterally ("\n", "\n ") implicit def toWrappedExpression (x:Expression): WrappedExpression = new WrappedExpression(x) - def ceilLog2(x: BigInt): Int = (x-1).bitLength + def getSIntWidth(s: BigInt): Int = s.bitLength + 1 + def getUIntWidth(u: BigInt): Int = u.bitLength + def dec2string(v: BigDecimal): String = v.underlying().stripTrailingZeros().toPlainString + def trim(v: BigDecimal): BigDecimal = BigDecimal(dec2string(v)) def max(a: BigInt, b: BigInt): BigInt = if (a >= b) a else b def min(a: BigInt, b: BigInt): BigInt = if (a >= b) b else a def pow_minus_one(a: BigInt, b: BigInt): BigInt = a.pow(b.toInt) - 1 @@ -375,6 +380,7 @@ object Utils extends LazyLogging { case (t1: UIntType, t2: UIntType) => UIntType(UnknownWidth) case (t1: SIntType, t2: SIntType) => SIntType(UnknownWidth) case (t1: FixedType, t2: FixedType) => FixedType(UnknownWidth, UnknownWidth) + case (t1: IntervalType, t2: IntervalType) => IntervalType(UnknownBound, UnknownBound, UnknownWidth) case (t1: VectorType, t2: VectorType) => VectorType(mux_type(t1.tpe, t2.tpe), t1.size) case (t1: BundleType, t2: BundleType) => BundleType(t1.fields zip t2.fields map { case (f1, f2) => Field(f1.name, f1.flip, mux_type(f1.tpe, f2.tpe)) @@ -386,15 +392,17 @@ object Utils extends LazyLogging { def mux_type_and_widths(t1: Type, t2: Type): Type = { def wmax(w1: Width, w2: Width): Width = (w1, w2) match { case (w1x: IntWidth, w2x: IntWidth) => IntWidth(w1x.width max w2x.width) - case (w1x, w2x) => MaxWidth(Seq(w1x, w2x)) + case (w1x, w2x) => IsMax(w1x, w2x) } (t1, t2) match { case (ClockType, ClockType) => ClockType case (AsyncResetType, AsyncResetType) => AsyncResetType - case (t1x: UIntType, t2x: UIntType) => UIntType(wmax(t1x.width, t2x.width)) - case (t1x: SIntType, t2x: SIntType) => SIntType(wmax(t1x.width, t2x.width)) + case (t1x: UIntType, t2x: UIntType) => UIntType(IsMax(t1x.width, t2x.width)) + case (t1x: SIntType, t2x: SIntType) => SIntType(IsMax(t1x.width, t2x.width)) case (FixedType(w1, p1), FixedType(w2, p2)) => FixedType(PLUS(MAX(p1, p2),MAX(MINUS(w1, p1), MINUS(w2, p2))), MAX(p1, p2)) + case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) => + IntervalType(IsMin(l1, l2), constraint.IsMax(u1, u2), MAX(p1, p2)) case (t1x: VectorType, t2x: VectorType) => VectorType( mux_type_and_widths(t1x.tpe, t2x.tpe), t1x.size) case (t1x: BundleType, t2x: BundleType) => BundleType(t1x.fields zip t2x.fields map { diff --git a/src/main/scala/firrtl/Visitor.scala b/src/main/scala/firrtl/Visitor.scala index 01de8f15..112343d1 100644 --- a/src/main/scala/firrtl/Visitor.scala +++ b/src/main/scala/firrtl/Visitor.scala @@ -30,6 +30,7 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w private val HexPattern = """\"*h([+\-]?[a-zA-Z0-9]+)\"*""".r private val DecPattern = """([+\-]?[1-9]\d*)""".r private val ZeroPattern = "0".r + private val DecimalPattern = """([+\-]?[0-9]\d*\.[0-9]\d*)""".r private def string2BigInt(s: String): BigInt = { // private define legal patterns @@ -41,6 +42,16 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w } } + private def string2BigDecimal(s: String): BigDecimal = { + // private define legal patterns + s match { + case ZeroPattern(_*) => BigDecimal(0) + case DecPattern(num) => BigDecimal(num) + case DecimalPattern(num) => BigDecimal(num) + case _ => throw new Exception("Invalid String for conversion to BigDecimal " + s) + } + } + private def string2Int(s: String): Int = string2BigInt(s).toInt private def visitInfo(ctx: Option[InfoContext], parentCtx: ParserRuleContext): Info = { @@ -129,6 +140,30 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w } case 2 => FixedType(getWidth(ctx.intLit(0)), getWidth(ctx.intLit(1))) } + case "Interval" => ctx.boundValue.size match { + case 0 => + val point = ctx.intLit.size match { + case 0 => UnknownWidth + case 1 => IntWidth(string2BigInt(ctx.intLit(0).getText)) + } + IntervalType(UnknownBound, UnknownBound, point) + case 2 => + val lower = (ctx.lowerBound.getText, ctx.boundValue(0).getText) match { + case (_, "?") => UnknownBound + case ("(", v) => Open(string2BigDecimal(v)) + case ("[", v) => Closed(string2BigDecimal(v)) + } + val upper = (ctx.upperBound.getText, ctx.boundValue(1).getText) match { + case (_, "?") => UnknownBound + case (")", v) => Open(string2BigDecimal(v)) + case ("]", v) => Closed(string2BigDecimal(v)) + } + val point = ctx.intLit.size match { + case 0 => UnknownWidth + case 1 => IntWidth(string2BigInt(ctx.intLit(0).getText)) + } + IntervalType(lower, upper, point) + } case "Clock" => ClockType case "AsyncReset" => AsyncResetType case "Reset" => ResetType diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index 475f5e9c..eb4a665f 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -162,11 +162,49 @@ case class WDefInstanceConnector( } // Resultant width is the same as the maximum input width -case object Addw extends PrimOp { override def toString = "addw" } +case object Addw extends PrimOp { + override def toString = "addw" + import constraint._ + import PrimOps._ + import Implicits.{constraint2width, width2constraint} + + override def propagateType(e: DoPrim): Type = { + (e.args(0).tpe, e.args(1).tpe) match { + case (_: UIntType, _: UIntType) => UIntType(IsMax(w1(e), w2(e))) + case (_: SIntType, _: SIntType) => SIntType(IsMax(w1(e), w2(e))) + case _ => UnknownType + } + } +} // Resultant width is the same as the maximum input width -case object Subw extends PrimOp { override def toString = "subw" } +case object Subw extends PrimOp { + override def toString = "subw" + import constraint._ + import PrimOps._ + import Implicits.{constraint2width, width2constraint} + + override def propagateType(e: DoPrim): Type = { + (e.args(0).tpe, e.args(1).tpe) match { + case (_: UIntType, _: UIntType) => UIntType(IsMax(w1(e), w2(e))) + case (_: SIntType, _: SIntType) => SIntType(IsMax(w1(e), w2(e))) + case _ => UnknownType + } + } +} // Resultant width is the same as input argument width -case object Dshlw extends PrimOp { override def toString = "dshlw" } +case object Dshlw extends PrimOp { + override def toString = "dshlw" + + import PrimOps._ + + override def propagateType(e: DoPrim): Type = { + e.args(0).tpe match { + case _: UIntType => UIntType(w1(e)) + case _: SIntType => SIntType(w1(e)) + case _ => UnknownType + } + } +} object WrappedExpression { def apply(e: Expression) = new WrappedExpression(e) @@ -200,30 +238,6 @@ class WrappedExpression(val e1: Expression) { private[firrtl] sealed trait HasMapWidth { def mapWidth(f: Width => Width): Width } -case class VarWidth(name: String) extends Width with HasMapWidth { - def serialize: String = name - def mapWidth(f: Width => Width): Width = this -} -case class PlusWidth(arg1: Width, arg2: Width) extends Width with HasMapWidth { - def serialize: String = "(" + arg1.serialize + " + " + arg2.serialize + ")" - def mapWidth(f: Width => Width): Width = PlusWidth(f(arg1), f(arg2)) -} -case class MinusWidth(arg1: Width, arg2: Width) extends Width with HasMapWidth { - def serialize: String = "(" + arg1.serialize + " - " + arg2.serialize + ")" - def mapWidth(f: Width => Width): Width = MinusWidth(f(arg1), f(arg2)) -} -case class MaxWidth(args: Seq[Width]) extends Width with HasMapWidth { - def serialize: String = args map (_.serialize) mkString ("max(", ", ", ")") - def mapWidth(f: Width => Width): Width = MaxWidth(args map f) -} -case class MinWidth(args: Seq[Width]) extends Width with HasMapWidth { - def serialize: String = args map (_.serialize) mkString ("min(", ", ", ")") - def mapWidth(f: Width => Width): Width = MinWidth(args map f) -} -case class ExpWidth(arg1: Width) extends Width with HasMapWidth { - def serialize: String = "exp(" + arg1.serialize + " )" - def mapWidth(f: Width => Width): Width = ExpWidth(f(arg1)) -} object WrappedType { def apply(t: Type) = new WrappedType(t) @@ -240,6 +254,7 @@ object WrappedType { case (ResetType, tpe) => legalResetType(tpe) case (tpe, ResetType) => legalResetType(tpe) case (_: FixedType, _: FixedType) => true + case (_: IntervalType, _: IntervalType) => true // Analog totally skips out of the Firrtl type system. // The only way Analog can play with another Analog component is through Attach. // Ohterwise, we'd need to special case it during ExpandWhens, Lowering, @@ -280,29 +295,13 @@ class WrappedWidth (val w: Width) { def ww(w: Width): WrappedWidth = new WrappedWidth(w) override def toString = w match { case (w: VarWidth) => w.name - case (w: MaxWidth) => s"max(${w.args.mkString})" - case (w: MinWidth) => s"min(${w.args.mkString})" - case (w: PlusWidth) => s"(${w.arg1} + ${w.arg2})" - case (w: MinusWidth) => s"(${w.arg1} -${w.arg2})" - case (w: ExpWidth) => s"exp(${w.arg1})" case (w: IntWidth) => w.width.toString case UnknownWidth => "?" } override def equals(o: Any): Boolean = o match { case (w2: WrappedWidth) => (w, w2.w) match { case (w1: VarWidth, w2: VarWidth) => w1.name.equals(w2.name) - case (w1: MaxWidth, w2: MaxWidth) => w1.args.size == w2.args.size && - (w1.args forall (a1 => w2.args exists (a2 => eqw(a1, a2)))) - case (w1: MinWidth, w2: MinWidth) => w1.args.size == w2.args.size && - (w1.args forall (a1 => w2.args exists (a2 => eqw(a1, a2)))) case (w1: IntWidth, w2: IntWidth) => w1.width == w2.width - case (w1: PlusWidth, w2: PlusWidth) => - (ww(w1.arg1) == ww(w2.arg1) && ww(w1.arg2) == ww(w2.arg2)) || - (ww(w1.arg1) == ww(w2.arg2) && ww(w1.arg2) == ww(w2.arg1)) - case (w1: MinusWidth,w2: MinusWidth) => - (ww(w1.arg1) == ww(w2.arg1) && ww(w1.arg2) == ww(w2.arg2)) || - (ww(w1.arg1) == ww(w2.arg2) && ww(w1.arg2) == ww(w2.arg1)) - case (w1: ExpWidth, w2: ExpWidth) => ww(w1.arg1) == ww(w2.arg1) case (UnknownWidth, UnknownWidth) => true case _ => false } @@ -310,18 +309,6 @@ class WrappedWidth (val w: Width) { } } -trait Constraint -class WGeq(val loc: Width, val exp: Width) extends Constraint { - override def toString = { - val wloc = new WrappedWidth(loc) - val wexp = new WrappedWidth(exp) - wloc.toString + " >= " + wexp.toString - } -} -object WGeq { - def apply(loc: Width, exp: Width) = new WGeq(loc, exp) -} - abstract class MPortDir extends FirrtlNode case object MInfer extends MPortDir { def serialize: String = "infer" diff --git a/src/main/scala/firrtl/annotations/Target.scala b/src/main/scala/firrtl/annotations/Target.scala index 313f1dc2..1571f98e 100644 --- a/src/main/scala/firrtl/annotations/Target.scala +++ b/src/main/scala/firrtl/annotations/Target.scala @@ -3,7 +3,7 @@ package firrtl package annotations -import firrtl.ir.{Expression, Type} +import firrtl.ir.{Field => _, _} import firrtl.Utils.{sub_type, field_type} import AnnotationUtils.{toExp, validComponentName, validModuleName} import TargetToken._ @@ -41,7 +41,7 @@ sealed trait Target extends Named { case Ref(r) => s">$r" case Instance(i) => s"/$i" case OfModule(o) => s":$o" - case Field(f) => s".$f" + case TargetToken.Field(f) => s".$f" case Index(v) => s"[$v]" case Clock => s"@clock" case Reset => s"@reset" @@ -103,6 +103,21 @@ sealed trait Target extends Named { } object Target { + def asTarget(m: ModuleTarget)(e: Expression): ReferenceTarget = e match { + case w: WRef => m.ref(w.name) + case r: ir.Reference => m.ref(r.name) + case w: WSubIndex => asTarget(m)(w.expr).index(w.value) + case s: ir.SubIndex => asTarget(m)(s.expr).index(s.value) + case w: WSubField => asTarget(m)(w.expr).field(w.name) + case s: ir.SubField => asTarget(m)(s.expr).field(s.name) + case w: WSubAccess => asTarget(m)(w.expr).field("@" + w.index.serialize) + case s: ir.SubAccess => asTarget(m)(s.expr).field("@" + s.index.serialize) + case d: DoPrim => m.ref("@" + d.serialize) + case d: Mux => m.ref("@" + d.serialize) + case d: ValidIf => m.ref("@" + d.serialize) + case d: Literal => m.ref("@" + d.serialize) + case other => sys.error(s"Unsupported: $other") + } def apply(circuitOpt: Option[String], moduleOpt: Option[String], reference: Seq[TargetToken]): GenericTarget = GenericTarget(circuitOpt, moduleOpt, reference.toVector) diff --git a/src/main/scala/firrtl/constraint/Constraint.scala b/src/main/scala/firrtl/constraint/Constraint.scala new file mode 100644 index 00000000..247593ee --- /dev/null +++ b/src/main/scala/firrtl/constraint/Constraint.scala @@ -0,0 +1,22 @@ +// See LICENSE for license details. + +package firrtl.constraint + +/** Trait for all Constraint Solver expressions */ +trait Constraint { + def serialize: String + def map(f: Constraint => Constraint): Constraint + val children: Vector[Constraint] + def reduce(): Constraint +} + +/** Trait for constraints with more than one argument */ +trait MultiAry extends Constraint { + def op(a: IsKnown, b: IsKnown): IsKnown + def merge(b1: Option[IsKnown], b2: Option[IsKnown]): Option[IsKnown] = (b1, b2) match { + case (Some(x), Some(y)) => Some(op(x, y)) + case (_, y: Some[_]) => y + case (x: Some[_], _) => x + case _ => None + } +} diff --git a/src/main/scala/firrtl/constraint/ConstraintSolver.scala b/src/main/scala/firrtl/constraint/ConstraintSolver.scala new file mode 100644 index 00000000..52440b15 --- /dev/null +++ b/src/main/scala/firrtl/constraint/ConstraintSolver.scala @@ -0,0 +1,357 @@ +// See LICENSE for license details. + +package firrtl.constraint + +import firrtl._ +import firrtl.ir._ +import firrtl.Utils.throwInternalError +import firrtl.annotations.ReferenceTarget + +import scala.collection.mutable + +/** Forwards-Backwards Constraint Solver + * + * Used for computing [[Width]] and [[Bound]] constraints + * + * Note - this is an O(N) algorithm, but requires exponential memory. We rely on aggressive early optimization + * of constraint expressions to (usually) get around this. + */ +class ConstraintSolver { + + /** Initial, mutable constraint list, with function to add the constraint */ + private val constraints = mutable.ArrayBuffer[Inequality]() + + /** Solved constraints */ + type ConstraintMap = mutable.HashMap[String, (Constraint, Boolean)] + private val solvedConstraintMap = new ConstraintMap() + + + /** Clear all previously recorded/solved constraints */ + def clear(): Unit = { + constraints.clear() + solvedConstraintMap.clear() + } + + /** Updates internal list of inequalities with a new [[GreaterOrEqual]] + * @param big The larger constraint, must be either known or a variable + * @param small The smaller constraint + */ + def addGeq(big: Constraint, small: Constraint, r1: String, r2: String): Unit = (big, small) match { + case (IsVar(name), other: Constraint) => add(GreaterOrEqual(name, other)) + case _ => // Constraints on widths should never error, e.g. attach adds lots of unnecessary constraints + } + + /** Updates internal list of inequalities with a new [[GreaterOrEqual]] + * @param big The larger constraint, must be either known or a variable + * @param small The smaller constraint + */ + def addGeq(big: Width, small: Width, r1: String, r2: String): Unit = (big, small) match { + case (IsVar(name), other: CalcWidth) => add(GreaterOrEqual(name, other.arg)) + case (IsVar(name), other: IsVar) => add(GreaterOrEqual(name, other)) + case (IsVar(name), other: IntWidth) => add(GreaterOrEqual(name, Implicits.width2constraint(other))) + case _ => // Constraints on widths should never error, e.g. attach adds lots of unnecessary constraints + } + + /** Updates internal list of inequalities with a new [[LesserOrEqual]] + * @param small The smaller constraint, must be either known or a variable + * @param big The larger constraint + */ + def addLeq(small: Constraint, big: Constraint, r1: String, r2: String): Unit = (small, big) match { + case (IsVar(name), other: Constraint) => add(LesserOrEqual(name, other)) + case _ => // Constraints on widths should never error, e.g. attach adds lots of unnecessary constraints + } + + /** Updates internal list of inequalities with a new [[LesserOrEqual]] + * @param small The smaller constraint, must be either known or a variable + * @param big The larger constraint + */ + def addLeq(small: Width, big: Width, r1: String, r2: String): Unit = (small, big) match { + case (IsVar(name), other: CalcWidth) => add(LesserOrEqual(name, other.arg)) + case (IsVar(name), other: IsVar) => add(LesserOrEqual(name, other)) + case (IsVar(name), other: IntWidth) => add(LesserOrEqual(name, Implicits.width2constraint(other))) + case _ => // Constraints on widths should never error, e.g. attach adds lots of unnecessary constraints + } + + /** Returns a solved constraint, if it exists and is solved + * @param b + * @return + */ + def get(b: Constraint): Option[IsKnown] = { + val name = b match { + case IsVar(name) => name + case x => "" + } + solvedConstraintMap.get(name) match { + case None => None + case Some((k: IsKnown, _)) => Some(k) + case Some(_) => None + } + } + + /** Returns a solved width, if it exists and is solved + * @param b + * @return + */ + def get(b: Width): Option[IsKnown] = { + val name = b match { + case IsVar(name) => name + case x => "" + } + solvedConstraintMap.get(name) match { + case None => None + case Some((k: IsKnown, _)) => Some(k) + case Some(_) => None + } + } + + + private def add(c: Inequality) = constraints += c + + + /** Creates an Inequality given a variable name, constraint, and whether its >= or <= + * @param left + * @param right + * @param geq + * @return + */ + private def genConst(left: String, right: Constraint, geq: Boolean): Inequality = geq match { + case true => GreaterOrEqual(left, right) + case false => LesserOrEqual(left, right) + } + + /** For debugging, can serialize the initial constraints */ + def serializeConstraints: String = constraints.mkString("\n") + + /** For debugging, can serialize the solved constraints */ + def serializeSolutions: String = solvedConstraintMap.map{ + case (k, (v, true)) => s"$k >= ${v.serialize}" + case (k, (v, false)) => s"$k <= ${v.serialize}" + }.mkString("\n") + + + + /************* Constraint Solver Engine ****************/ + + /** Merges constraints on the same variable + * + * Returns a new list of Inequalities with a single Inequality per variable + * + * For example, given: + * a >= 1 + b + * a >= 3 + * + * Will return: + * a >= max(3, 1 + b) + * + * @param constraints + * @return + */ + private def mergeConstraints(constraints: Seq[Inequality]): Seq[Inequality] = { + val mergedMap = mutable.HashMap[String, Inequality]() + constraints.foreach { + case c if c.geq && mergedMap.contains(c.left) => + mergedMap(c.left) = genConst(c.left, IsMax(mergedMap(c.left).right, c.right), true) + case c if !c.geq && mergedMap.contains(c.left) => + mergedMap(c.left) = genConst(c.left, IsMin(mergedMap(c.left).right, c.right), false) + case c => + mergedMap(c.left) = c + } + mergedMap.values.toList + } + + + /** Attempts to substitute variables with their corresponding forward-solved constraints + * If no corresponding constraint has been visited yet, keep variable as is + * + * @param forwardSolved ConstraintMap containing earlier forward-solved constraints + * @param constraint Constraint to forward solve + * @return Forward solved constraint + */ + private def forwardSubstitution(forwardSolved: ConstraintMap)(constraint: Constraint): Constraint = { + val x = constraint map forwardSubstitution(forwardSolved) + x match { + case isVar: IsVar => forwardSolved get isVar.name match { + case None => isVar.asInstanceOf[Constraint] + case Some((p, geq)) => + val newT = forwardSubstitution(forwardSolved)(p) + forwardSolved(isVar.name) = (newT, geq) + newT + } + case other => other + } + } + + /** Attempts to substitute variables with their corresponding backwards-solved constraints + * If no corresponding constraint is solved, keep variable as is (as an unsolved constraint, + * which will be reported later) + * + * @param backwardSolved ConstraintMap containing earlier backward-solved constraints + * @param constraint Constraint to backward solve + * @return Backward solved constraint + */ + private def backwardSubstitution(backwardSolved: ConstraintMap)(constraint: Constraint): Constraint = { + constraint match { + case isVar: IsVar => backwardSolved.get(isVar.name) match { + case Some((p, geq)) => p + case _ => isVar + } + case other => other map backwardSubstitution(backwardSolved) + } + } + + /** Remove solvable cycles in an inequality + * + * For example: + * a >= max(1, a) + * + * Can be simplified to: + * a >= 1 + * @param name Name of the variable on left side of inequality + * @param geq Whether inequality is >= or <= + * @param constraint Constraint expression + * @return + */ + private def removeCycle(name: String, geq: Boolean)(constraint: Constraint): Constraint = + if(geq) removeGeqCycle(name)(constraint) else removeLeqCycle(name)(constraint) + + /** Removes solvable cycles of <= inequalities + * @param name Name of the variable on left side of inequality + * @param constraint Constraint expression + * @return + */ + private def removeLeqCycle(name: String)(constraint: Constraint): Constraint = constraint match { + case x if greaterEqThan(name)(x) => VarCon(name) + case isMin: IsMin => IsMin(isMin.children.filter{ c => !greaterEqThan(name)(c)}) + case x => x + } + + /** Removes solvable cycles of >= inequalities + * @param name Name of the variable on left side of inequality + * @param constraint Constraint expression + * @return + */ + private def removeGeqCycle(name: String)(constraint: Constraint): Constraint = constraint match { + case x if lessEqThan(name)(x) => VarCon(name) + case isMax: IsMax => IsMax(isMax.children.filter{c => !lessEqThan(name)(c)}) + case x => x + } + + private def greaterEqThan(name: String)(constraint: Constraint): Boolean = constraint match { + case isMin: IsMin => isMin.children.map(greaterEqThan(name)).reduce(_ && _) + case isAdd: IsAdd => isAdd.children match { + case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value >= 0) => true + case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value >= 0) => true + case _ => false + } + case isMul: IsMul => isMul.children match { + case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value >= 0) => true + case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value >= 0) => true + case _ => false + } + case isVar: IsVar if isVar.name == name => true + case _ => false + } + + private def lessEqThan(name: String)(constraint: Constraint): Boolean = constraint match { + case isMax: IsMax => isMax.children.map(lessEqThan(name)).reduce(_ && _) + case isAdd: IsAdd => isAdd.children match { + case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value <= 0) => true + case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value <= 0) => true + case _ => false + } + case isMul: IsMul => isMul.children match { + case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value <= 0) => true + case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value <= 0) => true + case _ => false + } + case isVar: IsVar if isVar.name == name => true + case isNeg: IsNeg => isNeg.child match { + case isVar: IsVar if isVar.name == name => true + case _ => false + } + case _ => false + } + + /** Whether a constraint contains the named variable + * @param name Name of variable + * @param constraint Constraint to check + * @return + */ + private def hasVar(name: String)(constraint: Constraint): Boolean = { + var has = false + def rec(constraint: Constraint): Constraint = { + constraint match { + case isVar: IsVar if isVar.name == name => has = true + case _ => + } + constraint map rec + } + rec(constraint) + has + } + + /** Returns illegal constraints, where both a >= and <= inequality are used on the same variable + * @return + */ + def check(): Seq[Inequality] = { + val checkMap = new mutable.HashMap[String, Inequality]() + constraints.foldLeft(Seq[Inequality]()) { (seq, c) => + checkMap.get(c.left) match { + case None => + checkMap(c.left) = c + seq ++ Nil + case Some(x) if x.geq != c.geq => seq ++ Seq(x, c) + case Some(x) => seq ++ Nil + } + } + } + + /** Solves constraints present in collected inequalities + * + * Constraint solving steps: + * 1) Assert no variable has both >= and <= inequalities (it can have multiple of the same kind of inequality) + * 2) Merge constraints of variables having multiple inequalities + * 3) Forward solve inequalities + * a. Iterate through inequalities top-to-bottom, replacing previously seen variables with corresponding + * constraint + * b. For each forward-solved inequality, attempt to remove circular constraints + * c. Forward-solved inequalities without circular constraints are recorded + * 4) Backwards solve inequalities + * a. Iterate through successful forward-solved inequalities bottom-to-top, replacing previously seen variables + * with corresponding constraint + * b. Record solved constraints + */ + def solve(): Unit = { + // 1) Check if any variable has both >= and <= inequalities (which is illegal) + val illegals = check() + if (illegals != Nil) throwInternalError(s"Constraints cannot have both >= and <= inequalities: $illegals") + + // 2) Merge constraints + val uniqueConstraints = mergeConstraints(constraints.toSeq) + + // 3) Forward Solve + val forwardConstraintMap = new ConstraintMap + val orderedVars = mutable.HashMap[Int, String]() + + var index = 0 + for (constraint <- uniqueConstraints) { + //TODO: Risky if used improperly... need to check whether substitution from a leq to a geq is negated (always). + val subbedRight = forwardSubstitution(forwardConstraintMap)(constraint.right) + val name = constraint.left + val finishedRight = removeCycle(name, constraint.geq)(subbedRight) + if (!hasVar(name)(finishedRight)) { + forwardConstraintMap(name) = (finishedRight, constraint.geq) + orderedVars(index) = name + index += 1 + } + } + + // 4) Backwards Solve + for (i <- (orderedVars.size - 1) to 0 by -1) { + val name = orderedVars(i) // Should visit `orderedVars` backward + val (forwardRight, forwardGeq) = forwardConstraintMap(name) + val solvedRight = backwardSubstitution(solvedConstraintMap)(forwardRight) + solvedConstraintMap(name) = (solvedRight, forwardGeq) + } + } +} diff --git a/src/main/scala/firrtl/constraint/Inequality.scala b/src/main/scala/firrtl/constraint/Inequality.scala new file mode 100644 index 00000000..0fa1d2eb --- /dev/null +++ b/src/main/scala/firrtl/constraint/Inequality.scala @@ -0,0 +1,24 @@ +// See LICENSE for license details. + +package firrtl.constraint + +/** Represents either greater or equal to or less than or equal to + * Is passed to the constraint solver to resolve + */ +trait Inequality { + def left: String + def right: Constraint + def geq: Boolean +} + +case class GreaterOrEqual(left: String, right: Constraint) extends Inequality { + val geq = true + override def toString: String = s"$left >= ${right.serialize}" +} + +case class LesserOrEqual(left: String, right: Constraint) extends Inequality { + val geq = false + override def toString: String = s"$left <= ${right.serialize}" +} + + diff --git a/src/main/scala/firrtl/constraint/IsAdd.scala b/src/main/scala/firrtl/constraint/IsAdd.scala new file mode 100644 index 00000000..e177a8b9 --- /dev/null +++ b/src/main/scala/firrtl/constraint/IsAdd.scala @@ -0,0 +1,52 @@ +// See LICENSE for license details. + + +package firrtl.constraint + +// Is case class because writing tests is easier due to equality is not object equality +case class IsAdd private (known: Option[IsKnown], + maxs: Vector[IsMax], + mins: Vector[IsMin], + others: Vector[Constraint]) extends Constraint with MultiAry { + + def op(b1: IsKnown, b2: IsKnown): IsKnown = b1 + b2 + + lazy val children: Vector[Constraint] = { + if(known.nonEmpty) known.get +: (maxs ++ mins ++ others) else maxs ++ mins ++ others + } + + def addChild(x: Constraint): IsAdd = x match { + case k: IsKnown => new IsAdd(merge(Some(k), known), maxs, mins, others) + case add: IsAdd => new IsAdd(merge(known, add.known), maxs ++ add.maxs, mins ++ add.mins, others ++ add.others) + case max: IsMax => new IsAdd(known, maxs :+ max, mins, others) + case min: IsMin => new IsAdd(known, maxs, mins :+ min, others) + case other => new IsAdd(known, maxs, mins, others :+ other) + } + + override def serialize: String = "(" + children.map(_.serialize).mkString(" + ") + ")" + + override def map(f: Constraint=>Constraint): Constraint = IsAdd(children.map(f)) + + def reduce(): Constraint = { + if(children.size == 1) children.head else { + (known, maxs, mins, others) match { + case (Some(k), _, _, _) if k.value == 0 => new IsAdd(None, maxs, mins, others).reduce() + case (Some(k), Vector(max), Vector(), Vector()) => max.map { o => IsAdd(k, o) }.reduce() + case (Some(k), Vector(), Vector(min), Vector()) => min.map { o => IsAdd(k, o) }.reduce() + case _ => this + } + } + } +} + +object IsAdd { + def apply(left: Constraint, right: Constraint): Constraint = (left, right) match { + case (l: IsKnown, r: IsKnown) => l + r + case _ => apply(Seq(left, right)) + } + def apply(children: Seq[Constraint]): Constraint = { + children.foldLeft(new IsAdd(None, Vector(), Vector(), Vector())) { (add, c) => + add.addChild(c) + }.reduce() + } +}
\ No newline at end of file diff --git a/src/main/scala/firrtl/constraint/IsFloor.scala b/src/main/scala/firrtl/constraint/IsFloor.scala new file mode 100644 index 00000000..5de4697e --- /dev/null +++ b/src/main/scala/firrtl/constraint/IsFloor.scala @@ -0,0 +1,32 @@ +// See LICENSE for license details. + +package firrtl.constraint + +object IsFloor { + def apply(child: Constraint): Constraint = new IsFloor(child, 0).reduce() +} + +case class IsFloor private (child: Constraint, dummyArg: Int) extends Constraint { + + override def reduce(): Constraint = child match { + case k: IsKnown => k.floor + case x: IsAdd => this + case x: IsMul => this + case x: IsNeg => this + case x: IsPow => this + // floor(max(a, b)) -> max(floor(a), floor(b)) + case x: IsMax => IsMax(x.children.map {b => IsFloor(b)}) + case x: IsMin => IsMin(x.children.map {b => IsFloor(b)}) + case x: IsVar => this + // floor(floor(x)) -> floor(x) + case x: IsFloor => x + case _ => this + } + val children = Vector(child) + + override def map(f: Constraint=>Constraint): Constraint = IsFloor(f(child)) + + override def serialize: String = "floor(" + child.serialize + ")" +} + + diff --git a/src/main/scala/firrtl/constraint/IsKnown.scala b/src/main/scala/firrtl/constraint/IsKnown.scala new file mode 100644 index 00000000..5bd25f92 --- /dev/null +++ b/src/main/scala/firrtl/constraint/IsKnown.scala @@ -0,0 +1,44 @@ +// See LICENSE for license details. + +package firrtl.constraint + +object IsKnown { + def unapply(b: Constraint): Option[BigDecimal] = b match { + case k: IsKnown => Some(k.value) + case _ => None + } +} + +/** Constant values must extend this trait see [[firrtl.ir.Closed and firrtl.ir.Open]] */ +trait IsKnown extends Constraint { + val value: BigDecimal + + /** Addition */ + def +(that: IsKnown): IsKnown + + /** Multiplication */ + def *(that: IsKnown): IsKnown + + /** Max */ + def max(that: IsKnown): IsKnown + + /** Min */ + def min(that: IsKnown): IsKnown + + /** Negate */ + def neg: IsKnown + + /** 2 << value */ + def pow: IsKnown + + /** Floor */ + def floor: IsKnown + + override def map(f: Constraint=>Constraint): Constraint = this + + val children: Vector[Constraint] = Vector.empty[Constraint] + + def reduce(): IsKnown = this +} + + diff --git a/src/main/scala/firrtl/constraint/IsMax.scala b/src/main/scala/firrtl/constraint/IsMax.scala new file mode 100644 index 00000000..3f24b7c0 --- /dev/null +++ b/src/main/scala/firrtl/constraint/IsMax.scala @@ -0,0 +1,59 @@ +// See LICENSE for license details. + +package firrtl.constraint + +object IsMax { + def apply(left: Constraint, right: Constraint): Constraint = (left, right) match { + case (l: IsKnown, r: IsKnown) => l max r + case _ => apply(Seq(left, right)) + } + def apply(children: Seq[Constraint]): Constraint = { + val x = children.foldLeft(new IsMax(None, Vector(), Vector())) { (add, c) => + add.addChild(c) + } + x.reduce() + } +} + +case class IsMax private[constraint](known: Option[IsKnown], + mins: Vector[IsMin], + others: Vector[Constraint] + ) extends MultiAry { + + def op(b1: IsKnown, b2: IsKnown): IsKnown = b1 max b2 + + override def serialize: String = "max(" + children.map(_.serialize).mkString(", ") + ")" + + override def map(f: Constraint=>Constraint): Constraint = IsMax(children.map(f)) + + lazy val children: Vector[Constraint] = { + if(known.nonEmpty) known.get +: (mins ++ others) else mins ++ others + } + + def reduce(): Constraint = { + if(children.size == 1) children.head else { + (known, mins, others) match { + case (Some(IsKnown(a)), _, _) => + // Eliminate minimums who have a known minimum value which is smaller than known maximum value + val filteredMins = mins.filter { + case IsMin(Some(IsKnown(i)), _, _) if i <= a => false + case other => true + } + // If a successful filter, rerun reduce + val newMax = new IsMax(known, filteredMins, others) + if(filteredMins.size != mins.size) { + newMax.reduce() + } else newMax + case _ => this + } + } + } + + def addChild(x: Constraint): IsMax = x match { + case k: IsKnown => new IsMax(known = merge(Some(k), known), mins, others) + case max: IsMax => new IsMax(known = merge(known, max.known), max.mins ++ mins, others ++ max.others) + case min: IsMin => new IsMax(known, mins :+ min, others) + case other => new IsMax(known, mins, others :+ other) + } +} + diff --git a/src/main/scala/firrtl/constraint/IsMin.scala b/src/main/scala/firrtl/constraint/IsMin.scala new file mode 100644 index 00000000..ee97e298 --- /dev/null +++ b/src/main/scala/firrtl/constraint/IsMin.scala @@ -0,0 +1,57 @@ +// See LICENSE for license details. + +package firrtl.constraint + +object IsMin { + def apply(left: Constraint, right: Constraint): Constraint = (left, right) match { + case (l: IsKnown, r: IsKnown) => l min r + case _ => apply(Seq(left, right)) + } + def apply(children: Seq[Constraint]): Constraint = { + children.foldLeft(new IsMin(None, Vector(), Vector())) { (add, c) => + add.addChild(c) + }.reduce() + } +} + +case class IsMin private[constraint](known: Option[IsKnown], + maxs: Vector[IsMax], + others: Vector[Constraint] + ) extends MultiAry { + + def op(b1: IsKnown, b2: IsKnown): IsKnown = b1 min b2 + + override def serialize: String = "min(" + children.map(_.serialize).mkString(", ") + ")" + + override def map(f: Constraint=>Constraint): Constraint = IsMin(children.map(f)) + + lazy val children: Vector[Constraint] = { + if(known.nonEmpty) known.get +: (maxs ++ others) else maxs ++ others + } + + def reduce(): Constraint = { + if(children.size == 1) children.head else { + (known, maxs, others) match { + case (Some(IsKnown(i)), _, _) => + // Eliminate maximums who have a known maximum value which is larger than known minimum value + val filteredMaxs = maxs.filter { + case IsMax(Some(IsKnown(a)), _, _) if a >= i => false + case other => true + } + // If a successful filter, rerun reduce + val newMin = new IsMin(known, filteredMaxs, others) + if(filteredMaxs.size != maxs.size) { + newMin.reduce() + } else newMin + case _ => this + } + } + } + + def addChild(x: Constraint): IsMin = x match { + case k: IsKnown => new IsMin(merge(Some(k), known), maxs, others) + case max: IsMax => new IsMin(known, maxs :+ max, others) + case min: IsMin => new IsMin(merge(min.known, known), maxs ++ min.maxs, others ++ min.others) + case other => new IsMin(known, maxs, others :+ other) + } +} diff --git a/src/main/scala/firrtl/constraint/IsMul.scala b/src/main/scala/firrtl/constraint/IsMul.scala new file mode 100644 index 00000000..3f637d75 --- /dev/null +++ b/src/main/scala/firrtl/constraint/IsMul.scala @@ -0,0 +1,52 @@ +// See LICENSE for license details. + +package firrtl.constraint + +import firrtl.ir.Closed + +object IsMul { + def apply(left: Constraint, right: Constraint): Constraint = (left, right) match { + case (l: IsKnown, r: IsKnown) => l * r + case _ => apply(Seq(left, right)) + } + def apply(children: Seq[Constraint]): Constraint = { + children.foldLeft(new IsMul(None, Vector())) { (add, c) => + add.addChild(c) + }.reduce() + } +} + +case class IsMul private (known: Option[IsKnown], others: Vector[Constraint]) extends MultiAry { + + def op(b1: IsKnown, b2: IsKnown): IsKnown = b1 * b2 + + lazy val children: Vector[Constraint] = if(known.nonEmpty) known.get +: others else others + + def addChild(x: Constraint): IsMul = x match { + case k: IsKnown => new IsMul(known = merge(Some(k), known), others) + case mul: IsMul => new IsMul(merge(known, mul.known), others ++ mul.others) + case other => new IsMul(known, others :+ other) + } + + override def reduce(): Constraint = { + if(children.size == 1) children.head else { + (known, others) match { + case (Some(Closed(x)), _) if x == BigDecimal(1) => new IsMul(None, others).reduce() + case (Some(Closed(x)), _) if x == BigDecimal(0) => Closed(0) + case (Some(Closed(x)), Vector(m: IsMax)) if x > 0 => + IsMax(m.children.map { c => IsMul(Closed(x), c) }) + case (Some(Closed(x)), Vector(m: IsMax)) if x < 0 => + IsMin(m.children.map { c => IsMul(Closed(x), c) }) + case (Some(Closed(x)), Vector(m: IsMin)) if x > 0 => + IsMin(m.children.map { c => IsMul(Closed(x), c) }) + case (Some(Closed(x)), Vector(m: IsMin)) if x < 0 => + IsMax(m.children.map { c => IsMul(Closed(x), c) }) + case _ => this + } + } + } + + override def map(f: Constraint=>Constraint): Constraint = IsMul(children.map(f)) + + override def serialize: String = "(" + children.map(_.serialize).mkString(" * ") + ")" +} diff --git a/src/main/scala/firrtl/constraint/IsNeg.scala b/src/main/scala/firrtl/constraint/IsNeg.scala new file mode 100644 index 00000000..46f739c6 --- /dev/null +++ b/src/main/scala/firrtl/constraint/IsNeg.scala @@ -0,0 +1,32 @@ +// See LICENSE for license details. + +package firrtl.constraint + +object IsNeg { + def apply(child: Constraint): Constraint = new IsNeg(child, 0).reduce() +} + +// Dummy arg is to get around weird Scala issue that can't differentiate between a +// private constructor and public apply that share the same arguments +case class IsNeg private (child: Constraint, dummyArg: Int) extends Constraint { + override def reduce(): Constraint = child match { + case k: IsKnown => k.neg + case x: IsAdd => IsAdd(x.children.map { b => IsNeg(b) }) + case x: IsMul => IsMul(Seq(IsNeg(x.children.head)) ++ x.children.tail) + case x: IsNeg => x.child + case x: IsPow => this + // -[max(a, b)] -> min[-a, -b] + case x: IsMax => IsMin(x.children.map { b => IsNeg(b) }) + case x: IsMin => IsMax(x.children.map { b => IsNeg(b) }) + case x: IsVar => this + case _ => this + } + + lazy val children = Vector(child) + + override def map(f: Constraint=>Constraint): Constraint = IsNeg(f(child)) + + override def serialize: String = "(-" + child.serialize + ")" +} + + diff --git a/src/main/scala/firrtl/constraint/IsPow.scala b/src/main/scala/firrtl/constraint/IsPow.scala new file mode 100644 index 00000000..54a06bf8 --- /dev/null +++ b/src/main/scala/firrtl/constraint/IsPow.scala @@ -0,0 +1,33 @@ +// See LICENSE for license details. + +package firrtl.constraint + +object IsPow { + def apply(child: Constraint): Constraint = new IsPow(child, 0).reduce() +} + +// Dummy arg is to get around weird Scala issue that can't differentiate between a +// private constructor and public apply that share the same arguments +case class IsPow private (child: Constraint, dummyArg: Int) extends Constraint { + override def reduce(): Constraint = child match { + case k: IsKnown => k.pow + // 2^(a + b) -> 2^a * 2^b + case x: IsAdd => IsMul(x.children.map { b => IsPow(b)}) + case x: IsMul => this + case x: IsNeg => this + case x: IsPow => this + // 2^(max(a, b)) -> max(2^a, 2^b) since two is always positive, so a, b control magnitude + case x: IsMax => IsMax(x.children.map {b => IsPow(b)}) + case x: IsMin => IsMin(x.children.map {b => IsPow(b)}) + case x: IsVar => this + case _ => this + } + + val children = Vector(child) + + override def map(f: Constraint=>Constraint): Constraint = IsPow(f(child)) + + override def serialize: String = "(2^" + child.serialize + ")" +} + + diff --git a/src/main/scala/firrtl/constraint/IsVar.scala b/src/main/scala/firrtl/constraint/IsVar.scala new file mode 100644 index 00000000..98396fa0 --- /dev/null +++ b/src/main/scala/firrtl/constraint/IsVar.scala @@ -0,0 +1,27 @@ +// See LICENSE for license details. + +package firrtl.constraint + +object IsVar { + def unapply(i: Constraint): Option[String] = i match { + case i: IsVar => Some(i.name) + case _ => None + } +} + +/** Extend to be a constraint variable */ +trait IsVar extends Constraint { + + def name: String + + override def serialize: String = name + + override def map(f: Constraint=>Constraint): Constraint = this + + override def reduce() = this + + val children = Vector() +} + +case class VarCon(name: String) extends IsVar + diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index e721363c..63c620d1 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -3,7 +3,11 @@ package firrtl package ir -import Utils.indent +import Utils.{dec2string, indent, trim} +import firrtl.constraint.{Constraint, IsKnown, IsVar} + +import scala.math.BigDecimal.RoundingMode._ +import scala.collection.mutable /** Intermediate Representation */ abstract class FirrtlNode { @@ -103,6 +107,29 @@ object StringLit { */ abstract class PrimOp extends FirrtlNode { def serialize: String = this.toString + def propagateType(e: DoPrim): Type = UnknownType + def apply(args: Any*): DoPrim = { + val groups = args.groupBy { + case x: Expression => "exp" + case x: BigInt => "int" + case x: Int => "int" + case other => "other" + } + val exprs = groups.getOrElse("exp", Nil).collect { + case e: Expression => e + } + val consts = groups.getOrElse("int", Nil).map { + _ match { + case i: BigInt => i + case i: Int => BigInt(i) + } + } + groups.get("other") match { + case None => + case Some(x) => sys.error(s"Shouldn't be here: $x") + } + DoPrim(this, exprs, consts, UnknownType) + } } abstract class Expression extends FirrtlNode { @@ -151,7 +178,7 @@ case class SubAccess(expr: Expression, index: Expression, tpe: Type) extends Exp def foreachType(f: Type => Unit): Unit = f(tpe) def foreachWidth(f: Width => Unit): Unit = Unit } -case class Mux(cond: Expression, tval: Expression, fval: Expression, tpe: Type) extends Expression { +case class Mux(cond: Expression, tval: Expression, fval: Expression, tpe: Type = UnknownType) extends Expression { def serialize: String = s"mux(${cond.serialize}, ${tval.serialize}, ${fval.serialize})" def mapExpr(f: Expression => Expression): Expression = Mux(f(cond), f(tval), f(fval), tpe) def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) @@ -535,6 +562,12 @@ class IntWidth(val width: BigInt) extends Width with Product { case object UnknownWidth extends Width { def serialize: String = "" } +case class CalcWidth(arg: Constraint) extends Width { + def serialize: String = s"calcw(${arg.serialize})" +} +case class VarWidth(name: String) extends Width with IsVar { + override def serialize: String = s"<$name>" +} /** Orientation of [[Field]] */ abstract class Orientation extends FirrtlNode @@ -550,6 +583,67 @@ case class Field(name: String, flip: Orientation, tpe: Type) extends FirrtlNode def serialize: String = flip.serialize + name + " : " + tpe.serialize } + +/** Bounds of [[IntervalType]] */ + +trait Bound extends Constraint { + def serialize: String +} +case object UnknownBound extends Bound { + def serialize: String = "?" + def map(f: Constraint=>Constraint): Constraint = this + override def reduce(): Constraint = this + val children = Vector() +} +case class CalcBound(arg: Constraint) extends Bound { + def serialize: String = s"calcb(${arg.serialize})" + def map(f: Constraint=>Constraint): Constraint = f(arg) + override def reduce(): Constraint = arg + val children = Vector(arg) +} +case class VarBound(name: String) extends IsVar with Bound +object KnownBound { + def unapply(b: Constraint): Option[BigDecimal] = b match { + case k: IsKnown => Some(k.value) + case _ => None + } + def unapply(b: Bound): Option[BigDecimal] = b match { + case k: IsKnown => Some(k.value) + case _ => None + } +} +case class Open(value: BigDecimal) extends IsKnown with Bound { + def serialize = s"o($value)" + def +(that: IsKnown): IsKnown = Open(value + that.value) + def *(that: IsKnown): IsKnown = that match { + case Closed(x) if x == 0 => Closed(x) + case _ => Open(value * that.value) + } + def min(that: IsKnown): IsKnown = if(value < that.value) this else that + def max(that: IsKnown): IsKnown = if(value > that.value) this else that + def neg: IsKnown = Open(-value) + def floor: IsKnown = Open(value.setScale(0, BigDecimal.RoundingMode.FLOOR)) + def pow: IsKnown = if(value.isBinaryDouble) Open(BigDecimal(BigInt(1) << value.toInt)) else sys.error("Shouldn't be here") +} +case class Closed(value: BigDecimal) extends IsKnown with Bound { + def serialize = s"c($value)" + def +(that: IsKnown): IsKnown = that match { + case Open(x) => Open(value + x) + case Closed(x) => Closed(value + x) + } + def *(that: IsKnown): IsKnown = that match { + case IsKnown(x) if value == BigInt(0) => Closed(0) + case Open(x) => Open(value * x) + case Closed(x) => Closed(value * x) + } + def min(that: IsKnown): IsKnown = if(value <= that.value) this else that + def max(that: IsKnown): IsKnown = if(value >= that.value) this else that + def neg: IsKnown = Closed(-value) + def floor: IsKnown = Closed(value.setScale(0, BigDecimal.RoundingMode.FLOOR)) + def pow: IsKnown = if(value.isBinaryDouble) Closed(BigDecimal(BigInt(1) << value.toInt)) else sys.error("Shouldn't be here") +} + +/** Types of [[FirrtlNode]] */ abstract class Type extends FirrtlNode { def mapType(f: Type => Type): Type def mapWidth(f: Width => Width): Type @@ -586,6 +680,84 @@ case class FixedType(width: Width, point: Width) extends GroundType { def mapWidth(f: Width => Width): Type = FixedType(f(width), f(point)) def foreachWidth(f: Width => Unit): Unit = { f(width); f(point) } } +case class IntervalType(lower: Bound, upper: Bound, point: Width) extends GroundType { + override def serialize: String = { + val lowerString = lower match { + case Open(l) => s"(${dec2string(l)}, " + case Closed(l) => s"[${dec2string(l)}, " + case UnknownBound => s"[?, " + case _ => s"[?, " + } + val upperString = upper match { + case Open(u) => s"${dec2string(u)})" + case Closed(u) => s"${dec2string(u)}]" + case UnknownBound => s"?]" + case _ => s"?]" + } + val bounds = (lower, upper) match { + case (k1: IsKnown, k2: IsKnown) => lowerString + upperString + case _ => "" + } + val pointString = point match { + case IntWidth(i) => "." + i.toString + case _ => "" + } + "Interval" + bounds + pointString + } + + private lazy val bp = point.asInstanceOf[IntWidth].width.toInt + private def precision: Option[BigDecimal] = point match { + case IntWidth(width) => + val bp = width.toInt + if(bp >= 0) Some(BigDecimal(1) / BigDecimal(BigInt(1) << bp)) else Some(BigDecimal(BigInt(1) << -bp)) + case other => None + } + + def min: Option[BigDecimal] = (lower, precision) match { + case (Open(a), Some(prec)) => a / prec match { + case x if trim(x).isWhole => Some(a + prec) // add precision for open lower bound i.e. (-4 -> [3 for bp = 0 + case x => Some(x.setScale(0, CEILING) * prec) // Deal with unrepresentable bound representations (finite BP) -- new closed form l > original l + } + case (Closed(a), Some(prec)) => Some((a / prec).setScale(0, CEILING) * prec) + case other => None + } + + def max: Option[BigDecimal] = (upper, precision) match { + case (Open(a), Some(prec)) => a / prec match { + case x if trim(x).isWhole => Some(a - prec) // subtract precision for open upper bound + case x => Some(x.setScale(0, FLOOR) * prec) + } + case (Closed(a), Some(prec)) => Some((a / prec).setScale(0, FLOOR) * prec) + } + + def minAdjusted: Option[BigInt] = min.map(_ * BigDecimal(BigInt(1) << bp) match { + case x if trim(x).isWhole | x.doubleValue == 0.0 => x.toBigInt + case x => sys.error(s"MinAdjusted should be a whole number: $x. Min is $min. BP is $bp. Precision is $precision. Lower is ${lower}.") + }) + + def maxAdjusted: Option[BigInt] = max.map(_ * BigDecimal(BigInt(1) << bp) match { + case x if trim(x).isWhole => x.toBigInt + case x => sys.error(s"MaxAdjusted should be a whole number: $x") + }) + + /** If bounds are known, calculates the width, otherwise returns UnknownWidth */ + lazy val width: Width = (point, lower, upper) match { + case (IntWidth(i), l: IsKnown, u: IsKnown) => + IntWidth(Math.max(Utils.getSIntWidth(minAdjusted.get), Utils.getSIntWidth(maxAdjusted.get))) + case _ => UnknownWidth + } + + /** If bounds are known, returns a sequence of all possible values inside this interval */ + lazy val range: Option[Seq[BigDecimal]] = (lower, upper, point) match { + case (l: IsKnown, u: IsKnown, p: IntWidth) => + if(min.get > max.get) Some(Nil) else Some(Range.BigDecimal(min.get, max.get, precision.get)) + case _ => None + } + + override def mapWidth(f: Width => Width): Type = this.copy(point = f(point)) + override def foreachWidth(f: Width => Unit): Unit = f(point) +} + case class BundleType(fields: Seq[Field]) extends AggregateType { def serialize: String = "{ " + (fields map (_.serialize) mkString ", ") + "}" def mapType(f: Type => Type): Type = @@ -645,6 +817,7 @@ case class Port( direction: Direction, tpe: Type) extends FirrtlNode with IsDeclaration { def serialize: String = s"${direction.serialize} $name : ${tpe.serialize}" + info.serialize + def mapType(f: Type => Type): Port = Port(info, name, direction, f(tpe)) } /** Parameters for external modules */ diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala index 07784e19..5ae5dad4 100644 --- a/src/main/scala/firrtl/passes/CheckWidths.scala +++ b/src/main/scala/firrtl/passes/CheckWidths.scala @@ -7,14 +7,21 @@ import firrtl.ir._ import firrtl.PrimOps._ import firrtl.traversals.Foreachers._ import firrtl.Utils._ +import firrtl.constraint.IsKnown import firrtl.annotations.{CircuitTarget, ModuleTarget, Target, TargetToken} object CheckWidths extends Pass { /** The maximum allowed width for any circuit element */ val MaxWidth = 1000000 - val DshlMaxWidth = ceilLog2(MaxWidth + 1) + val DshlMaxWidth = getUIntWidth(MaxWidth) class UninferredWidth (info: Info, target: String) extends PassException( - s"""|$info : Uninferred width for target below. (Did you forget to assign to it?) + s"""|$info : Uninferred width for target below.serialize}. (Did you forget to assign to it?) + |$target""".stripMargin) + class UninferredBound (info: Info, target: String, bound: String) extends PassException( + s"""|$info : Uninferred $bound bound for target. (Did you forget to assign to it?) + |$target""".stripMargin) + class InvalidRange (info: Info, target: String, i: IntervalType) extends PassException( + s"""|$info : Invalid range ${i.serialize} for target below. (Are the bounds valid?) |$target""".stripMargin) class WidthTooSmall(info: Info, mname: String, b: BigInt) extends PassException( s"$info : [target $mname] Width too small for constant $b.") @@ -32,17 +39,25 @@ object CheckWidths extends Pass { s"$info: [target $mname] Parameter $n in tail operator is larger than input width $width.") class AttachWidthsNotEqual(info: Info, mname: String, eName: String, source: String) extends PassException( s"$info: [target $mname] Attach source $source and expression $eName must have identical widths.") + class DisjointSqueeze(info: Info, mname: String, squeeze: DoPrim) + extends PassException({ + val toSqz = squeeze.args.head.serialize + val toSqzTpe = squeeze.args.head.tpe.serialize + val sqzTo = squeeze.args(1).serialize + val sqzToTpe = squeeze.args(1).tpe.serialize + s"$info: [module $mname] Disjoint squz currently unsupported: $toSqz:$toSqzTpe cannot be squeezed with $sqzTo's type $sqzToTpe" + }) def run(c: Circuit): Circuit = { val errors = new Errors() - def check_width_w(info: Info, target: Target)(w: Width): Unit = { - w match { - case IntWidth(width) if width >= MaxWidth => + def check_width_w(info: Info, target: Target, t: Type)(w: Width): Unit = { + (w, t) match { + case (IntWidth(width), _) if width >= MaxWidth => errors.append(new WidthTooBig(info, target.serialize, width)) - case w: IntWidth if w.width >= 0 => - case _: IntWidth => + case (w: IntWidth, f: FixedType) if (w.width < 0 && w.width == f.width) => errors append new NegWidthException(info, target.serialize) + case (_: IntWidth, _) => case _ => errors append new UninferredWidth(info, target.prettyPrint(" ")) } @@ -57,9 +72,28 @@ object CheckWidths extends Pass { def check_width_t(info: Info, target: Target)(t: Type): Unit = { t match { case tt: BundleType => tt.fields.foreach(check_width_f(info, target)) + //Supports when l = u (if closed) + case i@IntervalType(Closed(l), Closed(u), IntWidth(_)) if l <= u => i + case i:IntervalType if i.range == Some(Nil) => + errors append new InvalidRange(info, target.prettyPrint(" "), i) + i + case i@IntervalType(KnownBound(l), KnownBound(u), IntWidth(p)) if l >= u => + errors append new InvalidRange(info, target.prettyPrint(" "), i) + i + case i@IntervalType(KnownBound(_), KnownBound(_), IntWidth(_)) => i + case i@IntervalType(_: IsKnown, _, _) => + errors append new UninferredBound(info, target.prettyPrint(" "), "upper") + i + case i@IntervalType(_, _: IsKnown, _) => + errors append new UninferredBound(info, target.prettyPrint(" "), "lower") + i + case i@IntervalType(_, _, _) => + errors append new UninferredBound(info, target.prettyPrint(" "), "lower") + errors append new UninferredBound(info, target.prettyPrint(" "), "upper") + i case tt => tt foreach check_width_t(info, target) } - t foreach check_width_w(info, target) + t foreach check_width_w(info, target, t) } def check_width_f(info: Info, target: Target)(f: Field): Unit = @@ -77,6 +111,12 @@ object CheckWidths extends Pass { errors append new WidthTooSmall(info, target.serialize, e.value) case _ => } + case sqz@DoPrim(Squeeze, Seq(a, b), _, IntervalType(Closed(min), Closed(max), _)) => + (a.tpe, b.tpe) match { + case (IntervalType(Closed(la), Closed(ua), _), IntervalType(Closed(lb), Closed(ub), _)) if (ua < lb) || (ub < la) => + errors append new DisjointSqueeze(info, target.serialize, sqz) + case other => + } case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) <= hi) => errors append new BitsWidthException(info, target.serialize, hi, bitWidth(a.tpe), e.serialize) case DoPrim(Head, Seq(a), Seq(n), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) < n) => @@ -87,7 +127,6 @@ object CheckWidths extends Pass { errors append new DshlTooBig(info, target.serialize) case _ => } - //e map check_width_t(info, mname) map check_width_e(info, mname) e foreach check_width_e(info, target) } @@ -111,11 +150,15 @@ object CheckWidths extends Pass { case ResetType => case _ => errors.append(new CheckTypes.IllegalResetType(info, target.serialize, sx.name)) } + if(!CheckTypes.validConnect(sx.tpe, sx.init.tpe)) { + val conMsg = sx.copy(info = NoInfo).serialize + errors.append(new CheckTypes.InvalidConnect(info, target.module, conMsg, WRef(sx), sx.init)) + } case _ => } } - def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Unit = check_width_t(p.info, target)(p.tpe) + def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Unit = check_width_t(p.info, target.ref(p.name))(p.tpe) def check_width_m(circuit: CircuitTarget)(m: DefModule): Unit = { m foreach check_width_p(m.info, circuit.module(m.name)) diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index 85ed7de0..4239247c 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -8,6 +8,7 @@ import firrtl.PrimOps._ import firrtl.Utils._ import firrtl.traversals.Foreachers._ import firrtl.WrappedType._ +import firrtl.constraint.{Constraint, IsKnown} trait CheckHighFormLike { type NameSet = collection.mutable.HashSet[String] @@ -84,11 +85,11 @@ trait CheckHighFormLike { e.op match { case Add | Sub | Mul | Div | Rem | Lt | Leq | Gt | Geq | - Eq | Neq | Dshl | Dshr | And | Or | Xor | Cat | Dshlw => + Eq | Neq | Dshl | Dshr | And | Or | Xor | Cat | Dshlw | Clip | Wrap | Squeeze => correctNum(Option(2), 0) case AsUInt | AsSInt | AsClock | AsAsyncReset | Cvt | Neq | Not => correctNum(Option(1), 0) - case AsFixedPoint | Pad | Head | Tail | BPShl | BPShr | BPSet => + case AsFixedPoint | Pad | Head | Tail | IncP | DecP | SetP => correctNum(Option(1), 1) case Shl | Shr => correctNum(Option(1), 1) @@ -101,6 +102,8 @@ trait CheckHighFormLike { if (lsb > msb) { errors.append(new LsbLargerThanMsbException(info, mname, e.op.toString, lsb, msb)) } + case AsInterval => + correctNum(Option(1), 3) case Andr | Orr | Xorr | Neg => correctNum(None,0) } @@ -137,7 +140,9 @@ trait CheckHighFormLike { def checkHighFormT(info: Info, mname: String)(t: Type): Unit = { t foreach checkHighFormT(info, mname) t match { - case tx: VectorType if tx.size < 0 => errors.append(new NegVecSizeException(info, mname)) + case tx: VectorType if tx.size < 0 => + errors.append(new NegVecSizeException(info, mname)) + case i: IntervalType => i case _ => t foreach checkHighFormW(info, mname) } } @@ -146,6 +151,7 @@ trait CheckHighFormLike { e match { case _: Reference | _: SubField | _: SubIndex | _: SubAccess => // No error case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess | _: Mux | _: ValidIf => // No error + case _: Reference | _: SubField | _: SubIndex | _: SubAccess => // No error case _ => errors.append(new InvalidAccessException(info, mname)) } } @@ -164,8 +170,8 @@ trait CheckHighFormLike { case ex: WSubAccess => validSubexp(info, mname)(ex.expr) case ex => ex foreach validSubexp(info, mname) } - e foreach checkHighFormW(info, mname) - e foreach checkHighFormT(info, mname) + e foreach checkHighFormW(info, mname + "/" + e.serialize) + e foreach checkHighFormT(info, mname + "/" + e.serialize) e foreach checkHighFormE(info, mname, names) } @@ -215,8 +221,7 @@ trait CheckHighFormLike { if (names(p.name)) errors.append(new NotUniqueException(NoInfo, mname, p.name)) names += p.name - p.tpe foreach checkHighFormT(p.info, mname) - p.tpe foreach checkHighFormW(p.info, mname) + checkHighFormT(p.info, mname)(p.tpe) } // Search for ResetType Ports of direction @@ -339,6 +344,11 @@ object CheckTypes extends Pass { s"$info: [module $mname] Uninferred type: $exp." ) + def fits(bigger: Constraint, smaller: Constraint): Boolean = (bigger, smaller) match { + case (IsKnown(v1), IsKnown(v2)) if v1 < v2 => false + case _ => true + } + def legalResetType(tpe: Type): Boolean = tpe match { case UIntType(IntWidth(w)) if w == 1 => true case AsyncResetType => true @@ -355,6 +365,9 @@ object CheckTypes extends Pass { case (_: UIntType, _: UIntType) => flip1 == flip2 case (_: SIntType, _: SIntType) => flip1 == flip2 case (_: FixedType, _: FixedType) => flip1 == flip2 + case (i1: IntervalType, i2: IntervalType) => + import Implicits.width2constraint + fits(i2.lower, i1.lower) && fits(i1.upper, i2.upper) && fits(i1.point, i2.point) case (_: AnalogType, _: AnalogType) => true case (AsyncResetType, AsyncResetType) => flip1 == flip2 case (ResetType, tpe) => legalResetType(tpe) && flip1 == flip2 @@ -375,7 +388,17 @@ object CheckTypes extends Pass { } } - def validConnect(con: Connect): Boolean = wt(con.loc.tpe).superTypeOf(wt(con.expr.tpe)) + def validConnect(locTpe: Type, expTpe: Type): Boolean = { + val itFits = (locTpe, expTpe) match { + case (i1: IntervalType, i2: IntervalType) => + import Implicits.width2constraint + fits(i2.lower, i1.lower) && fits(i1.upper, i2.upper) && fits(i1.point, i2.point) + case _ => true + } + wt(locTpe).superTypeOf(wt(expTpe)) && itFits + } + + def validConnect(con: Connect): Boolean = validConnect(con.loc.tpe, con.expr.tpe) def validPartialConnect(con: PartialConnect): Boolean = bulk_equals(con.loc.tpe, con.expr.tpe, Default, Default) @@ -393,44 +416,51 @@ object CheckTypes extends Pass { case tx: BundleType => tx.fields forall (x => x.flip == Default && passive(x.tpe)) case tx => true } + def check_types_primop(info: Info, mname: String, e: DoPrim): Unit = { - def checkAllTypes(exprs: Seq[Expression], okUInt: Boolean, okSInt: Boolean, okClock: Boolean, okFix: Boolean, okAsync: Boolean): Unit = { - exprs.foldLeft((false, false, false, false, false)) { - case ((isUInt, isSInt, isClock, isFix, isAsync), expr) => expr.tpe match { - case u: UIntType => (true, isSInt, isClock, isFix, isAsync) - case s: SIntType => (isUInt, true, isClock, isFix, isAsync) - case ClockType => (isUInt, isSInt, true, isFix, isAsync) - case f: FixedType => (isUInt, isSInt, isClock, true, isAsync) - case AsyncResetType => (isUInt, isSInt, isClock, isFix, true) + def checkAllTypes(exprs: Seq[Expression], okUInt: Boolean, okSInt: Boolean, okClock: Boolean, okFix: Boolean, okAsync: Boolean, okInterval: Boolean): Unit = { + exprs.foldLeft((false, false, false, false, false, false)) { + case ((isUInt, isSInt, isClock, isFix, isAsync, isInterval), expr) => expr.tpe match { + case u: UIntType => (true, isSInt, isClock, isFix, isAsync, isInterval) + case s: SIntType => (isUInt, true, isClock, isFix, isAsync, isInterval) + case ClockType => (isUInt, isSInt, true, isFix, isAsync, isInterval) + case f: FixedType => (isUInt, isSInt, isClock, true, isAsync, isInterval) + case AsyncResetType => (isUInt, isSInt, isClock, isFix, true, isInterval) + case i:IntervalType => (isUInt, isSInt, isClock, isFix, isAsync, true) case UnknownType => errors.append(new IllegalUnknownType(info, mname, e.serialize)) - (isUInt, isSInt, isClock, isFix, isAsync) + (isUInt, isSInt, isClock, isFix, isAsync, isInterval) case other => throwInternalError(s"Illegal Type: ${other.serialize}") } } match { - // (UInt, SInt, Clock, Fixed) - case (isAll, false, false, false, false) if isAll == okUInt => - case (false, isAll, false, false, false) if isAll == okSInt => - case (false, false, isAll, false, false) if isAll == okClock => - case (false, false, false, isAll, false) if isAll == okFix => - case (false, false, false, false, isAll) if isAll == okAsync => + // (UInt, SInt, Clock, Fixed, Async, Interval) + case (isAll, false, false, false, false, false) if isAll == okUInt => + case (false, isAll, false, false, false, false) if isAll == okSInt => + case (false, false, isAll, false, false, false) if isAll == okClock => + case (false, false, false, isAll, false, false) if isAll == okFix => + case (false, false, false, false, isAll, false) if isAll == okAsync => + case (false, false, false, false, false, isAll) if isAll == okInterval => case x => errors.append(new OpNotCorrectType(info, mname, e.op.serialize, exprs.map(_.tpe.serialize))) } } e.op match { - case AsUInt | AsSInt | AsClock | AsFixedPoint | AsAsyncReset => + case AsUInt | AsSInt | AsClock | AsFixedPoint | AsAsyncReset | AsInterval => // All types are ok case Dshl | Dshr => - checkAllTypes(Seq(e.args.head), okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false) - checkAllTypes(Seq(e.args(1)), okUInt=true, okSInt=false, okClock=false, okFix=false, okAsync=false) + checkAllTypes(Seq(e.args.head), okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true) + checkAllTypes(Seq(e.args(1)), okUInt=true, okSInt=false, okClock=false, okFix=false, okAsync=false, okInterval=false) case Add | Sub | Mul | Lt | Leq | Gt | Geq | Eq | Neq => - checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false) - case Pad | Shl | Shr | Cat | Bits | Head | Tail => - checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false) - case BPShl | BPShr | BPSet => - checkAllTypes(e.args, okUInt=false, okSInt=false, okClock=false, okFix=true, okAsync=false) + checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true) + case Pad | Bits | Head | Tail => + checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=false) + case Shl | Shr | Cat => + checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true) + case IncP | DecP | SetP => + checkAllTypes(e.args, okUInt=false, okSInt=false, okClock=false, okFix=true, okAsync=false, okInterval=true) + case Wrap | Clip | Squeeze => + checkAllTypes(e.args, okUInt = false, okSInt = false, okClock = false, okFix = false, okAsync=false, okInterval = true) case _ => - checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=false, okAsync=false) + checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=false, okAsync=false, okInterval=false) } } @@ -494,6 +524,9 @@ object CheckTypes extends Pass { sx.tpe match { case AnalogType(_) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name)) case t if wt(sx.tpe) != wt(sx.init.tpe) => errors.append(new InvalidRegInit(info, mname)) + case t if !validConnect(sx.tpe, sx.init.tpe) => + val conMsg = sx.copy(info = NoInfo).serialize + errors.append(new CheckTypes.InvalidConnect(info, mname, conMsg, WRef(sx), sx.init)) case t => } if (!legalResetType(sx.reset.tpe)) { diff --git a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala index 67fdfea0..05a000c5 100644 --- a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala +++ b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala @@ -39,9 +39,9 @@ object ConvertFixedToSInt extends Pass { def updateExpType(e:Expression): Expression = e match { case DoPrim(Mul, args, consts, tpe) => e map updateExpType case DoPrim(AsFixedPoint, args, consts, tpe) => DoPrim(AsSInt, args, Seq.empty, tpe) map updateExpType - case DoPrim(BPShl, args, consts, tpe) => DoPrim(Shl, args, consts, tpe) map updateExpType - case DoPrim(BPShr, args, consts, tpe) => DoPrim(Shr, args, consts, tpe) map updateExpType - case DoPrim(BPSet, args, consts, FixedType(w, IntWidth(p))) => alignArg(args.head, p) map updateExpType + case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe) map updateExpType + case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe) map updateExpType + case DoPrim(SetP, args, consts, FixedType(w, IntWidth(p))) => alignArg(args.head, p) map updateExpType case DoPrim(op, args, consts, tpe) => val point = calcPoint(args) val newExp = DoPrim(op, args.map(x => alignArg(x, point)), consts, UnknownType) diff --git a/src/main/scala/firrtl/passes/InferBinaryPoints.scala b/src/main/scala/firrtl/passes/InferBinaryPoints.scala new file mode 100644 index 00000000..258c9697 --- /dev/null +++ b/src/main/scala/firrtl/passes/InferBinaryPoints.scala @@ -0,0 +1,101 @@ +// See LICENSE for license details. + +package firrtl.passes + +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.Mappers._ +import firrtl.annotations.{CircuitTarget, ModuleTarget, ReferenceTarget, Target} +import firrtl.constraint.ConstraintSolver + +class InferBinaryPoints extends Pass { + private val constraintSolver = new ConstraintSolver() + + private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1,t2) match { + case (UIntType(w1), UIntType(w2)) => + case (SIntType(w1), SIntType(w2)) => + case (ClockType, ClockType) => + case (ResetType, _) => + case (_, ResetType) => + case (AsyncResetType, AsyncResetType) => + case (FixedType(w1, p1), FixedType(w2, p2)) => + constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint("")) + case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) => + constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint("")) + case (AnalogType(w1), AnalogType(w2)) => + case (t1: BundleType, t2: BundleType) => + (t1.fields zip t2.fields) foreach { case (f1, f2) => + (f1.flip, f2.flip) match { + case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe) + case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe) + case _ => sys.error("Shouldn't be here") + } + } + case (t1: VectorType, t2: VectorType) => addTypeConstraints(r1.index(0), r2.index(0))(t1.tpe, t2.tpe) + case other => throwInternalError(s"Illegal compiler state: cannot constraint different types - $other") + } + private def addDecConstraints(t: Type): Type = t map addDecConstraints + private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s map addDecConstraints match { + case c: Connect => + val n = get_size(c.loc.tpe) + val locs = create_exps(c.loc) + val exps = create_exps(c.expr) + (locs zip exps) foreach { case (loc, exp) => + to_flip(flow(loc)) match { + case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) + case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) + } + } + c + case pc: PartialConnect => + val ls = get_valid_points(pc.loc.tpe, pc.expr.tpe, Default, Default) + val locs = create_exps(pc.loc) + val exps = create_exps(pc.expr) + ls foreach { case (x, y) => + val loc = locs(x) + val exp = exps(y) + to_flip(flow(loc)) match { + case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) + case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) + } + } + pc + case r: DefRegister => + addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe) + r + case x => x map addStmtConstraints(mt) + } + private def fixWidth(w: Width): Width = constraintSolver.get(w) match { + case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) + case None => w + case _ => sys.error("Shouldn't be here") + } + private def fixType(t: Type): Type = t map fixType map fixWidth match { + case IntervalType(l, u, p) => + val px = constraintSolver.get(p) match { + case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) + case None => p + case _ => sys.error("Shouldn't be here") + } + IntervalType(l, u, px) + case FixedType(w, p) => + val px = constraintSolver.get(p) match { + case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) + case None => p + case _ => sys.error("Shouldn't be here") + } + FixedType(w, px) + case x => x + } + private def fixStmt(s: Statement): Statement = s map fixStmt map fixType + private def fixPort(p: Port): Port = Port(p.info, p.name, p.direction, fixType(p.tpe)) + def run (c: Circuit): Circuit = { + val ct = CircuitTarget(c.main) + c.modules foreach (m => m map addStmtConstraints(ct.module(m.name))) + c.modules foreach (_.ports foreach {p => addDecConstraints(p.tpe)}) + constraintSolver.solve() + InferTypes.run(c.copy(modules = c.modules map (_ + map fixPort + map fixStmt))) + } +} diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala index 288b62ba..3c5cf7fb 100644 --- a/src/main/scala/firrtl/passes/InferTypes.scala +++ b/src/main/scala/firrtl/passes/InferTypes.scala @@ -14,13 +14,23 @@ object InferTypes extends Pass { val namespace = Namespace() val mtypes = (c.modules map (m => m.name -> module_type(m))).toMap + def remove_unknowns_b(b: Bound): Bound = b match { + case UnknownBound => VarBound(namespace.newName("b")) + case k => k + } + def remove_unknowns_w(w: Width): Width = w match { case UnknownWidth => VarWidth(namespace.newName("w")) case wx => wx } - def remove_unknowns(t: Type): Type = - t map remove_unknowns map remove_unknowns_w + def remove_unknowns(t: Type): Type = { + t map remove_unknowns map remove_unknowns_w match { + case IntervalType(l, u, p) => + IntervalType(remove_unknowns_b(l), remove_unknowns_b(u), p) + case x => x + } + } def infer_types_e(types: TypeMap)(e: Expression): Expression = e map infer_types_e(types) match { diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index 9c58da2c..2211d238 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -3,15 +3,20 @@ package firrtl.passes // Datastructures -import scala.collection.mutable.ArrayBuffer -import scala.collection.immutable.ListMap - import firrtl._ import firrtl.annotations.{Annotation, ReferenceTarget} import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ -import firrtl.traversals.Foreachers._ +import firrtl.Implicits.width2constraint +import firrtl.annotations.{CircuitTarget, ModuleTarget, ReferenceTarget, Target} +import firrtl.constraint.{ConstraintSolver, IsMax} + +object InferWidths { + def apply(): InferWidths = new InferWidths() + def run(c: Circuit): Circuit = new InferWidths().run(c) + def execute(state: CircuitState): CircuitState = new InferWidths().execute(state) +} case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarget) extends Annotation { def update(renameMap: RenameMap): Seq[WidthGeqConstraintAnnotation] = { @@ -33,369 +38,146 @@ case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarg } } +/** Infers the widths of all signals with unknown widths + * + * Is a global width inference algorithm + * - Instances of the same module with unknown input port widths will be assigned the + * largest width of all assignments to each of its instance ports + * - If you don't want the global inference behavior, then be sure to define all your input widths + * + * Infers the smallest width is larger than all assigned widths to a signal + * - Note that this means that dummy assignments that are overwritten by last-connect-semantics + * can still influence width inference + * - E.g. + * wire x: UInt + * x <= UInt<5>(15) + * x <= UInt<1>(1) + * + * Since width inference occurs before lowering, it infers x's width to be 5 but with an assignment of UInt(1): + * + * wire x: UInt<5> + * x <= UInt<1>(1) + * + * Uses firrtl.constraint package to infer widths + */ class InferWidths extends Transform with ResolvedAnnotationPaths { def inputForm: CircuitForm = UnknownForm def outputForm: CircuitForm = UnknownForm - val annotationClasses = Seq(classOf[WidthGeqConstraintAnnotation]) + private val constraintSolver = new ConstraintSolver() - type ConstraintMap = collection.mutable.LinkedHashMap[String, Width] - - def solve_constraints(l: Seq[WGeq]): ConstraintMap = { - def unique(ls: Seq[Width]) : Seq[Width] = - (ls map (new WrappedWidth(_))).distinct map (_.w) - // Combines constraints on the same VarWidth into the same constraint - def make_unique(ls: Seq[WGeq]): ListMap[String,Width] = { - ls.foldLeft(ListMap.empty[String, Width])((acc, wgeq) => wgeq.loc match { - case VarWidth(name) => acc.get(name) match { - case None => acc + (name -> wgeq.exp) - // Avoid constructing massive MaxWidth chains - case Some(MaxWidth(args)) => acc + (name -> MaxWidth(wgeq.exp +: args)) - case Some(width) => acc + (name -> MaxWidth(Seq(wgeq.exp, width))) - } - case _ => acc - }) - } - def pullMinMax(w: Width): Width = w map pullMinMax match { - case PlusWidth(MaxWidth(maxs), IntWidth(i)) => MaxWidth(maxs.map(m => PlusWidth(m, IntWidth(i)))) - case PlusWidth(IntWidth(i), MaxWidth(maxs)) => MaxWidth(maxs.map(m => PlusWidth(m, IntWidth(i)))) - case MinusWidth(MaxWidth(maxs), IntWidth(i)) => MaxWidth(maxs.map(m => MinusWidth(m, IntWidth(i)))) - case MinusWidth(IntWidth(i), MaxWidth(maxs)) => MaxWidth(maxs.map(m => MinusWidth(IntWidth(i), m))) - case PlusWidth(MinWidth(mins), IntWidth(i)) => MinWidth(mins.map(m => PlusWidth(m, IntWidth(i)))) - case PlusWidth(IntWidth(i), MinWidth(mins)) => MinWidth(mins.map(m => PlusWidth(m, IntWidth(i)))) - case MinusWidth(MinWidth(mins), IntWidth(i)) => MinWidth(mins.map(m => MinusWidth(m, IntWidth(i)))) - case MinusWidth(IntWidth(i), MinWidth(mins)) => MinWidth(mins.map(m => MinusWidth(IntWidth(i), m))) - case wx => wx - } - def collectMinMax(w: Width): Width = w map collectMinMax match { - case MinWidth(args) => MinWidth(unique(args.foldLeft(List[Width]()) { - case (res, wxx: MinWidth) => wxx.args ++: res - case (res, wxx) => wxx +: res - })) - case MaxWidth(args) => MaxWidth(unique(args.foldLeft(List[Width]()) { - case (res, wxx: MaxWidth) => wxx.args ++: res - case (res, wxx) => wxx +: res - })) - case wx => wx - } - def mergePlusMinus(w: Width): Width = w map mergePlusMinus match { - case wx: PlusWidth => (wx.arg1, wx.arg2) match { - case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width + w2.width) - case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1) - case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1) - case (IntWidth(y), PlusWidth(w1, IntWidth(x))) => PlusWidth(IntWidth(x + y), w1) - case (IntWidth(y), PlusWidth(IntWidth(x), w1)) => PlusWidth(IntWidth(x + y), w1) - case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(y - x), w1) - case (IntWidth(y), MinusWidth(w1, IntWidth(x))) => PlusWidth(IntWidth(y - x), w1) - case _ => wx - } - case wx: MinusWidth => (wx.arg1, wx.arg2) match { - case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width) - case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1) - case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1) - case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1) - case _ => wx - } - case wx: ExpWidth => wx.arg1 match { - case w1: IntWidth => IntWidth(BigInt((math.pow(2, w1.width.toDouble) - 1).toLong)) - case _ => wx - } - case wx => wx - } - def removeZeros(w: Width): Width = w map removeZeros match { - case wx: PlusWidth => (wx.arg1, wx.arg2) match { - case (w1, IntWidth(x)) if x == 0 => w1 - case (IntWidth(x), w1) if x == 0 => w1 - case _ => wx - } - case wx: MinusWidth => (wx.arg1, wx.arg2) match { - case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width) - case (w1, IntWidth(x)) if x == 0 => w1 - case _ => wx - } - case wx => wx - } - def simplify(w: Width): Width = { - val opts = Seq( - pullMinMax _, - collectMinMax _, - mergePlusMinus _, - removeZeros _ - ) - opts.foldLeft(w) { (width, opt) => opt(width) } - } - - def substitute(h: ConstraintMap)(w: Width): Width = { - //;println-all-debug(["Substituting for [" w "]"]) - val wx = simplify(w) - //;println-all-debug(["After Simplify: [" wx "]"]) - wx map substitute(h) match { - //;("matched println-debugvarwidth!") - case wxx: VarWidth => h get wxx.name match { - case None => wxx - case Some(p) => - //;println-debug("Contained!") - //;println-all-debug(["Width: " wxx]) - //;println-all-debug(["Accessed: " h[name(wxx)]]) - val t = simplify(substitute(h)(p)) - h(wxx.name) = t - t - } - case wxx => wxx - //;println-all-debug(["not varwidth!" w]) - } - } - - def b_sub(h: ConstraintMap)(w: Width): Width = { - w map b_sub(h) match { - case wx: VarWidth => h getOrElse (wx.name, wx) - case wx => wx - } - } - - def remove_cycle(n: String)(w: Width): Width = { - //;println-all-debug(["Removing cycle for " n " inside " w]) - w match { - case wx: MaxWidth => MaxWidth(wx.args filter { - case wxx: VarWidth => !(n equals wxx.name) - case MinusWidth(VarWidth(name), IntWidth(i)) if ((i >= 0) && (n == name)) => false - case _ => true - }) - case wx: MinusWidth => wx.arg1 match { - case v: VarWidth if n == v.name => v - case v => wx - } - case wx => wx - } - //;println-all-debug(["After removing cycle for " n ", returning " wx]) - } + val annotationClasses = Seq(classOf[WidthGeqConstraintAnnotation]) - def hasVarWidth(n: String)(w: Width): Boolean = { - var has = false - def rec(w: Width): Width = { - w match { - case wx: VarWidth if wx.name == n => has = true - case _ => + private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1,t2) match { + case (UIntType(w1), UIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint("")) + case (SIntType(w1), SIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint("")) + case (ClockType, ClockType) => + case (FixedType(w1, p1), FixedType(w2, p2)) => + constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint("")) + constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint("")) + case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) => + constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint("")) + constraintSolver.addLeq(l1, l2, r1.prettyPrint(""), r2.prettyPrint("")) + constraintSolver.addGeq(u1, u2, r1.prettyPrint(""), r2.prettyPrint("")) + case (AnalogType(w1), AnalogType(w2)) => + constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint("")) + constraintSolver.addGeq(w2, w1, r1.prettyPrint(""), r2.prettyPrint("")) + case (t1: BundleType, t2: BundleType) => + (t1.fields zip t2.fields) foreach { case (f1, f2) => + (f1.flip, f2.flip) match { + case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe) + case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe) + case _ => sys.error("Shouldn't be here") } - w map rec - } - rec(w) - has - } - - //; Forward solve - //; Returns a solved list where each constraint undergoes: - //; 1) Continuous Solving (using triangular solving) - //; 2) Remove Cycles - //; 3) Move to solved if not self-recursive - val u = make_unique(l) - - //println("======== UNIQUE CONSTRAINTS ========") - //for (x <- u) { println(x) } - //println("====================================") - - val f = new ConstraintMap - val o = ArrayBuffer[String]() - for ((n, e) <- u) { - //println("==== SOLUTIONS TABLE ====") - //for (x <- f) println(x) - //println("=========================") - - val e_sub = simplify(substitute(f)(e)) - - //println("Solving " + n + " => " + e) - //println("After Substitute: " + n + " => " + e_sub) - //println("==== SOLUTIONS TABLE (Post Substitute) ====") - //for (x <- f) println(x) - //println("=========================") - - val ex = remove_cycle(n)(e_sub) - - //println("After Remove Cycle: " + n + " => " + ex) - if (!hasVarWidth(n)(ex)) { - //println("Not rec!: " + n + " => " + ex) - //println("Adding [" + n + "=>" + ex + "] to Solutions Table") - f(n) = ex - o += n } - } - - //println("Forward Solved Constraints") - //for (x <- f) println(x) - - //; Backwards Solve - val b = new ConstraintMap - for (i <- (o.size - 1) to 0 by -1) { - val n = o(i) // Should visit `o` backward - /* - println("SOLVE BACK: [" + n + " => " + f(n) + "]") - println("==== SOLUTIONS TABLE ====") - for (x <- b) println(x) - println("=========================") - */ - val ex = simplify(b_sub(b)(f(n))) - /* - println("BACK RETURN: [" + n + " => " + ex + "]") - */ - b(n) = ex - /* - println("==== SOLUTIONS TABLE (Post backsolve) ====") - for (x <- b) println(x) - println("=========================") - */ - } - b - } - - def get_constraints_t(t1: Type, t2: Type): Seq[WGeq] = (t1,t2) match { - case (t1: UIntType, t2: UIntType) => Seq(WGeq(t1.width, t2.width)) - case (t1: SIntType, t2: SIntType) => Seq(WGeq(t1.width, t2.width)) - case (ClockType, ClockType) => Nil + case (t1: VectorType, t2: VectorType) => addTypeConstraints(r1.index(0), r2.index(0))(t1.tpe, t2.tpe) case (AsyncResetType, AsyncResetType) => Nil - case (FixedType(w1, p1), FixedType(w2, p2)) => Seq(WGeq(w1,w2), WGeq(p1,p2)) - case (AnalogType(w1), AnalogType(w2)) => Seq(WGeq(w1,w2), WGeq(w2,w1)) - case (t1: BundleType, t2: BundleType) => - (t1.fields zip t2.fields foldLeft Seq[WGeq]()){case (res, (f1, f2)) => - res ++ (f1.flip match { - case Default => get_constraints_t(f1.tpe, f2.tpe) - case Flip => get_constraints_t(f2.tpe, f1.tpe) - }) - } - case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe) case (ResetType, _) => Nil case (_, ResetType) => Nil } - def run(c: Circuit, extra: Seq[WGeq]): Circuit = { - val v = ArrayBuffer[WGeq]() ++ extra + private def addExpConstraints(e: Expression): Expression = e map addExpConstraints match { + case m@Mux(p, tVal, fVal, t) => + constraintSolver.addGeq(getWidth(p), Closed(1), "mux predicate", "1.W") + m + case other => other + } - def get_constraints_e(e: Expression): Unit = { - e match { - case (e: Mux) => v ++= Seq( - WGeq(getWidth(e.cond), IntWidth(1)), - WGeq(IntWidth(1), getWidth(e.cond)) - ) - case _ => + private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s map addExpConstraints match { + case c: Connect => + val n = get_size(c.loc.tpe) + val locs = create_exps(c.loc) + val exps = create_exps(c.expr) + (locs zip exps).foreach { case (loc, exp) => + to_flip(flow(loc)) match { + case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) + case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) + } } - e.foreach(get_constraints_e) - } - - def get_constraints_declared_type (t: Type): Type = t match { - case FixedType(_, p) => - v += WGeq(p,IntWidth(0)) - t - case _ => t map get_constraints_declared_type - } - - def get_constraints_s(s: Statement): Unit = { - s map get_constraints_declared_type match { - case (s: Connect) => - val locs = create_exps(s.loc) - val exps = create_exps(s.expr) - v ++= locs.zip(exps).flatMap { case (locx, expx) => - to_flip(flow(locx)) match { - case Default => get_constraints_t(locx.tpe, expx.tpe)//WGeq(getWidth(locx), getWidth(expx)) - case Flip => get_constraints_t(expx.tpe, locx.tpe)//WGeq(getWidth(expx), getWidth(locx)) - } - } - case (s: PartialConnect) => - val ls = get_valid_points(s.loc.tpe, s.expr.tpe, Default, Default) - val locs = create_exps(s.loc) - val exps = create_exps(s.expr) - v ++= (ls flatMap {case (x, y) => - val locx = locs(x) - val expx = exps(y) - to_flip(flow(locx)) match { - case Default => get_constraints_t(locx.tpe, expx.tpe)//WGeq(getWidth(locx), getWidth(expx)) - case Flip => get_constraints_t(expx.tpe, locx.tpe)//WGeq(getWidth(expx), getWidth(locx)) - } - }) - case (s: DefRegister) => - if (s.reset.tpe != AsyncResetType ) { - v ++= ( - get_constraints_t(s.reset.tpe, UIntType(IntWidth(1))) ++ - get_constraints_t(UIntType(IntWidth(1)), s.reset.tpe)) - } - v ++= get_constraints_t(s.tpe, s.init.tpe) - case (s:Conditionally) => v ++= - get_constraints_t(s.pred.tpe, UIntType(IntWidth(1))) ++ - get_constraints_t(UIntType(IntWidth(1)), s.pred.tpe) - case Attach(_, exprs) => - // All widths must be equal - val widths = exprs map (e => getWidth(e.tpe)) - v ++= widths.tail map (WGeq(widths.head, _)) - case _ => + c + case pc: PartialConnect => + val ls = get_valid_points(pc.loc.tpe, pc.expr.tpe, Default, Default) + val locs = create_exps(pc.loc) + val exps = create_exps(pc.expr) + ls foreach { case (x, y) => + val loc = locs(x) + val exp = exps(y) + to_flip(flow(loc)) match { + case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) + case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) + } } - s.foreach(get_constraints_e) - s.foreach(get_constraints_s) - } - - c.modules.foreach(_.foreach(get_constraints_s)) - c.modules.foreach(_.ports.foreach({p => get_constraints_declared_type(p.tpe)})) - - //println("======== ALL CONSTRAINTS ========") - //for(x <- v) println(x) - //println("=================================") - val h = solve_constraints(v) - //println("======== SOLVED CONSTRAINTS ========") - //for(x <- h) println(x) - //println("====================================") - - def evaluate(w: Width): Width = { - def map2(a: Option[BigInt], b: Option[BigInt], f: (BigInt,BigInt) => BigInt): Option[BigInt] = - for (a_num <- a; b_num <- b) yield f(a_num, b_num) - def reduceOptions(l: Seq[Option[BigInt]], f: (BigInt,BigInt) => BigInt): Option[BigInt] = - l.reduce(map2(_, _, f)) + pc + case r: DefRegister => + if (r.reset.tpe != AsyncResetType ) { + addTypeConstraints(Target.asTarget(mt)(r.reset), mt.ref("1"))(r.reset.tpe, UIntType(IntWidth(1))) + } + addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe) + r + case a@Attach(_, exprs) => + val widths = exprs map (e => (e, getWidth(e.tpe))) + val maxWidth = IsMax(widths.map(x => width2constraint(x._2))) + widths.foreach { case (e, w) => + constraintSolver.addGeq(w, CalcWidth(maxWidth), Target.asTarget(mt)(e).prettyPrint(""), mt.ref(a.serialize).prettyPrint("")) + } + a + case c: Conditionally => + addTypeConstraints(Target.asTarget(mt)(c.pred), mt.ref("1.W"))(c.pred.tpe, UIntType(IntWidth(1))) + c map addStmtConstraints(mt) + case x => x map addStmtConstraints(mt) + } + private def fixWidth(w: Width): Width = constraintSolver.get(w) match { + case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) + case None => w + case _ => sys.error("Shouldn't be here") + } + private def fixType(t: Type): Type = t map fixType map fixWidth match { + case IntervalType(l, u, p) => + val (lx, ux) = (constraintSolver.get(l), constraintSolver.get(u)) match { + case (Some(x: Bound), Some(y: Bound)) => (x, y) + case (None, None) => (l, u) + case x => sys.error(s"Shouldn't be here: $x") - // This function shouldn't be necessary - // Added as protection in case a constraint accidentally uses MinWidth/MaxWidth - // without any actual Widths. This should be elevated to an earlier error - def forceNonEmpty(in: Seq[Option[BigInt]], default: Option[BigInt]): Seq[Option[BigInt]] = - if (in.isEmpty) Seq(default) - else in - def solve(w: Width): Option[BigInt] = w match { - case wx: VarWidth => - for{ - v <- h.get(wx.name) if !v.isInstanceOf[VarWidth] - result <- solve(v) - } yield result - case wx: MaxWidth => reduceOptions(forceNonEmpty(wx.args.map(solve), Some(BigInt(0))), max) - case wx: MinWidth => reduceOptions(forceNonEmpty(wx.args.map(solve), None), min) - case wx: PlusWidth => map2(solve(wx.arg1), solve(wx.arg2), {_ + _}) - case wx: MinusWidth => map2(solve(wx.arg1), solve(wx.arg2), {_ - _}) - case wx: ExpWidth => map2(Some(BigInt(2)), solve(wx.arg1), pow_minus_one) - case wx: IntWidth => Some(wx.width) - case wx => throwInternalError(s"solve: shouldn't be here - %$wx") } + IntervalType(lx, ux, fixWidth(p)) + case FixedType(w, p) => FixedType(w, fixWidth(p)) + case x => x + } + private def fixStmt(s: Statement): Statement = s map fixStmt map fixType + private def fixPort(p: Port): Port = { + Port(p.info, p.name, p.direction, fixType(p.tpe)) + } - solve(w) match { - case None => w - case Some(s) => IntWidth(s) - } - } - - def reduce_var_widths_w(w: Width): Width = { - //println-all-debug(["REPLACE: " w]) - evaluate(w) - //println-all-debug(["WITH: " wx]) - } - - def reduce_var_widths_t(t: Type): Type = { - t map reduce_var_widths_t map reduce_var_widths_w - } - - def reduce_var_widths_s(s: Statement): Statement = { - s map reduce_var_widths_s map reduce_var_widths_t - } - - def reduce_var_widths_p(p: Port): Port = { - Port(p.info, p.name, p.direction, reduce_var_widths_t(p.tpe)) - } - - InferTypes.run(c.copy(modules = c.modules map (_ - map reduce_var_widths_p - map reduce_var_widths_s))) + def run (c: Circuit): Circuit = { + val ct = CircuitTarget(c.main) + c.modules foreach ( m => m map addStmtConstraints(ct.module(m.name))) + constraintSolver.solve() + val ret = InferTypes.run(c.copy(modules = c.modules map (_ + map fixPort + map fixStmt))) + constraintSolver.clear() + ret } def execute(state: CircuitState): CircuitState = { @@ -426,7 +208,7 @@ class InferWidths extends Transform with ResolvedAnnotationPaths { } } - val extraConstraints = state.annotations.flatMap { + state.annotations.foreach { case anno: WidthGeqConstraintAnnotation if anno.loc.isLocal && anno.exp.isLocal => val locType :: expType :: Nil = Seq(anno.loc, anno.exp) map { target => val baseType = typeMap.getOrElse(target.copy(component = Seq.empty), @@ -440,10 +222,12 @@ class InferWidths extends Transform with ResolvedAnnotationPaths { leafType } - get_constraints_t(locType, expType) - case other => Seq.empty + //get_constraints_t(locType, expType) + addTypeConstraints(anno.loc, anno.exp)(locType, expType) + case other => } - state.copy(circuit = run(state.circuit, extraConstraints)) + state.copy(circuit = run(state.circuit)) } + } diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala index 48b5f041..921ec3c7 100644 --- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala +++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala @@ -72,7 +72,7 @@ object RemoveCHIRRTL extends Transform { refs: DataRefMap, raddrs: AddrMap, renames: RenameMap)(s: Statement): Statement = s match { case sx: CDefMemory => types(sx.name) = sx.tpe - val taddr = UIntType(IntWidth(1 max ceilLog2(sx.size))) + val taddr = UIntType(IntWidth(1 max getUIntWidth(sx.size - 1))) val tdata = sx.tpe def set_poison(vec: Seq[MPort]) = vec flatMap (r => Seq( IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "addr", taddr)), diff --git a/src/main/scala/firrtl/passes/RemoveIntervals.scala b/src/main/scala/firrtl/passes/RemoveIntervals.scala new file mode 100644 index 00000000..73f59b59 --- /dev/null +++ b/src/main/scala/firrtl/passes/RemoveIntervals.scala @@ -0,0 +1,173 @@ +// See LICENSE for license details. + +package firrtl.passes + +import firrtl.PrimOps._ +import firrtl.ir._ +import firrtl._ +import firrtl.Mappers._ +import Implicits.{bigint2WInt} +import firrtl.constraint.IsKnown + +import scala.math.BigDecimal.RoundingMode._ + +class WrapWithRemainder(info: Info, mname: String, wrap: DoPrim) + extends PassException({ + val toWrap = wrap.args.head.serialize + val toWrapTpe = wrap.args.head.tpe.serialize + val wrapTo = wrap.args(1).serialize + val wrapToTpe = wrap.args(1).tpe.serialize + s"$info: [module $mname] Wraps with remainder currently unsupported: $toWrap:$toWrapTpe cannot be wrapped to $wrapTo's type $wrapToTpe" + }) + + +/** Replaces IntervalType with SIntType, three AST walks: + * 1) Align binary points + * - adds shift operators to primop args and connections + * - does not affect declaration- or inferred-types + * 2) Replace Interval [[DefNode]] with [[DefWire]] + [[Connect]] + * - You have to do this to capture the smaller bitwidths of nodes that intervals give you. Otherwise, any future + * InferTypes would reinfer the larger widths on these nodes from SInt width inference rules + * 3) Replace declaration IntervalType's with SIntType's + * - for each declaration: + * a. remove non-zero binary points + * b. remove open bounds + * c. replace with SIntType + * 3) Run InferTypes + */ +class RemoveIntervals extends Pass { + + def run(c: Circuit): Circuit = { + val alignedCircuit = c + val errors = new Errors() + val wiredCircuit = alignedCircuit map makeWireModule + val replacedCircuit = wiredCircuit map replaceModuleInterval(errors) + errors.trigger() + InferTypes.run(replacedCircuit) + } + + /* Replace interval types */ + private def replaceModuleInterval(errors: Errors)(m: DefModule): DefModule = + m map replaceStmtInterval(errors, m.name) map replacePortInterval + + private def replaceStmtInterval(errors: Errors, mname: String)(s: Statement): Statement = { + val info = s match { + case h: HasInfo => h.info + case _ => NoInfo + } + s map replaceTypeInterval map replaceStmtInterval(errors, mname) map replaceExprInterval(errors, info, mname) + + } + + private def replaceExprInterval(errors: Errors, info: Info, mname: String)(e: Expression): Expression = e match { + case _: WRef | _: WSubIndex | _: WSubField => e + case o => + o map replaceExprInterval(errors, info, mname) match { + case DoPrim(AsInterval, Seq(a1), _, tpe) => DoPrim(AsSInt, Seq(a1), Seq.empty, tpe) + case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe) + case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe) + case DoPrim(Clip, Seq(a1, _), Nil, tpe: IntervalType) => + // Output interval (pre-calculated) + val clipLo = tpe.minAdjusted.get + val clipHi = tpe.maxAdjusted.get + // Input interval + val (inLow, inHigh) = a1.tpe match { + case t2: IntervalType => (t2.minAdjusted.get, t2.maxAdjusted.get) + case _ => sys.error("Shouldn't be here") + } + val gtOpt = clipHi >= inHigh + val ltOpt = clipLo <= inLow + (gtOpt, ltOpt) match { + // input range within output range -> no optimization + case (true, true) => a1 + case (true, false) => Mux(Lt(a1, clipLo.S), clipLo.S, a1) + case (false, true) => Mux(Gt(a1, clipHi.S), clipHi.S, a1) + case _ => Mux(Gt(a1, clipHi.S), clipHi.S, Mux(Lt(a1, clipLo.S), clipLo.S, a1)) + } + + case sqz@DoPrim(Squeeze, Seq(a1, a2), Nil, tpe: IntervalType) => + // Using (conditional) reassign interval w/o adding mux + val a1tpe = a1.tpe.asInstanceOf[IntervalType] + val a2tpe = a2.tpe.asInstanceOf[IntervalType] + val min2 = a2tpe.min.get * BigDecimal(BigInt(1) << a1tpe.point.asInstanceOf[IntWidth].width.toInt) + val max2 = a2tpe.max.get * BigDecimal(BigInt(1) << a1tpe.point.asInstanceOf[IntWidth].width.toInt) + val w1 = Seq(a1tpe.minAdjusted.get.bitLength, a1tpe.maxAdjusted.get.bitLength).max + 1 + // Conservative + val minOpt2 = min2.setScale(0, FLOOR).toBigInt + val maxOpt2 = max2.setScale(0, CEILING).toBigInt + val w2 = Seq(minOpt2.bitLength, maxOpt2.bitLength).max + 1 + if (w1 < w2) { + a1 + } else { + val bits = DoPrim(Bits, Seq(a1), Seq(w2 - 1, 0), UIntType(IntWidth(w2))) + DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(w2))) + } + case w@DoPrim(Wrap, Seq(a1, a2), Nil, tpe: IntervalType) => a2.tpe match { + // If a2 type is Interval wrap around range. If UInt, wrap around width + case t: IntervalType => + // Need to match binary points before getting *adjusted! + val (wrapLo, wrapHi) = t.copy(point = tpe.point) match { + case t: IntervalType => (t.minAdjusted.get, t.maxAdjusted.get) + case _ => Utils.throwInternalError(s"Illegal AST state: cannot have $e not have an IntervalType") + } + val (inLo, inHi) = a1.tpe match { + case t2: IntervalType => (t2.minAdjusted.get, t2.maxAdjusted.get) + case _ => sys.error("Shouldn't be here") + } + // If (max input) - (max wrap) + (min wrap) is less then (maxwrap), we can optimize when (max input > max wrap) + val range = wrapHi - wrapLo + val ltOpt = Add(a1, (range + 1).S) + val gtOpt = Sub(a1, (range + 1).S) + // [Angie]: This is dangerous. Would rather throw compilation error right now than allow "Rem" without the user explicitly including it. + // If x < wl + // output: wh - (wl - x) + 1 AKA x + r + 1 + // worst case: wh - (wl - xl) + 1 = wl + // -> xl + wr + 1 = wl + // If x > wh + // output: wl + (x - wh) - 1 AKA x - r - 1 + // worst case: wl + (xh - wh) - 1 = wh + // -> xh - wr - 1 = wh + val default = Add(Rem(Sub(a1, wrapLo.S), Sub(wrapHi.S, wrapLo.S)), wrapLo.S) + (wrapHi >= inHi, wrapLo <= inLo, (inHi - range - 1) <= wrapHi, (inLo + range + 1) >= wrapLo) match { + case (true, true, _, _) => a1 + case (true, _, _, true) => Mux(Lt(a1, wrapLo.S), ltOpt, a1) + case (_, true, true, _) => Mux(Gt(a1, wrapHi.S), gtOpt, a1) + // Note: inHi - range - 1 = wrapHi can't be true when inLo + range + 1 = wrapLo (i.e. simultaneous extreme cases don't work) + case (_, _, true, true) => Mux(Gt(a1, wrapHi.S), gtOpt, Mux(Lt(a1, wrapLo.S), ltOpt, a1)) + case _ => + errors.append(new WrapWithRemainder(info, mname, w)) + default + } + case _ => sys.error("Shouldn't be here") + } + case other => other + } + } + + private def replacePortInterval(p: Port): Port = p map replaceTypeInterval + + private def replaceTypeInterval(t: Type): Type = t match { + case i@IntervalType(l: IsKnown, u: IsKnown, p: IntWidth) => SIntType(i.width) + case i: IntervalType => sys.error(s"Shouldn't be here: $i") + case v => v map replaceTypeInterval + } + + /** Replace Interval Nodes with Interval Wires + * + * You have to do this to capture the smaller bitwidths of nodes that intervals give you. Otherwise, + * any future InferTypes would reinfer the larger widths on these nodes from SInt width inference rules + * @param m module to replace nodes with wire + connection + * @return + */ + private def makeWireModule(m: DefModule): DefModule = m map makeWireStmt + + private def makeWireStmt(s: Statement): Statement = s match { + case DefNode(info, name, value) => value.tpe match { + case IntervalType(l, u, p) => + val newType = IntervalType(l, u, p) + Block(Seq(DefWire(info, name, newType), Connect(info, WRef(name, newType, WireKind, FEMALE), value))) + case other => s + } + case other => other map makeWireStmt + } +} diff --git a/src/main/scala/firrtl/passes/TrimIntervals.scala b/src/main/scala/firrtl/passes/TrimIntervals.scala new file mode 100644 index 00000000..dec64ee7 --- /dev/null +++ b/src/main/scala/firrtl/passes/TrimIntervals.scala @@ -0,0 +1,97 @@ +// See LICENSE for license details. + +package firrtl.passes + +import scala.collection.mutable +import firrtl.PrimOps._ +import firrtl.ir._ +import firrtl._ +import firrtl.Mappers._ +import firrtl.Utils.{error, field_type, getUIntWidth, max, module_type, sub_type} +import Implicits.{bigint2WInt, int2WInt} +import firrtl.constraint.{IsFloor, IsKnown, IsMul} + +/** Replaces IntervalType with SIntType, three AST walks: + * 1) Align binary points + * - adds shift operators to primop args and connections + * - does not affect declaration- or inferred-types + * 2) Replace declaration IntervalType's with SIntType's + * - for each declaration: + * a. remove non-zero binary points + * b. remove open bounds + * c. replace with SIntType + * 3) Run InferTypes + */ +class TrimIntervals extends Pass { + def run(c: Circuit): Circuit = { + // Open -> closed + val firstPass = InferTypes.run(c map replaceModuleInterval) + // Align binary points and adjust range accordingly (loss of precision changes range) + firstPass map alignModuleBP + } + + /* Replace interval types */ + private def replaceModuleInterval(m: DefModule): DefModule = m map replaceStmtInterval map replacePortInterval + + private def replaceStmtInterval(s: Statement): Statement = s map replaceTypeInterval map replaceStmtInterval + + private def replacePortInterval(p: Port): Port = p map replaceTypeInterval + + private def replaceTypeInterval(t: Type): Type = t match { + case i@IntervalType(l: IsKnown, u: IsKnown, IntWidth(p)) => + IntervalType(Closed(i.min.get), Closed(i.max.get), IntWidth(p)) + case i: IntervalType => i + case v => v map replaceTypeInterval + } + + /* Align interval binary points -- BINARY POINT ALIGNMENT AFFECTS RANGE INFERENCE! */ + private def alignModuleBP(m: DefModule): DefModule = m map alignStmtBP + + private def alignStmtBP(s: Statement): Statement = s map alignExpBP match { + case c@Connect(info, loc, expr) => loc.tpe match { + case IntervalType(_, _, p) => Connect(info, loc, fixBP(p)(expr)) + case _ => c + } + case c@PartialConnect(info, loc, expr) => loc.tpe match { + case IntervalType(_, _, p) => PartialConnect(info, loc, fixBP(p)(expr)) + case _ => c + } + case other => other map alignStmtBP + } + + // Note - wrap/clip/squeeze ignore the binary point of the second argument, thus not needed to be aligned + // Note - Mul does not need its binary points aligned, because multiplication is cool like that + private val opsToFix = Seq(Add, Sub, Lt, Leq, Gt, Geq, Eq, Neq/*, Wrap, Clip, Squeeze*/) + + private def alignExpBP(e: Expression): Expression = e map alignExpBP match { + case DoPrim(SetP, Seq(arg), Seq(const), tpe: IntervalType) => fixBP(IntWidth(const))(arg) + case DoPrim(o, args, consts, t) if opsToFix.contains(o) && + (args.map(_.tpe).collect { case x: IntervalType => x }).size == args.size => + val maxBP = args.map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ max _) + DoPrim(o, args.map { a => fixBP(maxBP)(a) }, consts, t) + case Mux(cond, tval, fval, t: IntervalType) => + val maxBP = Seq(tval, fval).map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ max _) + Mux(cond, fixBP(maxBP)(tval), fixBP(maxBP)(fval), t) + case other => other + } + private def fixBP(p: Width)(e: Expression): Expression = (p, e.tpe) match { + case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired == current => e + case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired > current => + DoPrim(IncP, Seq(e), Seq(desired - current), IntervalType(l, u, IntWidth(desired))) + case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired < current => + val shiftAmt = current - desired + val shiftGain = BigDecimal(BigInt(1) << shiftAmt.toInt) + val shiftMul = Closed(BigDecimal(1) / shiftGain) + val bpGain = BigDecimal(BigInt(1) << current.toInt) + // BP is inferred at this point + // y = floor(x * 2^(-amt + bp)) gets rid of precision --> y * 2^(-bp + amt) + val newBPRes = Closed(shiftGain / bpGain) + val bpResInv = Closed(bpGain) + val newL = IsMul(IsFloor(IsMul(IsMul(l, shiftMul), bpResInv)), newBPRes) + val newU = IsMul(IsFloor(IsMul(IsMul(u, shiftMul), bpResInv)), newBPRes) + DoPrim(DecP, Seq(e), Seq(current - desired), IntervalType(CalcBound(newL), CalcBound(newU), IntWidth(desired))) + case x => sys.error(s"Shouldn't be here: $x") + } +} + +// vim: set ts=4 sw=4 et: diff --git a/src/main/scala/firrtl/passes/memlib/MemUtils.scala b/src/main/scala/firrtl/passes/memlib/MemUtils.scala index bb441ebb..69c6b284 100644 --- a/src/main/scala/firrtl/passes/memlib/MemUtils.scala +++ b/src/main/scala/firrtl/passes/memlib/MemUtils.scala @@ -56,7 +56,7 @@ object MemPortUtils { type Modules = collection.mutable.ArrayBuffer[DefModule] def defaultPortSeq(mem: DefMemory): Seq[Field] = Seq( - Field("addr", Default, UIntType(IntWidth(ceilLog2(mem.depth) max 1))), + Field("addr", Default, UIntType(IntWidth(getUIntWidth(mem.depth - 1) max 1))), Field("en", Default, BoolType), Field("clk", Default, ClockType) ) diff --git a/src/main/scala/firrtl/proto/ToProto.scala b/src/main/scala/firrtl/proto/ToProto.scala index 70de3ccd..9fe01a07 100644 --- a/src/main/scala/firrtl/proto/ToProto.scala +++ b/src/main/scala/firrtl/proto/ToProto.scala @@ -99,9 +99,13 @@ object ToProto { Bits -> Op.OP_EXTRACT_BITS, Head -> Op.OP_HEAD, Tail -> Op.OP_TAIL, - BPShl -> Op.OP_SHIFT_BINARY_POINT_LEFT, - BPShr -> Op.OP_SHIFT_BINARY_POINT_RIGHT, - BPSet -> Op.OP_SET_BINARY_POINT + IncP -> Op.OP_INCREASE_PRECISION, + DecP -> Op.OP_DECREASE_PRECISION, + SetP -> Op.OP_SET_PRECISION, + AsInterval -> Op.OP_AS_INTERVAL, + Squeeze -> Op.OP_SQUEEZE, + Wrap -> Op.OP_WRAP, + Clip -> Op.OP_CLIP ) def convert(ruw: ir.ReadUnderWrite.Value): ReadUnderWrite = ruw match { diff --git a/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala b/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala index d06344af..89f2ec07 100644 --- a/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala +++ b/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala @@ -3,14 +3,12 @@ package firrtl.stage.phases.tests import org.scalatest.{FlatSpec, Matchers, PrivateMethodTester} - import java.io.File import firrtl._ -import firrtl.stage._ import firrtl.stage.phases.DriverCompatibility._ - import firrtl.options.{InputAnnotationFileAnnotation, Phase, TargetDirAnnotation} +import firrtl.stage.{CompilerAnnotation, FirrtlCircuitAnnotation, FirrtlFileAnnotation, FirrtlSourceAnnotation, OutputFileAnnotation, RunFirrtlTransformAnnotation} import firrtl.stage.phases.DriverCompatibility class DriverCompatibilitySpec extends FlatSpec with Matchers with PrivateMethodTester { diff --git a/src/test/scala/firrtlTests/AsyncResetSpec.scala b/src/test/scala/firrtlTests/AsyncResetSpec.scala index 6fcb647a..8ad397b3 100644 --- a/src/test/scala/firrtlTests/AsyncResetSpec.scala +++ b/src/test/scala/firrtlTests/AsyncResetSpec.scala @@ -51,16 +51,19 @@ class AsyncResetSpec extends FirrtlFlatSpec { it should "support casting to other types" in { val result = compileBody(s""" |input a : AsyncReset + |output u : Interval[0, 1].0 |output v : UInt<1> |output w : SInt<1> |output x : Clock |output y : Fixed<1><<0>> |output z : AsyncReset + |u <= asInterval(a, 0, 1, 0) |v <= asUInt(a) |w <= asSInt(a) |x <= asClock(a) |y <= asFixedPoint(a, 0) - |z <= asAsyncReset(a)""".stripMargin + |z <= asAsyncReset(a) + |""".stripMargin ) result should containLine ("assign v = $unsigned(a);") result should containLine ("assign w = $signed(a);") @@ -76,22 +79,26 @@ class AsyncResetSpec extends FirrtlFlatSpec { |input c : Clock |input d : Fixed<1><<0>> |input e : AsyncReset + |input f : Interval[0, 1].0 + |output u : AsyncReset |output v : AsyncReset |output w : AsyncReset |output x : AsyncReset |output y : AsyncReset |output z : AsyncReset - |v <= asAsyncReset(a) - |w <= asAsyncReset(a) - |x <= asAsyncReset(a) - |y <= asAsyncReset(a) - |z <= asAsyncReset(a)""".stripMargin + |u <= asAsyncReset(a) + |v <= asAsyncReset(b) + |w <= asAsyncReset(c) + |x <= asAsyncReset(d) + |y <= asAsyncReset(e) + |z <= asAsyncReset(f)""".stripMargin ) - result should containLine ("assign v = a;") - result should containLine ("assign w = a;") - result should containLine ("assign x = a;") - result should containLine ("assign y = a;") - result should containLine ("assign z = a;") + result should containLine ("assign u = a;") + result should containLine ("assign v = b;") + result should containLine ("assign w = c;") + result should containLine ("assign x = d;") + result should containLine ("assign y = e;") + result should containLine ("assign z = f;") } "Non-literals" should "NOT be allowed as reset values for AsyncReset" in { diff --git a/src/test/scala/firrtlTests/ChirrtlSpec.scala b/src/test/scala/firrtlTests/ChirrtlSpec.scala index fba81ec7..b82637b6 100644 --- a/src/test/scala/firrtlTests/ChirrtlSpec.scala +++ b/src/test/scala/firrtlTests/ChirrtlSpec.scala @@ -70,8 +70,8 @@ class ChirrtlSpec extends FirrtlFlatSpec { behavior of "Uniqueness" for ((description, input) <- CheckSpec.nonUniqueExamples) { it should s"be asserted for $description" in { - assertThrows[CheckChirrtl.NotUniqueException] { - Seq(ToWorkingIR, CheckChirrtl).foldLeft(Parser.parse(input)){ case (c, tx) => tx.run(c) } + assertThrows[CheckHighForm.NotUniqueException] { + Seq(ToWorkingIR, CheckHighForm).foldLeft(Parser.parse(input)){ case (c, tx) => tx.run(c) } } } } diff --git a/src/test/scala/firrtlTests/InfoSpec.scala b/src/test/scala/firrtlTests/InfoSpec.scala index dbc997cd..9d6206af 100644 --- a/src/test/scala/firrtlTests/InfoSpec.scala +++ b/src/test/scala/firrtlTests/InfoSpec.scala @@ -66,7 +66,7 @@ class InfoSpec extends FirrtlFlatSpec { result should containLine (s"assign n = w | x; //$Info3") } - they should "be propagated on memories" in { + it should "be propagated on memories" in { val result = compileBody(s""" |input clock : Clock |input addr : UInt<5> @@ -102,7 +102,7 @@ class InfoSpec extends FirrtlFlatSpec { result should containLine (s"m[m_w_addr] <= m_w_data; //$Info1") } - they should "be propagated on instances" in { + it should "be propagated on instances" in { val result = compile(s""" |circuit Test : | module Child : diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala index be9d738b..69379c51 100644 --- a/src/test/scala/firrtlTests/LowerTypesSpec.scala +++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala @@ -12,7 +12,7 @@ import firrtl.transforms._ import firrtl._ class LowerTypesSpec extends FirrtlFlatSpec { - private val transforms = Seq( + private def transforms = Seq( ToWorkingIR, CheckHighForm, ResolveKinds, diff --git a/src/test/scala/firrtlTests/UniquifySpec.scala b/src/test/scala/firrtlTests/UniquifySpec.scala index afb82384..e64e9105 100644 --- a/src/test/scala/firrtlTests/UniquifySpec.scala +++ b/src/test/scala/firrtlTests/UniquifySpec.scala @@ -16,7 +16,7 @@ import firrtl.util.TestOptions class UniquifySpec extends FirrtlFlatSpec { - private val transforms = Seq( + private def transforms = Seq( ToWorkingIR, CheckHighForm, ResolveKinds, diff --git a/src/test/scala/firrtlTests/ZeroWidthTests.scala b/src/test/scala/firrtlTests/ZeroWidthTests.scala index eb3d1a96..f1dadcee 100644 --- a/src/test/scala/firrtlTests/ZeroWidthTests.scala +++ b/src/test/scala/firrtlTests/ZeroWidthTests.scala @@ -11,7 +11,7 @@ import firrtl.Parser import firrtl.passes._ class ZeroWidthTests extends FirrtlFlatSpec { - val transforms = Seq( + def transforms = Seq( ToWorkingIR, ResolveKinds, InferTypes, diff --git a/src/test/scala/firrtlTests/constraint/InequalitySpec.scala b/src/test/scala/firrtlTests/constraint/InequalitySpec.scala new file mode 100644 index 00000000..02a853cb --- /dev/null +++ b/src/test/scala/firrtlTests/constraint/InequalitySpec.scala @@ -0,0 +1,197 @@ +package firrtlTests.constraint + +import firrtl.constraint._ +import org.scalatest.{FlatSpec, Matchers} +import firrtl.ir.Closed + +class InequalitySpec extends FlatSpec with Matchers { + + behavior of "Constraints" + + "IsConstraints" should "reduce properly" in { + IsMin(Closed(0), Closed(1)) should be (Closed(0)) + IsMin(Closed(-1), Closed(1)) should be (Closed(-1)) + IsMax(Closed(-1), Closed(1)) should be (Closed(1)) + IsNeg(IsMul(Closed(-1), Closed(-2))) should be (Closed(-2)) + val x = IsMin(IsMul(Closed(1), VarCon("a")), Closed(2)) + x.children.toSet should be (IsMin(Closed(2), IsMul(Closed(1), VarCon("a"))).children.toSet) + } + + "IsAdd" should "reduce properly" in { + // All constants + IsAdd(Closed(-1), Closed(1)) should be (Closed(0)) + + // Pull Out IsMax + IsAdd(Closed(1), IsMax(Closed(1), VarCon("a"))) should be (IsMax(Closed(2), IsAdd(VarCon("a"), Closed(1)))) + IsAdd(Closed(1), IsMax(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be ( + IsMax(Seq(Closed(2), IsAdd(VarCon("a"), Closed(1)), IsAdd(VarCon("b"), Closed(1)))) + ) + + // Pull Out IsMin + IsAdd(Closed(1), IsMin(Closed(1), VarCon("a"))) should be (IsMin(Closed(2), IsAdd(VarCon("a"), Closed(1)))) + IsAdd(Closed(1), IsMin(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be ( + IsMin(Seq(Closed(2), IsAdd(VarCon("a"), Closed(1)), IsAdd(VarCon("b"), Closed(1)))) + ) + + // Add Zero + IsAdd(Closed(0), VarCon("a")) should be (VarCon("a")) + + // One argument + IsAdd(Seq(VarCon("a"))) should be (VarCon("a")) + } + + "IsMax" should "reduce properly" in { + // All constants + IsMax(Closed(-1), Closed(1)) should be (Closed(1)) + + // Flatten nested IsMax + IsMax(Closed(1), IsMax(Closed(1), VarCon("a"))) should be (IsMax(Closed(1), VarCon("a"))) + IsMax(Closed(1), IsMax(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be ( + IsMax(Seq(Closed(1), VarCon("a"), VarCon("b"))) + ) + + // Eliminate IsMins if possible + IsMax(Closed(2), IsMin(Closed(1), VarCon("a"))) should be (Closed(2)) + IsMax(Seq( + Closed(2), + IsMin(Closed(1), VarCon("a")), + IsMin(Closed(3), VarCon("b")) + )) should be ( + IsMax(Seq( + Closed(2), + IsMin(Closed(3), VarCon("b")) + )) + ) + + // One argument + IsMax(Seq(VarCon("a"))) should be (VarCon("a")) + IsMax(Seq(Closed(0))) should be (Closed(0)) + IsMax(Seq(IsMin(VarCon("a"), Closed(0)))) should be (IsMin(VarCon("a"), Closed(0))) + } + + "IsMin" should "reduce properly" in { + // All constants + IsMin(Closed(-1), Closed(1)) should be (Closed(-1)) + + // Flatten nested IsMin + IsMin(Closed(1), IsMin(Closed(1), VarCon("a"))) should be (IsMin(Closed(1), VarCon("a"))) + IsMin(Closed(1), IsMin(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be ( + IsMin(Seq(Closed(1), VarCon("a"), VarCon("b"))) + ) + + // Eliminate IsMaxs if possible + IsMin(Closed(1), IsMax(Closed(2), VarCon("a"))) should be (Closed(1)) + IsMin(Seq( + Closed(2), + IsMax(Closed(1), VarCon("a")), + IsMax(Closed(3), VarCon("b")) + )) should be ( + IsMin(Seq( + Closed(2), + IsMax(Closed(1), VarCon("a")) + )) + ) + + // One argument + IsMin(Seq(VarCon("a"))) should be (VarCon("a")) + IsMin(Seq(Closed(0))) should be (Closed(0)) + IsMin(Seq(IsMax(VarCon("a"), Closed(0)))) should be (IsMax(VarCon("a"), Closed(0))) + } + + "IsMul" should "reduce properly" in { + // All constants + IsMul(Closed(2), Closed(3)) should be (Closed(6)) + + // Pull out max, if positive stays max + IsMul(Closed(2), IsMax(Closed(3), VarCon("a"))) should be( + IsMax(Closed(6), IsMul(Closed(2), VarCon("a"))) + ) + + // Pull out max, if negative is min + IsMul(Closed(-2), IsMax(Closed(3), VarCon("a"))) should be( + IsMin(Closed(-6), IsMul(Closed(-2), VarCon("a"))) + ) + + // Pull out min, if positive stays min + IsMul(Closed(2), IsMin(Closed(3), VarCon("a"))) should be( + IsMin(Closed(6), IsMul(Closed(2), VarCon("a"))) + ) + + // Pull out min, if negative is max + IsMul(Closed(-2), IsMin(Closed(3), VarCon("a"))) should be( + IsMax(Closed(-6), IsMul(Closed(-2), VarCon("a"))) + ) + + // Times zero + IsMul(Closed(0), VarCon("x")) should be (Closed(0)) + + // Times 1 + IsMul(Closed(1), VarCon("x")) should be (VarCon("x")) + + // One argument + IsMul(Seq(Closed(0))) should be (Closed(0)) + IsMul(Seq(VarCon("a"))) should be (VarCon("a")) + + // No optimizations + val isMax = IsMax(VarCon("x"), VarCon("y")) + val isMin = IsMin(VarCon("x"), VarCon("y")) + val a = VarCon("a") + IsMul(a, isMax).children should be (Vector(a, isMax)) //non-known multiply + IsMul(a, isMin).children should be (Vector(a, isMin)) //non-known multiply + IsMul(Seq(Closed(2), isMin, isMin)).children should be (Vector(Closed(2), isMin, isMin)) //>1 min + IsMul(Seq(Closed(2), isMax, isMax)).children should be (Vector(Closed(2), isMax, isMax)) //>1 max + IsMul(Seq(Closed(2), isMin, isMax)).children should be (Vector(Closed(2), isMin, isMax)) //mixed min/max + } + + "IsNeg" should "reduce properly" in { + // All constants + IsNeg(Closed(1)) should be (Closed(-1)) + // Pull out max + IsNeg(IsMax(Closed(1), VarCon("a"))) should be (IsMin(Closed(-1), IsNeg(VarCon("a")))) + // Pull out min + IsNeg(IsMin(Closed(1), VarCon("a"))) should be (IsMax(Closed(-1), IsNeg(VarCon("a")))) + // Pull out add + IsNeg(IsAdd(Closed(1), VarCon("a"))) should be (IsAdd(Closed(-1), IsNeg(VarCon("a")))) + // Pull out mul + IsNeg(IsMul(Closed(2), VarCon("a"))) should be (IsMul(Closed(-2), VarCon("a"))) + // No optimizations + // (pow), (floor?) + IsNeg(IsPow(VarCon("x"))).children should be (Vector(IsPow(VarCon("x")))) + IsNeg(IsFloor(VarCon("x"))).children should be (Vector(IsFloor(VarCon("x")))) + } + + "IsPow" should "reduce properly" in { + // All constants + IsPow(Closed(1)) should be (Closed(2)) + // Pull out max + IsPow(IsMax(Closed(1), VarCon("a"))) should be (IsMax(Closed(2), IsPow(VarCon("a")))) + // Pull out min + IsPow(IsMin(Closed(1), VarCon("a"))) should be (IsMin(Closed(2), IsPow(VarCon("a")))) + // Pull out add + IsPow(IsAdd(Closed(1), VarCon("a"))) should be (IsMul(Closed(2), IsPow(VarCon("a")))) + // No optimizations + // (mul), (pow), (floor?) + IsPow(IsMul(Closed(2), VarCon("x"))).children should be (Vector(IsMul(Closed(2), VarCon("x")))) + IsPow(IsPow(VarCon("x"))).children should be (Vector(IsPow(VarCon("x")))) + IsPow(IsFloor(VarCon("x"))).children should be (Vector(IsFloor(VarCon("x")))) + } + + "IsFloor" should "reduce properly" in { + // All constants + IsFloor(Closed(1.9)) should be (Closed(1)) + IsFloor(Closed(-1.9)) should be (Closed(-2)) + // Pull out max + IsFloor(IsMax(Closed(1.9), VarCon("a"))) should be (IsMax(Closed(1), IsFloor(VarCon("a")))) + // Pull out min + IsFloor(IsMin(Closed(1.9), VarCon("a"))) should be (IsMin(Closed(1), IsFloor(VarCon("a")))) + // Cancel with another floor + IsFloor(IsFloor(VarCon("a"))) should be (IsFloor(VarCon("a"))) + // No optimizations + // (add), (mul), (pow) + IsFloor(IsMul(Closed(2), VarCon("x"))).children should be (Vector(IsMul(Closed(2), VarCon("x")))) + IsFloor(IsPow(VarCon("x"))).children should be (Vector(IsPow(VarCon("x")))) + IsFloor(IsAdd(Closed(1), VarCon("x"))).children should be (Vector(IsAdd(Closed(1), VarCon("x")))) + } + +} + diff --git a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala index 6bf86479..a34145ac 100644 --- a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala +++ b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala @@ -21,6 +21,36 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { } } + "Fixed types" should "infer add correctly if only precision unspecified" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + ResolveFlows, + CheckFlows, + new InferWidths, + CheckWidths) + val input = + """circuit Unit : + | module Unit : + | input a : Fixed<10><<2>> + | input b : Fixed<10><<0>> + | input c : Fixed<4><<3>> + | output d : Fixed<13> + | d <= add(a, add(b, c))""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input a : Fixed<10><<2>> + | input b : Fixed<10><<0>> + | input c : Fixed<4><<3>> + | output d : Fixed<13><<3>> + | d <= add(a, add(b, c))""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Fixed types" should "infer add correctly" in { val passes = Seq( ToWorkingIR, @@ -36,7 +66,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { """circuit Unit : | module Unit : | input a : Fixed<10><<2>> - | input b : Fixed<10> + | input b : Fixed<10><<0>> | input c : Fixed<4><<3>> | output d : Fixed | d <= add(a, add(b, c))""".stripMargin @@ -119,13 +149,13 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | module Unit : | input a : Fixed<10><<2>> | output d : Fixed - | d <= bpshl(a, 2)""".stripMargin + | d <= incp(a, 2)""".stripMargin val check = """circuit Unit : | module Unit : | input a : Fixed<10><<2>> | output d : Fixed<12><<4>> - | d <= bpshl(a, 2)""".stripMargin + | d <= incp(a, 2)""".stripMargin executeTest(input, check.split("\n") map normalized, passes) } @@ -145,13 +175,13 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | module Unit : | input a : Fixed<10><<2>> | output d : Fixed - | d <= bpshr(a, 2)""".stripMargin + | d <= decp(a, 2)""".stripMargin val check = """circuit Unit : | module Unit : | input a : Fixed<10><<2>> | output d : Fixed<8><<0>> - | d <= bpshr(a, 2)""".stripMargin + | d <= decp(a, 2)""".stripMargin executeTest(input, check.split("\n") map normalized, passes) } @@ -171,13 +201,13 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | module Unit : | input a : Fixed<10><<2>> | output d : Fixed - | d <= bpset(a, 3)""".stripMargin + | d <= setp(a, 3)""".stripMargin val check = """circuit Unit : | module Unit : | input a : Fixed<10><<2>> | output d : Fixed<11><<3>> - | d <= bpset(a, 3)""".stripMargin + | d <= setp(a, 3)""".stripMargin executeTest(input, check.split("\n") map normalized, passes) } diff --git a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala index 8686bd0f..f5b16e45 100644 --- a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala +++ b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala @@ -14,7 +14,6 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { (c: CircuitState, p: Transform) => p.runTransform(c) }.circuit val lines = c.serialize.split("\n") map normalized - println(c.serialize) expected foreach { e => lines should contain(e) @@ -37,7 +36,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { """circuit Unit : | module Unit : | input a : Fixed<10><<2>> - | input b : Fixed<10> + | input b : Fixed<10><<0>> | input c : Fixed<4><<3>> | output d : Fixed<<5>> | d <= add(a, add(b, c))""".stripMargin @@ -67,7 +66,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { """circuit Unit : | module Unit : | input a : Fixed<10><<2>> - | input b : Fixed<10> + | input b : Fixed<10><<0>> | input c : Fixed<4><<3>> | output d : Fixed<<5>> | d <- add(a, add(b, c))""".stripMargin @@ -99,7 +98,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { | module Unit : | input a : Fixed<10><<2>> | output d : Fixed<12><<4>> - | d <= bpshl(a, 2)""".stripMargin + | d <= incp(a, 2)""".stripMargin val check = """circuit Unit : | module Unit : @@ -126,7 +125,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { | module Unit : | input a : Fixed<10><<2>> | output d : Fixed<9><<1>> - | d <= bpshr(a, 1)""".stripMargin + | d <= decp(a, 1)""".stripMargin val check = """circuit Unit : | module Unit : @@ -153,7 +152,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { | module Unit : | input a : Fixed<10><<2>> | output d : Fixed - | d <= bpset(a, 3)""".stripMargin + | d <= setp(a, 3)""".stripMargin val check = """circuit Unit : | module Unit : @@ -181,7 +180,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { class CheckChirrtlTransform extends SeqTransform { def inputForm = ChirrtlForm def outputForm = ChirrtlForm - val transforms = Seq(passes.CheckChirrtl) + def transforms = Seq(passes.CheckChirrtl) } val chirrtlTransform = new CheckChirrtlTransform diff --git a/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala b/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala new file mode 100644 index 00000000..20fdeee1 --- /dev/null +++ b/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala @@ -0,0 +1,183 @@ +// See LICENSE for license details. + +package firrtlTests.interval + +import firrtl.Implicits.constraint2bound +import firrtl.{ChirrtlForm, CircuitState, LowFirrtlCompiler, Parser} +import firrtl.ir._ + +import scala.math.BigDecimal.RoundingMode._ +import firrtl.Parser.IgnoreInfo +import firrtl.constraint._ +import firrtlTests.FirrtlFlatSpec + +class IntervalMathSpec extends FirrtlFlatSpec { + val SumPattern = """.*output sum.*<(\d+)>.*""".r + val ProductPattern = """.*output product.*<(\d+)>.*""".r + val DifferencePattern = """.*output difference.*<(\d+)>.*""".r + val ComparisonPattern = """.*output (\w+).*UInt<(\d+)>.*""".r + val ShiftLeftPattern = """.*output shl.*<(\d+)>.*""".r + val ShiftRightPattern = """.*output shr.*<(\d+)>.*""".r + val DShiftLeftPattern = """.*output dshl.*<(\d+)>.*""".r + val DShiftRightPattern = """.*output dshr.*<(\d+)>.*""".r + val ArithAssignPattern = """\s*(\w+) <= asSInt\(bits\((\w+)\((.*)\).*\)\)\s*""".r + def getBound(bound: String, value: Double): IsKnown = bound match { + case "[" => Closed(BigDecimal(value)) + case "]" => Closed(BigDecimal(value)) + case "(" => Open(BigDecimal(value)) + case ")" => Open(BigDecimal(value)) + } + + val prec = 0.5 + + for { + lb1 <- Seq("[", "(") + lv1 <- Range.Double(-1.0, 1.0, prec) + uv1 <- if(lb1 == "[") Range.Double(lv1, 1.0, prec) else Range.Double(lv1 + prec, 1.0, prec) + ub1 <- if (lv1 == uv1) Seq("]") else Seq("]", ")") + bp1 <- 0 to 1 + lb2 <- Seq("[", "(") + lv2 <- Range.Double(-1.0, 1.0, prec) + uv2 <- if(lb2 == "[") Range.Double(lv2, 1.0, prec) else Range.Double(lv2 + prec, 1.0, prec) + ub2 <- if (lv2 == uv2) Seq("]") else Seq("]", ")") + bp2 <- 0 to 1 + } { + val it1 = IntervalType(getBound(lb1, lv1), getBound(ub1, uv1), IntWidth(bp1.toInt)) + val it2 = IntervalType(getBound(lb2, lv2), getBound(ub2, uv2), IntWidth(bp2.toInt)) + (it1.range, it2.range) match { + case (Some(Nil), _) => + case (_, Some(Nil)) => + case _ => + def config = s"$lb1$lv1,$uv1$ub1.$bp1 and $lb2$lv2,$uv2$ub2.$bp2" + + s"Configuration $config" should "pass" in { + + val input = + s"""circuit Unit : + | module Unit : + | input in1 : Interval$lb1$lv1, $uv1$ub1.$bp1 + | input in2 : Interval$lb2$lv2, $uv2$ub2.$bp2 + | input amt : UInt<3> + | output sum : Interval + | output difference : Interval + | output product : Interval + | output shl : Interval + | output shr : Interval + | output dshl : Interval + | output dshr : Interval + | output lt : UInt + | output leq : UInt + | output gt : UInt + | output geq : UInt + | output eq : UInt + | output neq : UInt + | output cat : UInt + | sum <= add(in1, in2) + | difference <= sub(in1, in2) + | product <= mul(in1, in2) + | shl <= shl(in1, 3) + | shr <= shr(in1, 3) + | dshl <= dshl(in1, amt) + | dshr <= dshr(in1, amt) + | lt <= lt(in1, in2) + | leq <= leq(in1, in2) + | gt <= gt(in1, in2) + | geq <= geq(in1, in2) + | eq <= eq(in1, in2) + | neq <= lt(in1, in2) + | cat <= cat(in1, in2) + | """.stripMargin + + val lowerer = new LowFirrtlCompiler + val res = lowerer.compileAndEmit(CircuitState(parse(input), ChirrtlForm)) + val output = res.getEmittedCircuit.value split "\n" + val min1 = Closed(it1.min.get) + val max1 = Closed(it1.max.get) + val min2 = Closed(it2.min.get) + val max2 = Closed(it2.max.get) + for (line <- output) { + line match { + case SumPattern(varWidth) => + val bp = IntWidth(Math.max(bp1.toInt, bp2.toInt)) + val it = IntervalType(IsAdd(min1, min2), IsAdd(max1, max2), bp) + assert(varWidth.toInt == it.width.asInstanceOf[IntWidth].width, s"$line,${it.range}") + case ProductPattern(varWidth) => + val bp = IntWidth(bp1.toInt + bp2.toInt) + val lv = IsMin(Seq(IsMul(min1, min2), IsMul(min1, max2), IsMul(max1, min2), IsMul(max1, max2))) + val uv = IsMax(Seq(IsMul(min1, min2), IsMul(min1, max2), IsMul(max1, min2), IsMul(max1, max2))) + assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "product") + case DifferencePattern(varWidth) => + val bp = IntWidth(Math.max(bp1.toInt, bp2.toInt)) + val lv = min1 + max2.neg + val uv = max1 + min2.neg + assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "diff") + case ShiftLeftPattern(varWidth) => + val bp = IntWidth(bp1.toInt) + val lv = min1 * Closed(8) + val uv = max1 * Closed(8) + val it = IntervalType(lv, uv, bp) + assert(varWidth.toInt == it.width.asInstanceOf[IntWidth].width, "shl") + case ShiftRightPattern(varWidth) => + val bp = IntWidth(bp1.toInt) + val lv = min1 * Closed(1/3) + val uv = max1 * Closed(1/3) + assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "shr") + case DShiftLeftPattern(varWidth) => + val bp = IntWidth(bp1.toInt) + val lv = min1 * Closed(128) + val uv = max1 * Closed(128) + assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "dshl") + case DShiftRightPattern(varWidth) => + val bp = IntWidth(bp1.toInt) + val lv = min1 + val uv = max1 + assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "dshr") + case ComparisonPattern(varWidth) => assert(varWidth.toInt == 1, "==") + case ArithAssignPattern(varName, operation, args) => + val arg1 = if(IntervalType(getBound(lb1, lv1), getBound(ub1, uv1), IntWidth(bp1)).width == IntWidth(0)) """SInt<1>("h0")""" else "in1" + val arg2 = if(IntervalType(getBound(lb2, lv2), getBound(ub2, uv2), IntWidth(bp2)).width == IntWidth(0)) """SInt<1>("h0")""" else "in2" + varName match { + case "sum" => + assert(operation === "add", s"""var sum should be result of an add in ${output.mkString("\n")}""") + if (bp1 > bp2) { + if (arg1 != arg2) assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(args.contains(s"shl($arg2, ${bp1 - bp2})"), + s"$config second arg incorrect in $line") + } else if (bp1 < bp2) { + assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"), + s"$config second arg incorrect in $line") + assert(!args.contains("shl($arg2"), s"$config second arg should be just $arg2 in $line") + } else { + assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") + } + case "product" => + assert(operation === "mul", s"var sum should be result of an add in $line") + assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") + case "difference" => + assert(operation === "sub", s"var difference should be result of an sub in $line") + if (bp1 > bp2) { + if (arg1 != arg2) assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(args.contains(s"shl($arg2, ${bp1 - bp2})"), + s"$config second arg incorrect in $line") + } else if (bp1 < bp2) { + assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"), + s"$config second arg incorrect in $line") + if (arg1 != arg2) assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") + } else { + assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") + } + case _ => + } + case _ => + } + } + } + } + } +} + + +// vim: set ts=4 sw=4 et: diff --git a/src/test/scala/firrtlTests/interval/IntervalSpec.scala b/src/test/scala/firrtlTests/interval/IntervalSpec.scala new file mode 100644 index 00000000..37d79c84 --- /dev/null +++ b/src/test/scala/firrtlTests/interval/IntervalSpec.scala @@ -0,0 +1,530 @@ +package firrtlTests +package interval + +import java.io._ + +import firrtl._ +import firrtl.ir.Circuit +import firrtl.passes._ +import firrtl.Parser.IgnoreInfo +import firrtl.passes.CheckTypes.InvalidConnect +import firrtl.passes.CheckWidths.DisjointSqueeze + +class IntervalSpec extends FirrtlFlatSpec { + private def executeTest(input: String, expected: Seq[String], passes: Seq[Transform]) = { + val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { + (c: Circuit, p: Transform) => + p.runTransform(CircuitState(c, UnknownForm, AnnotationSeq(Nil), None)).circuit + } + val lines = c.serialize.split("\n") map normalized + + expected foreach { e => + lines should contain(e) + } + } + + "Interval types" should "parse correctly" in { + val passes = Seq(ToWorkingIR) + val input = + """circuit Unit : + | module Unit : + | input in0 : Interval(-0.32, 10.1).4 + | input in1 : Interval[0, 10.1].4 + | input in2 : Interval(-0.32, 10].4 + | input in3 : Interval[-3, 10.1).4 + | input in4 : Interval(-0.32, 10.1) + | input in5 : Interval.4 + | input in6 : Interval + | output out0 : Interval.2 + | output out1 : Interval + | out0 <= add(in0, add(in1, add(in2, add(in3, add(in4, add(in5, in6)))))) + | out1 <= add(in0, add(in1, add(in2, add(in3, add(in4, add(in5, in6))))))""".stripMargin + executeTest(input, input.split("\n") map normalized, passes) + } + + "Interval types" should "infer bp correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints()) + val input = + """circuit Unit : + | module Unit : + | input in0 : Interval(-0.32, 10.1).4 + | input in1 : Interval[0, 10.1].3 + | input in2 : Interval(-0.32, 10].2 + | output out0 : Interval + | out0 <= add(in0, add(in1, in2))""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input in0 : Interval(-0.32, 10.1).4 + | input in1 : Interval[0, 10.1].3 + | input in2 : Interval(-0.32, 10].2 + | output out0 : Interval.4 + | out0 <= add(in0, add(in1, in2))""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + + "Interval types" should "trim known intervals correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals()) + val input = + """circuit Unit : + | module Unit : + | input in0 : Interval(-0.32, 10.1).4 + | input in1 : Interval[0, 10.1].3 + | input in2 : Interval(-0.32, 10].2 + | output out0 : Interval + | out0 <= add(in0, add(in1, in2))""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input in0 : Interval[-0.3125, 10.0625].4 + | input in1 : Interval[0, 10].3 + | input in2 : Interval[-0.25, 10].2 + | output out0 : Interval.4 + | out0 <= add(in0, incp(add(in1, incp(in2, 1)), 1))""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + + "Interval types" should "infer intervals correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) + val input = + """circuit Unit : + | module Unit : + | input in0 : Interval(0, 10).4 + | input in1 : Interval(0, 10].3 + | input in2 : Interval(-1, 3].2 + | output out0 : Interval + | output out1 : Interval + | output out2 : Interval + | out0 <= add(in0, add(in1, in2)) + | out1 <= mul(in0, mul(in1, in2)) + | out2 <= sub(in0, sub(in1, in2))""".stripMargin + val check = + """output out0 : Interval[-0.5625, 22.9375].4 + |output out1 : Interval[-74.53125, 298.125].9 + |output out2 : Interval[-10.6875, 12.8125].4""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + + "Interval types" should "be removed correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths(), new RemoveIntervals()) + val input = + """circuit Unit : + | module Unit : + | input in0 : Interval(0, 10).4 + | input in1 : Interval(0, 10].3 + | input in2 : Interval(-1, 3].2 + | output out0 : Interval + | output out1 : Interval + | output out2 : Interval + | out0 <= add(in0, add(in1, in2)) + | out1 <= mul(in0, mul(in1, in2)) + | out2 <= sub(in0, sub(in1, in2))""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input in0 : SInt<9> + | input in1 : SInt<8> + | input in2 : SInt<5> + | output out0 : SInt<10> + | output out1 : SInt<19> + | output out2 : SInt<9> + | out0 <= add(in0, shl(add(in1, shl(in2, 1)), 1)) + | out1 <= mul(in0, mul(in1, in2)) + | out2 <= sub(in0, shl(sub(in1, shl(in2, 1)), 1))""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + +"Interval types" should "infer multiplication by zero correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) + val input = + s"""circuit Unit : + | module Unit : + | input in1 : Interval[0, 0.5].1 + | input in2 : Interval[0, 0].1 + | output mul : Interval + | mul <= mul(in2, in1) + | """.stripMargin + val check = s"""output mul : Interval[0, 0].2 """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) +} + + "Interval types" should "infer muxes correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) + val input = + s"""circuit Unit : + | module Unit : + | input p : UInt<1> + | input in1 : Interval[0, 0.5].1 + | input in2 : Interval[0, 0].1 + | output out : Interval + | out <= mux(p, in2, in1) + | """.stripMargin + val check = s"""output out : Interval[0, 0.5].1 """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "infer dshl correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveKinds, ResolveGenders, new InferBinaryPoints(), new TrimIntervals, new InferWidths()) + val input = + s"""circuit Unit : + | module Unit : + | input p : UInt<3> + | input in1 : Interval[-1, 1].0 + | output out : Interval + | out <= dshl(in1, p) + | """.stripMargin + val check = s"""output out : Interval[-128, 128].0 """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "infer asInterval correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferWidths()) + val input = + s"""circuit Unit : + | module Unit : + | input p : UInt<3> + | output out : Interval + | out <= asInterval(p, 0, 4, 1) + | """.stripMargin + val check = s"""output out : Interval[0, 2].1 """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "do wrap/clip correctly" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck()) + val input = + s"""circuit Unit : + | module Unit : + | input s: SInt<2> + | input u: UInt<3> + | input in1: Interval[-3, 5].0 + | output wrap3: Interval + | output wrap4: Interval + | output wrap5: Interval + | output wrap6: Interval + | output wrap7: Interval + | output clip3: Interval + | output clip4: Interval + | output clip5: Interval + | output clip6: Interval + | output clip7: Interval + | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0)) + | wrap4 <= wrap(in1, asInterval(s, -1, 1, 0)) + | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0)) + | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0)) + | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0)) + | clip3 <= clip(in1, asInterval(s, -2, 4, 0)) + | clip4 <= clip(in1, asInterval(s, -1, 1, 0)) + | clip5 <= clip(in1, asInterval(s, -4, 4, 0)) + | clip6 <= clip(in1, asInterval(s, -1, 7, 0)) + | clip7 <= clip(in1, asInterval(s, -4, 7, 0)) + """.stripMargin + //| output wrap1: Interval + //| output wrap2: Interval + //| output clip1: Interval + //| output clip2: Interval + //| wrap1 <= wrap(in1, u, 0) + //| wrap2 <= wrap(in1, s, 0) + //| clip1 <= clip(in1, u) + //| clip2 <= clip(in1, s) + val check = s""" + | output wrap3 : Interval[-2, 4].0 + | output wrap4 : Interval[-1, 1].0 + | output wrap5 : Interval[-4, 4].0 + | output wrap6 : Interval[-1, 7].0 + | output wrap7 : Interval[-4, 7].0 + | output clip3 : Interval[-2, 4].0 + | output clip4 : Interval[-1, 1].0 + | output clip5 : Interval[-3, 4].0 + | output clip6 : Interval[-1, 5].0 + | output clip7 : Interval[-3, 5].0 """.stripMargin + // TODO: this optimization + //| output wrap1 : Interval[0, 7].0 + //| output wrap2 : Interval[-2, 1].0 + //| output clip1 : Interval[0, 5].0 + //| output clip2 : Interval[-2, 1].0 + //| output wrap7 : Interval[-3, 5].0 + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "remove wrap/clip correctly" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck(), new RemoveIntervals()) + val input = + s"""circuit Unit : + | module Unit : + | input s: SInt<2> + | input u: UInt<3> + | input in1: Interval[-3, 5].0 + | output wrap3: Interval + | output wrap5: Interval + | output wrap6: Interval + | output wrap7: Interval + | output clip3: Interval + | output clip4: Interval + | output clip5: Interval + | output clip6: Interval + | output clip7: Interval + | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0)) + | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0)) + | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0)) + | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0)) + | clip3 <= clip(in1, asInterval(s, -2, 4, 0)) + | clip4 <= clip(in1, asInterval(s, -1, 1, 0)) + | clip5 <= clip(in1, asInterval(s, -4, 4, 0)) + | clip6 <= clip(in1, asInterval(s, -1, 7, 0)) + | clip7 <= clip(in1, asInterval(s, -4, 7, 0)) + | """.stripMargin + val check = s""" + | wrap3 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<4>("h7")), mux(lt(in1, SInt<2>("h-2")), add(in1, SInt<4>("h7")), in1)) + | wrap5 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), in1) + | wrap6 <= mux(lt(in1, SInt<1>("h-1")), add(in1, SInt<5>("h9")), in1) + | wrap7 <= in1 + | clip3 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<2>("h-2")), SInt<2>("h-2"), in1)) + | clip4 <= mux(gt(in1, SInt<2>("h1")), SInt<2>("h1"), mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1)) + | clip5 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), in1) + | clip6 <= mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1) + | clip7 <= in1 + """.stripMargin + //| output wrap4: Interval + //| wrap4 <= wrap(in1, asInterval(s, -1, 1, 0), 0) + //| wrap4 <= add(rem(sub(in1, SInt<1>("h-1")), sub(SInt<2>("h1"), SInt<1>("h-1"))), SInt<1>("h-1")) + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "shift wrap/clip correctly" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals()) + val input = + s"""circuit Unit : + | module Unit : + | input s: SInt<2> + | input in1: Interval[-3, 5].1 + | output wrap1: Interval + | output clip1: Interval + | wrap1 <= wrap(in1, asInterval(s, -2, 2, 0)) + | clip1 <= clip(in1, asInterval(s, -2, 2, 0)) + | """.stripMargin + val check = s""" + | wrap1 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), mux(lt(in1, SInt<3>("h-4")), add(in1, SInt<5>("h9")), in1)) + | clip1 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<3>("h-4")), SInt<3>("h-4"), in1)) + """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "infer negative binary points" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck()) + val input = + s"""circuit Unit : + | module Unit : + | input in1: Interval[-2, 4].-1 + | input in2: Interval[-4, 8].-2 + | output out: Interval + | out <= add(in1, in2) + | """.stripMargin + val check = s""" + | output out : Interval[-6, 12].-1 + """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "remove negative binary points" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths(), new RemoveIntervals()) + val input = + s"""circuit Unit : + | module Unit : + | input in1: Interval[-2, 4].-1 + | input in2: Interval[-4, 8].-2 + | output out: Interval.0 + | out <= add(in1, in2) + | """.stripMargin + val check = s""" + | output out : SInt<5> + | out <= shl(add(in1, shl(in2, 1)), 1) + """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "implement squz properly" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck) + val input = + s"""circuit Unit : + | module Unit : + | input min: Interval[-1, 4].1 + | input max: Interval[-3, 5].1 + | input left: Interval[-3, 3].1 + | input right: Interval[0, 5].1 + | input off: Interval[-1, 4].2 + | output minMax: Interval + | output maxMin: Interval + | output minLeft: Interval + | output leftMin: Interval + | output minRight: Interval + | output rightMin: Interval + | output minOff: Interval + | output offMin: Interval + | + | minMax <= squz(min, max) + | maxMin <= squz(max, min) + | minLeft <= squz(min, left) + | leftMin <= squz(left, min) + | minRight <= squz(min, right) + | rightMin <= squz(right, min) + | minOff <= squz(min, off) + | offMin <= squz(off, min) + | """.stripMargin + val check = + s""" + | output minMax : Interval[-1, 4].1 + | output maxMin : Interval[-1, 4].1 + | output minLeft : Interval[-1, 3].1 + | output leftMin : Interval[-1, 3].1 + | output minRight : Interval[0, 4].1 + | output rightMin : Interval[0, 4].1 + | output minOff : Interval[-1, 4].1 + | output offMin : Interval[-1, 4].2 + """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "lower squz properly" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals) + val input = + s"""circuit Unit : + | module Unit : + | input min: Interval[-1, 4].1 + | input max: Interval[-3, 5].1 + | input left: Interval[-3, 3].1 + | input right: Interval[0, 5].1 + | input off: Interval[-1, 4].2 + | output minMax: Interval + | output maxMin: Interval + | output minLeft: Interval + | output leftMin: Interval + | output minRight: Interval + | output rightMin: Interval + | output minOff: Interval + | output offMin: Interval + | + | minMax <= squz(min, max) + | maxMin <= squz(max, min) + | minLeft <= squz(min, left) + | leftMin <= squz(left, min) + | minRight <= squz(min, right) + | rightMin <= squz(right, min) + | minOff <= squz(min, off) + | offMin <= squz(off, min) + | """.stripMargin + val check = + s""" + | minMax <= asSInt(bits(min, 4, 0)) + | maxMin <= asSInt(bits(max, 4, 0)) + | minLeft <= asSInt(bits(min, 3, 0)) + | leftMin <= left + | minRight <= asSInt(bits(min, 4, 0)) + | rightMin <= asSInt(bits(right, 4, 0)) + | minOff <= asSInt(bits(min, 4, 0)) + | offMin <= asSInt(bits(off, 5, 0)) + """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Assigning a larger interval to a smaller interval" should "error!" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals) + val input = + s"""circuit Unit : + | module Unit : + | input in: Interval[1, 4].1 + | output out: Interval[2, 3].1 + | out <= in + | """.stripMargin + intercept[InvalidConnect]{ + executeTest(input, Nil, passes) + } + } + "Assigning a more precise interval to a less precise interval" should "error!" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals) + val input = + s"""circuit Unit : + | module Unit : + | input in: Interval[2, 3].3 + | output out: Interval[2, 3].1 + | out <= in + | """.stripMargin + intercept[InvalidConnect]{ + executeTest(input, Nil, passes) + } + } + "Chick's example" should "work" in { + val input = + s"""circuit IntervalChainedSubTester : + | module IntervalChainedSubTester : + | input clock : Clock + | input reset : UInt<1> + | node _GEN_0 = sub(SInt<6>("h11"), SInt<6>("h2")) @[IntervalSpec.scala 337:26 IntervalSpec.scala 337:26] + | node _GEN_1 = bits(_GEN_0, 4, 0) @[IntervalSpec.scala 337:26 IntervalSpec.scala 337:26] + | node intervalResult = asSInt(_GEN_1) @[IntervalSpec.scala 337:26 IntervalSpec.scala 337:26] + | skip + | node _T_1 = asUInt(intervalResult) @[IntervalSpec.scala 338:50] + | skip + | node _T_3 = eq(reset, UInt<1>("h0")) @[IntervalSpec.scala 338:9] + | node _T_4 = eq(intervalResult, SInt<5>("hf")) @[IntervalSpec.scala 339:25] + | skip + | node _T_6 = or(_T_4, reset) @[IntervalSpec.scala 339:9] + | node _T_7 = eq(_T_6, UInt<1>("h0")) @[IntervalSpec.scala 339:9] + | skip + | skip + | printf(clock, _T_3, "Interval result: %d", _T_1) @[IntervalSpec.scala 338:9] + | printf(clock, _T_7, "Assertion failed at IntervalSpec.scala:339 assert(intervalResult === 15.I)") @[IntervalSpec.scala 339:9] + | stop(clock, _T_7, 1) @[IntervalSpec.scala 339:9] + | stop(clock, _T_3, 0) @[IntervalSpec.scala 340:7] + | + """.stripMargin + compileToVerilog(input) + } + + "Squeeze with disjoint intervals" should "error" in { + intercept[DisjointSqueeze] { + val input = + s"""circuit Unit : + | module Unit : + | input in1: Interval[2, 3).3 + | input in2: Interval[3, 6].3 + | node out = squz(in1, in2) + """.stripMargin + compileToVerilog(input) + } + intercept[DisjointSqueeze] { + val input = + s"""circuit Unit : + | module Unit : + | input in1: Interval[2, 3).3 + | input in2: Interval[3, 6].3 + | node out = squz(in2, in1) + """.stripMargin + compileToVerilog(input) + } + } + + "Clip with disjoint intervals" should "work" in { + compileToVerilog( + s"""circuit Unit : + | module Unit : + | input in1: Interval[2, 3).3 + | input in2: Interval[3, 6].3 + | output out: Interval + | out <= clip(in1, in2) + """.stripMargin + ) + compileToVerilog( + s"""circuit Unit : + | module Unit : + | input in1: Interval[2, 3).3 + | input in2: Interval[4, 6].3 + | node out = clip(in1, in2) + """.stripMargin + ) + } + + + "Wrap with remainder" should "error" in { + intercept[WrapWithRemainder] { + val input = + s"""circuit Unit : + | module Unit : + | input in1: Interval[0, 300).3 + | input in2: Interval[3, 6].3 + | node out = wrap(in1, in2) + """.stripMargin + compileToVerilog(input) + } + } +} diff --git a/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala b/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala index 46fb310a..88095830 100644 --- a/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala +++ b/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala @@ -152,7 +152,7 @@ class InferWidthsWithAnnosSpec extends FirrtlFlatSpec { } "InferWidthsWithAnnos" should "work with WiringTransform" in { - def transforms = Seq( + def transforms() = Seq( ToWorkingIR, ResolveKinds, InferTypes, |
