diff options
| author | Albert Chen | 2020-09-09 12:25:20 -0700 |
|---|---|---|
| committer | GitHub | 2020-09-09 19:25:20 +0000 |
| commit | bc7ac7013ecdf956b7cd61f0f0a60c7272d49cd6 (patch) | |
| tree | 5bc54ffb5d87fa363b2edfbc4a0b6e6cae7ad3fd | |
| parent | e420f99d87ece9f56504b3afc2e37d40b6e8c7b1 (diff) | |
Loosen inlining restrictions (#1882)
* test multiinfo comparison and mux cond inlining
* loosen inlining conditions
* fix typo
* include dshlw
* fix test
| -rw-r--r-- | src/main/scala/firrtl/transforms/InlineBooleanExpressions.scala | 57 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala | 97 |
2 files changed, 123 insertions, 31 deletions
diff --git a/src/main/scala/firrtl/transforms/InlineBooleanExpressions.scala b/src/main/scala/firrtl/transforms/InlineBooleanExpressions.scala index b405f353..29bdde0f 100644 --- a/src/main/scala/firrtl/transforms/InlineBooleanExpressions.scala +++ b/src/main/scala/firrtl/transforms/InlineBooleanExpressions.scala @@ -61,20 +61,21 @@ class InlineBooleanExpressions extends Transform with DependencyAPIMigration { } private val fileLineRegex = """(.*) ([0-9]+):[0-9]+""".r - private def sameFileAndLineInfo(info1: Info, info2: Info): Boolean = { - (info1, info2) match { - case (FileInfo(fileLineRegex(file1, line1)), FileInfo(fileLineRegex(file2, line2))) => - (file1 == file2) && (line1 == line2) - case (MultiInfo(infos1), MultiInfo(infos2)) if infos1.size == infos2.size => - infos1.zip(infos2).forall { - case (i1, i2) => - sameFileAndLineInfo(i1, i2) - } - case (NoInfo, NoInfo) => true - case _ => false + private def getFileAndLineNumbers(info: Info): Set[(String, String)] = { + info match { + case FileInfo(fileLineRegex(file, line)) => Set(file -> line) + case FileInfo(file) => Set(file -> "0") + case MultiInfo(infos) => infos.flatMap(getFileAndLineNumbers).toSet + case NoInfo => Set.empty[(String, String)] } } + private def sameFileAndLineInfo(info1: Info, info2: Info): Boolean = { + val set1 = getFileAndLineNumbers(info1) + val set2 = getFileAndLineNumbers(info2) + set1.subsetOf(set2) + } + /** A helper class to initialize and store mutable state that the expression * and statement map functions need access to. This makes it easier to pass * information around without having to plump arguments through the onExpr @@ -86,22 +87,34 @@ class InlineBooleanExpressions extends Transform with DependencyAPIMigration { var inlineCount: Int = 1 /** Whether or not an can be inlined + * @param ref the WRef that references refExpr * @param refExpr the expression to check for inlining * @param outerExpr the parent expression of refExpr, if any */ - def canInline(refExpr: Expression, outerExpr: Option[Expression]): Boolean = { + def canInline(ref: WRef, refExpr: Expression, outerExpr: Option[Expression]): Boolean = { val contextInsensitiveDetOps: Set[PrimOp] = Set(Lt, Leq, Gt, Geq, Eq, Neq, Andr, Orr, Xorr) outerExpr match { case None => true - case Some(o) if (o.tpe == Utils.BoolType) => - refExpr match { - case _: Mux => false - case e => e.tpe == Utils.BoolType - } case Some(o) => - refExpr match { - case DoPrim(op, _, _, Utils.BoolType) => contextInsensitiveDetOps(op) - case _ => false + if ((refExpr.tpe != Utils.BoolType) || refExpr.isInstanceOf[Mux]) { + false + } else { + o match { + // if outer expression is also boolean context does not affect width + case o if o.tpe == Utils.BoolType => true + + // mux condition argument is self-determined + case m: Mux if m.cond eq ref => true + + // dshl/dshr second argument is self-determined + case DoPrim(Dshl | Dshlw | Dshr, Seq(_, shamt), _, _) if shamt eq ref => true + + case o => + refExpr match { + case DoPrim(op, _, _, _) => contextInsensitiveDetOps(op) + case _ => false + } + } } } } @@ -118,10 +131,10 @@ class InlineBooleanExpressions extends Transform with DependencyAPIMigration { case ref: WRef if !dontTouches.contains(ref.name.Ref) && ref.name.head == '_' => val refKey = ref.name.Ref netlist.get(we(ref)) match { - case Some((refExpr, refInfo)) if sameFileAndLineInfo(info, refInfo) => + case Some((refExpr, refInfo)) if sameFileAndLineInfo(refInfo, info) => val inlineNum = inlineCounts.getOrElse(refKey, 1) val notTooDeep = !outerExpr.isDefined || ((inlineNum + inlineCount) <= maxInlineCount) - if (canInline(refExpr, outerExpr) && notTooDeep) { + if (canInline(ref, refExpr, outerExpr) && notTooDeep) { inlineCount += inlineNum refExpr } else { diff --git a/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala b/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala index 15dffee6..d7c79836 100644 --- a/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala +++ b/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala @@ -3,6 +3,7 @@ package firrtlTests import firrtl._ +import firrtl.ir.{Circuit, Connect, FileInfo, MultiInfo, Statement} import firrtl.annotations.{Annotation, ReferenceTarget} import firrtl.options.Dependency import firrtl.passes._ @@ -16,9 +17,9 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec { transform.prerequisites ).flattenedTransformOrder :+ transform - protected def exec(input: String, annos: Seq[Annotation] = Nil) = { + protected def exec(input: Circuit, annos: Seq[Annotation] = Nil) = { transforms - .foldLeft(CircuitState(parse(input), UnknownForm, AnnotationSeq(annos))) { (c: CircuitState, t: Transform) => + .foldLeft(CircuitState(input, UnknownForm, AnnotationSeq(annos))) { (c: CircuitState, t: Transform) => t.runTransform(c) } .circuit @@ -48,7 +49,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec { | node _c = lt(x1, x2) | node _y = mux(lt(x1, x2), head(x1, 1), head(x2, 1)) | out <= mux(lt(x1, x2), head(x1, 1), head(x2, 1))""".stripMargin - val result = exec(input) + val result = exec(parse(input)) (result) should be(parse(check).serialize) firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions)) } @@ -77,7 +78,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec { | node _y = mux(lt(x1, x2), _t, _f) | out <= mux(lt(x1, x2), _t, _f)""".stripMargin val result = exec( - input, + parse(input), Seq( DontTouchAnnotation(ReferenceTarget("Top", "Top", Seq.empty, "_t", Seq.empty)), DontTouchAnnotation(ReferenceTarget("Top", "Top", Seq.empty, "_f", Seq.empty)) @@ -122,11 +123,89 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec { | outA2 <= _y @[A 2:3] | | outB <= _y @[B]""".stripMargin - val result = exec(input) + val result = exec(parse(input)) (result) should be(parse(check).serialize) firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions)) } + it should "inline if subexpression info is a subset of parent info" in { + val input = + parse("""circuit test : + | module test : + | input in_1 : UInt<1> + | input in_2 : UInt<1> + | input in_3 : UInt<1> + | output out : UInt<1> + | node _c = in_1 @[A 1:1] + | node _t = in_2 @[A 1:1] + | node _f = in_3 @[A 1:1] + | out <= mux(_c, _t, _f)""".stripMargin).mapModule { m => + // workaround to insert MultiInfo + def onStmt(stmt: Statement): Statement = stmt match { + case c: Connect => + c.mapInfo { _ => + MultiInfo( + Seq( + FileInfo("A 1:1"), + FileInfo("A 2:2"), + FileInfo("A 3:3") + ) + ) + } + case other => other.mapStmt(onStmt) + } + m.mapStmt(onStmt) + } + val check = + """circuit test : + | module test : + | input in_1 : UInt<1> + | input in_2 : UInt<1> + | input in_3 : UInt<1> + | output out : UInt<1> + | node _c = in_1 @[A 1:1] + | node _t = in_2 @[A 1:1] + | node _f = in_3 @[A 1:1] + | out <= mux(in_1, in_2, in_3) @[A 1:1 A 2:2 A 3:3]""".stripMargin + val result = exec(input) + (result) should be(parse(check).serialize) + } + + it should "inline mux condition and dshl/dhslr shamt args" in { + val input = + """circuit inline_mux_dshl_dshlr_args : + | module inline_mux_dshl_dshlr_args : + | input in_1 : UInt<3> + | input in_2 : UInt<3> + | input in_3 : UInt<3> + | output out_1 : UInt<3> + | output out_2 : UInt<3> + | output out_3 : UInt<4> + | node _c = head(in_1, 1) + | node _t = in_2 + | node _f = in_3 + | out_1 <= mux(_c, _t, _f) + | out_2 <= dshr(in_1, _c) + | out_3 <= dshl(in_1, _c)""".stripMargin + val check = + """circuit inline_mux_dshl_dshlr_args : + | module inline_mux_dshl_dshlr_args : + | input in_1 : UInt<3> + | input in_2 : UInt<3> + | input in_3 : UInt<3> + | output out_1 : UInt<3> + | output out_2 : UInt<3> + | output out_3 : UInt<4> + | node _c = head(in_1, 1) + | node _t = in_2 + | node _f = in_3 + | out_1 <= mux(head(in_1, 1), _t, _f) + | out_2 <= dshr(in_1, head(in_1, 1)) + | out_3 <= dshl(in_1, head(in_1, 1))""".stripMargin + val result = exec(parse(input)) + (result) should be(parse(check).serialize) + } + it should "inline boolean DoPrims" in { val input = """circuit Top : @@ -163,7 +242,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec { | node _f = lt(andr(head(_c, 1)), x2) | | outB <= lt(andr(head(_c, 1)), x2)""".stripMargin - val result = exec(input) + val result = exec(parse(input)) (result) should be(parse(check).serialize) firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions)) } @@ -208,7 +287,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec { | node _h = geq(x1, gt(x1, leq(x1, lt(x1, x2)))) | | outB <= geq(x1, gt(x1, leq(x1, lt(x1, x2))))""".stripMargin - val result = exec(input) + val result = exec(parse(input)) (result) should be(parse(check).serialize) firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions)) } @@ -254,7 +333,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec { | node _6 = or(or(or(_3, c_4), c_5), c_6) | | out <= or(or(or(_3, c_4), c_5), c_6)""".stripMargin - val result = exec(input, Seq(InlineBooleanExpressionsMax(3))) + val result = exec(parse(input), Seq(InlineBooleanExpressionsMax(3))) (result) should be(parse(check).serialize) firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions)) } @@ -324,7 +403,7 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec { | | node _T_1 = and(a, b) | out <= and(_T_1, c)""".stripMargin - val result = exec(input, PrettyNoExprInlining :: Nil) + val result = exec(parse(input), PrettyNoExprInlining :: Nil) (result) should be(parse(input).serialize) } } |
