diff options
| author | Donggyu Kim | 2016-09-16 03:47:11 -0700 |
|---|---|---|
| committer | Donggyu Kim | 2016-09-21 13:18:45 -0700 |
| commit | 6fede8c92edd414ba63ed185fbad2cc48fd29d01 (patch) | |
| tree | 6b0b8ae102d9a6e89df726ce76ffcdec4bafe00b /src | |
| parent | ed95911d1b491ff3b122eb865f618dc8e65c767c (diff) | |
refactor UpdateDuplicateMemMacros
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/ReplaceMemMacros.scala | 4 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala | 92 |
2 files changed, 48 insertions, 48 deletions
diff --git a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/ReplaceMemMacros.scala index 7bb9c6c4..0ca1a32f 100644 --- a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala +++ b/src/main/scala/firrtl/passes/ReplaceMemMacros.scala @@ -22,7 +22,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass { val uniqueMems = mutable.ArrayBuffer[DefMemory]() def updateMemMods(m: Module) = { - val memPortMap = mutable.HashMap[String, Expression]() + val memPortMap = new MemPortMap def updateMemStmts(s: Statement): Statement = s match { case m: DefMemory if containsInfo(m.info, "useMacro") => @@ -52,7 +52,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass { } val updatedMems = updateMemStmts(m.body) - val updatedConns = updateStmtRefs(updatedMems, memPortMap.toMap) + val updatedConns = updateStmtRefs(memPortMap)(updatedMems) m.copy(body = updatedConns) } diff --git a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala index 0098fa5f..098d83f0 100644 --- a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala +++ b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala @@ -2,23 +2,25 @@ package firrtl.passes -import scala.collection.mutable -import AnalysisUtils._ -import MemTransformUtils._ -import firrtl.ir._ import firrtl._ -import firrtl.Mappers._ +import firrtl.ir._ import firrtl.Utils._ +import firrtl.Mappers._ +import AnalysisUtils._ +import MemTransformUtils._ object MemTransformUtils { + type MemPortMap = collection.mutable.HashMap[String, Expression] + type Memories = collection.mutable.ArrayBuffer[DefMemory] + def createRef(n: String) = WRef(n, UnknownType, ExpKind, UNKNOWNGENDER) def createSubField(exp: Expression, n: String) = WSubField(exp, n, UnknownType, UNKNOWNGENDER) def connectFields(lref: Expression, lname: String, rref: Expression, rname: String) = Connect(NoInfo, createSubField(lref, lname), createSubField(rref, rname)) def getMemPortMap(m: DefMemory) = { - val memPortMap = mutable.HashMap[String, Expression]() + val memPortMap = new MemPortMap val defaultFields = Seq("addr", "en", "clk") val rFields = defaultFields :+ "data" val wFields = rFields :+ "mask" @@ -33,35 +35,41 @@ object MemTransformUtils { updateMemPortMap(m.readers, rFields, "R") updateMemPortMap(m.writers, wFields, "W") updateMemPortMap(m.readwriters, rwFields, "RW") - memPortMap.toMap + memPortMap } + def createMemProto(m: DefMemory) = { val rports = (0 until m.readers.length) map (i => s"R$i") val wports = (0 until m.writers.length) map (i => s"W$i") val rwports = (0 until m.readwriters.length) map (i => s"RW$i") - m.copy(readers = rports, writers = wports, readwriters = rwports) + m copy (readers = rports, writers = wports, readwriters = rwports) } - def updateStmtRefs(s: Statement, repl: Map[String, Expression]): Statement = { - def updateRef(e: Expression): Expression = e map updateRef match { - case e => repl getOrElse (e.serialize, e) + def updateStmtRefs(repl: MemPortMap)(s: Statement): Statement = { + def updateRef(e: Expression): Expression = { + val ex = e map updateRef + repl getOrElse (ex.serialize, ex) } + def hasEmptyExpr(stmt: Statement): Boolean = { var foundEmpty = false def testEmptyExpr(e: Expression): Expression = { - e map testEmptyExpr match { + e match { case EmptyExpression => foundEmpty = true case _ => } - e // map must return; no foreach + e map testEmptyExpr // map must return; no foreach } stmt map testEmptyExpr foundEmpty } - def updateStmtRefs(s: Statement): Statement = s map updateStmtRefs map updateRef match { - case c: Connect if hasEmptyExpr(c) => EmptyStmt - case s => s - } + + def updateStmtRefs(s: Statement): Statement = + s map updateStmtRefs map updateRef match { + case c: Connect if hasEmptyExpr(c) => EmptyStmt + case s => s + } + updateStmtRefs(s) } @@ -71,37 +79,29 @@ object UpdateDuplicateMemMacros extends Pass { def name = "Convert memory port names to be more meaningful and tag duplicate memories" - def run(c: Circuit) = { - val uniqueMems = mutable.ArrayBuffer[DefMemory]() - - def updateMemMods(m: Module) = { - val memPortMap = mutable.HashMap[String, Expression]() - - def updateMemStmts(s: Statement): Statement = s match { - case m: DefMemory if containsInfo(m.info, "useMacro") => - val updatedMem = createMemProto(m) - memPortMap ++= getMemPortMap(m) - val proto = uniqueMems find (x => eqMems(x, updatedMem)) - if (proto == None) { - uniqueMems += updatedMem - updatedMem - } - else updatedMem.copy(info = appendInfo(updatedMem.info, "ref" -> proto.get.name)) - case b: Block => b map updateMemStmts - case s => s + def updateMemStmts(uniqueMems: Memories, + memPortMap: MemPortMap) + (s: Statement): Statement = s match { + case m: DefMemory if containsInfo(m.info, "useMacro") => + val updatedMem = createMemProto(m) + memPortMap ++= getMemPortMap(m) + uniqueMems find (x => eqMems(x, updatedMem)) match { + case None => + uniqueMems += updatedMem + updatedMem + case Some(proto) => + updatedMem copy (info = appendInfo(updatedMem.info, "ref" -> proto.name)) } + case s => s map updateMemStmts(uniqueMems, memPortMap) + } - val updatedMems = updateMemStmts(m.body) - val updatedConns = updateStmtRefs(updatedMems, memPortMap.toMap) - m.copy(body = updatedConns) - } - - val updatedMods = c.modules map { - case m: Module => updateMemMods(m) - case m: ExtModule => m - } - c.copy(modules = updatedMods) - } + def updateMemMods(m: DefModule) = { + val uniqueMems = new Memories + val memPortMap = new MemPortMap + (m map updateMemStmts(uniqueMems, memPortMap) + map updateStmtRefs(memPortMap)) + } + def run(c: Circuit) = c copy (modules = (c.modules map updateMemMods)) } // TODO: Module namespace? |
