aboutsummaryrefslogtreecommitdiff
path: root/src/test/scala/firrtlTests/interval/IntervalSpec.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/test/scala/firrtlTests/interval/IntervalSpec.scala')
-rw-r--r--src/test/scala/firrtlTests/interval/IntervalSpec.scala530
1 files changed, 530 insertions, 0 deletions
diff --git a/src/test/scala/firrtlTests/interval/IntervalSpec.scala b/src/test/scala/firrtlTests/interval/IntervalSpec.scala
new file mode 100644
index 00000000..37d79c84
--- /dev/null
+++ b/src/test/scala/firrtlTests/interval/IntervalSpec.scala
@@ -0,0 +1,530 @@
+package firrtlTests
+package interval
+
+import java.io._
+
+import firrtl._
+import firrtl.ir.Circuit
+import firrtl.passes._
+import firrtl.Parser.IgnoreInfo
+import firrtl.passes.CheckTypes.InvalidConnect
+import firrtl.passes.CheckWidths.DisjointSqueeze
+
+class IntervalSpec extends FirrtlFlatSpec {
+ private def executeTest(input: String, expected: Seq[String], passes: Seq[Transform]) = {
+ val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
+ (c: Circuit, p: Transform) =>
+ p.runTransform(CircuitState(c, UnknownForm, AnnotationSeq(Nil), None)).circuit
+ }
+ val lines = c.serialize.split("\n") map normalized
+
+ expected foreach { e =>
+ lines should contain(e)
+ }
+ }
+
+ "Interval types" should "parse correctly" in {
+ val passes = Seq(ToWorkingIR)
+ val input =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : Interval(-0.32, 10.1).4
+ | input in1 : Interval[0, 10.1].4
+ | input in2 : Interval(-0.32, 10].4
+ | input in3 : Interval[-3, 10.1).4
+ | input in4 : Interval(-0.32, 10.1)
+ | input in5 : Interval.4
+ | input in6 : Interval
+ | output out0 : Interval.2
+ | output out1 : Interval
+ | out0 <= add(in0, add(in1, add(in2, add(in3, add(in4, add(in5, in6))))))
+ | out1 <= add(in0, add(in1, add(in2, add(in3, add(in4, add(in5, in6))))))""".stripMargin
+ executeTest(input, input.split("\n") map normalized, passes)
+ }
+
+ "Interval types" should "infer bp correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints())
+ val input =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : Interval(-0.32, 10.1).4
+ | input in1 : Interval[0, 10.1].3
+ | input in2 : Interval(-0.32, 10].2
+ | output out0 : Interval
+ | out0 <= add(in0, add(in1, in2))""".stripMargin
+ val check =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : Interval(-0.32, 10.1).4
+ | input in1 : Interval[0, 10.1].3
+ | input in2 : Interval(-0.32, 10].2
+ | output out0 : Interval.4
+ | out0 <= add(in0, add(in1, in2))""".stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+
+ "Interval types" should "trim known intervals correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals())
+ val input =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : Interval(-0.32, 10.1).4
+ | input in1 : Interval[0, 10.1].3
+ | input in2 : Interval(-0.32, 10].2
+ | output out0 : Interval
+ | out0 <= add(in0, add(in1, in2))""".stripMargin
+ val check =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : Interval[-0.3125, 10.0625].4
+ | input in1 : Interval[0, 10].3
+ | input in2 : Interval[-0.25, 10].2
+ | output out0 : Interval.4
+ | out0 <= add(in0, incp(add(in1, incp(in2, 1)), 1))""".stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+
+ "Interval types" should "infer intervals correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths())
+ val input =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : Interval(0, 10).4
+ | input in1 : Interval(0, 10].3
+ | input in2 : Interval(-1, 3].2
+ | output out0 : Interval
+ | output out1 : Interval
+ | output out2 : Interval
+ | out0 <= add(in0, add(in1, in2))
+ | out1 <= mul(in0, mul(in1, in2))
+ | out2 <= sub(in0, sub(in1, in2))""".stripMargin
+ val check =
+ """output out0 : Interval[-0.5625, 22.9375].4
+ |output out1 : Interval[-74.53125, 298.125].9
+ |output out2 : Interval[-10.6875, 12.8125].4""".stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+
+ "Interval types" should "be removed correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths(), new RemoveIntervals())
+ val input =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : Interval(0, 10).4
+ | input in1 : Interval(0, 10].3
+ | input in2 : Interval(-1, 3].2
+ | output out0 : Interval
+ | output out1 : Interval
+ | output out2 : Interval
+ | out0 <= add(in0, add(in1, in2))
+ | out1 <= mul(in0, mul(in1, in2))
+ | out2 <= sub(in0, sub(in1, in2))""".stripMargin
+ val check =
+ """circuit Unit :
+ | module Unit :
+ | input in0 : SInt<9>
+ | input in1 : SInt<8>
+ | input in2 : SInt<5>
+ | output out0 : SInt<10>
+ | output out1 : SInt<19>
+ | output out2 : SInt<9>
+ | out0 <= add(in0, shl(add(in1, shl(in2, 1)), 1))
+ | out1 <= mul(in0, mul(in1, in2))
+ | out2 <= sub(in0, shl(sub(in1, shl(in2, 1)), 1))""".stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+
+"Interval types" should "infer multiplication by zero correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in1 : Interval[0, 0.5].1
+ | input in2 : Interval[0, 0].1
+ | output mul : Interval
+ | mul <= mul(in2, in1)
+ | """.stripMargin
+ val check = s"""output mul : Interval[0, 0].2 """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+}
+
+ "Interval types" should "infer muxes correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input p : UInt<1>
+ | input in1 : Interval[0, 0.5].1
+ | input in2 : Interval[0, 0].1
+ | output out : Interval
+ | out <= mux(p, in2, in1)
+ | """.stripMargin
+ val check = s"""output out : Interval[0, 0.5].1 """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "infer dshl correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveKinds, ResolveGenders, new InferBinaryPoints(), new TrimIntervals, new InferWidths())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input p : UInt<3>
+ | input in1 : Interval[-1, 1].0
+ | output out : Interval
+ | out <= dshl(in1, p)
+ | """.stripMargin
+ val check = s"""output out : Interval[-128, 128].0 """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "infer asInterval correctly" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferWidths())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input p : UInt<3>
+ | output out : Interval
+ | out <= asInterval(p, 0, 4, 1)
+ | """.stripMargin
+ val check = s"""output out : Interval[0, 2].1 """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "do wrap/clip correctly" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input s: SInt<2>
+ | input u: UInt<3>
+ | input in1: Interval[-3, 5].0
+ | output wrap3: Interval
+ | output wrap4: Interval
+ | output wrap5: Interval
+ | output wrap6: Interval
+ | output wrap7: Interval
+ | output clip3: Interval
+ | output clip4: Interval
+ | output clip5: Interval
+ | output clip6: Interval
+ | output clip7: Interval
+ | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0))
+ | wrap4 <= wrap(in1, asInterval(s, -1, 1, 0))
+ | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0))
+ | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0))
+ | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0))
+ | clip3 <= clip(in1, asInterval(s, -2, 4, 0))
+ | clip4 <= clip(in1, asInterval(s, -1, 1, 0))
+ | clip5 <= clip(in1, asInterval(s, -4, 4, 0))
+ | clip6 <= clip(in1, asInterval(s, -1, 7, 0))
+ | clip7 <= clip(in1, asInterval(s, -4, 7, 0))
+ """.stripMargin
+ //| output wrap1: Interval
+ //| output wrap2: Interval
+ //| output clip1: Interval
+ //| output clip2: Interval
+ //| wrap1 <= wrap(in1, u, 0)
+ //| wrap2 <= wrap(in1, s, 0)
+ //| clip1 <= clip(in1, u)
+ //| clip2 <= clip(in1, s)
+ val check = s"""
+ | output wrap3 : Interval[-2, 4].0
+ | output wrap4 : Interval[-1, 1].0
+ | output wrap5 : Interval[-4, 4].0
+ | output wrap6 : Interval[-1, 7].0
+ | output wrap7 : Interval[-4, 7].0
+ | output clip3 : Interval[-2, 4].0
+ | output clip4 : Interval[-1, 1].0
+ | output clip5 : Interval[-3, 4].0
+ | output clip6 : Interval[-1, 5].0
+ | output clip7 : Interval[-3, 5].0 """.stripMargin
+ // TODO: this optimization
+ //| output wrap1 : Interval[0, 7].0
+ //| output wrap2 : Interval[-2, 1].0
+ //| output clip1 : Interval[0, 5].0
+ //| output clip2 : Interval[-2, 1].0
+ //| output wrap7 : Interval[-3, 5].0
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "remove wrap/clip correctly" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck(), new RemoveIntervals())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input s: SInt<2>
+ | input u: UInt<3>
+ | input in1: Interval[-3, 5].0
+ | output wrap3: Interval
+ | output wrap5: Interval
+ | output wrap6: Interval
+ | output wrap7: Interval
+ | output clip3: Interval
+ | output clip4: Interval
+ | output clip5: Interval
+ | output clip6: Interval
+ | output clip7: Interval
+ | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0))
+ | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0))
+ | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0))
+ | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0))
+ | clip3 <= clip(in1, asInterval(s, -2, 4, 0))
+ | clip4 <= clip(in1, asInterval(s, -1, 1, 0))
+ | clip5 <= clip(in1, asInterval(s, -4, 4, 0))
+ | clip6 <= clip(in1, asInterval(s, -1, 7, 0))
+ | clip7 <= clip(in1, asInterval(s, -4, 7, 0))
+ | """.stripMargin
+ val check = s"""
+ | wrap3 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<4>("h7")), mux(lt(in1, SInt<2>("h-2")), add(in1, SInt<4>("h7")), in1))
+ | wrap5 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), in1)
+ | wrap6 <= mux(lt(in1, SInt<1>("h-1")), add(in1, SInt<5>("h9")), in1)
+ | wrap7 <= in1
+ | clip3 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<2>("h-2")), SInt<2>("h-2"), in1))
+ | clip4 <= mux(gt(in1, SInt<2>("h1")), SInt<2>("h1"), mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1))
+ | clip5 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), in1)
+ | clip6 <= mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1)
+ | clip7 <= in1
+ """.stripMargin
+ //| output wrap4: Interval
+ //| wrap4 <= wrap(in1, asInterval(s, -1, 1, 0), 0)
+ //| wrap4 <= add(rem(sub(in1, SInt<1>("h-1")), sub(SInt<2>("h1"), SInt<1>("h-1"))), SInt<1>("h-1"))
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "shift wrap/clip correctly" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input s: SInt<2>
+ | input in1: Interval[-3, 5].1
+ | output wrap1: Interval
+ | output clip1: Interval
+ | wrap1 <= wrap(in1, asInterval(s, -2, 2, 0))
+ | clip1 <= clip(in1, asInterval(s, -2, 2, 0))
+ | """.stripMargin
+ val check = s"""
+ | wrap1 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), mux(lt(in1, SInt<3>("h-4")), add(in1, SInt<5>("h9")), in1))
+ | clip1 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<3>("h-4")), SInt<3>("h-4"), in1))
+ """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "infer negative binary points" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in1: Interval[-2, 4].-1
+ | input in2: Interval[-4, 8].-2
+ | output out: Interval
+ | out <= add(in1, in2)
+ | """.stripMargin
+ val check = s"""
+ | output out : Interval[-6, 12].-1
+ """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "remove negative binary points" in {
+ val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths(), new RemoveIntervals())
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in1: Interval[-2, 4].-1
+ | input in2: Interval[-4, 8].-2
+ | output out: Interval.0
+ | out <= add(in1, in2)
+ | """.stripMargin
+ val check = s"""
+ | output out : SInt<5>
+ | out <= shl(add(in1, shl(in2, 1)), 1)
+ """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "implement squz properly" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck)
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input min: Interval[-1, 4].1
+ | input max: Interval[-3, 5].1
+ | input left: Interval[-3, 3].1
+ | input right: Interval[0, 5].1
+ | input off: Interval[-1, 4].2
+ | output minMax: Interval
+ | output maxMin: Interval
+ | output minLeft: Interval
+ | output leftMin: Interval
+ | output minRight: Interval
+ | output rightMin: Interval
+ | output minOff: Interval
+ | output offMin: Interval
+ |
+ | minMax <= squz(min, max)
+ | maxMin <= squz(max, min)
+ | minLeft <= squz(min, left)
+ | leftMin <= squz(left, min)
+ | minRight <= squz(min, right)
+ | rightMin <= squz(right, min)
+ | minOff <= squz(min, off)
+ | offMin <= squz(off, min)
+ | """.stripMargin
+ val check =
+ s"""
+ | output minMax : Interval[-1, 4].1
+ | output maxMin : Interval[-1, 4].1
+ | output minLeft : Interval[-1, 3].1
+ | output leftMin : Interval[-1, 3].1
+ | output minRight : Interval[0, 4].1
+ | output rightMin : Interval[0, 4].1
+ | output minOff : Interval[-1, 4].1
+ | output offMin : Interval[-1, 4].2
+ """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Interval types" should "lower squz properly" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals)
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input min: Interval[-1, 4].1
+ | input max: Interval[-3, 5].1
+ | input left: Interval[-3, 3].1
+ | input right: Interval[0, 5].1
+ | input off: Interval[-1, 4].2
+ | output minMax: Interval
+ | output maxMin: Interval
+ | output minLeft: Interval
+ | output leftMin: Interval
+ | output minRight: Interval
+ | output rightMin: Interval
+ | output minOff: Interval
+ | output offMin: Interval
+ |
+ | minMax <= squz(min, max)
+ | maxMin <= squz(max, min)
+ | minLeft <= squz(min, left)
+ | leftMin <= squz(left, min)
+ | minRight <= squz(min, right)
+ | rightMin <= squz(right, min)
+ | minOff <= squz(min, off)
+ | offMin <= squz(off, min)
+ | """.stripMargin
+ val check =
+ s"""
+ | minMax <= asSInt(bits(min, 4, 0))
+ | maxMin <= asSInt(bits(max, 4, 0))
+ | minLeft <= asSInt(bits(min, 3, 0))
+ | leftMin <= left
+ | minRight <= asSInt(bits(min, 4, 0))
+ | rightMin <= asSInt(bits(right, 4, 0))
+ | minOff <= asSInt(bits(min, 4, 0))
+ | offMin <= asSInt(bits(off, 5, 0))
+ """.stripMargin
+ executeTest(input, check.split("\n") map normalized, passes)
+ }
+ "Assigning a larger interval to a smaller interval" should "error!" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals)
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in: Interval[1, 4].1
+ | output out: Interval[2, 3].1
+ | out <= in
+ | """.stripMargin
+ intercept[InvalidConnect]{
+ executeTest(input, Nil, passes)
+ }
+ }
+ "Assigning a more precise interval to a less precise interval" should "error!" in {
+ val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals)
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in: Interval[2, 3].3
+ | output out: Interval[2, 3].1
+ | out <= in
+ | """.stripMargin
+ intercept[InvalidConnect]{
+ executeTest(input, Nil, passes)
+ }
+ }
+ "Chick's example" should "work" in {
+ val input =
+ s"""circuit IntervalChainedSubTester :
+ | module IntervalChainedSubTester :
+ | input clock : Clock
+ | input reset : UInt<1>
+ | node _GEN_0 = sub(SInt<6>("h11"), SInt<6>("h2")) @[IntervalSpec.scala 337:26 IntervalSpec.scala 337:26]
+ | node _GEN_1 = bits(_GEN_0, 4, 0) @[IntervalSpec.scala 337:26 IntervalSpec.scala 337:26]
+ | node intervalResult = asSInt(_GEN_1) @[IntervalSpec.scala 337:26 IntervalSpec.scala 337:26]
+ | skip
+ | node _T_1 = asUInt(intervalResult) @[IntervalSpec.scala 338:50]
+ | skip
+ | node _T_3 = eq(reset, UInt<1>("h0")) @[IntervalSpec.scala 338:9]
+ | node _T_4 = eq(intervalResult, SInt<5>("hf")) @[IntervalSpec.scala 339:25]
+ | skip
+ | node _T_6 = or(_T_4, reset) @[IntervalSpec.scala 339:9]
+ | node _T_7 = eq(_T_6, UInt<1>("h0")) @[IntervalSpec.scala 339:9]
+ | skip
+ | skip
+ | printf(clock, _T_3, "Interval result: %d", _T_1) @[IntervalSpec.scala 338:9]
+ | printf(clock, _T_7, "Assertion failed at IntervalSpec.scala:339 assert(intervalResult === 15.I)") @[IntervalSpec.scala 339:9]
+ | stop(clock, _T_7, 1) @[IntervalSpec.scala 339:9]
+ | stop(clock, _T_3, 0) @[IntervalSpec.scala 340:7]
+ |
+ """.stripMargin
+ compileToVerilog(input)
+ }
+
+ "Squeeze with disjoint intervals" should "error" in {
+ intercept[DisjointSqueeze] {
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in1: Interval[2, 3).3
+ | input in2: Interval[3, 6].3
+ | node out = squz(in1, in2)
+ """.stripMargin
+ compileToVerilog(input)
+ }
+ intercept[DisjointSqueeze] {
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in1: Interval[2, 3).3
+ | input in2: Interval[3, 6].3
+ | node out = squz(in2, in1)
+ """.stripMargin
+ compileToVerilog(input)
+ }
+ }
+
+ "Clip with disjoint intervals" should "work" in {
+ compileToVerilog(
+ s"""circuit Unit :
+ | module Unit :
+ | input in1: Interval[2, 3).3
+ | input in2: Interval[3, 6].3
+ | output out: Interval
+ | out <= clip(in1, in2)
+ """.stripMargin
+ )
+ compileToVerilog(
+ s"""circuit Unit :
+ | module Unit :
+ | input in1: Interval[2, 3).3
+ | input in2: Interval[4, 6].3
+ | node out = clip(in1, in2)
+ """.stripMargin
+ )
+ }
+
+
+ "Wrap with remainder" should "error" in {
+ intercept[WrapWithRemainder] {
+ val input =
+ s"""circuit Unit :
+ | module Unit :
+ | input in1: Interval[0, 300).3
+ | input in2: Interval[3, 6].3
+ | node out = wrap(in1, in2)
+ """.stripMargin
+ compileToVerilog(input)
+ }
+ }
+}