diff options
| author | Adam Izraelevitz | 2019-10-18 19:01:19 -0700 |
|---|---|---|
| committer | GitHub | 2019-10-18 19:01:19 -0700 |
| commit | fd981848c7d2a800a15f9acfbf33b57dd1c6225b (patch) | |
| tree | 3609a301cb0ec867deefea4a0d08425810b00418 /src/test | |
| parent | 973ecf516c0ef2b222f2eb68dc8b514767db59af (diff) | |
Upstream intervals (#870)
Major features:
- Added Interval type, as well as PrimOps asInterval, clip, wrap, and sqz.
- Changed PrimOp names: bpset -> setp, bpshl -> incp, bpshr -> decp
- Refactored width/bound inferencer into a separate constraint solver
- Added transforms to infer, trim, and remove interval bounds
- Tests for said features
Plan to be released with 1.3
Diffstat (limited to 'src/test')
13 files changed, 980 insertions, 36 deletions
diff --git a/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala b/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala index d06344af..89f2ec07 100644 --- a/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala +++ b/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala @@ -3,14 +3,12 @@ package firrtl.stage.phases.tests import org.scalatest.{FlatSpec, Matchers, PrivateMethodTester} - import java.io.File import firrtl._ -import firrtl.stage._ import firrtl.stage.phases.DriverCompatibility._ - import firrtl.options.{InputAnnotationFileAnnotation, Phase, TargetDirAnnotation} +import firrtl.stage.{CompilerAnnotation, FirrtlCircuitAnnotation, FirrtlFileAnnotation, FirrtlSourceAnnotation, OutputFileAnnotation, RunFirrtlTransformAnnotation} import firrtl.stage.phases.DriverCompatibility class DriverCompatibilitySpec extends FlatSpec with Matchers with PrivateMethodTester { diff --git a/src/test/scala/firrtlTests/AsyncResetSpec.scala b/src/test/scala/firrtlTests/AsyncResetSpec.scala index 6fcb647a..8ad397b3 100644 --- a/src/test/scala/firrtlTests/AsyncResetSpec.scala +++ b/src/test/scala/firrtlTests/AsyncResetSpec.scala @@ -51,16 +51,19 @@ class AsyncResetSpec extends FirrtlFlatSpec { it should "support casting to other types" in { val result = compileBody(s""" |input a : AsyncReset + |output u : Interval[0, 1].0 |output v : UInt<1> |output w : SInt<1> |output x : Clock |output y : Fixed<1><<0>> |output z : AsyncReset + |u <= asInterval(a, 0, 1, 0) |v <= asUInt(a) |w <= asSInt(a) |x <= asClock(a) |y <= asFixedPoint(a, 0) - |z <= asAsyncReset(a)""".stripMargin + |z <= asAsyncReset(a) + |""".stripMargin ) result should containLine ("assign v = $unsigned(a);") result should containLine ("assign w = $signed(a);") @@ -76,22 +79,26 @@ class AsyncResetSpec extends FirrtlFlatSpec { |input c : Clock |input d : Fixed<1><<0>> |input e : AsyncReset + |input f : Interval[0, 1].0 + |output u : AsyncReset |output v : AsyncReset |output w : AsyncReset |output x : AsyncReset |output y : AsyncReset |output z : AsyncReset - |v <= asAsyncReset(a) - |w <= asAsyncReset(a) - |x <= asAsyncReset(a) - |y <= asAsyncReset(a) - |z <= asAsyncReset(a)""".stripMargin + |u <= asAsyncReset(a) + |v <= asAsyncReset(b) + |w <= asAsyncReset(c) + |x <= asAsyncReset(d) + |y <= asAsyncReset(e) + |z <= asAsyncReset(f)""".stripMargin ) - result should containLine ("assign v = a;") - result should containLine ("assign w = a;") - result should containLine ("assign x = a;") - result should containLine ("assign y = a;") - result should containLine ("assign z = a;") + result should containLine ("assign u = a;") + result should containLine ("assign v = b;") + result should containLine ("assign w = c;") + result should containLine ("assign x = d;") + result should containLine ("assign y = e;") + result should containLine ("assign z = f;") } "Non-literals" should "NOT be allowed as reset values for AsyncReset" in { diff --git a/src/test/scala/firrtlTests/ChirrtlSpec.scala b/src/test/scala/firrtlTests/ChirrtlSpec.scala index fba81ec7..b82637b6 100644 --- a/src/test/scala/firrtlTests/ChirrtlSpec.scala +++ b/src/test/scala/firrtlTests/ChirrtlSpec.scala @@ -70,8 +70,8 @@ class ChirrtlSpec extends FirrtlFlatSpec { behavior of "Uniqueness" for ((description, input) <- CheckSpec.nonUniqueExamples) { it should s"be asserted for $description" in { - assertThrows[CheckChirrtl.NotUniqueException] { - Seq(ToWorkingIR, CheckChirrtl).foldLeft(Parser.parse(input)){ case (c, tx) => tx.run(c) } + assertThrows[CheckHighForm.NotUniqueException] { + Seq(ToWorkingIR, CheckHighForm).foldLeft(Parser.parse(input)){ case (c, tx) => tx.run(c) } } } } diff --git a/src/test/scala/firrtlTests/InfoSpec.scala b/src/test/scala/firrtlTests/InfoSpec.scala index dbc997cd..9d6206af 100644 --- a/src/test/scala/firrtlTests/InfoSpec.scala +++ b/src/test/scala/firrtlTests/InfoSpec.scala @@ -66,7 +66,7 @@ class InfoSpec extends FirrtlFlatSpec { result should containLine (s"assign n = w | x; //$Info3") } - they should "be propagated on memories" in { + it should "be propagated on memories" in { val result = compileBody(s""" |input clock : Clock |input addr : UInt<5> @@ -102,7 +102,7 @@ class InfoSpec extends FirrtlFlatSpec { result should containLine (s"m[m_w_addr] <= m_w_data; //$Info1") } - they should "be propagated on instances" in { + it should "be propagated on instances" in { val result = compile(s""" |circuit Test : | module Child : diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala index be9d738b..69379c51 100644 --- a/src/test/scala/firrtlTests/LowerTypesSpec.scala +++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala @@ -12,7 +12,7 @@ import firrtl.transforms._ import firrtl._ class LowerTypesSpec extends FirrtlFlatSpec { - private val transforms = Seq( + private def transforms = Seq( ToWorkingIR, CheckHighForm, ResolveKinds, diff --git a/src/test/scala/firrtlTests/UniquifySpec.scala b/src/test/scala/firrtlTests/UniquifySpec.scala index afb82384..e64e9105 100644 --- a/src/test/scala/firrtlTests/UniquifySpec.scala +++ b/src/test/scala/firrtlTests/UniquifySpec.scala @@ -16,7 +16,7 @@ import firrtl.util.TestOptions class UniquifySpec extends FirrtlFlatSpec { - private val transforms = Seq( + private def transforms = Seq( ToWorkingIR, CheckHighForm, ResolveKinds, diff --git a/src/test/scala/firrtlTests/ZeroWidthTests.scala b/src/test/scala/firrtlTests/ZeroWidthTests.scala index eb3d1a96..f1dadcee 100644 --- a/src/test/scala/firrtlTests/ZeroWidthTests.scala +++ b/src/test/scala/firrtlTests/ZeroWidthTests.scala @@ -11,7 +11,7 @@ import firrtl.Parser import firrtl.passes._ class ZeroWidthTests extends FirrtlFlatSpec { - val transforms = Seq( + def transforms = Seq( ToWorkingIR, ResolveKinds, InferTypes, diff --git a/src/test/scala/firrtlTests/constraint/InequalitySpec.scala b/src/test/scala/firrtlTests/constraint/InequalitySpec.scala new file mode 100644 index 00000000..02a853cb --- /dev/null +++ b/src/test/scala/firrtlTests/constraint/InequalitySpec.scala @@ -0,0 +1,197 @@ +package firrtlTests.constraint + +import firrtl.constraint._ +import org.scalatest.{FlatSpec, Matchers} +import firrtl.ir.Closed + +class InequalitySpec extends FlatSpec with Matchers { + + behavior of "Constraints" + + "IsConstraints" should "reduce properly" in { + IsMin(Closed(0), Closed(1)) should be (Closed(0)) + IsMin(Closed(-1), Closed(1)) should be (Closed(-1)) + IsMax(Closed(-1), Closed(1)) should be (Closed(1)) + IsNeg(IsMul(Closed(-1), Closed(-2))) should be (Closed(-2)) + val x = IsMin(IsMul(Closed(1), VarCon("a")), Closed(2)) + x.children.toSet should be (IsMin(Closed(2), IsMul(Closed(1), VarCon("a"))).children.toSet) + } + + "IsAdd" should "reduce properly" in { + // All constants + IsAdd(Closed(-1), Closed(1)) should be (Closed(0)) + + // Pull Out IsMax + IsAdd(Closed(1), IsMax(Closed(1), VarCon("a"))) should be (IsMax(Closed(2), IsAdd(VarCon("a"), Closed(1)))) + IsAdd(Closed(1), IsMax(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be ( + IsMax(Seq(Closed(2), IsAdd(VarCon("a"), Closed(1)), IsAdd(VarCon("b"), Closed(1)))) + ) + + // Pull Out IsMin + IsAdd(Closed(1), IsMin(Closed(1), VarCon("a"))) should be (IsMin(Closed(2), IsAdd(VarCon("a"), Closed(1)))) + IsAdd(Closed(1), IsMin(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be ( + IsMin(Seq(Closed(2), IsAdd(VarCon("a"), Closed(1)), IsAdd(VarCon("b"), Closed(1)))) + ) + + // Add Zero + IsAdd(Closed(0), VarCon("a")) should be (VarCon("a")) + + // One argument + IsAdd(Seq(VarCon("a"))) should be (VarCon("a")) + } + + "IsMax" should "reduce properly" in { + // All constants + IsMax(Closed(-1), Closed(1)) should be (Closed(1)) + + // Flatten nested IsMax + IsMax(Closed(1), IsMax(Closed(1), VarCon("a"))) should be (IsMax(Closed(1), VarCon("a"))) + IsMax(Closed(1), IsMax(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be ( + IsMax(Seq(Closed(1), VarCon("a"), VarCon("b"))) + ) + + // Eliminate IsMins if possible + IsMax(Closed(2), IsMin(Closed(1), VarCon("a"))) should be (Closed(2)) + IsMax(Seq( + Closed(2), + IsMin(Closed(1), VarCon("a")), + IsMin(Closed(3), VarCon("b")) + )) should be ( + IsMax(Seq( + Closed(2), + IsMin(Closed(3), VarCon("b")) + )) + ) + + // One argument + IsMax(Seq(VarCon("a"))) should be (VarCon("a")) + IsMax(Seq(Closed(0))) should be (Closed(0)) + IsMax(Seq(IsMin(VarCon("a"), Closed(0)))) should be (IsMin(VarCon("a"), Closed(0))) + } + + "IsMin" should "reduce properly" in { + // All constants + IsMin(Closed(-1), Closed(1)) should be (Closed(-1)) + + // Flatten nested IsMin + IsMin(Closed(1), IsMin(Closed(1), VarCon("a"))) should be (IsMin(Closed(1), VarCon("a"))) + IsMin(Closed(1), IsMin(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be ( + IsMin(Seq(Closed(1), VarCon("a"), VarCon("b"))) + ) + + // Eliminate IsMaxs if possible + IsMin(Closed(1), IsMax(Closed(2), VarCon("a"))) should be (Closed(1)) + IsMin(Seq( + Closed(2), + IsMax(Closed(1), VarCon("a")), + IsMax(Closed(3), VarCon("b")) + )) should be ( + IsMin(Seq( + Closed(2), + IsMax(Closed(1), VarCon("a")) + )) + ) + + // One argument + IsMin(Seq(VarCon("a"))) should be (VarCon("a")) + IsMin(Seq(Closed(0))) should be (Closed(0)) + IsMin(Seq(IsMax(VarCon("a"), Closed(0)))) should be (IsMax(VarCon("a"), Closed(0))) + } + + "IsMul" should "reduce properly" in { + // All constants + IsMul(Closed(2), Closed(3)) should be (Closed(6)) + + // Pull out max, if positive stays max + IsMul(Closed(2), IsMax(Closed(3), VarCon("a"))) should be( + IsMax(Closed(6), IsMul(Closed(2), VarCon("a"))) + ) + + // Pull out max, if negative is min + IsMul(Closed(-2), IsMax(Closed(3), VarCon("a"))) should be( + IsMin(Closed(-6), IsMul(Closed(-2), VarCon("a"))) + ) + + // Pull out min, if positive stays min + IsMul(Closed(2), IsMin(Closed(3), VarCon("a"))) should be( + IsMin(Closed(6), IsMul(Closed(2), VarCon("a"))) + ) + + // Pull out min, if negative is max + IsMul(Closed(-2), IsMin(Closed(3), VarCon("a"))) should be( + IsMax(Closed(-6), IsMul(Closed(-2), VarCon("a"))) + ) + + // Times zero + IsMul(Closed(0), VarCon("x")) should be (Closed(0)) + + // Times 1 + IsMul(Closed(1), VarCon("x")) should be (VarCon("x")) + + // One argument + IsMul(Seq(Closed(0))) should be (Closed(0)) + IsMul(Seq(VarCon("a"))) should be (VarCon("a")) + + // No optimizations + val isMax = IsMax(VarCon("x"), VarCon("y")) + val isMin = IsMin(VarCon("x"), VarCon("y")) + val a = VarCon("a") + IsMul(a, isMax).children should be (Vector(a, isMax)) //non-known multiply + IsMul(a, isMin).children should be (Vector(a, isMin)) //non-known multiply + IsMul(Seq(Closed(2), isMin, isMin)).children should be (Vector(Closed(2), isMin, isMin)) //>1 min + IsMul(Seq(Closed(2), isMax, isMax)).children should be (Vector(Closed(2), isMax, isMax)) //>1 max + IsMul(Seq(Closed(2), isMin, isMax)).children should be (Vector(Closed(2), isMin, isMax)) //mixed min/max + } + + "IsNeg" should "reduce properly" in { + // All constants + IsNeg(Closed(1)) should be (Closed(-1)) + // Pull out max + IsNeg(IsMax(Closed(1), VarCon("a"))) should be (IsMin(Closed(-1), IsNeg(VarCon("a")))) + // Pull out min + IsNeg(IsMin(Closed(1), VarCon("a"))) should be (IsMax(Closed(-1), IsNeg(VarCon("a")))) + // Pull out add + IsNeg(IsAdd(Closed(1), VarCon("a"))) should be (IsAdd(Closed(-1), IsNeg(VarCon("a")))) + // Pull out mul + IsNeg(IsMul(Closed(2), VarCon("a"))) should be (IsMul(Closed(-2), VarCon("a"))) + // No optimizations + // (pow), (floor?) + IsNeg(IsPow(VarCon("x"))).children should be (Vector(IsPow(VarCon("x")))) + IsNeg(IsFloor(VarCon("x"))).children should be (Vector(IsFloor(VarCon("x")))) + } + + "IsPow" should "reduce properly" in { + // All constants + IsPow(Closed(1)) should be (Closed(2)) + // Pull out max + IsPow(IsMax(Closed(1), VarCon("a"))) should be (IsMax(Closed(2), IsPow(VarCon("a")))) + // Pull out min + IsPow(IsMin(Closed(1), VarCon("a"))) should be (IsMin(Closed(2), IsPow(VarCon("a")))) + // Pull out add + IsPow(IsAdd(Closed(1), VarCon("a"))) should be (IsMul(Closed(2), IsPow(VarCon("a")))) + // No optimizations + // (mul), (pow), (floor?) + IsPow(IsMul(Closed(2), VarCon("x"))).children should be (Vector(IsMul(Closed(2), VarCon("x")))) + IsPow(IsPow(VarCon("x"))).children should be (Vector(IsPow(VarCon("x")))) + IsPow(IsFloor(VarCon("x"))).children should be (Vector(IsFloor(VarCon("x")))) + } + + "IsFloor" should "reduce properly" in { + // All constants + IsFloor(Closed(1.9)) should be (Closed(1)) + IsFloor(Closed(-1.9)) should be (Closed(-2)) + // Pull out max + IsFloor(IsMax(Closed(1.9), VarCon("a"))) should be (IsMax(Closed(1), IsFloor(VarCon("a")))) + // Pull out min + IsFloor(IsMin(Closed(1.9), VarCon("a"))) should be (IsMin(Closed(1), IsFloor(VarCon("a")))) + // Cancel with another floor + IsFloor(IsFloor(VarCon("a"))) should be (IsFloor(VarCon("a"))) + // No optimizations + // (add), (mul), (pow) + IsFloor(IsMul(Closed(2), VarCon("x"))).children should be (Vector(IsMul(Closed(2), VarCon("x")))) + IsFloor(IsPow(VarCon("x"))).children should be (Vector(IsPow(VarCon("x")))) + IsFloor(IsAdd(Closed(1), VarCon("x"))).children should be (Vector(IsAdd(Closed(1), VarCon("x")))) + } + +} + diff --git a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala index 6bf86479..a34145ac 100644 --- a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala +++ b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala @@ -21,6 +21,36 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { } } + "Fixed types" should "infer add correctly if only precision unspecified" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + ResolveFlows, + CheckFlows, + new InferWidths, + CheckWidths) + val input = + """circuit Unit : + | module Unit : + | input a : Fixed<10><<2>> + | input b : Fixed<10><<0>> + | input c : Fixed<4><<3>> + | output d : Fixed<13> + | d <= add(a, add(b, c))""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input a : Fixed<10><<2>> + | input b : Fixed<10><<0>> + | input c : Fixed<4><<3>> + | output d : Fixed<13><<3>> + | d <= add(a, add(b, c))""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Fixed types" should "infer add correctly" in { val passes = Seq( ToWorkingIR, @@ -36,7 +66,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { """circuit Unit : | module Unit : | input a : Fixed<10><<2>> - | input b : Fixed<10> + | input b : Fixed<10><<0>> | input c : Fixed<4><<3>> | output d : Fixed | d <= add(a, add(b, c))""".stripMargin @@ -119,13 +149,13 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | module Unit : | input a : Fixed<10><<2>> | output d : Fixed - | d <= bpshl(a, 2)""".stripMargin + | d <= incp(a, 2)""".stripMargin val check = """circuit Unit : | module Unit : | input a : Fixed<10><<2>> | output d : Fixed<12><<4>> - | d <= bpshl(a, 2)""".stripMargin + | d <= incp(a, 2)""".stripMargin executeTest(input, check.split("\n") map normalized, passes) } @@ -145,13 +175,13 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | module Unit : | input a : Fixed<10><<2>> | output d : Fixed - | d <= bpshr(a, 2)""".stripMargin + | d <= decp(a, 2)""".stripMargin val check = """circuit Unit : | module Unit : | input a : Fixed<10><<2>> | output d : Fixed<8><<0>> - | d <= bpshr(a, 2)""".stripMargin + | d <= decp(a, 2)""".stripMargin executeTest(input, check.split("\n") map normalized, passes) } @@ -171,13 +201,13 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | module Unit : | input a : Fixed<10><<2>> | output d : Fixed - | d <= bpset(a, 3)""".stripMargin + | d <= setp(a, 3)""".stripMargin val check = """circuit Unit : | module Unit : | input a : Fixed<10><<2>> | output d : Fixed<11><<3>> - | d <= bpset(a, 3)""".stripMargin + | d <= setp(a, 3)""".stripMargin executeTest(input, check.split("\n") map normalized, passes) } diff --git a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala index 8686bd0f..f5b16e45 100644 --- a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala +++ b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala @@ -14,7 +14,6 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { (c: CircuitState, p: Transform) => p.runTransform(c) }.circuit val lines = c.serialize.split("\n") map normalized - println(c.serialize) expected foreach { e => lines should contain(e) @@ -37,7 +36,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { """circuit Unit : | module Unit : | input a : Fixed<10><<2>> - | input b : Fixed<10> + | input b : Fixed<10><<0>> | input c : Fixed<4><<3>> | output d : Fixed<<5>> | d <= add(a, add(b, c))""".stripMargin @@ -67,7 +66,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { """circuit Unit : | module Unit : | input a : Fixed<10><<2>> - | input b : Fixed<10> + | input b : Fixed<10><<0>> | input c : Fixed<4><<3>> | output d : Fixed<<5>> | d <- add(a, add(b, c))""".stripMargin @@ -99,7 +98,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { | module Unit : | input a : Fixed<10><<2>> | output d : Fixed<12><<4>> - | d <= bpshl(a, 2)""".stripMargin + | d <= incp(a, 2)""".stripMargin val check = """circuit Unit : | module Unit : @@ -126,7 +125,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { | module Unit : | input a : Fixed<10><<2>> | output d : Fixed<9><<1>> - | d <= bpshr(a, 1)""".stripMargin + | d <= decp(a, 1)""".stripMargin val check = """circuit Unit : | module Unit : @@ -153,7 +152,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { | module Unit : | input a : Fixed<10><<2>> | output d : Fixed - | d <= bpset(a, 3)""".stripMargin + | d <= setp(a, 3)""".stripMargin val check = """circuit Unit : | module Unit : @@ -181,7 +180,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { class CheckChirrtlTransform extends SeqTransform { def inputForm = ChirrtlForm def outputForm = ChirrtlForm - val transforms = Seq(passes.CheckChirrtl) + def transforms = Seq(passes.CheckChirrtl) } val chirrtlTransform = new CheckChirrtlTransform diff --git a/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala b/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala new file mode 100644 index 00000000..20fdeee1 --- /dev/null +++ b/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala @@ -0,0 +1,183 @@ +// See LICENSE for license details. + +package firrtlTests.interval + +import firrtl.Implicits.constraint2bound +import firrtl.{ChirrtlForm, CircuitState, LowFirrtlCompiler, Parser} +import firrtl.ir._ + +import scala.math.BigDecimal.RoundingMode._ +import firrtl.Parser.IgnoreInfo +import firrtl.constraint._ +import firrtlTests.FirrtlFlatSpec + +class IntervalMathSpec extends FirrtlFlatSpec { + val SumPattern = """.*output sum.*<(\d+)>.*""".r + val ProductPattern = """.*output product.*<(\d+)>.*""".r + val DifferencePattern = """.*output difference.*<(\d+)>.*""".r + val ComparisonPattern = """.*output (\w+).*UInt<(\d+)>.*""".r + val ShiftLeftPattern = """.*output shl.*<(\d+)>.*""".r + val ShiftRightPattern = """.*output shr.*<(\d+)>.*""".r + val DShiftLeftPattern = """.*output dshl.*<(\d+)>.*""".r + val DShiftRightPattern = """.*output dshr.*<(\d+)>.*""".r + val ArithAssignPattern = """\s*(\w+) <= asSInt\(bits\((\w+)\((.*)\).*\)\)\s*""".r + def getBound(bound: String, value: Double): IsKnown = bound match { + case "[" => Closed(BigDecimal(value)) + case "]" => Closed(BigDecimal(value)) + case "(" => Open(BigDecimal(value)) + case ")" => Open(BigDecimal(value)) + } + + val prec = 0.5 + + for { + lb1 <- Seq("[", "(") + lv1 <- Range.Double(-1.0, 1.0, prec) + uv1 <- if(lb1 == "[") Range.Double(lv1, 1.0, prec) else Range.Double(lv1 + prec, 1.0, prec) + ub1 <- if (lv1 == uv1) Seq("]") else Seq("]", ")") + bp1 <- 0 to 1 + lb2 <- Seq("[", "(") + lv2 <- Range.Double(-1.0, 1.0, prec) + uv2 <- if(lb2 == "[") Range.Double(lv2, 1.0, prec) else Range.Double(lv2 + prec, 1.0, prec) + ub2 <- if (lv2 == uv2) Seq("]") else Seq("]", ")") + bp2 <- 0 to 1 + } { + val it1 = IntervalType(getBound(lb1, lv1), getBound(ub1, uv1), IntWidth(bp1.toInt)) + val it2 = IntervalType(getBound(lb2, lv2), getBound(ub2, uv2), IntWidth(bp2.toInt)) + (it1.range, it2.range) match { + case (Some(Nil), _) => + case (_, Some(Nil)) => + case _ => + def config = s"$lb1$lv1,$uv1$ub1.$bp1 and $lb2$lv2,$uv2$ub2.$bp2" + + s"Configuration $config" should "pass" in { + + val input = + s"""circuit Unit : + | module Unit : + | input in1 : Interval$lb1$lv1, $uv1$ub1.$bp1 + | input in2 : Interval$lb2$lv2, $uv2$ub2.$bp2 + | input amt : UInt<3> + | output sum : Interval + | output difference : Interval + | output product : Interval + | output shl : Interval + | output shr : Interval + | output dshl : Interval + | output dshr : Interval + | output lt : UInt + | output leq : UInt + | output gt : UInt + | output geq : UInt + | output eq : UInt + | output neq : UInt + | output cat : UInt + | sum <= add(in1, in2) + | difference <= sub(in1, in2) + | product <= mul(in1, in2) + | shl <= shl(in1, 3) + | shr <= shr(in1, 3) + | dshl <= dshl(in1, amt) + | dshr <= dshr(in1, amt) + | lt <= lt(in1, in2) + | leq <= leq(in1, in2) + | gt <= gt(in1, in2) + | geq <= geq(in1, in2) + | eq <= eq(in1, in2) + | neq <= lt(in1, in2) + | cat <= cat(in1, in2) + | """.stripMargin + + val lowerer = new LowFirrtlCompiler + val res = lowerer.compileAndEmit(CircuitState(parse(input), ChirrtlForm)) + val output = res.getEmittedCircuit.value split "\n" + val min1 = Closed(it1.min.get) + val max1 = Closed(it1.max.get) + val min2 = Closed(it2.min.get) + val max2 = Closed(it2.max.get) + for (line <- output) { + line match { + case SumPattern(varWidth) => + val bp = IntWidth(Math.max(bp1.toInt, bp2.toInt)) + val it = IntervalType(IsAdd(min1, min2), IsAdd(max1, max2), bp) + assert(varWidth.toInt == it.width.asInstanceOf[IntWidth].width, s"$line,${it.range}") + case ProductPattern(varWidth) => + val bp = IntWidth(bp1.toInt + bp2.toInt) + val lv = IsMin(Seq(IsMul(min1, min2), IsMul(min1, max2), IsMul(max1, min2), IsMul(max1, max2))) + val uv = IsMax(Seq(IsMul(min1, min2), IsMul(min1, max2), IsMul(max1, min2), IsMul(max1, max2))) + assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "product") + case DifferencePattern(varWidth) => + val bp = IntWidth(Math.max(bp1.toInt, bp2.toInt)) + val lv = min1 + max2.neg + val uv = max1 + min2.neg + assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "diff") + case ShiftLeftPattern(varWidth) => + val bp = IntWidth(bp1.toInt) + val lv = min1 * Closed(8) + val uv = max1 * Closed(8) + val it = IntervalType(lv, uv, bp) + assert(varWidth.toInt == it.width.asInstanceOf[IntWidth].width, "shl") + case ShiftRightPattern(varWidth) => + val bp = IntWidth(bp1.toInt) + val lv = min1 * Closed(1/3) + val uv = max1 * Closed(1/3) + assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "shr") + case DShiftLeftPattern(varWidth) => + val bp = IntWidth(bp1.toInt) + val lv = min1 * Closed(128) + val uv = max1 * Closed(128) + assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "dshl") + case DShiftRightPattern(varWidth) => + val bp = IntWidth(bp1.toInt) + val lv = min1 + val uv = max1 + assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "dshr") + case ComparisonPattern(varWidth) => assert(varWidth.toInt == 1, "==") + case ArithAssignPattern(varName, operation, args) => + val arg1 = if(IntervalType(getBound(lb1, lv1), getBound(ub1, uv1), IntWidth(bp1)).width == IntWidth(0)) """SInt<1>("h0")""" else "in1" + val arg2 = if(IntervalType(getBound(lb2, lv2), getBound(ub2, uv2), IntWidth(bp2)).width == IntWidth(0)) """SInt<1>("h0")""" else "in2" + varName match { + case "sum" => + assert(operation === "add", s"""var sum should be result of an add in ${output.mkString("\n")}""") + if (bp1 > bp2) { + if (arg1 != arg2) assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(args.contains(s"shl($arg2, ${bp1 - bp2})"), + s"$config second arg incorrect in $line") + } else if (bp1 < bp2) { + assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"), + s"$config second arg incorrect in $line") + assert(!args.contains("shl($arg2"), s"$config second arg should be just $arg2 in $line") + } else { + assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") + } + case "product" => + assert(operation === "mul", s"var sum should be result of an add in $line") + assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") + case "difference" => + assert(operation === "sub", s"var difference should be result of an sub in $line") + if (bp1 > bp2) { + if (arg1 != arg2) assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(args.contains(s"shl($arg2, ${bp1 - bp2})"), + s"$config second arg incorrect in $line") + } else if (bp1 < bp2) { + assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"), + s"$config second arg incorrect in $line") + if (arg1 != arg2) assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") + } else { + assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") + } + case _ => + } + case _ => + } + } + } + } + } +} + + +// vim: set ts=4 sw=4 et: diff --git a/src/test/scala/firrtlTests/interval/IntervalSpec.scala b/src/test/scala/firrtlTests/interval/IntervalSpec.scala new file mode 100644 index 00000000..37d79c84 --- /dev/null +++ b/src/test/scala/firrtlTests/interval/IntervalSpec.scala @@ -0,0 +1,530 @@ +package firrtlTests +package interval + +import java.io._ + +import firrtl._ +import firrtl.ir.Circuit +import firrtl.passes._ +import firrtl.Parser.IgnoreInfo +import firrtl.passes.CheckTypes.InvalidConnect +import firrtl.passes.CheckWidths.DisjointSqueeze + +class IntervalSpec extends FirrtlFlatSpec { + private def executeTest(input: String, expected: Seq[String], passes: Seq[Transform]) = { + val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { + (c: Circuit, p: Transform) => + p.runTransform(CircuitState(c, UnknownForm, AnnotationSeq(Nil), None)).circuit + } + val lines = c.serialize.split("\n") map normalized + + expected foreach { e => + lines should contain(e) + } + } + + "Interval types" should "parse correctly" in { + val passes = Seq(ToWorkingIR) + val input = + """circuit Unit : + | module Unit : + | input in0 : Interval(-0.32, 10.1).4 + | input in1 : Interval[0, 10.1].4 + | input in2 : Interval(-0.32, 10].4 + | input in3 : Interval[-3, 10.1).4 + | input in4 : Interval(-0.32, 10.1) + | input in5 : Interval.4 + | input in6 : Interval + | output out0 : Interval.2 + | output out1 : Interval + | out0 <= add(in0, add(in1, add(in2, add(in3, add(in4, add(in5, in6)))))) + | out1 <= add(in0, add(in1, add(in2, add(in3, add(in4, add(in5, in6))))))""".stripMargin + executeTest(input, input.split("\n") map normalized, passes) + } + + "Interval types" should "infer bp correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints()) + val input = + """circuit Unit : + | module Unit : + | input in0 : Interval(-0.32, 10.1).4 + | input in1 : Interval[0, 10.1].3 + | input in2 : Interval(-0.32, 10].2 + | output out0 : Interval + | out0 <= add(in0, add(in1, in2))""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input in0 : Interval(-0.32, 10.1).4 + | input in1 : Interval[0, 10.1].3 + | input in2 : Interval(-0.32, 10].2 + | output out0 : Interval.4 + | out0 <= add(in0, add(in1, in2))""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + + "Interval types" should "trim known intervals correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals()) + val input = + """circuit Unit : + | module Unit : + | input in0 : Interval(-0.32, 10.1).4 + | input in1 : Interval[0, 10.1].3 + | input in2 : Interval(-0.32, 10].2 + | output out0 : Interval + | out0 <= add(in0, add(in1, in2))""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input in0 : Interval[-0.3125, 10.0625].4 + | input in1 : Interval[0, 10].3 + | input in2 : Interval[-0.25, 10].2 + | output out0 : Interval.4 + | out0 <= add(in0, incp(add(in1, incp(in2, 1)), 1))""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + + "Interval types" should "infer intervals correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) + val input = + """circuit Unit : + | module Unit : + | input in0 : Interval(0, 10).4 + | input in1 : Interval(0, 10].3 + | input in2 : Interval(-1, 3].2 + | output out0 : Interval + | output out1 : Interval + | output out2 : Interval + | out0 <= add(in0, add(in1, in2)) + | out1 <= mul(in0, mul(in1, in2)) + | out2 <= sub(in0, sub(in1, in2))""".stripMargin + val check = + """output out0 : Interval[-0.5625, 22.9375].4 + |output out1 : Interval[-74.53125, 298.125].9 + |output out2 : Interval[-10.6875, 12.8125].4""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + + "Interval types" should "be removed correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths(), new RemoveIntervals()) + val input = + """circuit Unit : + | module Unit : + | input in0 : Interval(0, 10).4 + | input in1 : Interval(0, 10].3 + | input in2 : Interval(-1, 3].2 + | output out0 : Interval + | output out1 : Interval + | output out2 : Interval + | out0 <= add(in0, add(in1, in2)) + | out1 <= mul(in0, mul(in1, in2)) + | out2 <= sub(in0, sub(in1, in2))""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input in0 : SInt<9> + | input in1 : SInt<8> + | input in2 : SInt<5> + | output out0 : SInt<10> + | output out1 : SInt<19> + | output out2 : SInt<9> + | out0 <= add(in0, shl(add(in1, shl(in2, 1)), 1)) + | out1 <= mul(in0, mul(in1, in2)) + | out2 <= sub(in0, shl(sub(in1, shl(in2, 1)), 1))""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + +"Interval types" should "infer multiplication by zero correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) + val input = + s"""circuit Unit : + | module Unit : + | input in1 : Interval[0, 0.5].1 + | input in2 : Interval[0, 0].1 + | output mul : Interval + | mul <= mul(in2, in1) + | """.stripMargin + val check = s"""output mul : Interval[0, 0].2 """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) +} + + "Interval types" should "infer muxes correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) + val input = + s"""circuit Unit : + | module Unit : + | input p : UInt<1> + | input in1 : Interval[0, 0.5].1 + | input in2 : Interval[0, 0].1 + | output out : Interval + | out <= mux(p, in2, in1) + | """.stripMargin + val check = s"""output out : Interval[0, 0.5].1 """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "infer dshl correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveKinds, ResolveGenders, new InferBinaryPoints(), new TrimIntervals, new InferWidths()) + val input = + s"""circuit Unit : + | module Unit : + | input p : UInt<3> + | input in1 : Interval[-1, 1].0 + | output out : Interval + | out <= dshl(in1, p) + | """.stripMargin + val check = s"""output out : Interval[-128, 128].0 """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "infer asInterval correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferWidths()) + val input = + s"""circuit Unit : + | module Unit : + | input p : UInt<3> + | output out : Interval + | out <= asInterval(p, 0, 4, 1) + | """.stripMargin + val check = s"""output out : Interval[0, 2].1 """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "do wrap/clip correctly" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck()) + val input = + s"""circuit Unit : + | module Unit : + | input s: SInt<2> + | input u: UInt<3> + | input in1: Interval[-3, 5].0 + | output wrap3: Interval + | output wrap4: Interval + | output wrap5: Interval + | output wrap6: Interval + | output wrap7: Interval + | output clip3: Interval + | output clip4: Interval + | output clip5: Interval + | output clip6: Interval + | output clip7: Interval + | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0)) + | wrap4 <= wrap(in1, asInterval(s, -1, 1, 0)) + | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0)) + | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0)) + | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0)) + | clip3 <= clip(in1, asInterval(s, -2, 4, 0)) + | clip4 <= clip(in1, asInterval(s, -1, 1, 0)) + | clip5 <= clip(in1, asInterval(s, -4, 4, 0)) + | clip6 <= clip(in1, asInterval(s, -1, 7, 0)) + | clip7 <= clip(in1, asInterval(s, -4, 7, 0)) + """.stripMargin + //| output wrap1: Interval + //| output wrap2: Interval + //| output clip1: Interval + //| output clip2: Interval + //| wrap1 <= wrap(in1, u, 0) + //| wrap2 <= wrap(in1, s, 0) + //| clip1 <= clip(in1, u) + //| clip2 <= clip(in1, s) + val check = s""" + | output wrap3 : Interval[-2, 4].0 + | output wrap4 : Interval[-1, 1].0 + | output wrap5 : Interval[-4, 4].0 + | output wrap6 : Interval[-1, 7].0 + | output wrap7 : Interval[-4, 7].0 + | output clip3 : Interval[-2, 4].0 + | output clip4 : Interval[-1, 1].0 + | output clip5 : Interval[-3, 4].0 + | output clip6 : Interval[-1, 5].0 + | output clip7 : Interval[-3, 5].0 """.stripMargin + // TODO: this optimization + //| output wrap1 : Interval[0, 7].0 + //| output wrap2 : Interval[-2, 1].0 + //| output clip1 : Interval[0, 5].0 + //| output clip2 : Interval[-2, 1].0 + //| output wrap7 : Interval[-3, 5].0 + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "remove wrap/clip correctly" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck(), new RemoveIntervals()) + val input = + s"""circuit Unit : + | module Unit : + | input s: SInt<2> + | input u: UInt<3> + | input in1: Interval[-3, 5].0 + | output wrap3: Interval + | output wrap5: Interval + | output wrap6: Interval + | output wrap7: Interval + | output clip3: Interval + | output clip4: Interval + | output clip5: Interval + | output clip6: Interval + | output clip7: Interval + | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0)) + | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0)) + | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0)) + | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0)) + | clip3 <= clip(in1, asInterval(s, -2, 4, 0)) + | clip4 <= clip(in1, asInterval(s, -1, 1, 0)) + | clip5 <= clip(in1, asInterval(s, -4, 4, 0)) + | clip6 <= clip(in1, asInterval(s, -1, 7, 0)) + | clip7 <= clip(in1, asInterval(s, -4, 7, 0)) + | """.stripMargin + val check = s""" + | wrap3 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<4>("h7")), mux(lt(in1, SInt<2>("h-2")), add(in1, SInt<4>("h7")), in1)) + | wrap5 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), in1) + | wrap6 <= mux(lt(in1, SInt<1>("h-1")), add(in1, SInt<5>("h9")), in1) + | wrap7 <= in1 + | clip3 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<2>("h-2")), SInt<2>("h-2"), in1)) + | clip4 <= mux(gt(in1, SInt<2>("h1")), SInt<2>("h1"), mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1)) + | clip5 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), in1) + | clip6 <= mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1) + | clip7 <= in1 + """.stripMargin + //| output wrap4: Interval + //| wrap4 <= wrap(in1, asInterval(s, -1, 1, 0), 0) + //| wrap4 <= add(rem(sub(in1, SInt<1>("h-1")), sub(SInt<2>("h1"), SInt<1>("h-1"))), SInt<1>("h-1")) + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "shift wrap/clip correctly" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals()) + val input = + s"""circuit Unit : + | module Unit : + | input s: SInt<2> + | input in1: Interval[-3, 5].1 + | output wrap1: Interval + | output clip1: Interval + | wrap1 <= wrap(in1, asInterval(s, -2, 2, 0)) + | clip1 <= clip(in1, asInterval(s, -2, 2, 0)) + | """.stripMargin + val check = s""" + | wrap1 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), mux(lt(in1, SInt<3>("h-4")), add(in1, SInt<5>("h9")), in1)) + | clip1 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<3>("h-4")), SInt<3>("h-4"), in1)) + """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "infer negative binary points" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck()) + val input = + s"""circuit Unit : + | module Unit : + | input in1: Interval[-2, 4].-1 + | input in2: Interval[-4, 8].-2 + | output out: Interval + | out <= add(in1, in2) + | """.stripMargin + val check = s""" + | output out : Interval[-6, 12].-1 + """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "remove negative binary points" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths(), new RemoveIntervals()) + val input = + s"""circuit Unit : + | module Unit : + | input in1: Interval[-2, 4].-1 + | input in2: Interval[-4, 8].-2 + | output out: Interval.0 + | out <= add(in1, in2) + | """.stripMargin + val check = s""" + | output out : SInt<5> + | out <= shl(add(in1, shl(in2, 1)), 1) + """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "implement squz properly" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck) + val input = + s"""circuit Unit : + | module Unit : + | input min: Interval[-1, 4].1 + | input max: Interval[-3, 5].1 + | input left: Interval[-3, 3].1 + | input right: Interval[0, 5].1 + | input off: Interval[-1, 4].2 + | output minMax: Interval + | output maxMin: Interval + | output minLeft: Interval + | output leftMin: Interval + | output minRight: Interval + | output rightMin: Interval + | output minOff: Interval + | output offMin: Interval + | + | minMax <= squz(min, max) + | maxMin <= squz(max, min) + | minLeft <= squz(min, left) + | leftMin <= squz(left, min) + | minRight <= squz(min, right) + | rightMin <= squz(right, min) + | minOff <= squz(min, off) + | offMin <= squz(off, min) + | """.stripMargin + val check = + s""" + | output minMax : Interval[-1, 4].1 + | output maxMin : Interval[-1, 4].1 + | output minLeft : Interval[-1, 3].1 + | output leftMin : Interval[-1, 3].1 + | output minRight : Interval[0, 4].1 + | output rightMin : Interval[0, 4].1 + | output minOff : Interval[-1, 4].1 + | output offMin : Interval[-1, 4].2 + """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "lower squz properly" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals) + val input = + s"""circuit Unit : + | module Unit : + | input min: Interval[-1, 4].1 + | input max: Interval[-3, 5].1 + | input left: Interval[-3, 3].1 + | input right: Interval[0, 5].1 + | input off: Interval[-1, 4].2 + | output minMax: Interval + | output maxMin: Interval + | output minLeft: Interval + | output leftMin: Interval + | output minRight: Interval + | output rightMin: Interval + | output minOff: Interval + | output offMin: Interval + | + | minMax <= squz(min, max) + | maxMin <= squz(max, min) + | minLeft <= squz(min, left) + | leftMin <= squz(left, min) + | minRight <= squz(min, right) + | rightMin <= squz(right, min) + | minOff <= squz(min, off) + | offMin <= squz(off, min) + | """.stripMargin + val check = + s""" + | minMax <= asSInt(bits(min, 4, 0)) + | maxMin <= asSInt(bits(max, 4, 0)) + | minLeft <= asSInt(bits(min, 3, 0)) + | leftMin <= left + | minRight <= asSInt(bits(min, 4, 0)) + | rightMin <= asSInt(bits(right, 4, 0)) + | minOff <= asSInt(bits(min, 4, 0)) + | offMin <= asSInt(bits(off, 5, 0)) + """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Assigning a larger interval to a smaller interval" should "error!" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals) + val input = + s"""circuit Unit : + | module Unit : + | input in: Interval[1, 4].1 + | output out: Interval[2, 3].1 + | out <= in + | """.stripMargin + intercept[InvalidConnect]{ + executeTest(input, Nil, passes) + } + } + "Assigning a more precise interval to a less precise interval" should "error!" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals) + val input = + s"""circuit Unit : + | module Unit : + | input in: Interval[2, 3].3 + | output out: Interval[2, 3].1 + | out <= in + | """.stripMargin + intercept[InvalidConnect]{ + executeTest(input, Nil, passes) + } + } + "Chick's example" should "work" in { + val input = + s"""circuit IntervalChainedSubTester : + | module IntervalChainedSubTester : + | input clock : Clock + | input reset : UInt<1> + | node _GEN_0 = sub(SInt<6>("h11"), SInt<6>("h2")) @[IntervalSpec.scala 337:26 IntervalSpec.scala 337:26] + | node _GEN_1 = bits(_GEN_0, 4, 0) @[IntervalSpec.scala 337:26 IntervalSpec.scala 337:26] + | node intervalResult = asSInt(_GEN_1) @[IntervalSpec.scala 337:26 IntervalSpec.scala 337:26] + | skip + | node _T_1 = asUInt(intervalResult) @[IntervalSpec.scala 338:50] + | skip + | node _T_3 = eq(reset, UInt<1>("h0")) @[IntervalSpec.scala 338:9] + | node _T_4 = eq(intervalResult, SInt<5>("hf")) @[IntervalSpec.scala 339:25] + | skip + | node _T_6 = or(_T_4, reset) @[IntervalSpec.scala 339:9] + | node _T_7 = eq(_T_6, UInt<1>("h0")) @[IntervalSpec.scala 339:9] + | skip + | skip + | printf(clock, _T_3, "Interval result: %d", _T_1) @[IntervalSpec.scala 338:9] + | printf(clock, _T_7, "Assertion failed at IntervalSpec.scala:339 assert(intervalResult === 15.I)") @[IntervalSpec.scala 339:9] + | stop(clock, _T_7, 1) @[IntervalSpec.scala 339:9] + | stop(clock, _T_3, 0) @[IntervalSpec.scala 340:7] + | + """.stripMargin + compileToVerilog(input) + } + + "Squeeze with disjoint intervals" should "error" in { + intercept[DisjointSqueeze] { + val input = + s"""circuit Unit : + | module Unit : + | input in1: Interval[2, 3).3 + | input in2: Interval[3, 6].3 + | node out = squz(in1, in2) + """.stripMargin + compileToVerilog(input) + } + intercept[DisjointSqueeze] { + val input = + s"""circuit Unit : + | module Unit : + | input in1: Interval[2, 3).3 + | input in2: Interval[3, 6].3 + | node out = squz(in2, in1) + """.stripMargin + compileToVerilog(input) + } + } + + "Clip with disjoint intervals" should "work" in { + compileToVerilog( + s"""circuit Unit : + | module Unit : + | input in1: Interval[2, 3).3 + | input in2: Interval[3, 6].3 + | output out: Interval + | out <= clip(in1, in2) + """.stripMargin + ) + compileToVerilog( + s"""circuit Unit : + | module Unit : + | input in1: Interval[2, 3).3 + | input in2: Interval[4, 6].3 + | node out = clip(in1, in2) + """.stripMargin + ) + } + + + "Wrap with remainder" should "error" in { + intercept[WrapWithRemainder] { + val input = + s"""circuit Unit : + | module Unit : + | input in1: Interval[0, 300).3 + | input in2: Interval[3, 6].3 + | node out = wrap(in1, in2) + """.stripMargin + compileToVerilog(input) + } + } +} diff --git a/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala b/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala index 46fb310a..88095830 100644 --- a/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala +++ b/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala @@ -152,7 +152,7 @@ class InferWidthsWithAnnosSpec extends FirrtlFlatSpec { } "InferWidthsWithAnnos" should "work with WiringTransform" in { - def transforms = Seq( + def transforms() = Seq( ToWorkingIR, ResolveKinds, InferTypes, |
