aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes
diff options
context:
space:
mode:
authorKevin Laeufer2021-08-02 13:46:29 -0700
committerGitHub2021-08-02 20:46:29 +0000
commite04f1e7f303920ac1d1f865450d0e280aafb58b3 (patch)
tree73f26cd236ac8069d9c4877a3c42457d65d477fe /src/main/scala/firrtl/passes
parentff1cd28202fb423956a6803a889c3632487d8872 (diff)
add emitter for optimized low firrtl (#2304)
* rearrange passes to enable optimized firrtl emission * Support ConstProp on padded arguments to comparisons with literals * Move shr legalization logic into ConstProp Continue calling ConstProp of shr in Legalize. Co-authored-by: Jack Koenig <koenig@sifive.com> Co-authored-by: Jack Koenig <koenig@sifive.com>
Diffstat (limited to 'src/main/scala/firrtl/passes')
-rw-r--r--src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala10
-rw-r--r--src/main/scala/firrtl/passes/Legalize.scala108
-rw-r--r--src/main/scala/firrtl/passes/LegalizeConnects.scala31
-rw-r--r--src/main/scala/firrtl/passes/PadWidths.scala84
-rw-r--r--src/main/scala/firrtl/passes/RemoveValidIf.scala2
-rw-r--r--src/main/scala/firrtl/passes/SplitExpressions.scala8
-rw-r--r--src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala2
-rw-r--r--src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala4
8 files changed, 84 insertions, 165 deletions
diff --git a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala
index e70346d4..70da011e 100644
--- a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala
+++ b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala
@@ -9,15 +9,7 @@ import firrtl.options.Dependency
object CommonSubexpressionElimination extends Pass {
- override def prerequisites = firrtl.stage.Forms.LowForm ++
- Seq(
- Dependency(firrtl.passes.RemoveValidIf),
- Dependency[firrtl.transforms.ConstantPropagation],
- Dependency(firrtl.passes.memlib.VerilogMemDelays),
- Dependency(firrtl.passes.SplitExpressions),
- Dependency[firrtl.transforms.CombineCats]
- )
-
+ override def prerequisites = firrtl.stage.Forms.LowForm
override def optionalPrerequisiteOf =
Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
diff --git a/src/main/scala/firrtl/passes/Legalize.scala b/src/main/scala/firrtl/passes/Legalize.scala
deleted file mode 100644
index e1a39fbe..00000000
--- a/src/main/scala/firrtl/passes/Legalize.scala
+++ /dev/null
@@ -1,108 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-
-package firrtl.passes
-
-import firrtl.PrimOps._
-import firrtl.Utils.{error, getGroundZero, zero, BoolType}
-import firrtl.ir._
-import firrtl.options.Dependency
-import firrtl.transforms.ConstantPropagation
-import firrtl.{bitWidth, getWidth, Transform}
-import firrtl.Mappers._
-
-// Replace shr by amount >= arg width with 0 for UInts and MSB for SInts
-// TODO replace UInt with zero-width wire instead
-object Legalize extends Pass {
-
- override def prerequisites = firrtl.stage.Forms.MidForm :+ Dependency(LowerTypes)
-
- override def optionalPrerequisites = Seq.empty
-
- override def optionalPrerequisiteOf = Seq.empty
-
- override def invalidates(a: Transform) = false
-
- private def legalizeShiftRight(e: DoPrim): Expression = {
- require(e.op == Shr)
- e.args.head match {
- case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.foldShiftRight(e)
- case _ =>
- val amount = e.consts.head.toInt
- val width = bitWidth(e.args.head.tpe)
- lazy val msb = width - 1
- if (amount >= width) {
- e.tpe match {
- case UIntType(_) => zero
- case SIntType(_) =>
- val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType)
- DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1)))
- case t => error(s"Unsupported type $t for Primop Shift Right")
- }
- } else {
- e
- }
- }
- }
- private def legalizeBitExtract(expr: DoPrim): Expression = {
- expr.args.head match {
- case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.constPropBitExtract(expr)
- case _ => expr
- }
- }
- private def legalizePad(expr: DoPrim): Expression = expr.args.head match {
- case UIntLiteral(value, IntWidth(width)) if width < expr.consts.head =>
- UIntLiteral(value, IntWidth(expr.consts.head))
- case SIntLiteral(value, IntWidth(width)) if width < expr.consts.head =>
- SIntLiteral(value, IntWidth(expr.consts.head))
- case _ => expr
- }
- // Convert `-x` to `0 - x`
- private def legalizeNeg(expr: DoPrim): Expression = {
- val arg = expr.args.head
- arg.tpe match {
- case tpe: SIntType =>
- val zero = getGroundZero(tpe)
- DoPrim(Sub, Seq(zero, arg), Nil, expr.tpe)
- case tpe: UIntType =>
- val zero = getGroundZero(tpe)
- val sub = DoPrim(Sub, Seq(zero, arg), Nil, UIntType(tpe.width + IntWidth(1)))
- DoPrim(AsSInt, Seq(sub), Nil, expr.tpe)
- }
- }
- private def legalizeConnect(c: Connect): Statement = {
- val t = c.loc.tpe
- val w = bitWidth(t)
- if (w >= bitWidth(c.expr.tpe)) {
- c
- } else {
- val bits = DoPrim(Bits, Seq(c.expr), Seq(w - 1, 0), UIntType(IntWidth(w)))
- val expr = t match {
- case UIntType(_) => bits
- case SIntType(_) => DoPrim(AsSInt, Seq(bits), Seq(), SIntType(IntWidth(w)))
- case FixedType(_, IntWidth(p)) => DoPrim(AsFixedPoint, Seq(bits), Seq(p), t)
- }
- Connect(c.info, c.loc, expr)
- }
- }
- def run(c: Circuit): Circuit = {
- def legalizeE(expr: Expression): Expression = expr.map(legalizeE) match {
- case prim: DoPrim =>
- prim.op match {
- case Shr => legalizeShiftRight(prim)
- case Pad => legalizePad(prim)
- case Bits | Head | Tail => legalizeBitExtract(prim)
- case Neg => legalizeNeg(prim)
- case _ => prim
- }
- case e => e // respect pre-order traversal
- }
- def legalizeS(s: Statement): Statement = {
- val legalizedStmt = s match {
- case c: Connect => legalizeConnect(c)
- case _ => s
- }
- legalizedStmt.map(legalizeS).map(legalizeE)
- }
- c.copy(modules = c.modules.map(_.map(legalizeS)))
- }
-}
diff --git a/src/main/scala/firrtl/passes/LegalizeConnects.scala b/src/main/scala/firrtl/passes/LegalizeConnects.scala
new file mode 100644
index 00000000..2f29de10
--- /dev/null
+++ b/src/main/scala/firrtl/passes/LegalizeConnects.scala
@@ -0,0 +1,31 @@
+// SPDX-License-Identifier: Apache-2.0
+
+package firrtl.passes
+
+import firrtl.ir._
+import firrtl.options.Dependency
+import firrtl.{bitWidth, Transform}
+
+/** Ensures that all connects + register inits have the same bit-width on the rhs and the lhs.
+ * The rhs is padded or bit-extacted to fit the width of the lhs.
+ * @note technically, width(rhs) > width(lhs) is not legal firrtl, however, we do not error for historic reasons.
+ */
+object LegalizeConnects extends Pass {
+
+ override def prerequisites = firrtl.stage.Forms.MidForm :+ Dependency(LowerTypes)
+ override def optionalPrerequisites = Seq.empty
+ override def optionalPrerequisiteOf = Seq.empty
+ override def invalidates(a: Transform) = false
+
+ def onStmt(s: Statement): Statement = s match {
+ case c: Connect =>
+ c.copy(expr = PadWidths.forceWidth(bitWidth(c.loc.tpe).toInt)(c.expr))
+ case r: DefRegister =>
+ r.copy(init = PadWidths.forceWidth(bitWidth(r.tpe).toInt)(r.init))
+ case other => other.mapStmt(onStmt)
+ }
+
+ def run(c: Circuit): Circuit = {
+ c.copy(modules = c.modules.map(_.mapStmt(onStmt)))
+ }
+}
diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala
index 1a430778..02e94975 100644
--- a/src/main/scala/firrtl/passes/PadWidths.scala
+++ b/src/main/scala/firrtl/passes/PadWidths.scala
@@ -7,63 +7,59 @@ import firrtl.ir._
import firrtl.PrimOps._
import firrtl.Mappers._
import firrtl.options.Dependency
-
-import scala.collection.mutable
+import firrtl.transforms.ConstantPropagation
// Makes all implicit width extensions and truncations explicit
object PadWidths extends Pass {
- override def prerequisites =
- ((new mutable.LinkedHashSet())
- ++ firrtl.stage.Forms.LowForm
- - Dependency(firrtl.passes.Legalize)
- + Dependency(firrtl.passes.RemoveValidIf)).toSeq
-
- override def optionalPrerequisites = Seq(Dependency[firrtl.transforms.ConstantPropagation])
+ override def prerequisites = firrtl.stage.Forms.LowForm
override def optionalPrerequisiteOf =
Seq(Dependency(firrtl.passes.memlib.VerilogMemDelays), Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
override def invalidates(a: Transform): Boolean = a match {
- case _: firrtl.transforms.ConstantPropagation | Legalize => true
- case _ => false
+ case SplitExpressions => true // we generate pad and bits operations inline which need to be split up
+ case _ => false
}
- private def width(t: Type): Int = bitWidth(t).toInt
- private def width(e: Expression): Int = width(e.tpe)
- // Returns an expression with the correct integer width
- private def fixup(i: Int)(e: Expression) = {
- def tx = e.tpe match {
- case t: UIntType => UIntType(IntWidth(i))
- case t: SIntType => SIntType(IntWidth(i))
- // default case should never be reached
- }
- width(e) match {
- case j if i > j => DoPrim(Pad, Seq(e), Seq(i), tx)
- case j if i < j =>
- val e2 = DoPrim(Bits, Seq(e), Seq(i - 1, 0), UIntType(IntWidth(i)))
- // Bit Select always returns UInt, cast if selecting from SInt
- e.tpe match {
- case UIntType(_) => e2
- case SIntType(_) => DoPrim(AsSInt, Seq(e2), Seq.empty, SIntType(IntWidth(i)))
- }
- case _ => e
+ /** Adds padding or a bit extract to ensure that the expression is of the with specified.
+ * @note only works on UInt and SInt type expressions, other expressions will yield a match error
+ */
+ private[firrtl] def forceWidth(width: Int)(e: Expression): Expression = {
+ val old = getWidth(e)
+ if (width == old) { e }
+ else if (width > old) {
+ // padding retains the signedness
+ val newType = e.tpe match {
+ case _: UIntType => UIntType(IntWidth(width))
+ case _: SIntType => SIntType(IntWidth(width))
+ case other => throw new RuntimeException(s"forceWidth does not support expressions of type $other")
+ }
+ ConstantPropagation.constPropPad(DoPrim(Pad, Seq(e), Seq(width), newType))
+ } else {
+ val extract = DoPrim(Bits, Seq(e), Seq(width - 1, 0), UIntType(IntWidth(width)))
+ val e2 = ConstantPropagation.constPropBitExtract(extract)
+ // Bit Select always returns UInt, cast if selecting from SInt
+ e.tpe match {
+ case UIntType(_) => e2
+ case SIntType(_) => DoPrim(AsSInt, Seq(e2), Seq.empty, SIntType(IntWidth(width)))
+ }
}
}
+ private def getWidth(t: Type): Int = bitWidth(t).toInt
+ private def getWidth(e: Expression): Int = getWidth(e.tpe)
+
// Recursive, updates expression so children exp's have correct widths
private def onExp(e: Expression): Expression = e.map(onExp) match {
case Mux(cond, tval, fval, tpe) =>
- Mux(cond, fixup(width(tpe))(tval), fixup(width(tpe))(fval), tpe)
- case ex: ValidIf => ex.copy(value = fixup(width(ex.tpe))(ex.value))
+ Mux(cond, forceWidth(getWidth(tpe))(tval), forceWidth(getWidth(tpe))(fval), tpe)
+ case ex: ValidIf => ex.copy(value = forceWidth(getWidth(ex.tpe))(ex.value))
case ex: DoPrim =>
ex.op match {
- case Lt | Leq | Gt | Geq | Eq | Neq | Not | And | Or | Xor | Add | Sub | Rem | Shr =>
- // sensitive ops
- ex.map(fixup((ex.args.map(width).foldLeft(0))(math.max)))
- case Dshl =>
- // special case as args aren't all same width
- ex.copy(op = Dshlw, args = Seq(fixup(width(ex.tpe))(ex.args.head), ex.args(1)))
+ // pad arguments to ops where the result width is determined as max(w_1, w_2) (+ const)?
+ case Lt | Leq | Gt | Geq | Eq | Neq | And | Or | Xor | Add | Sub =>
+ ex.map(forceWidth(ex.args.map(getWidth).max))
case _ => ex
}
case ex => ex
@@ -72,9 +68,17 @@ object PadWidths extends Pass {
// Recursive. Fixes assignments and register initialization widths
private def onStmt(s: Statement): Statement = s.map(onExp) match {
case sx: Connect =>
- sx.copy(expr = fixup(width(sx.loc))(sx.expr))
+ assert(
+ getWidth(sx.loc) == getWidth(sx.expr),
+ "Connection widths should have been taken care of by LegalizeConnects!"
+ )
+ sx
case sx: DefRegister =>
- sx.copy(init = fixup(width(sx.tpe))(sx.init))
+ assert(
+ getWidth(sx.tpe) == getWidth(sx.init),
+ "Register init widths should have been taken care of by LegalizeConnects!"
+ )
+ sx
case sx => sx.map(onStmt)
}
diff --git a/src/main/scala/firrtl/passes/RemoveValidIf.scala b/src/main/scala/firrtl/passes/RemoveValidIf.scala
index 03214f83..dc4e70ff 100644
--- a/src/main/scala/firrtl/passes/RemoveValidIf.scala
+++ b/src/main/scala/firrtl/passes/RemoveValidIf.scala
@@ -31,7 +31,7 @@ object RemoveValidIf extends Pass {
Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
override def invalidates(a: Transform): Boolean = a match {
- case Legalize | _: firrtl.transforms.ConstantPropagation => true
+ case _: firrtl.transforms.ConstantPropagation => true // switching out the validifs allows for more constant prop
case _ => false
}
diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala
index 1b4ed1cc..26088e9c 100644
--- a/src/main/scala/firrtl/passes/SplitExpressions.scala
+++ b/src/main/scala/firrtl/passes/SplitExpressions.scala
@@ -8,6 +8,7 @@ import firrtl.ir._
import firrtl.options.Dependency
import firrtl.Mappers._
import firrtl.Utils.{flow, get_info, kind}
+import firrtl.transforms.InlineBooleanExpressions
// Datastructures
import scala.collection.mutable
@@ -16,15 +17,14 @@ import scala.collection.mutable
// and named intermediate nodes
object SplitExpressions extends Pass {
- override def prerequisites = firrtl.stage.Forms.LowForm ++
- Seq(Dependency(firrtl.passes.RemoveValidIf), Dependency(firrtl.passes.memlib.VerilogMemDelays))
-
+ override def prerequisites = firrtl.stage.Forms.LowForm
override def optionalPrerequisiteOf =
Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
override def invalidates(a: Transform) = a match {
case ResolveKinds => true
- case _ => false
+ case _: InlineBooleanExpressions => true // SplitExpressions undoes the inlining!
+ case _ => false
}
private def onModule(m: Module): Module = {
diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
index c7b0fbcd..331dd43e 100644
--- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
+++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
@@ -153,7 +153,7 @@ class ReplSeqMem extends SeqTransform with HasShellOptions with DependencyAPIMig
val transforms: Seq[Transform] =
Seq(
- new SimpleMidTransform(Legalize),
+ new SimpleMidTransform(LegalizeConnects),
new SimpleMidTransform(ToMemIR),
new SimpleMidTransform(ResolveMaskGranularity),
new SimpleMidTransform(RenameAnnotatedMemoryPorts),
diff --git a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala
index 11184e60..3778f4da 100644
--- a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala
+++ b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala
@@ -235,8 +235,8 @@ object VerilogMemDelays extends Pass {
Seq(Dependency[VerilogEmitter], Dependency[SystemVerilogEmitter])
override def invalidates(a: Transform): Boolean = a match {
- case _: transforms.ConstantPropagation | ResolveFlows => true
- case _ => false
+ case ResolveFlows => true
+ case _ => false
}
private def transform(m: DefModule): DefModule = (new MemDelayAndReadwriteTransformer(m)).transformed