diff options
| author | Adam Izraelevitz | 2016-04-27 14:57:12 -0700 |
|---|---|---|
| committer | jackkoenig | 2016-05-10 14:52:04 -0700 |
| commit | a73efa2f67428101cf0984a8fb8ac3ebf32b914b (patch) | |
| tree | 5e54bf0a8366c8f2a953241782a4f08a390c1fad | |
| parent | 7f9814eb8464463983d3d6aeac45dadee493fb5c (diff) | |
Add test suite for Constant Propagation
Add unit tests for splitting expressions and padding widths
| -rw-r--r-- | src/test/scala/firrtlTests/ConstantPropagationTests.scala | 350 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/UnitTests.scala | 60 |
2 files changed, 408 insertions, 2 deletions
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala new file mode 100644 index 00000000..5f5705d9 --- /dev/null +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -0,0 +1,350 @@ +package firrtlTests + +import org.scalatest.Matchers +import java.io.{StringWriter,Writer} +import firrtl._ +import firrtl.passes._ + +// Tests the following cases for constant propagation: +// 1) Unsigned integers are always greater than or +// equal to zero +// 2) Values are always smaller than a number greater +// than their maximum value +// 3) Values are always greater than a number smaller +// than their minimum value +class ConstantPropagationSpec extends FirrtlFlatSpec { + val passes = Seq( + ToWorkingIR, + ResolveKinds, + InferTypes, + ResolveGenders, + InferWidths, + ConstProp) + def parse(input: String): Circuit = Parser.parse("", input.split("\n").toIterator, false) + private def exec (input: String) = { + passes.foldLeft(parse(input)) { + (c: Circuit, p: Pass) => p.run(c) + }.serialize + } + // ============================= + "The rule x >= 0 " should " always be true if x is a UInt" in { + val input = +"""circuit Top : + module Top : + input x : UInt<5> + output y : UInt<1> + y <= geq(x, UInt(0)) +""" + val check = +"""circuit Top : + module Top : + input x : UInt<5> + output y : UInt<1> + y <= UInt<1>("h1") +""" + (parse(exec(input))) should be (parse(check)) + } + + // ============================= + "The rule x < 0 " should " never be true if x is a UInt" in { + val input = +"""circuit Top : + module Top : + input x : UInt<5> + output y : UInt<1> + y <= lt(x, UInt(0)) +""" + val check = +"""circuit Top : + module Top : + input x : UInt<5> + output y : UInt<1> + y <= UInt<1>(0) +""" + (parse(exec(input))) should be (parse(check)) + } + + // ============================= + "The rule 0 <= x " should " always be true if x is a UInt" in { + val input = +"""circuit Top : + module Top : + input x : UInt<5> + output y : UInt<1> + y <= leq(UInt(0),x) +""" + val check = +"""circuit Top : + module Top : + input x : UInt<5> + output y : UInt<1> + y <= UInt<1>(1) +""" + (parse(exec(input))) should be (parse(check)) + } + + // ============================= + "The rule 0 > x " should " never be true if x is a UInt" in { + val input = +"""circuit Top : + module Top : + input x : UInt<5> + output y : UInt<1> + y <= gt(UInt(0),x) +""" + val check = +"""circuit Top : + module Top : + input x : UInt<5> + output y : UInt<1> + y <= UInt<1>(0) +""" + (parse(exec(input))) should be (parse(check)) + } + + // ============================= + "The rule 1 < 3 " should " always be true" in { + val input = +"""circuit Top : + module Top : + input x : UInt<5> + output y : UInt<1> + y <= lt(UInt(0),UInt(3)) +""" + val check = +"""circuit Top : + module Top : + input x : UInt<5> + output y : UInt<1> + y <= UInt<1>(1) +""" + (parse(exec(input))) should be (parse(check)) + } + + // ============================= + "The rule x < 8 " should " always be true if x only has 3 bits" in { + val input = +"""circuit Top : + module Top : + input x : UInt<3> + output y : UInt<1> + y <= lt(x,UInt(8)) +""" + val check = +"""circuit Top : + module Top : + input x : UInt<3> + output y : UInt<1> + y <= UInt<1>(1) +""" + (parse(exec(input))) should be (parse(check)) + } + + // ============================= + "The rule x <= 7 " should " always be true if x only has 3 bits" in { + val input = +"""circuit Top : + module Top : + input x : UInt<3> + output y : UInt<1> + y <= leq(x,UInt(7)) +""" + val check = +"""circuit Top : + module Top : + input x : UInt<3> + output y : UInt<1> + y <= UInt<1>(1) +""" + (parse(exec(input))) should be (parse(check)) + } + + // ============================= + "The rule 8 > x" should " always be true if x only has 3 bits" in { + val input = +"""circuit Top : + module Top : + input x : UInt<3> + output y : UInt<1> + y <= gt(UInt(8),x) +""" + val check = +"""circuit Top : + module Top : + input x : UInt<3> + output y : UInt<1> + y <= UInt<1>(1) +""" + (parse(exec(input))) should be (parse(check)) + } + + // ============================= + "The rule 7 >= x" should " always be true if x only has 3 bits" in { + val input = +"""circuit Top : + module Top : + input x : UInt<3> + output y : UInt<1> + y <= geq(UInt(7),x) +""" + val check = +"""circuit Top : + module Top : + input x : UInt<3> + output y : UInt<1> + y <= UInt<1>(1) +""" + (parse(exec(input))) should be (parse(check)) + } + + // ============================= + "The rule 10 == 10" should " always be true" in { + val input = +"""circuit Top : + module Top : + input x : UInt<3> + output y : UInt<1> + y <= eq(UInt(10),UInt(10)) +""" + val check = +"""circuit Top : + module Top : + input x : UInt<3> + output y : UInt<1> + y <= UInt<1>(1) +""" + (parse(exec(input))) should be (parse(check)) + } + + // ============================= + "The rule x == z " should " not be true even if they have the same number of bits" in { + val input = +"""circuit Top : + module Top : + input x : UInt<3> + input z : UInt<3> + output y : UInt<1> + y <= eq(x,z) +""" + val check = +"""circuit Top : + module Top : + input x : UInt<3> + input z : UInt<3> + output y : UInt<1> + y <= eq(x,z) +""" + (parse(exec(input))) should be (parse(check)) + } + + // ============================= + "The rule 10 != 10 " should " always be false" in { + val input = +"""circuit Top : + module Top : + output y : UInt<1> + y <= neq(UInt(10),UInt(10)) +""" + val check = +"""circuit Top : + module Top : + output y : UInt<1> + y <= UInt(0) +""" + (parse(exec(input))) should be (parse(check)) + } + // ============================= + "The rule 1 >= 3 " should " always be false" in { + val input = +"""circuit Top : + module Top : + input x : UInt<5> + output y : UInt<1> + y <= geq(UInt(1),UInt(3)) +""" + val check = +"""circuit Top : + module Top : + input x : UInt<5> + output y : UInt<1> + y <= UInt<1>(0) +""" + (parse(exec(input))) should be (parse(check)) + } + + // ============================= + "The rule x >= 8 " should " never be true if x only has 3 bits" in { + val input = +"""circuit Top : + module Top : + input x : UInt<3> + output y : UInt<1> + y <= geq(x,UInt(8)) +""" + val check = +"""circuit Top : + module Top : + input x : UInt<3> + output y : UInt<1> + y <= UInt<1>(0) +""" + (parse(exec(input))) should be (parse(check)) + } + + // ============================= + "The rule x > 7 " should " never be true if x only has 3 bits" in { + val input = +"""circuit Top : + module Top : + input x : UInt<3> + output y : UInt<1> + y <= gt(x,UInt(7)) +""" + val check = +"""circuit Top : + module Top : + input x : UInt<3> + output y : UInt<1> + y <= UInt<1>(0) +""" + (parse(exec(input))) should be (parse(check)) + } + + // ============================= + "The rule 8 <= x" should " never be true if x only has 3 bits" in { + val input = +"""circuit Top : + module Top : + input x : UInt<3> + output y : UInt<1> + y <= leq(UInt(8),x) +""" + val check = +"""circuit Top : + module Top : + input x : UInt<3> + output y : UInt<1> + y <= UInt<1>(0) +""" + (parse(exec(input))) should be (parse(check)) + } + + // ============================= + "The rule 7 < x" should " never be true if x only has 3 bits" in { + val input = +"""circuit Top : + module Top : + input x : UInt<3> + output y : UInt<1> + y <= lt(UInt(7),x) +""" + val check = +"""circuit Top : + module Top : + input x : UInt<3> + output y : UInt<1> + y <= UInt<1>(0) +""" + (parse(exec(input))) should be (parse(check)) + } +} diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index a2968ac5..7276aabb 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -33,8 +33,19 @@ import org.scalatest.prop._ import firrtl._ import firrtl.passes._ -class UnitTests extends FlatSpec with Matchers { +class UnitTests extends FirrtlFlatSpec { def parse (input:String) = Parser.parse("",input.split("\n").toIterator,false) + private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { + val c = passes.foldLeft(Parser.parse("", input.split("\n").toIterator)) { + (c: Circuit, p: Pass) => p.run(c) + } + val lines = c.serialize.split("\n") map normalized + + expected foreach { e => + lines should contain(e) + } + } + "Connecting bundles of different types" should "throw an exception" in { val passes = Seq( ToWorkingIR, @@ -130,10 +141,55 @@ class UnitTests extends FlatSpec with Matchers { "After splitting, emitting a nested expression" should "compile" in { val passes = Seq( ToWorkingIR, - SplitExp, + SplitExpressions, InferTypes) val c = Parser.parse("",splitExpTestCode.split("\n").toIterator) val c2 = passes.foldLeft(c)((c, p) => p run c) new VerilogEmitter().run(c2, new OutputStreamWriter(new ByteArrayOutputStream)) } + + "Simple compound expressions" should "be split" in { + val passes = Seq( + ToWorkingIR, + ResolveKinds, + InferTypes, + ResolveGenders, + InferWidths, + SplitExpressions + ) + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | input b : UInt<32> + | input d : UInt<32> + | output c : UInt<1> + | c <= geq(add(a, b),d)""".stripMargin + val check = Seq( + "node GEN_0 = add(a, b)", + "c <= geq(GEN_0, d)" + ) + executeTest(input, check, passes) + } + + "Smaller widths" should "be explicitly padded" in { + val passes = Seq( + ToWorkingIR, + ResolveKinds, + InferTypes, + ResolveGenders, + InferWidths, + PadWidths + ) + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | input b : UInt<20> + | input pred : UInt<1> + | output c : UInt<32> + | c <= mux(pred,a,b)""".stripMargin + val check = Seq("c <= mux(pred, a, pad(b, 32))") + executeTest(input, check, passes) + } } |
