aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/PadWidths.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/passes/PadWidths.scala')
-rw-r--r--src/main/scala/firrtl/passes/PadWidths.scala84
1 files changed, 44 insertions, 40 deletions
diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala
index 1a430778..02e94975 100644
--- a/src/main/scala/firrtl/passes/PadWidths.scala
+++ b/src/main/scala/firrtl/passes/PadWidths.scala
@@ -7,63 +7,59 @@ import firrtl.ir._
import firrtl.PrimOps._
import firrtl.Mappers._
import firrtl.options.Dependency
-
-import scala.collection.mutable
+import firrtl.transforms.ConstantPropagation
// Makes all implicit width extensions and truncations explicit
object PadWidths extends Pass {
- override def prerequisites =
- ((new mutable.LinkedHashSet())
- ++ firrtl.stage.Forms.LowForm
- - Dependency(firrtl.passes.Legalize)
- + Dependency(firrtl.passes.RemoveValidIf)).toSeq
-
- override def optionalPrerequisites = Seq(Dependency[firrtl.transforms.ConstantPropagation])
+ override def prerequisites = firrtl.stage.Forms.LowForm
override def optionalPrerequisiteOf =
Seq(Dependency(firrtl.passes.memlib.VerilogMemDelays), Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
override def invalidates(a: Transform): Boolean = a match {
- case _: firrtl.transforms.ConstantPropagation | Legalize => true
- case _ => false
+ case SplitExpressions => true // we generate pad and bits operations inline which need to be split up
+ case _ => false
}
- private def width(t: Type): Int = bitWidth(t).toInt
- private def width(e: Expression): Int = width(e.tpe)
- // Returns an expression with the correct integer width
- private def fixup(i: Int)(e: Expression) = {
- def tx = e.tpe match {
- case t: UIntType => UIntType(IntWidth(i))
- case t: SIntType => SIntType(IntWidth(i))
- // default case should never be reached
- }
- width(e) match {
- case j if i > j => DoPrim(Pad, Seq(e), Seq(i), tx)
- case j if i < j =>
- val e2 = DoPrim(Bits, Seq(e), Seq(i - 1, 0), UIntType(IntWidth(i)))
- // Bit Select always returns UInt, cast if selecting from SInt
- e.tpe match {
- case UIntType(_) => e2
- case SIntType(_) => DoPrim(AsSInt, Seq(e2), Seq.empty, SIntType(IntWidth(i)))
- }
- case _ => e
+ /** Adds padding or a bit extract to ensure that the expression is of the with specified.
+ * @note only works on UInt and SInt type expressions, other expressions will yield a match error
+ */
+ private[firrtl] def forceWidth(width: Int)(e: Expression): Expression = {
+ val old = getWidth(e)
+ if (width == old) { e }
+ else if (width > old) {
+ // padding retains the signedness
+ val newType = e.tpe match {
+ case _: UIntType => UIntType(IntWidth(width))
+ case _: SIntType => SIntType(IntWidth(width))
+ case other => throw new RuntimeException(s"forceWidth does not support expressions of type $other")
+ }
+ ConstantPropagation.constPropPad(DoPrim(Pad, Seq(e), Seq(width), newType))
+ } else {
+ val extract = DoPrim(Bits, Seq(e), Seq(width - 1, 0), UIntType(IntWidth(width)))
+ val e2 = ConstantPropagation.constPropBitExtract(extract)
+ // Bit Select always returns UInt, cast if selecting from SInt
+ e.tpe match {
+ case UIntType(_) => e2
+ case SIntType(_) => DoPrim(AsSInt, Seq(e2), Seq.empty, SIntType(IntWidth(width)))
+ }
}
}
+ private def getWidth(t: Type): Int = bitWidth(t).toInt
+ private def getWidth(e: Expression): Int = getWidth(e.tpe)
+
// Recursive, updates expression so children exp's have correct widths
private def onExp(e: Expression): Expression = e.map(onExp) match {
case Mux(cond, tval, fval, tpe) =>
- Mux(cond, fixup(width(tpe))(tval), fixup(width(tpe))(fval), tpe)
- case ex: ValidIf => ex.copy(value = fixup(width(ex.tpe))(ex.value))
+ Mux(cond, forceWidth(getWidth(tpe))(tval), forceWidth(getWidth(tpe))(fval), tpe)
+ case ex: ValidIf => ex.copy(value = forceWidth(getWidth(ex.tpe))(ex.value))
case ex: DoPrim =>
ex.op match {
- case Lt | Leq | Gt | Geq | Eq | Neq | Not | And | Or | Xor | Add | Sub | Rem | Shr =>
- // sensitive ops
- ex.map(fixup((ex.args.map(width).foldLeft(0))(math.max)))
- case Dshl =>
- // special case as args aren't all same width
- ex.copy(op = Dshlw, args = Seq(fixup(width(ex.tpe))(ex.args.head), ex.args(1)))
+ // pad arguments to ops where the result width is determined as max(w_1, w_2) (+ const)?
+ case Lt | Leq | Gt | Geq | Eq | Neq | And | Or | Xor | Add | Sub =>
+ ex.map(forceWidth(ex.args.map(getWidth).max))
case _ => ex
}
case ex => ex
@@ -72,9 +68,17 @@ object PadWidths extends Pass {
// Recursive. Fixes assignments and register initialization widths
private def onStmt(s: Statement): Statement = s.map(onExp) match {
case sx: Connect =>
- sx.copy(expr = fixup(width(sx.loc))(sx.expr))
+ assert(
+ getWidth(sx.loc) == getWidth(sx.expr),
+ "Connection widths should have been taken care of by LegalizeConnects!"
+ )
+ sx
case sx: DefRegister =>
- sx.copy(init = fixup(width(sx.tpe))(sx.init))
+ assert(
+ getWidth(sx.tpe) == getWidth(sx.init),
+ "Register init widths should have been taken care of by LegalizeConnects!"
+ )
+ sx
case sx => sx.map(onStmt)
}