aboutsummaryrefslogtreecommitdiff
path: root/src/test
diff options
context:
space:
mode:
authorAlbert Chen2020-09-09 12:25:20 -0700
committerGitHub2020-09-09 19:25:20 +0000
commitbc7ac7013ecdf956b7cd61f0f0a60c7272d49cd6 (patch)
tree5bc54ffb5d87fa363b2edfbc4a0b6e6cae7ad3fd /src/test
parente420f99d87ece9f56504b3afc2e37d40b6e8c7b1 (diff)
Loosen inlining restrictions (#1882)
* test multiinfo comparison and mux cond inlining * loosen inlining conditions * fix typo * include dshlw * fix test
Diffstat (limited to 'src/test')
-rw-r--r--src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala97
1 files changed, 88 insertions, 9 deletions
diff --git a/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala b/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala
index 15dffee6..d7c79836 100644
--- a/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala
+++ b/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala
@@ -3,6 +3,7 @@
package firrtlTests
import firrtl._
+import firrtl.ir.{Circuit, Connect, FileInfo, MultiInfo, Statement}
import firrtl.annotations.{Annotation, ReferenceTarget}
import firrtl.options.Dependency
import firrtl.passes._
@@ -16,9 +17,9 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec {
transform.prerequisites
).flattenedTransformOrder :+ transform
- protected def exec(input: String, annos: Seq[Annotation] = Nil) = {
+ protected def exec(input: Circuit, annos: Seq[Annotation] = Nil) = {
transforms
- .foldLeft(CircuitState(parse(input), UnknownForm, AnnotationSeq(annos))) { (c: CircuitState, t: Transform) =>
+ .foldLeft(CircuitState(input, UnknownForm, AnnotationSeq(annos))) { (c: CircuitState, t: Transform) =>
t.runTransform(c)
}
.circuit
@@ -48,7 +49,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec {
| node _c = lt(x1, x2)
| node _y = mux(lt(x1, x2), head(x1, 1), head(x2, 1))
| out <= mux(lt(x1, x2), head(x1, 1), head(x2, 1))""".stripMargin
- val result = exec(input)
+ val result = exec(parse(input))
(result) should be(parse(check).serialize)
firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions))
}
@@ -77,7 +78,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec {
| node _y = mux(lt(x1, x2), _t, _f)
| out <= mux(lt(x1, x2), _t, _f)""".stripMargin
val result = exec(
- input,
+ parse(input),
Seq(
DontTouchAnnotation(ReferenceTarget("Top", "Top", Seq.empty, "_t", Seq.empty)),
DontTouchAnnotation(ReferenceTarget("Top", "Top", Seq.empty, "_f", Seq.empty))
@@ -122,11 +123,89 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec {
| outA2 <= _y @[A 2:3]
|
| outB <= _y @[B]""".stripMargin
- val result = exec(input)
+ val result = exec(parse(input))
(result) should be(parse(check).serialize)
firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions))
}
+ it should "inline if subexpression info is a subset of parent info" in {
+ val input =
+ parse("""circuit test :
+ | module test :
+ | input in_1 : UInt<1>
+ | input in_2 : UInt<1>
+ | input in_3 : UInt<1>
+ | output out : UInt<1>
+ | node _c = in_1 @[A 1:1]
+ | node _t = in_2 @[A 1:1]
+ | node _f = in_3 @[A 1:1]
+ | out <= mux(_c, _t, _f)""".stripMargin).mapModule { m =>
+ // workaround to insert MultiInfo
+ def onStmt(stmt: Statement): Statement = stmt match {
+ case c: Connect =>
+ c.mapInfo { _ =>
+ MultiInfo(
+ Seq(
+ FileInfo("A 1:1"),
+ FileInfo("A 2:2"),
+ FileInfo("A 3:3")
+ )
+ )
+ }
+ case other => other.mapStmt(onStmt)
+ }
+ m.mapStmt(onStmt)
+ }
+ val check =
+ """circuit test :
+ | module test :
+ | input in_1 : UInt<1>
+ | input in_2 : UInt<1>
+ | input in_3 : UInt<1>
+ | output out : UInt<1>
+ | node _c = in_1 @[A 1:1]
+ | node _t = in_2 @[A 1:1]
+ | node _f = in_3 @[A 1:1]
+ | out <= mux(in_1, in_2, in_3) @[A 1:1 A 2:2 A 3:3]""".stripMargin
+ val result = exec(input)
+ (result) should be(parse(check).serialize)
+ }
+
+ it should "inline mux condition and dshl/dhslr shamt args" in {
+ val input =
+ """circuit inline_mux_dshl_dshlr_args :
+ | module inline_mux_dshl_dshlr_args :
+ | input in_1 : UInt<3>
+ | input in_2 : UInt<3>
+ | input in_3 : UInt<3>
+ | output out_1 : UInt<3>
+ | output out_2 : UInt<3>
+ | output out_3 : UInt<4>
+ | node _c = head(in_1, 1)
+ | node _t = in_2
+ | node _f = in_3
+ | out_1 <= mux(_c, _t, _f)
+ | out_2 <= dshr(in_1, _c)
+ | out_3 <= dshl(in_1, _c)""".stripMargin
+ val check =
+ """circuit inline_mux_dshl_dshlr_args :
+ | module inline_mux_dshl_dshlr_args :
+ | input in_1 : UInt<3>
+ | input in_2 : UInt<3>
+ | input in_3 : UInt<3>
+ | output out_1 : UInt<3>
+ | output out_2 : UInt<3>
+ | output out_3 : UInt<4>
+ | node _c = head(in_1, 1)
+ | node _t = in_2
+ | node _f = in_3
+ | out_1 <= mux(head(in_1, 1), _t, _f)
+ | out_2 <= dshr(in_1, head(in_1, 1))
+ | out_3 <= dshl(in_1, head(in_1, 1))""".stripMargin
+ val result = exec(parse(input))
+ (result) should be(parse(check).serialize)
+ }
+
it should "inline boolean DoPrims" in {
val input =
"""circuit Top :
@@ -163,7 +242,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec {
| node _f = lt(andr(head(_c, 1)), x2)
|
| outB <= lt(andr(head(_c, 1)), x2)""".stripMargin
- val result = exec(input)
+ val result = exec(parse(input))
(result) should be(parse(check).serialize)
firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions))
}
@@ -208,7 +287,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec {
| node _h = geq(x1, gt(x1, leq(x1, lt(x1, x2))))
|
| outB <= geq(x1, gt(x1, leq(x1, lt(x1, x2))))""".stripMargin
- val result = exec(input)
+ val result = exec(parse(input))
(result) should be(parse(check).serialize)
firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions))
}
@@ -254,7 +333,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec {
| node _6 = or(or(or(_3, c_4), c_5), c_6)
|
| out <= or(or(or(_3, c_4), c_5), c_6)""".stripMargin
- val result = exec(input, Seq(InlineBooleanExpressionsMax(3)))
+ val result = exec(parse(input), Seq(InlineBooleanExpressionsMax(3)))
(result) should be(parse(check).serialize)
firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions))
}
@@ -324,7 +403,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec {
|
| node _T_1 = and(a, b)
| out <= and(_T_1, c)""".stripMargin
- val result = exec(input, PrettyNoExprInlining :: Nil)
+ val result = exec(parse(input), PrettyNoExprInlining :: Nil)
(result) should be(parse(input).serialize)
}
}