aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlbert Chen2020-08-25 16:54:25 -0700
committerGitHub2020-08-25 16:54:25 -0700
commit40cb49f9237e23608da454a194f5c55e33f19375 (patch)
treeef29fe2f44c9927f15fd1591285ffd008bd8d750 /src
parentd7a3741909edb72cda2b768e2f8fae4f3c2fd6e2 (diff)
Inline Boolean Expressions (#1817)
The following conditions must be satisfied to inline: 1. has type Utils.BoolType 2. is bound to a DefNode with name starting with '_' 3. is bound to a DefNode with a source locator that points at the same file and line number. If it is a MultiInfo source locator, the set of file and line number pairs must be the same. Source locators may point to different column numbers. 4. InlineBooleanExpressionsMax has not been exceeded 5. is not a Mux Also updates the Verilog emitter to break up lines greater than 120 characters
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/Emitter.scala288
-rw-r--r--src/main/scala/firrtl/stage/Forms.scala6
-rw-r--r--src/main/scala/firrtl/transforms/InlineBooleanExpressions.scala169
-rw-r--r--src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala242
-rw-r--r--src/test/scala/firrtlTests/LoweringCompilersSpec.scala2
-rw-r--r--src/test/scala/firrtlTests/UnitTests.scala12
6 files changed, 623 insertions, 96 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala
index 843c76a4..19f2661a 100644
--- a/src/main/scala/firrtl/Emitter.scala
+++ b/src/main/scala/firrtl/Emitter.scala
@@ -230,7 +230,48 @@ case class VRandom(width: BigInt) extends Expression {
def foreachWidth(f: Width => Unit): Unit = ()
}
+object VerilogEmitter {
+
+ /** Maps a [[PrimOp]] to a precedence number, lower number means higher precedence
+ *
+ * Only the [[PrimOp]]s contained in this map will be inlined. [[PrimOp]]s
+ * like [[PrimOp.Neg]] are not in this map because inlining them may result
+ * in illegal verilog like '--2sh1'
+ */
+ private val precedenceMap: Map[PrimOp, Int] = {
+ val precedenceSeq = Seq(
+ Set(Head, Tail, Bits, Shr, Pad), // Shr and Pad emit as bit select
+ Set(Andr, Orr, Xorr, Neg, Not),
+ Set(Mul, Div, Rem),
+ Set(Add, Sub, Addw, Subw),
+ Set(Dshl, Dshlw, Dshr),
+ Set(Lt, Leq, Gt, Geq),
+ Set(Eq, Neq),
+ Set(And),
+ Set(Xor),
+ Set(Or)
+ )
+ precedenceSeq.zipWithIndex.foldLeft(Map.empty[PrimOp, Int]) {
+ case (map, (ops, idx)) => map ++ ops.map(_ -> idx)
+ }
+ }
+
+ /** true if op1 has greater or equal precendence than op2
+ */
+ private def precedenceGeq(op1: PrimOp, op2: PrimOp): Boolean = {
+ precedenceMap(op1) <= precedenceMap(op2)
+ }
+
+ /** true if op1 has greater precendence than op2
+ */
+ private def precedenceGt(op1: PrimOp, op2: PrimOp): Boolean = {
+ precedenceMap(op1) < precedenceMap(op2)
+ }
+}
+
class VerilogEmitter extends SeqTransform with Emitter {
+ import VerilogEmitter._
+
def inputForm = LowForm
def outputForm = LowForm
@@ -280,8 +321,42 @@ class VerilogEmitter extends SeqTransform with Emitter {
case ClockType | AsyncResetType => ""
case _ => throwInternalError(s"trying to write unsupported type in the Verilog Emitter: $tpe")
}
- def emit(x: Any)(implicit w: Writer): Unit = { emit(x, 0) }
+ private def getLeadingTabs(x: Any): String = {
+ x match {
+ case seq: Seq[_] =>
+ val head = seq.takeWhile(_ == tab).mkString
+ val tail = seq.dropWhile(_ == tab).lift(0).map(getLeadingTabs).getOrElse(tab)
+ head + tail
+ case _ => tab
+ }
+ }
+ def emit(x: Any)(implicit w: Writer): Unit = {
+ emitCol(x, 0, getLeadingTabs(x), 0)
+ }
+ private def emitCast(e: Expression): Any = e.tpe match {
+ case (t: UIntType) => e
+ case (t: SIntType) => Seq("$signed(", e, ")")
+ case ClockType => e
+ case AnalogType(_) => e
+ case _ => throwInternalError(s"unrecognized cast: $e")
+ }
def emit(x: Any, top: Int)(implicit w: Writer): Unit = {
+ emitCol(x, top, "", 0)
+ }
+ private val maxCol = 120
+ private def emitCol(x: Any, top: Int, tabs: String, colNum: Int)(implicit w: Writer): Int = {
+ def writeCol(contents: String): Int = {
+ if ((contents.size + colNum) > maxCol) {
+ w.write("\n")
+ w.write(tabs)
+ w.write(contents)
+ tabs.size + contents.size
+ } else {
+ w.write(contents)
+ colNum + contents.size
+ }
+ }
+
def cast(e: Expression): Any = e.tpe match {
case (t: UIntType) => e
case (t: SIntType) => Seq("$signed(", e, ")")
@@ -290,7 +365,7 @@ class VerilogEmitter extends SeqTransform with Emitter {
case _ => throwInternalError(s"unrecognized cast: $e")
}
x match {
- case (e: DoPrim) => emit(op_stream(e), top + 1)
+ case (e: DoPrim) => emitCol(op_stream(e), top + 1, tabs, colNum)
case (e: Mux) => {
if (e.tpe == ClockType) {
throw EmitterException("Cannot emit clock muxes directly")
@@ -298,51 +373,64 @@ class VerilogEmitter extends SeqTransform with Emitter {
if (e.tpe == AsyncResetType) {
throw EmitterException("Cannot emit async reset muxes directly")
}
- emit(Seq(e.cond, " ? ", cast(e.tval), " : ", cast(e.fval)), top + 1)
+ emitCol(Seq(e.cond, " ? ", cast(e.tval), " : ", cast(e.fval)), top + 1, tabs, colNum)
}
- case (e: ValidIf) => emit(Seq(cast(e.value)), top + 1)
- case (e: WRef) => w.write(e.serialize)
- case (e: WSubField) => w.write(LowerTypes.loweredName(e))
- case (e: WSubAccess) => w.write(s"${LowerTypes.loweredName(e.expr)}[${LowerTypes.loweredName(e.index)}]")
- case (e: WSubIndex) => w.write(e.serialize)
- case (e: Literal) => v_print(e)
- case (e: VRandom) => w.write(s"{${e.nWords}{`RANDOM}}")
- case (t: GroundType) => w.write(stringify(t))
+ case (e: ValidIf) => emitCol(Seq(cast(e.value)), top + 1, tabs, colNum)
+ case (e: WRef) => writeCol(e.serialize)
+ case (e: WSubField) => writeCol(LowerTypes.loweredName(e))
+ case (e: WSubAccess) => writeCol(s"${LowerTypes.loweredName(e.expr)}[${LowerTypes.loweredName(e.index)}]")
+ case (e: WSubIndex) => writeCol(e.serialize)
+ case (e: Literal) => v_print(e, colNum)
+ case (e: VRandom) => writeCol(s"{${e.nWords}{`RANDOM}}")
+ case (t: GroundType) => writeCol(stringify(t))
case (t: VectorType) =>
emit(t.tpe, top + 1)
- w.write(s"[${t.size - 1}:0]")
- case (s: String) => w.write(s)
- case (i: Int) => w.write(i.toString)
- case (i: Long) => w.write(i.toString)
- case (i: BigInt) => w.write(i.toString)
+ writeCol(s"[${t.size - 1}:0]")
+ case (s: String) => writeCol(s)
+ case (i: Int) => writeCol(i.toString)
+ case (i: Long) => writeCol(i.toString)
+ case (i: BigInt) => writeCol(i.toString)
case (i: Info) =>
i match {
- case NoInfo => // Do nothing
+ case NoInfo => colNum // Do nothing
case f: FileInfo =>
val escaped = FileInfo.escapedToVerilog(f.escaped)
w.write(s" // @[$escaped]")
+ colNum
case m: MultiInfo =>
val escaped = FileInfo.escapedToVerilog(m.flatten.map(_.escaped).mkString(" "))
w.write(s" // @[$escaped]")
+ colNum
}
case (s: Seq[Any]) =>
- s.foreach(emit(_, top + 1))
- if (top == 0) w.write("\n")
+ val nextColNum = s.foldLeft(colNum) {
+ case (colNum, e) => emitCol(e, top + 1, tabs, colNum)
+ }
+ if (top == 0) {
+ w.write("\n")
+ 0
+ } else {
+ nextColNum
+ }
case x => throwInternalError(s"trying to emit unsupported operator: $x")
}
}
//;------------- PASS -----------------
- def v_print(e: Expression)(implicit w: Writer) = e match {
+ def v_print(e: Expression, colNum: Int)(implicit w: Writer) = e match {
case UIntLiteral(value, IntWidth(width)) =>
- w.write(s"$width'h${value.toString(16)}")
+ val contents = s"$width'h${value.toString(16)}"
+ w.write(contents)
+ colNum + contents.size
case SIntLiteral(value, IntWidth(width)) =>
val stringLiteral = value.toString(16)
- w.write(stringLiteral.head match {
+ val contents = stringLiteral.head match {
case '-' if value == FixAddingNegativeLiterals.minNegValue(width) => s"$width'sh${stringLiteral.tail}"
case '-' => s"-$width'sh${stringLiteral.tail}"
case _ => s"$width'sh${stringLiteral}"
- })
+ }
+ w.write(contents)
+ colNum + contents.size
case _ => throwInternalError(s"attempt to print unrecognized expression: $e")
}
@@ -350,29 +438,62 @@ class VerilogEmitter extends SeqTransform with Emitter {
// reference is actually unsigned in the emitted Verilog. Thus we must cast refs as necessary
// to ensure Verilog operations are signed.
def op_stream(doprim: DoPrim): Seq[Any] = {
+ def parenthesize(e: Expression, isFirst: Boolean): Any = doprim.op match {
+ // these PrimOps emit either {..., a0, ...} or a0 so they never need parentheses
+ case Shl | Cat | Cvt | AsUInt | AsSInt | AsClock | AsAsyncReset => e
+ case _ =>
+ e match {
+ case e: DoPrim =>
+ op_stream(e) match {
+ /** DoPrims like AsUInt simply emit Seq(a0), so we need to
+ * recursively check whether a0 needs to be parenthesized
+ */
+ case Seq(passthrough: Expression) => parenthesize(passthrough, isFirst)
+
+ /** If the expression is the first argument then it does not need
+ * parens if it's precedence is greather than or equal to the
+ * enclosing doprim, because verilog operators are left
+ * associative. All other args do not need parens only if the
+ * precedence is greater.
+ */
+ case other =>
+ if (precedenceGt(e.op, doprim.op) || (precedenceGeq(e.op, doprim.op) && isFirst)) {
+ other
+ } else {
+ Seq("(", other, ")")
+ }
+ }
+
+ /** Mux args should always have parens because Mux has the lowest precedence
+ */
+ case _: Mux => Seq("(", e, ")")
+ case _ => e
+ }
+ }
+
// Cast to SInt, don't cast multiple times
def doCast(e: Expression): Any = e match {
case DoPrim(AsSInt, Seq(arg), _, _) => doCast(arg)
case slit: SIntLiteral => slit
case other => Seq("$signed(", other, ")")
}
- def castIf(e: Expression): Any = {
+ def castIf(e: Expression, isFirst: Boolean = false): Any = {
if (doprim.args.exists(_.tpe.isInstanceOf[SIntType])) {
e.tpe match {
case _: SIntType => doCast(e)
case _ => throwInternalError(s"Unexpected non-SInt type for $e in $doprim")
}
} else {
- e
+ parenthesize(e, isFirst)
}
}
- def cast(e: Expression): Any = doprim.tpe match {
- case _: UIntType => e
+ def cast(e: Expression, isFirst: Boolean = false): Any = doprim.tpe match {
+ case _: UIntType => parenthesize(e, isFirst)
case _: SIntType => doCast(e)
case _ => throwInternalError(s"Unexpected type for $e in $doprim")
}
- def castAs(e: Expression): Any = e.tpe match {
- case _: UIntType => e
+ def castAs(e: Expression, isFirst: Boolean = false): Any = e.tpe match {
+ case _: UIntType => parenthesize(e, isFirst)
case _: SIntType => doCast(e)
case _ => throwInternalError(s"Unexpected type for $e in $doprim")
}
@@ -381,19 +502,6 @@ class VerilogEmitter extends SeqTransform with Emitter {
def c0: Int = doprim.consts.head.toInt
def c1: Int = doprim.consts(1).toInt
- def checkArgumentLegality(e: Expression): Unit = e match {
- case _: UIntLiteral | _: SIntLiteral | _: WRef | _: WSubField =>
- case DoPrim(Not, args, _, _) => args.foreach(checkArgumentLegality)
- case DoPrim(op, args, _, _) if isCast(op) => args.foreach(checkArgumentLegality)
- case DoPrim(op, args, _, _) if isBitExtract(op) => args.foreach(checkArgumentLegality)
- case _ => throw EmitterException(s"Can't emit ${e.getClass.getName} as PrimOp argument")
- }
-
- def checkCatArgumentLegality(e: Expression): Unit = e match {
- case DoPrim(Cat, args, _, _) => args.foreach(checkCatArgumentLegality)
- case _ => checkArgumentLegality(e)
- }
-
def castCatArgs(a0: Expression, a1: Expression): Seq[Any] = {
val a0Seq = a0 match {
case cat @ DoPrim(PrimOps.Cat, args, _, _) => castCatArgs(args.head, args(1))
@@ -407,24 +515,19 @@ class VerilogEmitter extends SeqTransform with Emitter {
}
doprim.op match {
- case Cat => doprim.args.foreach(checkCatArgumentLegality)
- case cast if isCast(cast) => // Casts are allowed to wrap any Expression
- case other => doprim.args.foreach(checkArgumentLegality)
- }
- doprim.op match {
- case Add => Seq(castIf(a0), " + ", castIf(a1))
- case Addw => Seq(castIf(a0), " + ", castIf(a1))
- case Sub => Seq(castIf(a0), " - ", castIf(a1))
- case Subw => Seq(castIf(a0), " - ", castIf(a1))
- case Mul => Seq(castIf(a0), " * ", castIf(a1))
- case Div => Seq(castIf(a0), " / ", castIf(a1))
- case Rem => Seq(castIf(a0), " % ", castIf(a1))
- case Lt => Seq(castIf(a0), " < ", castIf(a1))
- case Leq => Seq(castIf(a0), " <= ", castIf(a1))
- case Gt => Seq(castIf(a0), " > ", castIf(a1))
- case Geq => Seq(castIf(a0), " >= ", castIf(a1))
- case Eq => Seq(castIf(a0), " == ", castIf(a1))
- case Neq => Seq(castIf(a0), " != ", castIf(a1))
+ case Add => Seq(castIf(a0, true), " + ", castIf(a1))
+ case Addw => Seq(castIf(a0, true), " + ", castIf(a1))
+ case Sub => Seq(castIf(a0, true), " - ", castIf(a1))
+ case Subw => Seq(castIf(a0, true), " - ", castIf(a1))
+ case Mul => Seq(castIf(a0, true), " * ", castIf(a1))
+ case Div => Seq(castIf(a0, true), " / ", castIf(a1))
+ case Rem => Seq(castIf(a0, true), " % ", castIf(a1))
+ case Lt => Seq(castIf(a0, true), " < ", castIf(a1))
+ case Leq => Seq(castIf(a0, true), " <= ", castIf(a1))
+ case Gt => Seq(castIf(a0, true), " > ", castIf(a1))
+ case Geq => Seq(castIf(a0, true), " >= ", castIf(a1))
+ case Eq => Seq(castIf(a0, true), " == ", castIf(a1))
+ case Neq => Seq(castIf(a0, true), " != ", castIf(a1))
case Pad =>
val w = bitWidth(a0.tpe)
val diff = c0 - w
@@ -434,7 +537,7 @@ class VerilogEmitter extends SeqTransform with Emitter {
// Either sign extend or zero extend.
// If width == BigInt(1), don't extract bit
case (_: SIntType) if w == BigInt(1) => Seq("{", c0, "{", a0, "}}")
- case (_: SIntType) => Seq("{{", diff, "{", a0, "[", w - 1, "]}},", a0, "}")
+ case (_: SIntType) => Seq("{{", diff, "{", parenthesize(a0, true), "[", w - 1, "]}},", a0, "}")
case (_) => Seq("{{", diff, "'d0}, ", a0, "}")
}
// Because we don't support complex Expressions, all casts are ignored
@@ -451,35 +554,35 @@ class VerilogEmitter extends SeqTransform with Emitter {
case Shl => if (c0 > 0) Seq("{", cast(a0), s", $c0'h0}") else Seq(cast(a0))
case Shr if c0 >= bitWidth(a0.tpe) =>
error("Verilog emitter does not support SHIFT_RIGHT >= arg width")
- case Shr if c0 == (bitWidth(a0.tpe) - 1) => Seq(a0, "[", bitWidth(a0.tpe) - 1, "]")
- case Shr => Seq(a0, "[", bitWidth(a0.tpe) - 1, ":", c0, "]")
- case Neg => Seq("-", cast(a0))
+ case Shr if c0 == (bitWidth(a0.tpe) - 1) => Seq(parenthesize(a0, true), "[", bitWidth(a0.tpe) - 1, "]")
+ case Shr => Seq(parenthesize(a0, true), "[", bitWidth(a0.tpe) - 1, ":", c0, "]")
+ case Neg => Seq("-", cast(a0, true))
case Cvt =>
a0.tpe match {
case (_: UIntType) => Seq("{1'b0,", cast(a0), "}")
case (_: SIntType) => Seq(cast(a0))
}
- case Not => Seq("~", a0)
- case And => Seq(castAs(a0), " & ", castAs(a1))
- case Or => Seq(castAs(a0), " | ", castAs(a1))
- case Xor => Seq(castAs(a0), " ^ ", castAs(a1))
- case Andr => Seq("&", cast(a0))
- case Orr => Seq("|", cast(a0))
- case Xorr => Seq("^", cast(a0))
+ case Not => Seq("~", parenthesize(a0, true))
+ case And => Seq(castAs(a0, true), " & ", castAs(a1))
+ case Or => Seq(castAs(a0, true), " | ", castAs(a1))
+ case Xor => Seq(castAs(a0, true), " ^ ", castAs(a1))
+ case Andr => Seq("&", cast(a0, true))
+ case Orr => Seq("|", cast(a0, true))
+ case Xorr => Seq("^", cast(a0, true))
case Cat => "{" +: (castCatArgs(a0, a1) :+ "}")
// If selecting zeroth bit and single-bit wire, just emit the wire
case Bits if c0 == 0 && c1 == 0 && bitWidth(a0.tpe) == BigInt(1) => Seq(a0)
- case Bits if c0 == c1 => Seq(a0, "[", c0, "]")
- case Bits => Seq(a0, "[", c0, ":", c1, "]")
+ case Bits if c0 == c1 => Seq(parenthesize(a0, true), "[", c0, "]")
+ case Bits => Seq(parenthesize(a0, true), "[", c0, ":", c1, "]")
// If selecting zeroth bit and single-bit wire, just emit the wire
case Head if c0 == 1 && bitWidth(a0.tpe) == BigInt(1) => Seq(a0)
case Head if c0 == 1 => Seq(a0, "[", bitWidth(a0.tpe) - 1, "]")
case Head =>
val msb = bitWidth(a0.tpe) - 1
val lsb = bitWidth(a0.tpe) - c0
- Seq(a0, "[", msb, ":", lsb, "]")
- case Tail if c0 == (bitWidth(a0.tpe) - 1) => Seq(a0, "[0]")
- case Tail => Seq(a0, "[", bitWidth(a0.tpe) - c0 - 1, ":0]")
+ Seq(parenthesize(a0, true), "[", msb, ":", lsb, "]")
+ case Tail if c0 == (bitWidth(a0.tpe) - 1) => Seq(parenthesize(a0, true), "[0]")
+ case Tail => Seq(parenthesize(a0, true), "[", bitWidth(a0.tpe) - c0 - 1, ":0]")
}
}
@@ -804,7 +907,7 @@ class VerilogEmitter extends SeqTransform with Emitter {
}
def regUpdate(r: Expression, clk: Expression, reset: Expression, init: Expression) = {
- def addUpdate(info: Info, expr: Expression, tabs: String): Seq[Seq[Any]] = expr match {
+ def addUpdate(info: Info, expr: Expression, tabs: Seq[String]): Seq[Seq[Any]] = expr match {
case m: Mux =>
if (m.tpe == ClockType) throw EmitterException("Cannot emit clock muxes directly")
if (m.tpe == AsyncResetType) throw EmitterException("Cannot emit async reset muxes directly")
@@ -814,8 +917,8 @@ class VerilogEmitter extends SeqTransform with Emitter {
lazy val _else = Seq(tabs, "end else begin")
lazy val _ifNot = Seq(tabs, "if (!(", m.cond, ")) begin", eninfo)
lazy val _end = Seq(tabs, "end")
- lazy val _true = addUpdate(tinfo, m.tval, tabs + tab)
- lazy val _false = addUpdate(finfo, m.fval, tabs + tab)
+ lazy val _true = addUpdate(tinfo, m.tval, tab +: tabs)
+ lazy val _false = addUpdate(finfo, m.fval, tab +: tabs)
lazy val _elseIfFalse = {
val _falsex = addUpdate(finfo, m.fval, tabs) // _false, but without an additional tab
Seq(tabs, "end else ", _falsex.head.tail) +: _falsex.tail
@@ -845,13 +948,19 @@ class VerilogEmitter extends SeqTransform with Emitter {
}
if (weq(init, r)) { // Synchronous Reset
val InfoExpr(info, e) = netlist(r)
- noResetAlwaysBlocks.getOrElseUpdate(clk, ArrayBuffer[Seq[Any]]()) ++= addUpdate(info, e, "")
+ noResetAlwaysBlocks.getOrElseUpdate(clk, ArrayBuffer[Seq[Any]]()) ++= addUpdate(info, e, Seq.empty)
} else { // Asynchronous Reset
assert(reset.tpe == AsyncResetType, "Error! Synchronous reset should have been removed!")
val tv = init
val InfoExpr(finfo, fv) = netlist(r)
// TODO add register info argument and build a MultiInfo to pass
- asyncResetAlwaysBlocks += ((clk, reset, addUpdate(NoInfo, Mux(reset, tv, fv, mux_type_and_widths(tv, fv)), "")))
+ asyncResetAlwaysBlocks += (
+ (
+ clk,
+ reset,
+ addUpdate(NoInfo, Mux(reset, tv, fv, mux_type_and_widths(tv, fv)), Seq.empty)
+ )
+ )
}
}
@@ -1367,11 +1476,18 @@ class VerilogEmitter extends SeqTransform with Emitter {
}
override def execute(state: CircuitState): CircuitState = {
+ val writerToString =
+ (writer: java.io.StringWriter) => writer.toString.replaceAll("""(?m) +$""", "") // trim trailing whitespace
+
val newAnnos = state.annotations.flatMap {
case EmitCircuitAnnotation(a) if this.getClass == a =>
val writer = new java.io.StringWriter
emit(state, writer)
- Seq(EmittedVerilogCircuitAnnotation(EmittedVerilogCircuit(state.circuit.main, writer.toString, outputSuffix)))
+ Seq(
+ EmittedVerilogCircuitAnnotation(
+ EmittedVerilogCircuit(state.circuit.main, writerToString(writer), outputSuffix)
+ )
+ )
case EmitAllModulesAnnotation(a) if this.getClass == a =>
val cs = runTransforms(state)
@@ -1383,12 +1499,16 @@ class VerilogEmitter extends SeqTransform with Emitter {
val writer = new java.io.StringWriter
val renderer = new VerilogRender(d, pds, module, moduleMap, cs.circuit.main, emissionOptions)(writer)
renderer.emit_verilog()
- Some(EmittedVerilogModuleAnnotation(EmittedVerilogModule(module.name, writer.toString, outputSuffix)))
+ Some(
+ EmittedVerilogModuleAnnotation(EmittedVerilogModule(module.name, writerToString(writer), outputSuffix))
+ )
case module: Module =>
val writer = new java.io.StringWriter
val renderer = new VerilogRender(module, moduleMap, cs.circuit.main, emissionOptions)(writer)
renderer.emit_verilog()
- Some(EmittedVerilogModuleAnnotation(EmittedVerilogModule(module.name, writer.toString, outputSuffix)))
+ Some(
+ EmittedVerilogModuleAnnotation(EmittedVerilogModule(module.name, writerToString(writer), outputSuffix))
+ )
case _ => None
}
case _ => Seq()
diff --git a/src/main/scala/firrtl/stage/Forms.scala b/src/main/scala/firrtl/stage/Forms.scala
index a0c5ea0c..db411325 100644
--- a/src/main/scala/firrtl/stage/Forms.scala
+++ b/src/main/scala/firrtl/stage/Forms.scala
@@ -110,7 +110,11 @@ object Forms {
Dependency[firrtl.AddDescriptionNodes]
)
- val VerilogOptimized: Seq[TransformDependency] = LowFormOptimized ++ VerilogMinimumOptimized
+ val VerilogOptimized: Seq[TransformDependency] = LowFormOptimized ++
+ Seq(
+ Dependency[firrtl.transforms.InlineBooleanExpressions]
+ ) ++
+ VerilogMinimumOptimized
val AssertsRemoved: Seq[TransformDependency] =
Seq(
diff --git a/src/main/scala/firrtl/transforms/InlineBooleanExpressions.scala b/src/main/scala/firrtl/transforms/InlineBooleanExpressions.scala
new file mode 100644
index 00000000..7c52d6ef
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/InlineBooleanExpressions.scala
@@ -0,0 +1,169 @@
+// See LICENSE for license details.
+
+package firrtl
+package transforms
+
+import firrtl.annotations.{NoTargetAnnotation, Target}
+import firrtl.annotations.TargetToken.{fromStringToTargetToken, OfModule, Ref}
+import firrtl.ir._
+import firrtl.passes.{InferTypes, LowerTypes, SplitExpressions}
+import firrtl.options.Dependency
+import firrtl.PrimOps._
+import firrtl.WrappedExpression._
+
+import scala.collection.mutable
+
+case class InlineBooleanExpressionsMax(max: Int) extends NoTargetAnnotation
+
+object InlineBooleanExpressions {
+ val defaultMax = 30
+}
+
+/** Inline Bool expressions
+ *
+ * The following conditions must be satisfied to inline
+ * 1. has type [[Utils.BoolType]]
+ * 2. is bound to a [[firrtl.ir.DefNode DefNode]] with name starting with '_'
+ * 3. is bound to a [[firrtl.ir.DefNode DefNode]] with a source locator that
+ * points at the same file and line number. If it is a MultiInfo source
+ * locator, the set of file and line number pairs must be the same. Source
+ * locators may point to different column numbers.
+ * 4. [[InlineBooleanExpressionsMax]] has not been exceeded
+ * 5. is not a [[firrtl.ir.Mux Mux]]
+ */
+class InlineBooleanExpressions extends Transform with DependencyAPIMigration {
+
+ override def prerequisites = Seq(
+ Dependency(InferTypes),
+ Dependency(LowerTypes)
+ )
+
+ override def optionalPrerequisites = Seq(
+ Dependency(SplitExpressions)
+ )
+
+ override def invalidates(a: Transform) = a match {
+ case _: DeadCodeElimination => true // this transform does not remove nodes that are unused after inlining
+ case _ => false
+ }
+
+ type Netlist = mutable.HashMap[WrappedExpression, (Expression, Info)]
+
+ private def isArgN(outerExpr: DoPrim, subExpr: Expression, n: Int): Boolean = {
+ outerExpr.args.lift(n) match {
+ case Some(arg) => arg eq subExpr
+ case _ => false
+ }
+ }
+
+ private val fileLineRegex = """(.*) ([0-9]+):[0-9]+""".r
+ private def sameFileAndLineInfo(info1: Info, info2: Info): Boolean = {
+ (info1, info2) match {
+ case (FileInfo(fileLineRegex(file1, line1)), FileInfo(fileLineRegex(file2, line2))) =>
+ (file1 == file2) && (line1 == line2)
+ case (MultiInfo(infos1), MultiInfo(infos2)) if infos1.size == infos2.size =>
+ infos1.zip(infos2).forall {
+ case (i1, i2) =>
+ sameFileAndLineInfo(i1, i2)
+ }
+ case (NoInfo, NoInfo) => true
+ case _ => false
+ }
+ }
+
+ /** A helper class to initialize and store mutable state that the expression
+ * and statement map functions need access to. This makes it easier to pass
+ * information around without having to plump arguments through the onExpr
+ * and onStmt methods.
+ */
+ private class MapMethods(maxInlineCount: Int, dontTouches: Set[Ref]) {
+ val netlist: Netlist = new Netlist
+ val inlineCounts = mutable.Map.empty[Ref, Int]
+ var inlineCount: Int = 1
+
+ /** Whether or not an can be inlined
+ * @param refExpr the expression to check for inlining
+ */
+ def canInline(refExpr: Expression): Boolean = {
+ refExpr match {
+ case _: Mux => false
+ case _ => refExpr.tpe == Utils.BoolType
+ }
+ }
+
+ /** Inlines [[Wref]]s if they are Boolean, have matching file line numbers,
+ * and would not raise inlineCounts past the maximum.
+ *
+ * @param info the [[Info]] of the enclosing [[Statement]]
+ * @param outerExpr the direct parent [[Expression]] of the current [[Expression]]
+ * @param expr the [[Expression]] to apply inlining to
+ */
+ def onExpr(info: Info, outerExpr: Option[Expression])(expr: Expression): Expression = {
+ expr match {
+ case ref: WRef if !dontTouches.contains(ref.name.Ref) && ref.name.head == '_' =>
+ val refKey = ref.name.Ref
+ netlist.get(we(ref)) match {
+ case Some((refExpr, refInfo)) if sameFileAndLineInfo(info, refInfo) =>
+ val inlineNum = inlineCounts.getOrElse(refKey, 1)
+ if (!outerExpr.isDefined || canInline(refExpr) && ((inlineNum + inlineCount) <= maxInlineCount)) {
+ inlineCount += inlineNum
+ refExpr
+ } else {
+ ref
+ }
+ case other => ref
+ }
+ case other => other.mapExpr(onExpr(info, Some(other)))
+ }
+ }
+
+ /** Applies onExpr and records metadata for every [[HasInfo]] in a [[Statement]]
+ *
+ * This resets inlineCount before inlining and records the resulting
+ * inline counts and inlined values in the inlineCounts and netlist maps
+ * after inlining.
+ */
+ def onStmt(stmt: Statement): Statement = {
+ stmt.mapStmt(onStmt) match {
+ case hasInfo: HasInfo =>
+ inlineCount = 1
+ val stmtx = hasInfo.mapExpr(onExpr(hasInfo.info, None))
+ stmtx match {
+ case node: DefNode => inlineCounts(node.name.Ref) = inlineCount
+ case _ =>
+ }
+ stmtx match {
+ case node @ DefNode(info, name, value) =>
+ netlist(we(WRef(name))) = (value, info)
+ case _ =>
+ }
+ stmtx
+ case other => other
+ }
+ }
+ }
+
+ def execute(state: CircuitState): CircuitState = {
+ val dontTouchMap: Map[OfModule, Set[Ref]] = {
+ val refTargets = state.annotations.flatMap {
+ case anno: HasDontTouches => anno.dontTouches
+ case o => Nil
+ }
+ val dontTouches: Seq[(OfModule, Ref)] = refTargets.map {
+ case r => Target.referringModule(r).module.OfModule -> r.ref.Ref
+ }
+ dontTouches.groupBy(_._1).mapValues(_.map(_._2).toSet).toMap
+ }
+
+ val maxInlineCount = state.annotations.collectFirst {
+ case InlineBooleanExpressionsMax(max) => max
+ }.getOrElse(InlineBooleanExpressions.defaultMax)
+
+ val modulesx = state.circuit.modules.map { m =>
+ val mapMethods = new MapMethods(maxInlineCount, dontTouchMap.getOrElse(m.name.OfModule, Set.empty[Ref]))
+ m.mapStmt(mapMethods.onStmt(_))
+ }
+
+ state.copy(circuit = state.circuit.copy(modules = modulesx))
+ }
+}
diff --git a/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala b/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala
new file mode 100644
index 00000000..5fee87c9
--- /dev/null
+++ b/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala
@@ -0,0 +1,242 @@
+
+// See LICENSE for license details.
+
+package firrtlTests
+
+import firrtl._
+import firrtl.annotations.Annotation
+import firrtl.options.Dependency
+import firrtl.passes._
+import firrtl.transforms._
+import firrtl.testutils._
+import firrtl.stage.TransformManager
+
+class InlineBooleanExpressionsSpec extends FirrtlFlatSpec {
+ val transform = new InlineBooleanExpressions
+ val transforms: Seq[Transform] = new TransformManager(
+ transform.prerequisites
+ ).flattenedTransformOrder :+ transform
+
+ protected def exec(input: String, annos: Seq[Annotation] = Nil) = {
+ transforms.foldLeft(CircuitState(parse(input), UnknownForm, AnnotationSeq(annos))) {
+ (c: CircuitState, t: Transform) => t.runTransform(c)
+ }.circuit.serialize
+ }
+
+ it should "inline mux operands" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | output out : UInt<1>
+ | node x1 = UInt<1>(0)
+ | node x2 = UInt<1>(1)
+ | node _t = head(x1, 1)
+ | node _f = head(x2, 1)
+ | node _c = lt(x1, x2)
+ | node _y = mux(_c, _t, _f)
+ | out <= _y""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | output out : UInt<1>
+ | node x1 = UInt<1>(0)
+ | node x2 = UInt<1>(1)
+ | node _t = head(x1, 1)
+ | node _f = head(x2, 1)
+ | node _c = lt(x1, x2)
+ | node _y = mux(lt(x1, x2), head(x1, 1), head(x2, 1))
+ | out <= mux(lt(x1, x2), head(x1, 1), head(x2, 1))""".stripMargin
+ val result = exec(input)
+ (result) should be (parse(check).serialize)
+ firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions))
+ }
+
+ it should "only inline expressions with the same file and line number" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | output outA1 : UInt<1>
+ | output outA2 : UInt<1>
+ | output outB : UInt<1>
+ | node x1 = UInt<1>(0)
+ | node x2 = UInt<1>(1)
+ |
+ | node _t = head(x1, 1) @[A 1:1]
+ | node _f = head(x2, 1) @[A 1:2]
+ | node _y = mux(lt(x1, x2), _t, _f) @[A 1:3]
+ | outA1 <= _y @[A 1:3]
+ |
+ | outA2 <= _y @[A 2:3]
+ |
+ | outB <= _y @[B]""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | output outA1 : UInt<1>
+ | output outA2 : UInt<1>
+ | output outB : UInt<1>
+ | node x1 = UInt<1>(0)
+ | node x2 = UInt<1>(1)
+ |
+ | node _t = head(x1, 1) @[A 1:1]
+ | node _f = head(x2, 1) @[A 1:2]
+ | node _y = mux(lt(x1, x2), head(x1, 1), head(x2, 1)) @[A 1:3]
+ | outA1 <= mux(lt(x1, x2), head(x1, 1), head(x2, 1)) @[A 1:3]
+ |
+ | outA2 <= _y @[A 2:3]
+ |
+ | outB <= _y @[B]""".stripMargin
+ val result = exec(input)
+ (result) should be (parse(check).serialize)
+ firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions))
+ }
+
+ it should "inline boolean DoPrims" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | output outA : UInt<1>
+ | output outB : UInt<1>
+ | node x1 = UInt<3>(0)
+ | node x2 = UInt<3>(1)
+ |
+ | node _a = lt(x1, x2)
+ | node _b = eq(_a, x2)
+ | node _c = and(_b, x2)
+ | outA <= _c
+ |
+ | node _d = head(_c, 1)
+ | node _e = andr(_d)
+ | node _f = lt(_e, x2)
+ | outB <= _f""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | output outA : UInt<1>
+ | output outB : UInt<1>
+ | node x1 = UInt<3>(0)
+ | node x2 = UInt<3>(1)
+ |
+ | node _a = lt(x1, x2)
+ | node _b = eq(lt(x1, x2), x2)
+ | node _c = and(eq(lt(x1, x2), x2), x2)
+ | outA <= and(eq(lt(x1, x2), x2), x2)
+ |
+ | node _d = head(_c, 1)
+ | node _e = andr(head(_c, 1))
+ | node _f = lt(andr(head(_c, 1)), x2)
+ |
+ | outB <= lt(andr(head(_c, 1)), x2)""".stripMargin
+ val result = exec(input)
+ (result) should be (parse(check).serialize)
+ firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions))
+ }
+
+ it should "inline more boolean DoPrims" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | output outA : UInt<1>
+ | output outB : UInt<1>
+ | node x1 = UInt<3>(0)
+ | node x2 = UInt<3>(1)
+ |
+ | node _a = lt(x1, x2)
+ | node _b = leq(_a, x2)
+ | node _c = gt(_b, x2)
+ | node _d = geq(_c, x2)
+ | outA <= _d
+ |
+ | node _e = lt(x1, x2)
+ | node _f = leq(x1, _e)
+ | node _g = gt(x1, _f)
+ | node _h = geq(x1, _g)
+ | outB <= _h""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | output outA : UInt<1>
+ | output outB : UInt<1>
+ | node x1 = UInt<3>(0)
+ | node x2 = UInt<3>(1)
+ |
+ | node _a = lt(x1, x2)
+ | node _b = leq(lt(x1, x2), x2)
+ | node _c = gt(leq(lt(x1, x2), x2), x2)
+ | node _d = geq(gt(leq(lt(x1, x2), x2), x2), x2)
+ | outA <= geq(gt(leq(lt(x1, x2), x2), x2), x2)
+ |
+ | node _e = lt(x1, x2)
+ | node _f = leq(x1, lt(x1, x2))
+ | node _g = gt(x1, leq(x1, lt(x1, x2)))
+ | node _h = geq(x1, gt(x1, leq(x1, lt(x1, x2))))
+ |
+ | outB <= geq(x1, gt(x1, leq(x1, lt(x1, x2))))""".stripMargin
+ val result = exec(input)
+ (result) should be (parse(check).serialize)
+ firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions))
+ }
+
+ it should "limit the number of inlines" in {
+ val input =
+ s"""circuit Top :
+ | module Top :
+ | input c_0: UInt<1>
+ | input c_1: UInt<1>
+ | input c_2: UInt<1>
+ | input c_3: UInt<1>
+ | input c_4: UInt<1>
+ | input c_5: UInt<1>
+ | input c_6: UInt<1>
+ | output out : UInt<1>
+ |
+ | node _1 = or(c_0, c_1)
+ | node _2 = or(_1, c_2)
+ | node _3 = or(_2, c_3)
+ | node _4 = or(_3, c_4)
+ | node _5 = or(_4, c_5)
+ | node _6 = or(_5, c_6)
+ |
+ | out <= _6""".stripMargin
+ val check =
+ s"""circuit Top :
+ | module Top :
+ | input c_0: UInt<1>
+ | input c_1: UInt<1>
+ | input c_2: UInt<1>
+ | input c_3: UInt<1>
+ | input c_4: UInt<1>
+ | input c_5: UInt<1>
+ | input c_6: UInt<1>
+ | output out : UInt<1>
+ |
+ | node _1 = or(c_0, c_1)
+ | node _2 = or(or(c_0, c_1), c_2)
+ | node _3 = or(or(or(c_0, c_1), c_2), c_3)
+ | node _4 = or(_3, c_4)
+ | node _5 = or(or(_3, c_4), c_5)
+ | node _6 = or(or(or(_3, c_4), c_5), c_6)
+ |
+ | out <= or(or(or(_3, c_4), c_5), c_6)""".stripMargin
+ val result = exec(input, Seq(InlineBooleanExpressionsMax(3)))
+ (result) should be (parse(check).serialize)
+ firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions))
+ }
+
+ it should "be equivalent" in {
+ val input =
+ """circuit InlineBooleanExpressionsEquivalenceTest :
+ | module InlineBooleanExpressionsEquivalenceTest :
+ | input in : UInt<1>[6]
+ | output out : UInt<1>
+ |
+ | node _a = or(in[0], in[1])
+ | node _b = and(in[2], _a)
+ | node _c = eq(in[3], _b)
+ | node _d = lt(in[4], _c)
+ | node _e = eq(in[5], _d)
+ | node _f = head(_e, 1)
+ | out <= _f""".stripMargin
+ firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions))
+ }
+}
diff --git a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
index 46416619..40f8f123 100644
--- a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
+++ b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
@@ -260,6 +260,8 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers {
it should "replicate the old order" in {
val legacy = Seq(
+ new firrtl.transforms.InlineBooleanExpressions,
+ new firrtl.transforms.DeadCodeElimination,
new firrtl.transforms.BlackBoxSourceHelper,
new firrtl.transforms.FixAddingNegativeLiterals,
new firrtl.transforms.ReplaceTruncatingArithmetic,
diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala
index 8f128274..a864bfe5 100644
--- a/src/test/scala/firrtlTests/UnitTests.scala
+++ b/src/test/scala/firrtlTests/UnitTests.scala
@@ -110,18 +110,8 @@ class UnitTests extends FirrtlFlatSpec {
| out <= bits(mux(a, b, c), 0, 0)
|""".stripMargin
- "Emitting a nested expression" should "throw an exception" in {
+ "Emitting a nested expression" should "compile" in {
val passes = Seq(ToWorkingIR, InferTypes, ResolveKinds)
- intercept[PassException] {
- val c = Parser.parse(splitExpTestCode.split("\n").toIterator)
- val c2 = passes.foldLeft(c)((c, p) => p.run(c))
- val writer = new StringWriter()
- (new VerilogEmitter).emit(CircuitState(c2, LowForm), writer)
- }
- }
-
- "After splitting, emitting a nested expression" should "compile" in {
- val passes = Seq(ToWorkingIR, SplitExpressions, InferTypes)
val c = Parser.parse(splitExpTestCode.split("\n").toIterator)
val c2 = passes.foldLeft(c)((c, p) => p.run(c))
val writer = new StringWriter()