aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--fuzzer/src/main/scala/firrtl/ExprGenParams.scala65
-rw-r--r--fuzzer/src/main/scala/firrtl/FirrtlCompileTests.scala54
-rw-r--r--fuzzer/src/main/scala/firrtl/FirrtlEquivalenceTest.scala59
-rw-r--r--src/main/scala/firrtl/Emitter.scala288
-rw-r--r--src/main/scala/firrtl/stage/Forms.scala6
-rw-r--r--src/main/scala/firrtl/transforms/InlineBooleanExpressions.scala169
-rw-r--r--src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala242
-rw-r--r--src/test/scala/firrtlTests/LoweringCompilersSpec.scala2
-rw-r--r--src/test/scala/firrtlTests/UnitTests.scala12
9 files changed, 740 insertions, 157 deletions
diff --git a/fuzzer/src/main/scala/firrtl/ExprGenParams.scala b/fuzzer/src/main/scala/firrtl/ExprGenParams.scala
index ddaec00d..4c11b860 100644
--- a/fuzzer/src/main/scala/firrtl/ExprGenParams.scala
+++ b/fuzzer/src/main/scala/firrtl/ExprGenParams.scala
@@ -1,5 +1,8 @@
package firrtl.fuzzer
+import com.pholser.junit.quickcheck.generator.{Generator, GenerationStatus}
+import com.pholser.junit.quickcheck.random.SourceOfRandomness
+
import firrtl.{Namespace, Utils}
import firrtl.ir._
@@ -17,15 +20,15 @@ sealed trait ExprGenParams {
*/
def maxWidth: Int
- /** A list of frequency/expression generator pairs
+ /** A mapping of expression generator to frequency
*
* The frequency number determines the probability that the corresponding
- * generator will be chosen. i.e. for sequece Seq(1 -> A, 2 -> B, 3 -> C),
- * the probabilities for A, B, and C are 1/6, 2/6, and 3/6 respectively.
- * This sequency must be non-empty and all frequency numbers must be greater
- * than zero.
+ * generator will be chosen. i.g. for Map(A -> 1, B -> 2, C -> B), the
+ * probabilities for A, B, and C are 1/6, 2/6, and 3/6 respectively. This
+ * map must be non-empty and all frequency numbers must be greater than
+ * zero.
*/
- def generators: Seq[(Int, ExprGen[_ <: Expression])]
+ def generators: Map[ExprGen[_ <: Expression], Int]
/** The set of generated references that don't have a corresponding declaration
*/
@@ -102,10 +105,46 @@ sealed trait ExprGenParams {
object ExprGenParams {
+ val defaultGenerators: Map[ExprGen[_ <: Expression], Int] = {
+ import ExprGen._
+ Map(
+ AddDoPrimGen -> 1,
+ SubDoPrimGen -> 1,
+ MulDoPrimGen -> 1,
+ DivDoPrimGen -> 1,
+ LtDoPrimGen -> 1,
+ LeqDoPrimGen -> 1,
+ GtDoPrimGen -> 1,
+ GeqDoPrimGen -> 1,
+ EqDoPrimGen -> 1,
+ NeqDoPrimGen -> 1,
+ PadDoPrimGen -> 1,
+ ShlDoPrimGen -> 1,
+ ShrDoPrimGen -> 1,
+ DshlDoPrimGen -> 1,
+ CvtDoPrimGen -> 1,
+ NegDoPrimGen -> 1,
+ NotDoPrimGen -> 1,
+ AndDoPrimGen -> 1,
+ OrDoPrimGen -> 1,
+ XorDoPrimGen -> 1,
+ AndrDoPrimGen -> 1,
+ OrrDoPrimGen -> 1,
+ XorrDoPrimGen -> 1,
+ CatDoPrimGen -> 1,
+ BitsDoPrimGen -> 1,
+ HeadDoPrimGen -> 1,
+ TailDoPrimGen -> 1,
+ AsUIntDoPrimGen -> 1,
+ AsSIntDoPrimGen -> 1,
+ MuxGen -> 1
+ )
+ }
+
private case class ExprGenParamsImp(
maxDepth: Int,
maxWidth: Int,
- generators: Seq[(Int, ExprGen[_ <: Expression])],
+ generators: Map[ExprGen[_ <: Expression], Int],
protected val unboundRefs: Set[Reference],
protected val namespace: Namespace) extends ExprGenParams {
@@ -119,7 +158,7 @@ object ExprGenParams {
def apply(
maxDepth: Int,
maxWidth: Int,
- generators: Seq[(Int, ExprGen[_ <: Expression])]
+ generators: Map[ExprGen[_ <: Expression], Int]
): ExprGenParams = {
require(maxWidth > 0, "maxWidth must be greater than zero")
ExprGenParamsImp(
@@ -190,7 +229,8 @@ object ExprGenParams {
))(tpe).map(e => e.get) // should be safe because leaf generators are defined for all types
val branchGen: Type => StateGen[ExprGenParams, G, Expression] = (tpe: Type) => {
- combineExprGens(s.generators)(tpe).flatMap {
+ val gens = s.generators.toSeq.map { case (gen, freq) => (freq, gen) }
+ combineExprGens(gens)(tpe).flatMap {
case None => leafGen(tpe)
case Some(e) => StateGen.pure(e)
}
@@ -211,3 +251,10 @@ object ExprGenParams {
}
}
}
+
+abstract class SingleExpressionCircuitGenerator(val params: ExprGenParams) extends Generator[Circuit](classOf[Circuit]) {
+ override def generate(random: SourceOfRandomness, status: GenerationStatus): Circuit = {
+ implicit val r = random
+ params.generateSingleExprCircuit[SourceOfRandomnessGen]()
+ }
+}
diff --git a/fuzzer/src/main/scala/firrtl/FirrtlCompileTests.scala b/fuzzer/src/main/scala/firrtl/FirrtlCompileTests.scala
index 9ee8e52b..3091e4d6 100644
--- a/fuzzer/src/main/scala/firrtl/FirrtlCompileTests.scala
+++ b/fuzzer/src/main/scala/firrtl/FirrtlCompileTests.scala
@@ -53,50 +53,14 @@ object SourceOfRandomnessGen {
}
}
-
-class FirrtlSingleModuleGenerator extends Generator[Circuit](classOf[Circuit]) {
- override def generate(random: SourceOfRandomness, status: GenerationStatus): Circuit = {
- implicit val r = random
- import ExprGen._
-
- val params = ExprGenParams(
- maxDepth = 50,
- maxWidth = 31,
- generators = Seq(
- 1 -> AddDoPrimGen,
- 1 -> SubDoPrimGen,
- 1 -> MulDoPrimGen,
- 1 -> DivDoPrimGen,
- 1 -> LtDoPrimGen,
- 1 -> LeqDoPrimGen,
- 1 -> GtDoPrimGen,
- 1 -> GeqDoPrimGen,
- 1 -> EqDoPrimGen,
- 1 -> NeqDoPrimGen,
- 1 -> PadDoPrimGen,
- 1 -> ShlDoPrimGen,
- 1 -> ShrDoPrimGen,
- 1 -> DshlDoPrimGen,
- 1 -> CvtDoPrimGen,
- 1 -> NegDoPrimGen,
- 1 -> NotDoPrimGen,
- 1 -> AndDoPrimGen,
- 1 -> OrDoPrimGen,
- 1 -> XorDoPrimGen,
- 1 -> AndrDoPrimGen,
- 1 -> OrrDoPrimGen,
- 1 -> XorrDoPrimGen,
- 1 -> CatDoPrimGen,
- 1 -> BitsDoPrimGen,
- 1 -> HeadDoPrimGen,
- 1 -> TailDoPrimGen,
- 1 -> AsUIntDoPrimGen,
- 1 -> AsSIntDoPrimGen
- )
- )
- params.generateSingleExprCircuit[SourceOfRandomnessGen]()
- }
-}
+import ExprGen._
+class FirrtlCompileCircuitGenerator extends SingleExpressionCircuitGenerator (
+ ExprGenParams(
+ maxDepth = 50,
+ maxWidth = 31,
+ generators = ExprGenParams.defaultGenerators
+ )
+)
@RunWith(classOf[JQF])
class FirrtlCompileTests {
@@ -112,7 +76,7 @@ class FirrtlCompileTests {
}
@Fuzz
- def compileSingleModule(@From(value = classOf[FirrtlSingleModuleGenerator]) c: Circuit) = {
+ def compileSingleModule(@From(value = classOf[FirrtlCompileCircuitGenerator]) c: Circuit) = {
compile(CircuitState(c, ChirrtlForm, Seq()))
}
diff --git a/fuzzer/src/main/scala/firrtl/FirrtlEquivalenceTest.scala b/fuzzer/src/main/scala/firrtl/FirrtlEquivalenceTest.scala
index e0aa1707..822744c2 100644
--- a/fuzzer/src/main/scala/firrtl/FirrtlEquivalenceTest.scala
+++ b/fuzzer/src/main/scala/firrtl/FirrtlEquivalenceTest.scala
@@ -1,17 +1,19 @@
package firrtl.fuzzer
import com.pholser.junit.quickcheck.From
+import com.pholser.junit.quickcheck.generator.{Generator, GenerationStatus}
+import com.pholser.junit.quickcheck.random.SourceOfRandomness
-import edu.berkeley.cs.jqf.fuzz.Fuzz;
-import edu.berkeley.cs.jqf.fuzz.JQF;
+import edu.berkeley.cs.jqf.fuzz.{Fuzz, JQF};
import firrtl._
import firrtl.annotations.{Annotation, CircuitTarget, ModuleTarget, Target}
import firrtl.ir.Circuit
+import firrtl.options.Dependency
import firrtl.stage.{FirrtlCircuitAnnotation, InfoModeAnnotation, OutputFileAnnotation, TransformManager}
import firrtl.stage.Forms.{VerilogMinimumOptimized, VerilogOptimized}
import firrtl.stage.phases.WriteEmitted
-import firrtl.transforms.ManipulateNames
+import firrtl.transforms.{InlineBooleanExpressions, ManipulateNames}
import firrtl.util.BackendCompilationUtilities
import java.io.{File, FileWriter, PrintWriter, StringWriter}
@@ -89,6 +91,32 @@ object FirrtlEquivalenceTestUtils {
}
}
+import ExprGen._
+class InlineBooleanExprsCircuitGenerator extends SingleExpressionCircuitGenerator (
+ ExprGenParams(
+ maxDepth = 50,
+ maxWidth = 31,
+ generators = ExprGenParams.defaultGenerators ++ Map(
+ LtDoPrimGen -> 10,
+ LeqDoPrimGen -> 10,
+ GtDoPrimGen -> 10,
+ GeqDoPrimGen -> 10,
+ EqDoPrimGen -> 10,
+ NeqDoPrimGen -> 10,
+ AndDoPrimGen -> 10,
+ OrDoPrimGen -> 10,
+ XorDoPrimGen -> 10,
+ AndrDoPrimGen -> 10,
+ OrrDoPrimGen -> 10,
+ XorrDoPrimGen -> 10,
+ BitsDoPrimGen -> 10,
+ HeadDoPrimGen -> 10,
+ TailDoPrimGen -> 10,
+ MuxGen -> 10
+ )
+ )
+)
+
@RunWith(classOf[JQF])
class FirrtlEquivalenceTests {
private val lowFirrtlCompiler = new LowFirrtlCompiler()
@@ -103,8 +131,7 @@ class FirrtlEquivalenceTests {
}
private val baseTestDir = new File("fuzzer/test_run_dir")
- @Fuzz
- def compileSingleModule(@From(value = classOf[FirrtlSingleModuleGenerator]) c: Circuit) = {
+ private def runTest(c: Circuit, referenceCompiler: TransformManager, customCompiler: TransformManager) = {
val testDir = new File(baseTestDir, f"${c.hashCode}%08x")
testDir.mkdirs()
val fileWriter = new FileWriter(new File(testDir, s"${c.main}.fir"))
@@ -113,9 +140,9 @@ class FirrtlEquivalenceTests {
val passed = try {
FirrtlEquivalenceTestUtils.firrtlEquivalenceTestPass(
circuit = c,
- referenceCompiler = new TransformManager(VerilogMinimumOptimized),
+ referenceCompiler = referenceCompiler,
referenceAnnos = Seq(),
- customCompiler = new TransformManager(VerilogOptimized),
+ customCompiler = customCompiler,
customAnnos = Seq(),
testDir = testDir
)
@@ -131,4 +158,22 @@ class FirrtlEquivalenceTests {
s"not equivalent to reference compiler on input ${testDir}:\n${c.serialize}\n", false)
}
}
+
+ @Fuzz
+ def testOptimized(@From(value = classOf[FirrtlCompileCircuitGenerator]) c: Circuit) = {
+ runTest(
+ c = c,
+ referenceCompiler = new TransformManager(VerilogMinimumOptimized),
+ customCompiler = new TransformManager(VerilogOptimized)
+ )
+ }
+
+ @Fuzz
+ def testInlineBooleanExpressions(@From(value = classOf[InlineBooleanExprsCircuitGenerator]) c: Circuit) = {
+ runTest(
+ c = c,
+ referenceCompiler = new TransformManager(VerilogMinimumOptimized),
+ customCompiler = new TransformManager(VerilogMinimumOptimized :+ Dependency[InlineBooleanExpressions])
+ )
+ }
}
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala
index 843c76a4..19f2661a 100644
--- a/src/main/scala/firrtl/Emitter.scala
+++ b/src/main/scala/firrtl/Emitter.scala
@@ -230,7 +230,48 @@ case class VRandom(width: BigInt) extends Expression {
def foreachWidth(f: Width => Unit): Unit = ()
}
+object VerilogEmitter {
+
+ /** Maps a [[PrimOp]] to a precedence number, lower number means higher precedence
+ *
+ * Only the [[PrimOp]]s contained in this map will be inlined. [[PrimOp]]s
+ * like [[PrimOp.Neg]] are not in this map because inlining them may result
+ * in illegal verilog like '--2sh1'
+ */
+ private val precedenceMap: Map[PrimOp, Int] = {
+ val precedenceSeq = Seq(
+ Set(Head, Tail, Bits, Shr, Pad), // Shr and Pad emit as bit select
+ Set(Andr, Orr, Xorr, Neg, Not),
+ Set(Mul, Div, Rem),
+ Set(Add, Sub, Addw, Subw),
+ Set(Dshl, Dshlw, Dshr),
+ Set(Lt, Leq, Gt, Geq),
+ Set(Eq, Neq),
+ Set(And),
+ Set(Xor),
+ Set(Or)
+ )
+ precedenceSeq.zipWithIndex.foldLeft(Map.empty[PrimOp, Int]) {
+ case (map, (ops, idx)) => map ++ ops.map(_ -> idx)
+ }
+ }
+
+ /** true if op1 has greater or equal precendence than op2
+ */
+ private def precedenceGeq(op1: PrimOp, op2: PrimOp): Boolean = {
+ precedenceMap(op1) <= precedenceMap(op2)
+ }
+
+ /** true if op1 has greater precendence than op2
+ */
+ private def precedenceGt(op1: PrimOp, op2: PrimOp): Boolean = {
+ precedenceMap(op1) < precedenceMap(op2)
+ }
+}
+
class VerilogEmitter extends SeqTransform with Emitter {
+ import VerilogEmitter._
+
def inputForm = LowForm
def outputForm = LowForm
@@ -280,8 +321,42 @@ class VerilogEmitter extends SeqTransform with Emitter {
case ClockType | AsyncResetType => ""
case _ => throwInternalError(s"trying to write unsupported type in the Verilog Emitter: $tpe")
}
- def emit(x: Any)(implicit w: Writer): Unit = { emit(x, 0) }
+ private def getLeadingTabs(x: Any): String = {
+ x match {
+ case seq: Seq[_] =>
+ val head = seq.takeWhile(_ == tab).mkString
+ val tail = seq.dropWhile(_ == tab).lift(0).map(getLeadingTabs).getOrElse(tab)
+ head + tail
+ case _ => tab
+ }
+ }
+ def emit(x: Any)(implicit w: Writer): Unit = {
+ emitCol(x, 0, getLeadingTabs(x), 0)
+ }
+ private def emitCast(e: Expression): Any = e.tpe match {
+ case (t: UIntType) => e
+ case (t: SIntType) => Seq("$signed(", e, ")")
+ case ClockType => e
+ case AnalogType(_) => e
+ case _ => throwInternalError(s"unrecognized cast: $e")
+ }
def emit(x: Any, top: Int)(implicit w: Writer): Unit = {
+ emitCol(x, top, "", 0)
+ }
+ private val maxCol = 120
+ private def emitCol(x: Any, top: Int, tabs: String, colNum: Int)(implicit w: Writer): Int = {
+ def writeCol(contents: String): Int = {
+ if ((contents.size + colNum) > maxCol) {
+ w.write("\n")
+ w.write(tabs)
+ w.write(contents)
+ tabs.size + contents.size
+ } else {
+ w.write(contents)
+ colNum + contents.size
+ }
+ }
+
def cast(e: Expression): Any = e.tpe match {
case (t: UIntType) => e
case (t: SIntType) => Seq("$signed(", e, ")")
@@ -290,7 +365,7 @@ class VerilogEmitter extends SeqTransform with Emitter {
case _ => throwInternalError(s"unrecognized cast: $e")
}
x match {
- case (e: DoPrim) => emit(op_stream(e), top + 1)
+ case (e: DoPrim) => emitCol(op_stream(e), top + 1, tabs, colNum)
case (e: Mux) => {
if (e.tpe == ClockType) {
throw EmitterException("Cannot emit clock muxes directly")
@@ -298,51 +373,64 @@ class VerilogEmitter extends SeqTransform with Emitter {
if (e.tpe == AsyncResetType) {
throw EmitterException("Cannot emit async reset muxes directly")
}
- emit(Seq(e.cond, " ? ", cast(e.tval), " : ", cast(e.fval)), top + 1)
+ emitCol(Seq(e.cond, " ? ", cast(e.tval), " : ", cast(e.fval)), top + 1, tabs, colNum)
}
- case (e: ValidIf) => emit(Seq(cast(e.value)), top + 1)
- case (e: WRef) => w.write(e.serialize)
- case (e: WSubField) => w.write(LowerTypes.loweredName(e))
- case (e: WSubAccess) => w.write(s"${LowerTypes.loweredName(e.expr)}[${LowerTypes.loweredName(e.index)}]")
- case (e: WSubIndex) => w.write(e.serialize)
- case (e: Literal) => v_print(e)
- case (e: VRandom) => w.write(s"{${e.nWords}{`RANDOM}}")
- case (t: GroundType) => w.write(stringify(t))
+ case (e: ValidIf) => emitCol(Seq(cast(e.value)), top + 1, tabs, colNum)
+ case (e: WRef) => writeCol(e.serialize)
+ case (e: WSubField) => writeCol(LowerTypes.loweredName(e))
+ case (e: WSubAccess) => writeCol(s"${LowerTypes.loweredName(e.expr)}[${LowerTypes.loweredName(e.index)}]")
+ case (e: WSubIndex) => writeCol(e.serialize)
+ case (e: Literal) => v_print(e, colNum)
+ case (e: VRandom) => writeCol(s"{${e.nWords}{`RANDOM}}")
+ case (t: GroundType) => writeCol(stringify(t))
case (t: VectorType) =>
emit(t.tpe, top + 1)
- w.write(s"[${t.size - 1}:0]")
- case (s: String) => w.write(s)
- case (i: Int) => w.write(i.toString)
- case (i: Long) => w.write(i.toString)
- case (i: BigInt) => w.write(i.toString)
+ writeCol(s"[${t.size - 1}:0]")
+ case (s: String) => writeCol(s)
+ case (i: Int) => writeCol(i.toString)
+ case (i: Long) => writeCol(i.toString)
+ case (i: BigInt) => writeCol(i.toString)
case (i: Info) =>
i match {
- case NoInfo => // Do nothing
+ case NoInfo => colNum // Do nothing
case f: FileInfo =>
val escaped = FileInfo.escapedToVerilog(f.escaped)
w.write(s" // @[$escaped]")
+ colNum
case m: MultiInfo =>
val escaped = FileInfo.escapedToVerilog(m.flatten.map(_.escaped).mkString(" "))
w.write(s" // @[$escaped]")
+ colNum
}
case (s: Seq[Any]) =>
- s.foreach(emit(_, top + 1))
- if (top == 0) w.write("\n")
+ val nextColNum = s.foldLeft(colNum) {
+ case (colNum, e) => emitCol(e, top + 1, tabs, colNum)
+ }
+ if (top == 0) {
+ w.write("\n")
+ 0
+ } else {
+ nextColNum
+ }
case x => throwInternalError(s"trying to emit unsupported operator: $x")
}
}
//;------------- PASS -----------------
- def v_print(e: Expression)(implicit w: Writer) = e match {
+ def v_print(e: Expression, colNum: Int)(implicit w: Writer) = e match {
case UIntLiteral(value, IntWidth(width)) =>
- w.write(s"$width'h${value.toString(16)}")
+ val contents = s"$width'h${value.toString(16)}"
+ w.write(contents)
+ colNum + contents.size
case SIntLiteral(value, IntWidth(width)) =>
val stringLiteral = value.toString(16)
- w.write(stringLiteral.head match {
+ val contents = stringLiteral.head match {
case '-' if value == FixAddingNegativeLiterals.minNegValue(width) => s"$width'sh${stringLiteral.tail}"
case '-' => s"-$width'sh${stringLiteral.tail}"
case _ => s"$width'sh${stringLiteral}"
- })
+ }
+ w.write(contents)
+ colNum + contents.size
case _ => throwInternalError(s"attempt to print unrecognized expression: $e")
}
@@ -350,29 +438,62 @@ class VerilogEmitter extends SeqTransform with Emitter {
// reference is actually unsigned in the emitted Verilog. Thus we must cast refs as necessary
// to ensure Verilog operations are signed.
def op_stream(doprim: DoPrim): Seq[Any] = {
+ def parenthesize(e: Expression, isFirst: Boolean): Any = doprim.op match {
+ // these PrimOps emit either {..., a0, ...} or a0 so they never need parentheses
+ case Shl | Cat | Cvt | AsUInt | AsSInt | AsClock | AsAsyncReset => e
+ case _ =>
+ e match {
+ case e: DoPrim =>
+ op_stream(e) match {
+ /** DoPrims like AsUInt simply emit Seq(a0), so we need to
+ * recursively check whether a0 needs to be parenthesized
+ */
+ case Seq(passthrough: Expression) => parenthesize(passthrough, isFirst)
+
+ /** If the expression is the first argument then it does not need
+ * parens if it's precedence is greather than or equal to the
+ * enclosing doprim, because verilog operators are left
+ * associative. All other args do not need parens only if the
+ * precedence is greater.
+ */
+ case other =>
+ if (precedenceGt(e.op, doprim.op) || (precedenceGeq(e.op, doprim.op) && isFirst)) {
+ other
+ } else {
+ Seq("(", other, ")")
+ }
+ }
+
+ /** Mux args should always have parens because Mux has the lowest precedence
+ */
+ case _: Mux => Seq("(", e, ")")
+ case _ => e
+ }
+ }
+
// Cast to SInt, don't cast multiple times
def doCast(e: Expression): Any = e match {
case DoPrim(AsSInt, Seq(arg), _, _) => doCast(arg)
case slit: SIntLiteral => slit
case other => Seq("$signed(", other, ")")
}
- def castIf(e: Expression): Any = {
+ def castIf(e: Expression, isFirst: Boolean = false): Any = {
if (doprim.args.exists(_.tpe.isInstanceOf[SIntType])) {
e.tpe match {
case _: SIntType => doCast(e)
case _ => throwInternalError(s"Unexpected non-SInt type for $e in $doprim")
}
} else {
- e
+ parenthesize(e, isFirst)
}
}
- def cast(e: Expression): Any = doprim.tpe match {
- case _: UIntType => e
+ def cast(e: Expression, isFirst: Boolean = false): Any = doprim.tpe match {
+ case _: UIntType => parenthesize(e, isFirst)
case _: SIntType => doCast(e)
case _ => throwInternalError(s"Unexpected type for $e in $doprim")
}
- def castAs(e: Expression): Any = e.tpe match {
- case _: UIntType => e
+ def castAs(e: Expression, isFirst: Boolean = false): Any = e.tpe match {
+ case _: UIntType => parenthesize(e, isFirst)
case _: SIntType => doCast(e)
case _ => throwInternalError(s"Unexpected type for $e in $doprim")
}
@@ -381,19 +502,6 @@ class VerilogEmitter extends SeqTransform with Emitter {
def c0: Int = doprim.consts.head.toInt
def c1: Int = doprim.consts(1).toInt
- def checkArgumentLegality(e: Expression): Unit = e match {
- case _: UIntLiteral | _: SIntLiteral | _: WRef | _: WSubField =>
- case DoPrim(Not, args, _, _) => args.foreach(checkArgumentLegality)
- case DoPrim(op, args, _, _) if isCast(op) => args.foreach(checkArgumentLegality)
- case DoPrim(op, args, _, _) if isBitExtract(op) => args.foreach(checkArgumentLegality)
- case _ => throw EmitterException(s"Can't emit ${e.getClass.getName} as PrimOp argument")
- }
-
- def checkCatArgumentLegality(e: Expression): Unit = e match {
- case DoPrim(Cat, args, _, _) => args.foreach(checkCatArgumentLegality)
- case _ => checkArgumentLegality(e)
- }
-
def castCatArgs(a0: Expression, a1: Expression): Seq[Any] = {
val a0Seq = a0 match {
case cat @ DoPrim(PrimOps.Cat, args, _, _) => castCatArgs(args.head, args(1))
@@ -407,24 +515,19 @@ class VerilogEmitter extends SeqTransform with Emitter {
}
doprim.op match {
- case Cat => doprim.args.foreach(checkCatArgumentLegality)
- case cast if isCast(cast) => // Casts are allowed to wrap any Expression
- case other => doprim.args.foreach(checkArgumentLegality)
- }
- doprim.op match {
- case Add => Seq(castIf(a0), " + ", castIf(a1))
- case Addw => Seq(castIf(a0), " + ", castIf(a1))
- case Sub => Seq(castIf(a0), " - ", castIf(a1))
- case Subw => Seq(castIf(a0), " - ", castIf(a1))
- case Mul => Seq(castIf(a0), " * ", castIf(a1))
- case Div => Seq(castIf(a0), " / ", castIf(a1))
- case Rem => Seq(castIf(a0), " % ", castIf(a1))
- case Lt => Seq(castIf(a0), " < ", castIf(a1))
- case Leq => Seq(castIf(a0), " <= ", castIf(a1))
- case Gt => Seq(castIf(a0), " > ", castIf(a1))
- case Geq => Seq(castIf(a0), " >= ", castIf(a1))
- case Eq => Seq(castIf(a0), " == ", castIf(a1))
- case Neq => Seq(castIf(a0), " != ", castIf(a1))
+ case Add => Seq(castIf(a0, true), " + ", castIf(a1))
+ case Addw => Seq(castIf(a0, true), " + ", castIf(a1))
+ case Sub => Seq(castIf(a0, true), " - ", castIf(a1))
+ case Subw => Seq(castIf(a0, true), " - ", castIf(a1))
+ case Mul => Seq(castIf(a0, true), " * ", castIf(a1))
+ case Div => Seq(castIf(a0, true), " / ", castIf(a1))
+ case Rem => Seq(castIf(a0, true), " % ", castIf(a1))
+ case Lt => Seq(castIf(a0, true), " < ", castIf(a1))
+ case Leq => Seq(castIf(a0, true), " <= ", castIf(a1))
+ case Gt => Seq(castIf(a0, true), " > ", castIf(a1))
+ case Geq => Seq(castIf(a0, true), " >= ", castIf(a1))
+ case Eq => Seq(castIf(a0, true), " == ", castIf(a1))
+ case Neq => Seq(castIf(a0, true), " != ", castIf(a1))
case Pad =>
val w = bitWidth(a0.tpe)
val diff = c0 - w
@@ -434,7 +537,7 @@ class VerilogEmitter extends SeqTransform with Emitter {
// Either sign extend or zero extend.
// If width == BigInt(1), don't extract bit
case (_: SIntType) if w == BigInt(1) => Seq("{", c0, "{", a0, "}}")
- case (_: SIntType) => Seq("{{", diff, "{", a0, "[", w - 1, "]}},", a0, "}")
+ case (_: SIntType) => Seq("{{", diff, "{", parenthesize(a0, true), "[", w - 1, "]}},", a0, "}")
case (_) => Seq("{{", diff, "'d0}, ", a0, "}")
}
// Because we don't support complex Expressions, all casts are ignored
@@ -451,35 +554,35 @@ class VerilogEmitter extends SeqTransform with Emitter {
case Shl => if (c0 > 0) Seq("{", cast(a0), s", $c0'h0}") else Seq(cast(a0))
case Shr if c0 >= bitWidth(a0.tpe) =>
error("Verilog emitter does not support SHIFT_RIGHT >= arg width")
- case Shr if c0 == (bitWidth(a0.tpe) - 1) => Seq(a0, "[", bitWidth(a0.tpe) - 1, "]")
- case Shr => Seq(a0, "[", bitWidth(a0.tpe) - 1, ":", c0, "]")
- case Neg => Seq("-", cast(a0))
+ case Shr if c0 == (bitWidth(a0.tpe) - 1) => Seq(parenthesize(a0, true), "[", bitWidth(a0.tpe) - 1, "]")
+ case Shr => Seq(parenthesize(a0, true), "[", bitWidth(a0.tpe) - 1, ":", c0, "]")
+ case Neg => Seq("-", cast(a0, true))
case Cvt =>
a0.tpe match {
case (_: UIntType) => Seq("{1'b0,", cast(a0), "}")
case (_: SIntType) => Seq(cast(a0))
}
- case Not => Seq("~", a0)
- case And => Seq(castAs(a0), " & ", castAs(a1))
- case Or => Seq(castAs(a0), " | ", castAs(a1))
- case Xor => Seq(castAs(a0), " ^ ", castAs(a1))
- case Andr => Seq("&", cast(a0))
- case Orr => Seq("|", cast(a0))
- case Xorr => Seq("^", cast(a0))
+ case Not => Seq("~", parenthesize(a0, true))
+ case And => Seq(castAs(a0, true), " & ", castAs(a1))
+ case Or => Seq(castAs(a0, true), " | ", castAs(a1))
+ case Xor => Seq(castAs(a0, true), " ^ ", castAs(a1))
+ case Andr => Seq("&", cast(a0, true))
+ case Orr => Seq("|", cast(a0, true))
+ case Xorr => Seq("^", cast(a0, true))
case Cat => "{" +: (castCatArgs(a0, a1) :+ "}")
// If selecting zeroth bit and single-bit wire, just emit the wire
case Bits if c0 == 0 && c1 == 0 && bitWidth(a0.tpe) == BigInt(1) => Seq(a0)
- case Bits if c0 == c1 => Seq(a0, "[", c0, "]")
- case Bits => Seq(a0, "[", c0, ":", c1, "]")
+ case Bits if c0 == c1 => Seq(parenthesize(a0, true), "[", c0, "]")
+ case Bits => Seq(parenthesize(a0, true), "[", c0, ":", c1, "]")
// If selecting zeroth bit and single-bit wire, just emit the wire
case Head if c0 == 1 && bitWidth(a0.tpe) == BigInt(1) => Seq(a0)
case Head if c0 == 1 => Seq(a0, "[", bitWidth(a0.tpe) - 1, "]")
case Head =>
val msb = bitWidth(a0.tpe) - 1
val lsb = bitWidth(a0.tpe) - c0
- Seq(a0, "[", msb, ":", lsb, "]")
- case Tail if c0 == (bitWidth(a0.tpe) - 1) => Seq(a0, "[0]")
- case Tail => Seq(a0, "[", bitWidth(a0.tpe) - c0 - 1, ":0]")
+ Seq(parenthesize(a0, true), "[", msb, ":", lsb, "]")
+ case Tail if c0 == (bitWidth(a0.tpe) - 1) => Seq(parenthesize(a0, true), "[0]")
+ case Tail => Seq(parenthesize(a0, true), "[", bitWidth(a0.tpe) - c0 - 1, ":0]")
}
}
@@ -804,7 +907,7 @@ class VerilogEmitter extends SeqTransform with Emitter {
}
def regUpdate(r: Expression, clk: Expression, reset: Expression, init: Expression) = {
- def addUpdate(info: Info, expr: Expression, tabs: String): Seq[Seq[Any]] = expr match {
+ def addUpdate(info: Info, expr: Expression, tabs: Seq[String]): Seq[Seq[Any]] = expr match {
case m: Mux =>
if (m.tpe == ClockType) throw EmitterException("Cannot emit clock muxes directly")
if (m.tpe == AsyncResetType) throw EmitterException("Cannot emit async reset muxes directly")
@@ -814,8 +917,8 @@ class VerilogEmitter extends SeqTransform with Emitter {
lazy val _else = Seq(tabs, "end else begin")
lazy val _ifNot = Seq(tabs, "if (!(", m.cond, ")) begin", eninfo)
lazy val _end = Seq(tabs, "end")
- lazy val _true = addUpdate(tinfo, m.tval, tabs + tab)
- lazy val _false = addUpdate(finfo, m.fval, tabs + tab)
+ lazy val _true = addUpdate(tinfo, m.tval, tab +: tabs)
+ lazy val _false = addUpdate(finfo, m.fval, tab +: tabs)
lazy val _elseIfFalse = {
val _falsex = addUpdate(finfo, m.fval, tabs) // _false, but without an additional tab
Seq(tabs, "end else ", _falsex.head.tail) +: _falsex.tail
@@ -845,13 +948,19 @@ class VerilogEmitter extends SeqTransform with Emitter {
}
if (weq(init, r)) { // Synchronous Reset
val InfoExpr(info, e) = netlist(r)
- noResetAlwaysBlocks.getOrElseUpdate(clk, ArrayBuffer[Seq[Any]]()) ++= addUpdate(info, e, "")
+ noResetAlwaysBlocks.getOrElseUpdate(clk, ArrayBuffer[Seq[Any]]()) ++= addUpdate(info, e, Seq.empty)
} else { // Asynchronous Reset
assert(reset.tpe == AsyncResetType, "Error! Synchronous reset should have been removed!")
val tv = init
val InfoExpr(finfo, fv) = netlist(r)
// TODO add register info argument and build a MultiInfo to pass
- asyncResetAlwaysBlocks += ((clk, reset, addUpdate(NoInfo, Mux(reset, tv, fv, mux_type_and_widths(tv, fv)), "")))
+ asyncResetAlwaysBlocks += (
+ (
+ clk,
+ reset,
+ addUpdate(NoInfo, Mux(reset, tv, fv, mux_type_and_widths(tv, fv)), Seq.empty)
+ )
+ )
}
}
@@ -1367,11 +1476,18 @@ class VerilogEmitter extends SeqTransform with Emitter {
}
override def execute(state: CircuitState): CircuitState = {
+ val writerToString =
+ (writer: java.io.StringWriter) => writer.toString.replaceAll("""(?m) +$""", "") // trim trailing whitespace
+
val newAnnos = state.annotations.flatMap {
case EmitCircuitAnnotation(a) if this.getClass == a =>
val writer = new java.io.StringWriter
emit(state, writer)
- Seq(EmittedVerilogCircuitAnnotation(EmittedVerilogCircuit(state.circuit.main, writer.toString, outputSuffix)))
+ Seq(
+ EmittedVerilogCircuitAnnotation(
+ EmittedVerilogCircuit(state.circuit.main, writerToString(writer), outputSuffix)
+ )
+ )
case EmitAllModulesAnnotation(a) if this.getClass == a =>
val cs = runTransforms(state)
@@ -1383,12 +1499,16 @@ class VerilogEmitter extends SeqTransform with Emitter {
val writer = new java.io.StringWriter
val renderer = new VerilogRender(d, pds, module, moduleMap, cs.circuit.main, emissionOptions)(writer)
renderer.emit_verilog()
- Some(EmittedVerilogModuleAnnotation(EmittedVerilogModule(module.name, writer.toString, outputSuffix)))
+ Some(
+ EmittedVerilogModuleAnnotation(EmittedVerilogModule(module.name, writerToString(writer), outputSuffix))
+ )
case module: Module =>
val writer = new java.io.StringWriter
val renderer = new VerilogRender(module, moduleMap, cs.circuit.main, emissionOptions)(writer)
renderer.emit_verilog()
- Some(EmittedVerilogModuleAnnotation(EmittedVerilogModule(module.name, writer.toString, outputSuffix)))
+ Some(
+ EmittedVerilogModuleAnnotation(EmittedVerilogModule(module.name, writerToString(writer), outputSuffix))
+ )
case _ => None
}
case _ => Seq()
diff --git a/src/main/scala/firrtl/stage/Forms.scala b/src/main/scala/firrtl/stage/Forms.scala
index a0c5ea0c..db411325 100644
--- a/src/main/scala/firrtl/stage/Forms.scala
+++ b/src/main/scala/firrtl/stage/Forms.scala
@@ -110,7 +110,11 @@ object Forms {
Dependency[firrtl.AddDescriptionNodes]
)
- val VerilogOptimized: Seq[TransformDependency] = LowFormOptimized ++ VerilogMinimumOptimized
+ val VerilogOptimized: Seq[TransformDependency] = LowFormOptimized ++
+ Seq(
+ Dependency[firrtl.transforms.InlineBooleanExpressions]
+ ) ++
+ VerilogMinimumOptimized
val AssertsRemoved: Seq[TransformDependency] =
Seq(
diff --git a/src/main/scala/firrtl/transforms/InlineBooleanExpressions.scala b/src/main/scala/firrtl/transforms/InlineBooleanExpressions.scala
new file mode 100644
index 00000000..7c52d6ef
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/InlineBooleanExpressions.scala
@@ -0,0 +1,169 @@
+// See LICENSE for license details.
+
+package firrtl
+package transforms
+
+import firrtl.annotations.{NoTargetAnnotation, Target}
+import firrtl.annotations.TargetToken.{fromStringToTargetToken, OfModule, Ref}
+import firrtl.ir._
+import firrtl.passes.{InferTypes, LowerTypes, SplitExpressions}
+import firrtl.options.Dependency
+import firrtl.PrimOps._
+import firrtl.WrappedExpression._
+
+import scala.collection.mutable
+
+case class InlineBooleanExpressionsMax(max: Int) extends NoTargetAnnotation
+
+object InlineBooleanExpressions {
+ val defaultMax = 30
+}
+
+/** Inline Bool expressions
+ *
+ * The following conditions must be satisfied to inline
+ * 1. has type [[Utils.BoolType]]
+ * 2. is bound to a [[firrtl.ir.DefNode DefNode]] with name starting with '_'
+ * 3. is bound to a [[firrtl.ir.DefNode DefNode]] with a source locator that
+ * points at the same file and line number. If it is a MultiInfo source
+ * locator, the set of file and line number pairs must be the same. Source
+ * locators may point to different column numbers.
+ * 4. [[InlineBooleanExpressionsMax]] has not been exceeded
+ * 5. is not a [[firrtl.ir.Mux Mux]]
+ */
+class InlineBooleanExpressions extends Transform with DependencyAPIMigration {
+
+ override def prerequisites = Seq(
+ Dependency(InferTypes),
+ Dependency(LowerTypes)
+ )
+
+ override def optionalPrerequisites = Seq(
+ Dependency(SplitExpressions)
+ )
+
+ override def invalidates(a: Transform) = a match {
+ case _: DeadCodeElimination => true // this transform does not remove nodes that are unused after inlining
+ case _ => false
+ }
+
+ type Netlist = mutable.HashMap[WrappedExpression, (Expression, Info)]
+
+ private def isArgN(outerExpr: DoPrim, subExpr: Expression, n: Int): Boolean = {
+ outerExpr.args.lift(n) match {
+ case Some(arg) => arg eq subExpr
+ case _ => false
+ }
+ }
+
+ private val fileLineRegex = """(.*) ([0-9]+):[0-9]+""".r
+ private def sameFileAndLineInfo(info1: Info, info2: Info): Boolean = {
+ (info1, info2) match {
+ case (FileInfo(fileLineRegex(file1, line1)), FileInfo(fileLineRegex(file2, line2))) =>
+ (file1 == file2) && (line1 == line2)
+ case (MultiInfo(infos1), MultiInfo(infos2)) if infos1.size == infos2.size =>
+ infos1.zip(infos2).forall {
+ case (i1, i2) =>
+ sameFileAndLineInfo(i1, i2)
+ }
+ case (NoInfo, NoInfo) => true
+ case _ => false
+ }
+ }
+
+ /** A helper class to initialize and store mutable state that the expression
+ * and statement map functions need access to. This makes it easier to pass
+ * information around without having to plump arguments through the onExpr
+ * and onStmt methods.
+ */
+ private class MapMethods(maxInlineCount: Int, dontTouches: Set[Ref]) {
+ val netlist: Netlist = new Netlist
+ val inlineCounts = mutable.Map.empty[Ref, Int]
+ var inlineCount: Int = 1
+
+ /** Whether or not an can be inlined
+ * @param refExpr the expression to check for inlining
+ */
+ def canInline(refExpr: Expression): Boolean = {
+ refExpr match {
+ case _: Mux => false
+ case _ => refExpr.tpe == Utils.BoolType
+ }
+ }
+
+ /** Inlines [[Wref]]s if they are Boolean, have matching file line numbers,
+ * and would not raise inlineCounts past the maximum.
+ *
+ * @param info the [[Info]] of the enclosing [[Statement]]
+ * @param outerExpr the direct parent [[Expression]] of the current [[Expression]]
+ * @param expr the [[Expression]] to apply inlining to
+ */
+ def onExpr(info: Info, outerExpr: Option[Expression])(expr: Expression): Expression = {
+ expr match {
+ case ref: WRef if !dontTouches.contains(ref.name.Ref) && ref.name.head == '_' =>
+ val refKey = ref.name.Ref
+ netlist.get(we(ref)) match {
+ case Some((refExpr, refInfo)) if sameFileAndLineInfo(info, refInfo) =>
+ val inlineNum = inlineCounts.getOrElse(refKey, 1)
+ if (!outerExpr.isDefined || canInline(refExpr) && ((inlineNum + inlineCount) <= maxInlineCount)) {
+ inlineCount += inlineNum
+ refExpr
+ } else {
+ ref
+ }
+ case other => ref
+ }
+ case other => other.mapExpr(onExpr(info, Some(other)))
+ }
+ }
+
+ /** Applies onExpr and records metadata for every [[HasInfo]] in a [[Statement]]
+ *
+ * This resets inlineCount before inlining and records the resulting
+ * inline counts and inlined values in the inlineCounts and netlist maps
+ * after inlining.
+ */
+ def onStmt(stmt: Statement): Statement = {
+ stmt.mapStmt(onStmt) match {
+ case hasInfo: HasInfo =>
+ inlineCount = 1
+ val stmtx = hasInfo.mapExpr(onExpr(hasInfo.info, None))
+ stmtx match {
+ case node: DefNode => inlineCounts(node.name.Ref) = inlineCount
+ case _ =>
+ }
+ stmtx match {
+ case node @ DefNode(info, name, value) =>
+ netlist(we(WRef(name))) = (value, info)
+ case _ =>
+ }
+ stmtx
+ case other => other
+ }
+ }
+ }
+
+ def execute(state: CircuitState): CircuitState = {
+ val dontTouchMap: Map[OfModule, Set[Ref]] = {
+ val refTargets = state.annotations.flatMap {
+ case anno: HasDontTouches => anno.dontTouches
+ case o => Nil
+ }
+ val dontTouches: Seq[(OfModule, Ref)] = refTargets.map {
+ case r => Target.referringModule(r).module.OfModule -> r.ref.Ref
+ }
+ dontTouches.groupBy(_._1).mapValues(_.map(_._2).toSet).toMap
+ }
+
+ val maxInlineCount = state.annotations.collectFirst {
+ case InlineBooleanExpressionsMax(max) => max
+ }.getOrElse(InlineBooleanExpressions.defaultMax)
+
+ val modulesx = state.circuit.modules.map { m =>
+ val mapMethods = new MapMethods(maxInlineCount, dontTouchMap.getOrElse(m.name.OfModule, Set.empty[Ref]))
+ m.mapStmt(mapMethods.onStmt(_))
+ }
+
+ state.copy(circuit = state.circuit.copy(modules = modulesx))
+ }
+}
diff --git a/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala b/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala
new file mode 100644
index 00000000..5fee87c9
--- /dev/null
+++ b/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala
@@ -0,0 +1,242 @@
+
+// See LICENSE for license details.
+
+package firrtlTests
+
+import firrtl._
+import firrtl.annotations.Annotation
+import firrtl.options.Dependency
+import firrtl.passes._
+import firrtl.transforms._
+import firrtl.testutils._
+import firrtl.stage.TransformManager
+
+class InlineBooleanExpressionsSpec extends FirrtlFlatSpec {
+ val transform = new InlineBooleanExpressions
+ val transforms: Seq[Transform] = new TransformManager(
+ transform.prerequisites
+ ).flattenedTransformOrder :+ transform
+
+ protected def exec(input: String, annos: Seq[Annotation] = Nil) = {
+ transforms.foldLeft(CircuitState(parse(input), UnknownForm, AnnotationSeq(annos))) {
+ (c: CircuitState, t: Transform) => t.runTransform(c)
+ }.circuit.serialize
+ }
+
+ it should "inline mux operands" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | output out : UInt<1>
+ | node x1 = UInt<1>(0)
+ | node x2 = UInt<1>(1)
+ | node _t = head(x1, 1)
+ | node _f = head(x2, 1)
+ | node _c = lt(x1, x2)
+ | node _y = mux(_c, _t, _f)
+ | out <= _y""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | output out : UInt<1>
+ | node x1 = UInt<1>(0)
+ | node x2 = UInt<1>(1)
+ | node _t = head(x1, 1)
+ | node _f = head(x2, 1)
+ | node _c = lt(x1, x2)
+ | node _y = mux(lt(x1, x2), head(x1, 1), head(x2, 1))
+ | out <= mux(lt(x1, x2), head(x1, 1), head(x2, 1))""".stripMargin
+ val result = exec(input)
+ (result) should be (parse(check).serialize)
+ firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions))
+ }
+
+ it should "only inline expressions with the same file and line number" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | output outA1 : UInt<1>
+ | output outA2 : UInt<1>
+ | output outB : UInt<1>
+ | node x1 = UInt<1>(0)
+ | node x2 = UInt<1>(1)
+ |
+ | node _t = head(x1, 1) @[A 1:1]
+ | node _f = head(x2, 1) @[A 1:2]
+ | node _y = mux(lt(x1, x2), _t, _f) @[A 1:3]
+ | outA1 <= _y @[A 1:3]
+ |
+ | outA2 <= _y @[A 2:3]
+ |
+ | outB <= _y @[B]""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | output outA1 : UInt<1>
+ | output outA2 : UInt<1>
+ | output outB : UInt<1>
+ | node x1 = UInt<1>(0)
+ | node x2 = UInt<1>(1)
+ |
+ | node _t = head(x1, 1) @[A 1:1]
+ | node _f = head(x2, 1) @[A 1:2]
+ | node _y = mux(lt(x1, x2), head(x1, 1), head(x2, 1)) @[A 1:3]
+ | outA1 <= mux(lt(x1, x2), head(x1, 1), head(x2, 1)) @[A 1:3]
+ |
+ | outA2 <= _y @[A 2:3]
+ |
+ | outB <= _y @[B]""".stripMargin
+ val result = exec(input)
+ (result) should be (parse(check).serialize)
+ firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions))
+ }
+
+ it should "inline boolean DoPrims" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | output outA : UInt<1>
+ | output outB : UInt<1>
+ | node x1 = UInt<3>(0)
+ | node x2 = UInt<3>(1)
+ |
+ | node _a = lt(x1, x2)
+ | node _b = eq(_a, x2)
+ | node _c = and(_b, x2)
+ | outA <= _c
+ |
+ | node _d = head(_c, 1)
+ | node _e = andr(_d)
+ | node _f = lt(_e, x2)
+ | outB <= _f""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | output outA : UInt<1>
+ | output outB : UInt<1>
+ | node x1 = UInt<3>(0)
+ | node x2 = UInt<3>(1)
+ |
+ | node _a = lt(x1, x2)
+ | node _b = eq(lt(x1, x2), x2)
+ | node _c = and(eq(lt(x1, x2), x2), x2)
+ | outA <= and(eq(lt(x1, x2), x2), x2)
+ |
+ | node _d = head(_c, 1)
+ | node _e = andr(head(_c, 1))
+ | node _f = lt(andr(head(_c, 1)), x2)
+ |
+ | outB <= lt(andr(head(_c, 1)), x2)""".stripMargin
+ val result = exec(input)
+ (result) should be (parse(check).serialize)
+ firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions))
+ }
+
+ it should "inline more boolean DoPrims" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | output outA : UInt<1>
+ | output outB : UInt<1>
+ | node x1 = UInt<3>(0)
+ | node x2 = UInt<3>(1)
+ |
+ | node _a = lt(x1, x2)
+ | node _b = leq(_a, x2)
+ | node _c = gt(_b, x2)
+ | node _d = geq(_c, x2)
+ | outA <= _d
+ |
+ | node _e = lt(x1, x2)
+ | node _f = leq(x1, _e)
+ | node _g = gt(x1, _f)
+ | node _h = geq(x1, _g)
+ | outB <= _h""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | output outA : UInt<1>
+ | output outB : UInt<1>
+ | node x1 = UInt<3>(0)
+ | node x2 = UInt<3>(1)
+ |
+ | node _a = lt(x1, x2)
+ | node _b = leq(lt(x1, x2), x2)
+ | node _c = gt(leq(lt(x1, x2), x2), x2)
+ | node _d = geq(gt(leq(lt(x1, x2), x2), x2), x2)
+ | outA <= geq(gt(leq(lt(x1, x2), x2), x2), x2)
+ |
+ | node _e = lt(x1, x2)
+ | node _f = leq(x1, lt(x1, x2))
+ | node _g = gt(x1, leq(x1, lt(x1, x2)))
+ | node _h = geq(x1, gt(x1, leq(x1, lt(x1, x2))))
+ |
+ | outB <= geq(x1, gt(x1, leq(x1, lt(x1, x2))))""".stripMargin
+ val result = exec(input)
+ (result) should be (parse(check).serialize)
+ firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions))
+ }
+
+ it should "limit the number of inlines" in {
+ val input =
+ s"""circuit Top :
+ | module Top :
+ | input c_0: UInt<1>
+ | input c_1: UInt<1>
+ | input c_2: UInt<1>
+ | input c_3: UInt<1>
+ | input c_4: UInt<1>
+ | input c_5: UInt<1>
+ | input c_6: UInt<1>
+ | output out : UInt<1>
+ |
+ | node _1 = or(c_0, c_1)
+ | node _2 = or(_1, c_2)
+ | node _3 = or(_2, c_3)
+ | node _4 = or(_3, c_4)
+ | node _5 = or(_4, c_5)
+ | node _6 = or(_5, c_6)
+ |
+ | out <= _6""".stripMargin
+ val check =
+ s"""circuit Top :
+ | module Top :
+ | input c_0: UInt<1>
+ | input c_1: UInt<1>
+ | input c_2: UInt<1>
+ | input c_3: UInt<1>
+ | input c_4: UInt<1>
+ | input c_5: UInt<1>
+ | input c_6: UInt<1>
+ | output out : UInt<1>
+ |
+ | node _1 = or(c_0, c_1)
+ | node _2 = or(or(c_0, c_1), c_2)
+ | node _3 = or(or(or(c_0, c_1), c_2), c_3)
+ | node _4 = or(_3, c_4)
+ | node _5 = or(or(_3, c_4), c_5)
+ | node _6 = or(or(or(_3, c_4), c_5), c_6)
+ |
+ | out <= or(or(or(_3, c_4), c_5), c_6)""".stripMargin
+ val result = exec(input, Seq(InlineBooleanExpressionsMax(3)))
+ (result) should be (parse(check).serialize)
+ firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions))
+ }
+
+ it should "be equivalent" in {
+ val input =
+ """circuit InlineBooleanExpressionsEquivalenceTest :
+ | module InlineBooleanExpressionsEquivalenceTest :
+ | input in : UInt<1>[6]
+ | output out : UInt<1>
+ |
+ | node _a = or(in[0], in[1])
+ | node _b = and(in[2], _a)
+ | node _c = eq(in[3], _b)
+ | node _d = lt(in[4], _c)
+ | node _e = eq(in[5], _d)
+ | node _f = head(_e, 1)
+ | out <= _f""".stripMargin
+ firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions))
+ }
+}
diff --git a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
index 46416619..40f8f123 100644
--- a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
+++ b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
@@ -260,6 +260,8 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers {
it should "replicate the old order" in {
val legacy = Seq(
+ new firrtl.transforms.InlineBooleanExpressions,
+ new firrtl.transforms.DeadCodeElimination,
new firrtl.transforms.BlackBoxSourceHelper,
new firrtl.transforms.FixAddingNegativeLiterals,
new firrtl.transforms.ReplaceTruncatingArithmetic,
diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala
index 8f128274..a864bfe5 100644
--- a/src/test/scala/firrtlTests/UnitTests.scala
+++ b/src/test/scala/firrtlTests/UnitTests.scala
@@ -110,18 +110,8 @@ class UnitTests extends FirrtlFlatSpec {
| out <= bits(mux(a, b, c), 0, 0)
|""".stripMargin
- "Emitting a nested expression" should "throw an exception" in {
+ "Emitting a nested expression" should "compile" in {
val passes = Seq(ToWorkingIR, InferTypes, ResolveKinds)
- intercept[PassException] {
- val c = Parser.parse(splitExpTestCode.split("\n").toIterator)
- val c2 = passes.foldLeft(c)((c, p) => p.run(c))
- val writer = new StringWriter()
- (new VerilogEmitter).emit(CircuitState(c2, LowForm), writer)
- }
- }
-
- "After splitting, emitting a nested expression" should "compile" in {
- val passes = Seq(ToWorkingIR, SplitExpressions, InferTypes)
val c = Parser.parse(splitExpTestCode.split("\n").toIterator)
val c2 = passes.foldLeft(c)((c, p) => p.run(c))
val writer = new StringWriter()