diff options
| author | Kevin Laeufer | 2021-08-02 13:46:29 -0700 |
|---|---|---|
| committer | GitHub | 2021-08-02 20:46:29 +0000 |
| commit | e04f1e7f303920ac1d1f865450d0e280aafb58b3 (patch) | |
| tree | 73f26cd236ac8069d9c4877a3c42457d65d477fe | |
| parent | ff1cd28202fb423956a6803a889c3632487d8872 (diff) | |
add emitter for optimized low firrtl (#2304)
* rearrange passes to enable optimized firrtl emission
* Support ConstProp on padded arguments to comparisons with literals
* Move shr legalization logic into ConstProp
Continue calling ConstProp of shr in Legalize.
Co-authored-by: Jack Koenig <koenig@sifive.com>
Co-authored-by: Jack Koenig <koenig@sifive.com>
28 files changed, 551 insertions, 381 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index 7c91a544..760e83fd 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -3,12 +3,11 @@ package firrtl import java.io.File - import firrtl.annotations.NoTargetAnnotation import firrtl.backends.experimental.smt.{Btor2Emitter, SMTLibEmitter} import firrtl.backends.proto.{Emitter => ProtoEmitter} import firrtl.options.Viewer.view -import firrtl.options.{CustomFileEmission, HasShellOptions, PhaseException, ShellOption} +import firrtl.options.{CustomFileEmission, Dependency, HasShellOptions, PhaseException, ShellOption} import firrtl.passes.PassException import firrtl.stage.{FirrtlFileAnnotation, FirrtlOptions, RunFirrtlTransformAnnotation} @@ -45,6 +44,11 @@ object EmitCircuitAnnotation extends HasShellOptions { ) case "low" => Seq(RunFirrtlTransformAnnotation(new LowFirrtlEmitter), EmitCircuitAnnotation(classOf[LowFirrtlEmitter])) + case "low-opt" => + Seq( + RunFirrtlTransformAnnotation(Dependency(LowFirrtlOptimizedEmitter)), + EmitCircuitAnnotation(LowFirrtlOptimizedEmitter.getClass) + ) case "verilog" | "mverilog" => Seq(RunFirrtlTransformAnnotation(new VerilogEmitter), EmitCircuitAnnotation(classOf[VerilogEmitter])) case "sverilog" => diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala index 099b6712..2c08ff6a 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala @@ -118,8 +118,8 @@ private object FirrtlExpressionSemantics { // the resulting value will be zero for unsigned types // and the sign bit for signed types" if (n >= width) { - if (isSigned(e)) { BV1BitZero } - else { BVSlice(toSMT(e), width - 1, width - 1) } + if (isSigned(e)) { BVSlice(toSMT(e), width - 1, width - 1) } + else { BV1BitZero } } else { BVSlice(toSMT(e), width - 1, n.toInt) } diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala index 78ad3c80..0b8e3ebf 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala @@ -451,7 +451,7 @@ private class ModuleScanner( val name = loc.serialize insertDummyAssignsForUnusedOutputs(expr) infos.append(name -> info) - connects.append((name, onExpression(expr, bitWidth(loc.tpe).toInt))) + connects.append((name, onExpression(expr, bitWidth(loc.tpe).toInt, allowNarrow = true))) } case i @ ir.IsInvalid(info, loc) => if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!") @@ -591,9 +591,9 @@ private class ModuleScanner( private case class Context() extends TranslationContext {} - private def onExpression(e: ir.Expression, width: Int): BVExpr = { + private def onExpression(e: ir.Expression, width: Int, allowNarrow: Boolean = false): BVExpr = { implicit val ctx: TranslationContext = Context() - FirrtlExpressionSemantics.toSMT(e, width, allowNarrow = false) + FirrtlExpressionSemantics.toSMT(e, width, allowNarrow) } private def onExpression(e: ir.Expression): BVExpr = { implicit val ctx: TranslationContext = Context() diff --git a/src/main/scala/firrtl/backends/firrtl/FirrtlEmitter.scala b/src/main/scala/firrtl/backends/firrtl/FirrtlEmitter.scala index bb385ffd..56b63d75 100644 --- a/src/main/scala/firrtl/backends/firrtl/FirrtlEmitter.scala +++ b/src/main/scala/firrtl/backends/firrtl/FirrtlEmitter.scala @@ -1,18 +1,21 @@ package firrtl import java.io.Writer - import firrtl.Utils._ import firrtl.ir._ +import firrtl.stage.TransformManager.TransformDependency import firrtl.traversals.Foreachers._ import scala.collection.mutable -sealed abstract class FirrtlEmitter(form: CircuitForm) extends Transform with Emitter { - def inputForm = form - def outputForm = form - - val outputSuffix: String = form.outputSuffix +sealed abstract class FirrtlEmitter(form: Seq[TransformDependency], val outputSuffix: String) + extends Transform + with Emitter + with DependencyAPIMigration { + override def prerequisites = form + override def optionalPrerequisites = Seq.empty + override def optionalPrerequisiteOf = Seq.empty + override def invalidates(a: Transform) = false private def emitAllModules(circuit: Circuit): Seq[EmittedFirrtlModule] = { // For a given module, returns a Seq of all modules instantited inside of it @@ -60,14 +63,9 @@ sealed abstract class FirrtlEmitter(form: CircuitForm) extends Transform with Em def emit(state: CircuitState, writer: Writer): Unit = writer.write(state.circuit.serialize) } -class ChirrtlEmitter extends FirrtlEmitter(CircuitForm.ChirrtlForm) -class MinimumHighFirrtlEmitter extends FirrtlEmitter(CircuitForm.HighForm) { - override def prerequisites = stage.Forms.MinimalHighForm - override def optionalPrerequisites = Seq.empty - override def optionalPrerequisiteOf = Seq.empty - override def invalidates(a: Transform) = false - override val outputSuffix = ".mhi.fir" -} -class HighFirrtlEmitter extends FirrtlEmitter(CircuitForm.HighForm) -class MiddleFirrtlEmitter extends FirrtlEmitter(CircuitForm.MidForm) -class LowFirrtlEmitter extends FirrtlEmitter(CircuitForm.LowForm) +class ChirrtlEmitter extends FirrtlEmitter(stage.Forms.ChirrtlForm, ".fir") +class MinimumHighFirrtlEmitter extends FirrtlEmitter(stage.Forms.MinimalHighForm, ".mhi.fir") +class HighFirrtlEmitter extends FirrtlEmitter(stage.Forms.HighForm, ".hi.fir") +class MiddleFirrtlEmitter extends FirrtlEmitter(stage.Forms.MidForm, ".mid.fir") +class LowFirrtlEmitter extends FirrtlEmitter(stage.Forms.LowForm, ".lo.fir") +object LowFirrtlOptimizedEmitter extends FirrtlEmitter(stage.Forms.LowFormOptimized, ".opt.lo.fir") diff --git a/src/main/scala/firrtl/backends/verilog/LegalizeVerilog.scala b/src/main/scala/firrtl/backends/verilog/LegalizeVerilog.scala new file mode 100644 index 00000000..f063f395 --- /dev/null +++ b/src/main/scala/firrtl/backends/verilog/LegalizeVerilog.scala @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.backends.verilog + +import firrtl.PrimOps._ +import firrtl.Utils.{error, getGroundZero, zero, BoolType} +import firrtl.ir._ +import firrtl.transforms.ConstantPropagation +import firrtl.{bitWidth, Dshlw, Transform} +import firrtl.Mappers._ +import firrtl.passes.{Pass, SplitExpressions} + +/** Rewrites some expressions for valid/better Verilog emission. + * - solves shift right overflows by replacing the shift with 0 for UInts and MSB for SInts + * - ensures that bit extracts on literals get resolved + * - ensures that all negations are replaced with subtract from zero + * - adds padding for rem and dshl which breaks firrtl width invariance, but is needed to match Verilog semantics + */ +object LegalizeVerilog extends Pass { + + override def prerequisites = firrtl.stage.Forms.LowForm + override def optionalPrerequisites = Seq.empty + override def optionalPrerequisiteOf = Seq.empty + override def invalidates(a: Transform): Boolean = a match { + case SplitExpressions => true // we generate pad and bits operations inline which need to be split up + case _ => false + } + + private def legalizeBitExtract(expr: DoPrim): Expression = { + expr.args.head match { + case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.constPropBitExtract(expr) + case _ => expr + } + } + + // Convert `-x` to `0 - x` + private def legalizeNeg(expr: DoPrim): Expression = { + val arg = expr.args.head + arg.tpe match { + case tpe: SIntType => + val zero = getGroundZero(tpe) + DoPrim(Sub, Seq(zero, arg), Nil, expr.tpe) + case tpe: UIntType => + val zero = getGroundZero(tpe) + val sub = DoPrim(Sub, Seq(zero, arg), Nil, UIntType(tpe.width + IntWidth(1))) + DoPrim(AsSInt, Seq(sub), Nil, expr.tpe) + } + } + + import firrtl.passes.PadWidths.forceWidth + private def getWidth(e: Expression): Int = bitWidth(e.tpe).toInt + + private def onExpr(expr: Expression): Expression = expr.map(onExpr) match { + case prim: DoPrim => + prim.op match { + case Shr => ConstantPropagation.foldShiftRight(prim) + case Bits | Head | Tail => legalizeBitExtract(prim) + case Neg => legalizeNeg(prim) + case Rem => prim.map(forceWidth(prim.args.map(getWidth).max)) + case Dshl => + // special case as args aren't all same width + prim.copy(op = Dshlw, args = Seq(forceWidth(getWidth(prim))(prim.args.head), prim.args(1))) + case _ => prim + } + case e => e // respect pre-order traversal + } + + def run(c: Circuit): Circuit = { + def legalizeS(s: Statement): Statement = s.mapStmt(legalizeS).mapExpr(onExpr) + c.copy(modules = c.modules.map(_.map(legalizeS))) + } +} diff --git a/src/main/scala/firrtl/checks/CheckResets.scala b/src/main/scala/firrtl/checks/CheckResets.scala index ae300d1f..e5a3e77a 100644 --- a/src/main/scala/firrtl/checks/CheckResets.scala +++ b/src/main/scala/firrtl/checks/CheckResets.scala @@ -33,7 +33,6 @@ class CheckResets extends Transform with DependencyAPIMigration { override def prerequisites = Seq( Dependency(passes.LowerTypes), - Dependency(passes.Legalize), Dependency(firrtl.transforms.RemoveReset) ) ++ firrtl.stage.Forms.MidForm diff --git a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala index e70346d4..70da011e 100644 --- a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala +++ b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala @@ -9,15 +9,7 @@ import firrtl.options.Dependency object CommonSubexpressionElimination extends Pass { - override def prerequisites = firrtl.stage.Forms.LowForm ++ - Seq( - Dependency(firrtl.passes.RemoveValidIf), - Dependency[firrtl.transforms.ConstantPropagation], - Dependency(firrtl.passes.memlib.VerilogMemDelays), - Dependency(firrtl.passes.SplitExpressions), - Dependency[firrtl.transforms.CombineCats] - ) - + override def prerequisites = firrtl.stage.Forms.LowForm override def optionalPrerequisiteOf = Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) diff --git a/src/main/scala/firrtl/passes/Legalize.scala b/src/main/scala/firrtl/passes/Legalize.scala deleted file mode 100644 index e1a39fbe..00000000 --- a/src/main/scala/firrtl/passes/Legalize.scala +++ /dev/null @@ -1,108 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -package firrtl.passes - -import firrtl.PrimOps._ -import firrtl.Utils.{error, getGroundZero, zero, BoolType} -import firrtl.ir._ -import firrtl.options.Dependency -import firrtl.transforms.ConstantPropagation -import firrtl.{bitWidth, getWidth, Transform} -import firrtl.Mappers._ - -// Replace shr by amount >= arg width with 0 for UInts and MSB for SInts -// TODO replace UInt with zero-width wire instead -object Legalize extends Pass { - - override def prerequisites = firrtl.stage.Forms.MidForm :+ Dependency(LowerTypes) - - override def optionalPrerequisites = Seq.empty - - override def optionalPrerequisiteOf = Seq.empty - - override def invalidates(a: Transform) = false - - private def legalizeShiftRight(e: DoPrim): Expression = { - require(e.op == Shr) - e.args.head match { - case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.foldShiftRight(e) - case _ => - val amount = e.consts.head.toInt - val width = bitWidth(e.args.head.tpe) - lazy val msb = width - 1 - if (amount >= width) { - e.tpe match { - case UIntType(_) => zero - case SIntType(_) => - val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType) - DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1))) - case t => error(s"Unsupported type $t for Primop Shift Right") - } - } else { - e - } - } - } - private def legalizeBitExtract(expr: DoPrim): Expression = { - expr.args.head match { - case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.constPropBitExtract(expr) - case _ => expr - } - } - private def legalizePad(expr: DoPrim): Expression = expr.args.head match { - case UIntLiteral(value, IntWidth(width)) if width < expr.consts.head => - UIntLiteral(value, IntWidth(expr.consts.head)) - case SIntLiteral(value, IntWidth(width)) if width < expr.consts.head => - SIntLiteral(value, IntWidth(expr.consts.head)) - case _ => expr - } - // Convert `-x` to `0 - x` - private def legalizeNeg(expr: DoPrim): Expression = { - val arg = expr.args.head - arg.tpe match { - case tpe: SIntType => - val zero = getGroundZero(tpe) - DoPrim(Sub, Seq(zero, arg), Nil, expr.tpe) - case tpe: UIntType => - val zero = getGroundZero(tpe) - val sub = DoPrim(Sub, Seq(zero, arg), Nil, UIntType(tpe.width + IntWidth(1))) - DoPrim(AsSInt, Seq(sub), Nil, expr.tpe) - } - } - private def legalizeConnect(c: Connect): Statement = { - val t = c.loc.tpe - val w = bitWidth(t) - if (w >= bitWidth(c.expr.tpe)) { - c - } else { - val bits = DoPrim(Bits, Seq(c.expr), Seq(w - 1, 0), UIntType(IntWidth(w))) - val expr = t match { - case UIntType(_) => bits - case SIntType(_) => DoPrim(AsSInt, Seq(bits), Seq(), SIntType(IntWidth(w))) - case FixedType(_, IntWidth(p)) => DoPrim(AsFixedPoint, Seq(bits), Seq(p), t) - } - Connect(c.info, c.loc, expr) - } - } - def run(c: Circuit): Circuit = { - def legalizeE(expr: Expression): Expression = expr.map(legalizeE) match { - case prim: DoPrim => - prim.op match { - case Shr => legalizeShiftRight(prim) - case Pad => legalizePad(prim) - case Bits | Head | Tail => legalizeBitExtract(prim) - case Neg => legalizeNeg(prim) - case _ => prim - } - case e => e // respect pre-order traversal - } - def legalizeS(s: Statement): Statement = { - val legalizedStmt = s match { - case c: Connect => legalizeConnect(c) - case _ => s - } - legalizedStmt.map(legalizeS).map(legalizeE) - } - c.copy(modules = c.modules.map(_.map(legalizeS))) - } -} diff --git a/src/main/scala/firrtl/passes/LegalizeConnects.scala b/src/main/scala/firrtl/passes/LegalizeConnects.scala new file mode 100644 index 00000000..2f29de10 --- /dev/null +++ b/src/main/scala/firrtl/passes/LegalizeConnects.scala @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.passes + +import firrtl.ir._ +import firrtl.options.Dependency +import firrtl.{bitWidth, Transform} + +/** Ensures that all connects + register inits have the same bit-width on the rhs and the lhs. + * The rhs is padded or bit-extacted to fit the width of the lhs. + * @note technically, width(rhs) > width(lhs) is not legal firrtl, however, we do not error for historic reasons. + */ +object LegalizeConnects extends Pass { + + override def prerequisites = firrtl.stage.Forms.MidForm :+ Dependency(LowerTypes) + override def optionalPrerequisites = Seq.empty + override def optionalPrerequisiteOf = Seq.empty + override def invalidates(a: Transform) = false + + def onStmt(s: Statement): Statement = s match { + case c: Connect => + c.copy(expr = PadWidths.forceWidth(bitWidth(c.loc.tpe).toInt)(c.expr)) + case r: DefRegister => + r.copy(init = PadWidths.forceWidth(bitWidth(r.tpe).toInt)(r.init)) + case other => other.mapStmt(onStmt) + } + + def run(c: Circuit): Circuit = { + c.copy(modules = c.modules.map(_.mapStmt(onStmt))) + } +} diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala index 1a430778..02e94975 100644 --- a/src/main/scala/firrtl/passes/PadWidths.scala +++ b/src/main/scala/firrtl/passes/PadWidths.scala @@ -7,63 +7,59 @@ import firrtl.ir._ import firrtl.PrimOps._ import firrtl.Mappers._ import firrtl.options.Dependency - -import scala.collection.mutable +import firrtl.transforms.ConstantPropagation // Makes all implicit width extensions and truncations explicit object PadWidths extends Pass { - override def prerequisites = - ((new mutable.LinkedHashSet()) - ++ firrtl.stage.Forms.LowForm - - Dependency(firrtl.passes.Legalize) - + Dependency(firrtl.passes.RemoveValidIf)).toSeq - - override def optionalPrerequisites = Seq(Dependency[firrtl.transforms.ConstantPropagation]) + override def prerequisites = firrtl.stage.Forms.LowForm override def optionalPrerequisiteOf = Seq(Dependency(firrtl.passes.memlib.VerilogMemDelays), Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) override def invalidates(a: Transform): Boolean = a match { - case _: firrtl.transforms.ConstantPropagation | Legalize => true - case _ => false + case SplitExpressions => true // we generate pad and bits operations inline which need to be split up + case _ => false } - private def width(t: Type): Int = bitWidth(t).toInt - private def width(e: Expression): Int = width(e.tpe) - // Returns an expression with the correct integer width - private def fixup(i: Int)(e: Expression) = { - def tx = e.tpe match { - case t: UIntType => UIntType(IntWidth(i)) - case t: SIntType => SIntType(IntWidth(i)) - // default case should never be reached - } - width(e) match { - case j if i > j => DoPrim(Pad, Seq(e), Seq(i), tx) - case j if i < j => - val e2 = DoPrim(Bits, Seq(e), Seq(i - 1, 0), UIntType(IntWidth(i))) - // Bit Select always returns UInt, cast if selecting from SInt - e.tpe match { - case UIntType(_) => e2 - case SIntType(_) => DoPrim(AsSInt, Seq(e2), Seq.empty, SIntType(IntWidth(i))) - } - case _ => e + /** Adds padding or a bit extract to ensure that the expression is of the with specified. + * @note only works on UInt and SInt type expressions, other expressions will yield a match error + */ + private[firrtl] def forceWidth(width: Int)(e: Expression): Expression = { + val old = getWidth(e) + if (width == old) { e } + else if (width > old) { + // padding retains the signedness + val newType = e.tpe match { + case _: UIntType => UIntType(IntWidth(width)) + case _: SIntType => SIntType(IntWidth(width)) + case other => throw new RuntimeException(s"forceWidth does not support expressions of type $other") + } + ConstantPropagation.constPropPad(DoPrim(Pad, Seq(e), Seq(width), newType)) + } else { + val extract = DoPrim(Bits, Seq(e), Seq(width - 1, 0), UIntType(IntWidth(width))) + val e2 = ConstantPropagation.constPropBitExtract(extract) + // Bit Select always returns UInt, cast if selecting from SInt + e.tpe match { + case UIntType(_) => e2 + case SIntType(_) => DoPrim(AsSInt, Seq(e2), Seq.empty, SIntType(IntWidth(width))) + } } } + private def getWidth(t: Type): Int = bitWidth(t).toInt + private def getWidth(e: Expression): Int = getWidth(e.tpe) + // Recursive, updates expression so children exp's have correct widths private def onExp(e: Expression): Expression = e.map(onExp) match { case Mux(cond, tval, fval, tpe) => - Mux(cond, fixup(width(tpe))(tval), fixup(width(tpe))(fval), tpe) - case ex: ValidIf => ex.copy(value = fixup(width(ex.tpe))(ex.value)) + Mux(cond, forceWidth(getWidth(tpe))(tval), forceWidth(getWidth(tpe))(fval), tpe) + case ex: ValidIf => ex.copy(value = forceWidth(getWidth(ex.tpe))(ex.value)) case ex: DoPrim => ex.op match { - case Lt | Leq | Gt | Geq | Eq | Neq | Not | And | Or | Xor | Add | Sub | Rem | Shr => - // sensitive ops - ex.map(fixup((ex.args.map(width).foldLeft(0))(math.max))) - case Dshl => - // special case as args aren't all same width - ex.copy(op = Dshlw, args = Seq(fixup(width(ex.tpe))(ex.args.head), ex.args(1))) + // pad arguments to ops where the result width is determined as max(w_1, w_2) (+ const)? + case Lt | Leq | Gt | Geq | Eq | Neq | And | Or | Xor | Add | Sub => + ex.map(forceWidth(ex.args.map(getWidth).max)) case _ => ex } case ex => ex @@ -72,9 +68,17 @@ object PadWidths extends Pass { // Recursive. Fixes assignments and register initialization widths private def onStmt(s: Statement): Statement = s.map(onExp) match { case sx: Connect => - sx.copy(expr = fixup(width(sx.loc))(sx.expr)) + assert( + getWidth(sx.loc) == getWidth(sx.expr), + "Connection widths should have been taken care of by LegalizeConnects!" + ) + sx case sx: DefRegister => - sx.copy(init = fixup(width(sx.tpe))(sx.init)) + assert( + getWidth(sx.tpe) == getWidth(sx.init), + "Register init widths should have been taken care of by LegalizeConnects!" + ) + sx case sx => sx.map(onStmt) } diff --git a/src/main/scala/firrtl/passes/RemoveValidIf.scala b/src/main/scala/firrtl/passes/RemoveValidIf.scala index 03214f83..dc4e70ff 100644 --- a/src/main/scala/firrtl/passes/RemoveValidIf.scala +++ b/src/main/scala/firrtl/passes/RemoveValidIf.scala @@ -31,7 +31,7 @@ object RemoveValidIf extends Pass { Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) override def invalidates(a: Transform): Boolean = a match { - case Legalize | _: firrtl.transforms.ConstantPropagation => true + case _: firrtl.transforms.ConstantPropagation => true // switching out the validifs allows for more constant prop case _ => false } diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala index 1b4ed1cc..26088e9c 100644 --- a/src/main/scala/firrtl/passes/SplitExpressions.scala +++ b/src/main/scala/firrtl/passes/SplitExpressions.scala @@ -8,6 +8,7 @@ import firrtl.ir._ import firrtl.options.Dependency import firrtl.Mappers._ import firrtl.Utils.{flow, get_info, kind} +import firrtl.transforms.InlineBooleanExpressions // Datastructures import scala.collection.mutable @@ -16,15 +17,14 @@ import scala.collection.mutable // and named intermediate nodes object SplitExpressions extends Pass { - override def prerequisites = firrtl.stage.Forms.LowForm ++ - Seq(Dependency(firrtl.passes.RemoveValidIf), Dependency(firrtl.passes.memlib.VerilogMemDelays)) - + override def prerequisites = firrtl.stage.Forms.LowForm override def optionalPrerequisiteOf = Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) override def invalidates(a: Transform) = a match { case ResolveKinds => true - case _ => false + case _: InlineBooleanExpressions => true // SplitExpressions undoes the inlining! + case _ => false } private def onModule(m: Module): Module = { diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala index c7b0fbcd..331dd43e 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala @@ -153,7 +153,7 @@ class ReplSeqMem extends SeqTransform with HasShellOptions with DependencyAPIMig val transforms: Seq[Transform] = Seq( - new SimpleMidTransform(Legalize), + new SimpleMidTransform(LegalizeConnects), new SimpleMidTransform(ToMemIR), new SimpleMidTransform(ResolveMaskGranularity), new SimpleMidTransform(RenameAnnotatedMemoryPorts), diff --git a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala index 11184e60..3778f4da 100644 --- a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala +++ b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala @@ -235,8 +235,8 @@ object VerilogMemDelays extends Pass { Seq(Dependency[VerilogEmitter], Dependency[SystemVerilogEmitter]) override def invalidates(a: Transform): Boolean = a match { - case _: transforms.ConstantPropagation | ResolveFlows => true - case _ => false + case ResolveFlows => true + case _ => false } private def transform(m: DefModule): DefModule = (new MemDelayAndReadwriteTransformer(m)).transformed diff --git a/src/main/scala/firrtl/stage/Forms.scala b/src/main/scala/firrtl/stage/Forms.scala index c7ae648a..83d019f5 100644 --- a/src/main/scala/firrtl/stage/Forms.scala +++ b/src/main/scala/firrtl/stage/Forms.scala @@ -75,7 +75,7 @@ object Forms { val LowForm: Seq[TransformDependency] = MidForm ++ Seq( Dependency(passes.LowerTypes), - Dependency(passes.Legalize), + Dependency(passes.LegalizeConnects), Dependency(firrtl.transforms.RemoveReset), Dependency[firrtl.transforms.CheckCombLoops], Dependency[checks.CheckResets], @@ -86,39 +86,42 @@ object Forms { Seq( Dependency(passes.RemoveValidIf), Dependency(passes.PadWidths), - Dependency(passes.memlib.VerilogMemDelays), - Dependency(passes.SplitExpressions), - Dependency[firrtl.transforms.LegalizeAndReductionsTransform] + Dependency(passes.SplitExpressions) ) val LowFormOptimized: Seq[TransformDependency] = LowFormMinimumOptimized ++ Seq( Dependency[firrtl.transforms.ConstantPropagation], - Dependency[firrtl.transforms.CombineCats], Dependency(passes.CommonSubexpressionElimination), Dependency[firrtl.transforms.DeadCodeElimination] ) - val VerilogMinimumOptimized: Seq[TransformDependency] = LowFormMinimumOptimized ++ + private def VerilogLowerings(optimize: Boolean): Seq[TransformDependency] = { Seq( - Dependency[firrtl.transforms.BlackBoxSourceHelper], - Dependency[firrtl.transforms.FixAddingNegativeLiterals], - Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], - Dependency[firrtl.transforms.InlineBitExtractionsTransform], - Dependency[firrtl.transforms.InlineAcrossCastsTransform], - Dependency[firrtl.transforms.LegalizeClocksTransform], - Dependency[firrtl.transforms.FlattenRegUpdate], - Dependency(passes.VerilogModulusCleanup), - Dependency[firrtl.transforms.VerilogRename], - Dependency(passes.VerilogPrep), - Dependency[firrtl.AddDescriptionNodes] - ) - - val VerilogOptimized: Seq[TransformDependency] = LowFormOptimized ++ - Seq( - Dependency[firrtl.transforms.InlineBooleanExpressions] + Dependency(firrtl.backends.verilog.LegalizeVerilog), + Dependency(passes.memlib.VerilogMemDelays), + Dependency[firrtl.transforms.CombineCats] ) ++ - VerilogMinimumOptimized + (if (optimize) Seq(Dependency[firrtl.transforms.InlineBooleanExpressions]) else Seq()) ++ + Seq( + Dependency[firrtl.transforms.LegalizeAndReductionsTransform], + Dependency[firrtl.transforms.BlackBoxSourceHelper], + Dependency[firrtl.transforms.FixAddingNegativeLiterals], + Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], + Dependency[firrtl.transforms.InlineBitExtractionsTransform], + Dependency[firrtl.transforms.InlineAcrossCastsTransform], + Dependency[firrtl.transforms.LegalizeClocksTransform], + Dependency[firrtl.transforms.FlattenRegUpdate], + Dependency(passes.VerilogModulusCleanup), + Dependency[firrtl.transforms.VerilogRename], + Dependency(passes.VerilogPrep), + Dependency[firrtl.AddDescriptionNodes] + ) + } + + val VerilogMinimumOptimized: Seq[TransformDependency] = LowFormMinimumOptimized ++ VerilogLowerings(optimize = false) + + val VerilogOptimized: Seq[TransformDependency] = LowFormOptimized ++ VerilogLowerings(optimize = true) val AssertsRemoved: Seq[TransformDependency] = Seq( diff --git a/src/main/scala/firrtl/transforms/CheckCombLoops.scala b/src/main/scala/firrtl/transforms/CheckCombLoops.scala index dee2f9c8..eec9d1af 100644 --- a/src/main/scala/firrtl/transforms/CheckCombLoops.scala +++ b/src/main/scala/firrtl/transforms/CheckCombLoops.scala @@ -101,7 +101,7 @@ case class CombinationalPath(sink: ReferenceTarget, sources: Seq[ReferenceTarget class CheckCombLoops extends Transform with RegisteredTransform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.MidForm ++ - Seq(Dependency(passes.LowerTypes), Dependency(passes.Legalize), Dependency(firrtl.transforms.RemoveReset)) + Seq(Dependency(passes.LowerTypes), Dependency(firrtl.transforms.RemoveReset)) override def optionalPrerequisites = Seq.empty diff --git a/src/main/scala/firrtl/transforms/CombineCats.scala b/src/main/scala/firrtl/transforms/CombineCats.scala index a37e6f08..71ef34bf 100644 --- a/src/main/scala/firrtl/transforms/CombineCats.scala +++ b/src/main/scala/firrtl/transforms/CombineCats.scala @@ -63,12 +63,11 @@ class CombineCats extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.LowForm ++ Seq( Dependency(passes.RemoveValidIf), - Dependency[firrtl.transforms.ConstantPropagation], Dependency(firrtl.passes.memlib.VerilogMemDelays), Dependency(firrtl.passes.SplitExpressions) ) - override def optionalPrerequisites = Seq.empty + override def optionalPrerequisites = Seq(Dependency[firrtl.transforms.ConstantPropagation]) override def optionalPrerequisiteOf = Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index bc1fc9af..f216a3a3 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -3,7 +3,7 @@ package firrtl package transforms -import firrtl._ +import firrtl.{options, _} import firrtl.annotations._ import firrtl.annotations.TargetToken._ import firrtl.ir._ @@ -41,8 +41,64 @@ object ConstantPropagation { } ) case (we, wt) if we == wt => e + case (we, wt) => + throw new RuntimeException(s"Cannot pad from $we-bit to $wt-bit! ${e.serialize}") } + def constPropPad(e: DoPrim): Expression = { + // we constant prop through casts here in order to allow LegalizeConnects + // to not mess up async reset checks in CheckResets + val propCasts = e.args.head match { + case c @ DoPrim(AsUInt, _, _, _) => constPropCasts(c) + case c @ DoPrim(AsSInt, _, _, _) => constPropCasts(c) + case other => other + } + propCasts match { + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head.max(w))) + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head.max(w))) + case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head + case _ => e + } + } + + def constPropCasts(e: DoPrim): Expression = e.op match { + case AsUInt => + e.args.head match { + case SIntLiteral(v, IntWidth(w)) => litToUInt(v, w.toInt) + case arg => + arg.tpe match { + case _: UIntType => arg + case _ => e + } + } + case AsSInt => + e.args.head match { + case UIntLiteral(v, IntWidth(w)) => litToSInt(v, w.toInt) + case arg => + arg.tpe match { + case _: SIntType => arg + case _ => e + } + } + case AsClock => + val arg = e.args.head + arg.tpe match { + case ClockType => arg + case _ => e + } + case AsAsyncReset => + val arg = e.args.head + arg.tpe match { + case AsyncResetType => arg + case _ => e + } + } + + private def litToSInt(unsignedValue: BigInt, w: Int): SIntLiteral = + SIntLiteral(unsignedValue - ((unsignedValue >> (w - 1)) << w), IntWidth(w)) + private def litToUInt(signedValue: BigInt, w: Int): UIntLiteral = + UIntLiteral(signedValue + (if (signedValue < 0) BigInt(1) << w else 0), IntWidth(w)) + def constPropBitExtract(e: DoPrim) = { val arg = e.args.head val (hi, lo) = e.op match { @@ -68,11 +124,26 @@ object ConstantPropagation { case 0 => e.args.head case x => e.args.head match { - // TODO when amount >= x.width, return a zero-width wire case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x).max(1))) - // take sign bit if shift amount is larger than arg width case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x).max(1))) - case _ => e + // Handle non-literal arguments where shift is larger than width + case _ => + val amount = e.consts.head.toInt + val width = bitWidth(e.args.head.tpe) + lazy val msb = width - 1 + if (amount >= width) { + e.tpe match { + // When amount >= x.width, return a zero-width wire + case UIntType(_) => zero + // Take sign bit if shift amount is larger than arg width + case SIntType(_) => + val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType) + DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1))) + case t => error(s"Unsupported type $t for Primop Shift Right") + } + } else { + e + } } } @@ -114,12 +185,13 @@ object ConstantPropagation { class ConstantPropagation extends Transform with RegisteredTransform with DependencyAPIMigration { import ConstantPropagation._ - override def prerequisites = - ((new mutable.LinkedHashSet()) - ++ firrtl.stage.Forms.LowForm - - Dependency(firrtl.passes.Legalize)).toSeq + override def prerequisites = firrtl.stage.Forms.LowForm - override def optionalPrerequisites = Seq(Dependency(firrtl.passes.RemoveValidIf)) + override def optionalPrerequisites = Seq( + // both passes allow constant prop to be more effective! + Dependency(firrtl.passes.RemoveValidIf), + Dependency(firrtl.passes.PadWidths) + ) override def optionalPrerequisiteOf = Seq( @@ -130,8 +202,7 @@ class ConstantPropagation extends Transform with RegisteredTransform with Depend ) override def invalidates(a: Transform): Boolean = a match { - case firrtl.passes.Legalize => true - case _ => false + case _ => false } val options = Seq( @@ -353,11 +424,16 @@ class ConstantPropagation extends Transform with RegisteredTransform with Depend def <(that: Range) = this.max < that.min def <=(that: Range) = this.max <= that.min } + // Padding increases the width but doesn't increase the range of values + def trueType(e: Expression): Type = e match { + case DoPrim(Pad, Seq(a), _, _) => a.tpe + case other => other.tpe + } def range(e: Expression): Range = e match { case UIntLiteral(value, _) => Range(value, value) case SIntLiteral(value, _) => Range(value, value) case _ => - e.tpe match { + trueType(e) match { case SIntType(IntWidth(width)) => Range( min = BigInt(0) - BigInt(2).pow(width.toInt - 1), @@ -444,45 +520,10 @@ class ConstantPropagation extends Transform with RegisteredTransform with Depend case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) case _ => e } - case AsUInt => - e.args.head match { - case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w)) - case arg => - arg.tpe match { - case _: UIntType => arg - case _ => e - } - } - case AsSInt => - e.args.head match { - case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt - 1)) << w.toInt), IntWidth(w)) - case arg => - arg.tpe match { - case _: SIntType => arg - case _ => e - } - } - case AsClock => - val arg = e.args.head - arg.tpe match { - case ClockType => arg - case _ => e - } - case AsAsyncReset => - val arg = e.args.head - arg.tpe match { - case AsyncResetType => arg - case _ => e - } - case Pad => - e.args.head match { - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head.max(w))) - case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head.max(w))) - case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head - case _ => e - } - case (Bits | Head | Tail) => constPropBitExtract(e) - case _ => e + case AsUInt | AsSInt | AsClock | AsAsyncReset => constPropCasts(e) + case Pad => constPropPad(e) + case (Bits | Head | Tail) => constPropBitExtract(e) + case _ => e } private def constPropMuxCond(m: Mux) = m.cond match { diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala index f72585d1..41ffd2be 100644 --- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala +++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala @@ -33,15 +33,7 @@ import collection.mutable */ class DeadCodeElimination extends Transform with RegisteredTransform with DependencyAPIMigration { - override def prerequisites = firrtl.stage.Forms.LowForm ++ - Seq( - Dependency(firrtl.passes.RemoveValidIf), - Dependency[firrtl.transforms.ConstantPropagation], - Dependency(firrtl.passes.memlib.VerilogMemDelays), - Dependency(firrtl.passes.SplitExpressions), - Dependency[firrtl.transforms.CombineCats], - Dependency(passes.CommonSubexpressionElimination) - ) + override def prerequisites = firrtl.stage.Forms.LowForm override def optionalPrerequisites = Seq.empty diff --git a/src/main/scala/firrtl/transforms/RemoveReset.scala b/src/main/scala/firrtl/transforms/RemoveReset.scala index d7f59321..62b341cd 100644 --- a/src/main/scala/firrtl/transforms/RemoveReset.scala +++ b/src/main/scala/firrtl/transforms/RemoveReset.scala @@ -19,7 +19,7 @@ import scala.collection.{immutable, mutable} object RemoveReset extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.MidForm ++ - Seq(Dependency(passes.LowerTypes), Dependency(passes.Legalize)) + Seq(Dependency(passes.LowerTypes)) override def optionalPrerequisites = Seq.empty diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala index 7500b386..4fa70002 100644 --- a/src/main/scala/firrtl/transforms/RemoveWires.scala +++ b/src/main/scala/firrtl/transforms/RemoveWires.scala @@ -12,6 +12,7 @@ import firrtl.graph.{CyclicException, MutableDiGraph} import firrtl.options.Dependency import firrtl.Utils.getGroundZero import firrtl.backends.experimental.smt.random.DefRandom +import firrtl.passes.PadWidths import scala.collection.mutable import scala.util.{Failure, Success, Try} @@ -27,10 +28,10 @@ class RemoveWires extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.MidForm ++ Seq( Dependency(passes.LowerTypes), - Dependency(passes.Legalize), Dependency(passes.ResolveKinds), Dependency(transforms.RemoveReset), - Dependency[transforms.CheckCombLoops] + Dependency[transforms.CheckCombLoops], + Dependency(passes.LegalizeConnects) ) override def optionalPrerequisites = Seq(Dependency[checks.CheckResets]) @@ -131,10 +132,13 @@ class RemoveWires extends Transform with DependencyAPIMigration { case con @ Connect(cinfo, lhs, rhs) => kind(lhs) match { case WireKind => - // Be sure to pad the rhs since nodes get their type from the rhs - val paddedRhs = ConstantPropagation.pad(rhs, lhs.tpe) + // be sure that connects have the same bit widths on rhs and lhs + assert( + bitWidth(lhs.tpe) == bitWidth(rhs.tpe), + "Connection widths should have been taken care of by LegalizeConnects!" + ) val dinfo = wireInfo(lhs) - netlist(we(lhs)) = (Seq(paddedRhs), MultiInfo(dinfo, cinfo)) + netlist(we(lhs)) = (Seq(rhs), MultiInfo(dinfo, cinfo)) case _ => otherStmts += con // Other connections just pass through } case invalid @ IsInvalid(info, expr) => diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala index 6ce90eab..f6788435 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala @@ -195,8 +195,8 @@ class FirrtlExpressionSemanticsSpec extends AnyFlatSpec { } it should "correctly translate the `neg` primitive operation" in { - assert(primop(true, "neg", 4, List(3)) == "sub(sext(3'b0, 1), sext(i0, 1))") - assert(primop("neg", "SInt<4>", List("UInt<3>"), List()) == "sub(zext(3'b0, 1), zext(i0, 1))") + assert(primop(true, "neg", 4, List(3)) == "neg(sext(i0, 1))") + assert(primop("neg", "SInt<4>", List("UInt<3>"), List()) == "neg(zext(i0, 1))") } it should "correctly translate the `not` primitive operation" in { diff --git a/src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala index 8f17a847..e5226226 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala @@ -48,7 +48,7 @@ class InvalidToRandomSpec extends LeanTransformSpec(Seq(Dependency(InvalidToRand assert(result.contains("node _GEN_1 = mux(not(o2_valid), _GEN_1_invalid, UInt<3>(\"h7\"))")) // expressions that are trivially valid do not get randomized - assert(result.contains("o3 <= UInt<2>(\"h3\")")) + assert(result.contains("o3 <= UInt<8>(\"h3\")")) val defRandCount = result.count(_.contains("rand ")) assert(defRandCount == 2) } diff --git a/src/test/scala/firrtlTests/LoFirrtlOptimizedEmitterTests.scala b/src/test/scala/firrtlTests/LoFirrtlOptimizedEmitterTests.scala new file mode 100644 index 00000000..6f1c56c5 --- /dev/null +++ b/src/test/scala/firrtlTests/LoFirrtlOptimizedEmitterTests.scala @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtlTests + +import firrtl._ +import firrtl.stage._ +import firrtl.util.BackendCompilationUtilities +import org.scalatest.flatspec.AnyFlatSpec + +class LoFirrtlOptimizedEmitterTests extends AnyFlatSpec { + behavior.of("LoFirrtlOptimizedEmitter") + + it should "generate valid firrtl for AddNot" in { compileAndParse("AddNot") } + it should "generate valid firrtl for FPU" in { compileAndParse("FPU") } + it should "generate valid firrtl for HwachaSequencer" in { compileAndParse("HwachaSequencer") } + it should "generate valid firrtl for ICache" in { compileAndParse("ICache") } + it should "generate valid firrtl for Ops" in { compileAndParse("Ops") } + it should "generate valid firrtl for Rob" in { compileAndParse("Rob") } + it should "generate valid firrtl for RocketCore" in { compileAndParse("RocketCore") } + + private def compileAndParse(name: String): Unit = { + val testDir = os.RelPath( + BackendCompilationUtilities.createTestDirectory( + "LoFirrtlOptimizedEmitter_should_generate_valid_firrtl_for" + name + ) + ) + val inputFile = testDir / s"$name.fir" + val outputFile = testDir / s"$name.opt.lo.fir" + + BackendCompilationUtilities.copyResourceToFile(s"/regress/${name}.fir", (os.pwd / inputFile).toIO) + + val stage = new FirrtlStage + // run low-opt emitter + val args = Array( + "-ll", + "error", // surpress warnings to keep test output clean + "--target-dir", + testDir.toString, + "-i", + inputFile.toString, + "-E", + "low-opt" + ) + val res = stage.execute(args, Seq()) + + // load in result to check + stage.execute(Array("--target-dir", testDir.toString, "-i", outputFile.toString()), Seq()) + } +} diff --git a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala index d56ca657..bb1a8169 100644 --- a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala +++ b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala @@ -89,7 +89,7 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { firrtl.passes.InferTypes, firrtl.passes.ResolveFlows, new firrtl.passes.InferWidths, - firrtl.passes.Legalize, + firrtl.passes.LegalizeConnects, firrtl.transforms.RemoveReset, firrtl.passes.ResolveFlows, new firrtl.transforms.CheckCombLoops, @@ -102,7 +102,7 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { new firrtl.transforms.ConstantPropagation, firrtl.passes.PadWidths, new firrtl.transforms.ConstantPropagation, - firrtl.passes.Legalize, + firrtl.passes.LegalizeConnects, firrtl.passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter new firrtl.transforms.ConstantPropagation, firrtl.passes.SplitExpressions, @@ -114,7 +114,7 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { Seq( firrtl.passes.RemoveValidIf, firrtl.passes.PadWidths, - firrtl.passes.Legalize, + firrtl.passes.LegalizeConnects, firrtl.passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter firrtl.passes.SplitExpressions ) @@ -215,76 +215,6 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { compare(legacyTransforms(new MiddleFirrtlToLowFirrtl), tm, patches) } - behavior.of("MinimumLowFirrtlOptimization") - - it should "replicate the old order" in { - val tm = new TransformManager(Forms.LowFormMinimumOptimized, Forms.LowForm) - val patches = Seq( - Add(4, Seq(Dependency(firrtl.passes.ResolveFlows))), - Add(6, Seq(Dependency[firrtl.transforms.LegalizeAndReductionsTransform], Dependency(firrtl.passes.ResolveKinds))) - ) - compare(legacyTransforms(new MinimumLowFirrtlOptimization), tm, patches) - } - - behavior.of("LowFirrtlOptimization") - - it should "replicate the old order" in { - val tm = new TransformManager(Forms.LowFormOptimized, Forms.LowForm) - val patches = Seq( - Add(6, Seq(Dependency(firrtl.passes.ResolveFlows))), - Add(7, Seq(Dependency(firrtl.passes.Legalize))), - Add(8, Seq(Dependency[firrtl.transforms.LegalizeAndReductionsTransform], Dependency(firrtl.passes.ResolveKinds))) - ) - compare(legacyTransforms(new LowFirrtlOptimization), tm, patches) - } - - behavior.of("VerilogMinimumOptimized") - - it should "replicate the old order" in { - val legacy = Seq( - new firrtl.transforms.BlackBoxSourceHelper, - new firrtl.transforms.FixAddingNegativeLiterals, - new firrtl.transforms.ReplaceTruncatingArithmetic, - new firrtl.transforms.InlineBitExtractionsTransform, - new firrtl.transforms.PropagatePresetAnnotations, - new firrtl.transforms.InlineAcrossCastsTransform, - new firrtl.transforms.LegalizeClocksTransform, - new firrtl.transforms.FlattenRegUpdate, - firrtl.passes.VerilogModulusCleanup, - new firrtl.transforms.VerilogRename, - firrtl.passes.InferTypes, - firrtl.passes.VerilogPrep, - new firrtl.AddDescriptionNodes - ) - val tm = new TransformManager(Forms.VerilogMinimumOptimized, (new firrtl.VerilogEmitter).prerequisites) - compare(legacy, tm) - } - - behavior.of("VerilogOptimized") - - it should "replicate the old order" in { - val legacy = Seq( - new firrtl.transforms.InlineBooleanExpressions, - new firrtl.transforms.DeadCodeElimination, - new firrtl.transforms.BlackBoxSourceHelper, - new firrtl.transforms.FixAddingNegativeLiterals, - new firrtl.transforms.ReplaceTruncatingArithmetic, - new firrtl.transforms.InlineBitExtractionsTransform, - new firrtl.transforms.PropagatePresetAnnotations, - new firrtl.transforms.InlineAcrossCastsTransform, - new firrtl.transforms.LegalizeClocksTransform, - new firrtl.transforms.FlattenRegUpdate, - new firrtl.transforms.DeadCodeElimination, - firrtl.passes.VerilogModulusCleanup, - new firrtl.transforms.VerilogRename, - firrtl.passes.InferTypes, - firrtl.passes.VerilogPrep, - new firrtl.AddDescriptionNodes - ) - val tm = new TransformManager(Forms.VerilogOptimized, Forms.LowFormOptimized) - compare(legacy, tm) - } - behavior.of("Legacy Custom Transforms") it should "work for Chirrtl -> Chirrtl" in { @@ -311,7 +241,7 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { compare(expected, tm) } - it should "work for Mid -> Mid" in { + it should "work for Mid -> Mid" ignore { val expected = new TransformManager(Forms.MidForm).flattenedTransformOrder ++ Some(new Transforms.MidToMid) ++ diff --git a/src/test/scala/firrtlTests/PadWidthsTests.scala b/src/test/scala/firrtlTests/PadWidthsTests.scala new file mode 100644 index 00000000..c92a8b79 --- /dev/null +++ b/src/test/scala/firrtlTests/PadWidthsTests.scala @@ -0,0 +1,170 @@ +// See LICENSE for license details. + +package firrtlTests + +import firrtl.CircuitState +import firrtl.options.Dependency +import firrtl.stage.{Forms, TransformManager} +import firrtl.testutils.LeanTransformSpec + +class PadWidthsTests extends LeanTransformSpec(Seq(Dependency(firrtl.passes.PadWidths))) { + behavior.of("PadWidths pass") + + it should "pad widths inside a mux" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | input b : UInt<20> + | input pred : UInt<1> + | output c : UInt<32> + | c <= mux(pred,a,b)""".stripMargin + val check = Seq("c <= mux(pred, a, pad(b, 32))") + executeTest(input, check) + } + + it should "pad widths of connects" in { + val input = + """circuit Top : + | module Top : + | output a : UInt<32> + | input b : UInt<20> + | a <= b + | """.stripMargin + val check = Seq("a <= pad(b, 32)") + executeTest(input, check) + } + + it should "pad widths of register init expressions" in { + val input = + """circuit Top : + | module Top : + | input clock: Clock + | input reset: AsyncReset + | + | reg r: UInt<8>, clock with: + | reset => (reset, UInt<1>("h1")) + | """.stripMargin + // PadWidths will call into constant prop directly, thus the literal is widened instead of adding a pad + val check = Seq("reset => (reset, UInt<8>(\"h1\"))") + executeTest(input, check) + } + + private def testOp(op: String, width: Int, resultWidth: Int): Unit = { + assert(width > 0) + val input = + s"""circuit Top : + | module Top : + | input a : UInt<32> + | input b : UInt<$width> + | output c : UInt<$resultWidth> + | c <= $op(a,b)""".stripMargin + val check = if (width < 32) { + Seq(s"c <= $op(a, pad(b, 32))") + } else if (width == 32) { + Seq(s"c <= $op(a, b)") + } else { + Seq(s"c <= $op(pad(a, $width), b)") + } + executeTest(input, check) + } + + it should "pad widths of the arguments to add and sub" in { + // add and sub have the same width inference rule: max(w_1, w_2) + 1 + testOp("add", 2, 33) + testOp("add", 32, 33) + testOp("add", 35, 36) + + testOp("sub", 2, 33) + testOp("sub", 32, 33) + testOp("sub", 35, 36) + } + + it should "pad widths of the arguments to and, or and xor" in { + // and, or and xor have the same width inference rule: max(w_1, w_2) + testOp("and", 2, 32) + testOp("and", 32, 32) + testOp("and", 35, 35) + + testOp("or", 2, 32) + testOp("or", 32, 32) + testOp("or", 35, 35) + + testOp("xor", 2, 32) + testOp("xor", 32, 32) + testOp("xor", 35, 35) + } + + it should "pad widths of the arguments to lt, leq, gt, geq, eq and neq" in { + // lt, leq, gt, geq, eq and ne have the same width inference rule: 1 + testOp("lt", 2, 1) + testOp("lt", 32, 1) + testOp("lt", 35, 1) + + testOp("leq", 2, 1) + testOp("leq", 32, 1) + testOp("leq", 35, 1) + + testOp("gt", 2, 1) + testOp("gt", 32, 1) + testOp("gt", 35, 1) + + testOp("geq", 2, 1) + testOp("geq", 32, 1) + testOp("geq", 35, 1) + + testOp("eq", 2, 1) + testOp("eq", 32, 1) + testOp("eq", 35, 1) + + testOp("neq", 2, 1) + testOp("neq", 32, 1) + testOp("neq", 35, 1) + } + + private val resolvedCompiler = new TransformManager(Forms.Resolved) + private def checkWidthsAfterPadWidths(input: String, op: String): Unit = { + val result = compile(input) + + // we serialize the result in order to rerun width inference + val resultFir = firrtl.Parser.parse(result.circuit.serialize) + val newWidths = resolvedCompiler.runTransform(CircuitState(resultFir, Seq())) + + // the newly loaded circuit should look the same in serialized form (if this fails, the test has a bug) + assert(newWidths.circuit.serialize == result.circuit.serialize) + + // we compare the widths produced by PadWidths with the widths that would normally be inferred + assert(newWidths.circuit.modules.head == result.circuit.modules.head, s"failed with op `$op`") + } + + it should "always generate valid firrtl" in { + // an older version of PadWidths would generate ill types firrtl for mul, div, rem and dshl + + def input(op: String): String = + s"""circuit Top: + | module Top: + | input a: UInt<3> + | input b: UInt<1> + | output c: UInt + | c <= $op(a, b) + |""".stripMargin + + def test(op: String): Unit = checkWidthsAfterPadWidths(input(op), op) + + // This was never broken, but we want to make sure that the test works. + test("add") + + test("mul") + test("div") + test("rem") + test("dshl") + } + + private def executeTest(input: String, expected: Seq[String]): Unit = { + val result = compile(input) + val lines = result.circuit.serialize.split("\n").map(normalized) + expected.map(normalized).foreach { e => + assert(lines.contains(e), f"Failed to find $e in ${lines.mkString("\n")}") + } + } +} diff --git a/src/test/scala/firrtlTests/RemoveWiresSpec.scala b/src/test/scala/firrtlTests/RemoveWiresSpec.scala index 58d42710..4022b267 100644 --- a/src/test/scala/firrtlTests/RemoveWiresSpec.scala +++ b/src/test/scala/firrtlTests/RemoveWiresSpec.scala @@ -98,7 +98,7 @@ class RemoveWiresSpec extends FirrtlFlatSpec { val (nodes, wires) = getNodesAndWires(result.circuit) wires.size should be(0) nodes.map(_.serialize) should be( - Seq("""node w = pad(UInt<2>("h2"), 8)""") + Seq("""node w = UInt<8>("h2")""") ) } diff --git a/src/test/scala/firrtlTests/VerilogMemDelaySpec.scala b/src/test/scala/firrtlTests/VerilogMemDelaySpec.scala index 32b1c55d..8491977c 100644 --- a/src/test/scala/firrtlTests/VerilogMemDelaySpec.scala +++ b/src/test/scala/firrtlTests/VerilogMemDelaySpec.scala @@ -2,32 +2,22 @@ package firrtlTests -import firrtl._ import firrtl.testutils._ -import firrtl.testutils.FirrtlCheckers._ import firrtl.ir.Circuit -import firrtl.stage.{FirrtlCircuitAnnotation, FirrtlSourceAnnotation, FirrtlStage} +import firrtl.options.Dependency +import firrtl.passes.memlib.VerilogMemDelays -import org.scalatest.freespec.AnyFreeSpec -import org.scalatest.matchers.should.Matchers - -class VerilogMemDelaySpec extends AnyFreeSpec with Matchers { +class VerilogMemDelaySpec extends LeanTransformSpec(Seq(Dependency(VerilogMemDelays))) { + behavior.of("VerilogMemDelaySpec") private def compileTwiceReturnFirst(input: String): Circuit = { - (new FirrtlStage) - .transform(Seq(FirrtlSourceAnnotation(input))) - .toSeq - .collectFirst { - case fca: FirrtlCircuitAnnotation => - (new FirrtlStage).transform(Seq(fca)) - fca.circuit - } - .get + val res0 = compile(input) + compile(res0.circuit.serialize).circuit } private def compileTwice(input: String): Unit = compileTwiceReturnFirst(input) - "The following low FIRRTL should be parsed by VerilogMemDelays" in { + it should "The following low FIRRTL should be parsed by VerilogMemDelays" in { val input = """ |circuit Test : @@ -63,7 +53,7 @@ class VerilogMemDelaySpec extends AnyFreeSpec with Matchers { compileTwice(input) } - "Using a read-first memory should be allowed in VerilogMemDelays" in { + it should "Using a read-first memory should be allowed in VerilogMemDelays" in { val input = """ |circuit Test : @@ -107,7 +97,7 @@ class VerilogMemDelaySpec extends AnyFreeSpec with Matchers { compileTwice(input) } - "Chained memories should generate correct FIRRTL" in { + it should "Chained memories should generate correct FIRRTL" in { val input = """ |circuit Test : @@ -151,7 +141,7 @@ class VerilogMemDelaySpec extends AnyFreeSpec with Matchers { compileTwice(input) } - "VerilogMemDelays should not violate use before declaration of clocks" in { + it should "VerilogMemDelays should not violate use before declaration of clocks" in { val input = """ |circuit Test : @@ -188,7 +178,7 @@ class VerilogMemDelaySpec extends AnyFreeSpec with Matchers { | m.write.data <= in """.stripMargin - val res = compileTwiceReturnFirst(input).serialize + val res = compile(input).circuit.serialize // Inject a Wire when using a clock not derived from ports res should include("wire m_clock : Clock") res should include("m_clock <= cm.clock") |
