aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAngie2016-08-22 19:12:20 -0700
committerjackkoenig2016-09-06 00:17:18 -0700
commit6bf15386079d862d042968f5d2ac30c9d092134c (patch)
treeadab518cec2141451793c7a99cd2a7ad1f9ff985 /src
parentd2ee373b9f5cfb5dd50953f680ddcb2f8d4eb582 (diff)
Made the connect origin function more powerful
* It analyzes through statements that ConstProp would've optimized * Edge case wmask can be removed (pass tries harder to figure out that wmask = wen)
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/AnnotateMemMacros.scala48
-rw-r--r--src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala1
2 files changed, 42 insertions, 7 deletions
diff --git a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala
index 22fad253..8d7c3f65 100644
--- a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala
+++ b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala
@@ -22,7 +22,7 @@ object AnalysisUtils {
def getConnects(m: Module) = {
val connects = HashMap[String, Expression]()
def getConnects(s: Statement): Unit = s match {
- case s: Connect =>
+ case s: Connect =>
connects(s.loc.serialize) = s.expr
case s: PartialConnect =>
connects(s.loc.serialize) = s.expr
@@ -37,14 +37,49 @@ object AnalysisUtils {
}
// only works in a module
- def getConnectOrigin(connects: Map[String, Expression], node: String): Expression =
+ def getConnectOrigin(connects: Map[String, Expression], node: String): Expression = {
if (connects contains node) getConnectOrigin(connects,connects(node))
else EmptyExpression
-
+ }
+
+ def checkLit(e: Expression) = e match {
+ case l : Literal => true
+ case _ => false
+ }
+
+ def getOrigin(connects: Map[String, Expression], e: Expression) = e match {
+ case DoPrim(_,_,_,_) => getConnectOrigin(connects,e)
+ case l if (checkLit(l)) => e
+ case _ => getConnectOrigin(connects,e.serialize)
+ }
+
// backward searches until PrimOp, Lit or non-trivial Mux appears
- private def getConnectOrigin(connects: Map[String, Expression], e: Expression): Expression = e match {
- case Mux(cond, tv, fv, _) if we(tv) == we(one) && we(fv) == we(zero) =>
- getConnectOrigin(connects,cond.serialize)
+ // technically, you should keep searching through PrimOp, because a node + 0 is still itself,
+ // a node shifted by 0 is still itself, etc.
+ // TODO: handle validif???, more thorough
+ private def getConnectOrigin(connects: Map[String, Expression], e: Expression): Expression = e match {
+ case Mux(cond, tv, fv, _) =>
+ val fvOrigin = getOrigin(connects,fv)
+ val tvOrigin = getOrigin(connects,tv)
+ val condOrigin = getOrigin(connects,cond)
+ if (we(tvOrigin) == we(one) && we(fvOrigin) == we(zero)) condOrigin
+ else if (we(condOrigin) == we(one)) tvOrigin
+ else if (we(condOrigin) == we(zero)) fvOrigin
+ else if (we(tvOrigin) == we(fvOrigin)) tvOrigin
+ else if (we(fvOrigin) == we(zero) && we(condOrigin) == we(tvOrigin)) condOrigin
+ else e
+ case DoPrim(op, args, consts, tpe) if op == PrimOps.Or && args.contains(one) => one
+ case DoPrim(op, args, consts, tpe) if op == PrimOps.And && args.contains(zero) => zero
+ case DoPrim(op, args, consts, tpe) if op == PrimOps.Bits =>
+ val msb = consts(0)
+ val lsb = consts(1)
+ val extractionWidth = (msb-lsb)+1
+ val nodeWidth = bitWidth(args.head.tpe)
+ // if you're extracting the full bitwidth, then keep searching for origin
+ if (nodeWidth == extractionWidth) getOrigin(connects,args.head)
+ else e
+ case DoPrim(op, args, _, _) if (op == PrimOps.AsUInt || op == PrimOps.AsSInt || op == PrimOps.AsClock) =>
+ getOrigin(connects,args.head)
case _: WRef | _: SubField | _: SubIndex | _: SubAccess if connects contains e.serialize =>
getConnectOrigin(connects,e.serialize)
case _ => e
@@ -113,7 +148,6 @@ object AnnotateMemMacros extends Pass {
case b: Block => Block(b.stmts map updateStmts)
case s => s
}
-
m.copy(body=updateStmts(m.body))
}
diff --git a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala
index 2301ad1b..a4a910fd 100644
--- a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala
+++ b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala
@@ -47,6 +47,7 @@ object MemTransformUtils {
}
def updateStmtRefs(s: Statement): Statement = s map updateStmtRefs map updateRef match {
case Connect(info, loc, exp) if loc == EmptyExpression => EmptyStmt
+ case Connect(info, WSubIndex(EmptyExpression,_,_,_), exp) => EmptyStmt
case s => s
}
updateStmtRefs(s)