diff options
| author | Donggyu Kim | 2016-09-06 20:57:03 -0700 |
|---|---|---|
| committer | Donggyu Kim | 2016-09-21 13:13:03 -0700 |
| commit | 350ffd7bbc1b014b9d9b256da4181c59bf0419e3 (patch) | |
| tree | d9cabc2ec866799cbfba892e6b69fbcffe08d3b2 | |
| parent | 726c808375fe513c70376bf05e76dd938e578bf9 (diff) | |
generalize Analysis.getConnects for code resuse
| -rw-r--r-- | src/main/scala/firrtl/passes/AnnotateMemMacros.scala | 22 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/InferReadWrite.scala | 17 |
2 files changed, 13 insertions, 26 deletions
diff --git a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala index 58e10a66..7da290b7 100644 --- a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala +++ b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala @@ -21,22 +21,22 @@ case class AppendableInfo(fields: Map[String, Any]) extends Info { } object AnalysisUtils { - - def getConnects(m: Module) = { - val connects = mutable.HashMap[String, Expression]() - def getConnects(s: Statement): Statement = { - s map getConnects match { + type Connects = collection.mutable.HashMap[String, Expression] + def getConnects(m: DefModule): Connects = { + def getConnects(connects: Connects)(s: Statement): Statement = { + s match { case Connect(_, loc, expr) => connects(loc.serialize) = expr case DefNode(_, name, value) => connects(name) = value case _ => // do nothing } - s // return because we only have map and not foreach + s map getConnects(connects) } - getConnects(m.body) - connects.toMap - } + val connects = new Connects + m map getConnects(connects) + connects + } // takes in a list of node-to-node connections in a given module and looks to find the origin of the LHS. // if the source is a trivial primop/mux, etc. that has yet to be optimized via constant propagation, @@ -44,12 +44,12 @@ object AnalysisUtils { // use case: compare if two nodes have the same origin // limitation: only works in a module (stops @ module inputs) // TODO: more thorough (i.e. a + 0 = a) - def getConnectOrigin(connects: Map[String, Expression], node: String): Expression = { + def getConnectOrigin(connects: Connects, node: String): Expression = { if (connects contains node) getOrigin(connects, connects(node)) else EmptyExpression } - private def getOrigin(connects: Map[String, Expression], e: Expression): Expression = e match { + private def getOrigin(connects: Connects, e: Expression): Expression = e match { case Mux(cond, tv, fv, _) => val fvOrigin = getOrigin(connects, fv) val tvOrigin = getOrigin(connects, tv) diff --git a/src/main/scala/firrtl/passes/InferReadWrite.scala b/src/main/scala/firrtl/passes/InferReadWrite.scala index 9fbd6ab3..38933103 100644 --- a/src/main/scala/firrtl/passes/InferReadWrite.scala +++ b/src/main/scala/firrtl/passes/InferReadWrite.scala @@ -51,26 +51,14 @@ object InferReadWritePass extends Pass { def name = "Infer ReadWrite Ports" def inferReadWrite(m: Module) = { + import AnalysisUtils._ import WrappedExpression.we - val connects = HashMap[String, Expression]() + val connects = getConnects(m) val repl = HashMap[String, Expression]() val stmts = ArrayBuffer[Statement]() val zero = we(UIntLiteral(0, IntWidth(1))) val one = we(UIntLiteral(1, IntWidth(1))) - // find all wire connections - def analyze(s: Statement): Unit = s match { - case s: Connect => - connects(s.loc.serialize) = s.expr - case s: PartialConnect => - connects(s.loc.serialize) = s.expr - case s: DefNode => - connects(s.name) = s.value - case s: Block => - s.stmts foreach analyze - case _ => - } - def getProductTermsFromExp(e: Expression): Seq[Expression] = e match { // No ConstProp yet... @@ -169,7 +157,6 @@ object InferReadWritePass extends Pass { case s => s } - analyze(m.body) Module(m.info, m.name, m.ports, Block((m.body map inferReadWrite map replaceStmt) +: stmts.toSeq)) } |
