diff options
| author | Jack Koenig | 2017-11-09 19:02:43 -0800 |
|---|---|---|
| committer | Jack Koenig | 2017-12-12 15:34:43 -0800 |
| commit | e39609a2bfbbd108fa1e5044e9c270685d75a816 (patch) | |
| tree | 3f773ec4197c4a2ec9969c6e75db16afffe57f51 | |
| parent | 7cc075438aa8b67fb52f0556ac9a5bc07bcca232 (diff) | |
Add RemoveWires transform
This transform replaces all wires with nodes in a legal, flow-forward
order
| -rw-r--r-- | src/main/scala/firrtl/LoweringCompilers.scala | 3 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/RemoveWires.scala | 132 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ConstantPropagationTests.scala | 8 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/DCETests.scala | 5 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/FlattenTests.scala | 6 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/InfoSpec.scala | 4 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/RemoveWiresSpec.scala | 123 |
7 files changed, 270 insertions, 11 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index 8dd9b180..f032868a 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -86,7 +86,8 @@ class MiddleFirrtlToLowFirrtl extends CoreTransform { passes.InferWidths, passes.Legalize, new firrtl.transforms.RemoveReset, - new firrtl.transforms.CheckCombLoops) + new firrtl.transforms.CheckCombLoops, + new firrtl.transforms.RemoveWires) } /** Runs a series of optimization passes on LowFirrtl diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala new file mode 100644 index 00000000..a1fb32db --- /dev/null +++ b/src/main/scala/firrtl/transforms/RemoveWires.scala @@ -0,0 +1,132 @@ +// See LICENSE for license details. + +package firrtl +package transforms + +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.Mappers._ +import firrtl.WrappedExpression._ +import firrtl.graph.{DiGraph, MutableDiGraph, CyclicException} + +import scala.collection.mutable +import scala.util.{Try, Success, Failure} + +/** Replace wires with nodes in a legal, flow-forward order + * + * This pass must run after LowerTypes because Aggregate-type + * wires have multiple connections that may be impossible to order in a + * flow-foward way + */ +class RemoveWires extends Transform { + def inputForm = LowForm + def outputForm = LowForm + + // Extract all expressions that are references to a Wire or Node + // Since we are operating on LowForm, they can only be WRefs + private def extractNodeWireRefs(expr: Expression): Seq[WRef] = { + val refs = mutable.ArrayBuffer.empty[WRef] + def rec(e: Expression): Expression = { + e match { + case ref @ WRef(_,_, WireKind | NodeKind, _) => refs += ref + case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested map rec + case _ => // Do nothing + } + e + } + rec(expr) + refs + } + + // Transform netlist into DefNodes + private def getOrderedNodes( + netlist: mutable.LinkedHashMap[WrappedExpression, (Expression, Info)]): Try[Seq[DefNode]] = { + val digraph = new MutableDiGraph[WrappedExpression] + for ((sink, (expr, _)) <- netlist) { + digraph.addVertex(sink) + for (source <- extractNodeWireRefs(expr)) { + digraph.addPairWithEdge(sink, source) + } + } + + // We could reverse edge directions and not have to do this reverse, but doing it this way does + // a MUCH better job of preserving the logic order as expressed by the designer + // See RemoveWireTests for illustration + Try { + val ordered = digraph.linearize.reverse + ordered.map { key => + val WRef(name, _,_,_) = key.e1 + val (rhs, info) = netlist(key) + DefNode(info, name, rhs) + } + } + } + + private def onModule(m: DefModule): DefModule = { + // Store all non-node declarations here (like reg, inst, and mem) + val decls = mutable.ArrayBuffer.empty[Statement] + // Store all "other" statements here, non-wire, non-node connections, printfs, etc. + val otherStmts = mutable.ArrayBuffer.empty[Statement] + // Add nodes and wire connection here + val netlist = mutable.LinkedHashMap.empty[WrappedExpression, (Expression, Info)] + // Info at definition of wires for combining into node + val wireInfo = mutable.HashMap.empty[WrappedExpression, Info] + + def onStmt(stmt: Statement): Statement = { + stmt match { + case DefNode(info, name, expr) => + netlist(we(WRef(name))) = (expr, info) + case wire: DefWire if !wire.tpe.isInstanceOf[AnalogType] => // Remove all non-Analog wires + wireInfo(WRef(wire)) = wire.info + case decl: IsDeclaration => // Keep all declarations except for nodes and non-Analog wires + decls += decl + case con @ Connect(cinfo, lhs, rhs) => kind(lhs) match { + case WireKind => + // Be sure to pad the rhs since nodes get their type from the rhs + val paddedRhs = ConstantPropagation.pad(rhs, lhs.tpe) + val dinfo = wireInfo(lhs) + netlist(we(lhs)) = (paddedRhs, MultiInfo(dinfo, cinfo)) + case _ => otherStmts += con // Other connections just pass through + } + case invalid @ IsInvalid(info, expr) => + kind(expr) match { + case WireKind => + val width = expr.tpe match { case GroundType(width) => width } // LowFirrtl + netlist(we(expr)) = (ValidIf(Utils.zero, UIntLiteral(BigInt(0), width), expr.tpe), info) + case _ => otherStmts += invalid + } + case other @ (_: Print | _: Stop | _: Attach) => + otherStmts += other + case EmptyStmt => // Dont bother keeping EmptyStmts around + case block: Block => block map onStmt + case _ => throwInternalError + } + stmt + } + + m match { + case mod @ Module(info, name, ports, body) => + onStmt(body) + getOrderedNodes(netlist) match { + case Success(logic) => + Module(info, name, ports, Block(decls ++ logic ++ otherStmts)) + // If we hit a CyclicException, just abort removing wires + case Failure(_: CyclicException) => + logger.warn(s"Cycle found in module $name, " + + "wires will not be removed which can prevent optimizations!") + mod + case Failure(other) => throw other + } + case m: ExtModule => m + } + } + + private val cleanup = Seq( + passes.ResolveKinds + ) + + def execute(state: CircuitState): CircuitState = { + val result = state.copy(circuit = state.circuit map onModule) + cleanup.foldLeft(result) { case (in, xform) => xform.execute(in) } + } +} diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index e7bf7884..06e24b97 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -721,7 +721,13 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { | wire z : UInt<1> | y <= z | z <= x""".stripMargin - val check = input + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output y : UInt<1> + | node z = x + | y <= z""".stripMargin execute(input, check, Seq(dontTouch("Top.z"))) } diff --git a/src/test/scala/firrtlTests/DCETests.scala b/src/test/scala/firrtlTests/DCETests.scala index ea34d4be..d1848ab8 100644 --- a/src/test/scala/firrtlTests/DCETests.scala +++ b/src/test/scala/firrtlTests/DCETests.scala @@ -60,9 +60,8 @@ class DCETests extends FirrtlFlatSpec { | module Top : | input x : UInt<1> | output z : UInt<1> - | wire a : UInt<1> - | z <= x - | a <= x""".stripMargin + | node a = x + | z <= x""".stripMargin exec(input, check, Seq(dontTouch("Top.a"))) } "Unread register" should "be deleted" in { diff --git a/src/test/scala/firrtlTests/FlattenTests.scala b/src/test/scala/firrtlTests/FlattenTests.scala index 10988f8f..570d03bf 100644 --- a/src/test/scala/firrtlTests/FlattenTests.scala +++ b/src/test/scala/firrtlTests/FlattenTests.scala @@ -64,9 +64,8 @@ class FlattenTests extends LowTransformSpec { | output b : UInt<32> | inst i1 of Inline1 | inst i2 of Inline1 - | wire tmp : UInt<32> | i1.a <= a - | tmp <= i1.b + | node tmp = i1.b | i2.a <= tmp | b <= i2.b | module Inline1 : @@ -84,9 +83,8 @@ class FlattenTests extends LowTransformSpec { | wire i2$a : UInt<32> | wire i2$b : UInt<32> | i2$b <= i2$a - | wire tmp : UInt<32> + | node tmp = i1$b | b <= i2$b - | tmp <= i1$b | i1$a <= a | i2$a <= tmp | module Inline1 : diff --git a/src/test/scala/firrtlTests/InfoSpec.scala b/src/test/scala/firrtlTests/InfoSpec.scala index 4cb25640..8d49d753 100644 --- a/src/test/scala/firrtlTests/InfoSpec.scala +++ b/src/test/scala/firrtlTests/InfoSpec.scala @@ -59,8 +59,8 @@ class InfoSpec extends FirrtlFlatSpec { ) result should containTree { case DefRegister(Info1, "r", _,_,_,_) => true } result should containLine (s"reg [7:0] r; //$Info1") - result should containTree { case DefWire(Info2, "w", _) => true } - result should containLine (s"wire [7:0] w; //$Info2") + result should containTree { case DefNode(Info2, "w", _) => true } + result should containLine (s"wire [7:0] w; //$Info2") // Node "w" declaration in Verilog result should containTree { case DefNode(Info3, "n", _) => true } result should containLine (s"wire [7:0] n; //$Info3") result should containLine (s"assign n = w | x; //$Info3") diff --git a/src/test/scala/firrtlTests/RemoveWiresSpec.scala b/src/test/scala/firrtlTests/RemoveWiresSpec.scala new file mode 100644 index 00000000..cfc03ad9 --- /dev/null +++ b/src/test/scala/firrtlTests/RemoveWiresSpec.scala @@ -0,0 +1,123 @@ +// See LICENSE for license details. + +package firrtlTests + +import firrtl._ +import firrtl.ir._ +import firrtl.Mappers._ +import FirrtlCheckers._ + +import collection.mutable + +class RemoveWiresSpec extends FirrtlFlatSpec { + def compile(input: String): CircuitState = + (new LowFirrtlCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty) + def compileBody(body: String) = { + val str = """ + |circuit Test : + | module Test : + |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") + compile(str) + } + + def getNodesAndWires(circuit: Circuit): (Seq[DefNode], Seq[DefWire]) = { + require(circuit.modules.size == 1) + + val nodes = mutable.ArrayBuffer.empty[DefNode] + val wires = mutable.ArrayBuffer.empty[DefWire] + def onStmt(stmt: Statement): Statement = { + stmt map onStmt match { + case node: DefNode => nodes += node + case wire: DefWire => wires += wire + case _ => + } + stmt + } + + circuit.modules.head match { + case Module(_,_,_, body) => onStmt(body) + } + (nodes, wires) + } + + "Remove Wires" should "turn wires and their single connect into nodes" in { + val result = compileBody(s""" + |input a : UInt<8> + |output b : UInt<8> + |wire w : UInt<8> + |w <= a + |b <= w""".stripMargin + ) + val (nodes, wires) = getNodesAndWires(result.circuit) + wires.size should be (0) + + nodes.map(_.serialize) should be (Seq("node w = a")) + } + + it should "order nodes in a legal, flow-forward way" in { + val result = compileBody(s""" + |input a : UInt<8> + |output b : UInt<8> + |wire w : UInt<8> + |wire x : UInt<8> + |node y = x + |x <= w + |w <= a + |b <= y""".stripMargin + ) + val (nodes, wires) = getNodesAndWires(result.circuit) + wires.size should be (0) + nodes.map(_.serialize) should be ( + Seq("node w = a", + "node x = w", + "node y = x") + ) + } + + it should "properly pad rhs of introduced nodes if necessary" in { + val result = compileBody(s""" + |output b : UInt<8> + |wire w : UInt<8> + |w <= UInt(2) + |b <= w""".stripMargin + ) + val (nodes, wires) = getNodesAndWires(result.circuit) + wires.size should be (0) + nodes.map(_.serialize) should be ( + Seq("""node w = pad(UInt<2>("h2"), 8)""") + ) + } + + it should "support arbitrary expression for wire connection rhs" in { + val result = compileBody(s""" + |input a : UInt<8> + |input b : UInt<8> + |output c : UInt<8> + |wire w : UInt<8> + |w <= tail(add(a, b), 1) + |c <= w""".stripMargin + ) + val (nodes, wires) = getNodesAndWires(result.circuit) + wires.size should be (0) + nodes.map(_.serialize) should be ( + Seq("""node w = tail(add(a, b), 1)""") + ) + } + + it should "do a reasonable job preserving input order for unrelatd logic" in { + val result = compileBody(s""" + |input a : UInt<8> + |input b : UInt<8> + |output z : UInt<8> + |node x = not(a) + |node y = not(b) + |z <= and(x, y)""".stripMargin + ) + val (nodes, wires) = getNodesAndWires(result.circuit) + wires.size should be (0) + nodes.map(_.serialize) should be ( + Seq("node x = not(a)", + "node y = not(b)") + ) + } +} |
