diff options
| author | Colin Schmidt | 2017-02-21 11:48:04 -0800 |
|---|---|---|
| committer | Jack Koenig | 2017-02-21 11:48:04 -0800 |
| commit | a02750f379b266b76febc58ef0351b56d21e9fcf (patch) | |
| tree | c7ad03dfbe8415e3f061cfa1fecd2f5ea2ba960a | |
| parent | b69e787c0a698b7fb703ccd8d24003f83207e296 (diff) | |
Implementation of nodedupe mem (#447)
This allows the replseqmem transform to not deduplicate
some memories, based on their name.
4 files changed, 145 insertions, 12 deletions
diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala index f8f76a49..1659cf22 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala @@ -124,7 +124,7 @@ class ReplSeqMem extends Transform with SimpleRun { new SimpleMidTransform(ToMemIR), new SimpleMidTransform(ResolveMaskGranularity), new SimpleMidTransform(RenameAnnotatedMemoryPorts), - new SimpleMidTransform(ResolveMemoryReference), + new ResolveMemoryReference, new CreateMemoryAnnotations(inConfigFile), new ReplaceMemMacros(outConfigFile), new WiringTransform, diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala index 53d7234f..956bdd3c 100644 --- a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala +++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala @@ -77,8 +77,9 @@ object AnalysisUtils { /** Checks whether the two memories are equivalent in all respects except name */ - def eqMems(a: DefAnnotatedMemory, b: DefAnnotatedMemory) = - a == b.copy(info = a.info, name = a.name, memRef = a.memRef) + def eqMems(a: DefAnnotatedMemory, b: DefAnnotatedMemory, noDeDupeMems: Seq[String]) = + a == b.copy(info = a.info, name = a.name, memRef = a.memRef) && + !(noDeDupeMems.contains(a.name) || noDeDupeMems.contains(b.name)) } /** Determines if a write mask is needed (wmode/en and wmask are equivalent). diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala index 2112ca27..df555e57 100644 --- a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala +++ b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala @@ -2,35 +2,55 @@ package firrtl.passes package memlib +import firrtl._ import firrtl.ir._ import AnalysisUtils.eqMems import firrtl.Mappers._ +import firrtl.annotations._ +/** A component, e.g. register etc. Must be declared only once under the TopAnnotation + */ +object NoDedupMemAnnotation { + def apply(target: ComponentName): Annotation = Annotation(target, classOf[ResolveMemoryReference], s"nodedupmem!") + + def unapply(a: Annotation): Option[ComponentName] = a match { + case Annotation(ComponentName(n, mn), _, "nodedupmem!") => Some(ComponentName(n, mn)) + case _ => None + } +} /** Resolves annotation ref to memories that exactly match (except name) another memory */ -object ResolveMemoryReference extends Pass { - - def name = "Resolve Memory Reference" +class ResolveMemoryReference extends Transform { + def inputForm = MidForm + def outputForm = MidForm type AnnotatedMemories = collection.mutable.ArrayBuffer[(String, DefAnnotatedMemory)] /** If a candidate memory is identical except for name to another, add an * annotation that references the name of the other memory. */ - def updateMemStmts(mname: String, uniqueMems: AnnotatedMemories)(s: Statement): Statement = s match { + def updateMemStmts(mname: String, uniqueMems: AnnotatedMemories, noDeDupeMems: Seq[String])(s: Statement): Statement = s match { case m: DefAnnotatedMemory => - uniqueMems find (x => eqMems(x._2, m)) match { + uniqueMems find (x => eqMems(x._2, m, noDeDupeMems)) match { case None => uniqueMems += (mname -> m) m case Some((module, proto)) => m copy (memRef = Some(module -> proto.name)) } - case s => s map updateMemStmts(mname, uniqueMems) + case s => s map updateMemStmts(mname, uniqueMems, noDeDupeMems) } - def run(c: Circuit) = { + def run(c: Circuit, noDeDupeMems: Seq[String]) = { val uniqueMems = new AnnotatedMemories - c copy (modules = c.modules map (m => m map updateMemStmts(m.name, uniqueMems))) + c copy (modules = c.modules map (m => m map updateMemStmts(m.name, uniqueMems, noDeDupeMems))) + } + def execute(state: CircuitState): CircuitState = { + val noDedups = getMyAnnotations(state) match { + case Nil => Seq.empty + case annos => + annos.collect { case NoDedupMemAnnotation(ComponentName(cn, _)) => cn } + } + CircuitState(run(state.circuit, noDedups), state.form) } } diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index fc3bfe8e..01a4501b 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -3,6 +3,7 @@ package firrtlTests import firrtl._ +import firrtl.ir._ import firrtl.passes._ import firrtl.passes.memlib._ import annotations._ @@ -187,10 +188,121 @@ circuit Top : tests foreach { case(hurdle, origin) => checkConnectOrigin(hurdle, origin) } } + "ReplSeqMem" should "not de-duplicate memories with the nodedupe annotation " in { + val input = """ +circuit CustomMemory : + module CustomMemory : + input clock : Clock + input reset : UInt<1> + output io : {flip rClk : Clock, flip rAddr : UInt<3>, dO : UInt<16>, flip wClk : Clock, flip wAddr : UInt<3>, flip wEn : UInt<1>, flip dI : UInt<16>} + + io is invalid + smem mem_0 : UInt<16>[7] + smem mem_1 : UInt<16>[7] + read mport _T_17 = mem_0[io.rAddr], clock + read mport _T_19 = mem_1[io.rAddr], clock + io.dO <= _T_17 + when io.wEn : + write mport _T_18 = mem_0[io.wAddr], clock + write mport _T_20 = mem_1[io.wAddr], clock + _T_18 <= io.dI + _T_20 <= io.dI + skip +""" + val confLoc = "ReplSeqMemTests.confTEMP" + val aMap = AnnotationMap(Seq( + ReplSeqMemAnnotation("-c:CustomMemory:-o:"+confLoc), + NoDedupMemAnnotation(ComponentName("mem_0", ModuleName("CustomMemory",CircuitName("CustomMemory")))))) + val writer = new java.io.StringWriter + compile(CircuitState(parse(input), ChirrtlForm, Some(aMap)), writer) + // Check correctness of firrtl + val circuit = parse(writer.toString) + val numExtMods = circuit.modules.count { + case e: ExtModule => true + case _ => false + } + require(numExtMods == 2) + (new java.io.File(confLoc)).delete() + } + + "ReplSeqMem" should "only not de-duplicate memories with the nodedupe annotation " in { + val input = """ +circuit CustomMemory : + module CustomMemory : + input clock : Clock + input reset : UInt<1> + output io : {flip rClk : Clock, flip rAddr : UInt<3>, dO : UInt<16>, flip wClk : Clock, flip wAddr : UInt<3>, flip wEn : UInt<1>, flip dI : UInt<16>} + + io is invalid + smem mem_0 : UInt<16>[7] + smem mem_1 : UInt<16>[7] + smem mem_2 : UInt<16>[7] + read mport _T_17 = mem_0[io.rAddr], clock + read mport _T_19 = mem_1[io.rAddr], clock + read mport _T_21 = mem_2[io.rAddr], clock + io.dO <= _T_17 + when io.wEn : + write mport _T_18 = mem_0[io.wAddr], clock + write mport _T_20 = mem_1[io.wAddr], clock + write mport _T_22 = mem_2[io.wAddr], clock + _T_18 <= io.dI + _T_20 <= io.dI + _T_22 <= io.dI + skip +""" + val confLoc = "ReplSeqMemTests.confTEMP" + val aMap = AnnotationMap(Seq( + ReplSeqMemAnnotation("-c:CustomMemory:-o:"+confLoc), + NoDedupMemAnnotation(ComponentName("mem_1", ModuleName("CustomMemory",CircuitName("CustomMemory")))))) + val writer = new java.io.StringWriter + compile(CircuitState(parse(input), ChirrtlForm, Some(aMap)), writer) + // Check correctness of firrtl + val circuit = parse(writer.toString) + val numExtMods = circuit.modules.count { + case e: ExtModule => true + case _ => false + } + require(numExtMods == 2) + (new java.io.File(confLoc)).delete() + } + + "ReplSeqMem" should "de-duplicate memories without an annotation " in { + val input = """ +circuit CustomMemory : + module CustomMemory : + input clock : Clock + input reset : UInt<1> + output io : {flip rClk : Clock, flip rAddr : UInt<3>, dO : UInt<16>, flip wClk : Clock, flip wAddr : UInt<3>, flip wEn : UInt<1>, flip dI : UInt<16>} + + io is invalid + smem mem_0 : UInt<16>[7] + smem mem_1 : UInt<16>[7] + read mport _T_17 = mem_0[io.rAddr], clock + read mport _T_19 = mem_1[io.rAddr], clock + io.dO <= _T_17 + when io.wEn : + write mport _T_18 = mem_0[io.wAddr], clock + write mport _T_20 = mem_1[io.wAddr], clock + _T_18 <= io.dI + _T_20 <= io.dI + skip +""" + val confLoc = "ReplSeqMemTests.confTEMP" + val aMap = AnnotationMap(Seq(ReplSeqMemAnnotation("-c:CustomMemory:-o:"+confLoc))) + val writer = new java.io.StringWriter + compile(CircuitState(parse(input), ChirrtlForm, Some(aMap)), writer) + // Check correctness of firrtl + val circuit = parse(writer.toString) + val numExtMods = circuit.modules.count { + case e: ExtModule => true + case _ => false + } + require(numExtMods == 1) + (new java.io.File(confLoc)).delete() + } } // TODO: make more checks // readwrite vs. no readwrite -// redundant memories (multiple instances of the same type of memory) // mask + no mask // conf |
