aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/TrimIntervals.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/passes/TrimIntervals.scala')
-rw-r--r--src/main/scala/firrtl/passes/TrimIntervals.scala97
1 files changed, 97 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/passes/TrimIntervals.scala b/src/main/scala/firrtl/passes/TrimIntervals.scala
new file mode 100644
index 00000000..dec64ee7
--- /dev/null
+++ b/src/main/scala/firrtl/passes/TrimIntervals.scala
@@ -0,0 +1,97 @@
+// See LICENSE for license details.
+
+package firrtl.passes
+
+import scala.collection.mutable
+import firrtl.PrimOps._
+import firrtl.ir._
+import firrtl._
+import firrtl.Mappers._
+import firrtl.Utils.{error, field_type, getUIntWidth, max, module_type, sub_type}
+import Implicits.{bigint2WInt, int2WInt}
+import firrtl.constraint.{IsFloor, IsKnown, IsMul}
+
+/** Replaces IntervalType with SIntType, three AST walks:
+ * 1) Align binary points
+ * - adds shift operators to primop args and connections
+ * - does not affect declaration- or inferred-types
+ * 2) Replace declaration IntervalType's with SIntType's
+ * - for each declaration:
+ * a. remove non-zero binary points
+ * b. remove open bounds
+ * c. replace with SIntType
+ * 3) Run InferTypes
+ */
+class TrimIntervals extends Pass {
+ def run(c: Circuit): Circuit = {
+ // Open -> closed
+ val firstPass = InferTypes.run(c map replaceModuleInterval)
+ // Align binary points and adjust range accordingly (loss of precision changes range)
+ firstPass map alignModuleBP
+ }
+
+ /* Replace interval types */
+ private def replaceModuleInterval(m: DefModule): DefModule = m map replaceStmtInterval map replacePortInterval
+
+ private def replaceStmtInterval(s: Statement): Statement = s map replaceTypeInterval map replaceStmtInterval
+
+ private def replacePortInterval(p: Port): Port = p map replaceTypeInterval
+
+ private def replaceTypeInterval(t: Type): Type = t match {
+ case i@IntervalType(l: IsKnown, u: IsKnown, IntWidth(p)) =>
+ IntervalType(Closed(i.min.get), Closed(i.max.get), IntWidth(p))
+ case i: IntervalType => i
+ case v => v map replaceTypeInterval
+ }
+
+ /* Align interval binary points -- BINARY POINT ALIGNMENT AFFECTS RANGE INFERENCE! */
+ private def alignModuleBP(m: DefModule): DefModule = m map alignStmtBP
+
+ private def alignStmtBP(s: Statement): Statement = s map alignExpBP match {
+ case c@Connect(info, loc, expr) => loc.tpe match {
+ case IntervalType(_, _, p) => Connect(info, loc, fixBP(p)(expr))
+ case _ => c
+ }
+ case c@PartialConnect(info, loc, expr) => loc.tpe match {
+ case IntervalType(_, _, p) => PartialConnect(info, loc, fixBP(p)(expr))
+ case _ => c
+ }
+ case other => other map alignStmtBP
+ }
+
+ // Note - wrap/clip/squeeze ignore the binary point of the second argument, thus not needed to be aligned
+ // Note - Mul does not need its binary points aligned, because multiplication is cool like that
+ private val opsToFix = Seq(Add, Sub, Lt, Leq, Gt, Geq, Eq, Neq/*, Wrap, Clip, Squeeze*/)
+
+ private def alignExpBP(e: Expression): Expression = e map alignExpBP match {
+ case DoPrim(SetP, Seq(arg), Seq(const), tpe: IntervalType) => fixBP(IntWidth(const))(arg)
+ case DoPrim(o, args, consts, t) if opsToFix.contains(o) &&
+ (args.map(_.tpe).collect { case x: IntervalType => x }).size == args.size =>
+ val maxBP = args.map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ max _)
+ DoPrim(o, args.map { a => fixBP(maxBP)(a) }, consts, t)
+ case Mux(cond, tval, fval, t: IntervalType) =>
+ val maxBP = Seq(tval, fval).map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ max _)
+ Mux(cond, fixBP(maxBP)(tval), fixBP(maxBP)(fval), t)
+ case other => other
+ }
+ private def fixBP(p: Width)(e: Expression): Expression = (p, e.tpe) match {
+ case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired == current => e
+ case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired > current =>
+ DoPrim(IncP, Seq(e), Seq(desired - current), IntervalType(l, u, IntWidth(desired)))
+ case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired < current =>
+ val shiftAmt = current - desired
+ val shiftGain = BigDecimal(BigInt(1) << shiftAmt.toInt)
+ val shiftMul = Closed(BigDecimal(1) / shiftGain)
+ val bpGain = BigDecimal(BigInt(1) << current.toInt)
+ // BP is inferred at this point
+ // y = floor(x * 2^(-amt + bp)) gets rid of precision --> y * 2^(-bp + amt)
+ val newBPRes = Closed(shiftGain / bpGain)
+ val bpResInv = Closed(bpGain)
+ val newL = IsMul(IsFloor(IsMul(IsMul(l, shiftMul), bpResInv)), newBPRes)
+ val newU = IsMul(IsFloor(IsMul(IsMul(u, shiftMul), bpResInv)), newBPRes)
+ DoPrim(DecP, Seq(e), Seq(current - desired), IntervalType(CalcBound(newL), CalcBound(newU), IntWidth(desired)))
+ case x => sys.error(s"Shouldn't be here: $x")
+ }
+}
+
+// vim: set ts=4 sw=4 et: