aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlbert Chen2020-09-09 12:25:20 -0700
committerGitHub2020-09-09 19:25:20 +0000
commitbc7ac7013ecdf956b7cd61f0f0a60c7272d49cd6 (patch)
tree5bc54ffb5d87fa363b2edfbc4a0b6e6cae7ad3fd
parente420f99d87ece9f56504b3afc2e37d40b6e8c7b1 (diff)
Loosen inlining restrictions (#1882)
* test multiinfo comparison and mux cond inlining * loosen inlining conditions * fix typo * include dshlw * fix test
-rw-r--r--src/main/scala/firrtl/transforms/InlineBooleanExpressions.scala57
-rw-r--r--src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala97
2 files changed, 123 insertions, 31 deletions
diff --git a/src/main/scala/firrtl/transforms/InlineBooleanExpressions.scala b/src/main/scala/firrtl/transforms/InlineBooleanExpressions.scala
index b405f353..29bdde0f 100644
--- a/src/main/scala/firrtl/transforms/InlineBooleanExpressions.scala
+++ b/src/main/scala/firrtl/transforms/InlineBooleanExpressions.scala
@@ -61,20 +61,21 @@ class InlineBooleanExpressions extends Transform with DependencyAPIMigration {
}
private val fileLineRegex = """(.*) ([0-9]+):[0-9]+""".r
- private def sameFileAndLineInfo(info1: Info, info2: Info): Boolean = {
- (info1, info2) match {
- case (FileInfo(fileLineRegex(file1, line1)), FileInfo(fileLineRegex(file2, line2))) =>
- (file1 == file2) && (line1 == line2)
- case (MultiInfo(infos1), MultiInfo(infos2)) if infos1.size == infos2.size =>
- infos1.zip(infos2).forall {
- case (i1, i2) =>
- sameFileAndLineInfo(i1, i2)
- }
- case (NoInfo, NoInfo) => true
- case _ => false
+ private def getFileAndLineNumbers(info: Info): Set[(String, String)] = {
+ info match {
+ case FileInfo(fileLineRegex(file, line)) => Set(file -> line)
+ case FileInfo(file) => Set(file -> "0")
+ case MultiInfo(infos) => infos.flatMap(getFileAndLineNumbers).toSet
+ case NoInfo => Set.empty[(String, String)]
}
}
+ private def sameFileAndLineInfo(info1: Info, info2: Info): Boolean = {
+ val set1 = getFileAndLineNumbers(info1)
+ val set2 = getFileAndLineNumbers(info2)
+ set1.subsetOf(set2)
+ }
+
/** A helper class to initialize and store mutable state that the expression
* and statement map functions need access to. This makes it easier to pass
* information around without having to plump arguments through the onExpr
@@ -86,22 +87,34 @@ class InlineBooleanExpressions extends Transform with DependencyAPIMigration {
var inlineCount: Int = 1
/** Whether or not an can be inlined
+ * @param ref the WRef that references refExpr
* @param refExpr the expression to check for inlining
* @param outerExpr the parent expression of refExpr, if any
*/
- def canInline(refExpr: Expression, outerExpr: Option[Expression]): Boolean = {
+ def canInline(ref: WRef, refExpr: Expression, outerExpr: Option[Expression]): Boolean = {
val contextInsensitiveDetOps: Set[PrimOp] = Set(Lt, Leq, Gt, Geq, Eq, Neq, Andr, Orr, Xorr)
outerExpr match {
case None => true
- case Some(o) if (o.tpe == Utils.BoolType) =>
- refExpr match {
- case _: Mux => false
- case e => e.tpe == Utils.BoolType
- }
case Some(o) =>
- refExpr match {
- case DoPrim(op, _, _, Utils.BoolType) => contextInsensitiveDetOps(op)
- case _ => false
+ if ((refExpr.tpe != Utils.BoolType) || refExpr.isInstanceOf[Mux]) {
+ false
+ } else {
+ o match {
+ // if outer expression is also boolean context does not affect width
+ case o if o.tpe == Utils.BoolType => true
+
+ // mux condition argument is self-determined
+ case m: Mux if m.cond eq ref => true
+
+ // dshl/dshr second argument is self-determined
+ case DoPrim(Dshl | Dshlw | Dshr, Seq(_, shamt), _, _) if shamt eq ref => true
+
+ case o =>
+ refExpr match {
+ case DoPrim(op, _, _, _) => contextInsensitiveDetOps(op)
+ case _ => false
+ }
+ }
}
}
}
@@ -118,10 +131,10 @@ class InlineBooleanExpressions extends Transform with DependencyAPIMigration {
case ref: WRef if !dontTouches.contains(ref.name.Ref) && ref.name.head == '_' =>
val refKey = ref.name.Ref
netlist.get(we(ref)) match {
- case Some((refExpr, refInfo)) if sameFileAndLineInfo(info, refInfo) =>
+ case Some((refExpr, refInfo)) if sameFileAndLineInfo(refInfo, info) =>
val inlineNum = inlineCounts.getOrElse(refKey, 1)
val notTooDeep = !outerExpr.isDefined || ((inlineNum + inlineCount) <= maxInlineCount)
- if (canInline(refExpr, outerExpr) && notTooDeep) {
+ if (canInline(ref, refExpr, outerExpr) && notTooDeep) {
inlineCount += inlineNum
refExpr
} else {
diff --git a/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala b/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala
index 15dffee6..d7c79836 100644
--- a/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala
+++ b/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala
@@ -3,6 +3,7 @@
package firrtlTests
import firrtl._
+import firrtl.ir.{Circuit, Connect, FileInfo, MultiInfo, Statement}
import firrtl.annotations.{Annotation, ReferenceTarget}
import firrtl.options.Dependency
import firrtl.passes._
@@ -16,9 +17,9 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec {
transform.prerequisites
).flattenedTransformOrder :+ transform
- protected def exec(input: String, annos: Seq[Annotation] = Nil) = {
+ protected def exec(input: Circuit, annos: Seq[Annotation] = Nil) = {
transforms
- .foldLeft(CircuitState(parse(input), UnknownForm, AnnotationSeq(annos))) { (c: CircuitState, t: Transform) =>
+ .foldLeft(CircuitState(input, UnknownForm, AnnotationSeq(annos))) { (c: CircuitState, t: Transform) =>
t.runTransform(c)
}
.circuit
@@ -48,7 +49,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec {
| node _c = lt(x1, x2)
| node _y = mux(lt(x1, x2), head(x1, 1), head(x2, 1))
| out <= mux(lt(x1, x2), head(x1, 1), head(x2, 1))""".stripMargin
- val result = exec(input)
+ val result = exec(parse(input))
(result) should be(parse(check).serialize)
firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions))
}
@@ -77,7 +78,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec {
| node _y = mux(lt(x1, x2), _t, _f)
| out <= mux(lt(x1, x2), _t, _f)""".stripMargin
val result = exec(
- input,
+ parse(input),
Seq(
DontTouchAnnotation(ReferenceTarget("Top", "Top", Seq.empty, "_t", Seq.empty)),
DontTouchAnnotation(ReferenceTarget("Top", "Top", Seq.empty, "_f", Seq.empty))
@@ -122,11 +123,89 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec {
| outA2 <= _y @[A 2:3]
|
| outB <= _y @[B]""".stripMargin
- val result = exec(input)
+ val result = exec(parse(input))
(result) should be(parse(check).serialize)
firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions))
}
+ it should "inline if subexpression info is a subset of parent info" in {
+ val input =
+ parse("""circuit test :
+ | module test :
+ | input in_1 : UInt<1>
+ | input in_2 : UInt<1>
+ | input in_3 : UInt<1>
+ | output out : UInt<1>
+ | node _c = in_1 @[A 1:1]
+ | node _t = in_2 @[A 1:1]
+ | node _f = in_3 @[A 1:1]
+ | out <= mux(_c, _t, _f)""".stripMargin).mapModule { m =>
+ // workaround to insert MultiInfo
+ def onStmt(stmt: Statement): Statement = stmt match {
+ case c: Connect =>
+ c.mapInfo { _ =>
+ MultiInfo(
+ Seq(
+ FileInfo("A 1:1"),
+ FileInfo("A 2:2"),
+ FileInfo("A 3:3")
+ )
+ )
+ }
+ case other => other.mapStmt(onStmt)
+ }
+ m.mapStmt(onStmt)
+ }
+ val check =
+ """circuit test :
+ | module test :
+ | input in_1 : UInt<1>
+ | input in_2 : UInt<1>
+ | input in_3 : UInt<1>
+ | output out : UInt<1>
+ | node _c = in_1 @[A 1:1]
+ | node _t = in_2 @[A 1:1]
+ | node _f = in_3 @[A 1:1]
+ | out <= mux(in_1, in_2, in_3) @[A 1:1 A 2:2 A 3:3]""".stripMargin
+ val result = exec(input)
+ (result) should be(parse(check).serialize)
+ }
+
+ it should "inline mux condition and dshl/dhslr shamt args" in {
+ val input =
+ """circuit inline_mux_dshl_dshlr_args :
+ | module inline_mux_dshl_dshlr_args :
+ | input in_1 : UInt<3>
+ | input in_2 : UInt<3>
+ | input in_3 : UInt<3>
+ | output out_1 : UInt<3>
+ | output out_2 : UInt<3>
+ | output out_3 : UInt<4>
+ | node _c = head(in_1, 1)
+ | node _t = in_2
+ | node _f = in_3
+ | out_1 <= mux(_c, _t, _f)
+ | out_2 <= dshr(in_1, _c)
+ | out_3 <= dshl(in_1, _c)""".stripMargin
+ val check =
+ """circuit inline_mux_dshl_dshlr_args :
+ | module inline_mux_dshl_dshlr_args :
+ | input in_1 : UInt<3>
+ | input in_2 : UInt<3>
+ | input in_3 : UInt<3>
+ | output out_1 : UInt<3>
+ | output out_2 : UInt<3>
+ | output out_3 : UInt<4>
+ | node _c = head(in_1, 1)
+ | node _t = in_2
+ | node _f = in_3
+ | out_1 <= mux(head(in_1, 1), _t, _f)
+ | out_2 <= dshr(in_1, head(in_1, 1))
+ | out_3 <= dshl(in_1, head(in_1, 1))""".stripMargin
+ val result = exec(parse(input))
+ (result) should be(parse(check).serialize)
+ }
+
it should "inline boolean DoPrims" in {
val input =
"""circuit Top :
@@ -163,7 +242,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec {
| node _f = lt(andr(head(_c, 1)), x2)
|
| outB <= lt(andr(head(_c, 1)), x2)""".stripMargin
- val result = exec(input)
+ val result = exec(parse(input))
(result) should be(parse(check).serialize)
firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions))
}
@@ -208,7 +287,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec {
| node _h = geq(x1, gt(x1, leq(x1, lt(x1, x2))))
|
| outB <= geq(x1, gt(x1, leq(x1, lt(x1, x2))))""".stripMargin
- val result = exec(input)
+ val result = exec(parse(input))
(result) should be(parse(check).serialize)
firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions))
}
@@ -254,7 +333,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec {
| node _6 = or(or(or(_3, c_4), c_5), c_6)
|
| out <= or(or(or(_3, c_4), c_5), c_6)""".stripMargin
- val result = exec(input, Seq(InlineBooleanExpressionsMax(3)))
+ val result = exec(parse(input), Seq(InlineBooleanExpressionsMax(3)))
(result) should be(parse(check).serialize)
firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions))
}
@@ -324,7 +403,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec {
|
| node _T_1 = and(a, b)
| out <= and(_T_1, c)""".stripMargin
- val result = exec(input, PrettyNoExprInlining :: Nil)
+ val result = exec(parse(input), PrettyNoExprInlining :: Nil)
(result) should be(parse(input).serialize)
}
}