diff options
Diffstat (limited to 'src/main/scala/firrtl/passes/PadWidths.scala')
| -rw-r--r-- | src/main/scala/firrtl/passes/PadWidths.scala | 84 |
1 files changed, 44 insertions, 40 deletions
diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala index 1a430778..02e94975 100644 --- a/src/main/scala/firrtl/passes/PadWidths.scala +++ b/src/main/scala/firrtl/passes/PadWidths.scala @@ -7,63 +7,59 @@ import firrtl.ir._ import firrtl.PrimOps._ import firrtl.Mappers._ import firrtl.options.Dependency - -import scala.collection.mutable +import firrtl.transforms.ConstantPropagation // Makes all implicit width extensions and truncations explicit object PadWidths extends Pass { - override def prerequisites = - ((new mutable.LinkedHashSet()) - ++ firrtl.stage.Forms.LowForm - - Dependency(firrtl.passes.Legalize) - + Dependency(firrtl.passes.RemoveValidIf)).toSeq - - override def optionalPrerequisites = Seq(Dependency[firrtl.transforms.ConstantPropagation]) + override def prerequisites = firrtl.stage.Forms.LowForm override def optionalPrerequisiteOf = Seq(Dependency(firrtl.passes.memlib.VerilogMemDelays), Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) override def invalidates(a: Transform): Boolean = a match { - case _: firrtl.transforms.ConstantPropagation | Legalize => true - case _ => false + case SplitExpressions => true // we generate pad and bits operations inline which need to be split up + case _ => false } - private def width(t: Type): Int = bitWidth(t).toInt - private def width(e: Expression): Int = width(e.tpe) - // Returns an expression with the correct integer width - private def fixup(i: Int)(e: Expression) = { - def tx = e.tpe match { - case t: UIntType => UIntType(IntWidth(i)) - case t: SIntType => SIntType(IntWidth(i)) - // default case should never be reached - } - width(e) match { - case j if i > j => DoPrim(Pad, Seq(e), Seq(i), tx) - case j if i < j => - val e2 = DoPrim(Bits, Seq(e), Seq(i - 1, 0), UIntType(IntWidth(i))) - // Bit Select always returns UInt, cast if selecting from SInt - e.tpe match { - case UIntType(_) => e2 - case SIntType(_) => DoPrim(AsSInt, Seq(e2), Seq.empty, SIntType(IntWidth(i))) - } - case _ => e + /** Adds padding or a bit extract to ensure that the expression is of the with specified. + * @note only works on UInt and SInt type expressions, other expressions will yield a match error + */ + private[firrtl] def forceWidth(width: Int)(e: Expression): Expression = { + val old = getWidth(e) + if (width == old) { e } + else if (width > old) { + // padding retains the signedness + val newType = e.tpe match { + case _: UIntType => UIntType(IntWidth(width)) + case _: SIntType => SIntType(IntWidth(width)) + case other => throw new RuntimeException(s"forceWidth does not support expressions of type $other") + } + ConstantPropagation.constPropPad(DoPrim(Pad, Seq(e), Seq(width), newType)) + } else { + val extract = DoPrim(Bits, Seq(e), Seq(width - 1, 0), UIntType(IntWidth(width))) + val e2 = ConstantPropagation.constPropBitExtract(extract) + // Bit Select always returns UInt, cast if selecting from SInt + e.tpe match { + case UIntType(_) => e2 + case SIntType(_) => DoPrim(AsSInt, Seq(e2), Seq.empty, SIntType(IntWidth(width))) + } } } + private def getWidth(t: Type): Int = bitWidth(t).toInt + private def getWidth(e: Expression): Int = getWidth(e.tpe) + // Recursive, updates expression so children exp's have correct widths private def onExp(e: Expression): Expression = e.map(onExp) match { case Mux(cond, tval, fval, tpe) => - Mux(cond, fixup(width(tpe))(tval), fixup(width(tpe))(fval), tpe) - case ex: ValidIf => ex.copy(value = fixup(width(ex.tpe))(ex.value)) + Mux(cond, forceWidth(getWidth(tpe))(tval), forceWidth(getWidth(tpe))(fval), tpe) + case ex: ValidIf => ex.copy(value = forceWidth(getWidth(ex.tpe))(ex.value)) case ex: DoPrim => ex.op match { - case Lt | Leq | Gt | Geq | Eq | Neq | Not | And | Or | Xor | Add | Sub | Rem | Shr => - // sensitive ops - ex.map(fixup((ex.args.map(width).foldLeft(0))(math.max))) - case Dshl => - // special case as args aren't all same width - ex.copy(op = Dshlw, args = Seq(fixup(width(ex.tpe))(ex.args.head), ex.args(1))) + // pad arguments to ops where the result width is determined as max(w_1, w_2) (+ const)? + case Lt | Leq | Gt | Geq | Eq | Neq | And | Or | Xor | Add | Sub => + ex.map(forceWidth(ex.args.map(getWidth).max)) case _ => ex } case ex => ex @@ -72,9 +68,17 @@ object PadWidths extends Pass { // Recursive. Fixes assignments and register initialization widths private def onStmt(s: Statement): Statement = s.map(onExp) match { case sx: Connect => - sx.copy(expr = fixup(width(sx.loc))(sx.expr)) + assert( + getWidth(sx.loc) == getWidth(sx.expr), + "Connection widths should have been taken care of by LegalizeConnects!" + ) + sx case sx: DefRegister => - sx.copy(init = fixup(width(sx.tpe))(sx.init)) + assert( + getWidth(sx.tpe) == getWidth(sx.init), + "Register init widths should have been taken care of by LegalizeConnects!" + ) + sx case sx => sx.map(onStmt) } |
