diff options
| author | Jack Koenig | 2017-06-29 16:04:13 -0700 |
|---|---|---|
| committer | Jack Koenig | 2017-06-29 21:16:13 -0700 |
| commit | a43486b65506620f89f3e171101353b2dde65351 (patch) | |
| tree | a46e85ea7ea56e80596b645f4619f6f74ae1ad56 /src | |
| parent | ad3c3a6fcb5bc374bd56c7dd2591fb1def1a5e1b (diff) | |
Preserve "better" names in Constant Propagation
Names that do not start with '_' are "better" than those that do
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 46 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ConstantPropagationTests.scala | 103 |
2 files changed, 142 insertions, 7 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index efe06e9b..31a6a660 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -11,6 +11,7 @@ import firrtl.Mappers._ import firrtl.PrimOps._ import annotation.tailrec +import collection.mutable class ConstantPropagation extends Transform { def inputForm = LowForm @@ -239,18 +240,31 @@ class ConstantPropagation extends Transform { case _ => r } + // Is "a" a "better name" than "b"? + private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_') + // Two pass process // 1. Propagate constants in expressions and forward propagate references // 2. Propagate references again for backwards reference (Wires) // TODO Replacing all wires with nodes makes the second pass unnecessary + // However, preserving decent names DOES require a second pass + // Replacing all wires with nodes makes it unnecessary for preserving decent names to trigger an + // extra iteration though @tailrec private def constPropModule(m: Module, dontTouches: Set[String]): Module = { var nPropagated = 0L - val nodeMap = collection.mutable.HashMap[String, Expression]() + val nodeMap = mutable.HashMap.empty[String, Expression] + // For cases where we are trying to constprop a bad name over a good one, we swap their names + // during the second pass + val swapMap = mutable.HashMap.empty[String, String] def backPropExpr(expr: Expression): Expression = { val old = expr map backPropExpr val propagated = old match { + // When swapping, we swap both rhs and lhs + case ref @ WRef(rname, _,_,_) if swapMap.contains(rname) => + ref.copy(name = swapMap(rname)) + // Only const prop on the rhs case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => constPropNodeRef(ref, nodeMap(rname)) case x => x @@ -260,7 +274,19 @@ class ConstantPropagation extends Transform { } propagated } - def backPropStmt(stmt: Statement): Statement = stmt map backPropStmt map backPropExpr + + def backPropStmt(stmt: Statement): Statement = stmt map backPropExpr match { + case decl: IsDeclaration if swapMap.contains(decl.name) => + val newName = swapMap(decl.name) + nPropagated += 1 + decl match { + case node: DefNode => node.copy(name = newName) + case wire: DefWire => wire.copy(name = newName) + case reg: DefRegister => reg.copy(name = newName) + case other => throwInternalError + } + case other => other map backPropStmt + } def constPropExpression(e: Expression): Expression = { val old = e map constPropExpression @@ -274,19 +300,29 @@ class ConstantPropagation extends Transform { propagated } + // When propagating a reference, check if we want to keep the name that would be deleted + def propagateRef(lname: String, value: Expression): Unit = { + value match { + case WRef(rname,_,_,_) if betterName(lname, rname) => + swapMap += (lname -> rname, rname -> lname) + case _ => + } + nodeMap(lname) = value + } + def constPropStmt(s: Statement): Statement = { val stmtx = s map constPropStmt map constPropExpression stmtx match { - case x: DefNode if !dontTouches.contains(x.name) => nodeMap(x.name) = x.value + case x: DefNode if !dontTouches.contains(x.name) => propagateRef(x.name, x.value) case Connect(_, WRef(wname, wtpe, WireKind, _), expr) if !dontTouches.contains(wname) => val exprx = constPropExpression(pad(expr, wtpe)) - nodeMap(wname) = exprx + propagateRef(wname, exprx) case _ => } stmtx } - val res = Module(m.info, m.name, m.ports, backPropStmt(constPropStmt(m.body))) + val res = m.copy(body = backPropStmt(constPropStmt(m.body))) if (nPropagated > 0) constPropModule(res, dontTouches) else res } diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index f818f9c0..8f09ac9e 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -372,13 +372,92 @@ class ConstantPropagationSpec extends FirrtlFlatSpec { """ (parse(exec(input))) should be (parse(check)) } + + // ============================= + "ConstProp" should "swap named nodes with temporary nodes that drive them" in { + val input = +"""circuit Top : + module Top : + input x : UInt<1> + input y : UInt<1> + output z : UInt<1> + node _T_1 = and(x, y) + node n = _T_1 + z <= n +""" + val check = +"""circuit Top : + module Top : + input x : UInt<1> + input y : UInt<1> + output z : UInt<1> + node n = and(x, y) + node _T_1 = n + z <= n +""" + (parse(exec(input))) should be (parse(check)) + } + + // ============================= + "ConstProp" should "swap named nodes with temporary wires that drive them" in { + val input = +"""circuit Top : + module Top : + input x : UInt<1> + input y : UInt<1> + output z : UInt<1> + wire _T_1 : UInt<1> + node n = _T_1 + z <= n + _T_1 <= and(x, y) +""" + val check = +"""circuit Top : + module Top : + input x : UInt<1> + input y : UInt<1> + output z : UInt<1> + wire n : UInt<1> + node _T_1 = n + z <= n + n <= and(x, y) +""" + (parse(exec(input))) should be (parse(check)) + } + + // ============================= + "ConstProp" should "swap named nodes with temporary registers that drive them" in { + val input = +"""circuit Top : + module Top : + input clock : Clock + input x : UInt<1> + output z : UInt<1> + reg _T_1 : UInt<1>, clock with : (reset => (UInt<1>(0), _T_1)) + node n = _T_1 + z <= n + _T_1 <= x +""" + val check = +"""circuit Top : + module Top : + input clock : Clock + input x : UInt<1> + output z : UInt<1> + reg n : UInt<1>, clock with : (reset => (UInt<1>(0), n)) + node _T_1 = n + z <= n + n <= x +""" + (parse(exec(input))) should be (parse(check)) + } } // More sophisticated tests of the full compiler class ConstantPropagationIntegrationSpec extends LowTransformSpec { def transform = new LowFirrtlOptimization - "ConstProp" should "should not optimize across dontTouch on nodes" in { + "ConstProp" should "NOT optimize across dontTouch on nodes" in { val input = """circuit Top : | module Top : @@ -396,7 +475,7 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { execute(input, check, Seq(dontTouch("Top.z"))) } - it should "should not optimize across dontTouch on wires" in { + it should "NOT optimize across dontTouch on wires" in { val input = """circuit Top : | module Top : @@ -415,4 +494,24 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { | z <= x""".stripMargin execute(input, check, Seq(dontTouch("Top.z"))) } + + it should "still propagate constants even when there is name swapping" in { + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | input y : UInt<1> + | output z : UInt<1> + | node _T_1 = and(and(x, y), UInt<1>(0)) + | node n = _T_1 + | z <= n""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | input y : UInt<1> + | output z : UInt<1> + | z <= UInt<1>(0)""".stripMargin + execute(input, check, Seq.empty) + } } |
