aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSchuyler Eldridge2018-08-07 18:38:28 -0400
committerGitHub2018-08-07 18:38:28 -0400
commit0a2e2c2e4a97fb8d14e7259f779d8398e243e889 (patch)
tree5a1286c11553d73541ef2b92ea7ee4d88afd1ca7
parentb84cb05faba6d787cb599fac4ea687ce4249ef1d (diff)
parentadf66019948afc46f8818e6883f1bab4d200265d (diff)
Respect register references in RemoveWires (#868)
- makes RemoveWires properly include registers in dependency graph - adds an apply method to WRef for DefNode - adds a test case requiring register reordering
-rw-r--r--src/main/scala/firrtl/WIR.scala2
-rw-r--r--src/main/scala/firrtl/transforms/RemoveWires.scala14
-rw-r--r--src/test/scala/firrtlTests/RemoveWiresSpec.scala15
3 files changed, 24 insertions, 7 deletions
diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala
index 29a14406..f61fa41e 100644
--- a/src/main/scala/firrtl/WIR.scala
+++ b/src/main/scala/firrtl/WIR.scala
@@ -35,6 +35,8 @@ object WRef {
def apply(wire: DefWire): WRef = new WRef(wire.name, wire.tpe, WireKind, UNKNOWNGENDER)
/** Creates a WRef from a Register */
def apply(reg: DefRegister): WRef = new WRef(reg.name, reg.tpe, RegKind, UNKNOWNGENDER)
+ /** Creates a WRef from a Node */
+ def apply(node: DefNode): WRef = new WRef(node.name, node.value.tpe, NodeKind, MALE)
def apply(n: String, t: Type = UnknownType, k: Kind = ExpKind): WRef = new WRef(n, t, k, UNKNOWNGENDER)
}
case class WSubField(expr: Expression, name: String, tpe: Type, gender: Gender) extends Expression {
diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala
index 8cc967dd..1b5b3e5f 100644
--- a/src/main/scala/firrtl/transforms/RemoveWires.scala
+++ b/src/main/scala/firrtl/transforms/RemoveWires.scala
@@ -22,13 +22,13 @@ class RemoveWires extends Transform {
def inputForm = LowForm
def outputForm = LowForm
- // Extract all expressions that are references to a Wire or Node
+ // Extract all expressions that are references to a Node, Wire, or Reg
// Since we are operating on LowForm, they can only be WRefs
- private def extractNodeWireRefs(expr: Expression): Seq[WRef] = {
+ private def extractNodeWireRegRefs(expr: Expression): Seq[WRef] = {
val refs = mutable.ArrayBuffer.empty[WRef]
def rec(e: Expression): Expression = {
e match {
- case ref @ WRef(_,_, WireKind | NodeKind, _) => refs += ref
+ case ref @ WRef(_,_, WireKind | NodeKind | RegKind, _) => refs += ref
case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested map rec
case _ => // Do nothing
}
@@ -45,7 +45,7 @@ class RemoveWires extends Transform {
val digraph = new MutableDiGraph[WrappedExpression]
for ((sink, (expr, _)) <- netlist) {
digraph.addVertex(sink)
- for (source <- extractNodeWireRefs(expr)) {
+ for (source <- extractNodeWireRegRefs(expr)) {
digraph.addPairWithEdge(sink, source)
}
}
@@ -60,7 +60,7 @@ class RemoveWires extends Transform {
val (rhs, info) = netlist(key)
kind match {
case RegKind => regInfo(key)
- case _ => DefNode(info, name, rhs)
+ case WireKind | NodeKind => DefNode(info, name, rhs)
}
}
}
@@ -80,8 +80,8 @@ class RemoveWires extends Transform {
def onStmt(stmt: Statement): Statement = {
stmt match {
- case DefNode(info, name, expr) =>
- netlist(we(WRef(name))) = (expr, info)
+ case node: DefNode =>
+ netlist(we(WRef(node))) = (node.value, node.info)
case wire: DefWire if !wire.tpe.isInstanceOf[AnalogType] => // Remove all non-Analog wires
wireInfo(WRef(wire)) = wire.info
case reg: DefRegister =>
diff --git a/src/test/scala/firrtlTests/RemoveWiresSpec.scala b/src/test/scala/firrtlTests/RemoveWiresSpec.scala
index f162f32c..d15e6908 100644
--- a/src/test/scala/firrtlTests/RemoveWiresSpec.scala
+++ b/src/test/scala/firrtlTests/RemoveWiresSpec.scala
@@ -150,4 +150,19 @@ class RemoveWiresSpec extends FirrtlFlatSpec {
val names = orderedNames(result.circuit)
names should be (Seq("a", "clock2", "b"))
}
+
+ it should "order registers correctly" in {
+ val result = compileBody(s"""
+ |input clock : Clock
+ |input a : UInt<8>
+ |output c : UInt<8>
+ |wire w : UInt<8>
+ |node n = tail(add(w, UInt(1)), 1)
+ |reg r : UInt<8>, clock
+ |w <= tail(add(r, a), 1)
+ |c <= n""".stripMargin
+ )
+ // Check declaration before use is maintained
+ passes.CheckHighForm.execute(result)
+ }
}