diff options
| author | Jack Koenig | 2017-06-29 21:41:19 -0700 |
|---|---|---|
| committer | GitHub | 2017-06-29 21:41:19 -0700 |
| commit | b60c31806e9220d63ac2dae98ef4b54c37122491 (patch) | |
| tree | a46e85ea7ea56e80596b645f4619f6f74ae1ad56 /src/main | |
| parent | ad3c3a6fcb5bc374bd56c7dd2591fb1def1a5e1b (diff) | |
| parent | a43486b65506620f89f3e171101353b2dde65351 (diff) | |
Merge pull request #620 from freechipsproject/keep-names
Preserve "better" names in Constant Propagation
Diffstat (limited to 'src/main')
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 46 |
1 files changed, 41 insertions, 5 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index efe06e9b..31a6a660 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -11,6 +11,7 @@ import firrtl.Mappers._ import firrtl.PrimOps._ import annotation.tailrec +import collection.mutable class ConstantPropagation extends Transform { def inputForm = LowForm @@ -239,18 +240,31 @@ class ConstantPropagation extends Transform { case _ => r } + // Is "a" a "better name" than "b"? + private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_') + // Two pass process // 1. Propagate constants in expressions and forward propagate references // 2. Propagate references again for backwards reference (Wires) // TODO Replacing all wires with nodes makes the second pass unnecessary + // However, preserving decent names DOES require a second pass + // Replacing all wires with nodes makes it unnecessary for preserving decent names to trigger an + // extra iteration though @tailrec private def constPropModule(m: Module, dontTouches: Set[String]): Module = { var nPropagated = 0L - val nodeMap = collection.mutable.HashMap[String, Expression]() + val nodeMap = mutable.HashMap.empty[String, Expression] + // For cases where we are trying to constprop a bad name over a good one, we swap their names + // during the second pass + val swapMap = mutable.HashMap.empty[String, String] def backPropExpr(expr: Expression): Expression = { val old = expr map backPropExpr val propagated = old match { + // When swapping, we swap both rhs and lhs + case ref @ WRef(rname, _,_,_) if swapMap.contains(rname) => + ref.copy(name = swapMap(rname)) + // Only const prop on the rhs case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => constPropNodeRef(ref, nodeMap(rname)) case x => x @@ -260,7 +274,19 @@ class ConstantPropagation extends Transform { } propagated } - def backPropStmt(stmt: Statement): Statement = stmt map backPropStmt map backPropExpr + + def backPropStmt(stmt: Statement): Statement = stmt map backPropExpr match { + case decl: IsDeclaration if swapMap.contains(decl.name) => + val newName = swapMap(decl.name) + nPropagated += 1 + decl match { + case node: DefNode => node.copy(name = newName) + case wire: DefWire => wire.copy(name = newName) + case reg: DefRegister => reg.copy(name = newName) + case other => throwInternalError + } + case other => other map backPropStmt + } def constPropExpression(e: Expression): Expression = { val old = e map constPropExpression @@ -274,19 +300,29 @@ class ConstantPropagation extends Transform { propagated } + // When propagating a reference, check if we want to keep the name that would be deleted + def propagateRef(lname: String, value: Expression): Unit = { + value match { + case WRef(rname,_,_,_) if betterName(lname, rname) => + swapMap += (lname -> rname, rname -> lname) + case _ => + } + nodeMap(lname) = value + } + def constPropStmt(s: Statement): Statement = { val stmtx = s map constPropStmt map constPropExpression stmtx match { - case x: DefNode if !dontTouches.contains(x.name) => nodeMap(x.name) = x.value + case x: DefNode if !dontTouches.contains(x.name) => propagateRef(x.name, x.value) case Connect(_, WRef(wname, wtpe, WireKind, _), expr) if !dontTouches.contains(wname) => val exprx = constPropExpression(pad(expr, wtpe)) - nodeMap(wname) = exprx + propagateRef(wname, exprx) case _ => } stmtx } - val res = Module(m.info, m.name, m.ports, backPropStmt(constPropStmt(m.body))) + val res = m.copy(body = backPropStmt(constPropStmt(m.body))) if (nPropagated > 0) constPropModule(res, dontTouches) else res } |
