diff options
| author | Angie | 2016-08-22 19:12:20 -0700 |
|---|---|---|
| committer | jackkoenig | 2016-09-06 00:17:18 -0700 |
| commit | 6bf15386079d862d042968f5d2ac30c9d092134c (patch) | |
| tree | adab518cec2141451793c7a99cd2a7ad1f9ff985 /src | |
| parent | d2ee373b9f5cfb5dd50953f680ddcb2f8d4eb582 (diff) | |
Made the connect origin function more powerful
* It analyzes through statements that ConstProp would've optimized
* Edge case wmask can be removed (pass tries harder to figure out that wmask = wen)
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/AnnotateMemMacros.scala | 48 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala | 1 |
2 files changed, 42 insertions, 7 deletions
diff --git a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala index 22fad253..8d7c3f65 100644 --- a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala +++ b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala @@ -22,7 +22,7 @@ object AnalysisUtils { def getConnects(m: Module) = { val connects = HashMap[String, Expression]() def getConnects(s: Statement): Unit = s match { - case s: Connect => + case s: Connect => connects(s.loc.serialize) = s.expr case s: PartialConnect => connects(s.loc.serialize) = s.expr @@ -37,14 +37,49 @@ object AnalysisUtils { } // only works in a module - def getConnectOrigin(connects: Map[String, Expression], node: String): Expression = + def getConnectOrigin(connects: Map[String, Expression], node: String): Expression = { if (connects contains node) getConnectOrigin(connects,connects(node)) else EmptyExpression - + } + + def checkLit(e: Expression) = e match { + case l : Literal => true + case _ => false + } + + def getOrigin(connects: Map[String, Expression], e: Expression) = e match { + case DoPrim(_,_,_,_) => getConnectOrigin(connects,e) + case l if (checkLit(l)) => e + case _ => getConnectOrigin(connects,e.serialize) + } + // backward searches until PrimOp, Lit or non-trivial Mux appears - private def getConnectOrigin(connects: Map[String, Expression], e: Expression): Expression = e match { - case Mux(cond, tv, fv, _) if we(tv) == we(one) && we(fv) == we(zero) => - getConnectOrigin(connects,cond.serialize) + // technically, you should keep searching through PrimOp, because a node + 0 is still itself, + // a node shifted by 0 is still itself, etc. + // TODO: handle validif???, more thorough + private def getConnectOrigin(connects: Map[String, Expression], e: Expression): Expression = e match { + case Mux(cond, tv, fv, _) => + val fvOrigin = getOrigin(connects,fv) + val tvOrigin = getOrigin(connects,tv) + val condOrigin = getOrigin(connects,cond) + if (we(tvOrigin) == we(one) && we(fvOrigin) == we(zero)) condOrigin + else if (we(condOrigin) == we(one)) tvOrigin + else if (we(condOrigin) == we(zero)) fvOrigin + else if (we(tvOrigin) == we(fvOrigin)) tvOrigin + else if (we(fvOrigin) == we(zero) && we(condOrigin) == we(tvOrigin)) condOrigin + else e + case DoPrim(op, args, consts, tpe) if op == PrimOps.Or && args.contains(one) => one + case DoPrim(op, args, consts, tpe) if op == PrimOps.And && args.contains(zero) => zero + case DoPrim(op, args, consts, tpe) if op == PrimOps.Bits => + val msb = consts(0) + val lsb = consts(1) + val extractionWidth = (msb-lsb)+1 + val nodeWidth = bitWidth(args.head.tpe) + // if you're extracting the full bitwidth, then keep searching for origin + if (nodeWidth == extractionWidth) getOrigin(connects,args.head) + else e + case DoPrim(op, args, _, _) if (op == PrimOps.AsUInt || op == PrimOps.AsSInt || op == PrimOps.AsClock) => + getOrigin(connects,args.head) case _: WRef | _: SubField | _: SubIndex | _: SubAccess if connects contains e.serialize => getConnectOrigin(connects,e.serialize) case _ => e @@ -113,7 +148,6 @@ object AnnotateMemMacros extends Pass { case b: Block => Block(b.stmts map updateStmts) case s => s } - m.copy(body=updateStmts(m.body)) } diff --git a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala index 2301ad1b..a4a910fd 100644 --- a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala +++ b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala @@ -47,6 +47,7 @@ object MemTransformUtils { } def updateStmtRefs(s: Statement): Statement = s map updateStmtRefs map updateRef match { case Connect(info, loc, exp) if loc == EmptyExpression => EmptyStmt + case Connect(info, WSubIndex(EmptyExpression,_,_,_), exp) => EmptyStmt case s => s } updateStmtRefs(s) |
