aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJack Koenig2018-08-23 18:19:47 -0700
committerGitHub2018-08-23 18:19:47 -0700
commitabaa38f4f63b35105796916e3df2ccf5c1639f65 (patch)
tree1e3a90664f445d9ca6a8bddf148d691770a02430 /src
parentd7b96168f1c7244124ac258de174bf11d53092ab (diff)
Fix NoDedupMem to be cognizant of Module scope (#876)
Previously, mems marked no dedup would prevent mems with the same instance name in other modules from deduping
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala6
-rw-r--r--src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala52
-rw-r--r--src/test/scala/firrtlTests/ReplSeqMemTests.scala55
3 files changed, 94 insertions, 19 deletions
diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
index e254dcc9..b552470d 100644
--- a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
+++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
@@ -75,12 +75,6 @@ object AnalysisUtils {
}
case _ => e
}
-
- /** Checks whether the two memories are equivalent in all respects except name
- */
- def eqMems(a: DefAnnotatedMemory, b: DefAnnotatedMemory, noDeDupeMems: Seq[String]): Boolean =
- a == b.copy(info = a.info, name = a.name, memRef = a.memRef) &&
- !(noDeDupeMems.contains(a.name) || noDeDupeMems.contains(b.name))
}
/** Determines if a write mask is needed (wmode/en and wmask are equivalent).
diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala
index d195ea55..b0d3731f 100644
--- a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala
+++ b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala
@@ -4,7 +4,6 @@ package firrtl.passes
package memlib
import firrtl._
import firrtl.ir._
-import AnalysisUtils.eqMems
import firrtl.Mappers._
import firrtl.annotations._
@@ -19,30 +18,57 @@ class ResolveMemoryReference extends Transform {
def inputForm = MidForm
def outputForm = MidForm
- type AnnotatedMemories = collection.mutable.ArrayBuffer[(String, DefAnnotatedMemory)]
+ /** Helper class for determining when two memories are equivalent while igoring
+ * irrelevant details like name and info
+ */
+ private class WrappedDefAnnoMemory(val underlying: DefAnnotatedMemory) {
+ // Remove irrelevant details for comparison
+ private def generic = underlying.copy(info = NoInfo, name = "", memRef = None)
+ override def hashCode: Int = generic.hashCode
+ override def equals(that: Any): Boolean = that match {
+ case mem: WrappedDefAnnoMemory => this.generic == mem.generic
+ case _ => false
+ }
+ }
+ private def wrap(mem: DefAnnotatedMemory) = new WrappedDefAnnoMemory(mem)
+
+ // Values are Tuple of Module Name and Memory Instance Name
+ private type AnnotatedMemories = collection.mutable.HashMap[WrappedDefAnnoMemory, (String, String)]
+
+ private def dedupable(noDedups: Map[String, Set[String]], module: String, memory: String): Boolean =
+ noDedups.get(module).map(!_.contains(memory)).getOrElse(true)
/** If a candidate memory is identical except for name to another, add an
* annotation that references the name of the other memory.
*/
- def updateMemStmts(mname: String, uniqueMems: AnnotatedMemories, noDeDupeMems: Seq[String])(s: Statement): Statement = s match {
- case m: DefAnnotatedMemory =>
- uniqueMems find (x => eqMems(x._2, m, noDeDupeMems)) match {
+ def updateMemStmts(mname: String,
+ existingMems: AnnotatedMemories,
+ noDedupMap: Map[String, Set[String]])
+ (s: Statement): Statement = s match {
+ // If not dedupable, no need to add to existing (since nothing can dedup with it)
+ // We just return the DefAnnotatedMemory as is in the default case below
+ case m: DefAnnotatedMemory if dedupable(noDedupMap, mname, m.name) =>
+ val wrapped = wrap(m)
+ existingMems.get(wrapped) match {
+ case proto @ Some(_) =>
+ m.copy(memRef = proto)
case None =>
- uniqueMems += (mname -> m)
+ existingMems(wrapped) = (mname, m.name)
m
- case Some((module, proto)) => m copy (memRef = Some(module -> proto.name))
}
- case s => s map updateMemStmts(mname, uniqueMems, noDeDupeMems)
+ case s => s.map(updateMemStmts(mname, existingMems, noDedupMap))
}
- def run(c: Circuit, noDeDupeMems: Seq[String]) = {
- val uniqueMems = new AnnotatedMemories
- c copy (modules = c.modules map (m => m map updateMemStmts(m.name, uniqueMems, noDeDupeMems)))
+ def run(c: Circuit, noDedupMap: Map[String, Set[String]]) = {
+ val existingMems = new AnnotatedMemories
+ val modulesx = c.modules.map(m => m.map(updateMemStmts(m.name, existingMems, noDedupMap)))
+ c.copy(modules = modulesx)
}
def execute(state: CircuitState): CircuitState = {
val noDedups = state.annotations.collect {
- case NoDedupMemAnnotation(ComponentName(cn, _)) => cn
+ case NoDedupMemAnnotation(ComponentName(cn, ModuleName(mn, _))) => mn -> cn
}
- state.copy(circuit=run(state.circuit, noDedups))
+ val noDedupMap: Map[String, Set[String]] = noDedups.groupBy(_._1).mapValues(_.map(_._2).toSet)
+ state.copy(circuit = run(state.circuit, noDedupMap))
}
}
diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
index 2eae3580..6cedd3f0 100644
--- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala
+++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
@@ -264,6 +264,61 @@ circuit CustomMemory :
(new java.io.File(confLoc)).delete()
}
+ "ReplSeqMem" should "dedup mems with the same instance name as other mems (in other modules) marked NoDedup" in {
+ val input = """
+circuit CustomMemory :
+ module ChildMemory :
+ input clock : Clock
+ input reset : UInt<1>
+ output io : {flip rClk : Clock, flip rAddr : UInt<3>, dO : UInt<16>, flip wClk : Clock, flip wAddr : UInt<3>, flip wEn : UInt<1>, flip dI : UInt<16>}
+
+ smem mem_0 : UInt<16>[7]
+ read mport r1 = mem_0[io.rAddr], clock
+ io.dO <= r1
+ when io.wEn :
+ write mport w1 = mem_0[io.wAddr], clock
+ w1 <= io.dI
+
+ module CustomMemory :
+ input clock : Clock
+ input reset : UInt<1>
+ output io : {flip rClk : Clock, flip rAddr : UInt<3>, dO : UInt<16>, flip wClk : Clock, flip wAddr : UInt<3>, flip wEn : UInt<1>, flip dI : UInt<16>}
+
+ inst child of ChildMemory
+ child.clock <= clock
+ child.reset <= reset
+ io <- child.io
+
+ smem mem_0 : UInt<16>[7]
+ smem mem_1 : UInt<16>[7]
+ read mport r1 = mem_0[io.rAddr], clock
+ read mport r2 = mem_1[io.rAddr], clock
+ io.dO <= and(r1, and(r2, child.io.dO))
+ when io.wEn :
+ write mport w1 = mem_0[io.wAddr], clock
+ write mport w2 = mem_1[io.wAddr], clock
+ w1 <= io.dI
+ w2 <= io.dI
+"""
+ val confLoc = "ReplSeqMemTests.confTEMP"
+ val annos = Seq(
+ ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc),
+ NoDedupMemAnnotation(ComponentName("mem_0", ModuleName("ChildMemory",CircuitName("CustomMemory")))))
+ val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
+ // Check correctness of firrtl
+ val circuit = parse(res.getEmittedCircuit.value)
+ val numExtMods = circuit.modules.count {
+ case e: ExtModule => true
+ case _ => false
+ }
+ // Note that there are 3 identical SeqMems in this test
+ // If the NoDedupMemAnnotation were ignored, we'd end up with just 1 ExtModule
+ // If the NoDedupMemAnnotation were handled incorrectly as it was prior to this test, there
+ // would be 3 ExtModules
+ numExtMods should be (2)
+ (new java.io.File(confLoc)).delete()
+ }
+
"ReplSeqMem" should "de-duplicate memories without an annotation " in {
val input = """
circuit CustomMemory :