diff options
| author | chick | 2020-08-14 19:47:53 -0700 |
|---|---|---|
| committer | Jack Koenig | 2020-08-14 19:47:53 -0700 |
| commit | 6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch) | |
| tree | 2ed103ee80b0fba613c88a66af854ae9952610ce /src/test/scala/firrtlTests/interval | |
| parent | b516293f703c4de86397862fee1897aded2ae140 (diff) | |
All of src/ formatted with scalafmt
Diffstat (limited to 'src/test/scala/firrtlTests/interval')
| -rw-r--r-- | src/test/scala/firrtlTests/interval/IntervalMathSpec.scala | 162 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/interval/IntervalSpec.scala | 387 |
2 files changed, 289 insertions, 260 deletions
diff --git a/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala b/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala index 656e1f8c..74e6cabf 100644 --- a/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala +++ b/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala @@ -10,14 +10,14 @@ import firrtl.constraint._ import firrtl.testutils.FirrtlFlatSpec class IntervalMathSpec extends FirrtlFlatSpec { - val SumPattern = """.*output sum.*<(\d+)>.*""".r - val ProductPattern = """.*output product.*<(\d+)>.*""".r - val DifferencePattern = """.*output difference.*<(\d+)>.*""".r - val ComparisonPattern = """.*output (\w+).*UInt<(\d+)>.*""".r - val ShiftLeftPattern = """.*output shl.*<(\d+)>.*""".r - val ShiftRightPattern = """.*output shr.*<(\d+)>.*""".r - val DShiftLeftPattern = """.*output dshl.*<(\d+)>.*""".r - val DShiftRightPattern = """.*output dshr.*<(\d+)>.*""".r + val SumPattern = """.*output sum.*<(\d+)>.*""".r + val ProductPattern = """.*output product.*<(\d+)>.*""".r + val DifferencePattern = """.*output difference.*<(\d+)>.*""".r + val ComparisonPattern = """.*output (\w+).*UInt<(\d+)>.*""".r + val ShiftLeftPattern = """.*output shl.*<(\d+)>.*""".r + val ShiftRightPattern = """.*output shr.*<(\d+)>.*""".r + val DShiftLeftPattern = """.*output dshl.*<(\d+)>.*""".r + val DShiftRightPattern = """.*output dshr.*<(\d+)>.*""".r val ArithAssignPattern = """\s*(\w+) <= asSInt\(bits\((\w+)\((.*)\).*\)\)\s*""".r def getBound(bound: String, value: BigDecimal): IsKnown = bound match { case "[" => Closed(value) @@ -29,16 +29,16 @@ class IntervalMathSpec extends FirrtlFlatSpec { val prec = 0.5 for { - lb1 <- Seq("[", "(") - lv1 <- Range.BigDecimal(-1.0, 1.0, prec) - uv1 <- if(lb1 == "[") Range.BigDecimal(lv1, 1.0, prec) else Range.BigDecimal(lv1 + prec, 1.0, prec) - ub1 <- if (lv1 == uv1) Seq("]") else Seq("]", ")") - bp1 <- 0 to 1 - lb2 <- Seq("[", "(") - lv2 <- Range.BigDecimal(-1.0, 1.0, prec) - uv2 <- if(lb2 == "[") Range.BigDecimal(lv2, 1.0, prec) else Range.BigDecimal(lv2 + prec, 1.0, prec) - ub2 <- if (lv2 == uv2) Seq("]") else Seq("]", ")") - bp2 <- 0 to 1 + lb1 <- Seq("[", "(") + lv1 <- Range.BigDecimal(-1.0, 1.0, prec) + uv1 <- if (lb1 == "[") Range.BigDecimal(lv1, 1.0, prec) else Range.BigDecimal(lv1 + prec, 1.0, prec) + ub1 <- if (lv1 == uv1) Seq("]") else Seq("]", ")") + bp1 <- 0 to 1 + lb2 <- Seq("[", "(") + lv2 <- Range.BigDecimal(-1.0, 1.0, prec) + uv2 <- if (lb2 == "[") Range.BigDecimal(lv2, 1.0, prec) else Range.BigDecimal(lv2 + prec, 1.0, prec) + ub2 <- if (lv2 == uv2) Seq("]") else Seq("]", ")") + bp2 <- 0 to 1 } { val it1 = IntervalType(getBound(lb1, lv1), getBound(ub1, uv1), IntWidth(bp1.toInt)) val it2 = IntervalType(getBound(lb2, lv2), getBound(ub2, uv2), IntWidth(bp2.toInt)) @@ -47,103 +47,108 @@ class IntervalMathSpec extends FirrtlFlatSpec { case (_, Some(Nil)) => case _ => def config = s"$lb1$lv1,$uv1$ub1.$bp1 and $lb2$lv2,$uv2$ub2.$bp2" - + s"Configuration $config" should "pass" in { - + val input = s"""circuit Unit : - | module Unit : - | input in1 : Interval$lb1$lv1, $uv1$ub1.$bp1 - | input in2 : Interval$lb2$lv2, $uv2$ub2.$bp2 - | input amt : UInt<3> - | output sum : Interval - | output difference : Interval - | output product : Interval - | output shl : Interval - | output shr : Interval - | output dshl : Interval - | output dshr : Interval - | output lt : UInt - | output leq : UInt - | output gt : UInt - | output geq : UInt - | output eq : UInt - | output neq : UInt - | output cat : UInt - | sum <= add(in1, in2) - | difference <= sub(in1, in2) - | product <= mul(in1, in2) - | shl <= shl(in1, 3) - | shr <= shr(in1, 3) - | dshl <= dshl(in1, amt) - | dshr <= dshr(in1, amt) - | lt <= lt(in1, in2) - | leq <= leq(in1, in2) - | gt <= gt(in1, in2) - | geq <= geq(in1, in2) - | eq <= eq(in1, in2) - | neq <= lt(in1, in2) - | cat <= cat(in1, in2) - | """.stripMargin - + | module Unit : + | input in1 : Interval$lb1$lv1, $uv1$ub1.$bp1 + | input in2 : Interval$lb2$lv2, $uv2$ub2.$bp2 + | input amt : UInt<3> + | output sum : Interval + | output difference : Interval + | output product : Interval + | output shl : Interval + | output shr : Interval + | output dshl : Interval + | output dshr : Interval + | output lt : UInt + | output leq : UInt + | output gt : UInt + | output geq : UInt + | output eq : UInt + | output neq : UInt + | output cat : UInt + | sum <= add(in1, in2) + | difference <= sub(in1, in2) + | product <= mul(in1, in2) + | shl <= shl(in1, 3) + | shr <= shr(in1, 3) + | dshl <= dshl(in1, amt) + | dshr <= dshr(in1, amt) + | lt <= lt(in1, in2) + | leq <= leq(in1, in2) + | gt <= gt(in1, in2) + | geq <= geq(in1, in2) + | eq <= eq(in1, in2) + | neq <= lt(in1, in2) + | cat <= cat(in1, in2) + | """.stripMargin + val lowerer = new LowFirrtlCompiler val res = lowerer.compileAndEmit(CircuitState(parse(input), ChirrtlForm)) - val output = res.getEmittedCircuit.value split "\n" + val output = res.getEmittedCircuit.value.split("\n") val min1 = Closed(it1.min.get) val max1 = Closed(it1.max.get) val min2 = Closed(it2.min.get) val max2 = Closed(it2.max.get) for (line <- output) { line match { - case SumPattern(varWidth) => + case SumPattern(varWidth) => val bp = IntWidth(Math.max(bp1.toInt, bp2.toInt)) val it = IntervalType(IsAdd(min1, min2), IsAdd(max1, max2), bp) assert(varWidth.toInt == it.width.asInstanceOf[IntWidth].width, s"$line,${it.range}") - case ProductPattern(varWidth) => + case ProductPattern(varWidth) => val bp = IntWidth(bp1.toInt + bp2.toInt) val lv = IsMin(Seq(IsMul(min1, min2), IsMul(min1, max2), IsMul(max1, min2), IsMul(max1, max2))) val uv = IsMax(Seq(IsMul(min1, min2), IsMul(min1, max2), IsMul(max1, min2), IsMul(max1, max2))) assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "product") - case DifferencePattern(varWidth) => + case DifferencePattern(varWidth) => val bp = IntWidth(Math.max(bp1.toInt, bp2.toInt)) val lv = min1 + max2.neg val uv = max1 + min2.neg assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "diff") - case ShiftLeftPattern(varWidth) => + case ShiftLeftPattern(varWidth) => val bp = IntWidth(bp1.toInt) val lv = min1 * Closed(8) val uv = max1 * Closed(8) val it = IntervalType(lv, uv, bp) assert(varWidth.toInt == it.width.asInstanceOf[IntWidth].width, "shl") - case ShiftRightPattern(varWidth) => + case ShiftRightPattern(varWidth) => val bp = IntWidth(bp1.toInt) - val lv = min1 * Closed(1/3) - val uv = max1 * Closed(1/3) + val lv = min1 * Closed(1 / 3) + val uv = max1 * Closed(1 / 3) assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "shr") - case DShiftLeftPattern(varWidth) => + case DShiftLeftPattern(varWidth) => val bp = IntWidth(bp1.toInt) val lv = min1 * Closed(128) val uv = max1 * Closed(128) assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "dshl") - case DShiftRightPattern(varWidth) => + case DShiftRightPattern(varWidth) => val bp = IntWidth(bp1.toInt) val lv = min1 val uv = max1 assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "dshr") case ComparisonPattern(varWidth) => assert(varWidth.toInt == 1, "==") case ArithAssignPattern(varName, operation, args) => - val arg1 = if(IntervalType(getBound(lb1, lv1), getBound(ub1, uv1), IntWidth(bp1)).width == IntWidth(0)) """SInt<1>("h0")""" else "in1" - val arg2 = if(IntervalType(getBound(lb2, lv2), getBound(ub2, uv2), IntWidth(bp2)).width == IntWidth(0)) """SInt<1>("h0")""" else "in2" + val arg1 = + if (IntervalType(getBound(lb1, lv1), getBound(ub1, uv1), IntWidth(bp1)).width == IntWidth(0)) + """SInt<1>("h0")""" + else "in1" + val arg2 = + if (IntervalType(getBound(lb2, lv2), getBound(ub2, uv2), IntWidth(bp2)).width == IntWidth(0)) + """SInt<1>("h0")""" + else "in2" varName match { case "sum" => assert(operation === "add", s"""var sum should be result of an add in ${output.mkString("\n")}""") if (bp1 > bp2) { - if (arg1 != arg2) assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") - assert(args.contains(s"shl($arg2, ${bp1 - bp2})"), - s"$config second arg incorrect in $line") + if (arg1 != arg2) + assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(args.contains(s"shl($arg2, ${bp1 - bp2})"), s"$config second arg incorrect in $line") } else if (bp1 < bp2) { - assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"), - s"$config second arg incorrect in $line") + assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"), s"$config second arg incorrect in $line") assert(!args.contains("shl($arg2"), s"$config second arg should be just $arg2 in $line") } else { assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") @@ -156,13 +161,13 @@ class IntervalMathSpec extends FirrtlFlatSpec { case "difference" => assert(operation === "sub", s"var difference should be result of an sub in $line") if (bp1 > bp2) { - if (arg1 != arg2) assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") - assert(args.contains(s"shl($arg2, ${bp1 - bp2})"), - s"$config second arg incorrect in $line") + if (arg1 != arg2) + assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(args.contains(s"shl($arg2, ${bp1 - bp2})"), s"$config second arg incorrect in $line") } else if (bp1 < bp2) { - assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"), - s"$config second arg incorrect in $line") - if (arg1 != arg2) assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") + assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"), s"$config second arg incorrect in $line") + if (arg1 != arg2) + assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") } else { assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") @@ -170,12 +175,11 @@ class IntervalMathSpec extends FirrtlFlatSpec { case _ => } case _ => + } } } - } } } } - // vim: set ts=4 sw=4 et: diff --git a/src/test/scala/firrtlTests/interval/IntervalSpec.scala b/src/test/scala/firrtlTests/interval/IntervalSpec.scala index 5d82f6b5..1a39e98e 100644 --- a/src/test/scala/firrtlTests/interval/IntervalSpec.scala +++ b/src/test/scala/firrtlTests/interval/IntervalSpec.scala @@ -10,13 +10,12 @@ import firrtl.testutils.FirrtlFlatSpec class IntervalSpec extends FirrtlFlatSpec { private def executeTest(input: String, expected: Seq[String], passes: Seq[Transform]) = { - val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Transform) => - p.runTransform(CircuitState(c, UnknownForm, AnnotationSeq(Nil), None)).circuit + val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Transform) => + p.runTransform(CircuitState(c, UnknownForm, AnnotationSeq(Nil), None)).circuit } - val lines = c.serialize.split("\n") map normalized + val lines = c.serialize.split("\n").map(normalized) - expected foreach { e => + expected.foreach { e => lines should contain(e) } } @@ -37,7 +36,7 @@ class IntervalSpec extends FirrtlFlatSpec { | output out1 : Interval | out0 <= add(in0, add(in1, add(in2, add(in3, add(in4, add(in5, in6)))))) | out1 <= add(in0, add(in1, add(in2, add(in3, add(in4, add(in5, in6))))))""".stripMargin - executeTest(input, input.split("\n") map normalized, passes) + executeTest(input, input.split("\n").map(normalized), passes) } "Interval types" should "infer bp correctly" in { @@ -58,7 +57,7 @@ class IntervalSpec extends FirrtlFlatSpec { | input in2 : Interval(-0.32, 10].2 | output out0 : Interval.4 | out0 <= add(in0, add(in1, in2))""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "trim known intervals correctly" in { @@ -79,11 +78,12 @@ class IntervalSpec extends FirrtlFlatSpec { | input in2 : Interval[-0.25, 10].2 | output out0 : Interval.4 | out0 <= add(in0, incp(add(in1, incp(in2, 1)), 1))""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "infer intervals correctly" in { - val passes = Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) + val passes = + Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) val input = """circuit Unit : | module Unit : @@ -100,11 +100,19 @@ class IntervalSpec extends FirrtlFlatSpec { """output out0 : Interval[-0.5625, 22.9375].4 |output out1 : Interval[-74.53125, 298.125].9 |output out2 : Interval[-10.6875, 12.8125].4""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "be removed correctly" in { - val passes = Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths(), new RemoveIntervals()) + val passes = Seq( + ToWorkingIR, + InferTypes, + ResolveFlows, + new InferBinaryPoints(), + new TrimIntervals(), + new InferWidths(), + new RemoveIntervals() + ) val input = """circuit Unit : | module Unit : @@ -129,209 +137,227 @@ class IntervalSpec extends FirrtlFlatSpec { | out0 <= add(in0, shl(add(in1, shl(in2, 1)), 1)) | out1 <= mul(in0, mul(in1, in2)) | out2 <= sub(in0, shl(sub(in1, shl(in2, 1)), 1))""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } -"Interval types" should "infer multiplication by zero correctly" in { - val passes = Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) + "Interval types" should "infer multiplication by zero correctly" in { + val passes = + Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) val input = s"""circuit Unit : - | module Unit : - | input in1 : Interval[0, 0.5].1 - | input in2 : Interval[0, 0].1 - | output mul : Interval - | mul <= mul(in2, in1) - | """.stripMargin - val check = s"""output mul : Interval[0, 0].2 """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) -} + | module Unit : + | input in1 : Interval[0, 0.5].1 + | input in2 : Interval[0, 0].1 + | output mul : Interval + | mul <= mul(in2, in1) + | """.stripMargin + val check = s"""output mul : Interval[0, 0].2 """.stripMargin + executeTest(input, check.split("\n").map(normalized), passes) + } "Interval types" should "infer muxes correctly" in { - val passes = Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) - val input = - s"""circuit Unit : - | module Unit : - | input p : UInt<1> - | input in1 : Interval[0, 0.5].1 - | input in2 : Interval[0, 0].1 - | output out : Interval - | out <= mux(p, in2, in1) - | """.stripMargin + val passes = + Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) + val input = + s"""circuit Unit : + | module Unit : + | input p : UInt<1> + | input in1 : Interval[0, 0.5].1 + | input in2 : Interval[0, 0].1 + | output out : Interval + | out <= mux(p, in2, in1) + | """.stripMargin val check = s"""output out : Interval[0, 0.5].1 """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "infer dshl correctly" in { - val passes = Seq(ToWorkingIR, InferTypes, ResolveKinds, ResolveFlows, new InferBinaryPoints(), new TrimIntervals, new InferWidths()) - val input = - s"""circuit Unit : - | module Unit : - | input p : UInt<3> - | input in1 : Interval[-1, 1].0 - | output out : Interval - | out <= dshl(in1, p) - | """.stripMargin + val passes = Seq( + ToWorkingIR, + InferTypes, + ResolveKinds, + ResolveFlows, + new InferBinaryPoints(), + new TrimIntervals, + new InferWidths() + ) + val input = + s"""circuit Unit : + | module Unit : + | input p : UInt<3> + | input in1 : Interval[-1, 1].0 + | output out : Interval + | out <= dshl(in1, p) + | """.stripMargin val check = s"""output out : Interval[-128, 128].0 """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "infer asInterval correctly" in { val passes = Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferWidths()) - val input = - s"""circuit Unit : - | module Unit : - | input p : UInt<3> - | output out : Interval - | out <= asInterval(p, 0, 4, 1) - | """.stripMargin + val input = + s"""circuit Unit : + | module Unit : + | input p : UInt<3> + | output out : Interval + | out <= asInterval(p, 0, 4, 1) + | """.stripMargin val check = s"""output out : Interval[0, 2].1 """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "do wrap/clip correctly" in { val passes = Seq(ToWorkingIR, new ResolveAndCheck()) - val input = - s"""circuit Unit : - | module Unit : - | input s: SInt<2> - | input u: UInt<3> - | input in1: Interval[-3, 5].0 - | output wrap3: Interval - | output wrap4: Interval - | output wrap5: Interval - | output wrap6: Interval - | output wrap7: Interval - | output clip3: Interval - | output clip4: Interval - | output clip5: Interval - | output clip6: Interval - | output clip7: Interval - | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0)) - | wrap4 <= wrap(in1, asInterval(s, -1, 1, 0)) - | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0)) - | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0)) - | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0)) - | clip3 <= clip(in1, asInterval(s, -2, 4, 0)) - | clip4 <= clip(in1, asInterval(s, -1, 1, 0)) - | clip5 <= clip(in1, asInterval(s, -4, 4, 0)) - | clip6 <= clip(in1, asInterval(s, -1, 7, 0)) - | clip7 <= clip(in1, asInterval(s, -4, 7, 0)) + val input = + s"""circuit Unit : + | module Unit : + | input s: SInt<2> + | input u: UInt<3> + | input in1: Interval[-3, 5].0 + | output wrap3: Interval + | output wrap4: Interval + | output wrap5: Interval + | output wrap6: Interval + | output wrap7: Interval + | output clip3: Interval + | output clip4: Interval + | output clip5: Interval + | output clip6: Interval + | output clip7: Interval + | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0)) + | wrap4 <= wrap(in1, asInterval(s, -1, 1, 0)) + | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0)) + | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0)) + | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0)) + | clip3 <= clip(in1, asInterval(s, -2, 4, 0)) + | clip4 <= clip(in1, asInterval(s, -1, 1, 0)) + | clip5 <= clip(in1, asInterval(s, -4, 4, 0)) + | clip6 <= clip(in1, asInterval(s, -1, 7, 0)) + | clip7 <= clip(in1, asInterval(s, -4, 7, 0)) """.stripMargin - //| output wrap1: Interval - //| output wrap2: Interval - //| output clip1: Interval - //| output clip2: Interval - //| wrap1 <= wrap(in1, u, 0) - //| wrap2 <= wrap(in1, s, 0) - //| clip1 <= clip(in1, u) - //| clip2 <= clip(in1, s) + //| output wrap1: Interval + //| output wrap2: Interval + //| output clip1: Interval + //| output clip2: Interval + //| wrap1 <= wrap(in1, u, 0) + //| wrap2 <= wrap(in1, s, 0) + //| clip1 <= clip(in1, u) + //| clip2 <= clip(in1, s) val check = s""" - | output wrap3 : Interval[-2, 4].0 - | output wrap4 : Interval[-1, 1].0 - | output wrap5 : Interval[-4, 4].0 - | output wrap6 : Interval[-1, 7].0 - | output wrap7 : Interval[-4, 7].0 - | output clip3 : Interval[-2, 4].0 - | output clip4 : Interval[-1, 1].0 - | output clip5 : Interval[-3, 4].0 - | output clip6 : Interval[-1, 5].0 - | output clip7 : Interval[-3, 5].0 """.stripMargin - // TODO: this optimization - //| output wrap1 : Interval[0, 7].0 - //| output wrap2 : Interval[-2, 1].0 - //| output clip1 : Interval[0, 5].0 - //| output clip2 : Interval[-2, 1].0 - //| output wrap7 : Interval[-3, 5].0 - executeTest(input, check.split("\n") map normalized, passes) + | output wrap3 : Interval[-2, 4].0 + | output wrap4 : Interval[-1, 1].0 + | output wrap5 : Interval[-4, 4].0 + | output wrap6 : Interval[-1, 7].0 + | output wrap7 : Interval[-4, 7].0 + | output clip3 : Interval[-2, 4].0 + | output clip4 : Interval[-1, 1].0 + | output clip5 : Interval[-3, 4].0 + | output clip6 : Interval[-1, 5].0 + | output clip7 : Interval[-3, 5].0 """.stripMargin + // TODO: this optimization + //| output wrap1 : Interval[0, 7].0 + //| output wrap2 : Interval[-2, 1].0 + //| output clip1 : Interval[0, 5].0 + //| output clip2 : Interval[-2, 1].0 + //| output wrap7 : Interval[-3, 5].0 + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "remove wrap/clip correctly" in { val passes = Seq(ToWorkingIR, new ResolveAndCheck(), new RemoveIntervals()) - val input = - s"""circuit Unit : - | module Unit : - | input s: SInt<2> - | input u: UInt<3> - | input in1: Interval[-3, 5].0 - | output wrap3: Interval - | output wrap5: Interval - | output wrap6: Interval - | output wrap7: Interval - | output clip3: Interval - | output clip4: Interval - | output clip5: Interval - | output clip6: Interval - | output clip7: Interval - | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0)) - | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0)) - | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0)) - | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0)) - | clip3 <= clip(in1, asInterval(s, -2, 4, 0)) - | clip4 <= clip(in1, asInterval(s, -1, 1, 0)) - | clip5 <= clip(in1, asInterval(s, -4, 4, 0)) - | clip6 <= clip(in1, asInterval(s, -1, 7, 0)) - | clip7 <= clip(in1, asInterval(s, -4, 7, 0)) - | """.stripMargin + val input = + s"""circuit Unit : + | module Unit : + | input s: SInt<2> + | input u: UInt<3> + | input in1: Interval[-3, 5].0 + | output wrap3: Interval + | output wrap5: Interval + | output wrap6: Interval + | output wrap7: Interval + | output clip3: Interval + | output clip4: Interval + | output clip5: Interval + | output clip6: Interval + | output clip7: Interval + | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0)) + | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0)) + | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0)) + | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0)) + | clip3 <= clip(in1, asInterval(s, -2, 4, 0)) + | clip4 <= clip(in1, asInterval(s, -1, 1, 0)) + | clip5 <= clip(in1, asInterval(s, -4, 4, 0)) + | clip6 <= clip(in1, asInterval(s, -1, 7, 0)) + | clip7 <= clip(in1, asInterval(s, -4, 7, 0)) + | """.stripMargin val check = s""" - | wrap3 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<4>("h7")), mux(lt(in1, SInt<2>("h-2")), add(in1, SInt<4>("h7")), in1)) - | wrap5 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), in1) - | wrap6 <= mux(lt(in1, SInt<1>("h-1")), add(in1, SInt<5>("h9")), in1) - | wrap7 <= in1 - | clip3 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<2>("h-2")), SInt<2>("h-2"), in1)) - | clip4 <= mux(gt(in1, SInt<2>("h1")), SInt<2>("h1"), mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1)) - | clip5 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), in1) - | clip6 <= mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1) - | clip7 <= in1 + | wrap3 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<4>("h7")), mux(lt(in1, SInt<2>("h-2")), add(in1, SInt<4>("h7")), in1)) + | wrap5 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), in1) + | wrap6 <= mux(lt(in1, SInt<1>("h-1")), add(in1, SInt<5>("h9")), in1) + | wrap7 <= in1 + | clip3 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<2>("h-2")), SInt<2>("h-2"), in1)) + | clip4 <= mux(gt(in1, SInt<2>("h1")), SInt<2>("h1"), mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1)) + | clip5 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), in1) + | clip6 <= mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1) + | clip7 <= in1 """.stripMargin - //| output wrap4: Interval - //| wrap4 <= wrap(in1, asInterval(s, -1, 1, 0), 0) - //| wrap4 <= add(rem(sub(in1, SInt<1>("h-1")), sub(SInt<2>("h1"), SInt<1>("h-1"))), SInt<1>("h-1")) - executeTest(input, check.split("\n") map normalized, passes) + //| output wrap4: Interval + //| wrap4 <= wrap(in1, asInterval(s, -1, 1, 0), 0) + //| wrap4 <= add(rem(sub(in1, SInt<1>("h-1")), sub(SInt<2>("h1"), SInt<1>("h-1"))), SInt<1>("h-1")) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "shift wrap/clip correctly" in { val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals()) - val input = - s"""circuit Unit : - | module Unit : - | input s: SInt<2> - | input in1: Interval[-3, 5].1 - | output wrap1: Interval - | output clip1: Interval - | wrap1 <= wrap(in1, asInterval(s, -2, 2, 0)) - | clip1 <= clip(in1, asInterval(s, -2, 2, 0)) - | """.stripMargin + val input = + s"""circuit Unit : + | module Unit : + | input s: SInt<2> + | input in1: Interval[-3, 5].1 + | output wrap1: Interval + | output clip1: Interval + | wrap1 <= wrap(in1, asInterval(s, -2, 2, 0)) + | clip1 <= clip(in1, asInterval(s, -2, 2, 0)) + | """.stripMargin val check = s""" - | wrap1 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), mux(lt(in1, SInt<3>("h-4")), add(in1, SInt<5>("h9")), in1)) - | clip1 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<3>("h-4")), SInt<3>("h-4"), in1)) + | wrap1 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), mux(lt(in1, SInt<3>("h-4")), add(in1, SInt<5>("h9")), in1)) + | clip1 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<3>("h-4")), SInt<3>("h-4"), in1)) """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "infer negative binary points" in { val passes = Seq(ToWorkingIR, new ResolveAndCheck()) - val input = - s"""circuit Unit : - | module Unit : - | input in1: Interval[-2, 4].-1 - | input in2: Interval[-4, 8].-2 - | output out: Interval - | out <= add(in1, in2) - | """.stripMargin + val input = + s"""circuit Unit : + | module Unit : + | input in1: Interval[-2, 4].-1 + | input in2: Interval[-4, 8].-2 + | output out: Interval + | out <= add(in1, in2) + | """.stripMargin val check = s""" - | output out : Interval[-6, 12].-1 + | output out : Interval[-6, 12].-1 """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "remove negative binary points" in { - val passes = Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths(), new RemoveIntervals()) - val input = - s"""circuit Unit : - | module Unit : - | input in1: Interval[-2, 4].-1 - | input in2: Interval[-4, 8].-2 - | output out: Interval.0 - | out <= add(in1, in2) - | """.stripMargin + val passes = Seq( + ToWorkingIR, + InferTypes, + ResolveFlows, + new InferBinaryPoints(), + new TrimIntervals(), + new InferWidths(), + new RemoveIntervals() + ) + val input = + s"""circuit Unit : + | module Unit : + | input in1: Interval[-2, 4].-1 + | input in2: Interval[-4, 8].-2 + | output out: Interval.0 + | out <= add(in1, in2) + | """.stripMargin val check = s""" - | output out : SInt<5> - | out <= shl(add(in1, shl(in2, 1)), 1) + | output out : SInt<5> + | out <= shl(add(in1, shl(in2, 1)), 1) """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "implement squz properly" in { val passes = Seq(ToWorkingIR, new ResolveAndCheck) @@ -372,7 +398,7 @@ class IntervalSpec extends FirrtlFlatSpec { | output minOff : Interval[-1, 4].1 | output offMin : Interval[-1, 4].2 """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "lower squz properly" in { val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals) @@ -413,7 +439,7 @@ class IntervalSpec extends FirrtlFlatSpec { | minOff <= asSInt(bits(min, 4, 0)) | offMin <= asSInt(bits(off, 5, 0)) """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Assigning a larger interval to a smaller interval" should "error!" in { val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals) @@ -424,7 +450,7 @@ class IntervalSpec extends FirrtlFlatSpec { | output out: Interval[2, 3].1 | out <= in | """.stripMargin - intercept[InvalidConnect]{ + intercept[InvalidConnect] { executeTest(input, Nil, passes) } } @@ -437,7 +463,7 @@ class IntervalSpec extends FirrtlFlatSpec { | output out: Interval[2, 3].1 | out <= in | """.stripMargin - intercept[InvalidConnect]{ + intercept[InvalidConnect] { executeTest(input, Nil, passes) } } @@ -512,7 +538,6 @@ class IntervalSpec extends FirrtlFlatSpec { ) } - "Wrap with remainder" should "error" in { intercept[WrapWithRemainder] { val input = |
