diff options
5 files changed, 94 insertions, 44 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index dc9b2bbe..10e99beb 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -100,7 +100,7 @@ object ConstantPropagation { } -class ConstantPropagation extends Transform with DependencyAPIMigration with ResolvedAnnotationPaths { +class ConstantPropagation extends Transform with DependencyAPIMigration { import ConstantPropagation._ override def prerequisites = @@ -124,8 +124,6 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res case _ => false } - override val annotationClasses: Traversable[Class[_]] = Seq(classOf[DontTouchAnnotation]) - sealed trait SimplifyBinaryOp { def matchingArgsValue(e: DoPrim, arg: Expression): Expression def apply(e: DoPrim): Expression = { @@ -841,13 +839,15 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res } def execute(state: CircuitState): CircuitState = { - val dontTouchRTs = state.annotations.flatMap { - case anno: HasDontTouches => anno.dontTouches + val dontTouches: Seq[(OfModule, String)] = state.annotations.flatMap { + case anno: HasDontTouches => + anno.dontTouches + // We treat all ReferenceTargets as if they were local because of limitations of + // EliminateTargetPaths + .map(rt => OfModule(rt.encapsulatingModule) -> rt.ref) case o => Nil } - val dontTouches: Seq[(OfModule, String)] = dontTouchRTs.map { - case Target(_, Some(m), Seq(Ref(c))) => m.OfModule -> c - } + // Map from module name to component names val dontTouchMap: Map[OfModule, Set[String]] = dontTouches.groupBy(_._1).mapValues(_.map(_._2).toSet).toMap diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala index fb1bd1f6..c9b42f8e 100644 --- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala +++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala @@ -28,11 +28,7 @@ import collection.mutable * circumstances of their instantiation in their parent module, they will still not be removed. To * remove such modules, use the [[NoDedupAnnotation]] to prevent deduplication. */ -class DeadCodeElimination - extends Transform - with ResolvedAnnotationPaths - with RegisteredTransform - with DependencyAPIMigration { +class DeadCodeElimination extends Transform with RegisteredTransform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.LowForm ++ Seq( @@ -368,12 +364,13 @@ class DeadCodeElimination state.copy(circuit = newCircuit, renames = Some(renames)) } - override val annotationClasses: Traversable[Class[_]] = - Seq(classOf[DontTouchAnnotation], classOf[OptimizableExtModuleAnnotation]) - def execute(state: CircuitState): CircuitState = { val dontTouches: Seq[LogicNode] = state.annotations.flatMap { - case anno: HasDontTouches => anno.dontTouches.filter(_.isLocal).map(LogicNode(_)) + case anno: HasDontTouches => + anno.dontTouches + // We treat all ReferenceTargets as if they were local because of limitations of + // EliminateTargetPaths + .map(rt => LogicNode(rt.encapsulatingModule, rt.ref)) case o => Nil } val doTouchExtMods: Seq[String] = state.annotations.collect { diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index 6ab54159..cc6377ee 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -1528,6 +1528,41 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { execute(input, check, Seq.empty) } + "ConstProp" should "compose with Dedup and not duplicate modules " in { + val input = + """circuit Top : + | module child : + | input x : UInt<1> + | output z : UInt<1> + | z <= not(x) + | module child_1 : + | input x : UInt<1> + | output z : UInt<1> + | z <= not(x) + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst c of child + | inst c_1 of child_1 + | c.x <= x + | c_1.x <= x + | z <= and(c.z, c_1.z)""".stripMargin + val check = + """circuit Top : + | module child : + | input x : UInt<1> + | output z : UInt<1> + | z <= not(x) + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst c of child + | inst c_1 of child + | z <= and(c.z, c_1.z) + | c.x <= x + | c_1.x <= x""".stripMargin + execute(input, check, Seq(dontTouch("child.z"), dontTouch("child_1.z"))) + } } class ConstantPropagationEquivalenceSpec extends FirrtlFlatSpec { diff --git a/src/test/scala/firrtlTests/DCETests.scala b/src/test/scala/firrtlTests/DCETests.scala index a9084f0b..f1c0001a 100644 --- a/src/test/scala/firrtlTests/DCETests.scala +++ b/src/test/scala/firrtlTests/DCETests.scala @@ -491,6 +491,42 @@ class DCETests extends FirrtlFlatSpec { (verilog shouldNot include).regex("""fwrite""") (verilog shouldNot include).regex("""fatal""") } + + "DCE" should "not duplicate unnecessarily" in { + val input = + """circuit Top : + | module child : + | input x : UInt<1> + | output z : UInt<1> + | z <= not(x) + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst c of child + | inst c_1 of child + | c.x <= x + | c_1.x <= x + | z <= and(c.z, c_1.z)""".stripMargin + val check = + """circuit Top : + | module child : + | input x : UInt<1> + | output z : UInt<1> + | z <= not(x) + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst c of child + | inst c_1 of child + | z <= and(c.z, c_1.z) + | c.x <= x + | c_1.x <= x""".stripMargin + val top = CircuitTarget("Top").module("Top") + val annos = + Seq(top.instOf("c", "child").ref("z"), top.instOf("c_1", "child").ref("z")) + .map(DontTouchAnnotation(_)) + exec(input, check, annos) + } } class DCECommandLineSpec extends FirrtlFlatSpec { diff --git a/src/test/scala/firrtlTests/annotationTests/EliminateTargetPathsSpec.scala b/src/test/scala/firrtlTests/annotationTests/EliminateTargetPathsSpec.scala index bb833f0b..56079c31 100644 --- a/src/test/scala/firrtlTests/annotationTests/EliminateTargetPathsSpec.scala +++ b/src/test/scala/firrtlTests/annotationTests/EliminateTargetPathsSpec.scala @@ -121,55 +121,37 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { val outputState = new LowFirrtlCompiler().compile(inputState, customTransforms) val check = """circuit Top : - | module Leaf___Top_m1_l1 : - | input i : UInt<1> - | output o : UInt<1> - | - | node a = i - | o <= i - | | module Leaf : | input i : UInt<1> | output o : UInt<1> - | - | skip + + | node a = i | o <= i | - | module Middle___Top_m1 : - | input i : UInt<1> - | output o : UInt<1> - | - | inst l1 of Leaf___Top_m1_l1 - | inst l2 of Leaf - | o <= l2.o - | l1.i <= i - | l2.i <= l1.o - | | module Middle : | input i : UInt<1> | output o : UInt<1> - | + | inst l1 of Leaf | inst l2 of Leaf | o <= l2.o | l1.i <= i | l2.i <= l1.o - | + | module Top : | input i : UInt<1> | output o : UInt<1> - | - | inst m1 of Middle___Top_m1 + + | inst m1 of Middle | inst m2 of Middle | o <= m2.o | m1.i <= i - | m2.i <= m1.o - | - """.stripMargin + | m2.i <= m1.o""".stripMargin + canonicalize(outputState.circuit).serialize should be(canonicalize(parse(check)).serialize) outputState.annotations.collect { case x: DontTouchAnnotation => x.target - } should be(Seq(Top.circuitTarget.module("Leaf___Top_m1_l1").ref("a"))) + } should be(Seq(Top.circuitTarget.module("Top").instOf("m1", "Middle").instOf("l1", "Leaf").ref("a"))) } property("No name conflicts between old and new modules") { |
