diff options
Diffstat (limited to 'src/test/scala/firrtlTests/interval/IntervalMathSpec.scala')
| -rw-r--r-- | src/test/scala/firrtlTests/interval/IntervalMathSpec.scala | 183 |
1 files changed, 183 insertions, 0 deletions
diff --git a/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala b/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala new file mode 100644 index 00000000..20fdeee1 --- /dev/null +++ b/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala @@ -0,0 +1,183 @@ +// See LICENSE for license details. + +package firrtlTests.interval + +import firrtl.Implicits.constraint2bound +import firrtl.{ChirrtlForm, CircuitState, LowFirrtlCompiler, Parser} +import firrtl.ir._ + +import scala.math.BigDecimal.RoundingMode._ +import firrtl.Parser.IgnoreInfo +import firrtl.constraint._ +import firrtlTests.FirrtlFlatSpec + +class IntervalMathSpec extends FirrtlFlatSpec { + val SumPattern = """.*output sum.*<(\d+)>.*""".r + val ProductPattern = """.*output product.*<(\d+)>.*""".r + val DifferencePattern = """.*output difference.*<(\d+)>.*""".r + val ComparisonPattern = """.*output (\w+).*UInt<(\d+)>.*""".r + val ShiftLeftPattern = """.*output shl.*<(\d+)>.*""".r + val ShiftRightPattern = """.*output shr.*<(\d+)>.*""".r + val DShiftLeftPattern = """.*output dshl.*<(\d+)>.*""".r + val DShiftRightPattern = """.*output dshr.*<(\d+)>.*""".r + val ArithAssignPattern = """\s*(\w+) <= asSInt\(bits\((\w+)\((.*)\).*\)\)\s*""".r + def getBound(bound: String, value: Double): IsKnown = bound match { + case "[" => Closed(BigDecimal(value)) + case "]" => Closed(BigDecimal(value)) + case "(" => Open(BigDecimal(value)) + case ")" => Open(BigDecimal(value)) + } + + val prec = 0.5 + + for { + lb1 <- Seq("[", "(") + lv1 <- Range.Double(-1.0, 1.0, prec) + uv1 <- if(lb1 == "[") Range.Double(lv1, 1.0, prec) else Range.Double(lv1 + prec, 1.0, prec) + ub1 <- if (lv1 == uv1) Seq("]") else Seq("]", ")") + bp1 <- 0 to 1 + lb2 <- Seq("[", "(") + lv2 <- Range.Double(-1.0, 1.0, prec) + uv2 <- if(lb2 == "[") Range.Double(lv2, 1.0, prec) else Range.Double(lv2 + prec, 1.0, prec) + ub2 <- if (lv2 == uv2) Seq("]") else Seq("]", ")") + bp2 <- 0 to 1 + } { + val it1 = IntervalType(getBound(lb1, lv1), getBound(ub1, uv1), IntWidth(bp1.toInt)) + val it2 = IntervalType(getBound(lb2, lv2), getBound(ub2, uv2), IntWidth(bp2.toInt)) + (it1.range, it2.range) match { + case (Some(Nil), _) => + case (_, Some(Nil)) => + case _ => + def config = s"$lb1$lv1,$uv1$ub1.$bp1 and $lb2$lv2,$uv2$ub2.$bp2" + + s"Configuration $config" should "pass" in { + + val input = + s"""circuit Unit : + | module Unit : + | input in1 : Interval$lb1$lv1, $uv1$ub1.$bp1 + | input in2 : Interval$lb2$lv2, $uv2$ub2.$bp2 + | input amt : UInt<3> + | output sum : Interval + | output difference : Interval + | output product : Interval + | output shl : Interval + | output shr : Interval + | output dshl : Interval + | output dshr : Interval + | output lt : UInt + | output leq : UInt + | output gt : UInt + | output geq : UInt + | output eq : UInt + | output neq : UInt + | output cat : UInt + | sum <= add(in1, in2) + | difference <= sub(in1, in2) + | product <= mul(in1, in2) + | shl <= shl(in1, 3) + | shr <= shr(in1, 3) + | dshl <= dshl(in1, amt) + | dshr <= dshr(in1, amt) + | lt <= lt(in1, in2) + | leq <= leq(in1, in2) + | gt <= gt(in1, in2) + | geq <= geq(in1, in2) + | eq <= eq(in1, in2) + | neq <= lt(in1, in2) + | cat <= cat(in1, in2) + | """.stripMargin + + val lowerer = new LowFirrtlCompiler + val res = lowerer.compileAndEmit(CircuitState(parse(input), ChirrtlForm)) + val output = res.getEmittedCircuit.value split "\n" + val min1 = Closed(it1.min.get) + val max1 = Closed(it1.max.get) + val min2 = Closed(it2.min.get) + val max2 = Closed(it2.max.get) + for (line <- output) { + line match { + case SumPattern(varWidth) => + val bp = IntWidth(Math.max(bp1.toInt, bp2.toInt)) + val it = IntervalType(IsAdd(min1, min2), IsAdd(max1, max2), bp) + assert(varWidth.toInt == it.width.asInstanceOf[IntWidth].width, s"$line,${it.range}") + case ProductPattern(varWidth) => + val bp = IntWidth(bp1.toInt + bp2.toInt) + val lv = IsMin(Seq(IsMul(min1, min2), IsMul(min1, max2), IsMul(max1, min2), IsMul(max1, max2))) + val uv = IsMax(Seq(IsMul(min1, min2), IsMul(min1, max2), IsMul(max1, min2), IsMul(max1, max2))) + assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "product") + case DifferencePattern(varWidth) => + val bp = IntWidth(Math.max(bp1.toInt, bp2.toInt)) + val lv = min1 + max2.neg + val uv = max1 + min2.neg + assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "diff") + case ShiftLeftPattern(varWidth) => + val bp = IntWidth(bp1.toInt) + val lv = min1 * Closed(8) + val uv = max1 * Closed(8) + val it = IntervalType(lv, uv, bp) + assert(varWidth.toInt == it.width.asInstanceOf[IntWidth].width, "shl") + case ShiftRightPattern(varWidth) => + val bp = IntWidth(bp1.toInt) + val lv = min1 * Closed(1/3) + val uv = max1 * Closed(1/3) + assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "shr") + case DShiftLeftPattern(varWidth) => + val bp = IntWidth(bp1.toInt) + val lv = min1 * Closed(128) + val uv = max1 * Closed(128) + assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "dshl") + case DShiftRightPattern(varWidth) => + val bp = IntWidth(bp1.toInt) + val lv = min1 + val uv = max1 + assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "dshr") + case ComparisonPattern(varWidth) => assert(varWidth.toInt == 1, "==") + case ArithAssignPattern(varName, operation, args) => + val arg1 = if(IntervalType(getBound(lb1, lv1), getBound(ub1, uv1), IntWidth(bp1)).width == IntWidth(0)) """SInt<1>("h0")""" else "in1" + val arg2 = if(IntervalType(getBound(lb2, lv2), getBound(ub2, uv2), IntWidth(bp2)).width == IntWidth(0)) """SInt<1>("h0")""" else "in2" + varName match { + case "sum" => + assert(operation === "add", s"""var sum should be result of an add in ${output.mkString("\n")}""") + if (bp1 > bp2) { + if (arg1 != arg2) assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(args.contains(s"shl($arg2, ${bp1 - bp2})"), + s"$config second arg incorrect in $line") + } else if (bp1 < bp2) { + assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"), + s"$config second arg incorrect in $line") + assert(!args.contains("shl($arg2"), s"$config second arg should be just $arg2 in $line") + } else { + assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") + } + case "product" => + assert(operation === "mul", s"var sum should be result of an add in $line") + assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") + case "difference" => + assert(operation === "sub", s"var difference should be result of an sub in $line") + if (bp1 > bp2) { + if (arg1 != arg2) assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(args.contains(s"shl($arg2, ${bp1 - bp2})"), + s"$config second arg incorrect in $line") + } else if (bp1 < bp2) { + assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"), + s"$config second arg incorrect in $line") + if (arg1 != arg2) assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") + } else { + assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") + } + case _ => + } + case _ => + } + } + } + } + } +} + + +// vim: set ts=4 sw=4 et: |
