diff options
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 25 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ConstantPropagationTests.scala | 43 |
2 files changed, 61 insertions, 7 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 930fe45a..efe06e9b 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -4,6 +4,7 @@ package firrtl package transforms import firrtl._ +import firrtl.annotations._ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ @@ -243,7 +244,7 @@ class ConstantPropagation extends Transform { // 2. Propagate references again for backwards reference (Wires) // TODO Replacing all wires with nodes makes the second pass unnecessary @tailrec - private def constPropModule(m: Module): Module = { + private def constPropModule(m: Module, dontTouches: Set[String]): Module = { var nPropagated = 0L val nodeMap = collection.mutable.HashMap[String, Expression]() @@ -276,8 +277,8 @@ class ConstantPropagation extends Transform { def constPropStmt(s: Statement): Statement = { val stmtx = s map constPropStmt map constPropExpression stmtx match { - case x: DefNode => nodeMap(x.name) = x.value - case Connect(_, WRef(wname, wtpe, WireKind, _), expr) => + case x: DefNode if !dontTouches.contains(x.name) => nodeMap(x.name) = x.value + case Connect(_, WRef(wname, wtpe, WireKind, _), expr) if !dontTouches.contains(wname) => val exprx = constPropExpression(pad(expr, wtpe)) nodeMap(wname) = exprx case _ => @@ -286,18 +287,28 @@ class ConstantPropagation extends Transform { } val res = Module(m.info, m.name, m.ports, backPropStmt(constPropStmt(m.body))) - if (nPropagated > 0) constPropModule(res) else res + if (nPropagated > 0) constPropModule(res, dontTouches) else res } - def run(c: Circuit): Circuit = { + private def run(c: Circuit, dontTouchMap: Map[String, Set[String]]): Circuit = { val modulesx = c.modules.map { case m: ExtModule => m - case m: Module => constPropModule(m) + case m: Module => constPropModule(m, dontTouchMap.getOrElse(m.name, Set.empty)) } Circuit(c.info, modulesx, c.main) } def execute(state: CircuitState): CircuitState = { - state.copy(circuit = run(state.circuit)) + val dontTouches: Seq[(String, String)] = state.annotations match { + case Some(aMap) => aMap.annotations.collect { + case DontTouchAnnotation(ComponentName(c, ModuleName(m, _))) => m -> c + } + case None => Seq.empty + } + // Map from module name to component names + val dontTouchMap: Map[String, Set[String]] = + dontTouches.groupBy(_._1).mapValues(_.map(_._2).toSet) + + state.copy(circuit = run(state.circuit, dontTouchMap)) } } diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index c94adbf6..f818f9c0 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -373,3 +373,46 @@ class ConstantPropagationSpec extends FirrtlFlatSpec { (parse(exec(input))) should be (parse(check)) } } + +// More sophisticated tests of the full compiler +class ConstantPropagationIntegrationSpec extends LowTransformSpec { + def transform = new LowFirrtlOptimization + + "ConstProp" should "should not optimize across dontTouch on nodes" in { + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output y : UInt<1> + | node z = x + | y <= z""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output y : UInt<1> + | node z = x + | y <= z""".stripMargin + execute(input, check, Seq(dontTouch("Top.z"))) + } + + it should "should not optimize across dontTouch on wires" in { + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output y : UInt<1> + | wire z : UInt<1> + | y <= z + | z <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output y : UInt<1> + | wire z : UInt<1> + | y <= z + | z <= x""".stripMargin + execute(input, check, Seq(dontTouch("Top.z"))) + } +} |
