diff options
Diffstat (limited to 'src/main/scala/firrtl/transforms')
4 files changed, 31 insertions, 17 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 6618312a..fdaa7112 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -350,6 +350,8 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { // Keep track of any submodule inputs we drive with a constant // (can have more than 1 of the same submodule) val constSubInputs = mutable.HashMap.empty[String, mutable.HashMap[String, Seq[Literal]]] + // AsyncReset registers don't have reset turned into a mux so we must be careful + val asyncResetRegs = mutable.HashSet.empty[String] // Copy constant mapping for constant inputs (except ones marked dontTouch!) nodeMap ++= constInputs.filterNot { case (pname, _) => dontTouches.contains(pname) } @@ -405,6 +407,8 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { // Record things that should be propagated stmtx match { case x: DefNode if !dontTouches.contains(x.name) => propagateRef(x.name, x.value) + case reg: DefRegister if reg.reset.tpe == AsyncResetType => + asyncResetRegs += reg.name case Connect(_, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) => val exprx = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(expr, wtpe)) propagateRef(wname, exprx) @@ -414,7 +418,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { constOutputs(pname) = paddedLit // Const prop registers that are driven by a mux tree containing only instances of one constant or self-assigns // This requires that reset has been made explicit - case Connect(_, lref @ WRef(lname, ltpe, RegKind, _), rhs) if !dontTouches.contains(lname) => + case Connect(_, lref @ WRef(lname, ltpe, RegKind, _), rhs) if !dontTouches(lname) && !asyncResetRegs(lname) => /** Checks if an RHS expression e of a register assignment is convertible to a constant assignment. * Here, this means that e must be 1) a literal, 2) a self-connect, or 3) a mux tree of cases (1) and (2). * In case (3), it also recursively checks that the two mux cases are convertible to constants and diff --git a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala index f21e6b18..2bce124c 100644 --- a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala +++ b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala @@ -81,7 +81,8 @@ object FlattenRegUpdate { def onStmt(stmt: Statement): Statement = stmt.map(onStmt) match { case reg @ DefRegister(_, rname, _,_, resetCond, _) => - assert(resetCond == Utils.zero, "Register reset should have already been made explicit!") + assert(resetCond.tpe == AsyncResetType || resetCond == Utils.zero, + "Synchronous reset should have already been made explicit!") val ref = WRef(reg) val update = Connect(NoInfo, ref, constructRegUpdate(netlist.getOrElse(ref, ref))) regUpdates += update diff --git a/src/main/scala/firrtl/transforms/RemoveReset.scala b/src/main/scala/firrtl/transforms/RemoveReset.scala index bfec76a2..0b8b907d 100644 --- a/src/main/scala/firrtl/transforms/RemoveReset.scala +++ b/src/main/scala/firrtl/transforms/RemoveReset.scala @@ -22,7 +22,8 @@ class RemoveReset extends Transform { val resets = mutable.HashMap.empty[String, Reset] def onStmt(stmt: Statement): Statement = { stmt match { - case reg @ DefRegister(_, rname, _, _, reset, init) if reset != Utils.zero => + case reg @ DefRegister(_, rname, _, _, reset, init) + if reset != Utils.zero && reset.tpe != AsyncResetType => // Add register reset to map resets(rname) = Reset(reset, init) reg.copy(reset = Utils.zero, init = WRef(reg)) diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala index 1b5b3e5f..da79be8e 100644 --- a/src/main/scala/firrtl/transforms/RemoveWires.scala +++ b/src/main/scala/firrtl/transforms/RemoveWires.scala @@ -6,6 +6,7 @@ package transforms import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ +import firrtl.traversals.Foreachers._ import firrtl.WrappedExpression._ import firrtl.graph.{DiGraph, MutableDiGraph, CyclicException} @@ -29,7 +30,7 @@ class RemoveWires extends Transform { def rec(e: Expression): Expression = { e match { case ref @ WRef(_,_, WireKind | NodeKind | RegKind, _) => refs += ref - case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested map rec + case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested.foreach(rec) case _ => // Do nothing } e @@ -40,13 +41,15 @@ class RemoveWires extends Transform { // Transform netlist into DefNodes private def getOrderedNodes( - netlist: mutable.LinkedHashMap[WrappedExpression, (Expression, Info)], + netlist: mutable.LinkedHashMap[WrappedExpression, (Seq[Expression], Info)], regInfo: mutable.Map[WrappedExpression, DefRegister]): Try[Seq[Statement]] = { val digraph = new MutableDiGraph[WrappedExpression] - for ((sink, (expr, _)) <- netlist) { + for ((sink, (exprs, _)) <- netlist) { digraph.addVertex(sink) - for (source <- extractNodeWireRegRefs(expr)) { - digraph.addPairWithEdge(sink, source) + for (expr <- exprs) { + for (source <- extractNodeWireRegRefs(expr)) { + digraph.addPairWithEdge(sink, source) + } } } @@ -57,10 +60,11 @@ class RemoveWires extends Transform { val ordered = digraph.linearize.reverse ordered.map { key => val WRef(name, _, kind, _) = key.e1 - val (rhs, info) = netlist(key) kind match { case RegKind => regInfo(key) - case WireKind | NodeKind => DefNode(info, name, rhs) + case WireKind | NodeKind => + val (Seq(rhs), info) = netlist(key) + DefNode(info, name, rhs) } } } @@ -72,7 +76,7 @@ class RemoveWires extends Transform { // Store all "other" statements here, non-wire, non-node connections, printfs, etc. val otherStmts = mutable.ArrayBuffer.empty[Statement] // Add nodes and wire connection here - val netlist = mutable.LinkedHashMap.empty[WrappedExpression, (Expression, Info)] + val netlist = mutable.LinkedHashMap.empty[WrappedExpression, (Seq[Expression], Info)] // Info at definition of wires for combining into node val wireInfo = mutable.HashMap.empty[WrappedExpression, Info] // Additional info about registers @@ -81,12 +85,16 @@ class RemoveWires extends Transform { def onStmt(stmt: Statement): Statement = { stmt match { case node: DefNode => - netlist(we(WRef(node))) = (node.value, node.info) + netlist(we(WRef(node))) = (Seq(node.value), node.info) case wire: DefWire if !wire.tpe.isInstanceOf[AnalogType] => // Remove all non-Analog wires wireInfo(WRef(wire)) = wire.info case reg: DefRegister => + val resetDep = reg.reset.tpe match { + case AsyncResetType => reg.reset :: Nil + case _ => Nil + } regInfo(we(WRef(reg))) = reg - netlist(we(WRef(reg))) = (reg.clock, reg.info) + netlist(we(WRef(reg))) = (reg.clock :: resetDep, reg.info) case decl: IsDeclaration => // Keep all declarations except for nodes and non-Analog wires decls += decl case con @ Connect(cinfo, lhs, rhs) => kind(lhs) match { @@ -94,20 +102,20 @@ class RemoveWires extends Transform { // Be sure to pad the rhs since nodes get their type from the rhs val paddedRhs = ConstantPropagation.pad(rhs, lhs.tpe) val dinfo = wireInfo(lhs) - netlist(we(lhs)) = (paddedRhs, MultiInfo(dinfo, cinfo)) + netlist(we(lhs)) = (Seq(paddedRhs), MultiInfo(dinfo, cinfo)) case _ => otherStmts += con // Other connections just pass through } case invalid @ IsInvalid(info, expr) => kind(expr) match { case WireKind => val width = expr.tpe match { case GroundType(width) => width } // LowFirrtl - netlist(we(expr)) = (ValidIf(Utils.zero, UIntLiteral(BigInt(0), width), expr.tpe), info) + netlist(we(expr)) = (Seq(ValidIf(Utils.zero, UIntLiteral(BigInt(0), width), expr.tpe)), info) case _ => otherStmts += invalid } case other @ (_: Print | _: Stop | _: Attach) => otherStmts += other case EmptyStmt => // Dont bother keeping EmptyStmts around - case block: Block => block map onStmt + case block: Block => block.foreach(onStmt) case _ => throwInternalError() } stmt @@ -136,7 +144,7 @@ class RemoveWires extends Transform { ) def execute(state: CircuitState): CircuitState = { - val result = state.copy(circuit = state.circuit map onModule) + val result = state.copy(circuit = state.circuit.map(onModule)) cleanup.foldLeft(result) { case (in, xform) => xform.execute(in) } } } |
