aboutsummaryrefslogtreecommitdiff
path: root/src/test/scala/firrtlTests/RemoveWiresSpec.scala
blob: cfc03ad997e29868e5939f01ec59bbfd6c650ad3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
// See LICENSE for license details.

package firrtlTests

import firrtl._
import firrtl.ir._
import firrtl.Mappers._
import FirrtlCheckers._

import collection.mutable

class RemoveWiresSpec extends FirrtlFlatSpec {
  def compile(input: String): CircuitState =
    (new LowFirrtlCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty)
  def compileBody(body: String) = {
    val str = """
      |circuit Test :
      |  module Test :
      |""".stripMargin + body.split("\n").mkString("    ", "\n    ", "")
    compile(str)
  }

  def getNodesAndWires(circuit: Circuit): (Seq[DefNode], Seq[DefWire]) = {
    require(circuit.modules.size == 1)

    val nodes = mutable.ArrayBuffer.empty[DefNode]
    val wires = mutable.ArrayBuffer.empty[DefWire]
    def onStmt(stmt: Statement): Statement = {
      stmt map onStmt match {
        case node: DefNode => nodes += node
        case wire: DefWire => wires += wire
        case _ =>
      }
      stmt
    }

    circuit.modules.head match {
      case Module(_,_,_, body) => onStmt(body)
    }
    (nodes, wires)
  }

  "Remove Wires" should "turn wires and their single connect into nodes" in {
    val result = compileBody(s"""
      |input a : UInt<8>
      |output b : UInt<8>
      |wire w : UInt<8>
      |w <= a
      |b <= w""".stripMargin
    )
    val (nodes, wires) = getNodesAndWires(result.circuit)
    wires.size should be (0)

    nodes.map(_.serialize) should be (Seq("node w = a"))
  }

  it should "order nodes in a legal, flow-forward way" in {
    val result = compileBody(s"""
      |input a : UInt<8>
      |output b : UInt<8>
      |wire w : UInt<8>
      |wire x : UInt<8>
      |node y = x
      |x <= w
      |w <= a
      |b <= y""".stripMargin
    )
    val (nodes, wires) = getNodesAndWires(result.circuit)
    wires.size should be (0)
    nodes.map(_.serialize) should be (
      Seq("node w = a",
          "node x = w",
          "node y = x")
    )
  }

  it should "properly pad rhs of introduced nodes if necessary" in {
    val result = compileBody(s"""
      |output b : UInt<8>
      |wire w : UInt<8>
      |w <= UInt(2)
      |b <= w""".stripMargin
    )
    val (nodes, wires) = getNodesAndWires(result.circuit)
    wires.size should be (0)
    nodes.map(_.serialize) should be (
      Seq("""node w = pad(UInt<2>("h2"), 8)""")
    )
  }

  it should "support arbitrary expression for wire connection rhs" in {
    val result = compileBody(s"""
      |input a : UInt<8>
      |input b : UInt<8>
      |output c : UInt<8>
      |wire w : UInt<8>
      |w <= tail(add(a, b), 1)
      |c <= w""".stripMargin
    )
    val (nodes, wires) = getNodesAndWires(result.circuit)
    wires.size should be (0)
    nodes.map(_.serialize) should be (
      Seq("""node w = tail(add(a, b), 1)""")
    )
  }

  it should "do a reasonable job preserving input order for unrelatd logic" in {
    val result = compileBody(s"""
      |input a : UInt<8>
      |input b : UInt<8>
      |output z : UInt<8>
      |node x = not(a)
      |node y = not(b)
      |z <= and(x, y)""".stripMargin
    )
    val (nodes, wires) = getNodesAndWires(result.circuit)
    wires.size should be (0)
    nodes.map(_.serialize) should be (
      Seq("node x = not(a)",
          "node y = not(b)")
    )
  }
}