diff options
Diffstat (limited to 'src/main/scala/firrtl/transforms/ConstantPropagation.scala')
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 56 |
1 files changed, 42 insertions, 14 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 29410c7f..000adc15 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -458,9 +458,12 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res case _ => constPropMuxCond(m) } - private def constPropNodeRef(r: WRef, e: Expression) = e match { - case _: UIntLiteral | _: SIntLiteral | _: WRef => e - case _ => r + private def constPropNodeRef(r: WRef, e: Expression): Expression = { + def doit(ex: Expression) = ex match { + case _: UIntLiteral | _: SIntLiteral | _: WRef => ex + case _ => r + } + InfoExpr.map(e)(doit) } // Is "a" a "better name" than "b"? @@ -475,7 +478,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res case p: DoPrim => constPropPrim(p) case m: Mux => constPropMux(m) case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) => - constPropNodeRef(ref, nodeMap(rname)) + constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2) case ref @ WSubField(WRef(inst, _, InstanceKind, _), pname, _, SourceFlow) => val module = instMap(inst.Instance) // Check constSubOutputs to see if the submodule is driving a constant @@ -487,6 +490,24 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res else constPropExpression(nodeMap, instMap, constSubOutputs)(propagated) } + /** Hacky way of propagating source locators across nodes and connections that have just a + * reference on the right-hand side + * + * @todo generalize source locator propagation across Expressions and delete this method + * @todo is the `orElse` the way we want to do propagation here? + */ + private def propagateDirectConnectionInfoOnly(nodeMap: NodeMap, dontTouch: Set[String]) + (stmt: Statement): Statement = stmt match { + // We check rname because inlining it would cause the original declaration to go away + case node @ DefNode(info0, name, WRef(rname, _, NodeKind, _)) if !dontTouch(rname) => + val (info1, _) = InfoExpr.unwrap(nodeMap(rname)) + node.copy(info = InfoExpr.orElse(info1, info0)) + case con @ Connect(info0, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouch(rname) => + val (info1, _) = InfoExpr.unwrap(nodeMap(rname)) + con.copy(info = InfoExpr.orElse(info1, info0)) + case other => other + } + /* Constant propagate a Module * * Two pass process @@ -547,7 +568,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res ref.copy(name = swapMap(rname)) // Only const prop on the rhs case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) => - constPropNodeRef(ref, nodeMap(rname)) + constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2) case x => x } if (old ne propagated) { @@ -577,7 +598,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res } // When propagating a reference, check if we want to keep the name that would be deleted - def propagateRef(lname: String, value: Expression): Unit = { + def propagateRef(lname: String, value: Expression, info: Info): Unit = { value match { case WRef(rname,_,kind,_) if betterName(lname, rname) && !swapMap.contains(rname) && kind != PortKind => assert(!swapMap.contains(lname)) // <- Shouldn't be possible because lname is either a @@ -585,19 +606,22 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res swapMap += (lname -> rname, rname -> lname) case _ => } - nodeMap(lname) = value + nodeMap(lname) = InfoExpr.wrap(info, value) } def constPropStmt(s: Statement): Statement = { - val stmtx = s map constPropStmt map constPropExpression(nodeMap, instMap, constSubOutputs) + val s0 = s.map(constPropStmt) // Statement recurse + val s1 = propagateDirectConnectionInfoOnly(nodeMap, dontTouches)(s0) // hacky source locator propagation + val stmtx = s1.map(constPropExpression(nodeMap, instMap, constSubOutputs)) // propagate sub-Expressions // Record things that should be propagated stmtx match { - case x: DefNode if !dontTouches.contains(x.name) => propagateRef(x.name, x.value) + case DefNode(info, name, value) if !dontTouches.contains(name) => + propagateRef(name, value, info) case reg: DefRegister if reg.reset.tpe == AsyncResetType => asyncResetRegs(reg.name) = reg - case Connect(_, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) => + case Connect(info, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) => val exprx = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(expr, wtpe)) - propagateRef(wname, exprx) + propagateRef(wname, exprx, info) // Record constants driving outputs case Connect(_, WRef(pname, ptpe, PortKind, _), lit: Literal) if !dontTouches.contains(pname) => val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal] @@ -633,7 +657,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res case lit: Literal => baseCase.resolve(RegCPEntry(UnboundConstant, BoundConstant(lit))) case WRef(regName, _, RegKind, _) => baseCase.resolve(RegCPEntry(BoundConstant(regName), UnboundConstant)) case WRef(nodeName, _, NodeKind, _) if nodeMap.contains(nodeName) => - val cached = nodeRegCPEntries.getOrElseUpdate(nodeName, { regConstant(nodeMap(nodeName), unbound) }) + val (_, expr) = InfoExpr.unwrap(nodeMap(nodeName)) + val cached = nodeRegCPEntries.getOrElseUpdate(nodeName, { regConstant(expr, unbound) }) baseCase.resolve(cached) case Mux(_, tval, fval, _) => regConstant(tval, baseCase).resolve(regConstant(fval, baseCase)) @@ -676,8 +701,11 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // Actually transform some statements stmtx match { // Propagate connections to references - case Connect(info, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouches.contains(rname) => - Connect(info, lhs, nodeMap(rname)) + case Connect(info0, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouches.contains(rname) => + val (info1, value) = InfoExpr.unwrap(nodeMap(rname)) + // Is this the right info combination/propagation function? + // See propagateDirectConnectionInfoOnly + Connect(InfoExpr.orElse(info1, info0), lhs, value) // If an Attach has at least 1 port, any wires are redundant and can be removed case Attach(info, exprs) if exprs.exists(kind(_) == PortKind) => Attach(info, exprs.filterNot(kind(_) == WireKind)) |
