diff options
| author | Albert Chen | 2020-09-09 12:25:20 -0700 |
|---|---|---|
| committer | GitHub | 2020-09-09 19:25:20 +0000 |
| commit | bc7ac7013ecdf956b7cd61f0f0a60c7272d49cd6 (patch) | |
| tree | 5bc54ffb5d87fa363b2edfbc4a0b6e6cae7ad3fd /src/test | |
| parent | e420f99d87ece9f56504b3afc2e37d40b6e8c7b1 (diff) | |
Loosen inlining restrictions (#1882)
* test multiinfo comparison and mux cond inlining
* loosen inlining conditions
* fix typo
* include dshlw
* fix test
Diffstat (limited to 'src/test')
| -rw-r--r-- | src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala | 97 |
1 files changed, 88 insertions, 9 deletions
diff --git a/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala b/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala index 15dffee6..d7c79836 100644 --- a/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala +++ b/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala @@ -3,6 +3,7 @@ package firrtlTests import firrtl._ +import firrtl.ir.{Circuit, Connect, FileInfo, MultiInfo, Statement} import firrtl.annotations.{Annotation, ReferenceTarget} import firrtl.options.Dependency import firrtl.passes._ @@ -16,9 +17,9 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec { transform.prerequisites ).flattenedTransformOrder :+ transform - protected def exec(input: String, annos: Seq[Annotation] = Nil) = { + protected def exec(input: Circuit, annos: Seq[Annotation] = Nil) = { transforms - .foldLeft(CircuitState(parse(input), UnknownForm, AnnotationSeq(annos))) { (c: CircuitState, t: Transform) => + .foldLeft(CircuitState(input, UnknownForm, AnnotationSeq(annos))) { (c: CircuitState, t: Transform) => t.runTransform(c) } .circuit @@ -48,7 +49,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec { | node _c = lt(x1, x2) | node _y = mux(lt(x1, x2), head(x1, 1), head(x2, 1)) | out <= mux(lt(x1, x2), head(x1, 1), head(x2, 1))""".stripMargin - val result = exec(input) + val result = exec(parse(input)) (result) should be(parse(check).serialize) firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions)) } @@ -77,7 +78,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec { | node _y = mux(lt(x1, x2), _t, _f) | out <= mux(lt(x1, x2), _t, _f)""".stripMargin val result = exec( - input, + parse(input), Seq( DontTouchAnnotation(ReferenceTarget("Top", "Top", Seq.empty, "_t", Seq.empty)), DontTouchAnnotation(ReferenceTarget("Top", "Top", Seq.empty, "_f", Seq.empty)) @@ -122,11 +123,89 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec { | outA2 <= _y @[A 2:3] | | outB <= _y @[B]""".stripMargin - val result = exec(input) + val result = exec(parse(input)) (result) should be(parse(check).serialize) firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions)) } + it should "inline if subexpression info is a subset of parent info" in { + val input = + parse("""circuit test : + | module test : + | input in_1 : UInt<1> + | input in_2 : UInt<1> + | input in_3 : UInt<1> + | output out : UInt<1> + | node _c = in_1 @[A 1:1] + | node _t = in_2 @[A 1:1] + | node _f = in_3 @[A 1:1] + | out <= mux(_c, _t, _f)""".stripMargin).mapModule { m => + // workaround to insert MultiInfo + def onStmt(stmt: Statement): Statement = stmt match { + case c: Connect => + c.mapInfo { _ => + MultiInfo( + Seq( + FileInfo("A 1:1"), + FileInfo("A 2:2"), + FileInfo("A 3:3") + ) + ) + } + case other => other.mapStmt(onStmt) + } + m.mapStmt(onStmt) + } + val check = + """circuit test : + | module test : + | input in_1 : UInt<1> + | input in_2 : UInt<1> + | input in_3 : UInt<1> + | output out : UInt<1> + | node _c = in_1 @[A 1:1] + | node _t = in_2 @[A 1:1] + | node _f = in_3 @[A 1:1] + | out <= mux(in_1, in_2, in_3) @[A 1:1 A 2:2 A 3:3]""".stripMargin + val result = exec(input) + (result) should be(parse(check).serialize) + } + + it should "inline mux condition and dshl/dhslr shamt args" in { + val input = + """circuit inline_mux_dshl_dshlr_args : + | module inline_mux_dshl_dshlr_args : + | input in_1 : UInt<3> + | input in_2 : UInt<3> + | input in_3 : UInt<3> + | output out_1 : UInt<3> + | output out_2 : UInt<3> + | output out_3 : UInt<4> + | node _c = head(in_1, 1) + | node _t = in_2 + | node _f = in_3 + | out_1 <= mux(_c, _t, _f) + | out_2 <= dshr(in_1, _c) + | out_3 <= dshl(in_1, _c)""".stripMargin + val check = + """circuit inline_mux_dshl_dshlr_args : + | module inline_mux_dshl_dshlr_args : + | input in_1 : UInt<3> + | input in_2 : UInt<3> + | input in_3 : UInt<3> + | output out_1 : UInt<3> + | output out_2 : UInt<3> + | output out_3 : UInt<4> + | node _c = head(in_1, 1) + | node _t = in_2 + | node _f = in_3 + | out_1 <= mux(head(in_1, 1), _t, _f) + | out_2 <= dshr(in_1, head(in_1, 1)) + | out_3 <= dshl(in_1, head(in_1, 1))""".stripMargin + val result = exec(parse(input)) + (result) should be(parse(check).serialize) + } + it should "inline boolean DoPrims" in { val input = """circuit Top : @@ -163,7 +242,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec { | node _f = lt(andr(head(_c, 1)), x2) | | outB <= lt(andr(head(_c, 1)), x2)""".stripMargin - val result = exec(input) + val result = exec(parse(input)) (result) should be(parse(check).serialize) firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions)) } @@ -208,7 +287,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec { | node _h = geq(x1, gt(x1, leq(x1, lt(x1, x2)))) | | outB <= geq(x1, gt(x1, leq(x1, lt(x1, x2))))""".stripMargin - val result = exec(input) + val result = exec(parse(input)) (result) should be(parse(check).serialize) firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions)) } @@ -254,7 +333,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec { | node _6 = or(or(or(_3, c_4), c_5), c_6) | | out <= or(or(or(_3, c_4), c_5), c_6)""".stripMargin - val result = exec(input, Seq(InlineBooleanExpressionsMax(3))) + val result = exec(parse(input), Seq(InlineBooleanExpressionsMax(3))) (result) should be(parse(check).serialize) firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions)) } @@ -324,7 +403,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec { | | node _T_1 = and(a, b) | out <= and(_T_1, c)""".stripMargin - val result = exec(input, PrettyNoExprInlining :: Nil) + val result = exec(parse(input), PrettyNoExprInlining :: Nil) (result) should be(parse(input).serialize) } } |
