diff options
Diffstat (limited to 'src/main/scala/firrtl/passes/TrimIntervals.scala')
| -rw-r--r-- | src/main/scala/firrtl/passes/TrimIntervals.scala | 97 |
1 files changed, 97 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/passes/TrimIntervals.scala b/src/main/scala/firrtl/passes/TrimIntervals.scala new file mode 100644 index 00000000..dec64ee7 --- /dev/null +++ b/src/main/scala/firrtl/passes/TrimIntervals.scala @@ -0,0 +1,97 @@ +// See LICENSE for license details. + +package firrtl.passes + +import scala.collection.mutable +import firrtl.PrimOps._ +import firrtl.ir._ +import firrtl._ +import firrtl.Mappers._ +import firrtl.Utils.{error, field_type, getUIntWidth, max, module_type, sub_type} +import Implicits.{bigint2WInt, int2WInt} +import firrtl.constraint.{IsFloor, IsKnown, IsMul} + +/** Replaces IntervalType with SIntType, three AST walks: + * 1) Align binary points + * - adds shift operators to primop args and connections + * - does not affect declaration- or inferred-types + * 2) Replace declaration IntervalType's with SIntType's + * - for each declaration: + * a. remove non-zero binary points + * b. remove open bounds + * c. replace with SIntType + * 3) Run InferTypes + */ +class TrimIntervals extends Pass { + def run(c: Circuit): Circuit = { + // Open -> closed + val firstPass = InferTypes.run(c map replaceModuleInterval) + // Align binary points and adjust range accordingly (loss of precision changes range) + firstPass map alignModuleBP + } + + /* Replace interval types */ + private def replaceModuleInterval(m: DefModule): DefModule = m map replaceStmtInterval map replacePortInterval + + private def replaceStmtInterval(s: Statement): Statement = s map replaceTypeInterval map replaceStmtInterval + + private def replacePortInterval(p: Port): Port = p map replaceTypeInterval + + private def replaceTypeInterval(t: Type): Type = t match { + case i@IntervalType(l: IsKnown, u: IsKnown, IntWidth(p)) => + IntervalType(Closed(i.min.get), Closed(i.max.get), IntWidth(p)) + case i: IntervalType => i + case v => v map replaceTypeInterval + } + + /* Align interval binary points -- BINARY POINT ALIGNMENT AFFECTS RANGE INFERENCE! */ + private def alignModuleBP(m: DefModule): DefModule = m map alignStmtBP + + private def alignStmtBP(s: Statement): Statement = s map alignExpBP match { + case c@Connect(info, loc, expr) => loc.tpe match { + case IntervalType(_, _, p) => Connect(info, loc, fixBP(p)(expr)) + case _ => c + } + case c@PartialConnect(info, loc, expr) => loc.tpe match { + case IntervalType(_, _, p) => PartialConnect(info, loc, fixBP(p)(expr)) + case _ => c + } + case other => other map alignStmtBP + } + + // Note - wrap/clip/squeeze ignore the binary point of the second argument, thus not needed to be aligned + // Note - Mul does not need its binary points aligned, because multiplication is cool like that + private val opsToFix = Seq(Add, Sub, Lt, Leq, Gt, Geq, Eq, Neq/*, Wrap, Clip, Squeeze*/) + + private def alignExpBP(e: Expression): Expression = e map alignExpBP match { + case DoPrim(SetP, Seq(arg), Seq(const), tpe: IntervalType) => fixBP(IntWidth(const))(arg) + case DoPrim(o, args, consts, t) if opsToFix.contains(o) && + (args.map(_.tpe).collect { case x: IntervalType => x }).size == args.size => + val maxBP = args.map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ max _) + DoPrim(o, args.map { a => fixBP(maxBP)(a) }, consts, t) + case Mux(cond, tval, fval, t: IntervalType) => + val maxBP = Seq(tval, fval).map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ max _) + Mux(cond, fixBP(maxBP)(tval), fixBP(maxBP)(fval), t) + case other => other + } + private def fixBP(p: Width)(e: Expression): Expression = (p, e.tpe) match { + case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired == current => e + case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired > current => + DoPrim(IncP, Seq(e), Seq(desired - current), IntervalType(l, u, IntWidth(desired))) + case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired < current => + val shiftAmt = current - desired + val shiftGain = BigDecimal(BigInt(1) << shiftAmt.toInt) + val shiftMul = Closed(BigDecimal(1) / shiftGain) + val bpGain = BigDecimal(BigInt(1) << current.toInt) + // BP is inferred at this point + // y = floor(x * 2^(-amt + bp)) gets rid of precision --> y * 2^(-bp + amt) + val newBPRes = Closed(shiftGain / bpGain) + val bpResInv = Closed(bpGain) + val newL = IsMul(IsFloor(IsMul(IsMul(l, shiftMul), bpResInv)), newBPRes) + val newU = IsMul(IsFloor(IsMul(IsMul(u, shiftMul), bpResInv)), newBPRes) + DoPrim(DecP, Seq(e), Seq(current - desired), IntervalType(CalcBound(newL), CalcBound(newU), IntWidth(desired))) + case x => sys.error(s"Shouldn't be here: $x") + } +} + +// vim: set ts=4 sw=4 et: |
