aboutsummaryrefslogtreecommitdiff
path: root/src/test
diff options
context:
space:
mode:
authorAdam Izraelevitz2019-10-18 19:01:19 -0700
committerGitHub2019-10-18 19:01:19 -0700
commitfd981848c7d2a800a15f9acfbf33b57dd1c6225b (patch)
tree3609a301cb0ec867deefea4a0d08425810b00418 /src/test
parent973ecf516c0ef2b222f2eb68dc8b514767db59af (diff)
Upstream intervals (#870)
Major features: - Added Interval type, as well as PrimOps asInterval, clip, wrap, and sqz. - Changed PrimOp names: bpset -> setp, bpshl -> incp, bpshr -> decp - Refactored width/bound inferencer into a separate constraint solver - Added transforms to infer, trim, and remove interval bounds - Tests for said features Plan to be released with 1.3
Diffstat (limited to 'src/test')
-rw-r--r--src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala4
-rw-r--r--src/test/scala/firrtlTests/AsyncResetSpec.scala29
-rw-r--r--src/test/scala/firrtlTests/ChirrtlSpec.scala4
-rw-r--r--src/test/scala/firrtlTests/InfoSpec.scala4
-rw-r--r--src/test/scala/firrtlTests/LowerTypesSpec.scala2
-rw-r--r--src/test/scala/firrtlTests/UniquifySpec.scala2
-rw-r--r--src/test/scala/firrtlTests/ZeroWidthTests.scala2
-rw-r--r--src/test/scala/firrtlTests/constraint/InequalitySpec.scala197
-rw-r--r--src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala44
-rw-r--r--src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala13
-rw-r--r--src/test/scala/firrtlTests/interval/IntervalMathSpec.scala183
-rw-r--r--src/test/scala/firrtlTests/interval/IntervalSpec.scala530
-rw-r--r--src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala2
13 files changed, 980 insertions, 36 deletions
diff --git a/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala b/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala
index d06344af..89f2ec07 100644
--- a/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala
+++ b/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala
@@ -3,14 +3,12 @@
package firrtl.stage.phases.tests
import org.scalatest.{FlatSpec, Matchers, PrivateMethodTester}
-
import java.io.File
import firrtl._
-import firrtl.stage._
import firrtl.stage.phases.DriverCompatibility._
-
import firrtl.options.{InputAnnotationFileAnnotation, Phase, TargetDirAnnotation}
+import firrtl.stage.{CompilerAnnotation, FirrtlCircuitAnnotation, FirrtlFileAnnotation, FirrtlSourceAnnotation, OutputFileAnnotation, RunFirrtlTransformAnnotation}
import firrtl.stage.phases.DriverCompatibility
class DriverCompatibilitySpec extends FlatSpec with Matchers with PrivateMethodTester {
diff --git a/src/test/scala/firrtlTests/AsyncResetSpec.scala b/src/test/scala/firrtlTests/AsyncResetSpec.scala
index 6fcb647a..8ad397b3 100644
--- a/src/test/scala/firrtlTests/AsyncResetSpec.scala
+++ b/src/test/scala/firrtlTests/AsyncResetSpec.scala
@@ -51,16 +51,19 @@ class AsyncResetSpec extends FirrtlFlatSpec {
it should "support casting to other types" in {
val result = compileBody(s"""
|input a : AsyncReset
+ |output u : Interval[0, 1].0
|output v : UInt<1>
|output w : SInt<1>
|output x : Clock
|output y : Fixed<1><<0>>
|output z : AsyncReset
+ |u <= asInterval(a, 0, 1, 0)
|v <= asUInt(a)
|w <= asSInt(a)
|x <= asClock(a)
|y <= asFixedPoint(a, 0)
- |z <= asAsyncReset(a)""".stripMargin
+ |z <= asAsyncReset(a)
+ |""".stripMargin
)
result should containLine ("assign v = $unsigned(a);")
result should containLine ("assign w = $signed(a);")
@@ -76,22 +79,26 @@ class AsyncResetSpec extends FirrtlFlatSpec {
|input c : Clock
|input d : Fixed<1><<0>>
|input e : AsyncReset
+ |input f : Interval[0, 1].0
+ |output u : AsyncReset
|output v : AsyncReset
|output w : AsyncReset
|output x : AsyncReset
|output y : AsyncReset
|output z : AsyncReset
- |v <= asAsyncReset(a)
- |w <= asAsyncReset(a)
- |x <= asAsyncReset(a)
- |y <= asAsyncReset(a)
- |z <= asAsyncReset(a)""".stripMargin
+ |u <= asAsyncReset(a)
+ |v <= asAsyncReset(b)
+ |w <= asAsyncReset(c)
+ |x <= asAsyncReset(d)
+ |y <= asAsyncReset(e)
+ |z <= asAsyncReset(f)""".stripMargin
)
- result should containLine ("assign v = a;")
- result should containLine ("assign w = a;")
- result should containLine ("assign x = a;")
- result should containLine ("assign y = a;")
- result should containLine ("assign z = a;")
+ result should containLine ("assign u = a;")
+ result should containLine ("assign v = b;")
+ result should containLine ("assign w = c;")
+ result should containLine ("assign x = d;")
+ result should containLine ("assign y = e;")
+ result should containLine ("assign z = f;")
}
"Non-literals" should "NOT be allowed as reset values for AsyncReset" in {
diff --git a/src/test/scala/firrtlTests/ChirrtlSpec.scala b/src/test/scala/firrtlTests/ChirrtlSpec.scala
index fba81ec7..b82637b6 100644
--- a/src/test/scala/firrtlTests/ChirrtlSpec.scala
+++ b/src/test/scala/firrtlTests/ChirrtlSpec.scala
@@ -70,8 +70,8 @@ class ChirrtlSpec extends FirrtlFlatSpec {
behavior of "Uniqueness"
for ((description, input) <- CheckSpec.nonUniqueExamples) {
it should s"be asserted for $description" in {
- assertThrows[CheckChirrtl.NotUniqueException] {
- Seq(ToWorkingIR, CheckChirrtl).foldLeft(Parser.parse(input)){ case (c, tx) => tx.run(c) }
+ assertThrows[CheckHighForm.NotUniqueException] {
+ Seq(ToWorkingIR, CheckHighForm).foldLeft(Parser.parse(input)){ case (c, tx) => tx.run(c) }
}
}
}
diff --git a/src/test/scala/firrtlTests/InfoSpec.scala b/src/test/scala/firrtlTests/InfoSpec.scala
index dbc997cd..9d6206af 100644
--- a/src/test/scala/firrtlTests/InfoSpec.scala
+++ b/src/test/scala/firrtlTests/InfoSpec.scala
@@ -66,7 +66,7 @@ class InfoSpec extends FirrtlFlatSpec {
result should containLine (s"assign n = w | x; //$Info3")
}
- they should "be propagated on memories" in {
+ it should "be propagated on memories" in {
val result = compileBody(s"""
|input clock : Clock
|input addr : UInt<5>
@@ -102,7 +102,7 @@ class InfoSpec extends FirrtlFlatSpec {
result should containLine (s"m[m_w_addr] <= m_w_data; //$Info1")
}
- they should "be propagated on instances" in {
+ it should "be propagated on instances" in {
val result = compile(s"""
|circuit Test :
| module Child :
diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala
index be9d738b..69379c51 100644
--- a/src/test/scala/firrtlTests/LowerTypesSpec.scala
+++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala
@@ -12,7 +12,7 @@ import firrtl.transforms._
import firrtl._
class LowerTypesSpec extends FirrtlFlatSpec {
- private val transforms = Seq(
+ private def transforms = Seq(
ToWorkingIR,
CheckHighForm,
ResolveKinds,
diff --git a/src/test/scala/firrtlTests/UniquifySpec.scala b/src/test/scala/firrtlTests/UniquifySpec.scala
index afb82384..e64e9105 100644
--- a/src/test/scala/firrtlTests/UniquifySpec.scala
+++ b/src/test/scala/firrtlTests/UniquifySpec.scala
@@ -16,7 +16,7 @@ import firrtl.util.TestOptions
class UniquifySpec extends FirrtlFlatSpec {
- private val transforms = Seq(
+ private def transforms = Seq(
ToWorkingIR,
CheckHighForm,
ResolveKinds,
diff --git a/src/test/scala/firrtlTests/ZeroWidthTests.scala b/src/test/scala/firrtlTests/ZeroWidthTests.scala
index eb3d1a96..f1dadcee 100644
--- a/src/test/scala/firrtlTests/ZeroWidthTests.scala
+++ b/src/test/scala/firrtlTests/ZeroWidthTests.scala
@@ -11,7 +11,7 @@ import firrtl.Parser
import firrtl.passes._
class ZeroWidthTests extends FirrtlFlatSpec {
- val transforms = Seq(
+ def transforms = Seq(
ToWorkingIR,
ResolveKinds,
InferTypes,
diff --git a/src/test/scala/firrtlTests/constraint/InequalitySpec.scala b/src/test/scala/firrtlTests/constraint/InequalitySpec.scala
new file mode 100644
index 00000000..02a853cb
--- /dev/null
+++ b/src/test/scala/firrtlTests/constraint/InequalitySpec.scala
@@ -0,0 +1,197 @@
+package firrtlTests.constraint
+
+import firrtl.constraint._
+import org.scalatest.{FlatSpec, Matchers}
+import firrtl.ir.Closed
+
+class InequalitySpec extends FlatSpec with Matchers {
+
+ behavior of "Constraints"
+
+ "IsConstraints" should "reduce properly" in {
+ IsMin(Closed(0), Closed(1)) should be (Closed(0))
+ IsMin(Closed(-1), Closed(1)) should be (Closed(-1))
+ IsMax(Closed(-1), Closed(1)) should be (Closed(1))
+ IsNeg(IsMul(Closed(-1), Closed(-2))) should be (Closed(-2))
+ val x = IsMin(IsMul(Closed(1), VarCon("a")), Closed(2))
+ x.children.toSet should be (IsMin(Closed(2), IsMul(Closed(1), VarCon("a"))).children.toSet)
+ }
+
+ "IsAdd" should "reduce properly" in {
+ // All constants
+ IsAdd(Closed(-1), Closed(1)) should be (Closed(0))
+
+ // Pull Out IsMax
+ IsAdd(Closed(1), IsMax(Closed(1), VarCon("a"))) should be (IsMax(Closed(2), IsAdd(VarCon("a"), Closed(1))))
+ IsAdd(Closed(1), IsMax(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be (
+ IsMax(Seq(Closed(2), IsAdd(VarCon("a"), Closed(1)), IsAdd(VarCon("b"), Closed(1))))
+ )
+
+ // Pull Out IsMin
+ IsAdd(Closed(1), IsMin(Closed(1), VarCon("a"))) should be (IsMin(Closed(2), IsAdd(VarCon("a"), Closed(1))))
+ IsAdd(Closed(1), IsMin(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be (
+ IsMin(Seq(Closed(2), IsAdd(VarCon("a"), Closed(1)), IsAdd(VarCon("b"), Closed(1))))
+ )
+
+ // Add Zero
+ IsAdd(Closed(0), VarCon("a")) should be (VarCon("a"))
+
+ // One argument
+ IsAdd(Seq(VarCon("a"))) should be (VarCon("a"))
+ }
+
+ "IsMax" should "reduce properly" in {
+ // All constants
+ IsMax(Closed(-1), Closed(1)) should be (Closed(1))
+
+ // Flatten nested IsMax
+ IsMax(Closed(1), IsMax(Closed(1), VarCon("a"))) should be (IsMax(Closed(1), VarCon("a")))
+ IsMax(Closed(1), IsMax(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be (
+ IsMax(Seq(Closed(1), VarCon("a"), VarCon("b")))
+ )
+
+ // Eliminate IsMins if possible
+ IsMax(Closed(2), IsMin(Closed(1), VarCon("a"))) should be (Closed(2))
+ IsMax(Seq(
+ Closed(2),
+ IsMin(Closed(1), VarCon("a")),
+ IsMin(Closed(3), VarCon("b"))
+ )) should be (
+ IsMax(Seq(
+ Closed(2),
+ IsMin(Closed(3), VarCon("b"))
+ ))
+ )
+
+ // One argument
+ IsMax(Seq(VarCon("a"))) should be (VarCon("a"))
+ IsMax(Seq(Closed(0))) should be (Closed(0))
+ IsMax(Seq(IsMin(VarCon("a"), Closed(0)))) should be (IsMin(VarCon("a"), Closed(0)))
+ }
+
+ "IsMin" should "reduce properly" in {
+ // All constants
+ IsMin(Closed(-1), Closed(1)) should be (Closed(-1))
+
+ // Flatten nested IsMin
+ IsMin(Closed(1), IsMin(Closed(1), VarCon("a"))) should be (IsMin(Closed(1), VarCon("a")))
+ IsMin(Closed(1), IsMin(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be (
+ IsMin(Seq(Closed(1), VarCon("a"), VarCon("b")))
+ )
+
+ // Eliminate IsMaxs if possible
+ IsMin(Closed(1), IsMax(Closed(2), VarCon("a"))) should be (Closed(1))
+ IsMin(Seq(
+ Closed(2),
+ IsMax(Closed(1), VarCon("a")),
+ IsMax(Closed(3), VarCon("b"))
+ )) should be (
+ IsMin(Seq(
+ Closed(2),
+ IsMax(Closed(1), VarCon("a"))
+ ))
+ )
+
+ // One argument
+ IsMin(Seq(VarCon("a"))) should be (VarCon("a"))
+ IsMin(Seq(Closed(0))) should be (Closed(0))
+ IsMin(Seq(IsMax(VarCon("a"), Closed(0)))) should be (IsMax(VarCon("a"), Closed(0)))
+ }
+
+ "IsMul" should "reduce properly" in {
+ // All constants
+ IsMul(Closed(2), Closed(3)) should be (Closed(6))
+
+ // Pull out max, if positive stays max
+ IsMul(Closed(2), IsMax(Closed(3), VarCon("a"))) should be(
+ IsMax(Closed(6), IsMul(Closed(2), VarCon("a")))
+ )
+
+ // Pull out max, if negative is min
+ IsMul(Closed(-2), IsMax(Closed(3), VarCon("a"))) should be(
+ IsMin(Closed(-6), IsMul(Closed(-2), VarCon("a")))
+ )
+
+ // Pull out min, if positive stays min
+ IsMul(Closed(2), IsMin(Closed(3), VarCon("a"))) should be(
+ IsMin(Closed(6), IsMul(Closed(2), VarCon("a")))
+ )
+
+ // Pull out min, if negative is max
+ IsMul(Closed(-2), IsMin(Closed(3), VarCon("a"))) should be(
+ IsMax(Closed(-6), IsMul(Closed(-2), VarCon("a")))
+ )
+
+ // Times zero
+ IsMul(Closed(0), VarCon("x")) should be (Closed(0))
+
+ // Times 1
+ IsMul(Closed(1), VarCon("x")) should be (VarCon("x"))
+
+ // One argument
+ IsMul(Seq(Closed(0))) should be (Closed(0))
+ IsMul(Seq(VarCon("a"))) should be (VarCon("a"))
+
+ // No optimizations
+ val isMax = IsMax(VarCon("x"), VarCon("y"))
+ val isMin = IsMin(VarCon("x"), VarCon("y"))
+ val a = VarCon("a")
+ IsMul(a, isMax).children should be (Vector(a, isMax)) //non-known multiply
+ IsMul(a, isMin).children should be (Vector(a, isMin)) //non-known multiply
+ IsMul(Seq(Closed(2), isMin, isMin)).children should be (Vector(Closed(2), isMin, isMin)) //>1 min
+ IsMul(Seq(Closed(2), isMax, isMax)).children should be (Vector(Closed(2), isMax, isMax)) //>1 max
+ IsMul(Seq(Closed(2), isMin, isMax)).children should be (Vector(Closed(2), isMin, isMax)) //mixed min/max
+ }
+
+ "IsNeg" should "reduce properly" in {
+ // All constants
+ IsNeg(Closed(1)) should be (Closed(-1))
+ // Pull out max
+ IsNeg(IsMax(Closed(1), VarCon("a"))) should be (IsMin(Closed(-1), IsNeg(VarCon("a"))))
+ // Pull out min
+ IsNeg(IsMin(Closed(1), VarCon("a"))) should be (IsMax(Closed(-1), IsNeg(VarCon("a"))))
+ // Pull out add
+ IsNeg(IsAdd(Closed(1), VarCon("a"))) should be (IsAdd(Closed(-1), IsNeg(VarCon("a"))))
+ // Pull out mul
+ IsNeg(IsMul(Closed(2), VarCon("a"))) should be (IsMul(Closed(-2), VarCon("a")))
+ // No optimizations
+ // (pow), (floor?)
+ IsNeg(IsPow(VarCon("x"))).children should be (Vector(IsPow(VarCon("x"))))
+ IsNeg(IsFloor(VarCon("x"))).children should be (Vector(IsFloor(VarCon("x"))))
+ }
+
+ "IsPow" should "reduce properly" in {
+ // All constants
+ IsPow(Closed(1)) should be (Closed(2))
+ // Pull out max
+ IsPow(IsMax(Closed(1), VarCon("a"))) should be (IsMax(Closed(2), IsPow(VarCon("a"))))
+ // Pull out min
+ IsPow(IsMin(Closed(1), VarCon("a"))) should be (IsMin(Closed(2), IsPow(VarCon("a"))))
+ // Pull out add
+ IsPow(IsAdd(Closed(1), VarCon("a"))) should be (IsMul(Closed(2), IsPow(VarCon("a"))))
+ // No optimizations
+ // (mul), (pow), (floor?)
+ IsPow(IsMul(Closed(2), VarCon("x"))).children should be (Vector(IsMul(Closed(2), VarCon("x"))))
+ IsPow(IsPow(VarCon("x"))).children should be (Vector(IsPow(VarCon("x"))))
+ IsPow(IsFloor(VarCon("x"))).children should be (Vector(IsFloor(VarCon("x"))))
+ }
+
+ "IsFloor" should "reduce properly" in {
+ // All constants
+ IsFloor(Closed(1.9)) should be (Closed(1))
+ IsFloor(Closed(-1.9)) should be (Closed(-2))
+ // Pull out max
+ IsFloor(IsMax(Closed(1.9), VarCon("a"))) should be (IsMax(Closed(1), IsFloor(VarCon("a"))))
+ // Pull out min
+ IsFloor(IsMin(Closed(1.9), VarCon("a"))) should be (IsMin(Closed(1), IsFloor(VarCon("a"))))
+ // Cancel with another floor
+ IsFloor(IsFloor(VarCon("a"))) should be (IsFloor(VarCon("a")))
+ // No optimizations
+ // (add), (mul), (pow)
+ IsFloor(IsMul(Closed(2), VarCon("x"))).children should be (Vector(IsMul(Closed(2), VarCon("x"))))
+ IsFloor(IsPow(VarCon("x"))).children should be (Vector(IsPow(VarCon("x"))))
+ IsFloor(IsAdd(Closed(1), VarCon("x"))).children should be (Vector(IsAdd(Closed(1), VarCon("x"))))
+ }
+
+}
+
diff --git a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala
index 6bf86479..a34145ac 100644
--- a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala
+++ b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala
@@ -21,6 +21,36 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec {
}
}
+ "Fixed types" should "infer add correctly if only precision unspecified" in {
+ val passes = Seq(
+ ToWorkingIR,
+ CheckHighForm,
+ ResolveKinds,
+ InferTypes,
+ CheckTypes,
+ ResolveFlows,
+ CheckFlows,
+ new InferWidths,
+ CheckWidths)
+ val input =
+ """circuit Unit :
+ | module Unit :
+ | input a : Fixed<10><<2>>
+ | input b : Fixed<10><<0>>
+ | input c : Fixed<4><<3>>
+ | output d : Fixed<13>
+ | d <= add(a, add(b, c))""".stripMargin
+ val check =
+ """circuit Unit :
+ | module Unit :
+ | input a : Fixed<10><<2>>
+ | input b : Fixed<10><<0>>
+ | input c : Fixed<4><<3>>
+ | output d : Fixed<13><<3>>
+ | d <= add(a, add(b, c))""".stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+
"Fixed types" should "infer add correctly" in {
val passes = Seq(
ToWorkingIR,
@@ -36,7 +66,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec {
"""circuit Unit :
| module Unit :
| input a : Fixed<10><<2>>
- | input b : Fixed<10>
+ | input b : Fixed<10><<0>>
| input c : Fixed<4><<3>>
| output d : Fixed
| d <= add(a, add(b, c))""".stripMargin
@@ -119,13 +149,13 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec {
| module Unit :
| input a : Fixed<10><<2>>
| output d : Fixed
- | d <= bpshl(a, 2)""".stripMargin
+ | d <= incp(a, 2)""".stripMargin
val check =
"""circuit Unit :
| module Unit :
| input a : Fixed<10><<2>>
| output d : Fixed<12><<4>>
- | d <= bpshl(a, 2)""".stripMargin
+ | d <= incp(a, 2)""".stripMargin
executeTest(input, check.split("\n") map normalized, passes)
}
@@ -145,13 +175,13 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec {
| module Unit :
| input a : Fixed<10><<2>>
| output d : Fixed
- | d <= bpshr(a, 2)""".stripMargin
+ | d <= decp(a, 2)""".stripMargin
val check =
"""circuit Unit :
| module Unit :
| input a : Fixed<10><<2>>
| output d : Fixed<8><<0>>
- | d <= bpshr(a, 2)""".stripMargin
+ | d <= decp(a, 2)""".stripMargin
executeTest(input, check.split("\n") map normalized, passes)
}
@@ -171,13 +201,13 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec {
| module Unit :
| input a : Fixed<10><<2>>
| output d : Fixed
- | d <= bpset(a, 3)""".stripMargin
+ | d <= setp(a, 3)""".stripMargin
val check =
"""circuit Unit :
| module Unit :
| input a : Fixed<10><<2>>
| output d : Fixed<11><<3>>
- | d <= bpset(a, 3)""".stripMargin
+ | d <= setp(a, 3)""".stripMargin
executeTest(input, check.split("\n") map normalized, passes)
}
diff --git a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala
index 8686bd0f..f5b16e45 100644
--- a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala
+++ b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala
@@ -14,7 +14,6 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec {
(c: CircuitState, p: Transform) => p.runTransform(c)
}.circuit
val lines = c.serialize.split("\n") map normalized
- println(c.serialize)
expected foreach { e =>
lines should contain(e)
@@ -37,7 +36,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec {
"""circuit Unit :
| module Unit :
| input a : Fixed<10><<2>>
- | input b : Fixed<10>
+ | input b : Fixed<10><<0>>
| input c : Fixed<4><<3>>
| output d : Fixed<<5>>
| d <= add(a, add(b, c))""".stripMargin
@@ -67,7 +66,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec {
"""circuit Unit :
| module Unit :
| input a : Fixed<10><<2>>
- | input b : Fixed<10>
+ | input b : Fixed<10><<0>>
| input c : Fixed<4><<3>>
| output d : Fixed<<5>>
| d <- add(a, add(b, c))""".stripMargin
@@ -99,7 +98,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec {
| module Unit :
| input a : Fixed<10><<2>>
| output d : Fixed<12><<4>>
- | d <= bpshl(a, 2)""".stripMargin
+ | d <= incp(a, 2)""".stripMargin
val check =
"""circuit Unit :
| module Unit :
@@ -126,7 +125,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec {
| module Unit :
| input a : Fixed<10><<2>>
| output d : Fixed<9><<1>>
- | d <= bpshr(a, 1)""".stripMargin
+ | d <= decp(a, 1)""".stripMargin
val check =
"""circuit Unit :
| module Unit :
@@ -153,7 +152,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec {
| module Unit :
| input a : Fixed<10><<2>>
| output d : Fixed
- | d <= bpset(a, 3)""".stripMargin
+ | d <= setp(a, 3)""".stripMargin
val check =
"""circuit Unit :
| module Unit :
@@ -181,7 +180,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec {
class CheckChirrtlTransform extends SeqTransform {
def inputForm = ChirrtlForm
def outputForm = ChirrtlForm
- val transforms = Seq(passes.CheckChirrtl)
+ def transforms = Seq(passes.CheckChirrtl)
}
val chirrtlTransform = new CheckChirrtlTransform
diff --git a/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala b/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala
new file mode 100644
index 00000000..20fdeee1
--- /dev/null
+++ b/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala
@@ -0,0 +1,183 @@
+// See LICENSE for license details.
+
+package firrtlTests.interval
+
+import firrtl.Implicits.constraint2bound
+import firrtl.{ChirrtlForm, CircuitState, LowFirrtlCompiler, Parser}
+import firrtl.ir._
+
+import scala.math.BigDecimal.RoundingMode._
+import firrtl.Parser.IgnoreInfo
+import firrtl.constraint._
+import firrtlTests.FirrtlFlatSpec
+
+class IntervalMathSpec extends FirrtlFlatSpec {
+ val SumPattern = """.*output sum.*<(\d+)>.*""".r
+ val ProductPattern = """.*output product.*<(\d+)>.*""".r
+ val DifferencePattern = """.*output difference.*<(\d+)>.*""".r
+ val ComparisonPattern = """.*output (\w+).*UInt<(\d+)>.*""".r
+ val ShiftLeftPattern = """.*output shl.*<(\d+)>.*""".r
+ val ShiftRightPattern = """.*output shr.*<(\d+)>.*""".r
+ val DShiftLeftPattern = """.*output dshl.*<(\d+)>.*""".r
+ val DShiftRightPattern = """.*output dshr.*<(\d+)>.*""".r
+ val ArithAssignPattern = """\s*(\w+) <= asSInt\(bits\((\w+)\((.*)\).*\)\)\s*""".r
+ def getBound(bound: String, value: Double): IsKnown = bound match {
+ case "[" => Closed(BigDecimal(value))
+ case "]" => Closed(BigDecimal(value))
+ case "(" => Open(BigDecimal(value))
+ case ")" => Open(BigDecimal(value))
+ }
+
+ val prec = 0.5
+
+ for {
+ lb1 <- Seq("[", "(")
+ lv1 <- Range.Double(-1.0, 1.0, prec)
+ uv1 <- if(lb1 == "[") Range.Double(lv1, 1.0, prec) else Range.Double(lv1 + prec, 1.0, prec)
+ ub1 <- if (lv1 == uv1) Seq("]") else Seq("]", ")")
+ bp1 <- 0 to 1
+ lb2 <- Seq("[", "(")
+ lv2 <- Range.Double(-1.0, 1.0, prec)
+ uv2 <- if(lb2 == "[") Range.Double(lv2, 1.0, prec) else Range.Double(lv2 + prec, 1.0, prec)
+ ub2 <- if (lv2 == uv2) Seq("]") else Seq("]", ")")
+ bp2 <- 0 to 1
+ } {
+ val it1 = IntervalType(getBound(lb1, lv1), getBound(ub1, uv1), IntWidth(bp1.toInt))
+ val it2 = IntervalType(getBound(lb2, lv2), getBound(ub2, uv2), IntWidth(bp2.toInt))
+ (it1.range, it2.range) match {
+ case (Some(Nil), _) =>
+ case (_, Some(Nil)) =>
+ case _ =>
+ def config = s"$lb1$lv1,$uv1$ub1.$bp1 and $lb2$lv2,$uv2$ub2.$bp2"
+
+ s"Configuration $config" should "pass" in {
+
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in1 : Interval$lb1$lv1, $uv1$ub1.$bp1
+ | input in2 : Interval$lb2$lv2, $uv2$ub2.$bp2
+ | input amt : UInt<3>
+ | output sum : Interval
+ | output difference : Interval
+ | output product : Interval
+ | output shl : Interval
+ | output shr : Interval
+ | output dshl : Interval
+ | output dshr : Interval
+ | output lt : UInt
+ | output leq : UInt
+ | output gt : UInt
+ | output geq : UInt
+ | output eq : UInt
+ | output neq : UInt
+ | output cat : UInt
+ | sum <= add(in1, in2)
+ | difference <= sub(in1, in2)
+ | product <= mul(in1, in2)
+ | shl <= shl(in1, 3)
+ | shr <= shr(in1, 3)
+ | dshl <= dshl(in1, amt)
+ | dshr <= dshr(in1, amt)
+ | lt <= lt(in1, in2)
+ | leq <= leq(in1, in2)
+ | gt <= gt(in1, in2)
+ | geq <= geq(in1, in2)
+ | eq <= eq(in1, in2)
+ | neq <= lt(in1, in2)
+ | cat <= cat(in1, in2)
+ | """.stripMargin
+
+ val lowerer = new LowFirrtlCompiler
+ val res = lowerer.compileAndEmit(CircuitState(parse(input), ChirrtlForm))
+ val output = res.getEmittedCircuit.value split "\n"
+ val min1 = Closed(it1.min.get)
+ val max1 = Closed(it1.max.get)
+ val min2 = Closed(it2.min.get)
+ val max2 = Closed(it2.max.get)
+ for (line <- output) {
+ line match {
+ case SumPattern(varWidth) =>
+ val bp = IntWidth(Math.max(bp1.toInt, bp2.toInt))
+ val it = IntervalType(IsAdd(min1, min2), IsAdd(max1, max2), bp)
+ assert(varWidth.toInt == it.width.asInstanceOf[IntWidth].width, s"$line,${it.range}")
+ case ProductPattern(varWidth) =>
+ val bp = IntWidth(bp1.toInt + bp2.toInt)
+ val lv = IsMin(Seq(IsMul(min1, min2), IsMul(min1, max2), IsMul(max1, min2), IsMul(max1, max2)))
+ val uv = IsMax(Seq(IsMul(min1, min2), IsMul(min1, max2), IsMul(max1, min2), IsMul(max1, max2)))
+ assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "product")
+ case DifferencePattern(varWidth) =>
+ val bp = IntWidth(Math.max(bp1.toInt, bp2.toInt))
+ val lv = min1 + max2.neg
+ val uv = max1 + min2.neg
+ assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "diff")
+ case ShiftLeftPattern(varWidth) =>
+ val bp = IntWidth(bp1.toInt)
+ val lv = min1 * Closed(8)
+ val uv = max1 * Closed(8)
+ val it = IntervalType(lv, uv, bp)
+ assert(varWidth.toInt == it.width.asInstanceOf[IntWidth].width, "shl")
+ case ShiftRightPattern(varWidth) =>
+ val bp = IntWidth(bp1.toInt)
+ val lv = min1 * Closed(1/3)
+ val uv = max1 * Closed(1/3)
+ assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "shr")
+ case DShiftLeftPattern(varWidth) =>
+ val bp = IntWidth(bp1.toInt)
+ val lv = min1 * Closed(128)
+ val uv = max1 * Closed(128)
+ assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "dshl")
+ case DShiftRightPattern(varWidth) =>
+ val bp = IntWidth(bp1.toInt)
+ val lv = min1
+ val uv = max1
+ assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "dshr")
+ case ComparisonPattern(varWidth) => assert(varWidth.toInt == 1, "==")
+ case ArithAssignPattern(varName, operation, args) =>
+ val arg1 = if(IntervalType(getBound(lb1, lv1), getBound(ub1, uv1), IntWidth(bp1)).width == IntWidth(0)) """SInt<1>("h0")""" else "in1"
+ val arg2 = if(IntervalType(getBound(lb2, lv2), getBound(ub2, uv2), IntWidth(bp2)).width == IntWidth(0)) """SInt<1>("h0")""" else "in2"
+ varName match {
+ case "sum" =>
+ assert(operation === "add", s"""var sum should be result of an add in ${output.mkString("\n")}""")
+ if (bp1 > bp2) {
+ if (arg1 != arg2) assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line")
+ assert(args.contains(s"shl($arg2, ${bp1 - bp2})"),
+ s"$config second arg incorrect in $line")
+ } else if (bp1 < bp2) {
+ assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"),
+ s"$config second arg incorrect in $line")
+ assert(!args.contains("shl($arg2"), s"$config second arg should be just $arg2 in $line")
+ } else {
+ assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line")
+ assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line")
+ }
+ case "product" =>
+ assert(operation === "mul", s"var sum should be result of an add in $line")
+ assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line")
+ assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line")
+ case "difference" =>
+ assert(operation === "sub", s"var difference should be result of an sub in $line")
+ if (bp1 > bp2) {
+ if (arg1 != arg2) assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line")
+ assert(args.contains(s"shl($arg2, ${bp1 - bp2})"),
+ s"$config second arg incorrect in $line")
+ } else if (bp1 < bp2) {
+ assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"),
+ s"$config second arg incorrect in $line")
+ if (arg1 != arg2) assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line")
+ } else {
+ assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line")
+ assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line")
+ }
+ case _ =>
+ }
+ case _ =>
+ }
+ }
+ }
+ }
+ }
+}
+
+
+// vim: set ts=4 sw=4 et:
diff --git a/src/test/scala/firrtlTests/interval/IntervalSpec.scala b/src/test/scala/firrtlTests/interval/IntervalSpec.scala
new file mode 100644
index 00000000..37d79c84
--- /dev/null
+++ b/src/test/scala/firrtlTests/interval/IntervalSpec.scala
@@ -0,0 +1,530 @@
+package firrtlTests
+package interval
+
+import java.io._
+
+import firrtl._
+import firrtl.ir.Circuit
+import firrtl.passes._
+import firrtl.Parser.IgnoreInfo
+import firrtl.passes.CheckTypes.InvalidConnect
+import firrtl.passes.CheckWidths.DisjointSqueeze
+
+class IntervalSpec extends FirrtlFlatSpec {
+ private def executeTest(input: String, expected: Seq[String], passes: Seq[Transform]) = {
+ val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
+ (c: Circuit, p: Transform) =>
+ p.runTransform(CircuitState(c, UnknownForm, AnnotationSeq(Nil), None)).circuit
+ }
+ val lines = c.serialize.split("\n") map normalized
+
+ expected foreach { e =>
+ lines should contain(e)
+ }
+ }
+
+ "Interval types" should "parse correctly" in {
+ val passes = Seq(ToWorkingIR)
+ val input =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : Interval(-0.32, 10.1).4
+ | input in1 : Interval[0, 10.1].4
+ | input in2 : Interval(-0.32, 10].4
+ | input in3 : Interval[-3, 10.1).4
+ | input in4 : Interval(-0.32, 10.1)
+ | input in5 : Interval.4
+ | input in6 : Interval
+ | output out0 : Interval.2
+ | output out1 : Interval
+ | out0 <= add(in0, add(in1, add(in2, add(in3, add(in4, add(in5, in6))))))
+ | out1 <= add(in0, add(in1, add(in2, add(in3, add(in4, add(in5, in6))))))""".stripMargin
+ executeTest(input, input.split("\n") map normalized, passes)
+ }
+
+ "Interval types" should "infer bp correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints())
+ val input =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : Interval(-0.32, 10.1).4
+ | input in1 : Interval[0, 10.1].3
+ | input in2 : Interval(-0.32, 10].2
+ | output out0 : Interval
+ | out0 <= add(in0, add(in1, in2))""".stripMargin
+ val check =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : Interval(-0.32, 10.1).4
+ | input in1 : Interval[0, 10.1].3
+ | input in2 : Interval(-0.32, 10].2
+ | output out0 : Interval.4
+ | out0 <= add(in0, add(in1, in2))""".stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+
+ "Interval types" should "trim known intervals correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals())
+ val input =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : Interval(-0.32, 10.1).4
+ | input in1 : Interval[0, 10.1].3
+ | input in2 : Interval(-0.32, 10].2
+ | output out0 : Interval
+ | out0 <= add(in0, add(in1, in2))""".stripMargin
+ val check =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : Interval[-0.3125, 10.0625].4
+ | input in1 : Interval[0, 10].3
+ | input in2 : Interval[-0.25, 10].2
+ | output out0 : Interval.4
+ | out0 <= add(in0, incp(add(in1, incp(in2, 1)), 1))""".stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+
+ "Interval types" should "infer intervals correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths())
+ val input =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : Interval(0, 10).4
+ | input in1 : Interval(0, 10].3
+ | input in2 : Interval(-1, 3].2
+ | output out0 : Interval
+ | output out1 : Interval
+ | output out2 : Interval
+ | out0 <= add(in0, add(in1, in2))
+ | out1 <= mul(in0, mul(in1, in2))
+ | out2 <= sub(in0, sub(in1, in2))""".stripMargin
+ val check =
+ """output out0 : Interval[-0.5625, 22.9375].4
+ |output out1 : Interval[-74.53125, 298.125].9
+ |output out2 : Interval[-10.6875, 12.8125].4""".stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+
+ "Interval types" should "be removed correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths(), new RemoveIntervals())
+ val input =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : Interval(0, 10).4
+ | input in1 : Interval(0, 10].3
+ | input in2 : Interval(-1, 3].2
+ | output out0 : Interval
+ | output out1 : Interval
+ | output out2 : Interval
+ | out0 <= add(in0, add(in1, in2))
+ | out1 <= mul(in0, mul(in1, in2))
+ | out2 <= sub(in0, sub(in1, in2))""".stripMargin
+ val check =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : SInt<9>
+ | input in1 : SInt<8>
+ | input in2 : SInt<5>
+ | output out0 : SInt<10>
+ | output out1 : SInt<19>
+ | output out2 : SInt<9>
+ | out0 <= add(in0, shl(add(in1, shl(in2, 1)), 1))
+ | out1 <= mul(in0, mul(in1, in2))
+ | out2 <= sub(in0, shl(sub(in1, shl(in2, 1)), 1))""".stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+
+"Interval types" should "infer multiplication by zero correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in1 : Interval[0, 0.5].1
+ | input in2 : Interval[0, 0].1
+ | output mul : Interval
+ | mul <= mul(in2, in1)
+ | """.stripMargin
+ val check = s"""output mul : Interval[0, 0].2 """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+}
+
+ "Interval types" should "infer muxes correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input p : UInt<1>
+ | input in1 : Interval[0, 0.5].1
+ | input in2 : Interval[0, 0].1
+ | output out : Interval
+ | out <= mux(p, in2, in1)
+ | """.stripMargin
+ val check = s"""output out : Interval[0, 0.5].1 """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "infer dshl correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveKinds, ResolveGenders, new InferBinaryPoints(), new TrimIntervals, new InferWidths())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input p : UInt<3>
+ | input in1 : Interval[-1, 1].0
+ | output out : Interval
+ | out <= dshl(in1, p)
+ | """.stripMargin
+ val check = s"""output out : Interval[-128, 128].0 """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "infer asInterval correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferWidths())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input p : UInt<3>
+ | output out : Interval
+ | out <= asInterval(p, 0, 4, 1)
+ | """.stripMargin
+ val check = s"""output out : Interval[0, 2].1 """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "do wrap/clip correctly" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input s: SInt<2>
+ | input u: UInt<3>
+ | input in1: Interval[-3, 5].0
+ | output wrap3: Interval
+ | output wrap4: Interval
+ | output wrap5: Interval
+ | output wrap6: Interval
+ | output wrap7: Interval
+ | output clip3: Interval
+ | output clip4: Interval
+ | output clip5: Interval
+ | output clip6: Interval
+ | output clip7: Interval
+ | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0))
+ | wrap4 <= wrap(in1, asInterval(s, -1, 1, 0))
+ | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0))
+ | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0))
+ | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0))
+ | clip3 <= clip(in1, asInterval(s, -2, 4, 0))
+ | clip4 <= clip(in1, asInterval(s, -1, 1, 0))
+ | clip5 <= clip(in1, asInterval(s, -4, 4, 0))
+ | clip6 <= clip(in1, asInterval(s, -1, 7, 0))
+ | clip7 <= clip(in1, asInterval(s, -4, 7, 0))
+ """.stripMargin
+ //| output wrap1: Interval
+ //| output wrap2: Interval
+ //| output clip1: Interval
+ //| output clip2: Interval
+ //| wrap1 <= wrap(in1, u, 0)
+ //| wrap2 <= wrap(in1, s, 0)
+ //| clip1 <= clip(in1, u)
+ //| clip2 <= clip(in1, s)
+ val check = s"""
+ | output wrap3 : Interval[-2, 4].0
+ | output wrap4 : Interval[-1, 1].0
+ | output wrap5 : Interval[-4, 4].0
+ | output wrap6 : Interval[-1, 7].0
+ | output wrap7 : Interval[-4, 7].0
+ | output clip3 : Interval[-2, 4].0
+ | output clip4 : Interval[-1, 1].0
+ | output clip5 : Interval[-3, 4].0
+ | output clip6 : Interval[-1, 5].0
+ | output clip7 : Interval[-3, 5].0 """.stripMargin
+ // TODO: this optimization
+ //| output wrap1 : Interval[0, 7].0
+ //| output wrap2 : Interval[-2, 1].0
+ //| output clip1 : Interval[0, 5].0
+ //| output clip2 : Interval[-2, 1].0
+ //| output wrap7 : Interval[-3, 5].0
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "remove wrap/clip correctly" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck(), new RemoveIntervals())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input s: SInt<2>
+ | input u: UInt<3>
+ | input in1: Interval[-3, 5].0
+ | output wrap3: Interval
+ | output wrap5: Interval
+ | output wrap6: Interval
+ | output wrap7: Interval
+ | output clip3: Interval
+ | output clip4: Interval
+ | output clip5: Interval
+ | output clip6: Interval
+ | output clip7: Interval
+ | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0))
+ | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0))
+ | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0))
+ | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0))
+ | clip3 <= clip(in1, asInterval(s, -2, 4, 0))
+ | clip4 <= clip(in1, asInterval(s, -1, 1, 0))
+ | clip5 <= clip(in1, asInterval(s, -4, 4, 0))
+ | clip6 <= clip(in1, asInterval(s, -1, 7, 0))
+ | clip7 <= clip(in1, asInterval(s, -4, 7, 0))
+ | """.stripMargin
+ val check = s"""
+ | wrap3 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<4>("h7")), mux(lt(in1, SInt<2>("h-2")), add(in1, SInt<4>("h7")), in1))
+ | wrap5 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), in1)
+ | wrap6 <= mux(lt(in1, SInt<1>("h-1")), add(in1, SInt<5>("h9")), in1)
+ | wrap7 <= in1
+ | clip3 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<2>("h-2")), SInt<2>("h-2"), in1))
+ | clip4 <= mux(gt(in1, SInt<2>("h1")), SInt<2>("h1"), mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1))
+ | clip5 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), in1)
+ | clip6 <= mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1)
+ | clip7 <= in1
+ """.stripMargin
+ //| output wrap4: Interval
+ //| wrap4 <= wrap(in1, asInterval(s, -1, 1, 0), 0)
+ //| wrap4 <= add(rem(sub(in1, SInt<1>("h-1")), sub(SInt<2>("h1"), SInt<1>("h-1"))), SInt<1>("h-1"))
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "shift wrap/clip correctly" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input s: SInt<2>
+ | input in1: Interval[-3, 5].1
+ | output wrap1: Interval
+ | output clip1: Interval
+ | wrap1 <= wrap(in1, asInterval(s, -2, 2, 0))
+ | clip1 <= clip(in1, asInterval(s, -2, 2, 0))
+ | """.stripMargin
+ val check = s"""
+ | wrap1 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), mux(lt(in1, SInt<3>("h-4")), add(in1, SInt<5>("h9")), in1))
+ | clip1 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<3>("h-4")), SInt<3>("h-4"), in1))
+ """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "infer negative binary points" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in1: Interval[-2, 4].-1
+ | input in2: Interval[-4, 8].-2
+ | output out: Interval
+ | out <= add(in1, in2)
+ | """.stripMargin
+ val check = s"""
+ | output out : Interval[-6, 12].-1
+ """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "remove negative binary points" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths(), new RemoveIntervals())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in1: Interval[-2, 4].-1
+ | input in2: Interval[-4, 8].-2
+ | output out: Interval.0
+ | out <= add(in1, in2)
+ | """.stripMargin
+ val check = s"""
+ | output out : SInt<5>
+ | out <= shl(add(in1, shl(in2, 1)), 1)
+ """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "implement squz properly" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck)
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input min: Interval[-1, 4].1
+ | input max: Interval[-3, 5].1
+ | input left: Interval[-3, 3].1
+ | input right: Interval[0, 5].1
+ | input off: Interval[-1, 4].2
+ | output minMax: Interval
+ | output maxMin: Interval
+ | output minLeft: Interval
+ | output leftMin: Interval
+ | output minRight: Interval
+ | output rightMin: Interval
+ | output minOff: Interval
+ | output offMin: Interval
+ |
+ | minMax <= squz(min, max)
+ | maxMin <= squz(max, min)
+ | minLeft <= squz(min, left)
+ | leftMin <= squz(left, min)
+ | minRight <= squz(min, right)
+ | rightMin <= squz(right, min)
+ | minOff <= squz(min, off)
+ | offMin <= squz(off, min)
+ | """.stripMargin
+ val check =
+ s"""
+ | output minMax : Interval[-1, 4].1
+ | output maxMin : Interval[-1, 4].1
+ | output minLeft : Interval[-1, 3].1
+ | output leftMin : Interval[-1, 3].1
+ | output minRight : Interval[0, 4].1
+ | output rightMin : Interval[0, 4].1
+ | output minOff : Interval[-1, 4].1
+ | output offMin : Interval[-1, 4].2
+ """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "lower squz properly" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals)
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input min: Interval[-1, 4].1
+ | input max: Interval[-3, 5].1
+ | input left: Interval[-3, 3].1
+ | input right: Interval[0, 5].1
+ | input off: Interval[-1, 4].2
+ | output minMax: Interval
+ | output maxMin: Interval
+ | output minLeft: Interval
+ | output leftMin: Interval
+ | output minRight: Interval
+ | output rightMin: Interval
+ | output minOff: Interval
+ | output offMin: Interval
+ |
+ | minMax <= squz(min, max)
+ | maxMin <= squz(max, min)
+ | minLeft <= squz(min, left)
+ | leftMin <= squz(left, min)
+ | minRight <= squz(min, right)
+ | rightMin <= squz(right, min)
+ | minOff <= squz(min, off)
+ | offMin <= squz(off, min)
+ | """.stripMargin
+ val check =
+ s"""
+ | minMax <= asSInt(bits(min, 4, 0))
+ | maxMin <= asSInt(bits(max, 4, 0))
+ | minLeft <= asSInt(bits(min, 3, 0))
+ | leftMin <= left
+ | minRight <= asSInt(bits(min, 4, 0))
+ | rightMin <= asSInt(bits(right, 4, 0))
+ | minOff <= asSInt(bits(min, 4, 0))
+ | offMin <= asSInt(bits(off, 5, 0))
+ """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Assigning a larger interval to a smaller interval" should "error!" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals)
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in: Interval[1, 4].1
+ | output out: Interval[2, 3].1
+ | out <= in
+ | """.stripMargin
+ intercept[InvalidConnect]{
+ executeTest(input, Nil, passes)
+ }
+ }
+ "Assigning a more precise interval to a less precise interval" should "error!" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals)
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in: Interval[2, 3].3
+ | output out: Interval[2, 3].1
+ | out <= in
+ | """.stripMargin
+ intercept[InvalidConnect]{
+ executeTest(input, Nil, passes)
+ }
+ }
+ "Chick's example" should "work" in {
+ val input =
+ s"""circuit IntervalChainedSubTester :
+ | module IntervalChainedSubTester :
+ | input clock : Clock
+ | input reset : UInt<1>
+ | node _GEN_0 = sub(SInt<6>("h11"), SInt<6>("h2")) @[IntervalSpec.scala 337:26 IntervalSpec.scala 337:26]
+ | node _GEN_1 = bits(_GEN_0, 4, 0) @[IntervalSpec.scala 337:26 IntervalSpec.scala 337:26]
+ | node intervalResult = asSInt(_GEN_1) @[IntervalSpec.scala 337:26 IntervalSpec.scala 337:26]
+ | skip
+ | node _T_1 = asUInt(intervalResult) @[IntervalSpec.scala 338:50]
+ | skip
+ | node _T_3 = eq(reset, UInt<1>("h0")) @[IntervalSpec.scala 338:9]
+ | node _T_4 = eq(intervalResult, SInt<5>("hf")) @[IntervalSpec.scala 339:25]
+ | skip
+ | node _T_6 = or(_T_4, reset) @[IntervalSpec.scala 339:9]
+ | node _T_7 = eq(_T_6, UInt<1>("h0")) @[IntervalSpec.scala 339:9]
+ | skip
+ | skip
+ | printf(clock, _T_3, "Interval result: %d", _T_1) @[IntervalSpec.scala 338:9]
+ | printf(clock, _T_7, "Assertion failed at IntervalSpec.scala:339 assert(intervalResult === 15.I)") @[IntervalSpec.scala 339:9]
+ | stop(clock, _T_7, 1) @[IntervalSpec.scala 339:9]
+ | stop(clock, _T_3, 0) @[IntervalSpec.scala 340:7]
+ |
+ """.stripMargin
+ compileToVerilog(input)
+ }
+
+ "Squeeze with disjoint intervals" should "error" in {
+ intercept[DisjointSqueeze] {
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in1: Interval[2, 3).3
+ | input in2: Interval[3, 6].3
+ | node out = squz(in1, in2)
+ """.stripMargin
+ compileToVerilog(input)
+ }
+ intercept[DisjointSqueeze] {
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in1: Interval[2, 3).3
+ | input in2: Interval[3, 6].3
+ | node out = squz(in2, in1)
+ """.stripMargin
+ compileToVerilog(input)
+ }
+ }
+
+ "Clip with disjoint intervals" should "work" in {
+ compileToVerilog(
+ s"""circuit Unit :
+ | module Unit :
+ | input in1: Interval[2, 3).3
+ | input in2: Interval[3, 6].3
+ | output out: Interval
+ | out <= clip(in1, in2)
+ """.stripMargin
+ )
+ compileToVerilog(
+ s"""circuit Unit :
+ | module Unit :
+ | input in1: Interval[2, 3).3
+ | input in2: Interval[4, 6].3
+ | node out = clip(in1, in2)
+ """.stripMargin
+ )
+ }
+
+
+ "Wrap with remainder" should "error" in {
+ intercept[WrapWithRemainder] {
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in1: Interval[0, 300).3
+ | input in2: Interval[3, 6].3
+ | node out = wrap(in1, in2)
+ """.stripMargin
+ compileToVerilog(input)
+ }
+ }
+}
diff --git a/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala b/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala
index 46fb310a..88095830 100644
--- a/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala
+++ b/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala
@@ -152,7 +152,7 @@ class InferWidthsWithAnnosSpec extends FirrtlFlatSpec {
}
"InferWidthsWithAnnos" should "work with WiringTransform" in {
- def transforms = Seq(
+ def transforms() = Seq(
ToWorkingIR,
ResolveKinds,
InferTypes,