aboutsummaryrefslogtreecommitdiff
path: root/src/test/scala/firrtlTests/RemoveWiresSpec.scala
diff options
context:
space:
mode:
authorJack Koenig2017-11-09 19:02:43 -0800
committerJack Koenig2017-12-12 15:34:43 -0800
commite39609a2bfbbd108fa1e5044e9c270685d75a816 (patch)
tree3f773ec4197c4a2ec9969c6e75db16afffe57f51 /src/test/scala/firrtlTests/RemoveWiresSpec.scala
parent7cc075438aa8b67fb52f0556ac9a5bc07bcca232 (diff)
Add RemoveWires transform
This transform replaces all wires with nodes in a legal, flow-forward order
Diffstat (limited to 'src/test/scala/firrtlTests/RemoveWiresSpec.scala')
-rw-r--r--src/test/scala/firrtlTests/RemoveWiresSpec.scala123
1 files changed, 123 insertions, 0 deletions
diff --git a/src/test/scala/firrtlTests/RemoveWiresSpec.scala b/src/test/scala/firrtlTests/RemoveWiresSpec.scala
new file mode 100644
index 00000000..cfc03ad9
--- /dev/null
+++ b/src/test/scala/firrtlTests/RemoveWiresSpec.scala
@@ -0,0 +1,123 @@
+// See LICENSE for license details.
+
+package firrtlTests
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Mappers._
+import FirrtlCheckers._
+
+import collection.mutable
+
+class RemoveWiresSpec extends FirrtlFlatSpec {
+ def compile(input: String): CircuitState =
+ (new LowFirrtlCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty)
+ def compileBody(body: String) = {
+ val str = """
+ |circuit Test :
+ | module Test :
+ |""".stripMargin + body.split("\n").mkString(" ", "\n ", "")
+ compile(str)
+ }
+
+ def getNodesAndWires(circuit: Circuit): (Seq[DefNode], Seq[DefWire]) = {
+ require(circuit.modules.size == 1)
+
+ val nodes = mutable.ArrayBuffer.empty[DefNode]
+ val wires = mutable.ArrayBuffer.empty[DefWire]
+ def onStmt(stmt: Statement): Statement = {
+ stmt map onStmt match {
+ case node: DefNode => nodes += node
+ case wire: DefWire => wires += wire
+ case _ =>
+ }
+ stmt
+ }
+
+ circuit.modules.head match {
+ case Module(_,_,_, body) => onStmt(body)
+ }
+ (nodes, wires)
+ }
+
+ "Remove Wires" should "turn wires and their single connect into nodes" in {
+ val result = compileBody(s"""
+ |input a : UInt<8>
+ |output b : UInt<8>
+ |wire w : UInt<8>
+ |w <= a
+ |b <= w""".stripMargin
+ )
+ val (nodes, wires) = getNodesAndWires(result.circuit)
+ wires.size should be (0)
+
+ nodes.map(_.serialize) should be (Seq("node w = a"))
+ }
+
+ it should "order nodes in a legal, flow-forward way" in {
+ val result = compileBody(s"""
+ |input a : UInt<8>
+ |output b : UInt<8>
+ |wire w : UInt<8>
+ |wire x : UInt<8>
+ |node y = x
+ |x <= w
+ |w <= a
+ |b <= y""".stripMargin
+ )
+ val (nodes, wires) = getNodesAndWires(result.circuit)
+ wires.size should be (0)
+ nodes.map(_.serialize) should be (
+ Seq("node w = a",
+ "node x = w",
+ "node y = x")
+ )
+ }
+
+ it should "properly pad rhs of introduced nodes if necessary" in {
+ val result = compileBody(s"""
+ |output b : UInt<8>
+ |wire w : UInt<8>
+ |w <= UInt(2)
+ |b <= w""".stripMargin
+ )
+ val (nodes, wires) = getNodesAndWires(result.circuit)
+ wires.size should be (0)
+ nodes.map(_.serialize) should be (
+ Seq("""node w = pad(UInt<2>("h2"), 8)""")
+ )
+ }
+
+ it should "support arbitrary expression for wire connection rhs" in {
+ val result = compileBody(s"""
+ |input a : UInt<8>
+ |input b : UInt<8>
+ |output c : UInt<8>
+ |wire w : UInt<8>
+ |w <= tail(add(a, b), 1)
+ |c <= w""".stripMargin
+ )
+ val (nodes, wires) = getNodesAndWires(result.circuit)
+ wires.size should be (0)
+ nodes.map(_.serialize) should be (
+ Seq("""node w = tail(add(a, b), 1)""")
+ )
+ }
+
+ it should "do a reasonable job preserving input order for unrelatd logic" in {
+ val result = compileBody(s"""
+ |input a : UInt<8>
+ |input b : UInt<8>
+ |output z : UInt<8>
+ |node x = not(a)
+ |node y = not(b)
+ |z <= and(x, y)""".stripMargin
+ )
+ val (nodes, wires) = getNodesAndWires(result.circuit)
+ wires.size should be (0)
+ nodes.map(_.serialize) should be (
+ Seq("node x = not(a)",
+ "node y = not(b)")
+ )
+ }
+}