aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorKevin Laeufer2021-08-02 13:46:29 -0700
committerGitHub2021-08-02 20:46:29 +0000
commite04f1e7f303920ac1d1f865450d0e280aafb58b3 (patch)
tree73f26cd236ac8069d9c4877a3c42457d65d477fe /src
parentff1cd28202fb423956a6803a889c3632487d8872 (diff)
add emitter for optimized low firrtl (#2304)
* rearrange passes to enable optimized firrtl emission * Support ConstProp on padded arguments to comparisons with literals * Move shr legalization logic into ConstProp Continue calling ConstProp of shr in Legalize. Co-authored-by: Jack Koenig <koenig@sifive.com> Co-authored-by: Jack Koenig <koenig@sifive.com>
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/Emitter.scala8
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala4
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala6
-rw-r--r--src/main/scala/firrtl/backends/firrtl/FirrtlEmitter.scala32
-rw-r--r--src/main/scala/firrtl/backends/verilog/LegalizeVerilog.scala72
-rw-r--r--src/main/scala/firrtl/checks/CheckResets.scala1
-rw-r--r--src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala10
-rw-r--r--src/main/scala/firrtl/passes/Legalize.scala108
-rw-r--r--src/main/scala/firrtl/passes/LegalizeConnects.scala31
-rw-r--r--src/main/scala/firrtl/passes/PadWidths.scala84
-rw-r--r--src/main/scala/firrtl/passes/RemoveValidIf.scala2
-rw-r--r--src/main/scala/firrtl/passes/SplitExpressions.scala8
-rw-r--r--src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala2
-rw-r--r--src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala4
-rw-r--r--src/main/scala/firrtl/stage/Forms.scala49
-rw-r--r--src/main/scala/firrtl/transforms/CheckCombLoops.scala2
-rw-r--r--src/main/scala/firrtl/transforms/CombineCats.scala3
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala143
-rw-r--r--src/main/scala/firrtl/transforms/DeadCodeElimination.scala10
-rw-r--r--src/main/scala/firrtl/transforms/RemoveReset.scala2
-rw-r--r--src/main/scala/firrtl/transforms/RemoveWires.scala14
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala4
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala2
-rw-r--r--src/test/scala/firrtlTests/LoFirrtlOptimizedEmitterTests.scala49
-rw-r--r--src/test/scala/firrtlTests/LoweringCompilersSpec.scala78
-rw-r--r--src/test/scala/firrtlTests/PadWidthsTests.scala170
-rw-r--r--src/test/scala/firrtlTests/RemoveWiresSpec.scala2
-rw-r--r--src/test/scala/firrtlTests/VerilogMemDelaySpec.scala32
28 files changed, 551 insertions, 381 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala
index 7c91a544..760e83fd 100644
--- a/src/main/scala/firrtl/Emitter.scala
+++ b/src/main/scala/firrtl/Emitter.scala
@@ -3,12 +3,11 @@
package firrtl
import java.io.File
-
import firrtl.annotations.NoTargetAnnotation
import firrtl.backends.experimental.smt.{Btor2Emitter, SMTLibEmitter}
import firrtl.backends.proto.{Emitter => ProtoEmitter}
import firrtl.options.Viewer.view
-import firrtl.options.{CustomFileEmission, HasShellOptions, PhaseException, ShellOption}
+import firrtl.options.{CustomFileEmission, Dependency, HasShellOptions, PhaseException, ShellOption}
import firrtl.passes.PassException
import firrtl.stage.{FirrtlFileAnnotation, FirrtlOptions, RunFirrtlTransformAnnotation}
@@ -45,6 +44,11 @@ object EmitCircuitAnnotation extends HasShellOptions {
)
case "low" =>
Seq(RunFirrtlTransformAnnotation(new LowFirrtlEmitter), EmitCircuitAnnotation(classOf[LowFirrtlEmitter]))
+ case "low-opt" =>
+ Seq(
+ RunFirrtlTransformAnnotation(Dependency(LowFirrtlOptimizedEmitter)),
+ EmitCircuitAnnotation(LowFirrtlOptimizedEmitter.getClass)
+ )
case "verilog" | "mverilog" =>
Seq(RunFirrtlTransformAnnotation(new VerilogEmitter), EmitCircuitAnnotation(classOf[VerilogEmitter]))
case "sverilog" =>
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
index 099b6712..2c08ff6a 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
@@ -118,8 +118,8 @@ private object FirrtlExpressionSemantics {
// the resulting value will be zero for unsigned types
// and the sign bit for signed types"
if (n >= width) {
- if (isSigned(e)) { BV1BitZero }
- else { BVSlice(toSMT(e), width - 1, width - 1) }
+ if (isSigned(e)) { BVSlice(toSMT(e), width - 1, width - 1) }
+ else { BV1BitZero }
} else {
BVSlice(toSMT(e), width - 1, n.toInt)
}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
index 78ad3c80..0b8e3ebf 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
@@ -451,7 +451,7 @@ private class ModuleScanner(
val name = loc.serialize
insertDummyAssignsForUnusedOutputs(expr)
infos.append(name -> info)
- connects.append((name, onExpression(expr, bitWidth(loc.tpe).toInt)))
+ connects.append((name, onExpression(expr, bitWidth(loc.tpe).toInt, allowNarrow = true)))
}
case i @ ir.IsInvalid(info, loc) =>
if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!")
@@ -591,9 +591,9 @@ private class ModuleScanner(
private case class Context() extends TranslationContext {}
- private def onExpression(e: ir.Expression, width: Int): BVExpr = {
+ private def onExpression(e: ir.Expression, width: Int, allowNarrow: Boolean = false): BVExpr = {
implicit val ctx: TranslationContext = Context()
- FirrtlExpressionSemantics.toSMT(e, width, allowNarrow = false)
+ FirrtlExpressionSemantics.toSMT(e, width, allowNarrow)
}
private def onExpression(e: ir.Expression): BVExpr = {
implicit val ctx: TranslationContext = Context()
diff --git a/src/main/scala/firrtl/backends/firrtl/FirrtlEmitter.scala b/src/main/scala/firrtl/backends/firrtl/FirrtlEmitter.scala
index bb385ffd..56b63d75 100644
--- a/src/main/scala/firrtl/backends/firrtl/FirrtlEmitter.scala
+++ b/src/main/scala/firrtl/backends/firrtl/FirrtlEmitter.scala
@@ -1,18 +1,21 @@
package firrtl
import java.io.Writer
-
import firrtl.Utils._
import firrtl.ir._
+import firrtl.stage.TransformManager.TransformDependency
import firrtl.traversals.Foreachers._
import scala.collection.mutable
-sealed abstract class FirrtlEmitter(form: CircuitForm) extends Transform with Emitter {
- def inputForm = form
- def outputForm = form
-
- val outputSuffix: String = form.outputSuffix
+sealed abstract class FirrtlEmitter(form: Seq[TransformDependency], val outputSuffix: String)
+ extends Transform
+ with Emitter
+ with DependencyAPIMigration {
+ override def prerequisites = form
+ override def optionalPrerequisites = Seq.empty
+ override def optionalPrerequisiteOf = Seq.empty
+ override def invalidates(a: Transform) = false
private def emitAllModules(circuit: Circuit): Seq[EmittedFirrtlModule] = {
// For a given module, returns a Seq of all modules instantited inside of it
@@ -60,14 +63,9 @@ sealed abstract class FirrtlEmitter(form: CircuitForm) extends Transform with Em
def emit(state: CircuitState, writer: Writer): Unit = writer.write(state.circuit.serialize)
}
-class ChirrtlEmitter extends FirrtlEmitter(CircuitForm.ChirrtlForm)
-class MinimumHighFirrtlEmitter extends FirrtlEmitter(CircuitForm.HighForm) {
- override def prerequisites = stage.Forms.MinimalHighForm
- override def optionalPrerequisites = Seq.empty
- override def optionalPrerequisiteOf = Seq.empty
- override def invalidates(a: Transform) = false
- override val outputSuffix = ".mhi.fir"
-}
-class HighFirrtlEmitter extends FirrtlEmitter(CircuitForm.HighForm)
-class MiddleFirrtlEmitter extends FirrtlEmitter(CircuitForm.MidForm)
-class LowFirrtlEmitter extends FirrtlEmitter(CircuitForm.LowForm)
+class ChirrtlEmitter extends FirrtlEmitter(stage.Forms.ChirrtlForm, ".fir")
+class MinimumHighFirrtlEmitter extends FirrtlEmitter(stage.Forms.MinimalHighForm, ".mhi.fir")
+class HighFirrtlEmitter extends FirrtlEmitter(stage.Forms.HighForm, ".hi.fir")
+class MiddleFirrtlEmitter extends FirrtlEmitter(stage.Forms.MidForm, ".mid.fir")
+class LowFirrtlEmitter extends FirrtlEmitter(stage.Forms.LowForm, ".lo.fir")
+object LowFirrtlOptimizedEmitter extends FirrtlEmitter(stage.Forms.LowFormOptimized, ".opt.lo.fir")
diff --git a/src/main/scala/firrtl/backends/verilog/LegalizeVerilog.scala b/src/main/scala/firrtl/backends/verilog/LegalizeVerilog.scala
new file mode 100644
index 00000000..f063f395
--- /dev/null
+++ b/src/main/scala/firrtl/backends/verilog/LegalizeVerilog.scala
@@ -0,0 +1,72 @@
+// SPDX-License-Identifier: Apache-2.0
+
+package firrtl.backends.verilog
+
+import firrtl.PrimOps._
+import firrtl.Utils.{error, getGroundZero, zero, BoolType}
+import firrtl.ir._
+import firrtl.transforms.ConstantPropagation
+import firrtl.{bitWidth, Dshlw, Transform}
+import firrtl.Mappers._
+import firrtl.passes.{Pass, SplitExpressions}
+
+/** Rewrites some expressions for valid/better Verilog emission.
+ * - solves shift right overflows by replacing the shift with 0 for UInts and MSB for SInts
+ * - ensures that bit extracts on literals get resolved
+ * - ensures that all negations are replaced with subtract from zero
+ * - adds padding for rem and dshl which breaks firrtl width invariance, but is needed to match Verilog semantics
+ */
+object LegalizeVerilog extends Pass {
+
+ override def prerequisites = firrtl.stage.Forms.LowForm
+ override def optionalPrerequisites = Seq.empty
+ override def optionalPrerequisiteOf = Seq.empty
+ override def invalidates(a: Transform): Boolean = a match {
+ case SplitExpressions => true // we generate pad and bits operations inline which need to be split up
+ case _ => false
+ }
+
+ private def legalizeBitExtract(expr: DoPrim): Expression = {
+ expr.args.head match {
+ case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.constPropBitExtract(expr)
+ case _ => expr
+ }
+ }
+
+ // Convert `-x` to `0 - x`
+ private def legalizeNeg(expr: DoPrim): Expression = {
+ val arg = expr.args.head
+ arg.tpe match {
+ case tpe: SIntType =>
+ val zero = getGroundZero(tpe)
+ DoPrim(Sub, Seq(zero, arg), Nil, expr.tpe)
+ case tpe: UIntType =>
+ val zero = getGroundZero(tpe)
+ val sub = DoPrim(Sub, Seq(zero, arg), Nil, UIntType(tpe.width + IntWidth(1)))
+ DoPrim(AsSInt, Seq(sub), Nil, expr.tpe)
+ }
+ }
+
+ import firrtl.passes.PadWidths.forceWidth
+ private def getWidth(e: Expression): Int = bitWidth(e.tpe).toInt
+
+ private def onExpr(expr: Expression): Expression = expr.map(onExpr) match {
+ case prim: DoPrim =>
+ prim.op match {
+ case Shr => ConstantPropagation.foldShiftRight(prim)
+ case Bits | Head | Tail => legalizeBitExtract(prim)
+ case Neg => legalizeNeg(prim)
+ case Rem => prim.map(forceWidth(prim.args.map(getWidth).max))
+ case Dshl =>
+ // special case as args aren't all same width
+ prim.copy(op = Dshlw, args = Seq(forceWidth(getWidth(prim))(prim.args.head), prim.args(1)))
+ case _ => prim
+ }
+ case e => e // respect pre-order traversal
+ }
+
+ def run(c: Circuit): Circuit = {
+ def legalizeS(s: Statement): Statement = s.mapStmt(legalizeS).mapExpr(onExpr)
+ c.copy(modules = c.modules.map(_.map(legalizeS)))
+ }
+}
diff --git a/src/main/scala/firrtl/checks/CheckResets.scala b/src/main/scala/firrtl/checks/CheckResets.scala
index ae300d1f..e5a3e77a 100644
--- a/src/main/scala/firrtl/checks/CheckResets.scala
+++ b/src/main/scala/firrtl/checks/CheckResets.scala
@@ -33,7 +33,6 @@ class CheckResets extends Transform with DependencyAPIMigration {
override def prerequisites =
Seq(
Dependency(passes.LowerTypes),
- Dependency(passes.Legalize),
Dependency(firrtl.transforms.RemoveReset)
) ++ firrtl.stage.Forms.MidForm
diff --git a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala
index e70346d4..70da011e 100644
--- a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala
+++ b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala
@@ -9,15 +9,7 @@ import firrtl.options.Dependency
object CommonSubexpressionElimination extends Pass {
- override def prerequisites = firrtl.stage.Forms.LowForm ++
- Seq(
- Dependency(firrtl.passes.RemoveValidIf),
- Dependency[firrtl.transforms.ConstantPropagation],
- Dependency(firrtl.passes.memlib.VerilogMemDelays),
- Dependency(firrtl.passes.SplitExpressions),
- Dependency[firrtl.transforms.CombineCats]
- )
-
+ override def prerequisites = firrtl.stage.Forms.LowForm
override def optionalPrerequisiteOf =
Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
diff --git a/src/main/scala/firrtl/passes/Legalize.scala b/src/main/scala/firrtl/passes/Legalize.scala
deleted file mode 100644
index e1a39fbe..00000000
--- a/src/main/scala/firrtl/passes/Legalize.scala
+++ /dev/null
@@ -1,108 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-
-package firrtl.passes
-
-import firrtl.PrimOps._
-import firrtl.Utils.{error, getGroundZero, zero, BoolType}
-import firrtl.ir._
-import firrtl.options.Dependency
-import firrtl.transforms.ConstantPropagation
-import firrtl.{bitWidth, getWidth, Transform}
-import firrtl.Mappers._
-
-// Replace shr by amount >= arg width with 0 for UInts and MSB for SInts
-// TODO replace UInt with zero-width wire instead
-object Legalize extends Pass {
-
- override def prerequisites = firrtl.stage.Forms.MidForm :+ Dependency(LowerTypes)
-
- override def optionalPrerequisites = Seq.empty
-
- override def optionalPrerequisiteOf = Seq.empty
-
- override def invalidates(a: Transform) = false
-
- private def legalizeShiftRight(e: DoPrim): Expression = {
- require(e.op == Shr)
- e.args.head match {
- case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.foldShiftRight(e)
- case _ =>
- val amount = e.consts.head.toInt
- val width = bitWidth(e.args.head.tpe)
- lazy val msb = width - 1
- if (amount >= width) {
- e.tpe match {
- case UIntType(_) => zero
- case SIntType(_) =>
- val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType)
- DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1)))
- case t => error(s"Unsupported type $t for Primop Shift Right")
- }
- } else {
- e
- }
- }
- }
- private def legalizeBitExtract(expr: DoPrim): Expression = {
- expr.args.head match {
- case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.constPropBitExtract(expr)
- case _ => expr
- }
- }
- private def legalizePad(expr: DoPrim): Expression = expr.args.head match {
- case UIntLiteral(value, IntWidth(width)) if width < expr.consts.head =>
- UIntLiteral(value, IntWidth(expr.consts.head))
- case SIntLiteral(value, IntWidth(width)) if width < expr.consts.head =>
- SIntLiteral(value, IntWidth(expr.consts.head))
- case _ => expr
- }
- // Convert `-x` to `0 - x`
- private def legalizeNeg(expr: DoPrim): Expression = {
- val arg = expr.args.head
- arg.tpe match {
- case tpe: SIntType =>
- val zero = getGroundZero(tpe)
- DoPrim(Sub, Seq(zero, arg), Nil, expr.tpe)
- case tpe: UIntType =>
- val zero = getGroundZero(tpe)
- val sub = DoPrim(Sub, Seq(zero, arg), Nil, UIntType(tpe.width + IntWidth(1)))
- DoPrim(AsSInt, Seq(sub), Nil, expr.tpe)
- }
- }
- private def legalizeConnect(c: Connect): Statement = {
- val t = c.loc.tpe
- val w = bitWidth(t)
- if (w >= bitWidth(c.expr.tpe)) {
- c
- } else {
- val bits = DoPrim(Bits, Seq(c.expr), Seq(w - 1, 0), UIntType(IntWidth(w)))
- val expr = t match {
- case UIntType(_) => bits
- case SIntType(_) => DoPrim(AsSInt, Seq(bits), Seq(), SIntType(IntWidth(w)))
- case FixedType(_, IntWidth(p)) => DoPrim(AsFixedPoint, Seq(bits), Seq(p), t)
- }
- Connect(c.info, c.loc, expr)
- }
- }
- def run(c: Circuit): Circuit = {
- def legalizeE(expr: Expression): Expression = expr.map(legalizeE) match {
- case prim: DoPrim =>
- prim.op match {
- case Shr => legalizeShiftRight(prim)
- case Pad => legalizePad(prim)
- case Bits | Head | Tail => legalizeBitExtract(prim)
- case Neg => legalizeNeg(prim)
- case _ => prim
- }
- case e => e // respect pre-order traversal
- }
- def legalizeS(s: Statement): Statement = {
- val legalizedStmt = s match {
- case c: Connect => legalizeConnect(c)
- case _ => s
- }
- legalizedStmt.map(legalizeS).map(legalizeE)
- }
- c.copy(modules = c.modules.map(_.map(legalizeS)))
- }
-}
diff --git a/src/main/scala/firrtl/passes/LegalizeConnects.scala b/src/main/scala/firrtl/passes/LegalizeConnects.scala
new file mode 100644
index 00000000..2f29de10
--- /dev/null
+++ b/src/main/scala/firrtl/passes/LegalizeConnects.scala
@@ -0,0 +1,31 @@
+// SPDX-License-Identifier: Apache-2.0
+
+package firrtl.passes
+
+import firrtl.ir._
+import firrtl.options.Dependency
+import firrtl.{bitWidth, Transform}
+
+/** Ensures that all connects + register inits have the same bit-width on the rhs and the lhs.
+ * The rhs is padded or bit-extacted to fit the width of the lhs.
+ * @note technically, width(rhs) > width(lhs) is not legal firrtl, however, we do not error for historic reasons.
+ */
+object LegalizeConnects extends Pass {
+
+ override def prerequisites = firrtl.stage.Forms.MidForm :+ Dependency(LowerTypes)
+ override def optionalPrerequisites = Seq.empty
+ override def optionalPrerequisiteOf = Seq.empty
+ override def invalidates(a: Transform) = false
+
+ def onStmt(s: Statement): Statement = s match {
+ case c: Connect =>
+ c.copy(expr = PadWidths.forceWidth(bitWidth(c.loc.tpe).toInt)(c.expr))
+ case r: DefRegister =>
+ r.copy(init = PadWidths.forceWidth(bitWidth(r.tpe).toInt)(r.init))
+ case other => other.mapStmt(onStmt)
+ }
+
+ def run(c: Circuit): Circuit = {
+ c.copy(modules = c.modules.map(_.mapStmt(onStmt)))
+ }
+}
diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala
index 1a430778..02e94975 100644
--- a/src/main/scala/firrtl/passes/PadWidths.scala
+++ b/src/main/scala/firrtl/passes/PadWidths.scala
@@ -7,63 +7,59 @@ import firrtl.ir._
import firrtl.PrimOps._
import firrtl.Mappers._
import firrtl.options.Dependency
-
-import scala.collection.mutable
+import firrtl.transforms.ConstantPropagation
// Makes all implicit width extensions and truncations explicit
object PadWidths extends Pass {
- override def prerequisites =
- ((new mutable.LinkedHashSet())
- ++ firrtl.stage.Forms.LowForm
- - Dependency(firrtl.passes.Legalize)
- + Dependency(firrtl.passes.RemoveValidIf)).toSeq
-
- override def optionalPrerequisites = Seq(Dependency[firrtl.transforms.ConstantPropagation])
+ override def prerequisites = firrtl.stage.Forms.LowForm
override def optionalPrerequisiteOf =
Seq(Dependency(firrtl.passes.memlib.VerilogMemDelays), Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
override def invalidates(a: Transform): Boolean = a match {
- case _: firrtl.transforms.ConstantPropagation | Legalize => true
- case _ => false
+ case SplitExpressions => true // we generate pad and bits operations inline which need to be split up
+ case _ => false
}
- private def width(t: Type): Int = bitWidth(t).toInt
- private def width(e: Expression): Int = width(e.tpe)
- // Returns an expression with the correct integer width
- private def fixup(i: Int)(e: Expression) = {
- def tx = e.tpe match {
- case t: UIntType => UIntType(IntWidth(i))
- case t: SIntType => SIntType(IntWidth(i))
- // default case should never be reached
- }
- width(e) match {
- case j if i > j => DoPrim(Pad, Seq(e), Seq(i), tx)
- case j if i < j =>
- val e2 = DoPrim(Bits, Seq(e), Seq(i - 1, 0), UIntType(IntWidth(i)))
- // Bit Select always returns UInt, cast if selecting from SInt
- e.tpe match {
- case UIntType(_) => e2
- case SIntType(_) => DoPrim(AsSInt, Seq(e2), Seq.empty, SIntType(IntWidth(i)))
- }
- case _ => e
+ /** Adds padding or a bit extract to ensure that the expression is of the with specified.
+ * @note only works on UInt and SInt type expressions, other expressions will yield a match error
+ */
+ private[firrtl] def forceWidth(width: Int)(e: Expression): Expression = {
+ val old = getWidth(e)
+ if (width == old) { e }
+ else if (width > old) {
+ // padding retains the signedness
+ val newType = e.tpe match {
+ case _: UIntType => UIntType(IntWidth(width))
+ case _: SIntType => SIntType(IntWidth(width))
+ case other => throw new RuntimeException(s"forceWidth does not support expressions of type $other")
+ }
+ ConstantPropagation.constPropPad(DoPrim(Pad, Seq(e), Seq(width), newType))
+ } else {
+ val extract = DoPrim(Bits, Seq(e), Seq(width - 1, 0), UIntType(IntWidth(width)))
+ val e2 = ConstantPropagation.constPropBitExtract(extract)
+ // Bit Select always returns UInt, cast if selecting from SInt
+ e.tpe match {
+ case UIntType(_) => e2
+ case SIntType(_) => DoPrim(AsSInt, Seq(e2), Seq.empty, SIntType(IntWidth(width)))
+ }
}
}
+ private def getWidth(t: Type): Int = bitWidth(t).toInt
+ private def getWidth(e: Expression): Int = getWidth(e.tpe)
+
// Recursive, updates expression so children exp's have correct widths
private def onExp(e: Expression): Expression = e.map(onExp) match {
case Mux(cond, tval, fval, tpe) =>
- Mux(cond, fixup(width(tpe))(tval), fixup(width(tpe))(fval), tpe)
- case ex: ValidIf => ex.copy(value = fixup(width(ex.tpe))(ex.value))
+ Mux(cond, forceWidth(getWidth(tpe))(tval), forceWidth(getWidth(tpe))(fval), tpe)
+ case ex: ValidIf => ex.copy(value = forceWidth(getWidth(ex.tpe))(ex.value))
case ex: DoPrim =>
ex.op match {
- case Lt | Leq | Gt | Geq | Eq | Neq | Not | And | Or | Xor | Add | Sub | Rem | Shr =>
- // sensitive ops
- ex.map(fixup((ex.args.map(width).foldLeft(0))(math.max)))
- case Dshl =>
- // special case as args aren't all same width
- ex.copy(op = Dshlw, args = Seq(fixup(width(ex.tpe))(ex.args.head), ex.args(1)))
+ // pad arguments to ops where the result width is determined as max(w_1, w_2) (+ const)?
+ case Lt | Leq | Gt | Geq | Eq | Neq | And | Or | Xor | Add | Sub =>
+ ex.map(forceWidth(ex.args.map(getWidth).max))
case _ => ex
}
case ex => ex
@@ -72,9 +68,17 @@ object PadWidths extends Pass {
// Recursive. Fixes assignments and register initialization widths
private def onStmt(s: Statement): Statement = s.map(onExp) match {
case sx: Connect =>
- sx.copy(expr = fixup(width(sx.loc))(sx.expr))
+ assert(
+ getWidth(sx.loc) == getWidth(sx.expr),
+ "Connection widths should have been taken care of by LegalizeConnects!"
+ )
+ sx
case sx: DefRegister =>
- sx.copy(init = fixup(width(sx.tpe))(sx.init))
+ assert(
+ getWidth(sx.tpe) == getWidth(sx.init),
+ "Register init widths should have been taken care of by LegalizeConnects!"
+ )
+ sx
case sx => sx.map(onStmt)
}
diff --git a/src/main/scala/firrtl/passes/RemoveValidIf.scala b/src/main/scala/firrtl/passes/RemoveValidIf.scala
index 03214f83..dc4e70ff 100644
--- a/src/main/scala/firrtl/passes/RemoveValidIf.scala
+++ b/src/main/scala/firrtl/passes/RemoveValidIf.scala
@@ -31,7 +31,7 @@ object RemoveValidIf extends Pass {
Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
override def invalidates(a: Transform): Boolean = a match {
- case Legalize | _: firrtl.transforms.ConstantPropagation => true
+ case _: firrtl.transforms.ConstantPropagation => true // switching out the validifs allows for more constant prop
case _ => false
}
diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala
index 1b4ed1cc..26088e9c 100644
--- a/src/main/scala/firrtl/passes/SplitExpressions.scala
+++ b/src/main/scala/firrtl/passes/SplitExpressions.scala
@@ -8,6 +8,7 @@ import firrtl.ir._
import firrtl.options.Dependency
import firrtl.Mappers._
import firrtl.Utils.{flow, get_info, kind}
+import firrtl.transforms.InlineBooleanExpressions
// Datastructures
import scala.collection.mutable
@@ -16,15 +17,14 @@ import scala.collection.mutable
// and named intermediate nodes
object SplitExpressions extends Pass {
- override def prerequisites = firrtl.stage.Forms.LowForm ++
- Seq(Dependency(firrtl.passes.RemoveValidIf), Dependency(firrtl.passes.memlib.VerilogMemDelays))
-
+ override def prerequisites = firrtl.stage.Forms.LowForm
override def optionalPrerequisiteOf =
Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
override def invalidates(a: Transform) = a match {
case ResolveKinds => true
- case _ => false
+ case _: InlineBooleanExpressions => true // SplitExpressions undoes the inlining!
+ case _ => false
}
private def onModule(m: Module): Module = {
diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
index c7b0fbcd..331dd43e 100644
--- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
+++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
@@ -153,7 +153,7 @@ class ReplSeqMem extends SeqTransform with HasShellOptions with DependencyAPIMig
val transforms: Seq[Transform] =
Seq(
- new SimpleMidTransform(Legalize),
+ new SimpleMidTransform(LegalizeConnects),
new SimpleMidTransform(ToMemIR),
new SimpleMidTransform(ResolveMaskGranularity),
new SimpleMidTransform(RenameAnnotatedMemoryPorts),
diff --git a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala
index 11184e60..3778f4da 100644
--- a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala
+++ b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala
@@ -235,8 +235,8 @@ object VerilogMemDelays extends Pass {
Seq(Dependency[VerilogEmitter], Dependency[SystemVerilogEmitter])
override def invalidates(a: Transform): Boolean = a match {
- case _: transforms.ConstantPropagation | ResolveFlows => true
- case _ => false
+ case ResolveFlows => true
+ case _ => false
}
private def transform(m: DefModule): DefModule = (new MemDelayAndReadwriteTransformer(m)).transformed
diff --git a/src/main/scala/firrtl/stage/Forms.scala b/src/main/scala/firrtl/stage/Forms.scala
index c7ae648a..83d019f5 100644
--- a/src/main/scala/firrtl/stage/Forms.scala
+++ b/src/main/scala/firrtl/stage/Forms.scala
@@ -75,7 +75,7 @@ object Forms {
val LowForm: Seq[TransformDependency] = MidForm ++
Seq(
Dependency(passes.LowerTypes),
- Dependency(passes.Legalize),
+ Dependency(passes.LegalizeConnects),
Dependency(firrtl.transforms.RemoveReset),
Dependency[firrtl.transforms.CheckCombLoops],
Dependency[checks.CheckResets],
@@ -86,39 +86,42 @@ object Forms {
Seq(
Dependency(passes.RemoveValidIf),
Dependency(passes.PadWidths),
- Dependency(passes.memlib.VerilogMemDelays),
- Dependency(passes.SplitExpressions),
- Dependency[firrtl.transforms.LegalizeAndReductionsTransform]
+ Dependency(passes.SplitExpressions)
)
val LowFormOptimized: Seq[TransformDependency] = LowFormMinimumOptimized ++
Seq(
Dependency[firrtl.transforms.ConstantPropagation],
- Dependency[firrtl.transforms.CombineCats],
Dependency(passes.CommonSubexpressionElimination),
Dependency[firrtl.transforms.DeadCodeElimination]
)
- val VerilogMinimumOptimized: Seq[TransformDependency] = LowFormMinimumOptimized ++
+ private def VerilogLowerings(optimize: Boolean): Seq[TransformDependency] = {
Seq(
- Dependency[firrtl.transforms.BlackBoxSourceHelper],
- Dependency[firrtl.transforms.FixAddingNegativeLiterals],
- Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
- Dependency[firrtl.transforms.InlineBitExtractionsTransform],
- Dependency[firrtl.transforms.InlineAcrossCastsTransform],
- Dependency[firrtl.transforms.LegalizeClocksTransform],
- Dependency[firrtl.transforms.FlattenRegUpdate],
- Dependency(passes.VerilogModulusCleanup),
- Dependency[firrtl.transforms.VerilogRename],
- Dependency(passes.VerilogPrep),
- Dependency[firrtl.AddDescriptionNodes]
- )
-
- val VerilogOptimized: Seq[TransformDependency] = LowFormOptimized ++
- Seq(
- Dependency[firrtl.transforms.InlineBooleanExpressions]
+ Dependency(firrtl.backends.verilog.LegalizeVerilog),
+ Dependency(passes.memlib.VerilogMemDelays),
+ Dependency[firrtl.transforms.CombineCats]
) ++
- VerilogMinimumOptimized
+ (if (optimize) Seq(Dependency[firrtl.transforms.InlineBooleanExpressions]) else Seq()) ++
+ Seq(
+ Dependency[firrtl.transforms.LegalizeAndReductionsTransform],
+ Dependency[firrtl.transforms.BlackBoxSourceHelper],
+ Dependency[firrtl.transforms.FixAddingNegativeLiterals],
+ Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
+ Dependency[firrtl.transforms.InlineBitExtractionsTransform],
+ Dependency[firrtl.transforms.InlineAcrossCastsTransform],
+ Dependency[firrtl.transforms.LegalizeClocksTransform],
+ Dependency[firrtl.transforms.FlattenRegUpdate],
+ Dependency(passes.VerilogModulusCleanup),
+ Dependency[firrtl.transforms.VerilogRename],
+ Dependency(passes.VerilogPrep),
+ Dependency[firrtl.AddDescriptionNodes]
+ )
+ }
+
+ val VerilogMinimumOptimized: Seq[TransformDependency] = LowFormMinimumOptimized ++ VerilogLowerings(optimize = false)
+
+ val VerilogOptimized: Seq[TransformDependency] = LowFormOptimized ++ VerilogLowerings(optimize = true)
val AssertsRemoved: Seq[TransformDependency] =
Seq(
diff --git a/src/main/scala/firrtl/transforms/CheckCombLoops.scala b/src/main/scala/firrtl/transforms/CheckCombLoops.scala
index dee2f9c8..eec9d1af 100644
--- a/src/main/scala/firrtl/transforms/CheckCombLoops.scala
+++ b/src/main/scala/firrtl/transforms/CheckCombLoops.scala
@@ -101,7 +101,7 @@ case class CombinationalPath(sink: ReferenceTarget, sources: Seq[ReferenceTarget
class CheckCombLoops extends Transform with RegisteredTransform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.MidForm ++
- Seq(Dependency(passes.LowerTypes), Dependency(passes.Legalize), Dependency(firrtl.transforms.RemoveReset))
+ Seq(Dependency(passes.LowerTypes), Dependency(firrtl.transforms.RemoveReset))
override def optionalPrerequisites = Seq.empty
diff --git a/src/main/scala/firrtl/transforms/CombineCats.scala b/src/main/scala/firrtl/transforms/CombineCats.scala
index a37e6f08..71ef34bf 100644
--- a/src/main/scala/firrtl/transforms/CombineCats.scala
+++ b/src/main/scala/firrtl/transforms/CombineCats.scala
@@ -63,12 +63,11 @@ class CombineCats extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.LowForm ++
Seq(
Dependency(passes.RemoveValidIf),
- Dependency[firrtl.transforms.ConstantPropagation],
Dependency(firrtl.passes.memlib.VerilogMemDelays),
Dependency(firrtl.passes.SplitExpressions)
)
- override def optionalPrerequisites = Seq.empty
+ override def optionalPrerequisites = Seq(Dependency[firrtl.transforms.ConstantPropagation])
override def optionalPrerequisiteOf = Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index bc1fc9af..f216a3a3 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -3,7 +3,7 @@
package firrtl
package transforms
-import firrtl._
+import firrtl.{options, _}
import firrtl.annotations._
import firrtl.annotations.TargetToken._
import firrtl.ir._
@@ -41,8 +41,64 @@ object ConstantPropagation {
}
)
case (we, wt) if we == wt => e
+ case (we, wt) =>
+ throw new RuntimeException(s"Cannot pad from $we-bit to $wt-bit! ${e.serialize}")
}
+ def constPropPad(e: DoPrim): Expression = {
+ // we constant prop through casts here in order to allow LegalizeConnects
+ // to not mess up async reset checks in CheckResets
+ val propCasts = e.args.head match {
+ case c @ DoPrim(AsUInt, _, _, _) => constPropCasts(c)
+ case c @ DoPrim(AsSInt, _, _, _) => constPropCasts(c)
+ case other => other
+ }
+ propCasts match {
+ case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head.max(w)))
+ case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head.max(w)))
+ case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head
+ case _ => e
+ }
+ }
+
+ def constPropCasts(e: DoPrim): Expression = e.op match {
+ case AsUInt =>
+ e.args.head match {
+ case SIntLiteral(v, IntWidth(w)) => litToUInt(v, w.toInt)
+ case arg =>
+ arg.tpe match {
+ case _: UIntType => arg
+ case _ => e
+ }
+ }
+ case AsSInt =>
+ e.args.head match {
+ case UIntLiteral(v, IntWidth(w)) => litToSInt(v, w.toInt)
+ case arg =>
+ arg.tpe match {
+ case _: SIntType => arg
+ case _ => e
+ }
+ }
+ case AsClock =>
+ val arg = e.args.head
+ arg.tpe match {
+ case ClockType => arg
+ case _ => e
+ }
+ case AsAsyncReset =>
+ val arg = e.args.head
+ arg.tpe match {
+ case AsyncResetType => arg
+ case _ => e
+ }
+ }
+
+ private def litToSInt(unsignedValue: BigInt, w: Int): SIntLiteral =
+ SIntLiteral(unsignedValue - ((unsignedValue >> (w - 1)) << w), IntWidth(w))
+ private def litToUInt(signedValue: BigInt, w: Int): UIntLiteral =
+ UIntLiteral(signedValue + (if (signedValue < 0) BigInt(1) << w else 0), IntWidth(w))
+
def constPropBitExtract(e: DoPrim) = {
val arg = e.args.head
val (hi, lo) = e.op match {
@@ -68,11 +124,26 @@ object ConstantPropagation {
case 0 => e.args.head
case x =>
e.args.head match {
- // TODO when amount >= x.width, return a zero-width wire
case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x).max(1)))
- // take sign bit if shift amount is larger than arg width
case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x).max(1)))
- case _ => e
+ // Handle non-literal arguments where shift is larger than width
+ case _ =>
+ val amount = e.consts.head.toInt
+ val width = bitWidth(e.args.head.tpe)
+ lazy val msb = width - 1
+ if (amount >= width) {
+ e.tpe match {
+ // When amount >= x.width, return a zero-width wire
+ case UIntType(_) => zero
+ // Take sign bit if shift amount is larger than arg width
+ case SIntType(_) =>
+ val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType)
+ DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1)))
+ case t => error(s"Unsupported type $t for Primop Shift Right")
+ }
+ } else {
+ e
+ }
}
}
@@ -114,12 +185,13 @@ object ConstantPropagation {
class ConstantPropagation extends Transform with RegisteredTransform with DependencyAPIMigration {
import ConstantPropagation._
- override def prerequisites =
- ((new mutable.LinkedHashSet())
- ++ firrtl.stage.Forms.LowForm
- - Dependency(firrtl.passes.Legalize)).toSeq
+ override def prerequisites = firrtl.stage.Forms.LowForm
- override def optionalPrerequisites = Seq(Dependency(firrtl.passes.RemoveValidIf))
+ override def optionalPrerequisites = Seq(
+ // both passes allow constant prop to be more effective!
+ Dependency(firrtl.passes.RemoveValidIf),
+ Dependency(firrtl.passes.PadWidths)
+ )
override def optionalPrerequisiteOf =
Seq(
@@ -130,8 +202,7 @@ class ConstantPropagation extends Transform with RegisteredTransform with Depend
)
override def invalidates(a: Transform): Boolean = a match {
- case firrtl.passes.Legalize => true
- case _ => false
+ case _ => false
}
val options = Seq(
@@ -353,11 +424,16 @@ class ConstantPropagation extends Transform with RegisteredTransform with Depend
def <(that: Range) = this.max < that.min
def <=(that: Range) = this.max <= that.min
}
+ // Padding increases the width but doesn't increase the range of values
+ def trueType(e: Expression): Type = e match {
+ case DoPrim(Pad, Seq(a), _, _) => a.tpe
+ case other => other.tpe
+ }
def range(e: Expression): Range = e match {
case UIntLiteral(value, _) => Range(value, value)
case SIntLiteral(value, _) => Range(value, value)
case _ =>
- e.tpe match {
+ trueType(e) match {
case SIntType(IntWidth(width)) =>
Range(
min = BigInt(0) - BigInt(2).pow(width.toInt - 1),
@@ -444,45 +520,10 @@ class ConstantPropagation extends Transform with RegisteredTransform with Depend
case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w))
case _ => e
}
- case AsUInt =>
- e.args.head match {
- case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w))
- case arg =>
- arg.tpe match {
- case _: UIntType => arg
- case _ => e
- }
- }
- case AsSInt =>
- e.args.head match {
- case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt - 1)) << w.toInt), IntWidth(w))
- case arg =>
- arg.tpe match {
- case _: SIntType => arg
- case _ => e
- }
- }
- case AsClock =>
- val arg = e.args.head
- arg.tpe match {
- case ClockType => arg
- case _ => e
- }
- case AsAsyncReset =>
- val arg = e.args.head
- arg.tpe match {
- case AsyncResetType => arg
- case _ => e
- }
- case Pad =>
- e.args.head match {
- case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head.max(w)))
- case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head.max(w)))
- case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head
- case _ => e
- }
- case (Bits | Head | Tail) => constPropBitExtract(e)
- case _ => e
+ case AsUInt | AsSInt | AsClock | AsAsyncReset => constPropCasts(e)
+ case Pad => constPropPad(e)
+ case (Bits | Head | Tail) => constPropBitExtract(e)
+ case _ => e
}
private def constPropMuxCond(m: Mux) = m.cond match {
diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
index f72585d1..41ffd2be 100644
--- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
+++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
@@ -33,15 +33,7 @@ import collection.mutable
*/
class DeadCodeElimination extends Transform with RegisteredTransform with DependencyAPIMigration {
- override def prerequisites = firrtl.stage.Forms.LowForm ++
- Seq(
- Dependency(firrtl.passes.RemoveValidIf),
- Dependency[firrtl.transforms.ConstantPropagation],
- Dependency(firrtl.passes.memlib.VerilogMemDelays),
- Dependency(firrtl.passes.SplitExpressions),
- Dependency[firrtl.transforms.CombineCats],
- Dependency(passes.CommonSubexpressionElimination)
- )
+ override def prerequisites = firrtl.stage.Forms.LowForm
override def optionalPrerequisites = Seq.empty
diff --git a/src/main/scala/firrtl/transforms/RemoveReset.scala b/src/main/scala/firrtl/transforms/RemoveReset.scala
index d7f59321..62b341cd 100644
--- a/src/main/scala/firrtl/transforms/RemoveReset.scala
+++ b/src/main/scala/firrtl/transforms/RemoveReset.scala
@@ -19,7 +19,7 @@ import scala.collection.{immutable, mutable}
object RemoveReset extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.MidForm ++
- Seq(Dependency(passes.LowerTypes), Dependency(passes.Legalize))
+ Seq(Dependency(passes.LowerTypes))
override def optionalPrerequisites = Seq.empty
diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala
index 7500b386..4fa70002 100644
--- a/src/main/scala/firrtl/transforms/RemoveWires.scala
+++ b/src/main/scala/firrtl/transforms/RemoveWires.scala
@@ -12,6 +12,7 @@ import firrtl.graph.{CyclicException, MutableDiGraph}
import firrtl.options.Dependency
import firrtl.Utils.getGroundZero
import firrtl.backends.experimental.smt.random.DefRandom
+import firrtl.passes.PadWidths
import scala.collection.mutable
import scala.util.{Failure, Success, Try}
@@ -27,10 +28,10 @@ class RemoveWires extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.MidForm ++
Seq(
Dependency(passes.LowerTypes),
- Dependency(passes.Legalize),
Dependency(passes.ResolveKinds),
Dependency(transforms.RemoveReset),
- Dependency[transforms.CheckCombLoops]
+ Dependency[transforms.CheckCombLoops],
+ Dependency(passes.LegalizeConnects)
)
override def optionalPrerequisites = Seq(Dependency[checks.CheckResets])
@@ -131,10 +132,13 @@ class RemoveWires extends Transform with DependencyAPIMigration {
case con @ Connect(cinfo, lhs, rhs) =>
kind(lhs) match {
case WireKind =>
- // Be sure to pad the rhs since nodes get their type from the rhs
- val paddedRhs = ConstantPropagation.pad(rhs, lhs.tpe)
+ // be sure that connects have the same bit widths on rhs and lhs
+ assert(
+ bitWidth(lhs.tpe) == bitWidth(rhs.tpe),
+ "Connection widths should have been taken care of by LegalizeConnects!"
+ )
val dinfo = wireInfo(lhs)
- netlist(we(lhs)) = (Seq(paddedRhs), MultiInfo(dinfo, cinfo))
+ netlist(we(lhs)) = (Seq(rhs), MultiInfo(dinfo, cinfo))
case _ => otherStmts += con // Other connections just pass through
}
case invalid @ IsInvalid(info, expr) =>
diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala
index 6ce90eab..f6788435 100644
--- a/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala
+++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala
@@ -195,8 +195,8 @@ class FirrtlExpressionSemanticsSpec extends AnyFlatSpec {
}
it should "correctly translate the `neg` primitive operation" in {
- assert(primop(true, "neg", 4, List(3)) == "sub(sext(3'b0, 1), sext(i0, 1))")
- assert(primop("neg", "SInt<4>", List("UInt<3>"), List()) == "sub(zext(3'b0, 1), zext(i0, 1))")
+ assert(primop(true, "neg", 4, List(3)) == "neg(sext(i0, 1))")
+ assert(primop("neg", "SInt<4>", List("UInt<3>"), List()) == "neg(zext(i0, 1))")
}
it should "correctly translate the `not` primitive operation" in {
diff --git a/src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala
index 8f17a847..e5226226 100644
--- a/src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala
+++ b/src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala
@@ -48,7 +48,7 @@ class InvalidToRandomSpec extends LeanTransformSpec(Seq(Dependency(InvalidToRand
assert(result.contains("node _GEN_1 = mux(not(o2_valid), _GEN_1_invalid, UInt<3>(\"h7\"))"))
// expressions that are trivially valid do not get randomized
- assert(result.contains("o3 <= UInt<2>(\"h3\")"))
+ assert(result.contains("o3 <= UInt<8>(\"h3\")"))
val defRandCount = result.count(_.contains("rand "))
assert(defRandCount == 2)
}
diff --git a/src/test/scala/firrtlTests/LoFirrtlOptimizedEmitterTests.scala b/src/test/scala/firrtlTests/LoFirrtlOptimizedEmitterTests.scala
new file mode 100644
index 00000000..6f1c56c5
--- /dev/null
+++ b/src/test/scala/firrtlTests/LoFirrtlOptimizedEmitterTests.scala
@@ -0,0 +1,49 @@
+// SPDX-License-Identifier: Apache-2.0
+
+package firrtlTests
+
+import firrtl._
+import firrtl.stage._
+import firrtl.util.BackendCompilationUtilities
+import org.scalatest.flatspec.AnyFlatSpec
+
+class LoFirrtlOptimizedEmitterTests extends AnyFlatSpec {
+ behavior.of("LoFirrtlOptimizedEmitter")
+
+ it should "generate valid firrtl for AddNot" in { compileAndParse("AddNot") }
+ it should "generate valid firrtl for FPU" in { compileAndParse("FPU") }
+ it should "generate valid firrtl for HwachaSequencer" in { compileAndParse("HwachaSequencer") }
+ it should "generate valid firrtl for ICache" in { compileAndParse("ICache") }
+ it should "generate valid firrtl for Ops" in { compileAndParse("Ops") }
+ it should "generate valid firrtl for Rob" in { compileAndParse("Rob") }
+ it should "generate valid firrtl for RocketCore" in { compileAndParse("RocketCore") }
+
+ private def compileAndParse(name: String): Unit = {
+ val testDir = os.RelPath(
+ BackendCompilationUtilities.createTestDirectory(
+ "LoFirrtlOptimizedEmitter_should_generate_valid_firrtl_for" + name
+ )
+ )
+ val inputFile = testDir / s"$name.fir"
+ val outputFile = testDir / s"$name.opt.lo.fir"
+
+ BackendCompilationUtilities.copyResourceToFile(s"/regress/${name}.fir", (os.pwd / inputFile).toIO)
+
+ val stage = new FirrtlStage
+ // run low-opt emitter
+ val args = Array(
+ "-ll",
+ "error", // surpress warnings to keep test output clean
+ "--target-dir",
+ testDir.toString,
+ "-i",
+ inputFile.toString,
+ "-E",
+ "low-opt"
+ )
+ val res = stage.execute(args, Seq())
+
+ // load in result to check
+ stage.execute(Array("--target-dir", testDir.toString, "-i", outputFile.toString()), Seq())
+ }
+}
diff --git a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
index d56ca657..bb1a8169 100644
--- a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
+++ b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
@@ -89,7 +89,7 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers {
firrtl.passes.InferTypes,
firrtl.passes.ResolveFlows,
new firrtl.passes.InferWidths,
- firrtl.passes.Legalize,
+ firrtl.passes.LegalizeConnects,
firrtl.transforms.RemoveReset,
firrtl.passes.ResolveFlows,
new firrtl.transforms.CheckCombLoops,
@@ -102,7 +102,7 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers {
new firrtl.transforms.ConstantPropagation,
firrtl.passes.PadWidths,
new firrtl.transforms.ConstantPropagation,
- firrtl.passes.Legalize,
+ firrtl.passes.LegalizeConnects,
firrtl.passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter
new firrtl.transforms.ConstantPropagation,
firrtl.passes.SplitExpressions,
@@ -114,7 +114,7 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers {
Seq(
firrtl.passes.RemoveValidIf,
firrtl.passes.PadWidths,
- firrtl.passes.Legalize,
+ firrtl.passes.LegalizeConnects,
firrtl.passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter
firrtl.passes.SplitExpressions
)
@@ -215,76 +215,6 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers {
compare(legacyTransforms(new MiddleFirrtlToLowFirrtl), tm, patches)
}
- behavior.of("MinimumLowFirrtlOptimization")
-
- it should "replicate the old order" in {
- val tm = new TransformManager(Forms.LowFormMinimumOptimized, Forms.LowForm)
- val patches = Seq(
- Add(4, Seq(Dependency(firrtl.passes.ResolveFlows))),
- Add(6, Seq(Dependency[firrtl.transforms.LegalizeAndReductionsTransform], Dependency(firrtl.passes.ResolveKinds)))
- )
- compare(legacyTransforms(new MinimumLowFirrtlOptimization), tm, patches)
- }
-
- behavior.of("LowFirrtlOptimization")
-
- it should "replicate the old order" in {
- val tm = new TransformManager(Forms.LowFormOptimized, Forms.LowForm)
- val patches = Seq(
- Add(6, Seq(Dependency(firrtl.passes.ResolveFlows))),
- Add(7, Seq(Dependency(firrtl.passes.Legalize))),
- Add(8, Seq(Dependency[firrtl.transforms.LegalizeAndReductionsTransform], Dependency(firrtl.passes.ResolveKinds)))
- )
- compare(legacyTransforms(new LowFirrtlOptimization), tm, patches)
- }
-
- behavior.of("VerilogMinimumOptimized")
-
- it should "replicate the old order" in {
- val legacy = Seq(
- new firrtl.transforms.BlackBoxSourceHelper,
- new firrtl.transforms.FixAddingNegativeLiterals,
- new firrtl.transforms.ReplaceTruncatingArithmetic,
- new firrtl.transforms.InlineBitExtractionsTransform,
- new firrtl.transforms.PropagatePresetAnnotations,
- new firrtl.transforms.InlineAcrossCastsTransform,
- new firrtl.transforms.LegalizeClocksTransform,
- new firrtl.transforms.FlattenRegUpdate,
- firrtl.passes.VerilogModulusCleanup,
- new firrtl.transforms.VerilogRename,
- firrtl.passes.InferTypes,
- firrtl.passes.VerilogPrep,
- new firrtl.AddDescriptionNodes
- )
- val tm = new TransformManager(Forms.VerilogMinimumOptimized, (new firrtl.VerilogEmitter).prerequisites)
- compare(legacy, tm)
- }
-
- behavior.of("VerilogOptimized")
-
- it should "replicate the old order" in {
- val legacy = Seq(
- new firrtl.transforms.InlineBooleanExpressions,
- new firrtl.transforms.DeadCodeElimination,
- new firrtl.transforms.BlackBoxSourceHelper,
- new firrtl.transforms.FixAddingNegativeLiterals,
- new firrtl.transforms.ReplaceTruncatingArithmetic,
- new firrtl.transforms.InlineBitExtractionsTransform,
- new firrtl.transforms.PropagatePresetAnnotations,
- new firrtl.transforms.InlineAcrossCastsTransform,
- new firrtl.transforms.LegalizeClocksTransform,
- new firrtl.transforms.FlattenRegUpdate,
- new firrtl.transforms.DeadCodeElimination,
- firrtl.passes.VerilogModulusCleanup,
- new firrtl.transforms.VerilogRename,
- firrtl.passes.InferTypes,
- firrtl.passes.VerilogPrep,
- new firrtl.AddDescriptionNodes
- )
- val tm = new TransformManager(Forms.VerilogOptimized, Forms.LowFormOptimized)
- compare(legacy, tm)
- }
-
behavior.of("Legacy Custom Transforms")
it should "work for Chirrtl -> Chirrtl" in {
@@ -311,7 +241,7 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers {
compare(expected, tm)
}
- it should "work for Mid -> Mid" in {
+ it should "work for Mid -> Mid" ignore {
val expected =
new TransformManager(Forms.MidForm).flattenedTransformOrder ++
Some(new Transforms.MidToMid) ++
diff --git a/src/test/scala/firrtlTests/PadWidthsTests.scala b/src/test/scala/firrtlTests/PadWidthsTests.scala
new file mode 100644
index 00000000..c92a8b79
--- /dev/null
+++ b/src/test/scala/firrtlTests/PadWidthsTests.scala
@@ -0,0 +1,170 @@
+// See LICENSE for license details.
+
+package firrtlTests
+
+import firrtl.CircuitState
+import firrtl.options.Dependency
+import firrtl.stage.{Forms, TransformManager}
+import firrtl.testutils.LeanTransformSpec
+
+class PadWidthsTests extends LeanTransformSpec(Seq(Dependency(firrtl.passes.PadWidths))) {
+ behavior.of("PadWidths pass")
+
+ it should "pad widths inside a mux" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input a : UInt<32>
+ | input b : UInt<20>
+ | input pred : UInt<1>
+ | output c : UInt<32>
+ | c <= mux(pred,a,b)""".stripMargin
+ val check = Seq("c <= mux(pred, a, pad(b, 32))")
+ executeTest(input, check)
+ }
+
+ it should "pad widths of connects" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | output a : UInt<32>
+ | input b : UInt<20>
+ | a <= b
+ | """.stripMargin
+ val check = Seq("a <= pad(b, 32)")
+ executeTest(input, check)
+ }
+
+ it should "pad widths of register init expressions" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clock: Clock
+ | input reset: AsyncReset
+ |
+ | reg r: UInt<8>, clock with:
+ | reset => (reset, UInt<1>("h1"))
+ | """.stripMargin
+ // PadWidths will call into constant prop directly, thus the literal is widened instead of adding a pad
+ val check = Seq("reset => (reset, UInt<8>(\"h1\"))")
+ executeTest(input, check)
+ }
+
+ private def testOp(op: String, width: Int, resultWidth: Int): Unit = {
+ assert(width > 0)
+ val input =
+ s"""circuit Top :
+ | module Top :
+ | input a : UInt<32>
+ | input b : UInt<$width>
+ | output c : UInt<$resultWidth>
+ | c <= $op(a,b)""".stripMargin
+ val check = if (width < 32) {
+ Seq(s"c <= $op(a, pad(b, 32))")
+ } else if (width == 32) {
+ Seq(s"c <= $op(a, b)")
+ } else {
+ Seq(s"c <= $op(pad(a, $width), b)")
+ }
+ executeTest(input, check)
+ }
+
+ it should "pad widths of the arguments to add and sub" in {
+ // add and sub have the same width inference rule: max(w_1, w_2) + 1
+ testOp("add", 2, 33)
+ testOp("add", 32, 33)
+ testOp("add", 35, 36)
+
+ testOp("sub", 2, 33)
+ testOp("sub", 32, 33)
+ testOp("sub", 35, 36)
+ }
+
+ it should "pad widths of the arguments to and, or and xor" in {
+ // and, or and xor have the same width inference rule: max(w_1, w_2)
+ testOp("and", 2, 32)
+ testOp("and", 32, 32)
+ testOp("and", 35, 35)
+
+ testOp("or", 2, 32)
+ testOp("or", 32, 32)
+ testOp("or", 35, 35)
+
+ testOp("xor", 2, 32)
+ testOp("xor", 32, 32)
+ testOp("xor", 35, 35)
+ }
+
+ it should "pad widths of the arguments to lt, leq, gt, geq, eq and neq" in {
+ // lt, leq, gt, geq, eq and ne have the same width inference rule: 1
+ testOp("lt", 2, 1)
+ testOp("lt", 32, 1)
+ testOp("lt", 35, 1)
+
+ testOp("leq", 2, 1)
+ testOp("leq", 32, 1)
+ testOp("leq", 35, 1)
+
+ testOp("gt", 2, 1)
+ testOp("gt", 32, 1)
+ testOp("gt", 35, 1)
+
+ testOp("geq", 2, 1)
+ testOp("geq", 32, 1)
+ testOp("geq", 35, 1)
+
+ testOp("eq", 2, 1)
+ testOp("eq", 32, 1)
+ testOp("eq", 35, 1)
+
+ testOp("neq", 2, 1)
+ testOp("neq", 32, 1)
+ testOp("neq", 35, 1)
+ }
+
+ private val resolvedCompiler = new TransformManager(Forms.Resolved)
+ private def checkWidthsAfterPadWidths(input: String, op: String): Unit = {
+ val result = compile(input)
+
+ // we serialize the result in order to rerun width inference
+ val resultFir = firrtl.Parser.parse(result.circuit.serialize)
+ val newWidths = resolvedCompiler.runTransform(CircuitState(resultFir, Seq()))
+
+ // the newly loaded circuit should look the same in serialized form (if this fails, the test has a bug)
+ assert(newWidths.circuit.serialize == result.circuit.serialize)
+
+ // we compare the widths produced by PadWidths with the widths that would normally be inferred
+ assert(newWidths.circuit.modules.head == result.circuit.modules.head, s"failed with op `$op`")
+ }
+
+ it should "always generate valid firrtl" in {
+ // an older version of PadWidths would generate ill types firrtl for mul, div, rem and dshl
+
+ def input(op: String): String =
+ s"""circuit Top:
+ | module Top:
+ | input a: UInt<3>
+ | input b: UInt<1>
+ | output c: UInt
+ | c <= $op(a, b)
+ |""".stripMargin
+
+ def test(op: String): Unit = checkWidthsAfterPadWidths(input(op), op)
+
+ // This was never broken, but we want to make sure that the test works.
+ test("add")
+
+ test("mul")
+ test("div")
+ test("rem")
+ test("dshl")
+ }
+
+ private def executeTest(input: String, expected: Seq[String]): Unit = {
+ val result = compile(input)
+ val lines = result.circuit.serialize.split("\n").map(normalized)
+ expected.map(normalized).foreach { e =>
+ assert(lines.contains(e), f"Failed to find $e in ${lines.mkString("\n")}")
+ }
+ }
+}
diff --git a/src/test/scala/firrtlTests/RemoveWiresSpec.scala b/src/test/scala/firrtlTests/RemoveWiresSpec.scala
index 58d42710..4022b267 100644
--- a/src/test/scala/firrtlTests/RemoveWiresSpec.scala
+++ b/src/test/scala/firrtlTests/RemoveWiresSpec.scala
@@ -98,7 +98,7 @@ class RemoveWiresSpec extends FirrtlFlatSpec {
val (nodes, wires) = getNodesAndWires(result.circuit)
wires.size should be(0)
nodes.map(_.serialize) should be(
- Seq("""node w = pad(UInt<2>("h2"), 8)""")
+ Seq("""node w = UInt<8>("h2")""")
)
}
diff --git a/src/test/scala/firrtlTests/VerilogMemDelaySpec.scala b/src/test/scala/firrtlTests/VerilogMemDelaySpec.scala
index 32b1c55d..8491977c 100644
--- a/src/test/scala/firrtlTests/VerilogMemDelaySpec.scala
+++ b/src/test/scala/firrtlTests/VerilogMemDelaySpec.scala
@@ -2,32 +2,22 @@
package firrtlTests
-import firrtl._
import firrtl.testutils._
-import firrtl.testutils.FirrtlCheckers._
import firrtl.ir.Circuit
-import firrtl.stage.{FirrtlCircuitAnnotation, FirrtlSourceAnnotation, FirrtlStage}
+import firrtl.options.Dependency
+import firrtl.passes.memlib.VerilogMemDelays
-import org.scalatest.freespec.AnyFreeSpec
-import org.scalatest.matchers.should.Matchers
-
-class VerilogMemDelaySpec extends AnyFreeSpec with Matchers {
+class VerilogMemDelaySpec extends LeanTransformSpec(Seq(Dependency(VerilogMemDelays))) {
+ behavior.of("VerilogMemDelaySpec")
private def compileTwiceReturnFirst(input: String): Circuit = {
- (new FirrtlStage)
- .transform(Seq(FirrtlSourceAnnotation(input)))
- .toSeq
- .collectFirst {
- case fca: FirrtlCircuitAnnotation =>
- (new FirrtlStage).transform(Seq(fca))
- fca.circuit
- }
- .get
+ val res0 = compile(input)
+ compile(res0.circuit.serialize).circuit
}
private def compileTwice(input: String): Unit = compileTwiceReturnFirst(input)
- "The following low FIRRTL should be parsed by VerilogMemDelays" in {
+ it should "The following low FIRRTL should be parsed by VerilogMemDelays" in {
val input =
"""
|circuit Test :
@@ -63,7 +53,7 @@ class VerilogMemDelaySpec extends AnyFreeSpec with Matchers {
compileTwice(input)
}
- "Using a read-first memory should be allowed in VerilogMemDelays" in {
+ it should "Using a read-first memory should be allowed in VerilogMemDelays" in {
val input =
"""
|circuit Test :
@@ -107,7 +97,7 @@ class VerilogMemDelaySpec extends AnyFreeSpec with Matchers {
compileTwice(input)
}
- "Chained memories should generate correct FIRRTL" in {
+ it should "Chained memories should generate correct FIRRTL" in {
val input =
"""
|circuit Test :
@@ -151,7 +141,7 @@ class VerilogMemDelaySpec extends AnyFreeSpec with Matchers {
compileTwice(input)
}
- "VerilogMemDelays should not violate use before declaration of clocks" in {
+ it should "VerilogMemDelays should not violate use before declaration of clocks" in {
val input =
"""
|circuit Test :
@@ -188,7 +178,7 @@ class VerilogMemDelaySpec extends AnyFreeSpec with Matchers {
| m.write.data <= in
""".stripMargin
- val res = compileTwiceReturnFirst(input).serialize
+ val res = compile(input).circuit.serialize
// Inject a Wire when using a clock not derived from ports
res should include("wire m_clock : Clock")
res should include("m_clock <= cm.clock")