aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorColin Schmidt2017-02-21 11:48:04 -0800
committerJack Koenig2017-02-21 11:48:04 -0800
commita02750f379b266b76febc58ef0351b56d21e9fcf (patch)
treec7ad03dfbe8415e3f061cfa1fecd2f5ea2ba960a
parentb69e787c0a698b7fb703ccd8d24003f83207e296 (diff)
Implementation of nodedupe mem (#447)
This allows the replseqmem transform to not deduplicate some memories, based on their name.
-rw-r--r--src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala2
-rw-r--r--src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala5
-rw-r--r--src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala36
-rw-r--r--src/test/scala/firrtlTests/ReplSeqMemTests.scala114
4 files changed, 145 insertions, 12 deletions
diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
index f8f76a49..1659cf22 100644
--- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
+++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
@@ -124,7 +124,7 @@ class ReplSeqMem extends Transform with SimpleRun {
new SimpleMidTransform(ToMemIR),
new SimpleMidTransform(ResolveMaskGranularity),
new SimpleMidTransform(RenameAnnotatedMemoryPorts),
- new SimpleMidTransform(ResolveMemoryReference),
+ new ResolveMemoryReference,
new CreateMemoryAnnotations(inConfigFile),
new ReplaceMemMacros(outConfigFile),
new WiringTransform,
diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
index 53d7234f..956bdd3c 100644
--- a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
+++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
@@ -77,8 +77,9 @@ object AnalysisUtils {
/** Checks whether the two memories are equivalent in all respects except name
*/
- def eqMems(a: DefAnnotatedMemory, b: DefAnnotatedMemory) =
- a == b.copy(info = a.info, name = a.name, memRef = a.memRef)
+ def eqMems(a: DefAnnotatedMemory, b: DefAnnotatedMemory, noDeDupeMems: Seq[String]) =
+ a == b.copy(info = a.info, name = a.name, memRef = a.memRef) &&
+ !(noDeDupeMems.contains(a.name) || noDeDupeMems.contains(b.name))
}
/** Determines if a write mask is needed (wmode/en and wmask are equivalent).
diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala
index 2112ca27..df555e57 100644
--- a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala
+++ b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala
@@ -2,35 +2,55 @@
package firrtl.passes
package memlib
+import firrtl._
import firrtl.ir._
import AnalysisUtils.eqMems
import firrtl.Mappers._
+import firrtl.annotations._
+/** A component, e.g. register etc. Must be declared only once under the TopAnnotation
+ */
+object NoDedupMemAnnotation {
+ def apply(target: ComponentName): Annotation = Annotation(target, classOf[ResolveMemoryReference], s"nodedupmem!")
+
+ def unapply(a: Annotation): Option[ComponentName] = a match {
+ case Annotation(ComponentName(n, mn), _, "nodedupmem!") => Some(ComponentName(n, mn))
+ case _ => None
+ }
+}
/** Resolves annotation ref to memories that exactly match (except name) another memory
*/
-object ResolveMemoryReference extends Pass {
-
- def name = "Resolve Memory Reference"
+class ResolveMemoryReference extends Transform {
+ def inputForm = MidForm
+ def outputForm = MidForm
type AnnotatedMemories = collection.mutable.ArrayBuffer[(String, DefAnnotatedMemory)]
/** If a candidate memory is identical except for name to another, add an
* annotation that references the name of the other memory.
*/
- def updateMemStmts(mname: String, uniqueMems: AnnotatedMemories)(s: Statement): Statement = s match {
+ def updateMemStmts(mname: String, uniqueMems: AnnotatedMemories, noDeDupeMems: Seq[String])(s: Statement): Statement = s match {
case m: DefAnnotatedMemory =>
- uniqueMems find (x => eqMems(x._2, m)) match {
+ uniqueMems find (x => eqMems(x._2, m, noDeDupeMems)) match {
case None =>
uniqueMems += (mname -> m)
m
case Some((module, proto)) => m copy (memRef = Some(module -> proto.name))
}
- case s => s map updateMemStmts(mname, uniqueMems)
+ case s => s map updateMemStmts(mname, uniqueMems, noDeDupeMems)
}
- def run(c: Circuit) = {
+ def run(c: Circuit, noDeDupeMems: Seq[String]) = {
val uniqueMems = new AnnotatedMemories
- c copy (modules = c.modules map (m => m map updateMemStmts(m.name, uniqueMems)))
+ c copy (modules = c.modules map (m => m map updateMemStmts(m.name, uniqueMems, noDeDupeMems)))
+ }
+ def execute(state: CircuitState): CircuitState = {
+ val noDedups = getMyAnnotations(state) match {
+ case Nil => Seq.empty
+ case annos =>
+ annos.collect { case NoDedupMemAnnotation(ComponentName(cn, _)) => cn }
+ }
+ CircuitState(run(state.circuit, noDedups), state.form)
}
}
diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
index fc3bfe8e..01a4501b 100644
--- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala
+++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
@@ -3,6 +3,7 @@
package firrtlTests
import firrtl._
+import firrtl.ir._
import firrtl.passes._
import firrtl.passes.memlib._
import annotations._
@@ -187,10 +188,121 @@ circuit Top :
tests foreach { case(hurdle, origin) => checkConnectOrigin(hurdle, origin) }
}
+ "ReplSeqMem" should "not de-duplicate memories with the nodedupe annotation " in {
+ val input = """
+circuit CustomMemory :
+ module CustomMemory :
+ input clock : Clock
+ input reset : UInt<1>
+ output io : {flip rClk : Clock, flip rAddr : UInt<3>, dO : UInt<16>, flip wClk : Clock, flip wAddr : UInt<3>, flip wEn : UInt<1>, flip dI : UInt<16>}
+
+ io is invalid
+ smem mem_0 : UInt<16>[7]
+ smem mem_1 : UInt<16>[7]
+ read mport _T_17 = mem_0[io.rAddr], clock
+ read mport _T_19 = mem_1[io.rAddr], clock
+ io.dO <= _T_17
+ when io.wEn :
+ write mport _T_18 = mem_0[io.wAddr], clock
+ write mport _T_20 = mem_1[io.wAddr], clock
+ _T_18 <= io.dI
+ _T_20 <= io.dI
+ skip
+"""
+ val confLoc = "ReplSeqMemTests.confTEMP"
+ val aMap = AnnotationMap(Seq(
+ ReplSeqMemAnnotation("-c:CustomMemory:-o:"+confLoc),
+ NoDedupMemAnnotation(ComponentName("mem_0", ModuleName("CustomMemory",CircuitName("CustomMemory"))))))
+ val writer = new java.io.StringWriter
+ compile(CircuitState(parse(input), ChirrtlForm, Some(aMap)), writer)
+ // Check correctness of firrtl
+ val circuit = parse(writer.toString)
+ val numExtMods = circuit.modules.count {
+ case e: ExtModule => true
+ case _ => false
+ }
+ require(numExtMods == 2)
+ (new java.io.File(confLoc)).delete()
+ }
+
+ "ReplSeqMem" should "only not de-duplicate memories with the nodedupe annotation " in {
+ val input = """
+circuit CustomMemory :
+ module CustomMemory :
+ input clock : Clock
+ input reset : UInt<1>
+ output io : {flip rClk : Clock, flip rAddr : UInt<3>, dO : UInt<16>, flip wClk : Clock, flip wAddr : UInt<3>, flip wEn : UInt<1>, flip dI : UInt<16>}
+
+ io is invalid
+ smem mem_0 : UInt<16>[7]
+ smem mem_1 : UInt<16>[7]
+ smem mem_2 : UInt<16>[7]
+ read mport _T_17 = mem_0[io.rAddr], clock
+ read mport _T_19 = mem_1[io.rAddr], clock
+ read mport _T_21 = mem_2[io.rAddr], clock
+ io.dO <= _T_17
+ when io.wEn :
+ write mport _T_18 = mem_0[io.wAddr], clock
+ write mport _T_20 = mem_1[io.wAddr], clock
+ write mport _T_22 = mem_2[io.wAddr], clock
+ _T_18 <= io.dI
+ _T_20 <= io.dI
+ _T_22 <= io.dI
+ skip
+"""
+ val confLoc = "ReplSeqMemTests.confTEMP"
+ val aMap = AnnotationMap(Seq(
+ ReplSeqMemAnnotation("-c:CustomMemory:-o:"+confLoc),
+ NoDedupMemAnnotation(ComponentName("mem_1", ModuleName("CustomMemory",CircuitName("CustomMemory"))))))
+ val writer = new java.io.StringWriter
+ compile(CircuitState(parse(input), ChirrtlForm, Some(aMap)), writer)
+ // Check correctness of firrtl
+ val circuit = parse(writer.toString)
+ val numExtMods = circuit.modules.count {
+ case e: ExtModule => true
+ case _ => false
+ }
+ require(numExtMods == 2)
+ (new java.io.File(confLoc)).delete()
+ }
+
+ "ReplSeqMem" should "de-duplicate memories without an annotation " in {
+ val input = """
+circuit CustomMemory :
+ module CustomMemory :
+ input clock : Clock
+ input reset : UInt<1>
+ output io : {flip rClk : Clock, flip rAddr : UInt<3>, dO : UInt<16>, flip wClk : Clock, flip wAddr : UInt<3>, flip wEn : UInt<1>, flip dI : UInt<16>}
+
+ io is invalid
+ smem mem_0 : UInt<16>[7]
+ smem mem_1 : UInt<16>[7]
+ read mport _T_17 = mem_0[io.rAddr], clock
+ read mport _T_19 = mem_1[io.rAddr], clock
+ io.dO <= _T_17
+ when io.wEn :
+ write mport _T_18 = mem_0[io.wAddr], clock
+ write mport _T_20 = mem_1[io.wAddr], clock
+ _T_18 <= io.dI
+ _T_20 <= io.dI
+ skip
+"""
+ val confLoc = "ReplSeqMemTests.confTEMP"
+ val aMap = AnnotationMap(Seq(ReplSeqMemAnnotation("-c:CustomMemory:-o:"+confLoc)))
+ val writer = new java.io.StringWriter
+ compile(CircuitState(parse(input), ChirrtlForm, Some(aMap)), writer)
+ // Check correctness of firrtl
+ val circuit = parse(writer.toString)
+ val numExtMods = circuit.modules.count {
+ case e: ExtModule => true
+ case _ => false
+ }
+ require(numExtMods == 1)
+ (new java.io.File(confLoc)).delete()
+ }
}
// TODO: make more checks
// readwrite vs. no readwrite
-// redundant memories (multiple instances of the same type of memory)
// mask + no mask
// conf