aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/transforms')
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala6
-rw-r--r--src/main/scala/firrtl/transforms/FlattenRegUpdate.scala3
-rw-r--r--src/main/scala/firrtl/transforms/RemoveReset.scala3
-rw-r--r--src/main/scala/firrtl/transforms/RemoveWires.scala36
4 files changed, 31 insertions, 17 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index 6618312a..fdaa7112 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -350,6 +350,8 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
// Keep track of any submodule inputs we drive with a constant
// (can have more than 1 of the same submodule)
val constSubInputs = mutable.HashMap.empty[String, mutable.HashMap[String, Seq[Literal]]]
+ // AsyncReset registers don't have reset turned into a mux so we must be careful
+ val asyncResetRegs = mutable.HashSet.empty[String]
// Copy constant mapping for constant inputs (except ones marked dontTouch!)
nodeMap ++= constInputs.filterNot { case (pname, _) => dontTouches.contains(pname) }
@@ -405,6 +407,8 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
// Record things that should be propagated
stmtx match {
case x: DefNode if !dontTouches.contains(x.name) => propagateRef(x.name, x.value)
+ case reg: DefRegister if reg.reset.tpe == AsyncResetType =>
+ asyncResetRegs += reg.name
case Connect(_, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) =>
val exprx = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(expr, wtpe))
propagateRef(wname, exprx)
@@ -414,7 +418,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
constOutputs(pname) = paddedLit
// Const prop registers that are driven by a mux tree containing only instances of one constant or self-assigns
// This requires that reset has been made explicit
- case Connect(_, lref @ WRef(lname, ltpe, RegKind, _), rhs) if !dontTouches.contains(lname) =>
+ case Connect(_, lref @ WRef(lname, ltpe, RegKind, _), rhs) if !dontTouches(lname) && !asyncResetRegs(lname) =>
/** Checks if an RHS expression e of a register assignment is convertible to a constant assignment.
* Here, this means that e must be 1) a literal, 2) a self-connect, or 3) a mux tree of cases (1) and (2).
* In case (3), it also recursively checks that the two mux cases are convertible to constants and
diff --git a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
index f21e6b18..2bce124c 100644
--- a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
+++ b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
@@ -81,7 +81,8 @@ object FlattenRegUpdate {
def onStmt(stmt: Statement): Statement = stmt.map(onStmt) match {
case reg @ DefRegister(_, rname, _,_, resetCond, _) =>
- assert(resetCond == Utils.zero, "Register reset should have already been made explicit!")
+ assert(resetCond.tpe == AsyncResetType || resetCond == Utils.zero,
+ "Synchronous reset should have already been made explicit!")
val ref = WRef(reg)
val update = Connect(NoInfo, ref, constructRegUpdate(netlist.getOrElse(ref, ref)))
regUpdates += update
diff --git a/src/main/scala/firrtl/transforms/RemoveReset.scala b/src/main/scala/firrtl/transforms/RemoveReset.scala
index bfec76a2..0b8b907d 100644
--- a/src/main/scala/firrtl/transforms/RemoveReset.scala
+++ b/src/main/scala/firrtl/transforms/RemoveReset.scala
@@ -22,7 +22,8 @@ class RemoveReset extends Transform {
val resets = mutable.HashMap.empty[String, Reset]
def onStmt(stmt: Statement): Statement = {
stmt match {
- case reg @ DefRegister(_, rname, _, _, reset, init) if reset != Utils.zero =>
+ case reg @ DefRegister(_, rname, _, _, reset, init)
+ if reset != Utils.zero && reset.tpe != AsyncResetType =>
// Add register reset to map
resets(rname) = Reset(reset, init)
reg.copy(reset = Utils.zero, init = WRef(reg))
diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala
index 1b5b3e5f..da79be8e 100644
--- a/src/main/scala/firrtl/transforms/RemoveWires.scala
+++ b/src/main/scala/firrtl/transforms/RemoveWires.scala
@@ -6,6 +6,7 @@ package transforms
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
+import firrtl.traversals.Foreachers._
import firrtl.WrappedExpression._
import firrtl.graph.{DiGraph, MutableDiGraph, CyclicException}
@@ -29,7 +30,7 @@ class RemoveWires extends Transform {
def rec(e: Expression): Expression = {
e match {
case ref @ WRef(_,_, WireKind | NodeKind | RegKind, _) => refs += ref
- case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested map rec
+ case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested.foreach(rec)
case _ => // Do nothing
}
e
@@ -40,13 +41,15 @@ class RemoveWires extends Transform {
// Transform netlist into DefNodes
private def getOrderedNodes(
- netlist: mutable.LinkedHashMap[WrappedExpression, (Expression, Info)],
+ netlist: mutable.LinkedHashMap[WrappedExpression, (Seq[Expression], Info)],
regInfo: mutable.Map[WrappedExpression, DefRegister]): Try[Seq[Statement]] = {
val digraph = new MutableDiGraph[WrappedExpression]
- for ((sink, (expr, _)) <- netlist) {
+ for ((sink, (exprs, _)) <- netlist) {
digraph.addVertex(sink)
- for (source <- extractNodeWireRegRefs(expr)) {
- digraph.addPairWithEdge(sink, source)
+ for (expr <- exprs) {
+ for (source <- extractNodeWireRegRefs(expr)) {
+ digraph.addPairWithEdge(sink, source)
+ }
}
}
@@ -57,10 +60,11 @@ class RemoveWires extends Transform {
val ordered = digraph.linearize.reverse
ordered.map { key =>
val WRef(name, _, kind, _) = key.e1
- val (rhs, info) = netlist(key)
kind match {
case RegKind => regInfo(key)
- case WireKind | NodeKind => DefNode(info, name, rhs)
+ case WireKind | NodeKind =>
+ val (Seq(rhs), info) = netlist(key)
+ DefNode(info, name, rhs)
}
}
}
@@ -72,7 +76,7 @@ class RemoveWires extends Transform {
// Store all "other" statements here, non-wire, non-node connections, printfs, etc.
val otherStmts = mutable.ArrayBuffer.empty[Statement]
// Add nodes and wire connection here
- val netlist = mutable.LinkedHashMap.empty[WrappedExpression, (Expression, Info)]
+ val netlist = mutable.LinkedHashMap.empty[WrappedExpression, (Seq[Expression], Info)]
// Info at definition of wires for combining into node
val wireInfo = mutable.HashMap.empty[WrappedExpression, Info]
// Additional info about registers
@@ -81,12 +85,16 @@ class RemoveWires extends Transform {
def onStmt(stmt: Statement): Statement = {
stmt match {
case node: DefNode =>
- netlist(we(WRef(node))) = (node.value, node.info)
+ netlist(we(WRef(node))) = (Seq(node.value), node.info)
case wire: DefWire if !wire.tpe.isInstanceOf[AnalogType] => // Remove all non-Analog wires
wireInfo(WRef(wire)) = wire.info
case reg: DefRegister =>
+ val resetDep = reg.reset.tpe match {
+ case AsyncResetType => reg.reset :: Nil
+ case _ => Nil
+ }
regInfo(we(WRef(reg))) = reg
- netlist(we(WRef(reg))) = (reg.clock, reg.info)
+ netlist(we(WRef(reg))) = (reg.clock :: resetDep, reg.info)
case decl: IsDeclaration => // Keep all declarations except for nodes and non-Analog wires
decls += decl
case con @ Connect(cinfo, lhs, rhs) => kind(lhs) match {
@@ -94,20 +102,20 @@ class RemoveWires extends Transform {
// Be sure to pad the rhs since nodes get their type from the rhs
val paddedRhs = ConstantPropagation.pad(rhs, lhs.tpe)
val dinfo = wireInfo(lhs)
- netlist(we(lhs)) = (paddedRhs, MultiInfo(dinfo, cinfo))
+ netlist(we(lhs)) = (Seq(paddedRhs), MultiInfo(dinfo, cinfo))
case _ => otherStmts += con // Other connections just pass through
}
case invalid @ IsInvalid(info, expr) =>
kind(expr) match {
case WireKind =>
val width = expr.tpe match { case GroundType(width) => width } // LowFirrtl
- netlist(we(expr)) = (ValidIf(Utils.zero, UIntLiteral(BigInt(0), width), expr.tpe), info)
+ netlist(we(expr)) = (Seq(ValidIf(Utils.zero, UIntLiteral(BigInt(0), width), expr.tpe)), info)
case _ => otherStmts += invalid
}
case other @ (_: Print | _: Stop | _: Attach) =>
otherStmts += other
case EmptyStmt => // Dont bother keeping EmptyStmts around
- case block: Block => block map onStmt
+ case block: Block => block.foreach(onStmt)
case _ => throwInternalError()
}
stmt
@@ -136,7 +144,7 @@ class RemoveWires extends Transform {
)
def execute(state: CircuitState): CircuitState = {
- val result = state.copy(circuit = state.circuit map onModule)
+ val result = state.copy(circuit = state.circuit.map(onModule))
cleanup.foldLeft(result) { case (in, xform) => xform.execute(in) }
}
}