aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala3
-rw-r--r--src/main/scala/firrtl/transforms/RemoveWires.scala132
-rw-r--r--src/test/scala/firrtlTests/ConstantPropagationTests.scala8
-rw-r--r--src/test/scala/firrtlTests/DCETests.scala5
-rw-r--r--src/test/scala/firrtlTests/FlattenTests.scala6
-rw-r--r--src/test/scala/firrtlTests/InfoSpec.scala4
-rw-r--r--src/test/scala/firrtlTests/RemoveWiresSpec.scala123
7 files changed, 270 insertions, 11 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index 8dd9b180..f032868a 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -86,7 +86,8 @@ class MiddleFirrtlToLowFirrtl extends CoreTransform {
passes.InferWidths,
passes.Legalize,
new firrtl.transforms.RemoveReset,
- new firrtl.transforms.CheckCombLoops)
+ new firrtl.transforms.CheckCombLoops,
+ new firrtl.transforms.RemoveWires)
}
/** Runs a series of optimization passes on LowFirrtl
diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala
new file mode 100644
index 00000000..a1fb32db
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/RemoveWires.scala
@@ -0,0 +1,132 @@
+// See LICENSE for license details.
+
+package firrtl
+package transforms
+
+import firrtl.ir._
+import firrtl.Utils._
+import firrtl.Mappers._
+import firrtl.WrappedExpression._
+import firrtl.graph.{DiGraph, MutableDiGraph, CyclicException}
+
+import scala.collection.mutable
+import scala.util.{Try, Success, Failure}
+
+/** Replace wires with nodes in a legal, flow-forward order
+ *
+ * This pass must run after LowerTypes because Aggregate-type
+ * wires have multiple connections that may be impossible to order in a
+ * flow-foward way
+ */
+class RemoveWires extends Transform {
+ def inputForm = LowForm
+ def outputForm = LowForm
+
+ // Extract all expressions that are references to a Wire or Node
+ // Since we are operating on LowForm, they can only be WRefs
+ private def extractNodeWireRefs(expr: Expression): Seq[WRef] = {
+ val refs = mutable.ArrayBuffer.empty[WRef]
+ def rec(e: Expression): Expression = {
+ e match {
+ case ref @ WRef(_,_, WireKind | NodeKind, _) => refs += ref
+ case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested map rec
+ case _ => // Do nothing
+ }
+ e
+ }
+ rec(expr)
+ refs
+ }
+
+ // Transform netlist into DefNodes
+ private def getOrderedNodes(
+ netlist: mutable.LinkedHashMap[WrappedExpression, (Expression, Info)]): Try[Seq[DefNode]] = {
+ val digraph = new MutableDiGraph[WrappedExpression]
+ for ((sink, (expr, _)) <- netlist) {
+ digraph.addVertex(sink)
+ for (source <- extractNodeWireRefs(expr)) {
+ digraph.addPairWithEdge(sink, source)
+ }
+ }
+
+ // We could reverse edge directions and not have to do this reverse, but doing it this way does
+ // a MUCH better job of preserving the logic order as expressed by the designer
+ // See RemoveWireTests for illustration
+ Try {
+ val ordered = digraph.linearize.reverse
+ ordered.map { key =>
+ val WRef(name, _,_,_) = key.e1
+ val (rhs, info) = netlist(key)
+ DefNode(info, name, rhs)
+ }
+ }
+ }
+
+ private def onModule(m: DefModule): DefModule = {
+ // Store all non-node declarations here (like reg, inst, and mem)
+ val decls = mutable.ArrayBuffer.empty[Statement]
+ // Store all "other" statements here, non-wire, non-node connections, printfs, etc.
+ val otherStmts = mutable.ArrayBuffer.empty[Statement]
+ // Add nodes and wire connection here
+ val netlist = mutable.LinkedHashMap.empty[WrappedExpression, (Expression, Info)]
+ // Info at definition of wires for combining into node
+ val wireInfo = mutable.HashMap.empty[WrappedExpression, Info]
+
+ def onStmt(stmt: Statement): Statement = {
+ stmt match {
+ case DefNode(info, name, expr) =>
+ netlist(we(WRef(name))) = (expr, info)
+ case wire: DefWire if !wire.tpe.isInstanceOf[AnalogType] => // Remove all non-Analog wires
+ wireInfo(WRef(wire)) = wire.info
+ case decl: IsDeclaration => // Keep all declarations except for nodes and non-Analog wires
+ decls += decl
+ case con @ Connect(cinfo, lhs, rhs) => kind(lhs) match {
+ case WireKind =>
+ // Be sure to pad the rhs since nodes get their type from the rhs
+ val paddedRhs = ConstantPropagation.pad(rhs, lhs.tpe)
+ val dinfo = wireInfo(lhs)
+ netlist(we(lhs)) = (paddedRhs, MultiInfo(dinfo, cinfo))
+ case _ => otherStmts += con // Other connections just pass through
+ }
+ case invalid @ IsInvalid(info, expr) =>
+ kind(expr) match {
+ case WireKind =>
+ val width = expr.tpe match { case GroundType(width) => width } // LowFirrtl
+ netlist(we(expr)) = (ValidIf(Utils.zero, UIntLiteral(BigInt(0), width), expr.tpe), info)
+ case _ => otherStmts += invalid
+ }
+ case other @ (_: Print | _: Stop | _: Attach) =>
+ otherStmts += other
+ case EmptyStmt => // Dont bother keeping EmptyStmts around
+ case block: Block => block map onStmt
+ case _ => throwInternalError
+ }
+ stmt
+ }
+
+ m match {
+ case mod @ Module(info, name, ports, body) =>
+ onStmt(body)
+ getOrderedNodes(netlist) match {
+ case Success(logic) =>
+ Module(info, name, ports, Block(decls ++ logic ++ otherStmts))
+ // If we hit a CyclicException, just abort removing wires
+ case Failure(_: CyclicException) =>
+ logger.warn(s"Cycle found in module $name, " +
+ "wires will not be removed which can prevent optimizations!")
+ mod
+ case Failure(other) => throw other
+ }
+ case m: ExtModule => m
+ }
+ }
+
+ private val cleanup = Seq(
+ passes.ResolveKinds
+ )
+
+ def execute(state: CircuitState): CircuitState = {
+ val result = state.copy(circuit = state.circuit map onModule)
+ cleanup.foldLeft(result) { case (in, xform) => xform.execute(in) }
+ }
+}
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
index e7bf7884..06e24b97 100644
--- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala
+++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
@@ -721,7 +721,13 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec {
| wire z : UInt<1>
| y <= z
| z <= x""".stripMargin
- val check = input
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output y : UInt<1>
+ | node z = x
+ | y <= z""".stripMargin
execute(input, check, Seq(dontTouch("Top.z")))
}
diff --git a/src/test/scala/firrtlTests/DCETests.scala b/src/test/scala/firrtlTests/DCETests.scala
index ea34d4be..d1848ab8 100644
--- a/src/test/scala/firrtlTests/DCETests.scala
+++ b/src/test/scala/firrtlTests/DCETests.scala
@@ -60,9 +60,8 @@ class DCETests extends FirrtlFlatSpec {
| module Top :
| input x : UInt<1>
| output z : UInt<1>
- | wire a : UInt<1>
- | z <= x
- | a <= x""".stripMargin
+ | node a = x
+ | z <= x""".stripMargin
exec(input, check, Seq(dontTouch("Top.a")))
}
"Unread register" should "be deleted" in {
diff --git a/src/test/scala/firrtlTests/FlattenTests.scala b/src/test/scala/firrtlTests/FlattenTests.scala
index 10988f8f..570d03bf 100644
--- a/src/test/scala/firrtlTests/FlattenTests.scala
+++ b/src/test/scala/firrtlTests/FlattenTests.scala
@@ -64,9 +64,8 @@ class FlattenTests extends LowTransformSpec {
| output b : UInt<32>
| inst i1 of Inline1
| inst i2 of Inline1
- | wire tmp : UInt<32>
| i1.a <= a
- | tmp <= i1.b
+ | node tmp = i1.b
| i2.a <= tmp
| b <= i2.b
| module Inline1 :
@@ -84,9 +83,8 @@ class FlattenTests extends LowTransformSpec {
| wire i2$a : UInt<32>
| wire i2$b : UInt<32>
| i2$b <= i2$a
- | wire tmp : UInt<32>
+ | node tmp = i1$b
| b <= i2$b
- | tmp <= i1$b
| i1$a <= a
| i2$a <= tmp
| module Inline1 :
diff --git a/src/test/scala/firrtlTests/InfoSpec.scala b/src/test/scala/firrtlTests/InfoSpec.scala
index 4cb25640..8d49d753 100644
--- a/src/test/scala/firrtlTests/InfoSpec.scala
+++ b/src/test/scala/firrtlTests/InfoSpec.scala
@@ -59,8 +59,8 @@ class InfoSpec extends FirrtlFlatSpec {
)
result should containTree { case DefRegister(Info1, "r", _,_,_,_) => true }
result should containLine (s"reg [7:0] r; //$Info1")
- result should containTree { case DefWire(Info2, "w", _) => true }
- result should containLine (s"wire [7:0] w; //$Info2")
+ result should containTree { case DefNode(Info2, "w", _) => true }
+ result should containLine (s"wire [7:0] w; //$Info2") // Node "w" declaration in Verilog
result should containTree { case DefNode(Info3, "n", _) => true }
result should containLine (s"wire [7:0] n; //$Info3")
result should containLine (s"assign n = w | x; //$Info3")
diff --git a/src/test/scala/firrtlTests/RemoveWiresSpec.scala b/src/test/scala/firrtlTests/RemoveWiresSpec.scala
new file mode 100644
index 00000000..cfc03ad9
--- /dev/null
+++ b/src/test/scala/firrtlTests/RemoveWiresSpec.scala
@@ -0,0 +1,123 @@
+// See LICENSE for license details.
+
+package firrtlTests
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Mappers._
+import FirrtlCheckers._
+
+import collection.mutable
+
+class RemoveWiresSpec extends FirrtlFlatSpec {
+ def compile(input: String): CircuitState =
+ (new LowFirrtlCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty)
+ def compileBody(body: String) = {
+ val str = """
+ |circuit Test :
+ | module Test :
+ |""".stripMargin + body.split("\n").mkString(" ", "\n ", "")
+ compile(str)
+ }
+
+ def getNodesAndWires(circuit: Circuit): (Seq[DefNode], Seq[DefWire]) = {
+ require(circuit.modules.size == 1)
+
+ val nodes = mutable.ArrayBuffer.empty[DefNode]
+ val wires = mutable.ArrayBuffer.empty[DefWire]
+ def onStmt(stmt: Statement): Statement = {
+ stmt map onStmt match {
+ case node: DefNode => nodes += node
+ case wire: DefWire => wires += wire
+ case _ =>
+ }
+ stmt
+ }
+
+ circuit.modules.head match {
+ case Module(_,_,_, body) => onStmt(body)
+ }
+ (nodes, wires)
+ }
+
+ "Remove Wires" should "turn wires and their single connect into nodes" in {
+ val result = compileBody(s"""
+ |input a : UInt<8>
+ |output b : UInt<8>
+ |wire w : UInt<8>
+ |w <= a
+ |b <= w""".stripMargin
+ )
+ val (nodes, wires) = getNodesAndWires(result.circuit)
+ wires.size should be (0)
+
+ nodes.map(_.serialize) should be (Seq("node w = a"))
+ }
+
+ it should "order nodes in a legal, flow-forward way" in {
+ val result = compileBody(s"""
+ |input a : UInt<8>
+ |output b : UInt<8>
+ |wire w : UInt<8>
+ |wire x : UInt<8>
+ |node y = x
+ |x <= w
+ |w <= a
+ |b <= y""".stripMargin
+ )
+ val (nodes, wires) = getNodesAndWires(result.circuit)
+ wires.size should be (0)
+ nodes.map(_.serialize) should be (
+ Seq("node w = a",
+ "node x = w",
+ "node y = x")
+ )
+ }
+
+ it should "properly pad rhs of introduced nodes if necessary" in {
+ val result = compileBody(s"""
+ |output b : UInt<8>
+ |wire w : UInt<8>
+ |w <= UInt(2)
+ |b <= w""".stripMargin
+ )
+ val (nodes, wires) = getNodesAndWires(result.circuit)
+ wires.size should be (0)
+ nodes.map(_.serialize) should be (
+ Seq("""node w = pad(UInt<2>("h2"), 8)""")
+ )
+ }
+
+ it should "support arbitrary expression for wire connection rhs" in {
+ val result = compileBody(s"""
+ |input a : UInt<8>
+ |input b : UInt<8>
+ |output c : UInt<8>
+ |wire w : UInt<8>
+ |w <= tail(add(a, b), 1)
+ |c <= w""".stripMargin
+ )
+ val (nodes, wires) = getNodesAndWires(result.circuit)
+ wires.size should be (0)
+ nodes.map(_.serialize) should be (
+ Seq("""node w = tail(add(a, b), 1)""")
+ )
+ }
+
+ it should "do a reasonable job preserving input order for unrelatd logic" in {
+ val result = compileBody(s"""
+ |input a : UInt<8>
+ |input b : UInt<8>
+ |output z : UInt<8>
+ |node x = not(a)
+ |node y = not(b)
+ |z <= and(x, y)""".stripMargin
+ )
+ val (nodes, wires) = getNodesAndWires(result.circuit)
+ wires.size should be (0)
+ nodes.map(_.serialize) should be (
+ Seq("node x = not(a)",
+ "node y = not(b)")
+ )
+ }
+}