From 87544d43760ab0698f63b25da2e3b3d342e89fd7 Mon Sep 17 00:00:00 2001 From: Jack Date: Tue, 28 Nov 2017 16:16:08 -0500 Subject: Have DedupModules report renaming --- src/main/scala/firrtl/transforms/Dedup.scala | 33 ++++++++++++++---------- src/test/scala/firrtlTests/AnnotationTests.scala | 33 ++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 14 deletions(-) (limited to 'src') diff --git a/src/main/scala/firrtl/transforms/Dedup.scala b/src/main/scala/firrtl/transforms/Dedup.scala index 5fa2c036..496500ca 100644 --- a/src/main/scala/firrtl/transforms/Dedup.scala +++ b/src/main/scala/firrtl/transforms/Dedup.scala @@ -30,14 +30,6 @@ object NoDedupAnnotation { class DedupModules extends Transform { def inputForm = HighForm def outputForm = HighForm - def execute(state: CircuitState): CircuitState = { - getMyAnnotations(state) match { - case Nil => state.copy(circuit = run(state.circuit, Seq.empty)) - case annos => - val noDedups = annos.collect { case NoDedupAnnotation(ModuleName(m, c)) => m } - state.copy(circuit = run(state.circuit, noDedups)) - } - } // Orders the modules of a circuit from leaves to root // A module will appear *after* all modules it instantiates private def buildModuleOrder(c: Circuit): Seq[String] = { @@ -74,15 +66,17 @@ class DedupModules extends Transform { // Finds duplicate Modules // Also changes DefInstances to instantiate the deduplicated module + // Returns (Deduped Module name -> Seq of identical modules, + // Deuplicate Module name -> deduped module name) private def findDups( moduleOrder: Seq[String], moduleMap: Map[String, DefModule], - noDedups: Seq[String]): Map[String, Seq[DefModule]] = { + noDedups: Seq[String]): (Map[String, Seq[DefModule]], Map[String, String]) = { // Module body -> Module name val dedupModules = mutable.HashMap.empty[String, String] // Old module name -> dup module name val dedupMap = mutable.HashMap.empty[String, String] - // Dup module name -> all old module names + // Deduplicated module name -> all identical modules val oldModuleMap = mutable.HashMap.empty[String, Seq[DefModule]] def onModule(m: DefModule): Unit = { @@ -134,18 +128,29 @@ class DedupModules extends Transform { } } moduleOrder.foreach(n => onModule(moduleMap(n))) - oldModuleMap.toMap + (oldModuleMap.toMap, dedupMap.toMap) } - def run(c: Circuit, noDedups: Seq[String]): Circuit = { + def run(c: Circuit, noDedups: Seq[String]): (Circuit, RenameMap) = { val moduleOrder = buildModuleOrder(c) val moduleMap = c.modules.map(m => m.name -> m).toMap - val oldModuleMap = findDups(moduleOrder, moduleMap, noDedups) + val (oldModuleMap, dedupMap) = findDups(moduleOrder, moduleMap, noDedups) // Use old module list to preserve ordering val dedupedModules = c.modules.flatMap(m => oldModuleMap.get(m.name).map(_.head)) - c.copy(modules = dedupedModules) + val cname = CircuitName(c.main) + val renameMap = RenameMap(dedupMap.map { case (from, to) => + ModuleName(from, cname) -> List(ModuleName(to, cname)) + }) + + (c.copy(modules = dedupedModules), renameMap) + } + + def execute(state: CircuitState): CircuitState = { + val noDedups = getMyAnnotations(state).collect { case NoDedupAnnotation(ModuleName(m, c)) => m } + val (newC, renameMap) = run(state.circuit, noDedups) + state.copy(circuit = newC, renames = Some(renameMap)) } } diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala index aeefbbe3..c8a90729 100644 --- a/src/test/scala/firrtlTests/AnnotationTests.scala +++ b/src/test/scala/firrtlTests/AnnotationTests.scala @@ -463,4 +463,37 @@ class AnnotationTests extends AnnotationSpec with Matchers { resultAnno should not contain (anno("foo", mod = "DeadExt")) resultAnno should not contain (anno("bar", mod = "DeadExt")) } + + "Renaming" should "track deduplication" in { + val compiler = new VerilogCompiler + val input = + """circuit Top : + | module Child : + | input x : UInt<32> + | output y : UInt<32> + | y <= x + | module Child_1 : + | input x : UInt<32> + | output y : UInt<32> + | y <= x + | module Top : + | input in : UInt<32>[2] + | output out : UInt<32> + | inst a of Child + | inst b of Child_1 + | a.x <= in[0] + | b.x <= in[1] + | out <= tail(add(a.y, b.y), 1) + |""".stripMargin + val annos = Seq( + anno("x", mod = "Child"), anno("y", mod = "Child_1"), manno("Child"), manno("Child_1") + ) + val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) + val resultAnno = result.annotations.get.annotations + resultAnno should contain (anno("x", mod = "Child")) + resultAnno should contain (anno("y", mod = "Child")) + resultAnno should contain (manno("Child")) + resultAnno should not contain (anno("y", mod = "Child_1")) + resultAnno should not contain (manno("Child_1")) + } } -- cgit v1.2.3