diff options
| author | Jack Koenig | 2017-11-09 19:02:43 -0800 |
|---|---|---|
| committer | Jack Koenig | 2017-12-12 15:34:43 -0800 |
| commit | e39609a2bfbbd108fa1e5044e9c270685d75a816 (patch) | |
| tree | 3f773ec4197c4a2ec9969c6e75db16afffe57f51 /src/test/scala/firrtlTests/RemoveWiresSpec.scala | |
| parent | 7cc075438aa8b67fb52f0556ac9a5bc07bcca232 (diff) | |
Add RemoveWires transform
This transform replaces all wires with nodes in a legal, flow-forward
order
Diffstat (limited to 'src/test/scala/firrtlTests/RemoveWiresSpec.scala')
| -rw-r--r-- | src/test/scala/firrtlTests/RemoveWiresSpec.scala | 123 |
1 files changed, 123 insertions, 0 deletions
diff --git a/src/test/scala/firrtlTests/RemoveWiresSpec.scala b/src/test/scala/firrtlTests/RemoveWiresSpec.scala new file mode 100644 index 00000000..cfc03ad9 --- /dev/null +++ b/src/test/scala/firrtlTests/RemoveWiresSpec.scala @@ -0,0 +1,123 @@ +// See LICENSE for license details. + +package firrtlTests + +import firrtl._ +import firrtl.ir._ +import firrtl.Mappers._ +import FirrtlCheckers._ + +import collection.mutable + +class RemoveWiresSpec extends FirrtlFlatSpec { + def compile(input: String): CircuitState = + (new LowFirrtlCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty) + def compileBody(body: String) = { + val str = """ + |circuit Test : + | module Test : + |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") + compile(str) + } + + def getNodesAndWires(circuit: Circuit): (Seq[DefNode], Seq[DefWire]) = { + require(circuit.modules.size == 1) + + val nodes = mutable.ArrayBuffer.empty[DefNode] + val wires = mutable.ArrayBuffer.empty[DefWire] + def onStmt(stmt: Statement): Statement = { + stmt map onStmt match { + case node: DefNode => nodes += node + case wire: DefWire => wires += wire + case _ => + } + stmt + } + + circuit.modules.head match { + case Module(_,_,_, body) => onStmt(body) + } + (nodes, wires) + } + + "Remove Wires" should "turn wires and their single connect into nodes" in { + val result = compileBody(s""" + |input a : UInt<8> + |output b : UInt<8> + |wire w : UInt<8> + |w <= a + |b <= w""".stripMargin + ) + val (nodes, wires) = getNodesAndWires(result.circuit) + wires.size should be (0) + + nodes.map(_.serialize) should be (Seq("node w = a")) + } + + it should "order nodes in a legal, flow-forward way" in { + val result = compileBody(s""" + |input a : UInt<8> + |output b : UInt<8> + |wire w : UInt<8> + |wire x : UInt<8> + |node y = x + |x <= w + |w <= a + |b <= y""".stripMargin + ) + val (nodes, wires) = getNodesAndWires(result.circuit) + wires.size should be (0) + nodes.map(_.serialize) should be ( + Seq("node w = a", + "node x = w", + "node y = x") + ) + } + + it should "properly pad rhs of introduced nodes if necessary" in { + val result = compileBody(s""" + |output b : UInt<8> + |wire w : UInt<8> + |w <= UInt(2) + |b <= w""".stripMargin + ) + val (nodes, wires) = getNodesAndWires(result.circuit) + wires.size should be (0) + nodes.map(_.serialize) should be ( + Seq("""node w = pad(UInt<2>("h2"), 8)""") + ) + } + + it should "support arbitrary expression for wire connection rhs" in { + val result = compileBody(s""" + |input a : UInt<8> + |input b : UInt<8> + |output c : UInt<8> + |wire w : UInt<8> + |w <= tail(add(a, b), 1) + |c <= w""".stripMargin + ) + val (nodes, wires) = getNodesAndWires(result.circuit) + wires.size should be (0) + nodes.map(_.serialize) should be ( + Seq("""node w = tail(add(a, b), 1)""") + ) + } + + it should "do a reasonable job preserving input order for unrelatd logic" in { + val result = compileBody(s""" + |input a : UInt<8> + |input b : UInt<8> + |output z : UInt<8> + |node x = not(a) + |node y = not(b) + |z <= and(x, y)""".stripMargin + ) + val (nodes, wires) = getNodesAndWires(result.circuit) + wires.size should be (0) + nodes.map(_.serialize) should be ( + Seq("node x = not(a)", + "node y = not(b)") + ) + } +} |
