diff options
| -rw-r--r-- | fuzzer/src/main/scala/firrtl/ExprGenParams.scala | 65 | ||||
| -rw-r--r-- | fuzzer/src/main/scala/firrtl/FirrtlCompileTests.scala | 54 | ||||
| -rw-r--r-- | fuzzer/src/main/scala/firrtl/FirrtlEquivalenceTest.scala | 59 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Emitter.scala | 288 | ||||
| -rw-r--r-- | src/main/scala/firrtl/stage/Forms.scala | 6 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/InlineBooleanExpressions.scala | 169 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala | 242 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/LoweringCompilersSpec.scala | 2 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/UnitTests.scala | 12 |
9 files changed, 740 insertions, 157 deletions
diff --git a/fuzzer/src/main/scala/firrtl/ExprGenParams.scala b/fuzzer/src/main/scala/firrtl/ExprGenParams.scala index ddaec00d..4c11b860 100644 --- a/fuzzer/src/main/scala/firrtl/ExprGenParams.scala +++ b/fuzzer/src/main/scala/firrtl/ExprGenParams.scala @@ -1,5 +1,8 @@ package firrtl.fuzzer +import com.pholser.junit.quickcheck.generator.{Generator, GenerationStatus} +import com.pholser.junit.quickcheck.random.SourceOfRandomness + import firrtl.{Namespace, Utils} import firrtl.ir._ @@ -17,15 +20,15 @@ sealed trait ExprGenParams { */ def maxWidth: Int - /** A list of frequency/expression generator pairs + /** A mapping of expression generator to frequency * * The frequency number determines the probability that the corresponding - * generator will be chosen. i.e. for sequece Seq(1 -> A, 2 -> B, 3 -> C), - * the probabilities for A, B, and C are 1/6, 2/6, and 3/6 respectively. - * This sequency must be non-empty and all frequency numbers must be greater - * than zero. + * generator will be chosen. i.g. for Map(A -> 1, B -> 2, C -> B), the + * probabilities for A, B, and C are 1/6, 2/6, and 3/6 respectively. This + * map must be non-empty and all frequency numbers must be greater than + * zero. */ - def generators: Seq[(Int, ExprGen[_ <: Expression])] + def generators: Map[ExprGen[_ <: Expression], Int] /** The set of generated references that don't have a corresponding declaration */ @@ -102,10 +105,46 @@ sealed trait ExprGenParams { object ExprGenParams { + val defaultGenerators: Map[ExprGen[_ <: Expression], Int] = { + import ExprGen._ + Map( + AddDoPrimGen -> 1, + SubDoPrimGen -> 1, + MulDoPrimGen -> 1, + DivDoPrimGen -> 1, + LtDoPrimGen -> 1, + LeqDoPrimGen -> 1, + GtDoPrimGen -> 1, + GeqDoPrimGen -> 1, + EqDoPrimGen -> 1, + NeqDoPrimGen -> 1, + PadDoPrimGen -> 1, + ShlDoPrimGen -> 1, + ShrDoPrimGen -> 1, + DshlDoPrimGen -> 1, + CvtDoPrimGen -> 1, + NegDoPrimGen -> 1, + NotDoPrimGen -> 1, + AndDoPrimGen -> 1, + OrDoPrimGen -> 1, + XorDoPrimGen -> 1, + AndrDoPrimGen -> 1, + OrrDoPrimGen -> 1, + XorrDoPrimGen -> 1, + CatDoPrimGen -> 1, + BitsDoPrimGen -> 1, + HeadDoPrimGen -> 1, + TailDoPrimGen -> 1, + AsUIntDoPrimGen -> 1, + AsSIntDoPrimGen -> 1, + MuxGen -> 1 + ) + } + private case class ExprGenParamsImp( maxDepth: Int, maxWidth: Int, - generators: Seq[(Int, ExprGen[_ <: Expression])], + generators: Map[ExprGen[_ <: Expression], Int], protected val unboundRefs: Set[Reference], protected val namespace: Namespace) extends ExprGenParams { @@ -119,7 +158,7 @@ object ExprGenParams { def apply( maxDepth: Int, maxWidth: Int, - generators: Seq[(Int, ExprGen[_ <: Expression])] + generators: Map[ExprGen[_ <: Expression], Int] ): ExprGenParams = { require(maxWidth > 0, "maxWidth must be greater than zero") ExprGenParamsImp( @@ -190,7 +229,8 @@ object ExprGenParams { ))(tpe).map(e => e.get) // should be safe because leaf generators are defined for all types val branchGen: Type => StateGen[ExprGenParams, G, Expression] = (tpe: Type) => { - combineExprGens(s.generators)(tpe).flatMap { + val gens = s.generators.toSeq.map { case (gen, freq) => (freq, gen) } + combineExprGens(gens)(tpe).flatMap { case None => leafGen(tpe) case Some(e) => StateGen.pure(e) } @@ -211,3 +251,10 @@ object ExprGenParams { } } } + +abstract class SingleExpressionCircuitGenerator(val params: ExprGenParams) extends Generator[Circuit](classOf[Circuit]) { + override def generate(random: SourceOfRandomness, status: GenerationStatus): Circuit = { + implicit val r = random + params.generateSingleExprCircuit[SourceOfRandomnessGen]() + } +} diff --git a/fuzzer/src/main/scala/firrtl/FirrtlCompileTests.scala b/fuzzer/src/main/scala/firrtl/FirrtlCompileTests.scala index 9ee8e52b..3091e4d6 100644 --- a/fuzzer/src/main/scala/firrtl/FirrtlCompileTests.scala +++ b/fuzzer/src/main/scala/firrtl/FirrtlCompileTests.scala @@ -53,50 +53,14 @@ object SourceOfRandomnessGen { } } - -class FirrtlSingleModuleGenerator extends Generator[Circuit](classOf[Circuit]) { - override def generate(random: SourceOfRandomness, status: GenerationStatus): Circuit = { - implicit val r = random - import ExprGen._ - - val params = ExprGenParams( - maxDepth = 50, - maxWidth = 31, - generators = Seq( - 1 -> AddDoPrimGen, - 1 -> SubDoPrimGen, - 1 -> MulDoPrimGen, - 1 -> DivDoPrimGen, - 1 -> LtDoPrimGen, - 1 -> LeqDoPrimGen, - 1 -> GtDoPrimGen, - 1 -> GeqDoPrimGen, - 1 -> EqDoPrimGen, - 1 -> NeqDoPrimGen, - 1 -> PadDoPrimGen, - 1 -> ShlDoPrimGen, - 1 -> ShrDoPrimGen, - 1 -> DshlDoPrimGen, - 1 -> CvtDoPrimGen, - 1 -> NegDoPrimGen, - 1 -> NotDoPrimGen, - 1 -> AndDoPrimGen, - 1 -> OrDoPrimGen, - 1 -> XorDoPrimGen, - 1 -> AndrDoPrimGen, - 1 -> OrrDoPrimGen, - 1 -> XorrDoPrimGen, - 1 -> CatDoPrimGen, - 1 -> BitsDoPrimGen, - 1 -> HeadDoPrimGen, - 1 -> TailDoPrimGen, - 1 -> AsUIntDoPrimGen, - 1 -> AsSIntDoPrimGen - ) - ) - params.generateSingleExprCircuit[SourceOfRandomnessGen]() - } -} +import ExprGen._ +class FirrtlCompileCircuitGenerator extends SingleExpressionCircuitGenerator ( + ExprGenParams( + maxDepth = 50, + maxWidth = 31, + generators = ExprGenParams.defaultGenerators + ) +) @RunWith(classOf[JQF]) class FirrtlCompileTests { @@ -112,7 +76,7 @@ class FirrtlCompileTests { } @Fuzz - def compileSingleModule(@From(value = classOf[FirrtlSingleModuleGenerator]) c: Circuit) = { + def compileSingleModule(@From(value = classOf[FirrtlCompileCircuitGenerator]) c: Circuit) = { compile(CircuitState(c, ChirrtlForm, Seq())) } diff --git a/fuzzer/src/main/scala/firrtl/FirrtlEquivalenceTest.scala b/fuzzer/src/main/scala/firrtl/FirrtlEquivalenceTest.scala index e0aa1707..822744c2 100644 --- a/fuzzer/src/main/scala/firrtl/FirrtlEquivalenceTest.scala +++ b/fuzzer/src/main/scala/firrtl/FirrtlEquivalenceTest.scala @@ -1,17 +1,19 @@ package firrtl.fuzzer import com.pholser.junit.quickcheck.From +import com.pholser.junit.quickcheck.generator.{Generator, GenerationStatus} +import com.pholser.junit.quickcheck.random.SourceOfRandomness -import edu.berkeley.cs.jqf.fuzz.Fuzz; -import edu.berkeley.cs.jqf.fuzz.JQF; +import edu.berkeley.cs.jqf.fuzz.{Fuzz, JQF}; import firrtl._ import firrtl.annotations.{Annotation, CircuitTarget, ModuleTarget, Target} import firrtl.ir.Circuit +import firrtl.options.Dependency import firrtl.stage.{FirrtlCircuitAnnotation, InfoModeAnnotation, OutputFileAnnotation, TransformManager} import firrtl.stage.Forms.{VerilogMinimumOptimized, VerilogOptimized} import firrtl.stage.phases.WriteEmitted -import firrtl.transforms.ManipulateNames +import firrtl.transforms.{InlineBooleanExpressions, ManipulateNames} import firrtl.util.BackendCompilationUtilities import java.io.{File, FileWriter, PrintWriter, StringWriter} @@ -89,6 +91,32 @@ object FirrtlEquivalenceTestUtils { } } +import ExprGen._ +class InlineBooleanExprsCircuitGenerator extends SingleExpressionCircuitGenerator ( + ExprGenParams( + maxDepth = 50, + maxWidth = 31, + generators = ExprGenParams.defaultGenerators ++ Map( + LtDoPrimGen -> 10, + LeqDoPrimGen -> 10, + GtDoPrimGen -> 10, + GeqDoPrimGen -> 10, + EqDoPrimGen -> 10, + NeqDoPrimGen -> 10, + AndDoPrimGen -> 10, + OrDoPrimGen -> 10, + XorDoPrimGen -> 10, + AndrDoPrimGen -> 10, + OrrDoPrimGen -> 10, + XorrDoPrimGen -> 10, + BitsDoPrimGen -> 10, + HeadDoPrimGen -> 10, + TailDoPrimGen -> 10, + MuxGen -> 10 + ) + ) +) + @RunWith(classOf[JQF]) class FirrtlEquivalenceTests { private val lowFirrtlCompiler = new LowFirrtlCompiler() @@ -103,8 +131,7 @@ class FirrtlEquivalenceTests { } private val baseTestDir = new File("fuzzer/test_run_dir") - @Fuzz - def compileSingleModule(@From(value = classOf[FirrtlSingleModuleGenerator]) c: Circuit) = { + private def runTest(c: Circuit, referenceCompiler: TransformManager, customCompiler: TransformManager) = { val testDir = new File(baseTestDir, f"${c.hashCode}%08x") testDir.mkdirs() val fileWriter = new FileWriter(new File(testDir, s"${c.main}.fir")) @@ -113,9 +140,9 @@ class FirrtlEquivalenceTests { val passed = try { FirrtlEquivalenceTestUtils.firrtlEquivalenceTestPass( circuit = c, - referenceCompiler = new TransformManager(VerilogMinimumOptimized), + referenceCompiler = referenceCompiler, referenceAnnos = Seq(), - customCompiler = new TransformManager(VerilogOptimized), + customCompiler = customCompiler, customAnnos = Seq(), testDir = testDir ) @@ -131,4 +158,22 @@ class FirrtlEquivalenceTests { s"not equivalent to reference compiler on input ${testDir}:\n${c.serialize}\n", false) } } + + @Fuzz + def testOptimized(@From(value = classOf[FirrtlCompileCircuitGenerator]) c: Circuit) = { + runTest( + c = c, + referenceCompiler = new TransformManager(VerilogMinimumOptimized), + customCompiler = new TransformManager(VerilogOptimized) + ) + } + + @Fuzz + def testInlineBooleanExpressions(@From(value = classOf[InlineBooleanExprsCircuitGenerator]) c: Circuit) = { + runTest( + c = c, + referenceCompiler = new TransformManager(VerilogMinimumOptimized), + customCompiler = new TransformManager(VerilogMinimumOptimized :+ Dependency[InlineBooleanExpressions]) + ) + } } diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index 843c76a4..19f2661a 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -230,7 +230,48 @@ case class VRandom(width: BigInt) extends Expression { def foreachWidth(f: Width => Unit): Unit = () } +object VerilogEmitter { + + /** Maps a [[PrimOp]] to a precedence number, lower number means higher precedence + * + * Only the [[PrimOp]]s contained in this map will be inlined. [[PrimOp]]s + * like [[PrimOp.Neg]] are not in this map because inlining them may result + * in illegal verilog like '--2sh1' + */ + private val precedenceMap: Map[PrimOp, Int] = { + val precedenceSeq = Seq( + Set(Head, Tail, Bits, Shr, Pad), // Shr and Pad emit as bit select + Set(Andr, Orr, Xorr, Neg, Not), + Set(Mul, Div, Rem), + Set(Add, Sub, Addw, Subw), + Set(Dshl, Dshlw, Dshr), + Set(Lt, Leq, Gt, Geq), + Set(Eq, Neq), + Set(And), + Set(Xor), + Set(Or) + ) + precedenceSeq.zipWithIndex.foldLeft(Map.empty[PrimOp, Int]) { + case (map, (ops, idx)) => map ++ ops.map(_ -> idx) + } + } + + /** true if op1 has greater or equal precendence than op2 + */ + private def precedenceGeq(op1: PrimOp, op2: PrimOp): Boolean = { + precedenceMap(op1) <= precedenceMap(op2) + } + + /** true if op1 has greater precendence than op2 + */ + private def precedenceGt(op1: PrimOp, op2: PrimOp): Boolean = { + precedenceMap(op1) < precedenceMap(op2) + } +} + class VerilogEmitter extends SeqTransform with Emitter { + import VerilogEmitter._ + def inputForm = LowForm def outputForm = LowForm @@ -280,8 +321,42 @@ class VerilogEmitter extends SeqTransform with Emitter { case ClockType | AsyncResetType => "" case _ => throwInternalError(s"trying to write unsupported type in the Verilog Emitter: $tpe") } - def emit(x: Any)(implicit w: Writer): Unit = { emit(x, 0) } + private def getLeadingTabs(x: Any): String = { + x match { + case seq: Seq[_] => + val head = seq.takeWhile(_ == tab).mkString + val tail = seq.dropWhile(_ == tab).lift(0).map(getLeadingTabs).getOrElse(tab) + head + tail + case _ => tab + } + } + def emit(x: Any)(implicit w: Writer): Unit = { + emitCol(x, 0, getLeadingTabs(x), 0) + } + private def emitCast(e: Expression): Any = e.tpe match { + case (t: UIntType) => e + case (t: SIntType) => Seq("$signed(", e, ")") + case ClockType => e + case AnalogType(_) => e + case _ => throwInternalError(s"unrecognized cast: $e") + } def emit(x: Any, top: Int)(implicit w: Writer): Unit = { + emitCol(x, top, "", 0) + } + private val maxCol = 120 + private def emitCol(x: Any, top: Int, tabs: String, colNum: Int)(implicit w: Writer): Int = { + def writeCol(contents: String): Int = { + if ((contents.size + colNum) > maxCol) { + w.write("\n") + w.write(tabs) + w.write(contents) + tabs.size + contents.size + } else { + w.write(contents) + colNum + contents.size + } + } + def cast(e: Expression): Any = e.tpe match { case (t: UIntType) => e case (t: SIntType) => Seq("$signed(", e, ")") @@ -290,7 +365,7 @@ class VerilogEmitter extends SeqTransform with Emitter { case _ => throwInternalError(s"unrecognized cast: $e") } x match { - case (e: DoPrim) => emit(op_stream(e), top + 1) + case (e: DoPrim) => emitCol(op_stream(e), top + 1, tabs, colNum) case (e: Mux) => { if (e.tpe == ClockType) { throw EmitterException("Cannot emit clock muxes directly") @@ -298,51 +373,64 @@ class VerilogEmitter extends SeqTransform with Emitter { if (e.tpe == AsyncResetType) { throw EmitterException("Cannot emit async reset muxes directly") } - emit(Seq(e.cond, " ? ", cast(e.tval), " : ", cast(e.fval)), top + 1) + emitCol(Seq(e.cond, " ? ", cast(e.tval), " : ", cast(e.fval)), top + 1, tabs, colNum) } - case (e: ValidIf) => emit(Seq(cast(e.value)), top + 1) - case (e: WRef) => w.write(e.serialize) - case (e: WSubField) => w.write(LowerTypes.loweredName(e)) - case (e: WSubAccess) => w.write(s"${LowerTypes.loweredName(e.expr)}[${LowerTypes.loweredName(e.index)}]") - case (e: WSubIndex) => w.write(e.serialize) - case (e: Literal) => v_print(e) - case (e: VRandom) => w.write(s"{${e.nWords}{`RANDOM}}") - case (t: GroundType) => w.write(stringify(t)) + case (e: ValidIf) => emitCol(Seq(cast(e.value)), top + 1, tabs, colNum) + case (e: WRef) => writeCol(e.serialize) + case (e: WSubField) => writeCol(LowerTypes.loweredName(e)) + case (e: WSubAccess) => writeCol(s"${LowerTypes.loweredName(e.expr)}[${LowerTypes.loweredName(e.index)}]") + case (e: WSubIndex) => writeCol(e.serialize) + case (e: Literal) => v_print(e, colNum) + case (e: VRandom) => writeCol(s"{${e.nWords}{`RANDOM}}") + case (t: GroundType) => writeCol(stringify(t)) case (t: VectorType) => emit(t.tpe, top + 1) - w.write(s"[${t.size - 1}:0]") - case (s: String) => w.write(s) - case (i: Int) => w.write(i.toString) - case (i: Long) => w.write(i.toString) - case (i: BigInt) => w.write(i.toString) + writeCol(s"[${t.size - 1}:0]") + case (s: String) => writeCol(s) + case (i: Int) => writeCol(i.toString) + case (i: Long) => writeCol(i.toString) + case (i: BigInt) => writeCol(i.toString) case (i: Info) => i match { - case NoInfo => // Do nothing + case NoInfo => colNum // Do nothing case f: FileInfo => val escaped = FileInfo.escapedToVerilog(f.escaped) w.write(s" // @[$escaped]") + colNum case m: MultiInfo => val escaped = FileInfo.escapedToVerilog(m.flatten.map(_.escaped).mkString(" ")) w.write(s" // @[$escaped]") + colNum } case (s: Seq[Any]) => - s.foreach(emit(_, top + 1)) - if (top == 0) w.write("\n") + val nextColNum = s.foldLeft(colNum) { + case (colNum, e) => emitCol(e, top + 1, tabs, colNum) + } + if (top == 0) { + w.write("\n") + 0 + } else { + nextColNum + } case x => throwInternalError(s"trying to emit unsupported operator: $x") } } //;------------- PASS ----------------- - def v_print(e: Expression)(implicit w: Writer) = e match { + def v_print(e: Expression, colNum: Int)(implicit w: Writer) = e match { case UIntLiteral(value, IntWidth(width)) => - w.write(s"$width'h${value.toString(16)}") + val contents = s"$width'h${value.toString(16)}" + w.write(contents) + colNum + contents.size case SIntLiteral(value, IntWidth(width)) => val stringLiteral = value.toString(16) - w.write(stringLiteral.head match { + val contents = stringLiteral.head match { case '-' if value == FixAddingNegativeLiterals.minNegValue(width) => s"$width'sh${stringLiteral.tail}" case '-' => s"-$width'sh${stringLiteral.tail}" case _ => s"$width'sh${stringLiteral}" - }) + } + w.write(contents) + colNum + contents.size case _ => throwInternalError(s"attempt to print unrecognized expression: $e") } @@ -350,29 +438,62 @@ class VerilogEmitter extends SeqTransform with Emitter { // reference is actually unsigned in the emitted Verilog. Thus we must cast refs as necessary // to ensure Verilog operations are signed. def op_stream(doprim: DoPrim): Seq[Any] = { + def parenthesize(e: Expression, isFirst: Boolean): Any = doprim.op match { + // these PrimOps emit either {..., a0, ...} or a0 so they never need parentheses + case Shl | Cat | Cvt | AsUInt | AsSInt | AsClock | AsAsyncReset => e + case _ => + e match { + case e: DoPrim => + op_stream(e) match { + /** DoPrims like AsUInt simply emit Seq(a0), so we need to + * recursively check whether a0 needs to be parenthesized + */ + case Seq(passthrough: Expression) => parenthesize(passthrough, isFirst) + + /** If the expression is the first argument then it does not need + * parens if it's precedence is greather than or equal to the + * enclosing doprim, because verilog operators are left + * associative. All other args do not need parens only if the + * precedence is greater. + */ + case other => + if (precedenceGt(e.op, doprim.op) || (precedenceGeq(e.op, doprim.op) && isFirst)) { + other + } else { + Seq("(", other, ")") + } + } + + /** Mux args should always have parens because Mux has the lowest precedence + */ + case _: Mux => Seq("(", e, ")") + case _ => e + } + } + // Cast to SInt, don't cast multiple times def doCast(e: Expression): Any = e match { case DoPrim(AsSInt, Seq(arg), _, _) => doCast(arg) case slit: SIntLiteral => slit case other => Seq("$signed(", other, ")") } - def castIf(e: Expression): Any = { + def castIf(e: Expression, isFirst: Boolean = false): Any = { if (doprim.args.exists(_.tpe.isInstanceOf[SIntType])) { e.tpe match { case _: SIntType => doCast(e) case _ => throwInternalError(s"Unexpected non-SInt type for $e in $doprim") } } else { - e + parenthesize(e, isFirst) } } - def cast(e: Expression): Any = doprim.tpe match { - case _: UIntType => e + def cast(e: Expression, isFirst: Boolean = false): Any = doprim.tpe match { + case _: UIntType => parenthesize(e, isFirst) case _: SIntType => doCast(e) case _ => throwInternalError(s"Unexpected type for $e in $doprim") } - def castAs(e: Expression): Any = e.tpe match { - case _: UIntType => e + def castAs(e: Expression, isFirst: Boolean = false): Any = e.tpe match { + case _: UIntType => parenthesize(e, isFirst) case _: SIntType => doCast(e) case _ => throwInternalError(s"Unexpected type for $e in $doprim") } @@ -381,19 +502,6 @@ class VerilogEmitter extends SeqTransform with Emitter { def c0: Int = doprim.consts.head.toInt def c1: Int = doprim.consts(1).toInt - def checkArgumentLegality(e: Expression): Unit = e match { - case _: UIntLiteral | _: SIntLiteral | _: WRef | _: WSubField => - case DoPrim(Not, args, _, _) => args.foreach(checkArgumentLegality) - case DoPrim(op, args, _, _) if isCast(op) => args.foreach(checkArgumentLegality) - case DoPrim(op, args, _, _) if isBitExtract(op) => args.foreach(checkArgumentLegality) - case _ => throw EmitterException(s"Can't emit ${e.getClass.getName} as PrimOp argument") - } - - def checkCatArgumentLegality(e: Expression): Unit = e match { - case DoPrim(Cat, args, _, _) => args.foreach(checkCatArgumentLegality) - case _ => checkArgumentLegality(e) - } - def castCatArgs(a0: Expression, a1: Expression): Seq[Any] = { val a0Seq = a0 match { case cat @ DoPrim(PrimOps.Cat, args, _, _) => castCatArgs(args.head, args(1)) @@ -407,24 +515,19 @@ class VerilogEmitter extends SeqTransform with Emitter { } doprim.op match { - case Cat => doprim.args.foreach(checkCatArgumentLegality) - case cast if isCast(cast) => // Casts are allowed to wrap any Expression - case other => doprim.args.foreach(checkArgumentLegality) - } - doprim.op match { - case Add => Seq(castIf(a0), " + ", castIf(a1)) - case Addw => Seq(castIf(a0), " + ", castIf(a1)) - case Sub => Seq(castIf(a0), " - ", castIf(a1)) - case Subw => Seq(castIf(a0), " - ", castIf(a1)) - case Mul => Seq(castIf(a0), " * ", castIf(a1)) - case Div => Seq(castIf(a0), " / ", castIf(a1)) - case Rem => Seq(castIf(a0), " % ", castIf(a1)) - case Lt => Seq(castIf(a0), " < ", castIf(a1)) - case Leq => Seq(castIf(a0), " <= ", castIf(a1)) - case Gt => Seq(castIf(a0), " > ", castIf(a1)) - case Geq => Seq(castIf(a0), " >= ", castIf(a1)) - case Eq => Seq(castIf(a0), " == ", castIf(a1)) - case Neq => Seq(castIf(a0), " != ", castIf(a1)) + case Add => Seq(castIf(a0, true), " + ", castIf(a1)) + case Addw => Seq(castIf(a0, true), " + ", castIf(a1)) + case Sub => Seq(castIf(a0, true), " - ", castIf(a1)) + case Subw => Seq(castIf(a0, true), " - ", castIf(a1)) + case Mul => Seq(castIf(a0, true), " * ", castIf(a1)) + case Div => Seq(castIf(a0, true), " / ", castIf(a1)) + case Rem => Seq(castIf(a0, true), " % ", castIf(a1)) + case Lt => Seq(castIf(a0, true), " < ", castIf(a1)) + case Leq => Seq(castIf(a0, true), " <= ", castIf(a1)) + case Gt => Seq(castIf(a0, true), " > ", castIf(a1)) + case Geq => Seq(castIf(a0, true), " >= ", castIf(a1)) + case Eq => Seq(castIf(a0, true), " == ", castIf(a1)) + case Neq => Seq(castIf(a0, true), " != ", castIf(a1)) case Pad => val w = bitWidth(a0.tpe) val diff = c0 - w @@ -434,7 +537,7 @@ class VerilogEmitter extends SeqTransform with Emitter { // Either sign extend or zero extend. // If width == BigInt(1), don't extract bit case (_: SIntType) if w == BigInt(1) => Seq("{", c0, "{", a0, "}}") - case (_: SIntType) => Seq("{{", diff, "{", a0, "[", w - 1, "]}},", a0, "}") + case (_: SIntType) => Seq("{{", diff, "{", parenthesize(a0, true), "[", w - 1, "]}},", a0, "}") case (_) => Seq("{{", diff, "'d0}, ", a0, "}") } // Because we don't support complex Expressions, all casts are ignored @@ -451,35 +554,35 @@ class VerilogEmitter extends SeqTransform with Emitter { case Shl => if (c0 > 0) Seq("{", cast(a0), s", $c0'h0}") else Seq(cast(a0)) case Shr if c0 >= bitWidth(a0.tpe) => error("Verilog emitter does not support SHIFT_RIGHT >= arg width") - case Shr if c0 == (bitWidth(a0.tpe) - 1) => Seq(a0, "[", bitWidth(a0.tpe) - 1, "]") - case Shr => Seq(a0, "[", bitWidth(a0.tpe) - 1, ":", c0, "]") - case Neg => Seq("-", cast(a0)) + case Shr if c0 == (bitWidth(a0.tpe) - 1) => Seq(parenthesize(a0, true), "[", bitWidth(a0.tpe) - 1, "]") + case Shr => Seq(parenthesize(a0, true), "[", bitWidth(a0.tpe) - 1, ":", c0, "]") + case Neg => Seq("-", cast(a0, true)) case Cvt => a0.tpe match { case (_: UIntType) => Seq("{1'b0,", cast(a0), "}") case (_: SIntType) => Seq(cast(a0)) } - case Not => Seq("~", a0) - case And => Seq(castAs(a0), " & ", castAs(a1)) - case Or => Seq(castAs(a0), " | ", castAs(a1)) - case Xor => Seq(castAs(a0), " ^ ", castAs(a1)) - case Andr => Seq("&", cast(a0)) - case Orr => Seq("|", cast(a0)) - case Xorr => Seq("^", cast(a0)) + case Not => Seq("~", parenthesize(a0, true)) + case And => Seq(castAs(a0, true), " & ", castAs(a1)) + case Or => Seq(castAs(a0, true), " | ", castAs(a1)) + case Xor => Seq(castAs(a0, true), " ^ ", castAs(a1)) + case Andr => Seq("&", cast(a0, true)) + case Orr => Seq("|", cast(a0, true)) + case Xorr => Seq("^", cast(a0, true)) case Cat => "{" +: (castCatArgs(a0, a1) :+ "}") // If selecting zeroth bit and single-bit wire, just emit the wire case Bits if c0 == 0 && c1 == 0 && bitWidth(a0.tpe) == BigInt(1) => Seq(a0) - case Bits if c0 == c1 => Seq(a0, "[", c0, "]") - case Bits => Seq(a0, "[", c0, ":", c1, "]") + case Bits if c0 == c1 => Seq(parenthesize(a0, true), "[", c0, "]") + case Bits => Seq(parenthesize(a0, true), "[", c0, ":", c1, "]") // If selecting zeroth bit and single-bit wire, just emit the wire case Head if c0 == 1 && bitWidth(a0.tpe) == BigInt(1) => Seq(a0) case Head if c0 == 1 => Seq(a0, "[", bitWidth(a0.tpe) - 1, "]") case Head => val msb = bitWidth(a0.tpe) - 1 val lsb = bitWidth(a0.tpe) - c0 - Seq(a0, "[", msb, ":", lsb, "]") - case Tail if c0 == (bitWidth(a0.tpe) - 1) => Seq(a0, "[0]") - case Tail => Seq(a0, "[", bitWidth(a0.tpe) - c0 - 1, ":0]") + Seq(parenthesize(a0, true), "[", msb, ":", lsb, "]") + case Tail if c0 == (bitWidth(a0.tpe) - 1) => Seq(parenthesize(a0, true), "[0]") + case Tail => Seq(parenthesize(a0, true), "[", bitWidth(a0.tpe) - c0 - 1, ":0]") } } @@ -804,7 +907,7 @@ class VerilogEmitter extends SeqTransform with Emitter { } def regUpdate(r: Expression, clk: Expression, reset: Expression, init: Expression) = { - def addUpdate(info: Info, expr: Expression, tabs: String): Seq[Seq[Any]] = expr match { + def addUpdate(info: Info, expr: Expression, tabs: Seq[String]): Seq[Seq[Any]] = expr match { case m: Mux => if (m.tpe == ClockType) throw EmitterException("Cannot emit clock muxes directly") if (m.tpe == AsyncResetType) throw EmitterException("Cannot emit async reset muxes directly") @@ -814,8 +917,8 @@ class VerilogEmitter extends SeqTransform with Emitter { lazy val _else = Seq(tabs, "end else begin") lazy val _ifNot = Seq(tabs, "if (!(", m.cond, ")) begin", eninfo) lazy val _end = Seq(tabs, "end") - lazy val _true = addUpdate(tinfo, m.tval, tabs + tab) - lazy val _false = addUpdate(finfo, m.fval, tabs + tab) + lazy val _true = addUpdate(tinfo, m.tval, tab +: tabs) + lazy val _false = addUpdate(finfo, m.fval, tab +: tabs) lazy val _elseIfFalse = { val _falsex = addUpdate(finfo, m.fval, tabs) // _false, but without an additional tab Seq(tabs, "end else ", _falsex.head.tail) +: _falsex.tail @@ -845,13 +948,19 @@ class VerilogEmitter extends SeqTransform with Emitter { } if (weq(init, r)) { // Synchronous Reset val InfoExpr(info, e) = netlist(r) - noResetAlwaysBlocks.getOrElseUpdate(clk, ArrayBuffer[Seq[Any]]()) ++= addUpdate(info, e, "") + noResetAlwaysBlocks.getOrElseUpdate(clk, ArrayBuffer[Seq[Any]]()) ++= addUpdate(info, e, Seq.empty) } else { // Asynchronous Reset assert(reset.tpe == AsyncResetType, "Error! Synchronous reset should have been removed!") val tv = init val InfoExpr(finfo, fv) = netlist(r) // TODO add register info argument and build a MultiInfo to pass - asyncResetAlwaysBlocks += ((clk, reset, addUpdate(NoInfo, Mux(reset, tv, fv, mux_type_and_widths(tv, fv)), ""))) + asyncResetAlwaysBlocks += ( + ( + clk, + reset, + addUpdate(NoInfo, Mux(reset, tv, fv, mux_type_and_widths(tv, fv)), Seq.empty) + ) + ) } } @@ -1367,11 +1476,18 @@ class VerilogEmitter extends SeqTransform with Emitter { } override def execute(state: CircuitState): CircuitState = { + val writerToString = + (writer: java.io.StringWriter) => writer.toString.replaceAll("""(?m) +$""", "") // trim trailing whitespace + val newAnnos = state.annotations.flatMap { case EmitCircuitAnnotation(a) if this.getClass == a => val writer = new java.io.StringWriter emit(state, writer) - Seq(EmittedVerilogCircuitAnnotation(EmittedVerilogCircuit(state.circuit.main, writer.toString, outputSuffix))) + Seq( + EmittedVerilogCircuitAnnotation( + EmittedVerilogCircuit(state.circuit.main, writerToString(writer), outputSuffix) + ) + ) case EmitAllModulesAnnotation(a) if this.getClass == a => val cs = runTransforms(state) @@ -1383,12 +1499,16 @@ class VerilogEmitter extends SeqTransform with Emitter { val writer = new java.io.StringWriter val renderer = new VerilogRender(d, pds, module, moduleMap, cs.circuit.main, emissionOptions)(writer) renderer.emit_verilog() - Some(EmittedVerilogModuleAnnotation(EmittedVerilogModule(module.name, writer.toString, outputSuffix))) + Some( + EmittedVerilogModuleAnnotation(EmittedVerilogModule(module.name, writerToString(writer), outputSuffix)) + ) case module: Module => val writer = new java.io.StringWriter val renderer = new VerilogRender(module, moduleMap, cs.circuit.main, emissionOptions)(writer) renderer.emit_verilog() - Some(EmittedVerilogModuleAnnotation(EmittedVerilogModule(module.name, writer.toString, outputSuffix))) + Some( + EmittedVerilogModuleAnnotation(EmittedVerilogModule(module.name, writerToString(writer), outputSuffix)) + ) case _ => None } case _ => Seq() diff --git a/src/main/scala/firrtl/stage/Forms.scala b/src/main/scala/firrtl/stage/Forms.scala index a0c5ea0c..db411325 100644 --- a/src/main/scala/firrtl/stage/Forms.scala +++ b/src/main/scala/firrtl/stage/Forms.scala @@ -110,7 +110,11 @@ object Forms { Dependency[firrtl.AddDescriptionNodes] ) - val VerilogOptimized: Seq[TransformDependency] = LowFormOptimized ++ VerilogMinimumOptimized + val VerilogOptimized: Seq[TransformDependency] = LowFormOptimized ++ + Seq( + Dependency[firrtl.transforms.InlineBooleanExpressions] + ) ++ + VerilogMinimumOptimized val AssertsRemoved: Seq[TransformDependency] = Seq( diff --git a/src/main/scala/firrtl/transforms/InlineBooleanExpressions.scala b/src/main/scala/firrtl/transforms/InlineBooleanExpressions.scala new file mode 100644 index 00000000..7c52d6ef --- /dev/null +++ b/src/main/scala/firrtl/transforms/InlineBooleanExpressions.scala @@ -0,0 +1,169 @@ +// See LICENSE for license details. + +package firrtl +package transforms + +import firrtl.annotations.{NoTargetAnnotation, Target} +import firrtl.annotations.TargetToken.{fromStringToTargetToken, OfModule, Ref} +import firrtl.ir._ +import firrtl.passes.{InferTypes, LowerTypes, SplitExpressions} +import firrtl.options.Dependency +import firrtl.PrimOps._ +import firrtl.WrappedExpression._ + +import scala.collection.mutable + +case class InlineBooleanExpressionsMax(max: Int) extends NoTargetAnnotation + +object InlineBooleanExpressions { + val defaultMax = 30 +} + +/** Inline Bool expressions + * + * The following conditions must be satisfied to inline + * 1. has type [[Utils.BoolType]] + * 2. is bound to a [[firrtl.ir.DefNode DefNode]] with name starting with '_' + * 3. is bound to a [[firrtl.ir.DefNode DefNode]] with a source locator that + * points at the same file and line number. If it is a MultiInfo source + * locator, the set of file and line number pairs must be the same. Source + * locators may point to different column numbers. + * 4. [[InlineBooleanExpressionsMax]] has not been exceeded + * 5. is not a [[firrtl.ir.Mux Mux]] + */ +class InlineBooleanExpressions extends Transform with DependencyAPIMigration { + + override def prerequisites = Seq( + Dependency(InferTypes), + Dependency(LowerTypes) + ) + + override def optionalPrerequisites = Seq( + Dependency(SplitExpressions) + ) + + override def invalidates(a: Transform) = a match { + case _: DeadCodeElimination => true // this transform does not remove nodes that are unused after inlining + case _ => false + } + + type Netlist = mutable.HashMap[WrappedExpression, (Expression, Info)] + + private def isArgN(outerExpr: DoPrim, subExpr: Expression, n: Int): Boolean = { + outerExpr.args.lift(n) match { + case Some(arg) => arg eq subExpr + case _ => false + } + } + + private val fileLineRegex = """(.*) ([0-9]+):[0-9]+""".r + private def sameFileAndLineInfo(info1: Info, info2: Info): Boolean = { + (info1, info2) match { + case (FileInfo(fileLineRegex(file1, line1)), FileInfo(fileLineRegex(file2, line2))) => + (file1 == file2) && (line1 == line2) + case (MultiInfo(infos1), MultiInfo(infos2)) if infos1.size == infos2.size => + infos1.zip(infos2).forall { + case (i1, i2) => + sameFileAndLineInfo(i1, i2) + } + case (NoInfo, NoInfo) => true + case _ => false + } + } + + /** A helper class to initialize and store mutable state that the expression + * and statement map functions need access to. This makes it easier to pass + * information around without having to plump arguments through the onExpr + * and onStmt methods. + */ + private class MapMethods(maxInlineCount: Int, dontTouches: Set[Ref]) { + val netlist: Netlist = new Netlist + val inlineCounts = mutable.Map.empty[Ref, Int] + var inlineCount: Int = 1 + + /** Whether or not an can be inlined + * @param refExpr the expression to check for inlining + */ + def canInline(refExpr: Expression): Boolean = { + refExpr match { + case _: Mux => false + case _ => refExpr.tpe == Utils.BoolType + } + } + + /** Inlines [[Wref]]s if they are Boolean, have matching file line numbers, + * and would not raise inlineCounts past the maximum. + * + * @param info the [[Info]] of the enclosing [[Statement]] + * @param outerExpr the direct parent [[Expression]] of the current [[Expression]] + * @param expr the [[Expression]] to apply inlining to + */ + def onExpr(info: Info, outerExpr: Option[Expression])(expr: Expression): Expression = { + expr match { + case ref: WRef if !dontTouches.contains(ref.name.Ref) && ref.name.head == '_' => + val refKey = ref.name.Ref + netlist.get(we(ref)) match { + case Some((refExpr, refInfo)) if sameFileAndLineInfo(info, refInfo) => + val inlineNum = inlineCounts.getOrElse(refKey, 1) + if (!outerExpr.isDefined || canInline(refExpr) && ((inlineNum + inlineCount) <= maxInlineCount)) { + inlineCount += inlineNum + refExpr + } else { + ref + } + case other => ref + } + case other => other.mapExpr(onExpr(info, Some(other))) + } + } + + /** Applies onExpr and records metadata for every [[HasInfo]] in a [[Statement]] + * + * This resets inlineCount before inlining and records the resulting + * inline counts and inlined values in the inlineCounts and netlist maps + * after inlining. + */ + def onStmt(stmt: Statement): Statement = { + stmt.mapStmt(onStmt) match { + case hasInfo: HasInfo => + inlineCount = 1 + val stmtx = hasInfo.mapExpr(onExpr(hasInfo.info, None)) + stmtx match { + case node: DefNode => inlineCounts(node.name.Ref) = inlineCount + case _ => + } + stmtx match { + case node @ DefNode(info, name, value) => + netlist(we(WRef(name))) = (value, info) + case _ => + } + stmtx + case other => other + } + } + } + + def execute(state: CircuitState): CircuitState = { + val dontTouchMap: Map[OfModule, Set[Ref]] = { + val refTargets = state.annotations.flatMap { + case anno: HasDontTouches => anno.dontTouches + case o => Nil + } + val dontTouches: Seq[(OfModule, Ref)] = refTargets.map { + case r => Target.referringModule(r).module.OfModule -> r.ref.Ref + } + dontTouches.groupBy(_._1).mapValues(_.map(_._2).toSet).toMap + } + + val maxInlineCount = state.annotations.collectFirst { + case InlineBooleanExpressionsMax(max) => max + }.getOrElse(InlineBooleanExpressions.defaultMax) + + val modulesx = state.circuit.modules.map { m => + val mapMethods = new MapMethods(maxInlineCount, dontTouchMap.getOrElse(m.name.OfModule, Set.empty[Ref])) + m.mapStmt(mapMethods.onStmt(_)) + } + + state.copy(circuit = state.circuit.copy(modules = modulesx)) + } +} diff --git a/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala b/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala new file mode 100644 index 00000000..5fee87c9 --- /dev/null +++ b/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala @@ -0,0 +1,242 @@ + +// See LICENSE for license details. + +package firrtlTests + +import firrtl._ +import firrtl.annotations.Annotation +import firrtl.options.Dependency +import firrtl.passes._ +import firrtl.transforms._ +import firrtl.testutils._ +import firrtl.stage.TransformManager + +class InlineBooleanExpressionsSpec extends FirrtlFlatSpec { + val transform = new InlineBooleanExpressions + val transforms: Seq[Transform] = new TransformManager( + transform.prerequisites + ).flattenedTransformOrder :+ transform + + protected def exec(input: String, annos: Seq[Annotation] = Nil) = { + transforms.foldLeft(CircuitState(parse(input), UnknownForm, AnnotationSeq(annos))) { + (c: CircuitState, t: Transform) => t.runTransform(c) + }.circuit.serialize + } + + it should "inline mux operands" in { + val input = + """circuit Top : + | module Top : + | output out : UInt<1> + | node x1 = UInt<1>(0) + | node x2 = UInt<1>(1) + | node _t = head(x1, 1) + | node _f = head(x2, 1) + | node _c = lt(x1, x2) + | node _y = mux(_c, _t, _f) + | out <= _y""".stripMargin + val check = + """circuit Top : + | module Top : + | output out : UInt<1> + | node x1 = UInt<1>(0) + | node x2 = UInt<1>(1) + | node _t = head(x1, 1) + | node _f = head(x2, 1) + | node _c = lt(x1, x2) + | node _y = mux(lt(x1, x2), head(x1, 1), head(x2, 1)) + | out <= mux(lt(x1, x2), head(x1, 1), head(x2, 1))""".stripMargin + val result = exec(input) + (result) should be (parse(check).serialize) + firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions)) + } + + it should "only inline expressions with the same file and line number" in { + val input = + """circuit Top : + | module Top : + | output outA1 : UInt<1> + | output outA2 : UInt<1> + | output outB : UInt<1> + | node x1 = UInt<1>(0) + | node x2 = UInt<1>(1) + | + | node _t = head(x1, 1) @[A 1:1] + | node _f = head(x2, 1) @[A 1:2] + | node _y = mux(lt(x1, x2), _t, _f) @[A 1:3] + | outA1 <= _y @[A 1:3] + | + | outA2 <= _y @[A 2:3] + | + | outB <= _y @[B]""".stripMargin + val check = + """circuit Top : + | module Top : + | output outA1 : UInt<1> + | output outA2 : UInt<1> + | output outB : UInt<1> + | node x1 = UInt<1>(0) + | node x2 = UInt<1>(1) + | + | node _t = head(x1, 1) @[A 1:1] + | node _f = head(x2, 1) @[A 1:2] + | node _y = mux(lt(x1, x2), head(x1, 1), head(x2, 1)) @[A 1:3] + | outA1 <= mux(lt(x1, x2), head(x1, 1), head(x2, 1)) @[A 1:3] + | + | outA2 <= _y @[A 2:3] + | + | outB <= _y @[B]""".stripMargin + val result = exec(input) + (result) should be (parse(check).serialize) + firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions)) + } + + it should "inline boolean DoPrims" in { + val input = + """circuit Top : + | module Top : + | output outA : UInt<1> + | output outB : UInt<1> + | node x1 = UInt<3>(0) + | node x2 = UInt<3>(1) + | + | node _a = lt(x1, x2) + | node _b = eq(_a, x2) + | node _c = and(_b, x2) + | outA <= _c + | + | node _d = head(_c, 1) + | node _e = andr(_d) + | node _f = lt(_e, x2) + | outB <= _f""".stripMargin + val check = + """circuit Top : + | module Top : + | output outA : UInt<1> + | output outB : UInt<1> + | node x1 = UInt<3>(0) + | node x2 = UInt<3>(1) + | + | node _a = lt(x1, x2) + | node _b = eq(lt(x1, x2), x2) + | node _c = and(eq(lt(x1, x2), x2), x2) + | outA <= and(eq(lt(x1, x2), x2), x2) + | + | node _d = head(_c, 1) + | node _e = andr(head(_c, 1)) + | node _f = lt(andr(head(_c, 1)), x2) + | + | outB <= lt(andr(head(_c, 1)), x2)""".stripMargin + val result = exec(input) + (result) should be (parse(check).serialize) + firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions)) + } + + it should "inline more boolean DoPrims" in { + val input = + """circuit Top : + | module Top : + | output outA : UInt<1> + | output outB : UInt<1> + | node x1 = UInt<3>(0) + | node x2 = UInt<3>(1) + | + | node _a = lt(x1, x2) + | node _b = leq(_a, x2) + | node _c = gt(_b, x2) + | node _d = geq(_c, x2) + | outA <= _d + | + | node _e = lt(x1, x2) + | node _f = leq(x1, _e) + | node _g = gt(x1, _f) + | node _h = geq(x1, _g) + | outB <= _h""".stripMargin + val check = + """circuit Top : + | module Top : + | output outA : UInt<1> + | output outB : UInt<1> + | node x1 = UInt<3>(0) + | node x2 = UInt<3>(1) + | + | node _a = lt(x1, x2) + | node _b = leq(lt(x1, x2), x2) + | node _c = gt(leq(lt(x1, x2), x2), x2) + | node _d = geq(gt(leq(lt(x1, x2), x2), x2), x2) + | outA <= geq(gt(leq(lt(x1, x2), x2), x2), x2) + | + | node _e = lt(x1, x2) + | node _f = leq(x1, lt(x1, x2)) + | node _g = gt(x1, leq(x1, lt(x1, x2))) + | node _h = geq(x1, gt(x1, leq(x1, lt(x1, x2)))) + | + | outB <= geq(x1, gt(x1, leq(x1, lt(x1, x2))))""".stripMargin + val result = exec(input) + (result) should be (parse(check).serialize) + firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions)) + } + + it should "limit the number of inlines" in { + val input = + s"""circuit Top : + | module Top : + | input c_0: UInt<1> + | input c_1: UInt<1> + | input c_2: UInt<1> + | input c_3: UInt<1> + | input c_4: UInt<1> + | input c_5: UInt<1> + | input c_6: UInt<1> + | output out : UInt<1> + | + | node _1 = or(c_0, c_1) + | node _2 = or(_1, c_2) + | node _3 = or(_2, c_3) + | node _4 = or(_3, c_4) + | node _5 = or(_4, c_5) + | node _6 = or(_5, c_6) + | + | out <= _6""".stripMargin + val check = + s"""circuit Top : + | module Top : + | input c_0: UInt<1> + | input c_1: UInt<1> + | input c_2: UInt<1> + | input c_3: UInt<1> + | input c_4: UInt<1> + | input c_5: UInt<1> + | input c_6: UInt<1> + | output out : UInt<1> + | + | node _1 = or(c_0, c_1) + | node _2 = or(or(c_0, c_1), c_2) + | node _3 = or(or(or(c_0, c_1), c_2), c_3) + | node _4 = or(_3, c_4) + | node _5 = or(or(_3, c_4), c_5) + | node _6 = or(or(or(_3, c_4), c_5), c_6) + | + | out <= or(or(or(_3, c_4), c_5), c_6)""".stripMargin + val result = exec(input, Seq(InlineBooleanExpressionsMax(3))) + (result) should be (parse(check).serialize) + firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions)) + } + + it should "be equivalent" in { + val input = + """circuit InlineBooleanExpressionsEquivalenceTest : + | module InlineBooleanExpressionsEquivalenceTest : + | input in : UInt<1>[6] + | output out : UInt<1> + | + | node _a = or(in[0], in[1]) + | node _b = and(in[2], _a) + | node _c = eq(in[3], _b) + | node _d = lt(in[4], _c) + | node _e = eq(in[5], _d) + | node _f = head(_e, 1) + | out <= _f""".stripMargin + firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions)) + } +} diff --git a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala index 46416619..40f8f123 100644 --- a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala +++ b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala @@ -260,6 +260,8 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { it should "replicate the old order" in { val legacy = Seq( + new firrtl.transforms.InlineBooleanExpressions, + new firrtl.transforms.DeadCodeElimination, new firrtl.transforms.BlackBoxSourceHelper, new firrtl.transforms.FixAddingNegativeLiterals, new firrtl.transforms.ReplaceTruncatingArithmetic, diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index 8f128274..a864bfe5 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -110,18 +110,8 @@ class UnitTests extends FirrtlFlatSpec { | out <= bits(mux(a, b, c), 0, 0) |""".stripMargin - "Emitting a nested expression" should "throw an exception" in { + "Emitting a nested expression" should "compile" in { val passes = Seq(ToWorkingIR, InferTypes, ResolveKinds) - intercept[PassException] { - val c = Parser.parse(splitExpTestCode.split("\n").toIterator) - val c2 = passes.foldLeft(c)((c, p) => p.run(c)) - val writer = new StringWriter() - (new VerilogEmitter).emit(CircuitState(c2, LowForm), writer) - } - } - - "After splitting, emitting a nested expression" should "compile" in { - val passes = Seq(ToWorkingIR, SplitExpressions, InferTypes) val c = Parser.parse(splitExpTestCode.split("\n").toIterator) val c2 = passes.foldLeft(c)((c, p) => p.run(c)) val writer = new StringWriter() |
