diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/AnnotateMemMacros.scala | 48 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala | 1 |
2 files changed, 42 insertions, 7 deletions
diff --git a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala index 22fad253..8d7c3f65 100644 --- a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala +++ b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala @@ -22,7 +22,7 @@ object AnalysisUtils { def getConnects(m: Module) = { val connects = HashMap[String, Expression]() def getConnects(s: Statement): Unit = s match { - case s: Connect => + case s: Connect => connects(s.loc.serialize) = s.expr case s: PartialConnect => connects(s.loc.serialize) = s.expr @@ -37,14 +37,49 @@ object AnalysisUtils { } // only works in a module - def getConnectOrigin(connects: Map[String, Expression], node: String): Expression = + def getConnectOrigin(connects: Map[String, Expression], node: String): Expression = { if (connects contains node) getConnectOrigin(connects,connects(node)) else EmptyExpression - + } + + def checkLit(e: Expression) = e match { + case l : Literal => true + case _ => false + } + + def getOrigin(connects: Map[String, Expression], e: Expression) = e match { + case DoPrim(_,_,_,_) => getConnectOrigin(connects,e) + case l if (checkLit(l)) => e + case _ => getConnectOrigin(connects,e.serialize) + } + // backward searches until PrimOp, Lit or non-trivial Mux appears - private def getConnectOrigin(connects: Map[String, Expression], e: Expression): Expression = e match { - case Mux(cond, tv, fv, _) if we(tv) == we(one) && we(fv) == we(zero) => - getConnectOrigin(connects,cond.serialize) + // technically, you should keep searching through PrimOp, because a node + 0 is still itself, + // a node shifted by 0 is still itself, etc. + // TODO: handle validif???, more thorough + private def getConnectOrigin(connects: Map[String, Expression], e: Expression): Expression = e match { + case Mux(cond, tv, fv, _) => + val fvOrigin = getOrigin(connects,fv) + val tvOrigin = getOrigin(connects,tv) + val condOrigin = getOrigin(connects,cond) + if (we(tvOrigin) == we(one) && we(fvOrigin) == we(zero)) condOrigin + else if (we(condOrigin) == we(one)) tvOrigin + else if (we(condOrigin) == we(zero)) fvOrigin + else if (we(tvOrigin) == we(fvOrigin)) tvOrigin + else if (we(fvOrigin) == we(zero) && we(condOrigin) == we(tvOrigin)) condOrigin + else e + case DoPrim(op, args, consts, tpe) if op == PrimOps.Or && args.contains(one) => one + case DoPrim(op, args, consts, tpe) if op == PrimOps.And && args.contains(zero) => zero + case DoPrim(op, args, consts, tpe) if op == PrimOps.Bits => + val msb = consts(0) + val lsb = consts(1) + val extractionWidth = (msb-lsb)+1 + val nodeWidth = bitWidth(args.head.tpe) + // if you're extracting the full bitwidth, then keep searching for origin + if (nodeWidth == extractionWidth) getOrigin(connects,args.head) + else e + case DoPrim(op, args, _, _) if (op == PrimOps.AsUInt || op == PrimOps.AsSInt || op == PrimOps.AsClock) => + getOrigin(connects,args.head) case _: WRef | _: SubField | _: SubIndex | _: SubAccess if connects contains e.serialize => getConnectOrigin(connects,e.serialize) case _ => e @@ -113,7 +148,6 @@ object AnnotateMemMacros extends Pass { case b: Block => Block(b.stmts map updateStmts) case s => s } - m.copy(body=updateStmts(m.body)) } diff --git a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala index 2301ad1b..a4a910fd 100644 --- a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala +++ b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala @@ -47,6 +47,7 @@ object MemTransformUtils { } def updateStmtRefs(s: Statement): Statement = s map updateStmtRefs map updateRef match { case Connect(info, loc, exp) if loc == EmptyExpression => EmptyStmt + case Connect(info, WSubIndex(EmptyExpression,_,_,_), exp) => EmptyStmt case s => s } updateStmtRefs(s) |
