diff options
| author | Kevin Laeufer | 2021-08-02 13:46:29 -0700 |
|---|---|---|
| committer | GitHub | 2021-08-02 20:46:29 +0000 |
| commit | e04f1e7f303920ac1d1f865450d0e280aafb58b3 (patch) | |
| tree | 73f26cd236ac8069d9c4877a3c42457d65d477fe /src/main/scala/firrtl/passes | |
| parent | ff1cd28202fb423956a6803a889c3632487d8872 (diff) | |
add emitter for optimized low firrtl (#2304)
* rearrange passes to enable optimized firrtl emission
* Support ConstProp on padded arguments to comparisons with literals
* Move shr legalization logic into ConstProp
Continue calling ConstProp of shr in Legalize.
Co-authored-by: Jack Koenig <koenig@sifive.com>
Co-authored-by: Jack Koenig <koenig@sifive.com>
Diffstat (limited to 'src/main/scala/firrtl/passes')
8 files changed, 84 insertions, 165 deletions
diff --git a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala index e70346d4..70da011e 100644 --- a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala +++ b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala @@ -9,15 +9,7 @@ import firrtl.options.Dependency object CommonSubexpressionElimination extends Pass { - override def prerequisites = firrtl.stage.Forms.LowForm ++ - Seq( - Dependency(firrtl.passes.RemoveValidIf), - Dependency[firrtl.transforms.ConstantPropagation], - Dependency(firrtl.passes.memlib.VerilogMemDelays), - Dependency(firrtl.passes.SplitExpressions), - Dependency[firrtl.transforms.CombineCats] - ) - + override def prerequisites = firrtl.stage.Forms.LowForm override def optionalPrerequisiteOf = Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) diff --git a/src/main/scala/firrtl/passes/Legalize.scala b/src/main/scala/firrtl/passes/Legalize.scala deleted file mode 100644 index e1a39fbe..00000000 --- a/src/main/scala/firrtl/passes/Legalize.scala +++ /dev/null @@ -1,108 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -package firrtl.passes - -import firrtl.PrimOps._ -import firrtl.Utils.{error, getGroundZero, zero, BoolType} -import firrtl.ir._ -import firrtl.options.Dependency -import firrtl.transforms.ConstantPropagation -import firrtl.{bitWidth, getWidth, Transform} -import firrtl.Mappers._ - -// Replace shr by amount >= arg width with 0 for UInts and MSB for SInts -// TODO replace UInt with zero-width wire instead -object Legalize extends Pass { - - override def prerequisites = firrtl.stage.Forms.MidForm :+ Dependency(LowerTypes) - - override def optionalPrerequisites = Seq.empty - - override def optionalPrerequisiteOf = Seq.empty - - override def invalidates(a: Transform) = false - - private def legalizeShiftRight(e: DoPrim): Expression = { - require(e.op == Shr) - e.args.head match { - case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.foldShiftRight(e) - case _ => - val amount = e.consts.head.toInt - val width = bitWidth(e.args.head.tpe) - lazy val msb = width - 1 - if (amount >= width) { - e.tpe match { - case UIntType(_) => zero - case SIntType(_) => - val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType) - DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1))) - case t => error(s"Unsupported type $t for Primop Shift Right") - } - } else { - e - } - } - } - private def legalizeBitExtract(expr: DoPrim): Expression = { - expr.args.head match { - case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.constPropBitExtract(expr) - case _ => expr - } - } - private def legalizePad(expr: DoPrim): Expression = expr.args.head match { - case UIntLiteral(value, IntWidth(width)) if width < expr.consts.head => - UIntLiteral(value, IntWidth(expr.consts.head)) - case SIntLiteral(value, IntWidth(width)) if width < expr.consts.head => - SIntLiteral(value, IntWidth(expr.consts.head)) - case _ => expr - } - // Convert `-x` to `0 - x` - private def legalizeNeg(expr: DoPrim): Expression = { - val arg = expr.args.head - arg.tpe match { - case tpe: SIntType => - val zero = getGroundZero(tpe) - DoPrim(Sub, Seq(zero, arg), Nil, expr.tpe) - case tpe: UIntType => - val zero = getGroundZero(tpe) - val sub = DoPrim(Sub, Seq(zero, arg), Nil, UIntType(tpe.width + IntWidth(1))) - DoPrim(AsSInt, Seq(sub), Nil, expr.tpe) - } - } - private def legalizeConnect(c: Connect): Statement = { - val t = c.loc.tpe - val w = bitWidth(t) - if (w >= bitWidth(c.expr.tpe)) { - c - } else { - val bits = DoPrim(Bits, Seq(c.expr), Seq(w - 1, 0), UIntType(IntWidth(w))) - val expr = t match { - case UIntType(_) => bits - case SIntType(_) => DoPrim(AsSInt, Seq(bits), Seq(), SIntType(IntWidth(w))) - case FixedType(_, IntWidth(p)) => DoPrim(AsFixedPoint, Seq(bits), Seq(p), t) - } - Connect(c.info, c.loc, expr) - } - } - def run(c: Circuit): Circuit = { - def legalizeE(expr: Expression): Expression = expr.map(legalizeE) match { - case prim: DoPrim => - prim.op match { - case Shr => legalizeShiftRight(prim) - case Pad => legalizePad(prim) - case Bits | Head | Tail => legalizeBitExtract(prim) - case Neg => legalizeNeg(prim) - case _ => prim - } - case e => e // respect pre-order traversal - } - def legalizeS(s: Statement): Statement = { - val legalizedStmt = s match { - case c: Connect => legalizeConnect(c) - case _ => s - } - legalizedStmt.map(legalizeS).map(legalizeE) - } - c.copy(modules = c.modules.map(_.map(legalizeS))) - } -} diff --git a/src/main/scala/firrtl/passes/LegalizeConnects.scala b/src/main/scala/firrtl/passes/LegalizeConnects.scala new file mode 100644 index 00000000..2f29de10 --- /dev/null +++ b/src/main/scala/firrtl/passes/LegalizeConnects.scala @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.passes + +import firrtl.ir._ +import firrtl.options.Dependency +import firrtl.{bitWidth, Transform} + +/** Ensures that all connects + register inits have the same bit-width on the rhs and the lhs. + * The rhs is padded or bit-extacted to fit the width of the lhs. + * @note technically, width(rhs) > width(lhs) is not legal firrtl, however, we do not error for historic reasons. + */ +object LegalizeConnects extends Pass { + + override def prerequisites = firrtl.stage.Forms.MidForm :+ Dependency(LowerTypes) + override def optionalPrerequisites = Seq.empty + override def optionalPrerequisiteOf = Seq.empty + override def invalidates(a: Transform) = false + + def onStmt(s: Statement): Statement = s match { + case c: Connect => + c.copy(expr = PadWidths.forceWidth(bitWidth(c.loc.tpe).toInt)(c.expr)) + case r: DefRegister => + r.copy(init = PadWidths.forceWidth(bitWidth(r.tpe).toInt)(r.init)) + case other => other.mapStmt(onStmt) + } + + def run(c: Circuit): Circuit = { + c.copy(modules = c.modules.map(_.mapStmt(onStmt))) + } +} diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala index 1a430778..02e94975 100644 --- a/src/main/scala/firrtl/passes/PadWidths.scala +++ b/src/main/scala/firrtl/passes/PadWidths.scala @@ -7,63 +7,59 @@ import firrtl.ir._ import firrtl.PrimOps._ import firrtl.Mappers._ import firrtl.options.Dependency - -import scala.collection.mutable +import firrtl.transforms.ConstantPropagation // Makes all implicit width extensions and truncations explicit object PadWidths extends Pass { - override def prerequisites = - ((new mutable.LinkedHashSet()) - ++ firrtl.stage.Forms.LowForm - - Dependency(firrtl.passes.Legalize) - + Dependency(firrtl.passes.RemoveValidIf)).toSeq - - override def optionalPrerequisites = Seq(Dependency[firrtl.transforms.ConstantPropagation]) + override def prerequisites = firrtl.stage.Forms.LowForm override def optionalPrerequisiteOf = Seq(Dependency(firrtl.passes.memlib.VerilogMemDelays), Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) override def invalidates(a: Transform): Boolean = a match { - case _: firrtl.transforms.ConstantPropagation | Legalize => true - case _ => false + case SplitExpressions => true // we generate pad and bits operations inline which need to be split up + case _ => false } - private def width(t: Type): Int = bitWidth(t).toInt - private def width(e: Expression): Int = width(e.tpe) - // Returns an expression with the correct integer width - private def fixup(i: Int)(e: Expression) = { - def tx = e.tpe match { - case t: UIntType => UIntType(IntWidth(i)) - case t: SIntType => SIntType(IntWidth(i)) - // default case should never be reached - } - width(e) match { - case j if i > j => DoPrim(Pad, Seq(e), Seq(i), tx) - case j if i < j => - val e2 = DoPrim(Bits, Seq(e), Seq(i - 1, 0), UIntType(IntWidth(i))) - // Bit Select always returns UInt, cast if selecting from SInt - e.tpe match { - case UIntType(_) => e2 - case SIntType(_) => DoPrim(AsSInt, Seq(e2), Seq.empty, SIntType(IntWidth(i))) - } - case _ => e + /** Adds padding or a bit extract to ensure that the expression is of the with specified. + * @note only works on UInt and SInt type expressions, other expressions will yield a match error + */ + private[firrtl] def forceWidth(width: Int)(e: Expression): Expression = { + val old = getWidth(e) + if (width == old) { e } + else if (width > old) { + // padding retains the signedness + val newType = e.tpe match { + case _: UIntType => UIntType(IntWidth(width)) + case _: SIntType => SIntType(IntWidth(width)) + case other => throw new RuntimeException(s"forceWidth does not support expressions of type $other") + } + ConstantPropagation.constPropPad(DoPrim(Pad, Seq(e), Seq(width), newType)) + } else { + val extract = DoPrim(Bits, Seq(e), Seq(width - 1, 0), UIntType(IntWidth(width))) + val e2 = ConstantPropagation.constPropBitExtract(extract) + // Bit Select always returns UInt, cast if selecting from SInt + e.tpe match { + case UIntType(_) => e2 + case SIntType(_) => DoPrim(AsSInt, Seq(e2), Seq.empty, SIntType(IntWidth(width))) + } } } + private def getWidth(t: Type): Int = bitWidth(t).toInt + private def getWidth(e: Expression): Int = getWidth(e.tpe) + // Recursive, updates expression so children exp's have correct widths private def onExp(e: Expression): Expression = e.map(onExp) match { case Mux(cond, tval, fval, tpe) => - Mux(cond, fixup(width(tpe))(tval), fixup(width(tpe))(fval), tpe) - case ex: ValidIf => ex.copy(value = fixup(width(ex.tpe))(ex.value)) + Mux(cond, forceWidth(getWidth(tpe))(tval), forceWidth(getWidth(tpe))(fval), tpe) + case ex: ValidIf => ex.copy(value = forceWidth(getWidth(ex.tpe))(ex.value)) case ex: DoPrim => ex.op match { - case Lt | Leq | Gt | Geq | Eq | Neq | Not | And | Or | Xor | Add | Sub | Rem | Shr => - // sensitive ops - ex.map(fixup((ex.args.map(width).foldLeft(0))(math.max))) - case Dshl => - // special case as args aren't all same width - ex.copy(op = Dshlw, args = Seq(fixup(width(ex.tpe))(ex.args.head), ex.args(1))) + // pad arguments to ops where the result width is determined as max(w_1, w_2) (+ const)? + case Lt | Leq | Gt | Geq | Eq | Neq | And | Or | Xor | Add | Sub => + ex.map(forceWidth(ex.args.map(getWidth).max)) case _ => ex } case ex => ex @@ -72,9 +68,17 @@ object PadWidths extends Pass { // Recursive. Fixes assignments and register initialization widths private def onStmt(s: Statement): Statement = s.map(onExp) match { case sx: Connect => - sx.copy(expr = fixup(width(sx.loc))(sx.expr)) + assert( + getWidth(sx.loc) == getWidth(sx.expr), + "Connection widths should have been taken care of by LegalizeConnects!" + ) + sx case sx: DefRegister => - sx.copy(init = fixup(width(sx.tpe))(sx.init)) + assert( + getWidth(sx.tpe) == getWidth(sx.init), + "Register init widths should have been taken care of by LegalizeConnects!" + ) + sx case sx => sx.map(onStmt) } diff --git a/src/main/scala/firrtl/passes/RemoveValidIf.scala b/src/main/scala/firrtl/passes/RemoveValidIf.scala index 03214f83..dc4e70ff 100644 --- a/src/main/scala/firrtl/passes/RemoveValidIf.scala +++ b/src/main/scala/firrtl/passes/RemoveValidIf.scala @@ -31,7 +31,7 @@ object RemoveValidIf extends Pass { Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) override def invalidates(a: Transform): Boolean = a match { - case Legalize | _: firrtl.transforms.ConstantPropagation => true + case _: firrtl.transforms.ConstantPropagation => true // switching out the validifs allows for more constant prop case _ => false } diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala index 1b4ed1cc..26088e9c 100644 --- a/src/main/scala/firrtl/passes/SplitExpressions.scala +++ b/src/main/scala/firrtl/passes/SplitExpressions.scala @@ -8,6 +8,7 @@ import firrtl.ir._ import firrtl.options.Dependency import firrtl.Mappers._ import firrtl.Utils.{flow, get_info, kind} +import firrtl.transforms.InlineBooleanExpressions // Datastructures import scala.collection.mutable @@ -16,15 +17,14 @@ import scala.collection.mutable // and named intermediate nodes object SplitExpressions extends Pass { - override def prerequisites = firrtl.stage.Forms.LowForm ++ - Seq(Dependency(firrtl.passes.RemoveValidIf), Dependency(firrtl.passes.memlib.VerilogMemDelays)) - + override def prerequisites = firrtl.stage.Forms.LowForm override def optionalPrerequisiteOf = Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) override def invalidates(a: Transform) = a match { case ResolveKinds => true - case _ => false + case _: InlineBooleanExpressions => true // SplitExpressions undoes the inlining! + case _ => false } private def onModule(m: Module): Module = { diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala index c7b0fbcd..331dd43e 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala @@ -153,7 +153,7 @@ class ReplSeqMem extends SeqTransform with HasShellOptions with DependencyAPIMig val transforms: Seq[Transform] = Seq( - new SimpleMidTransform(Legalize), + new SimpleMidTransform(LegalizeConnects), new SimpleMidTransform(ToMemIR), new SimpleMidTransform(ResolveMaskGranularity), new SimpleMidTransform(RenameAnnotatedMemoryPorts), diff --git a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala index 11184e60..3778f4da 100644 --- a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala +++ b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala @@ -235,8 +235,8 @@ object VerilogMemDelays extends Pass { Seq(Dependency[VerilogEmitter], Dependency[SystemVerilogEmitter]) override def invalidates(a: Transform): Boolean = a match { - case _: transforms.ConstantPropagation | ResolveFlows => true - case _ => false + case ResolveFlows => true + case _ => false } private def transform(m: DefModule): DefModule = (new MemDelayAndReadwriteTransformer(m)).transformed |
