diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/transforms/Dedup.scala | 57 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/transforms/DedupTests.scala | 67 |
2 files changed, 118 insertions, 6 deletions
diff --git a/src/main/scala/firrtl/transforms/Dedup.scala b/src/main/scala/firrtl/transforms/Dedup.scala index 04ac968d..dc182858 100644 --- a/src/main/scala/firrtl/transforms/Dedup.scala +++ b/src/main/scala/firrtl/transforms/Dedup.scala @@ -5,15 +5,18 @@ package transforms import firrtl.ir._ import firrtl.Mappers._ +import firrtl.traversals.Foreachers._ import firrtl.analyses.InstanceGraph import firrtl.annotations._ import firrtl.passes.{InferTypes, MemPortUtils} -import firrtl.Utils.throwInternalError +import firrtl.Utils.{kind, splitRef, throwInternalError} import firrtl.annotations.transforms.DupedResult import firrtl.annotations.TargetToken.{OfModule, Instance} import firrtl.options.{HasShellOptions, ShellOption} import logger.LazyLogging +import scala.annotation.tailrec + // Datastructures import scala.collection.mutable @@ -446,6 +449,48 @@ object DedupModules extends LazyLogging { changeInternals({n => n}, retype, {i => i}, renameOfModule)(module) } + @tailrec + private def hasBundleType(tpe: Type): Boolean = tpe match { + case _: BundleType => true + case _: GroundType => false + case VectorType(t, _) => hasBundleType(t) + } + + // Find modules that should not have their ports agnostified to avoid bug in + // https://github.com/freechipsproject/firrtl/issues/1703 + // Marks modules that have a port of BundleType that are connected via an aggregate connect or + // partial connect in an instantiating parent + // Order of modules does not matter + private def modsToNotAgnostifyPorts(modules: Seq[DefModule]): Set[String] = { + val dontDedup = mutable.HashSet.empty[String] + def onModule(mod: DefModule): Unit = { + val instToModule = mutable.HashMap.empty[String, String] + def markAggregatePorts(expr: Expression): Unit = { + if (kind(expr) == InstanceKind && hasBundleType(expr.tpe)) { + val (WRef(inst, _, _, _), _) = splitRef(expr) + dontDedup += instToModule(inst) + } + } + def onStmt(stmt: Statement): Unit = { + stmt.foreach(onStmt) + stmt match { + case inst: DefInstance => + instToModule(inst.name) = inst.module + case Connect(_, lhs, rhs) => + markAggregatePorts(lhs) + markAggregatePorts(rhs) + case PartialConnect(_, lhs, rhs) => + markAggregatePorts(lhs) + markAggregatePorts(rhs) + case _ => + } + } + mod.foreach(onStmt) + } + modules.foreach(onModule) + dontDedup.toSet + } + //scalastyle:off /** Returns * 1) map of tag to all matching module names, @@ -470,6 +515,8 @@ object DedupModules extends LazyLogging { val agnosticRename = RenameMap() + val dontAgnostifyPorts = modsToNotAgnostifyPorts(moduleLinearization) + moduleLinearization.foreach { originalModule => // Replace instance references to new deduped modules val dontcare = RenameMap() @@ -487,7 +534,13 @@ object DedupModules extends LazyLogging { // Build tag val builder = new mutable.ArrayBuffer[Any]() - agnosticModule.ports.foreach { builder ++= _.serialize } + + // It may seem weird to use non-agnostified ports with an agnostified body because + // technically it would be invalid FIRRTL, but it is logically sound for the purpose of + // calculating deduplication tags + val ports = + if (dontAgnostifyPorts(originalModule.name)) originalModule.ports else agnosticModule.ports + ports.foreach { builder ++= _.serialize } agnosticModule match { case Module(i, n, ps, b) => builder ++= fastSerializedHash(b).toString()//.serialize diff --git a/src/test/scala/firrtlTests/transforms/DedupTests.scala b/src/test/scala/firrtlTests/transforms/DedupTests.scala index bb12c759..5776db31 100644 --- a/src/test/scala/firrtlTests/transforms/DedupTests.scala +++ b/src/test/scala/firrtlTests/transforms/DedupTests.scala @@ -253,6 +253,61 @@ class DedupModuleTests extends HighTransformSpec { val diff_params = mkfir(("BB", "BB"), ("0", "1")) execute(diff_params, diff_params, Seq.empty) } + + "Modules with aggregate ports that are bulk connected" should "NOT dedup if their port names differ" in { + val input = + """ + |circuit FooAndBarModule : + | module FooModule : + | output io : {flip foo : UInt<1>, fuzz : UInt<1>} + | io.fuzz <= io.foo + | module BarModule : + | output io : {flip bar : UInt<1>, buzz : UInt<1>} + | io.buzz <= io.bar + | module FooAndBarModule : + | output io : {foo : {flip foo : UInt<1>, fuzz : UInt<1>}, bar : {flip bar : UInt<1>, buzz : UInt<1>}} + | inst foo of FooModule + | inst bar of BarModule + | io.foo <- foo.io + | io.bar <- bar.io + |""".stripMargin + val check = input + execute(input, check, Seq.empty) + } + + "Modules with aggregate ports that are bulk connected" should "dedup if their port names are the same" in { + val input = + """ + |circuit FooAndBarModule : + | module FooModule : + | output io : {flip foo : UInt<1>, fuzz : UInt<1>} + | io.fuzz <= io.foo + | module BarModule : + | output io : {flip foo : UInt<1>, fuzz : UInt<1>} + | io.fuzz <= io.foo + | module FooAndBarModule : + | output io : {foo : {flip foo : UInt<1>, fuzz : UInt<1>}, bar : {flip bar : UInt<1>, buzz : UInt<1>}} + | inst foo of FooModule + | inst bar of BarModule + | io.foo <- foo.io + | io.bar <- bar.io + |""".stripMargin + val check = + """ + |circuit FooAndBarModule : + | module FooModule : + | output io : {flip foo : UInt<1>, fuzz : UInt<1>} + | io.fuzz <= io.foo + | module FooAndBarModule : + | output io : {foo : {flip foo : UInt<1>, fuzz : UInt<1>}, bar : {flip bar : UInt<1>, buzz : UInt<1>}} + | inst foo of FooModule + | inst bar of FooModule + | io.foo <- foo.io + | io.bar <- bar.io + |""".stripMargin + execute(input, check, Seq.empty) + } + "The module A and B" should "be deduped with the first module in order" in { val input = """circuit Top : @@ -772,11 +827,15 @@ class DedupModuleTests extends HighTransformSpec { | output oa: {z: {y: {x: UInt<1>}}, a: UInt<1>} | output ob: {a: {b: {c: UInt<1>}}, z: UInt<1>} | inst a of a - | a.i <= ia - | oa <= a.o + | a.i.z.y.x <= ia.z.y.x + | a.i.a <= ia.a + | oa.z.y.x <= a.o.z.y.x + | oa.a <= a.o.a | inst b of b - | b.q <= ib - | ob <= b.r + | b.q.a.b.c <= ib.a.b.c + | b.q.z <= ib.z + | ob.a.b.c <= b.r.a.b.c + | ob.z <= b.r.z | module a: | input i: {z: {y: {x: UInt<1>}}, a: UInt<1>} | output o: {z: {y: {x: UInt<1>}}, a: UInt<1>} |
