diff options
| author | Jack Koenig | 2018-08-23 18:19:47 -0700 |
|---|---|---|
| committer | GitHub | 2018-08-23 18:19:47 -0700 |
| commit | abaa38f4f63b35105796916e3df2ccf5c1639f65 (patch) | |
| tree | 1e3a90664f445d9ca6a8bddf148d691770a02430 /src | |
| parent | d7b96168f1c7244124ac258de174bf11d53092ab (diff) | |
Fix NoDedupMem to be cognizant of Module scope (#876)
Previously, mems marked no dedup would prevent mems with the same
instance name in other modules from deduping
Diffstat (limited to 'src')
3 files changed, 94 insertions, 19 deletions
diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala index e254dcc9..b552470d 100644 --- a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala +++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala @@ -75,12 +75,6 @@ object AnalysisUtils { } case _ => e } - - /** Checks whether the two memories are equivalent in all respects except name - */ - def eqMems(a: DefAnnotatedMemory, b: DefAnnotatedMemory, noDeDupeMems: Seq[String]): Boolean = - a == b.copy(info = a.info, name = a.name, memRef = a.memRef) && - !(noDeDupeMems.contains(a.name) || noDeDupeMems.contains(b.name)) } /** Determines if a write mask is needed (wmode/en and wmask are equivalent). diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala index d195ea55..b0d3731f 100644 --- a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala +++ b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala @@ -4,7 +4,6 @@ package firrtl.passes package memlib import firrtl._ import firrtl.ir._ -import AnalysisUtils.eqMems import firrtl.Mappers._ import firrtl.annotations._ @@ -19,30 +18,57 @@ class ResolveMemoryReference extends Transform { def inputForm = MidForm def outputForm = MidForm - type AnnotatedMemories = collection.mutable.ArrayBuffer[(String, DefAnnotatedMemory)] + /** Helper class for determining when two memories are equivalent while igoring + * irrelevant details like name and info + */ + private class WrappedDefAnnoMemory(val underlying: DefAnnotatedMemory) { + // Remove irrelevant details for comparison + private def generic = underlying.copy(info = NoInfo, name = "", memRef = None) + override def hashCode: Int = generic.hashCode + override def equals(that: Any): Boolean = that match { + case mem: WrappedDefAnnoMemory => this.generic == mem.generic + case _ => false + } + } + private def wrap(mem: DefAnnotatedMemory) = new WrappedDefAnnoMemory(mem) + + // Values are Tuple of Module Name and Memory Instance Name + private type AnnotatedMemories = collection.mutable.HashMap[WrappedDefAnnoMemory, (String, String)] + + private def dedupable(noDedups: Map[String, Set[String]], module: String, memory: String): Boolean = + noDedups.get(module).map(!_.contains(memory)).getOrElse(true) /** If a candidate memory is identical except for name to another, add an * annotation that references the name of the other memory. */ - def updateMemStmts(mname: String, uniqueMems: AnnotatedMemories, noDeDupeMems: Seq[String])(s: Statement): Statement = s match { - case m: DefAnnotatedMemory => - uniqueMems find (x => eqMems(x._2, m, noDeDupeMems)) match { + def updateMemStmts(mname: String, + existingMems: AnnotatedMemories, + noDedupMap: Map[String, Set[String]]) + (s: Statement): Statement = s match { + // If not dedupable, no need to add to existing (since nothing can dedup with it) + // We just return the DefAnnotatedMemory as is in the default case below + case m: DefAnnotatedMemory if dedupable(noDedupMap, mname, m.name) => + val wrapped = wrap(m) + existingMems.get(wrapped) match { + case proto @ Some(_) => + m.copy(memRef = proto) case None => - uniqueMems += (mname -> m) + existingMems(wrapped) = (mname, m.name) m - case Some((module, proto)) => m copy (memRef = Some(module -> proto.name)) } - case s => s map updateMemStmts(mname, uniqueMems, noDeDupeMems) + case s => s.map(updateMemStmts(mname, existingMems, noDedupMap)) } - def run(c: Circuit, noDeDupeMems: Seq[String]) = { - val uniqueMems = new AnnotatedMemories - c copy (modules = c.modules map (m => m map updateMemStmts(m.name, uniqueMems, noDeDupeMems))) + def run(c: Circuit, noDedupMap: Map[String, Set[String]]) = { + val existingMems = new AnnotatedMemories + val modulesx = c.modules.map(m => m.map(updateMemStmts(m.name, existingMems, noDedupMap))) + c.copy(modules = modulesx) } def execute(state: CircuitState): CircuitState = { val noDedups = state.annotations.collect { - case NoDedupMemAnnotation(ComponentName(cn, _)) => cn + case NoDedupMemAnnotation(ComponentName(cn, ModuleName(mn, _))) => mn -> cn } - state.copy(circuit=run(state.circuit, noDedups)) + val noDedupMap: Map[String, Set[String]] = noDedups.groupBy(_._1).mapValues(_.map(_._2).toSet) + state.copy(circuit = run(state.circuit, noDedupMap)) } } diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index 2eae3580..6cedd3f0 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -264,6 +264,61 @@ circuit CustomMemory : (new java.io.File(confLoc)).delete() } + "ReplSeqMem" should "dedup mems with the same instance name as other mems (in other modules) marked NoDedup" in { + val input = """ +circuit CustomMemory : + module ChildMemory : + input clock : Clock + input reset : UInt<1> + output io : {flip rClk : Clock, flip rAddr : UInt<3>, dO : UInt<16>, flip wClk : Clock, flip wAddr : UInt<3>, flip wEn : UInt<1>, flip dI : UInt<16>} + + smem mem_0 : UInt<16>[7] + read mport r1 = mem_0[io.rAddr], clock + io.dO <= r1 + when io.wEn : + write mport w1 = mem_0[io.wAddr], clock + w1 <= io.dI + + module CustomMemory : + input clock : Clock + input reset : UInt<1> + output io : {flip rClk : Clock, flip rAddr : UInt<3>, dO : UInt<16>, flip wClk : Clock, flip wAddr : UInt<3>, flip wEn : UInt<1>, flip dI : UInt<16>} + + inst child of ChildMemory + child.clock <= clock + child.reset <= reset + io <- child.io + + smem mem_0 : UInt<16>[7] + smem mem_1 : UInt<16>[7] + read mport r1 = mem_0[io.rAddr], clock + read mport r2 = mem_1[io.rAddr], clock + io.dO <= and(r1, and(r2, child.io.dO)) + when io.wEn : + write mport w1 = mem_0[io.wAddr], clock + write mport w2 = mem_1[io.wAddr], clock + w1 <= io.dI + w2 <= io.dI +""" + val confLoc = "ReplSeqMemTests.confTEMP" + val annos = Seq( + ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc), + NoDedupMemAnnotation(ComponentName("mem_0", ModuleName("ChildMemory",CircuitName("CustomMemory"))))) + val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) + // Check correctness of firrtl + val circuit = parse(res.getEmittedCircuit.value) + val numExtMods = circuit.modules.count { + case e: ExtModule => true + case _ => false + } + // Note that there are 3 identical SeqMems in this test + // If the NoDedupMemAnnotation were ignored, we'd end up with just 1 ExtModule + // If the NoDedupMemAnnotation were handled incorrectly as it was prior to this test, there + // would be 3 ExtModules + numExtMods should be (2) + (new java.io.File(confLoc)).delete() + } + "ReplSeqMem" should "de-duplicate memories without an annotation " in { val input = """ circuit CustomMemory : |
