1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
|
// See LICENSE for license details.
package firrtlTests
import firrtl._
import firrtl.ir._
import firrtl.Mappers._
import FirrtlCheckers._
import collection.mutable
class RemoveWiresSpec extends FirrtlFlatSpec {
def compile(input: String): CircuitState =
(new LowFirrtlCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty)
def compileBody(body: String) = {
val str = """
|circuit Test :
| module Test :
|""".stripMargin + body.split("\n").mkString(" ", "\n ", "")
compile(str)
}
def getNodesAndWires(circuit: Circuit): (Seq[DefNode], Seq[DefWire]) = {
require(circuit.modules.size == 1)
val nodes = mutable.ArrayBuffer.empty[DefNode]
val wires = mutable.ArrayBuffer.empty[DefWire]
def onStmt(stmt: Statement): Statement = {
stmt map onStmt match {
case node: DefNode => nodes += node
case wire: DefWire => wires += wire
case _ =>
}
stmt
}
circuit.modules.head match {
case Module(_,_,_, body) => onStmt(body)
}
(nodes, wires)
}
"Remove Wires" should "turn wires and their single connect into nodes" in {
val result = compileBody(s"""
|input a : UInt<8>
|output b : UInt<8>
|wire w : UInt<8>
|w <= a
|b <= w""".stripMargin
)
val (nodes, wires) = getNodesAndWires(result.circuit)
wires.size should be (0)
nodes.map(_.serialize) should be (Seq("node w = a"))
}
it should "order nodes in a legal, flow-forward way" in {
val result = compileBody(s"""
|input a : UInt<8>
|output b : UInt<8>
|wire w : UInt<8>
|wire x : UInt<8>
|node y = x
|x <= w
|w <= a
|b <= y""".stripMargin
)
val (nodes, wires) = getNodesAndWires(result.circuit)
wires.size should be (0)
nodes.map(_.serialize) should be (
Seq("node w = a",
"node x = w",
"node y = x")
)
}
it should "properly pad rhs of introduced nodes if necessary" in {
val result = compileBody(s"""
|output b : UInt<8>
|wire w : UInt<8>
|w <= UInt(2)
|b <= w""".stripMargin
)
val (nodes, wires) = getNodesAndWires(result.circuit)
wires.size should be (0)
nodes.map(_.serialize) should be (
Seq("""node w = pad(UInt<2>("h2"), 8)""")
)
}
it should "support arbitrary expression for wire connection rhs" in {
val result = compileBody(s"""
|input a : UInt<8>
|input b : UInt<8>
|output c : UInt<8>
|wire w : UInt<8>
|w <= tail(add(a, b), 1)
|c <= w""".stripMargin
)
val (nodes, wires) = getNodesAndWires(result.circuit)
wires.size should be (0)
nodes.map(_.serialize) should be (
Seq("""node w = tail(add(a, b), 1)""")
)
}
it should "do a reasonable job preserving input order for unrelatd logic" in {
val result = compileBody(s"""
|input a : UInt<8>
|input b : UInt<8>
|output z : UInt<8>
|node x = not(a)
|node y = not(b)
|z <= and(x, y)""".stripMargin
)
val (nodes, wires) = getNodesAndWires(result.circuit)
wires.size should be (0)
nodes.map(_.serialize) should be (
Seq("node x = not(a)",
"node y = not(b)")
)
}
}
|