diff options
| author | Schuyler Eldridge | 2018-08-07 18:38:28 -0400 |
|---|---|---|
| committer | GitHub | 2018-08-07 18:38:28 -0400 |
| commit | 0a2e2c2e4a97fb8d14e7259f779d8398e243e889 (patch) | |
| tree | 5a1286c11553d73541ef2b92ea7ee4d88afd1ca7 /src | |
| parent | b84cb05faba6d787cb599fac4ea687ce4249ef1d (diff) | |
| parent | adf66019948afc46f8818e6883f1bab4d200265d (diff) | |
Respect register references in RemoveWires (#868)
- makes RemoveWires properly include registers in dependency graph
- adds an apply method to WRef for DefNode
- adds a test case requiring register reordering
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/WIR.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/RemoveWires.scala | 14 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/RemoveWiresSpec.scala | 15 |
3 files changed, 24 insertions, 7 deletions
diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index 29a14406..f61fa41e 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -35,6 +35,8 @@ object WRef { def apply(wire: DefWire): WRef = new WRef(wire.name, wire.tpe, WireKind, UNKNOWNGENDER) /** Creates a WRef from a Register */ def apply(reg: DefRegister): WRef = new WRef(reg.name, reg.tpe, RegKind, UNKNOWNGENDER) + /** Creates a WRef from a Node */ + def apply(node: DefNode): WRef = new WRef(node.name, node.value.tpe, NodeKind, MALE) def apply(n: String, t: Type = UnknownType, k: Kind = ExpKind): WRef = new WRef(n, t, k, UNKNOWNGENDER) } case class WSubField(expr: Expression, name: String, tpe: Type, gender: Gender) extends Expression { diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala index 8cc967dd..1b5b3e5f 100644 --- a/src/main/scala/firrtl/transforms/RemoveWires.scala +++ b/src/main/scala/firrtl/transforms/RemoveWires.scala @@ -22,13 +22,13 @@ class RemoveWires extends Transform { def inputForm = LowForm def outputForm = LowForm - // Extract all expressions that are references to a Wire or Node + // Extract all expressions that are references to a Node, Wire, or Reg // Since we are operating on LowForm, they can only be WRefs - private def extractNodeWireRefs(expr: Expression): Seq[WRef] = { + private def extractNodeWireRegRefs(expr: Expression): Seq[WRef] = { val refs = mutable.ArrayBuffer.empty[WRef] def rec(e: Expression): Expression = { e match { - case ref @ WRef(_,_, WireKind | NodeKind, _) => refs += ref + case ref @ WRef(_,_, WireKind | NodeKind | RegKind, _) => refs += ref case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested map rec case _ => // Do nothing } @@ -45,7 +45,7 @@ class RemoveWires extends Transform { val digraph = new MutableDiGraph[WrappedExpression] for ((sink, (expr, _)) <- netlist) { digraph.addVertex(sink) - for (source <- extractNodeWireRefs(expr)) { + for (source <- extractNodeWireRegRefs(expr)) { digraph.addPairWithEdge(sink, source) } } @@ -60,7 +60,7 @@ class RemoveWires extends Transform { val (rhs, info) = netlist(key) kind match { case RegKind => regInfo(key) - case _ => DefNode(info, name, rhs) + case WireKind | NodeKind => DefNode(info, name, rhs) } } } @@ -80,8 +80,8 @@ class RemoveWires extends Transform { def onStmt(stmt: Statement): Statement = { stmt match { - case DefNode(info, name, expr) => - netlist(we(WRef(name))) = (expr, info) + case node: DefNode => + netlist(we(WRef(node))) = (node.value, node.info) case wire: DefWire if !wire.tpe.isInstanceOf[AnalogType] => // Remove all non-Analog wires wireInfo(WRef(wire)) = wire.info case reg: DefRegister => diff --git a/src/test/scala/firrtlTests/RemoveWiresSpec.scala b/src/test/scala/firrtlTests/RemoveWiresSpec.scala index f162f32c..d15e6908 100644 --- a/src/test/scala/firrtlTests/RemoveWiresSpec.scala +++ b/src/test/scala/firrtlTests/RemoveWiresSpec.scala @@ -150,4 +150,19 @@ class RemoveWiresSpec extends FirrtlFlatSpec { val names = orderedNames(result.circuit) names should be (Seq("a", "clock2", "b")) } + + it should "order registers correctly" in { + val result = compileBody(s""" + |input clock : Clock + |input a : UInt<8> + |output c : UInt<8> + |wire w : UInt<8> + |node n = tail(add(w, UInt(1)), 1) + |reg r : UInt<8>, clock + |w <= tail(add(r, a), 1) + |c <= n""".stripMargin + ) + // Check declaration before use is maintained + passes.CheckHighForm.execute(result) + } } |
