aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJack2017-11-28 16:16:08 -0500
committerAdam Izraelevitz2017-11-28 18:16:57 -0800
commit87544d43760ab0698f63b25da2e3b3d342e89fd7 (patch)
tree22985d73c2f7da45f7a04e492d01b350cbbb7b0c
parentd8e9fc3d84c06c546440b1ef821cd1e3626b62e6 (diff)
Have DedupModules report renaming
-rw-r--r--src/main/scala/firrtl/transforms/Dedup.scala33
-rw-r--r--src/test/scala/firrtlTests/AnnotationTests.scala33
2 files changed, 52 insertions, 14 deletions
diff --git a/src/main/scala/firrtl/transforms/Dedup.scala b/src/main/scala/firrtl/transforms/Dedup.scala
index 5fa2c036..496500ca 100644
--- a/src/main/scala/firrtl/transforms/Dedup.scala
+++ b/src/main/scala/firrtl/transforms/Dedup.scala
@@ -30,14 +30,6 @@ object NoDedupAnnotation {
class DedupModules extends Transform {
def inputForm = HighForm
def outputForm = HighForm
- def execute(state: CircuitState): CircuitState = {
- getMyAnnotations(state) match {
- case Nil => state.copy(circuit = run(state.circuit, Seq.empty))
- case annos =>
- val noDedups = annos.collect { case NoDedupAnnotation(ModuleName(m, c)) => m }
- state.copy(circuit = run(state.circuit, noDedups))
- }
- }
// Orders the modules of a circuit from leaves to root
// A module will appear *after* all modules it instantiates
private def buildModuleOrder(c: Circuit): Seq[String] = {
@@ -74,15 +66,17 @@ class DedupModules extends Transform {
// Finds duplicate Modules
// Also changes DefInstances to instantiate the deduplicated module
+ // Returns (Deduped Module name -> Seq of identical modules,
+ // Deuplicate Module name -> deduped module name)
private def findDups(
moduleOrder: Seq[String],
moduleMap: Map[String, DefModule],
- noDedups: Seq[String]): Map[String, Seq[DefModule]] = {
+ noDedups: Seq[String]): (Map[String, Seq[DefModule]], Map[String, String]) = {
// Module body -> Module name
val dedupModules = mutable.HashMap.empty[String, String]
// Old module name -> dup module name
val dedupMap = mutable.HashMap.empty[String, String]
- // Dup module name -> all old module names
+ // Deduplicated module name -> all identical modules
val oldModuleMap = mutable.HashMap.empty[String, Seq[DefModule]]
def onModule(m: DefModule): Unit = {
@@ -134,18 +128,29 @@ class DedupModules extends Transform {
}
}
moduleOrder.foreach(n => onModule(moduleMap(n)))
- oldModuleMap.toMap
+ (oldModuleMap.toMap, dedupMap.toMap)
}
- def run(c: Circuit, noDedups: Seq[String]): Circuit = {
+ def run(c: Circuit, noDedups: Seq[String]): (Circuit, RenameMap) = {
val moduleOrder = buildModuleOrder(c)
val moduleMap = c.modules.map(m => m.name -> m).toMap
- val oldModuleMap = findDups(moduleOrder, moduleMap, noDedups)
+ val (oldModuleMap, dedupMap) = findDups(moduleOrder, moduleMap, noDedups)
// Use old module list to preserve ordering
val dedupedModules = c.modules.flatMap(m => oldModuleMap.get(m.name).map(_.head))
- c.copy(modules = dedupedModules)
+ val cname = CircuitName(c.main)
+ val renameMap = RenameMap(dedupMap.map { case (from, to) =>
+ ModuleName(from, cname) -> List(ModuleName(to, cname))
+ })
+
+ (c.copy(modules = dedupedModules), renameMap)
+ }
+
+ def execute(state: CircuitState): CircuitState = {
+ val noDedups = getMyAnnotations(state).collect { case NoDedupAnnotation(ModuleName(m, c)) => m }
+ val (newC, renameMap) = run(state.circuit, noDedups)
+ state.copy(circuit = newC, renames = Some(renameMap))
}
}
diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala
index aeefbbe3..c8a90729 100644
--- a/src/test/scala/firrtlTests/AnnotationTests.scala
+++ b/src/test/scala/firrtlTests/AnnotationTests.scala
@@ -463,4 +463,37 @@ class AnnotationTests extends AnnotationSpec with Matchers {
resultAnno should not contain (anno("foo", mod = "DeadExt"))
resultAnno should not contain (anno("bar", mod = "DeadExt"))
}
+
+ "Renaming" should "track deduplication" in {
+ val compiler = new VerilogCompiler
+ val input =
+ """circuit Top :
+ | module Child :
+ | input x : UInt<32>
+ | output y : UInt<32>
+ | y <= x
+ | module Child_1 :
+ | input x : UInt<32>
+ | output y : UInt<32>
+ | y <= x
+ | module Top :
+ | input in : UInt<32>[2]
+ | output out : UInt<32>
+ | inst a of Child
+ | inst b of Child_1
+ | a.x <= in[0]
+ | b.x <= in[1]
+ | out <= tail(add(a.y, b.y), 1)
+ |""".stripMargin
+ val annos = Seq(
+ anno("x", mod = "Child"), anno("y", mod = "Child_1"), manno("Child"), manno("Child_1")
+ )
+ val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
+ val resultAnno = result.annotations.get.annotations
+ resultAnno should contain (anno("x", mod = "Child"))
+ resultAnno should contain (anno("y", mod = "Child"))
+ resultAnno should contain (manno("Child"))
+ resultAnno should not contain (anno("y", mod = "Child_1"))
+ resultAnno should not contain (manno("Child_1"))
+ }
}