aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala25
-rw-r--r--src/test/scala/firrtlTests/ConstantPropagationTests.scala43
2 files changed, 61 insertions, 7 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index 930fe45a..efe06e9b 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -4,6 +4,7 @@ package firrtl
package transforms
import firrtl._
+import firrtl.annotations._
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
@@ -243,7 +244,7 @@ class ConstantPropagation extends Transform {
// 2. Propagate references again for backwards reference (Wires)
// TODO Replacing all wires with nodes makes the second pass unnecessary
@tailrec
- private def constPropModule(m: Module): Module = {
+ private def constPropModule(m: Module, dontTouches: Set[String]): Module = {
var nPropagated = 0L
val nodeMap = collection.mutable.HashMap[String, Expression]()
@@ -276,8 +277,8 @@ class ConstantPropagation extends Transform {
def constPropStmt(s: Statement): Statement = {
val stmtx = s map constPropStmt map constPropExpression
stmtx match {
- case x: DefNode => nodeMap(x.name) = x.value
- case Connect(_, WRef(wname, wtpe, WireKind, _), expr) =>
+ case x: DefNode if !dontTouches.contains(x.name) => nodeMap(x.name) = x.value
+ case Connect(_, WRef(wname, wtpe, WireKind, _), expr) if !dontTouches.contains(wname) =>
val exprx = constPropExpression(pad(expr, wtpe))
nodeMap(wname) = exprx
case _ =>
@@ -286,18 +287,28 @@ class ConstantPropagation extends Transform {
}
val res = Module(m.info, m.name, m.ports, backPropStmt(constPropStmt(m.body)))
- if (nPropagated > 0) constPropModule(res) else res
+ if (nPropagated > 0) constPropModule(res, dontTouches) else res
}
- def run(c: Circuit): Circuit = {
+ private def run(c: Circuit, dontTouchMap: Map[String, Set[String]]): Circuit = {
val modulesx = c.modules.map {
case m: ExtModule => m
- case m: Module => constPropModule(m)
+ case m: Module => constPropModule(m, dontTouchMap.getOrElse(m.name, Set.empty))
}
Circuit(c.info, modulesx, c.main)
}
def execute(state: CircuitState): CircuitState = {
- state.copy(circuit = run(state.circuit))
+ val dontTouches: Seq[(String, String)] = state.annotations match {
+ case Some(aMap) => aMap.annotations.collect {
+ case DontTouchAnnotation(ComponentName(c, ModuleName(m, _))) => m -> c
+ }
+ case None => Seq.empty
+ }
+ // Map from module name to component names
+ val dontTouchMap: Map[String, Set[String]] =
+ dontTouches.groupBy(_._1).mapValues(_.map(_._2).toSet)
+
+ state.copy(circuit = run(state.circuit, dontTouchMap))
}
}
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
index c94adbf6..f818f9c0 100644
--- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala
+++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
@@ -373,3 +373,46 @@ class ConstantPropagationSpec extends FirrtlFlatSpec {
(parse(exec(input))) should be (parse(check))
}
}
+
+// More sophisticated tests of the full compiler
+class ConstantPropagationIntegrationSpec extends LowTransformSpec {
+ def transform = new LowFirrtlOptimization
+
+ "ConstProp" should "should not optimize across dontTouch on nodes" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output y : UInt<1>
+ | node z = x
+ | y <= z""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output y : UInt<1>
+ | node z = x
+ | y <= z""".stripMargin
+ execute(input, check, Seq(dontTouch("Top.z")))
+ }
+
+ it should "should not optimize across dontTouch on wires" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output y : UInt<1>
+ | wire z : UInt<1>
+ | y <= z
+ | z <= x""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output y : UInt<1>
+ | wire z : UInt<1>
+ | y <= z
+ | z <= x""".stripMargin
+ execute(input, check, Seq(dontTouch("Top.z")))
+ }
+}