aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala')
-rw-r--r--src/main/scala/firrtl/Emitter.scala8
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala4
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala6
-rw-r--r--src/main/scala/firrtl/backends/firrtl/FirrtlEmitter.scala32
-rw-r--r--src/main/scala/firrtl/backends/verilog/LegalizeVerilog.scala72
-rw-r--r--src/main/scala/firrtl/checks/CheckResets.scala1
-rw-r--r--src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala10
-rw-r--r--src/main/scala/firrtl/passes/Legalize.scala108
-rw-r--r--src/main/scala/firrtl/passes/LegalizeConnects.scala31
-rw-r--r--src/main/scala/firrtl/passes/PadWidths.scala84
-rw-r--r--src/main/scala/firrtl/passes/RemoveValidIf.scala2
-rw-r--r--src/main/scala/firrtl/passes/SplitExpressions.scala8
-rw-r--r--src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala2
-rw-r--r--src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala4
-rw-r--r--src/main/scala/firrtl/stage/Forms.scala49
-rw-r--r--src/main/scala/firrtl/transforms/CheckCombLoops.scala2
-rw-r--r--src/main/scala/firrtl/transforms/CombineCats.scala3
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala143
-rw-r--r--src/main/scala/firrtl/transforms/DeadCodeElimination.scala10
-rw-r--r--src/main/scala/firrtl/transforms/RemoveReset.scala2
-rw-r--r--src/main/scala/firrtl/transforms/RemoveWires.scala14
21 files changed, 313 insertions, 282 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala
index 7c91a544..760e83fd 100644
--- a/src/main/scala/firrtl/Emitter.scala
+++ b/src/main/scala/firrtl/Emitter.scala
@@ -3,12 +3,11 @@
package firrtl
import java.io.File
-
import firrtl.annotations.NoTargetAnnotation
import firrtl.backends.experimental.smt.{Btor2Emitter, SMTLibEmitter}
import firrtl.backends.proto.{Emitter => ProtoEmitter}
import firrtl.options.Viewer.view
-import firrtl.options.{CustomFileEmission, HasShellOptions, PhaseException, ShellOption}
+import firrtl.options.{CustomFileEmission, Dependency, HasShellOptions, PhaseException, ShellOption}
import firrtl.passes.PassException
import firrtl.stage.{FirrtlFileAnnotation, FirrtlOptions, RunFirrtlTransformAnnotation}
@@ -45,6 +44,11 @@ object EmitCircuitAnnotation extends HasShellOptions {
)
case "low" =>
Seq(RunFirrtlTransformAnnotation(new LowFirrtlEmitter), EmitCircuitAnnotation(classOf[LowFirrtlEmitter]))
+ case "low-opt" =>
+ Seq(
+ RunFirrtlTransformAnnotation(Dependency(LowFirrtlOptimizedEmitter)),
+ EmitCircuitAnnotation(LowFirrtlOptimizedEmitter.getClass)
+ )
case "verilog" | "mverilog" =>
Seq(RunFirrtlTransformAnnotation(new VerilogEmitter), EmitCircuitAnnotation(classOf[VerilogEmitter]))
case "sverilog" =>
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
index 099b6712..2c08ff6a 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
@@ -118,8 +118,8 @@ private object FirrtlExpressionSemantics {
// the resulting value will be zero for unsigned types
// and the sign bit for signed types"
if (n >= width) {
- if (isSigned(e)) { BV1BitZero }
- else { BVSlice(toSMT(e), width - 1, width - 1) }
+ if (isSigned(e)) { BVSlice(toSMT(e), width - 1, width - 1) }
+ else { BV1BitZero }
} else {
BVSlice(toSMT(e), width - 1, n.toInt)
}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
index 78ad3c80..0b8e3ebf 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
@@ -451,7 +451,7 @@ private class ModuleScanner(
val name = loc.serialize
insertDummyAssignsForUnusedOutputs(expr)
infos.append(name -> info)
- connects.append((name, onExpression(expr, bitWidth(loc.tpe).toInt)))
+ connects.append((name, onExpression(expr, bitWidth(loc.tpe).toInt, allowNarrow = true)))
}
case i @ ir.IsInvalid(info, loc) =>
if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!")
@@ -591,9 +591,9 @@ private class ModuleScanner(
private case class Context() extends TranslationContext {}
- private def onExpression(e: ir.Expression, width: Int): BVExpr = {
+ private def onExpression(e: ir.Expression, width: Int, allowNarrow: Boolean = false): BVExpr = {
implicit val ctx: TranslationContext = Context()
- FirrtlExpressionSemantics.toSMT(e, width, allowNarrow = false)
+ FirrtlExpressionSemantics.toSMT(e, width, allowNarrow)
}
private def onExpression(e: ir.Expression): BVExpr = {
implicit val ctx: TranslationContext = Context()
diff --git a/src/main/scala/firrtl/backends/firrtl/FirrtlEmitter.scala b/src/main/scala/firrtl/backends/firrtl/FirrtlEmitter.scala
index bb385ffd..56b63d75 100644
--- a/src/main/scala/firrtl/backends/firrtl/FirrtlEmitter.scala
+++ b/src/main/scala/firrtl/backends/firrtl/FirrtlEmitter.scala
@@ -1,18 +1,21 @@
package firrtl
import java.io.Writer
-
import firrtl.Utils._
import firrtl.ir._
+import firrtl.stage.TransformManager.TransformDependency
import firrtl.traversals.Foreachers._
import scala.collection.mutable
-sealed abstract class FirrtlEmitter(form: CircuitForm) extends Transform with Emitter {
- def inputForm = form
- def outputForm = form
-
- val outputSuffix: String = form.outputSuffix
+sealed abstract class FirrtlEmitter(form: Seq[TransformDependency], val outputSuffix: String)
+ extends Transform
+ with Emitter
+ with DependencyAPIMigration {
+ override def prerequisites = form
+ override def optionalPrerequisites = Seq.empty
+ override def optionalPrerequisiteOf = Seq.empty
+ override def invalidates(a: Transform) = false
private def emitAllModules(circuit: Circuit): Seq[EmittedFirrtlModule] = {
// For a given module, returns a Seq of all modules instantited inside of it
@@ -60,14 +63,9 @@ sealed abstract class FirrtlEmitter(form: CircuitForm) extends Transform with Em
def emit(state: CircuitState, writer: Writer): Unit = writer.write(state.circuit.serialize)
}
-class ChirrtlEmitter extends FirrtlEmitter(CircuitForm.ChirrtlForm)
-class MinimumHighFirrtlEmitter extends FirrtlEmitter(CircuitForm.HighForm) {
- override def prerequisites = stage.Forms.MinimalHighForm
- override def optionalPrerequisites = Seq.empty
- override def optionalPrerequisiteOf = Seq.empty
- override def invalidates(a: Transform) = false
- override val outputSuffix = ".mhi.fir"
-}
-class HighFirrtlEmitter extends FirrtlEmitter(CircuitForm.HighForm)
-class MiddleFirrtlEmitter extends FirrtlEmitter(CircuitForm.MidForm)
-class LowFirrtlEmitter extends FirrtlEmitter(CircuitForm.LowForm)
+class ChirrtlEmitter extends FirrtlEmitter(stage.Forms.ChirrtlForm, ".fir")
+class MinimumHighFirrtlEmitter extends FirrtlEmitter(stage.Forms.MinimalHighForm, ".mhi.fir")
+class HighFirrtlEmitter extends FirrtlEmitter(stage.Forms.HighForm, ".hi.fir")
+class MiddleFirrtlEmitter extends FirrtlEmitter(stage.Forms.MidForm, ".mid.fir")
+class LowFirrtlEmitter extends FirrtlEmitter(stage.Forms.LowForm, ".lo.fir")
+object LowFirrtlOptimizedEmitter extends FirrtlEmitter(stage.Forms.LowFormOptimized, ".opt.lo.fir")
diff --git a/src/main/scala/firrtl/backends/verilog/LegalizeVerilog.scala b/src/main/scala/firrtl/backends/verilog/LegalizeVerilog.scala
new file mode 100644
index 00000000..f063f395
--- /dev/null
+++ b/src/main/scala/firrtl/backends/verilog/LegalizeVerilog.scala
@@ -0,0 +1,72 @@
+// SPDX-License-Identifier: Apache-2.0
+
+package firrtl.backends.verilog
+
+import firrtl.PrimOps._
+import firrtl.Utils.{error, getGroundZero, zero, BoolType}
+import firrtl.ir._
+import firrtl.transforms.ConstantPropagation
+import firrtl.{bitWidth, Dshlw, Transform}
+import firrtl.Mappers._
+import firrtl.passes.{Pass, SplitExpressions}
+
+/** Rewrites some expressions for valid/better Verilog emission.
+ * - solves shift right overflows by replacing the shift with 0 for UInts and MSB for SInts
+ * - ensures that bit extracts on literals get resolved
+ * - ensures that all negations are replaced with subtract from zero
+ * - adds padding for rem and dshl which breaks firrtl width invariance, but is needed to match Verilog semantics
+ */
+object LegalizeVerilog extends Pass {
+
+ override def prerequisites = firrtl.stage.Forms.LowForm
+ override def optionalPrerequisites = Seq.empty
+ override def optionalPrerequisiteOf = Seq.empty
+ override def invalidates(a: Transform): Boolean = a match {
+ case SplitExpressions => true // we generate pad and bits operations inline which need to be split up
+ case _ => false
+ }
+
+ private def legalizeBitExtract(expr: DoPrim): Expression = {
+ expr.args.head match {
+ case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.constPropBitExtract(expr)
+ case _ => expr
+ }
+ }
+
+ // Convert `-x` to `0 - x`
+ private def legalizeNeg(expr: DoPrim): Expression = {
+ val arg = expr.args.head
+ arg.tpe match {
+ case tpe: SIntType =>
+ val zero = getGroundZero(tpe)
+ DoPrim(Sub, Seq(zero, arg), Nil, expr.tpe)
+ case tpe: UIntType =>
+ val zero = getGroundZero(tpe)
+ val sub = DoPrim(Sub, Seq(zero, arg), Nil, UIntType(tpe.width + IntWidth(1)))
+ DoPrim(AsSInt, Seq(sub), Nil, expr.tpe)
+ }
+ }
+
+ import firrtl.passes.PadWidths.forceWidth
+ private def getWidth(e: Expression): Int = bitWidth(e.tpe).toInt
+
+ private def onExpr(expr: Expression): Expression = expr.map(onExpr) match {
+ case prim: DoPrim =>
+ prim.op match {
+ case Shr => ConstantPropagation.foldShiftRight(prim)
+ case Bits | Head | Tail => legalizeBitExtract(prim)
+ case Neg => legalizeNeg(prim)
+ case Rem => prim.map(forceWidth(prim.args.map(getWidth).max))
+ case Dshl =>
+ // special case as args aren't all same width
+ prim.copy(op = Dshlw, args = Seq(forceWidth(getWidth(prim))(prim.args.head), prim.args(1)))
+ case _ => prim
+ }
+ case e => e // respect pre-order traversal
+ }
+
+ def run(c: Circuit): Circuit = {
+ def legalizeS(s: Statement): Statement = s.mapStmt(legalizeS).mapExpr(onExpr)
+ c.copy(modules = c.modules.map(_.map(legalizeS)))
+ }
+}
diff --git a/src/main/scala/firrtl/checks/CheckResets.scala b/src/main/scala/firrtl/checks/CheckResets.scala
index ae300d1f..e5a3e77a 100644
--- a/src/main/scala/firrtl/checks/CheckResets.scala
+++ b/src/main/scala/firrtl/checks/CheckResets.scala
@@ -33,7 +33,6 @@ class CheckResets extends Transform with DependencyAPIMigration {
override def prerequisites =
Seq(
Dependency(passes.LowerTypes),
- Dependency(passes.Legalize),
Dependency(firrtl.transforms.RemoveReset)
) ++ firrtl.stage.Forms.MidForm
diff --git a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala
index e70346d4..70da011e 100644
--- a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala
+++ b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala
@@ -9,15 +9,7 @@ import firrtl.options.Dependency
object CommonSubexpressionElimination extends Pass {
- override def prerequisites = firrtl.stage.Forms.LowForm ++
- Seq(
- Dependency(firrtl.passes.RemoveValidIf),
- Dependency[firrtl.transforms.ConstantPropagation],
- Dependency(firrtl.passes.memlib.VerilogMemDelays),
- Dependency(firrtl.passes.SplitExpressions),
- Dependency[firrtl.transforms.CombineCats]
- )
-
+ override def prerequisites = firrtl.stage.Forms.LowForm
override def optionalPrerequisiteOf =
Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
diff --git a/src/main/scala/firrtl/passes/Legalize.scala b/src/main/scala/firrtl/passes/Legalize.scala
deleted file mode 100644
index e1a39fbe..00000000
--- a/src/main/scala/firrtl/passes/Legalize.scala
+++ /dev/null
@@ -1,108 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-
-package firrtl.passes
-
-import firrtl.PrimOps._
-import firrtl.Utils.{error, getGroundZero, zero, BoolType}
-import firrtl.ir._
-import firrtl.options.Dependency
-import firrtl.transforms.ConstantPropagation
-import firrtl.{bitWidth, getWidth, Transform}
-import firrtl.Mappers._
-
-// Replace shr by amount >= arg width with 0 for UInts and MSB for SInts
-// TODO replace UInt with zero-width wire instead
-object Legalize extends Pass {
-
- override def prerequisites = firrtl.stage.Forms.MidForm :+ Dependency(LowerTypes)
-
- override def optionalPrerequisites = Seq.empty
-
- override def optionalPrerequisiteOf = Seq.empty
-
- override def invalidates(a: Transform) = false
-
- private def legalizeShiftRight(e: DoPrim): Expression = {
- require(e.op == Shr)
- e.args.head match {
- case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.foldShiftRight(e)
- case _ =>
- val amount = e.consts.head.toInt
- val width = bitWidth(e.args.head.tpe)
- lazy val msb = width - 1
- if (amount >= width) {
- e.tpe match {
- case UIntType(_) => zero
- case SIntType(_) =>
- val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType)
- DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1)))
- case t => error(s"Unsupported type $t for Primop Shift Right")
- }
- } else {
- e
- }
- }
- }
- private def legalizeBitExtract(expr: DoPrim): Expression = {
- expr.args.head match {
- case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.constPropBitExtract(expr)
- case _ => expr
- }
- }
- private def legalizePad(expr: DoPrim): Expression = expr.args.head match {
- case UIntLiteral(value, IntWidth(width)) if width < expr.consts.head =>
- UIntLiteral(value, IntWidth(expr.consts.head))
- case SIntLiteral(value, IntWidth(width)) if width < expr.consts.head =>
- SIntLiteral(value, IntWidth(expr.consts.head))
- case _ => expr
- }
- // Convert `-x` to `0 - x`
- private def legalizeNeg(expr: DoPrim): Expression = {
- val arg = expr.args.head
- arg.tpe match {
- case tpe: SIntType =>
- val zero = getGroundZero(tpe)
- DoPrim(Sub, Seq(zero, arg), Nil, expr.tpe)
- case tpe: UIntType =>
- val zero = getGroundZero(tpe)
- val sub = DoPrim(Sub, Seq(zero, arg), Nil, UIntType(tpe.width + IntWidth(1)))
- DoPrim(AsSInt, Seq(sub), Nil, expr.tpe)
- }
- }
- private def legalizeConnect(c: Connect): Statement = {
- val t = c.loc.tpe
- val w = bitWidth(t)
- if (w >= bitWidth(c.expr.tpe)) {
- c
- } else {
- val bits = DoPrim(Bits, Seq(c.expr), Seq(w - 1, 0), UIntType(IntWidth(w)))
- val expr = t match {
- case UIntType(_) => bits
- case SIntType(_) => DoPrim(AsSInt, Seq(bits), Seq(), SIntType(IntWidth(w)))
- case FixedType(_, IntWidth(p)) => DoPrim(AsFixedPoint, Seq(bits), Seq(p), t)
- }
- Connect(c.info, c.loc, expr)
- }
- }
- def run(c: Circuit): Circuit = {
- def legalizeE(expr: Expression): Expression = expr.map(legalizeE) match {
- case prim: DoPrim =>
- prim.op match {
- case Shr => legalizeShiftRight(prim)
- case Pad => legalizePad(prim)
- case Bits | Head | Tail => legalizeBitExtract(prim)
- case Neg => legalizeNeg(prim)
- case _ => prim
- }
- case e => e // respect pre-order traversal
- }
- def legalizeS(s: Statement): Statement = {
- val legalizedStmt = s match {
- case c: Connect => legalizeConnect(c)
- case _ => s
- }
- legalizedStmt.map(legalizeS).map(legalizeE)
- }
- c.copy(modules = c.modules.map(_.map(legalizeS)))
- }
-}
diff --git a/src/main/scala/firrtl/passes/LegalizeConnects.scala b/src/main/scala/firrtl/passes/LegalizeConnects.scala
new file mode 100644
index 00000000..2f29de10
--- /dev/null
+++ b/src/main/scala/firrtl/passes/LegalizeConnects.scala
@@ -0,0 +1,31 @@
+// SPDX-License-Identifier: Apache-2.0
+
+package firrtl.passes
+
+import firrtl.ir._
+import firrtl.options.Dependency
+import firrtl.{bitWidth, Transform}
+
+/** Ensures that all connects + register inits have the same bit-width on the rhs and the lhs.
+ * The rhs is padded or bit-extacted to fit the width of the lhs.
+ * @note technically, width(rhs) > width(lhs) is not legal firrtl, however, we do not error for historic reasons.
+ */
+object LegalizeConnects extends Pass {
+
+ override def prerequisites = firrtl.stage.Forms.MidForm :+ Dependency(LowerTypes)
+ override def optionalPrerequisites = Seq.empty
+ override def optionalPrerequisiteOf = Seq.empty
+ override def invalidates(a: Transform) = false
+
+ def onStmt(s: Statement): Statement = s match {
+ case c: Connect =>
+ c.copy(expr = PadWidths.forceWidth(bitWidth(c.loc.tpe).toInt)(c.expr))
+ case r: DefRegister =>
+ r.copy(init = PadWidths.forceWidth(bitWidth(r.tpe).toInt)(r.init))
+ case other => other.mapStmt(onStmt)
+ }
+
+ def run(c: Circuit): Circuit = {
+ c.copy(modules = c.modules.map(_.mapStmt(onStmt)))
+ }
+}
diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala
index 1a430778..02e94975 100644
--- a/src/main/scala/firrtl/passes/PadWidths.scala
+++ b/src/main/scala/firrtl/passes/PadWidths.scala
@@ -7,63 +7,59 @@ import firrtl.ir._
import firrtl.PrimOps._
import firrtl.Mappers._
import firrtl.options.Dependency
-
-import scala.collection.mutable
+import firrtl.transforms.ConstantPropagation
// Makes all implicit width extensions and truncations explicit
object PadWidths extends Pass {
- override def prerequisites =
- ((new mutable.LinkedHashSet())
- ++ firrtl.stage.Forms.LowForm
- - Dependency(firrtl.passes.Legalize)
- + Dependency(firrtl.passes.RemoveValidIf)).toSeq
-
- override def optionalPrerequisites = Seq(Dependency[firrtl.transforms.ConstantPropagation])
+ override def prerequisites = firrtl.stage.Forms.LowForm
override def optionalPrerequisiteOf =
Seq(Dependency(firrtl.passes.memlib.VerilogMemDelays), Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
override def invalidates(a: Transform): Boolean = a match {
- case _: firrtl.transforms.ConstantPropagation | Legalize => true
- case _ => false
+ case SplitExpressions => true // we generate pad and bits operations inline which need to be split up
+ case _ => false
}
- private def width(t: Type): Int = bitWidth(t).toInt
- private def width(e: Expression): Int = width(e.tpe)
- // Returns an expression with the correct integer width
- private def fixup(i: Int)(e: Expression) = {
- def tx = e.tpe match {
- case t: UIntType => UIntType(IntWidth(i))
- case t: SIntType => SIntType(IntWidth(i))
- // default case should never be reached
- }
- width(e) match {
- case j if i > j => DoPrim(Pad, Seq(e), Seq(i), tx)
- case j if i < j =>
- val e2 = DoPrim(Bits, Seq(e), Seq(i - 1, 0), UIntType(IntWidth(i)))
- // Bit Select always returns UInt, cast if selecting from SInt
- e.tpe match {
- case UIntType(_) => e2
- case SIntType(_) => DoPrim(AsSInt, Seq(e2), Seq.empty, SIntType(IntWidth(i)))
- }
- case _ => e
+ /** Adds padding or a bit extract to ensure that the expression is of the with specified.
+ * @note only works on UInt and SInt type expressions, other expressions will yield a match error
+ */
+ private[firrtl] def forceWidth(width: Int)(e: Expression): Expression = {
+ val old = getWidth(e)
+ if (width == old) { e }
+ else if (width > old) {
+ // padding retains the signedness
+ val newType = e.tpe match {
+ case _: UIntType => UIntType(IntWidth(width))
+ case _: SIntType => SIntType(IntWidth(width))
+ case other => throw new RuntimeException(s"forceWidth does not support expressions of type $other")
+ }
+ ConstantPropagation.constPropPad(DoPrim(Pad, Seq(e), Seq(width), newType))
+ } else {
+ val extract = DoPrim(Bits, Seq(e), Seq(width - 1, 0), UIntType(IntWidth(width)))
+ val e2 = ConstantPropagation.constPropBitExtract(extract)
+ // Bit Select always returns UInt, cast if selecting from SInt
+ e.tpe match {
+ case UIntType(_) => e2
+ case SIntType(_) => DoPrim(AsSInt, Seq(e2), Seq.empty, SIntType(IntWidth(width)))
+ }
}
}
+ private def getWidth(t: Type): Int = bitWidth(t).toInt
+ private def getWidth(e: Expression): Int = getWidth(e.tpe)
+
// Recursive, updates expression so children exp's have correct widths
private def onExp(e: Expression): Expression = e.map(onExp) match {
case Mux(cond, tval, fval, tpe) =>
- Mux(cond, fixup(width(tpe))(tval), fixup(width(tpe))(fval), tpe)
- case ex: ValidIf => ex.copy(value = fixup(width(ex.tpe))(ex.value))
+ Mux(cond, forceWidth(getWidth(tpe))(tval), forceWidth(getWidth(tpe))(fval), tpe)
+ case ex: ValidIf => ex.copy(value = forceWidth(getWidth(ex.tpe))(ex.value))
case ex: DoPrim =>
ex.op match {
- case Lt | Leq | Gt | Geq | Eq | Neq | Not | And | Or | Xor | Add | Sub | Rem | Shr =>
- // sensitive ops
- ex.map(fixup((ex.args.map(width).foldLeft(0))(math.max)))
- case Dshl =>
- // special case as args aren't all same width
- ex.copy(op = Dshlw, args = Seq(fixup(width(ex.tpe))(ex.args.head), ex.args(1)))
+ // pad arguments to ops where the result width is determined as max(w_1, w_2) (+ const)?
+ case Lt | Leq | Gt | Geq | Eq | Neq | And | Or | Xor | Add | Sub =>
+ ex.map(forceWidth(ex.args.map(getWidth).max))
case _ => ex
}
case ex => ex
@@ -72,9 +68,17 @@ object PadWidths extends Pass {
// Recursive. Fixes assignments and register initialization widths
private def onStmt(s: Statement): Statement = s.map(onExp) match {
case sx: Connect =>
- sx.copy(expr = fixup(width(sx.loc))(sx.expr))
+ assert(
+ getWidth(sx.loc) == getWidth(sx.expr),
+ "Connection widths should have been taken care of by LegalizeConnects!"
+ )
+ sx
case sx: DefRegister =>
- sx.copy(init = fixup(width(sx.tpe))(sx.init))
+ assert(
+ getWidth(sx.tpe) == getWidth(sx.init),
+ "Register init widths should have been taken care of by LegalizeConnects!"
+ )
+ sx
case sx => sx.map(onStmt)
}
diff --git a/src/main/scala/firrtl/passes/RemoveValidIf.scala b/src/main/scala/firrtl/passes/RemoveValidIf.scala
index 03214f83..dc4e70ff 100644
--- a/src/main/scala/firrtl/passes/RemoveValidIf.scala
+++ b/src/main/scala/firrtl/passes/RemoveValidIf.scala
@@ -31,7 +31,7 @@ object RemoveValidIf extends Pass {
Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
override def invalidates(a: Transform): Boolean = a match {
- case Legalize | _: firrtl.transforms.ConstantPropagation => true
+ case _: firrtl.transforms.ConstantPropagation => true // switching out the validifs allows for more constant prop
case _ => false
}
diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala
index 1b4ed1cc..26088e9c 100644
--- a/src/main/scala/firrtl/passes/SplitExpressions.scala
+++ b/src/main/scala/firrtl/passes/SplitExpressions.scala
@@ -8,6 +8,7 @@ import firrtl.ir._
import firrtl.options.Dependency
import firrtl.Mappers._
import firrtl.Utils.{flow, get_info, kind}
+import firrtl.transforms.InlineBooleanExpressions
// Datastructures
import scala.collection.mutable
@@ -16,15 +17,14 @@ import scala.collection.mutable
// and named intermediate nodes
object SplitExpressions extends Pass {
- override def prerequisites = firrtl.stage.Forms.LowForm ++
- Seq(Dependency(firrtl.passes.RemoveValidIf), Dependency(firrtl.passes.memlib.VerilogMemDelays))
-
+ override def prerequisites = firrtl.stage.Forms.LowForm
override def optionalPrerequisiteOf =
Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
override def invalidates(a: Transform) = a match {
case ResolveKinds => true
- case _ => false
+ case _: InlineBooleanExpressions => true // SplitExpressions undoes the inlining!
+ case _ => false
}
private def onModule(m: Module): Module = {
diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
index c7b0fbcd..331dd43e 100644
--- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
+++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
@@ -153,7 +153,7 @@ class ReplSeqMem extends SeqTransform with HasShellOptions with DependencyAPIMig
val transforms: Seq[Transform] =
Seq(
- new SimpleMidTransform(Legalize),
+ new SimpleMidTransform(LegalizeConnects),
new SimpleMidTransform(ToMemIR),
new SimpleMidTransform(ResolveMaskGranularity),
new SimpleMidTransform(RenameAnnotatedMemoryPorts),
diff --git a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala
index 11184e60..3778f4da 100644
--- a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala
+++ b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala
@@ -235,8 +235,8 @@ object VerilogMemDelays extends Pass {
Seq(Dependency[VerilogEmitter], Dependency[SystemVerilogEmitter])
override def invalidates(a: Transform): Boolean = a match {
- case _: transforms.ConstantPropagation | ResolveFlows => true
- case _ => false
+ case ResolveFlows => true
+ case _ => false
}
private def transform(m: DefModule): DefModule = (new MemDelayAndReadwriteTransformer(m)).transformed
diff --git a/src/main/scala/firrtl/stage/Forms.scala b/src/main/scala/firrtl/stage/Forms.scala
index c7ae648a..83d019f5 100644
--- a/src/main/scala/firrtl/stage/Forms.scala
+++ b/src/main/scala/firrtl/stage/Forms.scala
@@ -75,7 +75,7 @@ object Forms {
val LowForm: Seq[TransformDependency] = MidForm ++
Seq(
Dependency(passes.LowerTypes),
- Dependency(passes.Legalize),
+ Dependency(passes.LegalizeConnects),
Dependency(firrtl.transforms.RemoveReset),
Dependency[firrtl.transforms.CheckCombLoops],
Dependency[checks.CheckResets],
@@ -86,39 +86,42 @@ object Forms {
Seq(
Dependency(passes.RemoveValidIf),
Dependency(passes.PadWidths),
- Dependency(passes.memlib.VerilogMemDelays),
- Dependency(passes.SplitExpressions),
- Dependency[firrtl.transforms.LegalizeAndReductionsTransform]
+ Dependency(passes.SplitExpressions)
)
val LowFormOptimized: Seq[TransformDependency] = LowFormMinimumOptimized ++
Seq(
Dependency[firrtl.transforms.ConstantPropagation],
- Dependency[firrtl.transforms.CombineCats],
Dependency(passes.CommonSubexpressionElimination),
Dependency[firrtl.transforms.DeadCodeElimination]
)
- val VerilogMinimumOptimized: Seq[TransformDependency] = LowFormMinimumOptimized ++
+ private def VerilogLowerings(optimize: Boolean): Seq[TransformDependency] = {
Seq(
- Dependency[firrtl.transforms.BlackBoxSourceHelper],
- Dependency[firrtl.transforms.FixAddingNegativeLiterals],
- Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
- Dependency[firrtl.transforms.InlineBitExtractionsTransform],
- Dependency[firrtl.transforms.InlineAcrossCastsTransform],
- Dependency[firrtl.transforms.LegalizeClocksTransform],
- Dependency[firrtl.transforms.FlattenRegUpdate],
- Dependency(passes.VerilogModulusCleanup),
- Dependency[firrtl.transforms.VerilogRename],
- Dependency(passes.VerilogPrep),
- Dependency[firrtl.AddDescriptionNodes]
- )
-
- val VerilogOptimized: Seq[TransformDependency] = LowFormOptimized ++
- Seq(
- Dependency[firrtl.transforms.InlineBooleanExpressions]
+ Dependency(firrtl.backends.verilog.LegalizeVerilog),
+ Dependency(passes.memlib.VerilogMemDelays),
+ Dependency[firrtl.transforms.CombineCats]
) ++
- VerilogMinimumOptimized
+ (if (optimize) Seq(Dependency[firrtl.transforms.InlineBooleanExpressions]) else Seq()) ++
+ Seq(
+ Dependency[firrtl.transforms.LegalizeAndReductionsTransform],
+ Dependency[firrtl.transforms.BlackBoxSourceHelper],
+ Dependency[firrtl.transforms.FixAddingNegativeLiterals],
+ Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
+ Dependency[firrtl.transforms.InlineBitExtractionsTransform],
+ Dependency[firrtl.transforms.InlineAcrossCastsTransform],
+ Dependency[firrtl.transforms.LegalizeClocksTransform],
+ Dependency[firrtl.transforms.FlattenRegUpdate],
+ Dependency(passes.VerilogModulusCleanup),
+ Dependency[firrtl.transforms.VerilogRename],
+ Dependency(passes.VerilogPrep),
+ Dependency[firrtl.AddDescriptionNodes]
+ )
+ }
+
+ val VerilogMinimumOptimized: Seq[TransformDependency] = LowFormMinimumOptimized ++ VerilogLowerings(optimize = false)
+
+ val VerilogOptimized: Seq[TransformDependency] = LowFormOptimized ++ VerilogLowerings(optimize = true)
val AssertsRemoved: Seq[TransformDependency] =
Seq(
diff --git a/src/main/scala/firrtl/transforms/CheckCombLoops.scala b/src/main/scala/firrtl/transforms/CheckCombLoops.scala
index dee2f9c8..eec9d1af 100644
--- a/src/main/scala/firrtl/transforms/CheckCombLoops.scala
+++ b/src/main/scala/firrtl/transforms/CheckCombLoops.scala
@@ -101,7 +101,7 @@ case class CombinationalPath(sink: ReferenceTarget, sources: Seq[ReferenceTarget
class CheckCombLoops extends Transform with RegisteredTransform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.MidForm ++
- Seq(Dependency(passes.LowerTypes), Dependency(passes.Legalize), Dependency(firrtl.transforms.RemoveReset))
+ Seq(Dependency(passes.LowerTypes), Dependency(firrtl.transforms.RemoveReset))
override def optionalPrerequisites = Seq.empty
diff --git a/src/main/scala/firrtl/transforms/CombineCats.scala b/src/main/scala/firrtl/transforms/CombineCats.scala
index a37e6f08..71ef34bf 100644
--- a/src/main/scala/firrtl/transforms/CombineCats.scala
+++ b/src/main/scala/firrtl/transforms/CombineCats.scala
@@ -63,12 +63,11 @@ class CombineCats extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.LowForm ++
Seq(
Dependency(passes.RemoveValidIf),
- Dependency[firrtl.transforms.ConstantPropagation],
Dependency(firrtl.passes.memlib.VerilogMemDelays),
Dependency(firrtl.passes.SplitExpressions)
)
- override def optionalPrerequisites = Seq.empty
+ override def optionalPrerequisites = Seq(Dependency[firrtl.transforms.ConstantPropagation])
override def optionalPrerequisiteOf = Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index bc1fc9af..f216a3a3 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -3,7 +3,7 @@
package firrtl
package transforms
-import firrtl._
+import firrtl.{options, _}
import firrtl.annotations._
import firrtl.annotations.TargetToken._
import firrtl.ir._
@@ -41,8 +41,64 @@ object ConstantPropagation {
}
)
case (we, wt) if we == wt => e
+ case (we, wt) =>
+ throw new RuntimeException(s"Cannot pad from $we-bit to $wt-bit! ${e.serialize}")
}
+ def constPropPad(e: DoPrim): Expression = {
+ // we constant prop through casts here in order to allow LegalizeConnects
+ // to not mess up async reset checks in CheckResets
+ val propCasts = e.args.head match {
+ case c @ DoPrim(AsUInt, _, _, _) => constPropCasts(c)
+ case c @ DoPrim(AsSInt, _, _, _) => constPropCasts(c)
+ case other => other
+ }
+ propCasts match {
+ case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head.max(w)))
+ case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head.max(w)))
+ case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head
+ case _ => e
+ }
+ }
+
+ def constPropCasts(e: DoPrim): Expression = e.op match {
+ case AsUInt =>
+ e.args.head match {
+ case SIntLiteral(v, IntWidth(w)) => litToUInt(v, w.toInt)
+ case arg =>
+ arg.tpe match {
+ case _: UIntType => arg
+ case _ => e
+ }
+ }
+ case AsSInt =>
+ e.args.head match {
+ case UIntLiteral(v, IntWidth(w)) => litToSInt(v, w.toInt)
+ case arg =>
+ arg.tpe match {
+ case _: SIntType => arg
+ case _ => e
+ }
+ }
+ case AsClock =>
+ val arg = e.args.head
+ arg.tpe match {
+ case ClockType => arg
+ case _ => e
+ }
+ case AsAsyncReset =>
+ val arg = e.args.head
+ arg.tpe match {
+ case AsyncResetType => arg
+ case _ => e
+ }
+ }
+
+ private def litToSInt(unsignedValue: BigInt, w: Int): SIntLiteral =
+ SIntLiteral(unsignedValue - ((unsignedValue >> (w - 1)) << w), IntWidth(w))
+ private def litToUInt(signedValue: BigInt, w: Int): UIntLiteral =
+ UIntLiteral(signedValue + (if (signedValue < 0) BigInt(1) << w else 0), IntWidth(w))
+
def constPropBitExtract(e: DoPrim) = {
val arg = e.args.head
val (hi, lo) = e.op match {
@@ -68,11 +124,26 @@ object ConstantPropagation {
case 0 => e.args.head
case x =>
e.args.head match {
- // TODO when amount >= x.width, return a zero-width wire
case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x).max(1)))
- // take sign bit if shift amount is larger than arg width
case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x).max(1)))
- case _ => e
+ // Handle non-literal arguments where shift is larger than width
+ case _ =>
+ val amount = e.consts.head.toInt
+ val width = bitWidth(e.args.head.tpe)
+ lazy val msb = width - 1
+ if (amount >= width) {
+ e.tpe match {
+ // When amount >= x.width, return a zero-width wire
+ case UIntType(_) => zero
+ // Take sign bit if shift amount is larger than arg width
+ case SIntType(_) =>
+ val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType)
+ DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1)))
+ case t => error(s"Unsupported type $t for Primop Shift Right")
+ }
+ } else {
+ e
+ }
}
}
@@ -114,12 +185,13 @@ object ConstantPropagation {
class ConstantPropagation extends Transform with RegisteredTransform with DependencyAPIMigration {
import ConstantPropagation._
- override def prerequisites =
- ((new mutable.LinkedHashSet())
- ++ firrtl.stage.Forms.LowForm
- - Dependency(firrtl.passes.Legalize)).toSeq
+ override def prerequisites = firrtl.stage.Forms.LowForm
- override def optionalPrerequisites = Seq(Dependency(firrtl.passes.RemoveValidIf))
+ override def optionalPrerequisites = Seq(
+ // both passes allow constant prop to be more effective!
+ Dependency(firrtl.passes.RemoveValidIf),
+ Dependency(firrtl.passes.PadWidths)
+ )
override def optionalPrerequisiteOf =
Seq(
@@ -130,8 +202,7 @@ class ConstantPropagation extends Transform with RegisteredTransform with Depend
)
override def invalidates(a: Transform): Boolean = a match {
- case firrtl.passes.Legalize => true
- case _ => false
+ case _ => false
}
val options = Seq(
@@ -353,11 +424,16 @@ class ConstantPropagation extends Transform with RegisteredTransform with Depend
def <(that: Range) = this.max < that.min
def <=(that: Range) = this.max <= that.min
}
+ // Padding increases the width but doesn't increase the range of values
+ def trueType(e: Expression): Type = e match {
+ case DoPrim(Pad, Seq(a), _, _) => a.tpe
+ case other => other.tpe
+ }
def range(e: Expression): Range = e match {
case UIntLiteral(value, _) => Range(value, value)
case SIntLiteral(value, _) => Range(value, value)
case _ =>
- e.tpe match {
+ trueType(e) match {
case SIntType(IntWidth(width)) =>
Range(
min = BigInt(0) - BigInt(2).pow(width.toInt - 1),
@@ -444,45 +520,10 @@ class ConstantPropagation extends Transform with RegisteredTransform with Depend
case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w))
case _ => e
}
- case AsUInt =>
- e.args.head match {
- case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w))
- case arg =>
- arg.tpe match {
- case _: UIntType => arg
- case _ => e
- }
- }
- case AsSInt =>
- e.args.head match {
- case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt - 1)) << w.toInt), IntWidth(w))
- case arg =>
- arg.tpe match {
- case _: SIntType => arg
- case _ => e
- }
- }
- case AsClock =>
- val arg = e.args.head
- arg.tpe match {
- case ClockType => arg
- case _ => e
- }
- case AsAsyncReset =>
- val arg = e.args.head
- arg.tpe match {
- case AsyncResetType => arg
- case _ => e
- }
- case Pad =>
- e.args.head match {
- case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head.max(w)))
- case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head.max(w)))
- case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head
- case _ => e
- }
- case (Bits | Head | Tail) => constPropBitExtract(e)
- case _ => e
+ case AsUInt | AsSInt | AsClock | AsAsyncReset => constPropCasts(e)
+ case Pad => constPropPad(e)
+ case (Bits | Head | Tail) => constPropBitExtract(e)
+ case _ => e
}
private def constPropMuxCond(m: Mux) = m.cond match {
diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
index f72585d1..41ffd2be 100644
--- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
+++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
@@ -33,15 +33,7 @@ import collection.mutable
*/
class DeadCodeElimination extends Transform with RegisteredTransform with DependencyAPIMigration {
- override def prerequisites = firrtl.stage.Forms.LowForm ++
- Seq(
- Dependency(firrtl.passes.RemoveValidIf),
- Dependency[firrtl.transforms.ConstantPropagation],
- Dependency(firrtl.passes.memlib.VerilogMemDelays),
- Dependency(firrtl.passes.SplitExpressions),
- Dependency[firrtl.transforms.CombineCats],
- Dependency(passes.CommonSubexpressionElimination)
- )
+ override def prerequisites = firrtl.stage.Forms.LowForm
override def optionalPrerequisites = Seq.empty
diff --git a/src/main/scala/firrtl/transforms/RemoveReset.scala b/src/main/scala/firrtl/transforms/RemoveReset.scala
index d7f59321..62b341cd 100644
--- a/src/main/scala/firrtl/transforms/RemoveReset.scala
+++ b/src/main/scala/firrtl/transforms/RemoveReset.scala
@@ -19,7 +19,7 @@ import scala.collection.{immutable, mutable}
object RemoveReset extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.MidForm ++
- Seq(Dependency(passes.LowerTypes), Dependency(passes.Legalize))
+ Seq(Dependency(passes.LowerTypes))
override def optionalPrerequisites = Seq.empty
diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala
index 7500b386..4fa70002 100644
--- a/src/main/scala/firrtl/transforms/RemoveWires.scala
+++ b/src/main/scala/firrtl/transforms/RemoveWires.scala
@@ -12,6 +12,7 @@ import firrtl.graph.{CyclicException, MutableDiGraph}
import firrtl.options.Dependency
import firrtl.Utils.getGroundZero
import firrtl.backends.experimental.smt.random.DefRandom
+import firrtl.passes.PadWidths
import scala.collection.mutable
import scala.util.{Failure, Success, Try}
@@ -27,10 +28,10 @@ class RemoveWires extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.MidForm ++
Seq(
Dependency(passes.LowerTypes),
- Dependency(passes.Legalize),
Dependency(passes.ResolveKinds),
Dependency(transforms.RemoveReset),
- Dependency[transforms.CheckCombLoops]
+ Dependency[transforms.CheckCombLoops],
+ Dependency(passes.LegalizeConnects)
)
override def optionalPrerequisites = Seq(Dependency[checks.CheckResets])
@@ -131,10 +132,13 @@ class RemoveWires extends Transform with DependencyAPIMigration {
case con @ Connect(cinfo, lhs, rhs) =>
kind(lhs) match {
case WireKind =>
- // Be sure to pad the rhs since nodes get their type from the rhs
- val paddedRhs = ConstantPropagation.pad(rhs, lhs.tpe)
+ // be sure that connects have the same bit widths on rhs and lhs
+ assert(
+ bitWidth(lhs.tpe) == bitWidth(rhs.tpe),
+ "Connection widths should have been taken care of by LegalizeConnects!"
+ )
val dinfo = wireInfo(lhs)
- netlist(we(lhs)) = (Seq(paddedRhs), MultiInfo(dinfo, cinfo))
+ netlist(we(lhs)) = (Seq(rhs), MultiInfo(dinfo, cinfo))
case _ => otherStmts += con // Other connections just pass through
}
case invalid @ IsInvalid(info, expr) =>