// SPDX-License-Identifier: Apache-2.0 package firrtl.passes import firrtl.PrimOps._ import firrtl.ir._ import firrtl.Mappers._ import firrtl.constraint.{IsFloor, IsKnown, IsMul} import firrtl.options.Dependency import firrtl.Transform /** Replaces IntervalType with SIntType, three AST walks: * 1) Align binary points * - adds shift operators to primop args and connections * - does not affect declaration- or inferred-types * 2) Replace declaration IntervalType's with SIntType's * - for each declaration: * a. remove non-zero binary points * b. remove open bounds * c. replace with SIntType * 3) Run InferTypes */ class TrimIntervals extends Pass { override def prerequisites = Seq(Dependency(ResolveKinds), Dependency(InferTypes), Dependency(ResolveFlows), Dependency[InferBinaryPoints]) override def optionalPrerequisiteOf = Seq.empty override def invalidates(a: Transform) = false def run(c: Circuit): Circuit = { // Open -> closed val firstPass = InferTypes.run(c.map(replaceModuleInterval)) // Align binary points and adjust range accordingly (loss of precision changes range) firstPass.map(alignModuleBP) } /* Replace interval types */ private def replaceModuleInterval(m: DefModule): DefModule = m.map(replaceStmtInterval).map(replacePortInterval) private def replaceStmtInterval(s: Statement): Statement = s.map(replaceTypeInterval).map(replaceStmtInterval) private def replacePortInterval(p: Port): Port = p.map(replaceTypeInterval) private def replaceTypeInterval(t: Type): Type = t match { case i @ IntervalType(l: IsKnown, u: IsKnown, IntWidth(p)) => IntervalType(Closed(i.min.get), Closed(i.max.get), IntWidth(p)) case i: IntervalType => i case v => v.map(replaceTypeInterval) } /* Align interval binary points -- BINARY POINT ALIGNMENT AFFECTS RANGE INFERENCE! */ private def alignModuleBP(m: DefModule): DefModule = m.map(alignStmtBP) private def alignStmtBP(s: Statement): Statement = s.map(alignExpBP) match { case c @ Connect(info, loc, expr) => loc.tpe match { case IntervalType(_, _, p) => Connect(info, loc, fixBP(p)(expr)) case _ => c } case c @ PartialConnect(info, loc, expr) => loc.tpe match { case IntervalType(_, _, p) => PartialConnect(info, loc, fixBP(p)(expr)) case _ => c } case other => other.map(alignStmtBP) } // Note - wrap/clip/squeeze ignore the binary point of the second argument, thus not needed to be aligned // Note - Mul does not need its binary points aligned, because multiplication is cool like that private val opsToFix = Seq(Add, Sub, Lt, Leq, Gt, Geq, Eq, Neq /*, Wrap, Clip, Squeeze*/ ) private def alignExpBP(e: Expression): Expression = e.map(alignExpBP) match { case DoPrim(SetP, Seq(arg), Seq(const), tpe: IntervalType) => fixBP(IntWidth(const))(arg) case DoPrim(o, args, consts, t) if opsToFix.contains(o) && (args.map(_.tpe).collect { case x: IntervalType => x }).size == args.size => val maxBP = args.map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ max _) DoPrim(o, args.map { a => fixBP(maxBP)(a) }, consts, t) case Mux(cond, tval, fval, t: IntervalType) => val maxBP = Seq(tval, fval).map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ max _) Mux(cond, fixBP(maxBP)(tval), fixBP(maxBP)(fval), t) case other => other } private def fixBP(p: Width)(e: Expression): Expression = (p, e.tpe) match { case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired == current => e case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired > current => DoPrim(IncP, Seq(e), Seq(desired - current), IntervalType(l, u, IntWidth(desired))) case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired < current => val shiftAmt = current - desired val shiftGain = BigDecimal(BigInt(1) << shiftAmt.toInt) val shiftMul = Closed(BigDecimal(1) / shiftGain) val bpGain = BigDecimal(BigInt(1) << current.toInt) // BP is inferred at this point // y = floor(x * 2^(-amt + bp)) gets rid of precision --> y * 2^(-bp + amt) val newBPRes = Closed(shiftGain / bpGain) val bpResInv = Closed(bpGain) val newL = IsMul(IsFloor(IsMul(IsMul(l, shiftMul), bpResInv)), newBPRes) val newU = IsMul(IsFloor(IsMul(IsMul(u, shiftMul), bpResInv)), newBPRes) DoPrim(DecP, Seq(e), Seq(current - desired), IntervalType(CalcBound(newL), CalcBound(newU), IntWidth(desired))) case x => sys.error(s"Shouldn't be here: $x") } } // vim: set ts=4 sw=4 et: