diff options
| author | Kevin Laeufer | 2021-08-02 13:46:29 -0700 |
|---|---|---|
| committer | GitHub | 2021-08-02 20:46:29 +0000 |
| commit | e04f1e7f303920ac1d1f865450d0e280aafb58b3 (patch) | |
| tree | 73f26cd236ac8069d9c4877a3c42457d65d477fe /src/main/scala | |
| parent | ff1cd28202fb423956a6803a889c3632487d8872 (diff) | |
add emitter for optimized low firrtl (#2304)
* rearrange passes to enable optimized firrtl emission
* Support ConstProp on padded arguments to comparisons with literals
* Move shr legalization logic into ConstProp
Continue calling ConstProp of shr in Legalize.
Co-authored-by: Jack Koenig <koenig@sifive.com>
Co-authored-by: Jack Koenig <koenig@sifive.com>
Diffstat (limited to 'src/main/scala')
21 files changed, 313 insertions, 282 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index 7c91a544..760e83fd 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -3,12 +3,11 @@ package firrtl import java.io.File - import firrtl.annotations.NoTargetAnnotation import firrtl.backends.experimental.smt.{Btor2Emitter, SMTLibEmitter} import firrtl.backends.proto.{Emitter => ProtoEmitter} import firrtl.options.Viewer.view -import firrtl.options.{CustomFileEmission, HasShellOptions, PhaseException, ShellOption} +import firrtl.options.{CustomFileEmission, Dependency, HasShellOptions, PhaseException, ShellOption} import firrtl.passes.PassException import firrtl.stage.{FirrtlFileAnnotation, FirrtlOptions, RunFirrtlTransformAnnotation} @@ -45,6 +44,11 @@ object EmitCircuitAnnotation extends HasShellOptions { ) case "low" => Seq(RunFirrtlTransformAnnotation(new LowFirrtlEmitter), EmitCircuitAnnotation(classOf[LowFirrtlEmitter])) + case "low-opt" => + Seq( + RunFirrtlTransformAnnotation(Dependency(LowFirrtlOptimizedEmitter)), + EmitCircuitAnnotation(LowFirrtlOptimizedEmitter.getClass) + ) case "verilog" | "mverilog" => Seq(RunFirrtlTransformAnnotation(new VerilogEmitter), EmitCircuitAnnotation(classOf[VerilogEmitter])) case "sverilog" => diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala index 099b6712..2c08ff6a 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala @@ -118,8 +118,8 @@ private object FirrtlExpressionSemantics { // the resulting value will be zero for unsigned types // and the sign bit for signed types" if (n >= width) { - if (isSigned(e)) { BV1BitZero } - else { BVSlice(toSMT(e), width - 1, width - 1) } + if (isSigned(e)) { BVSlice(toSMT(e), width - 1, width - 1) } + else { BV1BitZero } } else { BVSlice(toSMT(e), width - 1, n.toInt) } diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala index 78ad3c80..0b8e3ebf 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala @@ -451,7 +451,7 @@ private class ModuleScanner( val name = loc.serialize insertDummyAssignsForUnusedOutputs(expr) infos.append(name -> info) - connects.append((name, onExpression(expr, bitWidth(loc.tpe).toInt))) + connects.append((name, onExpression(expr, bitWidth(loc.tpe).toInt, allowNarrow = true))) } case i @ ir.IsInvalid(info, loc) => if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!") @@ -591,9 +591,9 @@ private class ModuleScanner( private case class Context() extends TranslationContext {} - private def onExpression(e: ir.Expression, width: Int): BVExpr = { + private def onExpression(e: ir.Expression, width: Int, allowNarrow: Boolean = false): BVExpr = { implicit val ctx: TranslationContext = Context() - FirrtlExpressionSemantics.toSMT(e, width, allowNarrow = false) + FirrtlExpressionSemantics.toSMT(e, width, allowNarrow) } private def onExpression(e: ir.Expression): BVExpr = { implicit val ctx: TranslationContext = Context() diff --git a/src/main/scala/firrtl/backends/firrtl/FirrtlEmitter.scala b/src/main/scala/firrtl/backends/firrtl/FirrtlEmitter.scala index bb385ffd..56b63d75 100644 --- a/src/main/scala/firrtl/backends/firrtl/FirrtlEmitter.scala +++ b/src/main/scala/firrtl/backends/firrtl/FirrtlEmitter.scala @@ -1,18 +1,21 @@ package firrtl import java.io.Writer - import firrtl.Utils._ import firrtl.ir._ +import firrtl.stage.TransformManager.TransformDependency import firrtl.traversals.Foreachers._ import scala.collection.mutable -sealed abstract class FirrtlEmitter(form: CircuitForm) extends Transform with Emitter { - def inputForm = form - def outputForm = form - - val outputSuffix: String = form.outputSuffix +sealed abstract class FirrtlEmitter(form: Seq[TransformDependency], val outputSuffix: String) + extends Transform + with Emitter + with DependencyAPIMigration { + override def prerequisites = form + override def optionalPrerequisites = Seq.empty + override def optionalPrerequisiteOf = Seq.empty + override def invalidates(a: Transform) = false private def emitAllModules(circuit: Circuit): Seq[EmittedFirrtlModule] = { // For a given module, returns a Seq of all modules instantited inside of it @@ -60,14 +63,9 @@ sealed abstract class FirrtlEmitter(form: CircuitForm) extends Transform with Em def emit(state: CircuitState, writer: Writer): Unit = writer.write(state.circuit.serialize) } -class ChirrtlEmitter extends FirrtlEmitter(CircuitForm.ChirrtlForm) -class MinimumHighFirrtlEmitter extends FirrtlEmitter(CircuitForm.HighForm) { - override def prerequisites = stage.Forms.MinimalHighForm - override def optionalPrerequisites = Seq.empty - override def optionalPrerequisiteOf = Seq.empty - override def invalidates(a: Transform) = false - override val outputSuffix = ".mhi.fir" -} -class HighFirrtlEmitter extends FirrtlEmitter(CircuitForm.HighForm) -class MiddleFirrtlEmitter extends FirrtlEmitter(CircuitForm.MidForm) -class LowFirrtlEmitter extends FirrtlEmitter(CircuitForm.LowForm) +class ChirrtlEmitter extends FirrtlEmitter(stage.Forms.ChirrtlForm, ".fir") +class MinimumHighFirrtlEmitter extends FirrtlEmitter(stage.Forms.MinimalHighForm, ".mhi.fir") +class HighFirrtlEmitter extends FirrtlEmitter(stage.Forms.HighForm, ".hi.fir") +class MiddleFirrtlEmitter extends FirrtlEmitter(stage.Forms.MidForm, ".mid.fir") +class LowFirrtlEmitter extends FirrtlEmitter(stage.Forms.LowForm, ".lo.fir") +object LowFirrtlOptimizedEmitter extends FirrtlEmitter(stage.Forms.LowFormOptimized, ".opt.lo.fir") diff --git a/src/main/scala/firrtl/backends/verilog/LegalizeVerilog.scala b/src/main/scala/firrtl/backends/verilog/LegalizeVerilog.scala new file mode 100644 index 00000000..f063f395 --- /dev/null +++ b/src/main/scala/firrtl/backends/verilog/LegalizeVerilog.scala @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.backends.verilog + +import firrtl.PrimOps._ +import firrtl.Utils.{error, getGroundZero, zero, BoolType} +import firrtl.ir._ +import firrtl.transforms.ConstantPropagation +import firrtl.{bitWidth, Dshlw, Transform} +import firrtl.Mappers._ +import firrtl.passes.{Pass, SplitExpressions} + +/** Rewrites some expressions for valid/better Verilog emission. + * - solves shift right overflows by replacing the shift with 0 for UInts and MSB for SInts + * - ensures that bit extracts on literals get resolved + * - ensures that all negations are replaced with subtract from zero + * - adds padding for rem and dshl which breaks firrtl width invariance, but is needed to match Verilog semantics + */ +object LegalizeVerilog extends Pass { + + override def prerequisites = firrtl.stage.Forms.LowForm + override def optionalPrerequisites = Seq.empty + override def optionalPrerequisiteOf = Seq.empty + override def invalidates(a: Transform): Boolean = a match { + case SplitExpressions => true // we generate pad and bits operations inline which need to be split up + case _ => false + } + + private def legalizeBitExtract(expr: DoPrim): Expression = { + expr.args.head match { + case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.constPropBitExtract(expr) + case _ => expr + } + } + + // Convert `-x` to `0 - x` + private def legalizeNeg(expr: DoPrim): Expression = { + val arg = expr.args.head + arg.tpe match { + case tpe: SIntType => + val zero = getGroundZero(tpe) + DoPrim(Sub, Seq(zero, arg), Nil, expr.tpe) + case tpe: UIntType => + val zero = getGroundZero(tpe) + val sub = DoPrim(Sub, Seq(zero, arg), Nil, UIntType(tpe.width + IntWidth(1))) + DoPrim(AsSInt, Seq(sub), Nil, expr.tpe) + } + } + + import firrtl.passes.PadWidths.forceWidth + private def getWidth(e: Expression): Int = bitWidth(e.tpe).toInt + + private def onExpr(expr: Expression): Expression = expr.map(onExpr) match { + case prim: DoPrim => + prim.op match { + case Shr => ConstantPropagation.foldShiftRight(prim) + case Bits | Head | Tail => legalizeBitExtract(prim) + case Neg => legalizeNeg(prim) + case Rem => prim.map(forceWidth(prim.args.map(getWidth).max)) + case Dshl => + // special case as args aren't all same width + prim.copy(op = Dshlw, args = Seq(forceWidth(getWidth(prim))(prim.args.head), prim.args(1))) + case _ => prim + } + case e => e // respect pre-order traversal + } + + def run(c: Circuit): Circuit = { + def legalizeS(s: Statement): Statement = s.mapStmt(legalizeS).mapExpr(onExpr) + c.copy(modules = c.modules.map(_.map(legalizeS))) + } +} diff --git a/src/main/scala/firrtl/checks/CheckResets.scala b/src/main/scala/firrtl/checks/CheckResets.scala index ae300d1f..e5a3e77a 100644 --- a/src/main/scala/firrtl/checks/CheckResets.scala +++ b/src/main/scala/firrtl/checks/CheckResets.scala @@ -33,7 +33,6 @@ class CheckResets extends Transform with DependencyAPIMigration { override def prerequisites = Seq( Dependency(passes.LowerTypes), - Dependency(passes.Legalize), Dependency(firrtl.transforms.RemoveReset) ) ++ firrtl.stage.Forms.MidForm diff --git a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala index e70346d4..70da011e 100644 --- a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala +++ b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala @@ -9,15 +9,7 @@ import firrtl.options.Dependency object CommonSubexpressionElimination extends Pass { - override def prerequisites = firrtl.stage.Forms.LowForm ++ - Seq( - Dependency(firrtl.passes.RemoveValidIf), - Dependency[firrtl.transforms.ConstantPropagation], - Dependency(firrtl.passes.memlib.VerilogMemDelays), - Dependency(firrtl.passes.SplitExpressions), - Dependency[firrtl.transforms.CombineCats] - ) - + override def prerequisites = firrtl.stage.Forms.LowForm override def optionalPrerequisiteOf = Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) diff --git a/src/main/scala/firrtl/passes/Legalize.scala b/src/main/scala/firrtl/passes/Legalize.scala deleted file mode 100644 index e1a39fbe..00000000 --- a/src/main/scala/firrtl/passes/Legalize.scala +++ /dev/null @@ -1,108 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -package firrtl.passes - -import firrtl.PrimOps._ -import firrtl.Utils.{error, getGroundZero, zero, BoolType} -import firrtl.ir._ -import firrtl.options.Dependency -import firrtl.transforms.ConstantPropagation -import firrtl.{bitWidth, getWidth, Transform} -import firrtl.Mappers._ - -// Replace shr by amount >= arg width with 0 for UInts and MSB for SInts -// TODO replace UInt with zero-width wire instead -object Legalize extends Pass { - - override def prerequisites = firrtl.stage.Forms.MidForm :+ Dependency(LowerTypes) - - override def optionalPrerequisites = Seq.empty - - override def optionalPrerequisiteOf = Seq.empty - - override def invalidates(a: Transform) = false - - private def legalizeShiftRight(e: DoPrim): Expression = { - require(e.op == Shr) - e.args.head match { - case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.foldShiftRight(e) - case _ => - val amount = e.consts.head.toInt - val width = bitWidth(e.args.head.tpe) - lazy val msb = width - 1 - if (amount >= width) { - e.tpe match { - case UIntType(_) => zero - case SIntType(_) => - val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType) - DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1))) - case t => error(s"Unsupported type $t for Primop Shift Right") - } - } else { - e - } - } - } - private def legalizeBitExtract(expr: DoPrim): Expression = { - expr.args.head match { - case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.constPropBitExtract(expr) - case _ => expr - } - } - private def legalizePad(expr: DoPrim): Expression = expr.args.head match { - case UIntLiteral(value, IntWidth(width)) if width < expr.consts.head => - UIntLiteral(value, IntWidth(expr.consts.head)) - case SIntLiteral(value, IntWidth(width)) if width < expr.consts.head => - SIntLiteral(value, IntWidth(expr.consts.head)) - case _ => expr - } - // Convert `-x` to `0 - x` - private def legalizeNeg(expr: DoPrim): Expression = { - val arg = expr.args.head - arg.tpe match { - case tpe: SIntType => - val zero = getGroundZero(tpe) - DoPrim(Sub, Seq(zero, arg), Nil, expr.tpe) - case tpe: UIntType => - val zero = getGroundZero(tpe) - val sub = DoPrim(Sub, Seq(zero, arg), Nil, UIntType(tpe.width + IntWidth(1))) - DoPrim(AsSInt, Seq(sub), Nil, expr.tpe) - } - } - private def legalizeConnect(c: Connect): Statement = { - val t = c.loc.tpe - val w = bitWidth(t) - if (w >= bitWidth(c.expr.tpe)) { - c - } else { - val bits = DoPrim(Bits, Seq(c.expr), Seq(w - 1, 0), UIntType(IntWidth(w))) - val expr = t match { - case UIntType(_) => bits - case SIntType(_) => DoPrim(AsSInt, Seq(bits), Seq(), SIntType(IntWidth(w))) - case FixedType(_, IntWidth(p)) => DoPrim(AsFixedPoint, Seq(bits), Seq(p), t) - } - Connect(c.info, c.loc, expr) - } - } - def run(c: Circuit): Circuit = { - def legalizeE(expr: Expression): Expression = expr.map(legalizeE) match { - case prim: DoPrim => - prim.op match { - case Shr => legalizeShiftRight(prim) - case Pad => legalizePad(prim) - case Bits | Head | Tail => legalizeBitExtract(prim) - case Neg => legalizeNeg(prim) - case _ => prim - } - case e => e // respect pre-order traversal - } - def legalizeS(s: Statement): Statement = { - val legalizedStmt = s match { - case c: Connect => legalizeConnect(c) - case _ => s - } - legalizedStmt.map(legalizeS).map(legalizeE) - } - c.copy(modules = c.modules.map(_.map(legalizeS))) - } -} diff --git a/src/main/scala/firrtl/passes/LegalizeConnects.scala b/src/main/scala/firrtl/passes/LegalizeConnects.scala new file mode 100644 index 00000000..2f29de10 --- /dev/null +++ b/src/main/scala/firrtl/passes/LegalizeConnects.scala @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.passes + +import firrtl.ir._ +import firrtl.options.Dependency +import firrtl.{bitWidth, Transform} + +/** Ensures that all connects + register inits have the same bit-width on the rhs and the lhs. + * The rhs is padded or bit-extacted to fit the width of the lhs. + * @note technically, width(rhs) > width(lhs) is not legal firrtl, however, we do not error for historic reasons. + */ +object LegalizeConnects extends Pass { + + override def prerequisites = firrtl.stage.Forms.MidForm :+ Dependency(LowerTypes) + override def optionalPrerequisites = Seq.empty + override def optionalPrerequisiteOf = Seq.empty + override def invalidates(a: Transform) = false + + def onStmt(s: Statement): Statement = s match { + case c: Connect => + c.copy(expr = PadWidths.forceWidth(bitWidth(c.loc.tpe).toInt)(c.expr)) + case r: DefRegister => + r.copy(init = PadWidths.forceWidth(bitWidth(r.tpe).toInt)(r.init)) + case other => other.mapStmt(onStmt) + } + + def run(c: Circuit): Circuit = { + c.copy(modules = c.modules.map(_.mapStmt(onStmt))) + } +} diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala index 1a430778..02e94975 100644 --- a/src/main/scala/firrtl/passes/PadWidths.scala +++ b/src/main/scala/firrtl/passes/PadWidths.scala @@ -7,63 +7,59 @@ import firrtl.ir._ import firrtl.PrimOps._ import firrtl.Mappers._ import firrtl.options.Dependency - -import scala.collection.mutable +import firrtl.transforms.ConstantPropagation // Makes all implicit width extensions and truncations explicit object PadWidths extends Pass { - override def prerequisites = - ((new mutable.LinkedHashSet()) - ++ firrtl.stage.Forms.LowForm - - Dependency(firrtl.passes.Legalize) - + Dependency(firrtl.passes.RemoveValidIf)).toSeq - - override def optionalPrerequisites = Seq(Dependency[firrtl.transforms.ConstantPropagation]) + override def prerequisites = firrtl.stage.Forms.LowForm override def optionalPrerequisiteOf = Seq(Dependency(firrtl.passes.memlib.VerilogMemDelays), Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) override def invalidates(a: Transform): Boolean = a match { - case _: firrtl.transforms.ConstantPropagation | Legalize => true - case _ => false + case SplitExpressions => true // we generate pad and bits operations inline which need to be split up + case _ => false } - private def width(t: Type): Int = bitWidth(t).toInt - private def width(e: Expression): Int = width(e.tpe) - // Returns an expression with the correct integer width - private def fixup(i: Int)(e: Expression) = { - def tx = e.tpe match { - case t: UIntType => UIntType(IntWidth(i)) - case t: SIntType => SIntType(IntWidth(i)) - // default case should never be reached - } - width(e) match { - case j if i > j => DoPrim(Pad, Seq(e), Seq(i), tx) - case j if i < j => - val e2 = DoPrim(Bits, Seq(e), Seq(i - 1, 0), UIntType(IntWidth(i))) - // Bit Select always returns UInt, cast if selecting from SInt - e.tpe match { - case UIntType(_) => e2 - case SIntType(_) => DoPrim(AsSInt, Seq(e2), Seq.empty, SIntType(IntWidth(i))) - } - case _ => e + /** Adds padding or a bit extract to ensure that the expression is of the with specified. + * @note only works on UInt and SInt type expressions, other expressions will yield a match error + */ + private[firrtl] def forceWidth(width: Int)(e: Expression): Expression = { + val old = getWidth(e) + if (width == old) { e } + else if (width > old) { + // padding retains the signedness + val newType = e.tpe match { + case _: UIntType => UIntType(IntWidth(width)) + case _: SIntType => SIntType(IntWidth(width)) + case other => throw new RuntimeException(s"forceWidth does not support expressions of type $other") + } + ConstantPropagation.constPropPad(DoPrim(Pad, Seq(e), Seq(width), newType)) + } else { + val extract = DoPrim(Bits, Seq(e), Seq(width - 1, 0), UIntType(IntWidth(width))) + val e2 = ConstantPropagation.constPropBitExtract(extract) + // Bit Select always returns UInt, cast if selecting from SInt + e.tpe match { + case UIntType(_) => e2 + case SIntType(_) => DoPrim(AsSInt, Seq(e2), Seq.empty, SIntType(IntWidth(width))) + } } } + private def getWidth(t: Type): Int = bitWidth(t).toInt + private def getWidth(e: Expression): Int = getWidth(e.tpe) + // Recursive, updates expression so children exp's have correct widths private def onExp(e: Expression): Expression = e.map(onExp) match { case Mux(cond, tval, fval, tpe) => - Mux(cond, fixup(width(tpe))(tval), fixup(width(tpe))(fval), tpe) - case ex: ValidIf => ex.copy(value = fixup(width(ex.tpe))(ex.value)) + Mux(cond, forceWidth(getWidth(tpe))(tval), forceWidth(getWidth(tpe))(fval), tpe) + case ex: ValidIf => ex.copy(value = forceWidth(getWidth(ex.tpe))(ex.value)) case ex: DoPrim => ex.op match { - case Lt | Leq | Gt | Geq | Eq | Neq | Not | And | Or | Xor | Add | Sub | Rem | Shr => - // sensitive ops - ex.map(fixup((ex.args.map(width).foldLeft(0))(math.max))) - case Dshl => - // special case as args aren't all same width - ex.copy(op = Dshlw, args = Seq(fixup(width(ex.tpe))(ex.args.head), ex.args(1))) + // pad arguments to ops where the result width is determined as max(w_1, w_2) (+ const)? + case Lt | Leq | Gt | Geq | Eq | Neq | And | Or | Xor | Add | Sub => + ex.map(forceWidth(ex.args.map(getWidth).max)) case _ => ex } case ex => ex @@ -72,9 +68,17 @@ object PadWidths extends Pass { // Recursive. Fixes assignments and register initialization widths private def onStmt(s: Statement): Statement = s.map(onExp) match { case sx: Connect => - sx.copy(expr = fixup(width(sx.loc))(sx.expr)) + assert( + getWidth(sx.loc) == getWidth(sx.expr), + "Connection widths should have been taken care of by LegalizeConnects!" + ) + sx case sx: DefRegister => - sx.copy(init = fixup(width(sx.tpe))(sx.init)) + assert( + getWidth(sx.tpe) == getWidth(sx.init), + "Register init widths should have been taken care of by LegalizeConnects!" + ) + sx case sx => sx.map(onStmt) } diff --git a/src/main/scala/firrtl/passes/RemoveValidIf.scala b/src/main/scala/firrtl/passes/RemoveValidIf.scala index 03214f83..dc4e70ff 100644 --- a/src/main/scala/firrtl/passes/RemoveValidIf.scala +++ b/src/main/scala/firrtl/passes/RemoveValidIf.scala @@ -31,7 +31,7 @@ object RemoveValidIf extends Pass { Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) override def invalidates(a: Transform): Boolean = a match { - case Legalize | _: firrtl.transforms.ConstantPropagation => true + case _: firrtl.transforms.ConstantPropagation => true // switching out the validifs allows for more constant prop case _ => false } diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala index 1b4ed1cc..26088e9c 100644 --- a/src/main/scala/firrtl/passes/SplitExpressions.scala +++ b/src/main/scala/firrtl/passes/SplitExpressions.scala @@ -8,6 +8,7 @@ import firrtl.ir._ import firrtl.options.Dependency import firrtl.Mappers._ import firrtl.Utils.{flow, get_info, kind} +import firrtl.transforms.InlineBooleanExpressions // Datastructures import scala.collection.mutable @@ -16,15 +17,14 @@ import scala.collection.mutable // and named intermediate nodes object SplitExpressions extends Pass { - override def prerequisites = firrtl.stage.Forms.LowForm ++ - Seq(Dependency(firrtl.passes.RemoveValidIf), Dependency(firrtl.passes.memlib.VerilogMemDelays)) - + override def prerequisites = firrtl.stage.Forms.LowForm override def optionalPrerequisiteOf = Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) override def invalidates(a: Transform) = a match { case ResolveKinds => true - case _ => false + case _: InlineBooleanExpressions => true // SplitExpressions undoes the inlining! + case _ => false } private def onModule(m: Module): Module = { diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala index c7b0fbcd..331dd43e 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala @@ -153,7 +153,7 @@ class ReplSeqMem extends SeqTransform with HasShellOptions with DependencyAPIMig val transforms: Seq[Transform] = Seq( - new SimpleMidTransform(Legalize), + new SimpleMidTransform(LegalizeConnects), new SimpleMidTransform(ToMemIR), new SimpleMidTransform(ResolveMaskGranularity), new SimpleMidTransform(RenameAnnotatedMemoryPorts), diff --git a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala index 11184e60..3778f4da 100644 --- a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala +++ b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala @@ -235,8 +235,8 @@ object VerilogMemDelays extends Pass { Seq(Dependency[VerilogEmitter], Dependency[SystemVerilogEmitter]) override def invalidates(a: Transform): Boolean = a match { - case _: transforms.ConstantPropagation | ResolveFlows => true - case _ => false + case ResolveFlows => true + case _ => false } private def transform(m: DefModule): DefModule = (new MemDelayAndReadwriteTransformer(m)).transformed diff --git a/src/main/scala/firrtl/stage/Forms.scala b/src/main/scala/firrtl/stage/Forms.scala index c7ae648a..83d019f5 100644 --- a/src/main/scala/firrtl/stage/Forms.scala +++ b/src/main/scala/firrtl/stage/Forms.scala @@ -75,7 +75,7 @@ object Forms { val LowForm: Seq[TransformDependency] = MidForm ++ Seq( Dependency(passes.LowerTypes), - Dependency(passes.Legalize), + Dependency(passes.LegalizeConnects), Dependency(firrtl.transforms.RemoveReset), Dependency[firrtl.transforms.CheckCombLoops], Dependency[checks.CheckResets], @@ -86,39 +86,42 @@ object Forms { Seq( Dependency(passes.RemoveValidIf), Dependency(passes.PadWidths), - Dependency(passes.memlib.VerilogMemDelays), - Dependency(passes.SplitExpressions), - Dependency[firrtl.transforms.LegalizeAndReductionsTransform] + Dependency(passes.SplitExpressions) ) val LowFormOptimized: Seq[TransformDependency] = LowFormMinimumOptimized ++ Seq( Dependency[firrtl.transforms.ConstantPropagation], - Dependency[firrtl.transforms.CombineCats], Dependency(passes.CommonSubexpressionElimination), Dependency[firrtl.transforms.DeadCodeElimination] ) - val VerilogMinimumOptimized: Seq[TransformDependency] = LowFormMinimumOptimized ++ + private def VerilogLowerings(optimize: Boolean): Seq[TransformDependency] = { Seq( - Dependency[firrtl.transforms.BlackBoxSourceHelper], - Dependency[firrtl.transforms.FixAddingNegativeLiterals], - Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], - Dependency[firrtl.transforms.InlineBitExtractionsTransform], - Dependency[firrtl.transforms.InlineAcrossCastsTransform], - Dependency[firrtl.transforms.LegalizeClocksTransform], - Dependency[firrtl.transforms.FlattenRegUpdate], - Dependency(passes.VerilogModulusCleanup), - Dependency[firrtl.transforms.VerilogRename], - Dependency(passes.VerilogPrep), - Dependency[firrtl.AddDescriptionNodes] - ) - - val VerilogOptimized: Seq[TransformDependency] = LowFormOptimized ++ - Seq( - Dependency[firrtl.transforms.InlineBooleanExpressions] + Dependency(firrtl.backends.verilog.LegalizeVerilog), + Dependency(passes.memlib.VerilogMemDelays), + Dependency[firrtl.transforms.CombineCats] ) ++ - VerilogMinimumOptimized + (if (optimize) Seq(Dependency[firrtl.transforms.InlineBooleanExpressions]) else Seq()) ++ + Seq( + Dependency[firrtl.transforms.LegalizeAndReductionsTransform], + Dependency[firrtl.transforms.BlackBoxSourceHelper], + Dependency[firrtl.transforms.FixAddingNegativeLiterals], + Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], + Dependency[firrtl.transforms.InlineBitExtractionsTransform], + Dependency[firrtl.transforms.InlineAcrossCastsTransform], + Dependency[firrtl.transforms.LegalizeClocksTransform], + Dependency[firrtl.transforms.FlattenRegUpdate], + Dependency(passes.VerilogModulusCleanup), + Dependency[firrtl.transforms.VerilogRename], + Dependency(passes.VerilogPrep), + Dependency[firrtl.AddDescriptionNodes] + ) + } + + val VerilogMinimumOptimized: Seq[TransformDependency] = LowFormMinimumOptimized ++ VerilogLowerings(optimize = false) + + val VerilogOptimized: Seq[TransformDependency] = LowFormOptimized ++ VerilogLowerings(optimize = true) val AssertsRemoved: Seq[TransformDependency] = Seq( diff --git a/src/main/scala/firrtl/transforms/CheckCombLoops.scala b/src/main/scala/firrtl/transforms/CheckCombLoops.scala index dee2f9c8..eec9d1af 100644 --- a/src/main/scala/firrtl/transforms/CheckCombLoops.scala +++ b/src/main/scala/firrtl/transforms/CheckCombLoops.scala @@ -101,7 +101,7 @@ case class CombinationalPath(sink: ReferenceTarget, sources: Seq[ReferenceTarget class CheckCombLoops extends Transform with RegisteredTransform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.MidForm ++ - Seq(Dependency(passes.LowerTypes), Dependency(passes.Legalize), Dependency(firrtl.transforms.RemoveReset)) + Seq(Dependency(passes.LowerTypes), Dependency(firrtl.transforms.RemoveReset)) override def optionalPrerequisites = Seq.empty diff --git a/src/main/scala/firrtl/transforms/CombineCats.scala b/src/main/scala/firrtl/transforms/CombineCats.scala index a37e6f08..71ef34bf 100644 --- a/src/main/scala/firrtl/transforms/CombineCats.scala +++ b/src/main/scala/firrtl/transforms/CombineCats.scala @@ -63,12 +63,11 @@ class CombineCats extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.LowForm ++ Seq( Dependency(passes.RemoveValidIf), - Dependency[firrtl.transforms.ConstantPropagation], Dependency(firrtl.passes.memlib.VerilogMemDelays), Dependency(firrtl.passes.SplitExpressions) ) - override def optionalPrerequisites = Seq.empty + override def optionalPrerequisites = Seq(Dependency[firrtl.transforms.ConstantPropagation]) override def optionalPrerequisiteOf = Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index bc1fc9af..f216a3a3 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -3,7 +3,7 @@ package firrtl package transforms -import firrtl._ +import firrtl.{options, _} import firrtl.annotations._ import firrtl.annotations.TargetToken._ import firrtl.ir._ @@ -41,8 +41,64 @@ object ConstantPropagation { } ) case (we, wt) if we == wt => e + case (we, wt) => + throw new RuntimeException(s"Cannot pad from $we-bit to $wt-bit! ${e.serialize}") } + def constPropPad(e: DoPrim): Expression = { + // we constant prop through casts here in order to allow LegalizeConnects + // to not mess up async reset checks in CheckResets + val propCasts = e.args.head match { + case c @ DoPrim(AsUInt, _, _, _) => constPropCasts(c) + case c @ DoPrim(AsSInt, _, _, _) => constPropCasts(c) + case other => other + } + propCasts match { + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head.max(w))) + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head.max(w))) + case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head + case _ => e + } + } + + def constPropCasts(e: DoPrim): Expression = e.op match { + case AsUInt => + e.args.head match { + case SIntLiteral(v, IntWidth(w)) => litToUInt(v, w.toInt) + case arg => + arg.tpe match { + case _: UIntType => arg + case _ => e + } + } + case AsSInt => + e.args.head match { + case UIntLiteral(v, IntWidth(w)) => litToSInt(v, w.toInt) + case arg => + arg.tpe match { + case _: SIntType => arg + case _ => e + } + } + case AsClock => + val arg = e.args.head + arg.tpe match { + case ClockType => arg + case _ => e + } + case AsAsyncReset => + val arg = e.args.head + arg.tpe match { + case AsyncResetType => arg + case _ => e + } + } + + private def litToSInt(unsignedValue: BigInt, w: Int): SIntLiteral = + SIntLiteral(unsignedValue - ((unsignedValue >> (w - 1)) << w), IntWidth(w)) + private def litToUInt(signedValue: BigInt, w: Int): UIntLiteral = + UIntLiteral(signedValue + (if (signedValue < 0) BigInt(1) << w else 0), IntWidth(w)) + def constPropBitExtract(e: DoPrim) = { val arg = e.args.head val (hi, lo) = e.op match { @@ -68,11 +124,26 @@ object ConstantPropagation { case 0 => e.args.head case x => e.args.head match { - // TODO when amount >= x.width, return a zero-width wire case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x).max(1))) - // take sign bit if shift amount is larger than arg width case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x).max(1))) - case _ => e + // Handle non-literal arguments where shift is larger than width + case _ => + val amount = e.consts.head.toInt + val width = bitWidth(e.args.head.tpe) + lazy val msb = width - 1 + if (amount >= width) { + e.tpe match { + // When amount >= x.width, return a zero-width wire + case UIntType(_) => zero + // Take sign bit if shift amount is larger than arg width + case SIntType(_) => + val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType) + DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1))) + case t => error(s"Unsupported type $t for Primop Shift Right") + } + } else { + e + } } } @@ -114,12 +185,13 @@ object ConstantPropagation { class ConstantPropagation extends Transform with RegisteredTransform with DependencyAPIMigration { import ConstantPropagation._ - override def prerequisites = - ((new mutable.LinkedHashSet()) - ++ firrtl.stage.Forms.LowForm - - Dependency(firrtl.passes.Legalize)).toSeq + override def prerequisites = firrtl.stage.Forms.LowForm - override def optionalPrerequisites = Seq(Dependency(firrtl.passes.RemoveValidIf)) + override def optionalPrerequisites = Seq( + // both passes allow constant prop to be more effective! + Dependency(firrtl.passes.RemoveValidIf), + Dependency(firrtl.passes.PadWidths) + ) override def optionalPrerequisiteOf = Seq( @@ -130,8 +202,7 @@ class ConstantPropagation extends Transform with RegisteredTransform with Depend ) override def invalidates(a: Transform): Boolean = a match { - case firrtl.passes.Legalize => true - case _ => false + case _ => false } val options = Seq( @@ -353,11 +424,16 @@ class ConstantPropagation extends Transform with RegisteredTransform with Depend def <(that: Range) = this.max < that.min def <=(that: Range) = this.max <= that.min } + // Padding increases the width but doesn't increase the range of values + def trueType(e: Expression): Type = e match { + case DoPrim(Pad, Seq(a), _, _) => a.tpe + case other => other.tpe + } def range(e: Expression): Range = e match { case UIntLiteral(value, _) => Range(value, value) case SIntLiteral(value, _) => Range(value, value) case _ => - e.tpe match { + trueType(e) match { case SIntType(IntWidth(width)) => Range( min = BigInt(0) - BigInt(2).pow(width.toInt - 1), @@ -444,45 +520,10 @@ class ConstantPropagation extends Transform with RegisteredTransform with Depend case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) case _ => e } - case AsUInt => - e.args.head match { - case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w)) - case arg => - arg.tpe match { - case _: UIntType => arg - case _ => e - } - } - case AsSInt => - e.args.head match { - case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt - 1)) << w.toInt), IntWidth(w)) - case arg => - arg.tpe match { - case _: SIntType => arg - case _ => e - } - } - case AsClock => - val arg = e.args.head - arg.tpe match { - case ClockType => arg - case _ => e - } - case AsAsyncReset => - val arg = e.args.head - arg.tpe match { - case AsyncResetType => arg - case _ => e - } - case Pad => - e.args.head match { - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head.max(w))) - case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head.max(w))) - case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head - case _ => e - } - case (Bits | Head | Tail) => constPropBitExtract(e) - case _ => e + case AsUInt | AsSInt | AsClock | AsAsyncReset => constPropCasts(e) + case Pad => constPropPad(e) + case (Bits | Head | Tail) => constPropBitExtract(e) + case _ => e } private def constPropMuxCond(m: Mux) = m.cond match { diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala index f72585d1..41ffd2be 100644 --- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala +++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala @@ -33,15 +33,7 @@ import collection.mutable */ class DeadCodeElimination extends Transform with RegisteredTransform with DependencyAPIMigration { - override def prerequisites = firrtl.stage.Forms.LowForm ++ - Seq( - Dependency(firrtl.passes.RemoveValidIf), - Dependency[firrtl.transforms.ConstantPropagation], - Dependency(firrtl.passes.memlib.VerilogMemDelays), - Dependency(firrtl.passes.SplitExpressions), - Dependency[firrtl.transforms.CombineCats], - Dependency(passes.CommonSubexpressionElimination) - ) + override def prerequisites = firrtl.stage.Forms.LowForm override def optionalPrerequisites = Seq.empty diff --git a/src/main/scala/firrtl/transforms/RemoveReset.scala b/src/main/scala/firrtl/transforms/RemoveReset.scala index d7f59321..62b341cd 100644 --- a/src/main/scala/firrtl/transforms/RemoveReset.scala +++ b/src/main/scala/firrtl/transforms/RemoveReset.scala @@ -19,7 +19,7 @@ import scala.collection.{immutable, mutable} object RemoveReset extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.MidForm ++ - Seq(Dependency(passes.LowerTypes), Dependency(passes.Legalize)) + Seq(Dependency(passes.LowerTypes)) override def optionalPrerequisites = Seq.empty diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala index 7500b386..4fa70002 100644 --- a/src/main/scala/firrtl/transforms/RemoveWires.scala +++ b/src/main/scala/firrtl/transforms/RemoveWires.scala @@ -12,6 +12,7 @@ import firrtl.graph.{CyclicException, MutableDiGraph} import firrtl.options.Dependency import firrtl.Utils.getGroundZero import firrtl.backends.experimental.smt.random.DefRandom +import firrtl.passes.PadWidths import scala.collection.mutable import scala.util.{Failure, Success, Try} @@ -27,10 +28,10 @@ class RemoveWires extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.MidForm ++ Seq( Dependency(passes.LowerTypes), - Dependency(passes.Legalize), Dependency(passes.ResolveKinds), Dependency(transforms.RemoveReset), - Dependency[transforms.CheckCombLoops] + Dependency[transforms.CheckCombLoops], + Dependency(passes.LegalizeConnects) ) override def optionalPrerequisites = Seq(Dependency[checks.CheckResets]) @@ -131,10 +132,13 @@ class RemoveWires extends Transform with DependencyAPIMigration { case con @ Connect(cinfo, lhs, rhs) => kind(lhs) match { case WireKind => - // Be sure to pad the rhs since nodes get their type from the rhs - val paddedRhs = ConstantPropagation.pad(rhs, lhs.tpe) + // be sure that connects have the same bit widths on rhs and lhs + assert( + bitWidth(lhs.tpe) == bitWidth(rhs.tpe), + "Connection widths should have been taken care of by LegalizeConnects!" + ) val dinfo = wireInfo(lhs) - netlist(we(lhs)) = (Seq(paddedRhs), MultiInfo(dinfo, cinfo)) + netlist(we(lhs)) = (Seq(rhs), MultiInfo(dinfo, cinfo)) case _ => otherStmts += con // Other connections just pass through } case invalid @ IsInvalid(info, expr) => |
