aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDonggyu Kim2016-09-06 20:57:03 -0700
committerDonggyu Kim2016-09-21 13:13:03 -0700
commit350ffd7bbc1b014b9d9b256da4181c59bf0419e3 (patch)
treed9cabc2ec866799cbfba892e6b69fbcffe08d3b2 /src
parent726c808375fe513c70376bf05e76dd938e578bf9 (diff)
generalize Analysis.getConnects for code resuse
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/AnnotateMemMacros.scala22
-rw-r--r--src/main/scala/firrtl/passes/InferReadWrite.scala17
2 files changed, 13 insertions, 26 deletions
diff --git a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala
index 58e10a66..7da290b7 100644
--- a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala
+++ b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala
@@ -21,22 +21,22 @@ case class AppendableInfo(fields: Map[String, Any]) extends Info {
}
object AnalysisUtils {
-
- def getConnects(m: Module) = {
- val connects = mutable.HashMap[String, Expression]()
- def getConnects(s: Statement): Statement = {
- s map getConnects match {
+ type Connects = collection.mutable.HashMap[String, Expression]
+ def getConnects(m: DefModule): Connects = {
+ def getConnects(connects: Connects)(s: Statement): Statement = {
+ s match {
case Connect(_, loc, expr) =>
connects(loc.serialize) = expr
case DefNode(_, name, value) =>
connects(name) = value
case _ => // do nothing
}
- s // return because we only have map and not foreach
+ s map getConnects(connects)
}
- getConnects(m.body)
- connects.toMap
- }
+ val connects = new Connects
+ m map getConnects(connects)
+ connects
+ }
// takes in a list of node-to-node connections in a given module and looks to find the origin of the LHS.
// if the source is a trivial primop/mux, etc. that has yet to be optimized via constant propagation,
@@ -44,12 +44,12 @@ object AnalysisUtils {
// use case: compare if two nodes have the same origin
// limitation: only works in a module (stops @ module inputs)
// TODO: more thorough (i.e. a + 0 = a)
- def getConnectOrigin(connects: Map[String, Expression], node: String): Expression = {
+ def getConnectOrigin(connects: Connects, node: String): Expression = {
if (connects contains node) getOrigin(connects, connects(node))
else EmptyExpression
}
- private def getOrigin(connects: Map[String, Expression], e: Expression): Expression = e match {
+ private def getOrigin(connects: Connects, e: Expression): Expression = e match {
case Mux(cond, tv, fv, _) =>
val fvOrigin = getOrigin(connects, fv)
val tvOrigin = getOrigin(connects, tv)
diff --git a/src/main/scala/firrtl/passes/InferReadWrite.scala b/src/main/scala/firrtl/passes/InferReadWrite.scala
index 9fbd6ab3..38933103 100644
--- a/src/main/scala/firrtl/passes/InferReadWrite.scala
+++ b/src/main/scala/firrtl/passes/InferReadWrite.scala
@@ -51,26 +51,14 @@ object InferReadWritePass extends Pass {
def name = "Infer ReadWrite Ports"
def inferReadWrite(m: Module) = {
+ import AnalysisUtils._
import WrappedExpression.we
- val connects = HashMap[String, Expression]()
+ val connects = getConnects(m)
val repl = HashMap[String, Expression]()
val stmts = ArrayBuffer[Statement]()
val zero = we(UIntLiteral(0, IntWidth(1)))
val one = we(UIntLiteral(1, IntWidth(1)))
- // find all wire connections
- def analyze(s: Statement): Unit = s match {
- case s: Connect =>
- connects(s.loc.serialize) = s.expr
- case s: PartialConnect =>
- connects(s.loc.serialize) = s.expr
- case s: DefNode =>
- connects(s.name) = s.value
- case s: Block =>
- s.stmts foreach analyze
- case _ =>
- }
-
def getProductTermsFromExp(e: Expression): Seq[Expression] =
e match {
// No ConstProp yet...
@@ -169,7 +157,6 @@ object InferReadWritePass extends Pass {
case s => s
}
- analyze(m.body)
Module(m.info, m.name, m.ports, Block((m.body map inferReadWrite map replaceStmt) +: stmts.toSeq))
}