aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJack Koenig2017-06-29 16:04:13 -0700
committerJack Koenig2017-06-29 21:16:13 -0700
commita43486b65506620f89f3e171101353b2dde65351 (patch)
treea46e85ea7ea56e80596b645f4619f6f74ae1ad56 /src
parentad3c3a6fcb5bc374bd56c7dd2591fb1def1a5e1b (diff)
Preserve "better" names in Constant Propagation
Names that do not start with '_' are "better" than those that do
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala46
-rw-r--r--src/test/scala/firrtlTests/ConstantPropagationTests.scala103
2 files changed, 142 insertions, 7 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index efe06e9b..31a6a660 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -11,6 +11,7 @@ import firrtl.Mappers._
import firrtl.PrimOps._
import annotation.tailrec
+import collection.mutable
class ConstantPropagation extends Transform {
def inputForm = LowForm
@@ -239,18 +240,31 @@ class ConstantPropagation extends Transform {
case _ => r
}
+ // Is "a" a "better name" than "b"?
+ private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_')
+
// Two pass process
// 1. Propagate constants in expressions and forward propagate references
// 2. Propagate references again for backwards reference (Wires)
// TODO Replacing all wires with nodes makes the second pass unnecessary
+ // However, preserving decent names DOES require a second pass
+ // Replacing all wires with nodes makes it unnecessary for preserving decent names to trigger an
+ // extra iteration though
@tailrec
private def constPropModule(m: Module, dontTouches: Set[String]): Module = {
var nPropagated = 0L
- val nodeMap = collection.mutable.HashMap[String, Expression]()
+ val nodeMap = mutable.HashMap.empty[String, Expression]
+ // For cases where we are trying to constprop a bad name over a good one, we swap their names
+ // during the second pass
+ val swapMap = mutable.HashMap.empty[String, String]
def backPropExpr(expr: Expression): Expression = {
val old = expr map backPropExpr
val propagated = old match {
+ // When swapping, we swap both rhs and lhs
+ case ref @ WRef(rname, _,_,_) if swapMap.contains(rname) =>
+ ref.copy(name = swapMap(rname))
+ // Only const prop on the rhs
case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) =>
constPropNodeRef(ref, nodeMap(rname))
case x => x
@@ -260,7 +274,19 @@ class ConstantPropagation extends Transform {
}
propagated
}
- def backPropStmt(stmt: Statement): Statement = stmt map backPropStmt map backPropExpr
+
+ def backPropStmt(stmt: Statement): Statement = stmt map backPropExpr match {
+ case decl: IsDeclaration if swapMap.contains(decl.name) =>
+ val newName = swapMap(decl.name)
+ nPropagated += 1
+ decl match {
+ case node: DefNode => node.copy(name = newName)
+ case wire: DefWire => wire.copy(name = newName)
+ case reg: DefRegister => reg.copy(name = newName)
+ case other => throwInternalError
+ }
+ case other => other map backPropStmt
+ }
def constPropExpression(e: Expression): Expression = {
val old = e map constPropExpression
@@ -274,19 +300,29 @@ class ConstantPropagation extends Transform {
propagated
}
+ // When propagating a reference, check if we want to keep the name that would be deleted
+ def propagateRef(lname: String, value: Expression): Unit = {
+ value match {
+ case WRef(rname,_,_,_) if betterName(lname, rname) =>
+ swapMap += (lname -> rname, rname -> lname)
+ case _ =>
+ }
+ nodeMap(lname) = value
+ }
+
def constPropStmt(s: Statement): Statement = {
val stmtx = s map constPropStmt map constPropExpression
stmtx match {
- case x: DefNode if !dontTouches.contains(x.name) => nodeMap(x.name) = x.value
+ case x: DefNode if !dontTouches.contains(x.name) => propagateRef(x.name, x.value)
case Connect(_, WRef(wname, wtpe, WireKind, _), expr) if !dontTouches.contains(wname) =>
val exprx = constPropExpression(pad(expr, wtpe))
- nodeMap(wname) = exprx
+ propagateRef(wname, exprx)
case _ =>
}
stmtx
}
- val res = Module(m.info, m.name, m.ports, backPropStmt(constPropStmt(m.body)))
+ val res = m.copy(body = backPropStmt(constPropStmt(m.body)))
if (nPropagated > 0) constPropModule(res, dontTouches) else res
}
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
index f818f9c0..8f09ac9e 100644
--- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala
+++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
@@ -372,13 +372,92 @@ class ConstantPropagationSpec extends FirrtlFlatSpec {
"""
(parse(exec(input))) should be (parse(check))
}
+
+ // =============================
+ "ConstProp" should "swap named nodes with temporary nodes that drive them" in {
+ val input =
+"""circuit Top :
+ module Top :
+ input x : UInt<1>
+ input y : UInt<1>
+ output z : UInt<1>
+ node _T_1 = and(x, y)
+ node n = _T_1
+ z <= n
+"""
+ val check =
+"""circuit Top :
+ module Top :
+ input x : UInt<1>
+ input y : UInt<1>
+ output z : UInt<1>
+ node n = and(x, y)
+ node _T_1 = n
+ z <= n
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
+
+ // =============================
+ "ConstProp" should "swap named nodes with temporary wires that drive them" in {
+ val input =
+"""circuit Top :
+ module Top :
+ input x : UInt<1>
+ input y : UInt<1>
+ output z : UInt<1>
+ wire _T_1 : UInt<1>
+ node n = _T_1
+ z <= n
+ _T_1 <= and(x, y)
+"""
+ val check =
+"""circuit Top :
+ module Top :
+ input x : UInt<1>
+ input y : UInt<1>
+ output z : UInt<1>
+ wire n : UInt<1>
+ node _T_1 = n
+ z <= n
+ n <= and(x, y)
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
+
+ // =============================
+ "ConstProp" should "swap named nodes with temporary registers that drive them" in {
+ val input =
+"""circuit Top :
+ module Top :
+ input clock : Clock
+ input x : UInt<1>
+ output z : UInt<1>
+ reg _T_1 : UInt<1>, clock with : (reset => (UInt<1>(0), _T_1))
+ node n = _T_1
+ z <= n
+ _T_1 <= x
+"""
+ val check =
+"""circuit Top :
+ module Top :
+ input clock : Clock
+ input x : UInt<1>
+ output z : UInt<1>
+ reg n : UInt<1>, clock with : (reset => (UInt<1>(0), n))
+ node _T_1 = n
+ z <= n
+ n <= x
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
}
// More sophisticated tests of the full compiler
class ConstantPropagationIntegrationSpec extends LowTransformSpec {
def transform = new LowFirrtlOptimization
- "ConstProp" should "should not optimize across dontTouch on nodes" in {
+ "ConstProp" should "NOT optimize across dontTouch on nodes" in {
val input =
"""circuit Top :
| module Top :
@@ -396,7 +475,7 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec {
execute(input, check, Seq(dontTouch("Top.z")))
}
- it should "should not optimize across dontTouch on wires" in {
+ it should "NOT optimize across dontTouch on wires" in {
val input =
"""circuit Top :
| module Top :
@@ -415,4 +494,24 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec {
| z <= x""".stripMargin
execute(input, check, Seq(dontTouch("Top.z")))
}
+
+ it should "still propagate constants even when there is name swapping" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | input y : UInt<1>
+ | output z : UInt<1>
+ | node _T_1 = and(and(x, y), UInt<1>(0))
+ | node n = _T_1
+ | z <= n""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | input y : UInt<1>
+ | output z : UInt<1>
+ | z <= UInt<1>(0)""".stripMargin
+ execute(input, check, Seq.empty)
+ }
}