aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJack Koenig2021-03-18 23:31:51 -0700
committerGitHub2021-03-18 23:31:51 -0700
commitb274b319d4a4014c154f06bfc174beba461d6fce (patch)
tree36f3c83f5ceb3d820bc6d6073d8ad2de202c8773
parent94d1bee4c23bd3d8f99dae3ca431ffaa5dc1410d (diff)
Ensure InlineCasts does not inline complex Expressions (#2130)
Previously, InlineCasts could inline complex (ie. non-cast) Expressions into other complex Expressions. Now it will only inline so long as there no more than 1 complex Expression in the current nested Expression. Co-authored-by: Albert Magyar <albert.magyar@gmail.com>
-rw-r--r--src/main/scala/firrtl/transforms/InlineCasts.scala41
-rw-r--r--src/test/scala/firrtl/testutils/FirrtlSpec.scala8
-rw-r--r--src/test/scala/firrtlTests/InlineCastsSpec.scala68
3 files changed, 88 insertions, 29 deletions
diff --git a/src/main/scala/firrtl/transforms/InlineCasts.scala b/src/main/scala/firrtl/transforms/InlineCasts.scala
index 71318eee..761252c1 100644
--- a/src/main/scala/firrtl/transforms/InlineCasts.scala
+++ b/src/main/scala/firrtl/transforms/InlineCasts.scala
@@ -28,23 +28,30 @@ object InlineCastsTransform {
* @param expr the Expression being transformed
* @return Returns expr with [[WRef]]s replaced by values found in replace
*/
- def onExpr(replace: NodeMap)(expr: Expression): Expression = expr match {
- // Anything that may generate a part-select should not be inlined!
- case DoPrim(op, _, _, _) if (isBitExtract(op) || op == Pad) => expr
- case e =>
- e.map(onExpr(replace)) match {
- case e @ WRef(name, _, _, _) =>
- replace
- .get(name)
- .filter(isSimpleCast(castSeen = false))
- .getOrElse(e)
- case e @ DoPrim(op, Seq(WRef(name, _, _, _)), _, _) if isCast(op) =>
- replace
- .get(name)
- .map(value => e.copy(args = Seq(value)))
- .getOrElse(e)
- case other => other // Not a candidate
- }
+ def onExpr(replace: NodeMap)(expr: Expression): Expression = {
+ // Keep track if we've seen any non-cast expressions while recursing
+ def rec(hasNonCastParent: Boolean)(expr: Expression): Expression = expr match {
+ // Skip pads to avoid inlining literals into pads which results in invalid Verilog
+ case DoPrim(op, _, _, _) if (isBitExtract(op) || op == Pad) => expr
+ case e =>
+ e.map(rec(hasNonCastParent || !isCast(e))) match {
+ case e @ WRef(name, _, _, _) =>
+ replace
+ .get(name)
+ .filter(isSimpleCast(castSeen = false))
+ .getOrElse(e)
+ case e @ DoPrim(op, Seq(WRef(name, _, _, _)), _, _) if isCast(op) =>
+ replace
+ .get(name)
+ // Only inline the Expression if there is no non-cast parent in the expression tree OR
+ // if the subtree contains only casts and references.
+ .filter(x => !hasNonCastParent || isSimpleCast(castSeen = true)(x))
+ .map(value => e.copy(args = Seq(value)))
+ .getOrElse(e)
+ case other => other // Not a candidate
+ }
+ }
+ rec(false)(expr)
}
/** Inline casts in a Statement
diff --git a/src/test/scala/firrtl/testutils/FirrtlSpec.scala b/src/test/scala/firrtl/testutils/FirrtlSpec.scala
index 24793437..63def26a 100644
--- a/src/test/scala/firrtl/testutils/FirrtlSpec.scala
+++ b/src/test/scala/firrtl/testutils/FirrtlSpec.scala
@@ -165,10 +165,14 @@ trait FirrtlRunners extends BackendCompilationUtilities {
/** Compiles input Firrtl to Verilog */
def compileToVerilog(input: String, annotations: AnnotationSeq = Seq.empty): String = {
+ compileToVerilogCircuitState(input, annotations).getEmittedCircuit.value
+ }
+
+ /** Compiles input Firrtl to Verilog */
+ def compileToVerilogCircuitState(input: String, annotations: AnnotationSeq = Seq.empty): CircuitState = {
val circuit = Parser.parse(input.split("\n").toIterator)
val compiler = new VerilogCompiler
- val res = compiler.compileAndEmit(CircuitState(circuit, HighForm, annotations), extraCheckTransforms)
- res.getEmittedCircuit.value
+ compiler.compileAndEmit(CircuitState(circuit, HighForm, annotations), extraCheckTransforms)
}
/** Compile a Firrtl file
diff --git a/src/test/scala/firrtlTests/InlineCastsSpec.scala b/src/test/scala/firrtlTests/InlineCastsSpec.scala
index e27020e5..7a248def 100644
--- a/src/test/scala/firrtlTests/InlineCastsSpec.scala
+++ b/src/test/scala/firrtlTests/InlineCastsSpec.scala
@@ -4,18 +4,19 @@ package firrtlTests
import firrtl.transforms.InlineCastsTransform
import firrtl.testutils.FirrtlFlatSpec
+import firrtl.testutils.FirrtlCheckers._
-/*
- * Note: InlineCasts is still part of mverilog, so this test must both:
- * - Test that the InlineCasts fix is effective given the current mverilog
- * - Provide a test that will be robust if and when InlineCasts is no longer run in mverilog
- *
- * This is why the test passes InlineCasts as a custom transform: to future-proof it so that
- * it can do real LEC against no-InlineCasts. It currently is just a sanity check that the
- * emitted Verilog is legal, but it will automatically become a more meaningful test when
- * InlineCasts is not run in mverilog.
- */
class InlineCastsEquivalenceSpec extends FirrtlFlatSpec {
+ /*
+ * Note: InlineCasts is still part of mverilog, so this test must both:
+ * - Test that the InlineCasts fix is effective given the current mverilog
+ * - Provide a test that will be robust if and when InlineCasts is no longer run in mverilog
+ *
+ * This is why the test passes InlineCasts as a custom transform: to future-proof it so that
+ * it can do real LEC against no-InlineCasts. It currently is just a sanity check that the
+ * emitted Verilog is legal, but it will automatically become a more meaningful test when
+ * InlineCasts is not run in mverilog.
+ */
"InlineCastsTransform" should "not produce broken Verilog" in {
val input =
s"""circuit literalsel_fir:
@@ -26,4 +27,51 @@ class InlineCastsEquivalenceSpec extends FirrtlFlatSpec {
|""".stripMargin
firrtlEquivalenceTest(input, Seq(new InlineCastsTransform))
}
+
+ it should "not inline complex expressions into other complex expressions" in {
+ val input =
+ """circuit NeverInlineComplexIntoComplex :
+ | module NeverInlineComplexIntoComplex :
+ | input a : SInt<3>
+ | input b : UInt<2>
+ | input c : UInt<2>
+ | input sel : UInt<1>
+ | output out : SInt<3>
+ | node diff = sub(b, c)
+ | out <= mux(sel, a, asSInt(diff))
+ |""".stripMargin
+ val expected =
+ """module NeverInlineComplexIntoComplexRef(
+ | input [2:0] a,
+ | input [1:0] b,
+ | input [1:0] c,
+ | input sel,
+ | output [2:0] out
+ |);
+ | wire [2:0] diff = b - c;
+ | assign out = sel ? $signed(a) : $signed(diff);
+ |endmodule
+ |""".stripMargin
+ firrtlEquivalenceWithVerilog(input, expected)
+ }
+
+ it should "inline casts on both sides of a more complex expression" in {
+ val input =
+ """circuit test :
+ | module test :
+ | input clock : Clock
+ | input in : UInt<8>
+ | output out : UInt<8>
+ |
+ | node _T_1 = asUInt(clock)
+ | node _T_2 = not(_T_1)
+ | node clock_n = asClock(_T_2)
+ | reg r : UInt<8>, clock_n
+ | r <= in
+ | out <= r
+ |""".stripMargin
+ val verilog = compileToVerilogCircuitState(input)
+ verilog should containLine("always @(posedge clock_n) begin")
+
+ }
}