aboutsummaryrefslogtreecommitdiff
path: root/src/test/scala/firrtlTests/transforms/ManipulateNamesSpec.scala
blob: 82cc997dde8bf2cd0095305db1d343af906b315f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
// See LICENSE for license details.

package firrtlTests.transforms

import firrtl.{ir, CircuitState, FirrtlUserException, Namespace, Parser}
import firrtl.annotations.CircuitTarget
import firrtl.options.Dependency
import firrtl.testutils.FirrtlCheckers._
import firrtl.transforms.{
  ManipulateNames,
  ManipulateNamesBlocklistAnnotation,
  ManipulateNamesAllowlistAnnotation
}

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

object ManipulateNamesSpec {

  class AddPrefix extends ManipulateNames {
    override def manipulate = (a: String, b: Namespace) => Some(b.newName("prefix_" + a))
  }

}

class ManipulateNamesSpec extends AnyFlatSpec with Matchers {

  import ManipulateNamesSpec._

  class CircuitFixture {
    protected val input =
      """|circuit Foo:
         |  module Bar:
         |    node a = UInt<1>(0)
         |  module Foo:
         |    inst bar of Bar
         |    inst bar2 of Bar
         |""".stripMargin
    val `~Foo` = CircuitTarget("Foo")
    val `~Foo|Foo` = `~Foo`.module("Foo")
    val `~Foo|Foo/bar:Bar` = `~Foo|Foo`.instOf("bar", "Bar")
    val `~Foo|Foo/bar2:Bar` = `~Foo|Foo`.instOf("bar2", "Bar")
    val `~Foo|Bar` = `~Foo`.module("Bar")
    val `~Foo|Bar>a` = `~Foo|Bar`.ref("a")
    val tm = new firrtl.stage.transforms.Compiler(Seq(Dependency[AddPrefix]))
  }

  behavior of "ManipulateNames"

  it should "rename everything by default" in new CircuitFixture {
    val state = CircuitState(Parser.parse(input), Seq.empty)
    val statex = tm.execute(state)
    val expected: Seq[PartialFunction[Any, Boolean]] = Seq(
      { case ir.Circuit(_, _, "prefix_Foo") => true },
      { case ir.Module(_, "prefix_Foo", _, _) => true},
      { case ir.Module(_, "prefix_Bar", _, _) => true}
    )
    expected.foreach(statex should containTree (_))
  }

  it should "do nothing if the circuit is blocklisted" in new CircuitFixture {
    val annotations = Seq(ManipulateNamesBlocklistAnnotation(Seq(Seq(`~Foo`)), Dependency[AddPrefix]))
    val state = CircuitState(Parser.parse(input), annotations)
    val statex = tm.execute(state)
    state.circuit.serialize should be (statex.circuit.serialize)
  }

  it should "not rename the circuit if the top module is blocklisted" in new CircuitFixture {
    val annotations = Seq(ManipulateNamesBlocklistAnnotation(Seq(Seq(`~Foo|Foo`)), Dependency[AddPrefix]))
    val state = CircuitState(Parser.parse(input), annotations)
    val expected: Seq[PartialFunction[Any, Boolean]] = Seq(
      { case ir.Circuit(_, _, "Foo") => true },
      { case ir.Module(_, "Foo", _, _) => true},
      { case ir.Module(_, "prefix_Bar", _, _) => true}
    )
    val statex = tm.execute(state)
    expected.foreach(statex should containTree (_))
  }

  it should "not rename instances if blocklisted" in new CircuitFixture {
    val annotations = Seq(ManipulateNamesBlocklistAnnotation(Seq(Seq(`~Foo|Foo/bar:Bar`)), Dependency[AddPrefix]))
    val state = CircuitState(Parser.parse(input), annotations)
    val expected: Seq[PartialFunction[Any, Boolean]] = Seq(
      { case ir.DefInstance(_, "bar", "prefix_Bar", _) => true},
      { case ir.Module(_, "prefix_Bar", _, _) => true}
    )
    val statex = tm.execute(state)
    expected.foreach(statex should containTree (_))
  }

