1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
|
// SPDX-License-Identifier: Apache-2.0
package firrtlTests.transforms
import firrtl.PrimOps._
import firrtl._
import firrtl.ir.DoPrim
import firrtl.stage.PrettyNoExprInlining
import firrtl.transforms.{CombineCats, MaxCatLenAnnotation}
import firrtl.testutils.FirrtlFlatSpec
import firrtl.testutils.FirrtlCheckers._
class CombineCatsSpec extends FirrtlFlatSpec {
private val transforms = Seq(new IRToWorkingIR, new CombineCats)
private val annotations = Seq(new MaxCatLenAnnotation(12))
private def execute(input: String, transforms: Seq[Transform], annotations: AnnotationSeq): CircuitState = {
val c = transforms
.foldLeft(CircuitState(parse(input), UnknownForm, annotations)) { (c: CircuitState, t: Transform) =>
t.runTransform(c)
}
.circuit
CircuitState(c, UnknownForm, Seq(), None)
}
"circuit1 with combined cats" should "be equivalent to one without" in {
val input =
"""circuit Test_CombinedCats1 :
| module Test_CombinedCats1 :
| input in1 : UInt<1>
| input in2 : UInt<2>
| input in3 : UInt<3>
| input in4 : UInt<4>
| output out : UInt<10>
| out <= cat(in4, cat(in3, cat(in2, in1)))
|""".stripMargin
firrtlEquivalenceTest(input, transforms, annotations)
}
"circuit2 with combined cats" should "be equivalent to one without" in {
val input =
"""circuit Test_CombinedCats2 :
| module Test_CombinedCats2 :
| input in1 : UInt<1>
| input in2 : UInt<2>
| input in3 : UInt<3>
| input in4 : UInt<4>
| output out : UInt<10>
| out <= cat(cat(in4, in1), cat(cat(in4, in3), cat(in2, in1)))
|""".stripMargin
firrtlEquivalenceTest(input, transforms, annotations)
}
"circuit3 with combined cats" should "be equivalent to one without" in {
val input =
"""circuit Test_CombinedCats3 :
| module Test_CombinedCats3 :
| input in1 : UInt<1>
| input in2 : UInt<2>
| input in3 : UInt<3>
| input in4 : UInt<4>
| output out : UInt<10>
| node temp1 = cat(cat(in4, in3), cat(in2, in1))
| node temp2 = cat(in4, cat(in3, cat(in2, in1)))
| out <= add(temp1, temp2)
|""".stripMargin
firrtlEquivalenceTest(input, transforms, annotations)
}
"nested cats" should "be combined" in {
val input =
"""circuit Test_CombinedCats4 :
| module Test_CombinedCats4 :
| input in1 : UInt<1>
| input in2 : UInt<2>
| input in3 : UInt<3>
| input in4 : UInt<4>
| output out : UInt<10>
| node temp1 = cat(in2, in1)
| node temp2 = cat(in3, in2)
| node temp3 = cat(in4, in3)
| node temp4 = cat(temp1, temp2)
| node temp5 = cat(temp4, temp3)
| out <= temp5
|""".stripMargin
firrtlEquivalenceTest(input, transforms, annotations)
val result = execute(input, transforms, Seq.empty)
// temp5 should get cat(cat(cat(in3, in2), cat(in4, in3)), cat(cat(in3, in2), cat(in4, in3)))
result should containTree {
case DoPrim(
Cat,
Seq(
DoPrim(
Cat,
Seq(
DoPrim(Cat, Seq(WRef("in2", _, _, _), WRef("in1", _, _, _)), _, _),
DoPrim(Cat, Seq(WRef("in3", _, _, _), WRef("in2", _, _, _)), _, _)
),
_,
_
),
DoPrim(Cat, Seq(WRef("in4", _, _, _), WRef("in3", _, _, _)), _, _)
),
_,
_
) =>
true
}
}
"cats" should "not be longer than maxCatLen" in {
val input =
"""circuit Test_CombinedCats5 :
| module Test_CombinedCats5 :
| input in1 : UInt<1>
| input in2 : UInt<2>
| input in3 : UInt<3>
| input in4 : UInt<4>
| input in5 : UInt<5>
| output out : UInt<10>
| node temp1 = cat(in2, in1)
| node temp2 = cat(in3, temp1)
| node temp3 = cat(in4, temp2)
| node temp4 = cat(in5, temp3)
| out <= temp4
|""".stripMargin
val maxCatLenAnnotation3 = Seq(new MaxCatLenAnnotation(3))
firrtlEquivalenceTest(input, transforms, maxCatLenAnnotation3)
val result = execute(input, transforms, maxCatLenAnnotation3)
// should not contain any cat chains greater than 3
result shouldNot containTree {
case DoPrim(Cat, Seq(_, DoPrim(Cat, Seq(_, DoPrim(Cat, _, _, _)), _, _)), _, _) => true
case DoPrim(Cat, Seq(_, DoPrim(Cat, Seq(_, DoPrim(Cat, Seq(_, DoPrim(Cat, _, _, _)), _, _)), _, _)), _, _) => true
}
// temp2 should get cat(in3, cat(in2, in1))
result should containTree {
case DoPrim(
Cat,
Seq(WRef("in3", _, _, _), DoPrim(Cat, Seq(WRef("in2", _, _, _), WRef("in1", _, _, _)), _, _)),
_,
_
) =>
true
}
}
"nested nodes that are not cats" should "not be expanded" in {
val input =
"""circuit Test_CombinedCats5 :
| module Test_CombinedCats5 :
| input in1 : UInt<1>
| input in2 : UInt<2>
| input in3 : UInt<3>
| input in4 : UInt<4>
| input in5 : UInt<5>
| output out : UInt<10>
| node temp1 = add(in2, in1)
| node temp2 = cat(in3, temp1)
| node temp3 = sub(in4, temp2)
| node temp4 = cat(in5, temp3)
| out <= temp4
|""".stripMargin
firrtlEquivalenceTest(input, transforms, annotations)
val result = execute(input, transforms, Seq.empty)
result shouldNot containTree {
case DoPrim(Cat, Seq(_, DoPrim(Add, _, _, _)), _, _) => true
case DoPrim(Cat, Seq(_, DoPrim(Sub, _, _, _)), _, _) => true
case DoPrim(Cat, Seq(_, DoPrim(Cat, Seq(_, DoPrim(Cat, _, _, _)), _, _)), _, _) => true
}
}
"CombineCats" should s"respect --${PrettyNoExprInlining.longOption}" in {
val input =
"""circuit test :
| module test :
| input in1 : UInt<1>
| input in2 : UInt<2>
| input in3 : UInt<3>
| input in4 : UInt<4>
| output out : UInt<10>
|
| node _T_1 = cat(in1, in2)
| node _T_2 = cat(_T_1, in3)
| out <= cat(_T_2, in4)
|""".stripMargin
val result = execute(input, transforms, PrettyNoExprInlining :: Nil)
result.circuit.serialize should be(parse(input).serialize)
}
}
|