diff options
Diffstat (limited to 'src/test/scala/firrtlTests/interval')
| -rw-r--r-- | src/test/scala/firrtlTests/interval/IntervalMathSpec.scala | 183 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/interval/IntervalSpec.scala | 530 |
2 files changed, 713 insertions, 0 deletions
diff --git a/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala b/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala new file mode 100644 index 00000000..20fdeee1 --- /dev/null +++ b/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala @@ -0,0 +1,183 @@ +// See LICENSE for license details. + +package firrtlTests.interval + +import firrtl.Implicits.constraint2bound +import firrtl.{ChirrtlForm, CircuitState, LowFirrtlCompiler, Parser} +import firrtl.ir._ + +import scala.math.BigDecimal.RoundingMode._ +import firrtl.Parser.IgnoreInfo +import firrtl.constraint._ +import firrtlTests.FirrtlFlatSpec + +class IntervalMathSpec extends FirrtlFlatSpec { + val SumPattern = """.*output sum.*<(\d+)>.*""".r + val ProductPattern = """.*output product.*<(\d+)>.*""".r + val DifferencePattern = """.*output difference.*<(\d+)>.*""".r + val ComparisonPattern = """.*output (\w+).*UInt<(\d+)>.*""".r + val ShiftLeftPattern = """.*output shl.*<(\d+)>.*""".r + val ShiftRightPattern = """.*output shr.*<(\d+)>.*""".r + val DShiftLeftPattern = """.*output dshl.*<(\d+)>.*""".r + val DShiftRightPattern = """.*output dshr.*<(\d+)>.*""".r + val ArithAssignPattern = """\s*(\w+) <= asSInt\(bits\((\w+)\((.*)\).*\)\)\s*""".r + def getBound(bound: String, value: Double): IsKnown = bound match { + case "[" => Closed(BigDecimal(value)) + case "]" => Closed(BigDecimal(value)) + case "(" => Open(BigDecimal(value)) + case ")" => Open(BigDecimal(value)) + } + + val prec = 0.5 + + for { + lb1 <- Seq("[", "(") + lv1 <- Range.Double(-1.0, 1.0, prec) + uv1 <- if(lb1 == "[") Range.Double(lv1, 1.0, prec) else Range.Double(lv1 + prec, 1.0, prec) + ub1 <- if (lv1 == uv1) Seq("]") else Seq("]", ")") + bp1 <- 0 to 1 + lb2 <- Seq("[", "(") + lv2 <- Range.Double(-1.0, 1.0, prec) + uv2 <- if(lb2 == "[") Range.Double(lv2, 1.0, prec) else Range.Double(lv2 + prec, 1.0, prec) + ub2 <- if (lv2 == uv2) Seq("]") else Seq("]", ")") + bp2 <- 0 to 1 + } { + val it1 = IntervalType(getBound(lb1, lv1), getBound(ub1, uv1), IntWidth(bp1.toInt)) + val it2 = IntervalType(getBound(lb2, lv2), getBound(ub2, uv2), IntWidth(bp2.toInt)) + (it1.range, it2.range) match { + case (Some(Nil), _) => + case (_, Some(Nil)) => + case _ => + def config = s"$lb1$lv1,$uv1$ub1.$bp1 and $lb2$lv2,$uv2$ub2.$bp2" + + s"Configuration $config" should "pass" in { + + val input = + s"""circuit Unit : + | module Unit : + | input in1 : Interval$lb1$lv1, $uv1$ub1.$bp1 + | input in2 : Interval$lb2$lv2, $uv2$ub2.$bp2 + | input amt : UInt<3> + | output sum : Interval + | output difference : Interval + | output product : Interval + | output shl : Interval + | output shr : Interval + | output dshl : Interval + | output dshr : Interval + | output lt : UInt + | output leq : UInt + | output gt : UInt + | output geq : UInt + | output eq : UInt + | output neq : UInt + | output cat : UInt + | sum <= add(in1, in2) + | difference <= sub(in1, in2) + | product <= mul(in1, in2) + | shl <= shl(in1, 3) + | shr <= shr(in1, 3) + | dshl <= dshl(in1, amt) + | dshr <= dshr(in1, amt) + | lt <= lt(in1, in2) + | leq <= leq(in1, in2) + | gt <= gt(in1, in2) + | geq <= geq(in1, in2) + | eq <= eq(in1, in2) + | neq <= lt(in1, in2) + | cat <= cat(in1, in2) + | """.stripMargin + + val lowerer = new LowFirrtlCompiler + val res = lowerer.compileAndEmit(CircuitState(parse(input), ChirrtlForm)) + val output = res.getEmittedCircuit.value split "\n" + val min1 = Closed(it1.min.get) + val max1 = Closed(it1.max.get) + val min2 = Closed(it2.min.get) + val max2 = Closed(it2.max.get) + for (line <- output) { + line match { + case SumPattern(varWidth) => + val bp = IntWidth(Math.max(bp1.toInt, bp2.toInt)) + val it = IntervalType(IsAdd(min1, min2), IsAdd(max1, max2), bp) + assert(varWidth.toInt == it.width.asInstanceOf[IntWidth].width, s"$line,${it.range}") + case ProductPattern(varWidth) => + val bp = IntWidth(bp1.toInt + bp2.toInt) + val lv = IsMin(Seq(IsMul(min1, min2), IsMul(min1, max2), IsMul(max1, min2), IsMul(max1, max2))) + val uv = IsMax(Seq(IsMul(min1, min2), IsMul(min1, max2), IsMul(max1, min2), IsMul(max1, max2))) + assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "product") + case DifferencePattern(varWidth) => + val bp = IntWidth(Math.max(bp1.toInt, bp2.toInt)) + val lv = min1 + max2.neg + val uv = max1 + min2.neg + assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "diff") + case ShiftLeftPattern(varWidth) => + val bp = IntWidth(bp1.toInt) + val lv = min1 * Closed(8) + val uv = max1 * Closed(8) + val it = IntervalType(lv, uv, bp) + assert(varWidth.toInt == it.width.asInstanceOf[IntWidth].width, "shl") + case ShiftRightPattern(varWidth) => + val bp = IntWidth(bp1.toInt) + val lv = min1 * Closed(1/3) + val uv = max1 * Closed(1/3) + assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "shr") + case DShiftLeftPattern(varWidth) => + val bp = IntWidth(bp1.toInt) + val lv = min1 * Closed(128) + val uv = max1 * Closed(128) + assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "dshl") + case DShiftRightPattern(varWidth) => + val bp = IntWidth(bp1.toInt) + val lv = min1 + val uv = max1 + assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "dshr") + case ComparisonPattern(varWidth) => assert(varWidth.toInt == 1, "==") + case ArithAssignPattern(varName, operation, args) => + val arg1 = if(IntervalType(getBound(lb1, lv1), getBound(ub1, uv1), IntWidth(bp1)).width == IntWidth(0)) """SInt<1>("h0")""" else "in1" + val arg2 = if(IntervalType(getBound(lb2, lv2), getBound(ub2, uv2), IntWidth(bp2)).width == IntWidth(0)) """SInt<1>("h0")""" else "in2" + varName match { + case "sum" => + assert(operation === "add", s"""var sum should be result of an add in ${output.mkString("\n")}""") + if (bp1 > bp2) { + if (arg1 != arg2) assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(args.contains(s"shl($arg2, ${bp1 - bp2})"), + s"$config second arg incorrect in $line") + } else if (bp1 < bp2) { + assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"), + s"$config second arg incorrect in $line") + assert(!args.contains("shl($arg2"), s"$config second arg should be just $arg2 in $line") + } else { + assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") + } + case "product" => + assert(operation === "mul", s"var sum should be result of an add in $line") + assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") + case "difference" => + assert(operation === "sub", s"var difference should be result of an sub in $line") + if (bp1 > bp2) { + if (arg1 != arg2) assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(args.contains(s"shl($arg2, ${bp1 - bp2})"), + s"$config second arg incorrect in $line") + } else if (bp1 < bp2) { + assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"), + s"$config second arg incorrect in $line") + if (arg1 != arg2) assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") + } else { + assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") + } + case _ => + } + case _ => + } + } + } + } + } +} + + +// vim: set ts=4 sw=4 et: diff --git a/src/test/scala/firrtlTests/interval/IntervalSpec.scala b/src/test/scala/firrtlTests/interval/IntervalSpec.scala new file mode 100644 index 00000000..37d79c84 --- /dev/null +++ b/src/test/scala/firrtlTests/interval/IntervalSpec.scala @@ -0,0 +1,530 @@ +package firrtlTests +package interval + +import java.io._ + +import firrtl._ +import firrtl.ir.Circuit +import firrtl.passes._ +import firrtl.Parser.IgnoreInfo +import firrtl.passes.CheckTypes.InvalidConnect +import firrtl.passes.CheckWidths.DisjointSqueeze + +class IntervalSpec extends FirrtlFlatSpec { + private def executeTest(input: String, expected: Seq[String], passes: Seq[Transform]) = { + val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { + (c: Circuit, p: Transform) => + p.runTransform(CircuitState(c, UnknownForm, AnnotationSeq(Nil), None)).circuit + } + val lines = c.serialize.split("\n") map normalized + + expected foreach { e => + lines should contain(e) + } + } + + "Interval types" should "parse correctly" in { + val passes = Seq(ToWorkingIR) + val input = + """circuit Unit : + | module Unit : + | input in0 : Interval(-0.32, 10.1).4 + | input in1 : Interval[0, 10.1].4 + | input in2 : Interval(-0.32, 10].4 + | input in3 : Interval[-3, 10.1).4 + | input in4 : Interval(-0.32, 10.1) + | input in5 : Interval.4 + | input in6 : Interval + | output out0 : Interval.2 + | output out1 : Interval + | out0 <= add(in0, add(in1, add(in2, add(in3, add(in4, add(in5, in6)))))) + | out1 <= add(in0, add(in1, add(in2, add(in3, add(in4, add(in5, in6))))))""".stripMargin + executeTest(input, input.split("\n") map normalized, passes) + } + + "Interval types" should "infer bp correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints()) + val input = + """circuit Unit : + | module Unit : + | input in0 : Interval(-0.32, 10.1).4 + | input in1 : Interval[0, 10.1].3 + | input in2 : Interval(-0.32, 10].2 + | output out0 : Interval + | out0 <= add(in0, add(in1, in2))""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input in0 : Interval(-0.32, 10.1).4 + | input in1 : Interval[0, 10.1].3 + | input in2 : Interval(-0.32, 10].2 + | output out0 : Interval.4 + | out0 <= add(in0, add(in1, in2))""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + + "Interval types" should "trim known intervals correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals()) + val input = + """circuit Unit : + | module Unit : + | input in0 : Interval(-0.32, 10.1).4 + | input in1 : Interval[0, 10.1].3 + | input in2 : Interval(-0.32, 10].2 + | output out0 : Interval + | out0 <= add(in0, add(in1, in2))""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input in0 : Interval[-0.3125, 10.0625].4 + | input in1 : Interval[0, 10].3 + | input in2 : Interval[-0.25, 10].2 + | output out0 : Interval.4 + | out0 <= add(in0, incp(add(in1, incp(in2, 1)), 1))""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + + "Interval types" should "infer intervals correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) + val input = + """circuit Unit : + | module Unit : + | input in0 : Interval(0, 10).4 + | input in1 : Interval(0, 10].3 + | input in2 : Interval(-1, 3].2 + | output out0 : Interval + | output out1 : Interval + | output out2 : Interval + | out0 <= add(in0, add(in1, in2)) + | out1 <= mul(in0, mul(in1, in2)) + | out2 <= sub(in0, sub(in1, in2))""".stripMargin + val check = + """output out0 : Interval[-0.5625, 22.9375].4 + |output out1 : Interval[-74.53125, 298.125].9 + |output out2 : Interval[-10.6875, 12.8125].4""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + + "Interval types" should "be removed correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths(), new RemoveIntervals()) + val input = + """circuit Unit : + | module Unit : + | input in0 : Interval(0, 10).4 + | input in1 : Interval(0, 10].3 + | input in2 : Interval(-1, 3].2 + | output out0 : Interval + | output out1 : Interval + | output out2 : Interval + | out0 <= add(in0, add(in1, in2)) + | out1 <= mul(in0, mul(in1, in2)) + | out2 <= sub(in0, sub(in1, in2))""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input in0 : SInt<9> + | input in1 : SInt<8> + | input in2 : SInt<5> + | output out0 : SInt<10> + | output out1 : SInt<19> + | output out2 : SInt<9> + | out0 <= add(in0, shl(add(in1, shl(in2, 1)), 1)) + | out1 <= mul(in0, mul(in1, in2)) + | out2 <= sub(in0, shl(sub(in1, shl(in2, 1)), 1))""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + +"Interval types" should "infer multiplication by zero correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) + val input = + s"""circuit Unit : + | module Unit : + | input in1 : Interval[0, 0.5].1 + | input in2 : Interval[0, 0].1 + | output mul : Interval + | mul <= mul(in2, in1) + | """.stripMargin + val check = s"""output mul : Interval[0, 0].2 """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) +} + + "Interval types" should "infer muxes correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) + val input = + s"""circuit Unit : + | module Unit : + | input p : UInt<1> + | input in1 : Interval[0, 0.5].1 + | input in2 : Interval[0, 0].1 + | output out : Interval + | out <= mux(p, in2, in1) + | """.stripMargin + val check = s"""output out : Interval[0, 0.5].1 """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "infer dshl correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveKinds, ResolveGenders, new InferBinaryPoints(), new TrimIntervals, new InferWidths()) + val input = + s"""circuit Unit : + | module Unit : + | input p : UInt<3> + | input in1 : Interval[-1, 1].0 + | output out : Interval + | out <= dshl(in1, p) + | """.stripMargin + val check = s"""output out : Interval[-128, 128].0 """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "infer asInterval correctly" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferWidths()) + val input = + s"""circuit Unit : + | module Unit : + | input p : UInt<3> + | output out : Interval + | out <= asInterval(p, 0, 4, 1) + | """.stripMargin + val check = s"""output out : Interval[0, 2].1 """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "do wrap/clip correctly" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck()) + val input = + s"""circuit Unit : + | module Unit : + | input s: SInt<2> + | input u: UInt<3> + | input in1: Interval[-3, 5].0 + | output wrap3: Interval + | output wrap4: Interval + | output wrap5: Interval + | output wrap6: Interval + | output wrap7: Interval + | output clip3: Interval + | output clip4: Interval + | output clip5: Interval + | output clip6: Interval + | output clip7: Interval + | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0)) + | wrap4 <= wrap(in1, asInterval(s, -1, 1, 0)) + | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0)) + | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0)) + | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0)) + | clip3 <= clip(in1, asInterval(s, -2, 4, 0)) + | clip4 <= clip(in1, asInterval(s, -1, 1, 0)) + | clip5 <= clip(in1, asInterval(s, -4, 4, 0)) + | clip6 <= clip(in1, asInterval(s, -1, 7, 0)) + | clip7 <= clip(in1, asInterval(s, -4, 7, 0)) + """.stripMargin + //| output wrap1: Interval + //| output wrap2: Interval + //| output clip1: Interval + //| output clip2: Interval + //| wrap1 <= wrap(in1, u, 0) + //| wrap2 <= wrap(in1, s, 0) + //| clip1 <= clip(in1, u) + //| clip2 <= clip(in1, s) + val check = s""" + | output wrap3 : Interval[-2, 4].0 + | output wrap4 : Interval[-1, 1].0 + | output wrap5 : Interval[-4, 4].0 + | output wrap6 : Interval[-1, 7].0 + | output wrap7 : Interval[-4, 7].0 + | output clip3 : Interval[-2, 4].0 + | output clip4 : Interval[-1, 1].0 + | output clip5 : Interval[-3, 4].0 + | output clip6 : Interval[-1, 5].0 + | output clip7 : Interval[-3, 5].0 """.stripMargin + // TODO: this optimization + //| output wrap1 : Interval[0, 7].0 + //| output wrap2 : Interval[-2, 1].0 + //| output clip1 : Interval[0, 5].0 + //| output clip2 : Interval[-2, 1].0 + //| output wrap7 : Interval[-3, 5].0 + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "remove wrap/clip correctly" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck(), new RemoveIntervals()) + val input = + s"""circuit Unit : + | module Unit : + | input s: SInt<2> + | input u: UInt<3> + | input in1: Interval[-3, 5].0 + | output wrap3: Interval + | output wrap5: Interval + | output wrap6: Interval + | output wrap7: Interval + | output clip3: Interval + | output clip4: Interval + | output clip5: Interval + | output clip6: Interval + | output clip7: Interval + | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0)) + | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0)) + | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0)) + | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0)) + | clip3 <= clip(in1, asInterval(s, -2, 4, 0)) + | clip4 <= clip(in1, asInterval(s, -1, 1, 0)) + | clip5 <= clip(in1, asInterval(s, -4, 4, 0)) + | clip6 <= clip(in1, asInterval(s, -1, 7, 0)) + | clip7 <= clip(in1, asInterval(s, -4, 7, 0)) + | """.stripMargin + val check = s""" + | wrap3 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<4>("h7")), mux(lt(in1, SInt<2>("h-2")), add(in1, SInt<4>("h7")), in1)) + | wrap5 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), in1) + | wrap6 <= mux(lt(in1, SInt<1>("h-1")), add(in1, SInt<5>("h9")), in1) + | wrap7 <= in1 + | clip3 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<2>("h-2")), SInt<2>("h-2"), in1)) + | clip4 <= mux(gt(in1, SInt<2>("h1")), SInt<2>("h1"), mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1)) + | clip5 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), in1) + | clip6 <= mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1) + | clip7 <= in1 + """.stripMargin + //| output wrap4: Interval + //| wrap4 <= wrap(in1, asInterval(s, -1, 1, 0), 0) + //| wrap4 <= add(rem(sub(in1, SInt<1>("h-1")), sub(SInt<2>("h1"), SInt<1>("h-1"))), SInt<1>("h-1")) + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "shift wrap/clip correctly" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals()) + val input = + s"""circuit Unit : + | module Unit : + | input s: SInt<2> + | input in1: Interval[-3, 5].1 + | output wrap1: Interval + | output clip1: Interval + | wrap1 <= wrap(in1, asInterval(s, -2, 2, 0)) + | clip1 <= clip(in1, asInterval(s, -2, 2, 0)) + | """.stripMargin + val check = s""" + | wrap1 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), mux(lt(in1, SInt<3>("h-4")), add(in1, SInt<5>("h9")), in1)) + | clip1 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<3>("h-4")), SInt<3>("h-4"), in1)) + """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "infer negative binary points" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck()) + val input = + s"""circuit Unit : + | module Unit : + | input in1: Interval[-2, 4].-1 + | input in2: Interval[-4, 8].-2 + | output out: Interval + | out <= add(in1, in2) + | """.stripMargin + val check = s""" + | output out : Interval[-6, 12].-1 + """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "remove negative binary points" in { + val passes = Seq(ToWorkingIR, InferTypes, ResolveGenders, new InferBinaryPoints(), new TrimIntervals(), new InferWidths(), new RemoveIntervals()) + val input = + s"""circuit Unit : + | module Unit : + | input in1: Interval[-2, 4].-1 + | input in2: Interval[-4, 8].-2 + | output out: Interval.0 + | out <= add(in1, in2) + | """.stripMargin + val check = s""" + | output out : SInt<5> + | out <= shl(add(in1, shl(in2, 1)), 1) + """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "implement squz properly" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck) + val input = + s"""circuit Unit : + | module Unit : + | input min: Interval[-1, 4].1 + | input max: Interval[-3, 5].1 + | input left: Interval[-3, 3].1 + | input right: Interval[0, 5].1 + | input off: Interval[-1, 4].2 + | output minMax: Interval + | output maxMin: Interval + | output minLeft: Interval + | output leftMin: Interval + | output minRight: Interval + | output rightMin: Interval + | output minOff: Interval + | output offMin: Interval + | + | minMax <= squz(min, max) + | maxMin <= squz(max, min) + | minLeft <= squz(min, left) + | leftMin <= squz(left, min) + | minRight <= squz(min, right) + | rightMin <= squz(right, min) + | minOff <= squz(min, off) + | offMin <= squz(off, min) + | """.stripMargin + val check = + s""" + | output minMax : Interval[-1, 4].1 + | output maxMin : Interval[-1, 4].1 + | output minLeft : Interval[-1, 3].1 + | output leftMin : Interval[-1, 3].1 + | output minRight : Interval[0, 4].1 + | output rightMin : Interval[0, 4].1 + | output minOff : Interval[-1, 4].1 + | output offMin : Interval[-1, 4].2 + """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Interval types" should "lower squz properly" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals) + val input = + s"""circuit Unit : + | module Unit : + | input min: Interval[-1, 4].1 + | input max: Interval[-3, 5].1 + | input left: Interval[-3, 3].1 + | input right: Interval[0, 5].1 + | input off: Interval[-1, 4].2 + | output minMax: Interval + | output maxMin: Interval + | output minLeft: Interval + | output leftMin: Interval + | output minRight: Interval + | output rightMin: Interval + | output minOff: Interval + | output offMin: Interval + | + | minMax <= squz(min, max) + | maxMin <= squz(max, min) + | minLeft <= squz(min, left) + | leftMin <= squz(left, min) + | minRight <= squz(min, right) + | rightMin <= squz(right, min) + | minOff <= squz(min, off) + | offMin <= squz(off, min) + | """.stripMargin + val check = + s""" + | minMax <= asSInt(bits(min, 4, 0)) + | maxMin <= asSInt(bits(max, 4, 0)) + | minLeft <= asSInt(bits(min, 3, 0)) + | leftMin <= left + | minRight <= asSInt(bits(min, 4, 0)) + | rightMin <= asSInt(bits(right, 4, 0)) + | minOff <= asSInt(bits(min, 4, 0)) + | offMin <= asSInt(bits(off, 5, 0)) + """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Assigning a larger interval to a smaller interval" should "error!" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals) + val input = + s"""circuit Unit : + | module Unit : + | input in: Interval[1, 4].1 + | output out: Interval[2, 3].1 + | out <= in + | """.stripMargin + intercept[InvalidConnect]{ + executeTest(input, Nil, passes) + } + } + "Assigning a more precise interval to a less precise interval" should "error!" in { + val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals) + val input = + s"""circuit Unit : + | module Unit : + | input in: Interval[2, 3].3 + | output out: Interval[2, 3].1 + | out <= in + | """.stripMargin + intercept[InvalidConnect]{ + executeTest(input, Nil, passes) + } + } + "Chick's example" should "work" in { + val input = + s"""circuit IntervalChainedSubTester : + | module IntervalChainedSubTester : + | input clock : Clock + | input reset : UInt<1> + | node _GEN_0 = sub(SInt<6>("h11"), SInt<6>("h2")) @[IntervalSpec.scala 337:26 IntervalSpec.scala 337:26] + | node _GEN_1 = bits(_GEN_0, 4, 0) @[IntervalSpec.scala 337:26 IntervalSpec.scala 337:26] + | node intervalResult = asSInt(_GEN_1) @[IntervalSpec.scala 337:26 IntervalSpec.scala 337:26] + | skip + | node _T_1 = asUInt(intervalResult) @[IntervalSpec.scala 338:50] + | skip + | node _T_3 = eq(reset, UInt<1>("h0")) @[IntervalSpec.scala 338:9] + | node _T_4 = eq(intervalResult, SInt<5>("hf")) @[IntervalSpec.scala 339:25] + | skip + | node _T_6 = or(_T_4, reset) @[IntervalSpec.scala 339:9] + | node _T_7 = eq(_T_6, UInt<1>("h0")) @[IntervalSpec.scala 339:9] + | skip + | skip + | printf(clock, _T_3, "Interval result: %d", _T_1) @[IntervalSpec.scala 338:9] + | printf(clock, _T_7, "Assertion failed at IntervalSpec.scala:339 assert(intervalResult === 15.I)") @[IntervalSpec.scala 339:9] + | stop(clock, _T_7, 1) @[IntervalSpec.scala 339:9] + | stop(clock, _T_3, 0) @[IntervalSpec.scala 340:7] + | + """.stripMargin + compileToVerilog(input) + } + + "Squeeze with disjoint intervals" should "error" in { + intercept[DisjointSqueeze] { + val input = + s"""circuit Unit : + | module Unit : + | input in1: Interval[2, 3).3 + | input in2: Interval[3, 6].3 + | node out = squz(in1, in2) + """.stripMargin + compileToVerilog(input) + } + intercept[DisjointSqueeze] { + val input = + s"""circuit Unit : + | module Unit : + | input in1: Interval[2, 3).3 + | input in2: Interval[3, 6].3 + | node out = squz(in2, in1) + """.stripMargin + compileToVerilog(input) + } + } + + "Clip with disjoint intervals" should "work" in { + compileToVerilog( + s"""circuit Unit : + | module Unit : + | input in1: Interval[2, 3).3 + | input in2: Interval[3, 6].3 + | output out: Interval + | out <= clip(in1, in2) + """.stripMargin + ) + compileToVerilog( + s"""circuit Unit : + | module Unit : + | input in1: Interval[2, 3).3 + | input in2: Interval[4, 6].3 + | node out = clip(in1, in2) + """.stripMargin + ) + } + + + "Wrap with remainder" should "error" in { + intercept[WrapWithRemainder] { + val input = + s"""circuit Unit : + | module Unit : + | input in1: Interval[0, 300).3 + | input in2: Interval[3, 6].3 + | node out = wrap(in1, in2) + """.stripMargin + compileToVerilog(input) + } + } +} |