  it  should "do nothing if the circuit is not allowlisted" in new CircuitFixture {
    val annotations = Seq(
      ManipulateNamesAllowlistAnnotation(Seq(Seq(`~Foo|Foo`)), Dependency[AddPrefix])
    )
    val state = CircuitState(Parser.parse(input), annotations)
    val statex = tm.execute(state)
    state.circuit.serialize should be (statex.circuit.serialize)
  }

  it should "rename only the circuit if allowlisted" in new CircuitFixture {
    val annotations = Seq(
      ManipulateNamesAllowlistAnnotation(Seq(Seq(`~Foo`)), Dependency[AddPrefix]),
      ManipulateNamesAllowlistAnnotation(Seq(Seq(`~Foo|Foo`)), Dependency[AddPrefix])
    )
    val state = CircuitState(Parser.parse(input), annotations)
    val statex = tm.execute(state)
    val expected: Seq[PartialFunction[Any, Boolean]] = Seq(
      { case ir.Circuit(_, _, "prefix_Foo") => true },
      { case ir.Module(_, "prefix_Foo", _, _) => true},
      { case ir.DefInstance(_, "bar", "Bar", _) => true},
      { case ir.DefInstance(_, "bar2", "Bar", _) => true},
      { case ir.Module(_, "Bar", _, _) => true},
      { case ir.DefNode(_, "a", _) => true}
    )
    expected.foreach(statex should containTree (_))
  }

  it should "rename an instance via allowlisting" in new CircuitFixture {
    val annotations = Seq(
      ManipulateNamesAllowlistAnnotation(Seq(Seq(`~Foo`)), Dependency[AddPrefix]),
      ManipulateNamesAllowlistAnnotation(Seq(Seq(`~Foo|Foo/bar:Bar`)), Dependency[AddPrefix])
    )
    val state = CircuitState(Parser.parse(input), annotations)
    val statex = tm.execute(state)
    val expected: Seq[PartialFunction[Any, Boolean]] = Seq(
      { case ir.Circuit(_, _, "Foo") => true },
      { case ir.Module(_, "Foo", _, _) => true},
      { case ir.DefInstance(_, "prefix_bar", "Bar", _) => true},
      { case ir.DefInstance(_, "bar2", "Bar", _) => true},
      { case ir.Module(_, "Bar", _, _) => true},
      { case ir.DefNode(_, "a", _) => true}
    )
    expected.foreach(statex should containTree (_))
  }

  it should "rename a node via allowlisting" in new CircuitFixture {
    val annotations = Seq(
      ManipulateNamesAllowlistAnnotation(Seq(Seq(`~Foo`)), Dependency[AddPrefix]),
      ManipulateNamesAllowlistAnnotation(Seq(Seq(`~Foo|Bar>a`)), Dependency[AddPrefix])
    )
    val state = CircuitState(Parser.parse(input), annotations)
    val statex = tm.execute(state)
    val expected: Seq[PartialFunction[Any, Boolean]] = Seq(
      { case ir.Circuit(_, _, "Foo") => true },
      { case ir.Module(_, "Foo", _, _) => true},
      { case ir.DefInstance(_, "bar", "Bar", _) => true},
      { case ir.DefInstance(_, "bar2", "Bar", _) => true},
      { case ir.Module(_, "Bar", _, _) => true},
      { case ir.DefNode(_, "prefix_a", _) => true}
    )
    expected.foreach(statex should containTree (_))
  }

  it should "throw user errors on circuits that haven't been run through LowerTypes" in {
    val input =
      """|circuit Foo:
         |  module Foo:
         |    wire bar: {a: UInt<1>, b: UInt<1>}
         |    node baz = bar.a
         |""".stripMargin
    val state = CircuitState(Parser.parse(input), Seq.empty)
    intercept [FirrtlUserException] {
      (new AddPrefix).transform(state)
    }.getMessage should include ("LowerTypes")
  }

  behavior of "ManipulateNamesBlocklistAnnotation"

  it should "throw an exception if a non-local target is skipped" in new CircuitFixture {
    val barA = CircuitTarget("Foo").module("Foo").instOf("bar", "Bar").ref("a")
    assertThrows[java.lang.IllegalArgumentException]{
      Seq(ManipulateNamesBlocklistAnnotation(Seq(Seq(barA)), Dependency[AddPrefix]))
    }
  }

}