aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAdam Izraelevitz2016-04-27 14:57:12 -0700
committerjackkoenig2016-05-10 14:52:04 -0700
commita73efa2f67428101cf0984a8fb8ac3ebf32b914b (patch)
tree5e54bf0a8366c8f2a953241782a4f08a390c1fad /src
parent7f9814eb8464463983d3d6aeac45dadee493fb5c (diff)
Add test suite for Constant Propagation
Add unit tests for splitting expressions and padding widths
Diffstat (limited to 'src')
-rw-r--r--src/test/scala/firrtlTests/ConstantPropagationTests.scala350
-rw-r--r--src/test/scala/firrtlTests/UnitTests.scala60
2 files changed, 408 insertions, 2 deletions
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
new file mode 100644
index 00000000..5f5705d9
--- /dev/null
+++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
@@ -0,0 +1,350 @@
+package firrtlTests
+
+import org.scalatest.Matchers
+import java.io.{StringWriter,Writer}
+import firrtl._
+import firrtl.passes._
+
+// Tests the following cases for constant propagation:
+// 1) Unsigned integers are always greater than or
+// equal to zero
+// 2) Values are always smaller than a number greater
+// than their maximum value
+// 3) Values are always greater than a number smaller
+// than their minimum value
+class ConstantPropagationSpec extends FirrtlFlatSpec {
+ val passes = Seq(
+ ToWorkingIR,
+ ResolveKinds,
+ InferTypes,
+ ResolveGenders,
+ InferWidths,
+ ConstProp)
+ def parse(input: String): Circuit = Parser.parse("", input.split("\n").toIterator, false)
+ private def exec (input: String) = {
+ passes.foldLeft(parse(input)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }.serialize
+ }
+ // =============================
+ "The rule x >= 0 " should " always be true if x is a UInt" in {
+ val input =
+"""circuit Top :
+ module Top :
+ input x : UInt<5>
+ output y : UInt<1>
+ y <= geq(x, UInt(0))
+"""
+ val check =
+"""circuit Top :
+ module Top :
+ input x : UInt<5>
+ output y : UInt<1>
+ y <= UInt<1>("h1")
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
+
+ // =============================
+ "The rule x < 0 " should " never be true if x is a UInt" in {
+ val input =
+"""circuit Top :
+ module Top :
+ input x : UInt<5>
+ output y : UInt<1>
+ y <= lt(x, UInt(0))
+"""
+ val check =
+"""circuit Top :
+ module Top :
+ input x : UInt<5>
+ output y : UInt<1>
+ y <= UInt<1>(0)
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
+
+ // =============================
+ "The rule 0 <= x " should " always be true if x is a UInt" in {
+ val input =
+"""circuit Top :
+ module Top :
+ input x : UInt<5>
+ output y : UInt<1>
+ y <= leq(UInt(0),x)
+"""
+ val check =
+"""circuit Top :
+ module Top :
+ input x : UInt<5>
+ output y : UInt<1>
+ y <= UInt<1>(1)
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
+
+ // =============================
+ "The rule 0 > x " should " never be true if x is a UInt" in {
+ val input =
+"""circuit Top :
+ module Top :
+ input x : UInt<5>
+ output y : UInt<1>
+ y <= gt(UInt(0),x)
+"""
+ val check =
+"""circuit Top :
+ module Top :
+ input x : UInt<5>
+ output y : UInt<1>
+ y <= UInt<1>(0)
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
+
+ // =============================
+ "The rule 1 < 3 " should " always be true" in {
+ val input =
+"""circuit Top :
+ module Top :
+ input x : UInt<5>
+ output y : UInt<1>
+ y <= lt(UInt(0),UInt(3))
+"""
+ val check =
+"""circuit Top :
+ module Top :
+ input x : UInt<5>
+ output y : UInt<1>
+ y <= UInt<1>(1)
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
+
+ // =============================
+ "The rule x < 8 " should " always be true if x only has 3 bits" in {
+ val input =
+"""circuit Top :
+ module Top :
+ input x : UInt<3>
+ output y : UInt<1>
+ y <= lt(x,UInt(8))
+"""
+ val check =
+"""circuit Top :
+ module Top :
+ input x : UInt<3>
+ output y : UInt<1>
+ y <= UInt<1>(1)
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
+
+ // =============================
+ "The rule x <= 7 " should " always be true if x only has 3 bits" in {
+ val input =
+"""circuit Top :
+ module Top :
+ input x : UInt<3>
+ output y : UInt<1>
+ y <= leq(x,UInt(7))
+"""
+ val check =
+"""circuit Top :
+ module Top :
+ input x : UInt<3>
+ output y : UInt<1>
+ y <= UInt<1>(1)
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
+
+ // =============================
+ "The rule 8 > x" should " always be true if x only has 3 bits" in {
+ val input =
+"""circuit Top :
+ module Top :
+ input x : UInt<3>
+ output y : UInt<1>
+ y <= gt(UInt(8),x)
+"""
+ val check =
+"""circuit Top :
+ module Top :
+ input x : UInt<3>
+ output y : UInt<1>
+ y <= UInt<1>(1)
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
+
+ // =============================
+ "The rule 7 >= x" should " always be true if x only has 3 bits" in {
+ val input =
+"""circuit Top :
+ module Top :
+ input x : UInt<3>
+ output y : UInt<1>
+ y <= geq(UInt(7),x)
+"""
+ val check =
+"""circuit Top :
+ module Top :
+ input x : UInt<3>
+ output y : UInt<1>
+ y <= UInt<1>(1)
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
+
+ // =============================
+ "The rule 10 == 10" should " always be true" in {
+ val input =
+"""circuit Top :
+ module Top :
+ input x : UInt<3>
+ output y : UInt<1>
+ y <= eq(UInt(10),UInt(10))
+"""
+ val check =
+"""circuit Top :
+ module Top :
+ input x : UInt<3>
+ output y : UInt<1>
+ y <= UInt<1>(1)
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
+
+ // =============================
+ "The rule x == z " should " not be true even if they have the same number of bits" in {
+ val input =
+"""circuit Top :
+ module Top :
+ input x : UInt<3>
+ input z : UInt<3>
+ output y : UInt<1>
+ y <= eq(x,z)
+"""
+ val check =
+"""circuit Top :
+ module Top :
+ input x : UInt<3>
+ input z : UInt<3>
+ output y : UInt<1>
+ y <= eq(x,z)
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
+
+ // =============================
+ "The rule 10 != 10 " should " always be false" in {
+ val input =
+"""circuit Top :
+ module Top :
+ output y : UInt<1>
+ y <= neq(UInt(10),UInt(10))
+"""
+ val check =
+"""circuit Top :
+ module Top :
+ output y : UInt<1>
+ y <= UInt(0)
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
+ // =============================
+ "The rule 1 >= 3 " should " always be false" in {
+ val input =
+"""circuit Top :
+ module Top :
+ input x : UInt<5>
+ output y : UInt<1>
+ y <= geq(UInt(1),UInt(3))
+"""
+ val check =
+"""circuit Top :
+ module Top :
+ input x : UInt<5>
+ output y : UInt<1>
+ y <= UInt<1>(0)
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
+
+ // =============================
+ "The rule x >= 8 " should " never be true if x only has 3 bits" in {
+ val input =
+"""circuit Top :
+ module Top :
+ input x : UInt<3>
+ output y : UInt<1>
+ y <= geq(x,UInt(8))
+"""
+ val check =
+"""circuit Top :
+ module Top :
+ input x : UInt<3>
+ output y : UInt<1>
+ y <= UInt<1>(0)
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
+
+ // =============================
+ "The rule x > 7 " should " never be true if x only has 3 bits" in {
+ val input =
+"""circuit Top :
+ module Top :
+ input x : UInt<3>
+ output y : UInt<1>
+ y <= gt(x,UInt(7))
+"""
+ val check =
+"""circuit Top :
+ module Top :
+ input x : UInt<3>
+ output y : UInt<1>
+ y <= UInt<1>(0)
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
+
+ // =============================
+ "The rule 8 <= x" should " never be true if x only has 3 bits" in {
+ val input =
+"""circuit Top :
+ module Top :
+ input x : UInt<3>
+ output y : UInt<1>
+ y <= leq(UInt(8),x)
+"""
+ val check =
+"""circuit Top :
+ module Top :
+ input x : UInt<3>
+ output y : UInt<1>
+ y <= UInt<1>(0)
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
+
+ // =============================
+ "The rule 7 < x" should " never be true if x only has 3 bits" in {
+ val input =
+"""circuit Top :
+ module Top :
+ input x : UInt<3>
+ output y : UInt<1>
+ y <= lt(UInt(7),x)
+"""
+ val check =
+"""circuit Top :
+ module Top :
+ input x : UInt<3>
+ output y : UInt<1>
+ y <= UInt<1>(0)
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
+}
diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala
index a2968ac5..7276aabb 100644
--- a/src/test/scala/firrtlTests/UnitTests.scala
+++ b/src/test/scala/firrtlTests/UnitTests.scala
@@ -33,8 +33,19 @@ import org.scalatest.prop._
import firrtl._
import firrtl.passes._
-class UnitTests extends FlatSpec with Matchers {
+class UnitTests extends FirrtlFlatSpec {
def parse (input:String) = Parser.parse("",input.split("\n").toIterator,false)
+ private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = {
+ val c = passes.foldLeft(Parser.parse("", input.split("\n").toIterator)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }
+ val lines = c.serialize.split("\n") map normalized
+
+ expected foreach { e =>
+ lines should contain(e)
+ }
+ }
+
"Connecting bundles of different types" should "throw an exception" in {
val passes = Seq(
ToWorkingIR,
@@ -130,10 +141,55 @@ class UnitTests extends FlatSpec with Matchers {
"After splitting, emitting a nested expression" should "compile" in {
val passes = Seq(
ToWorkingIR,
- SplitExp,
+ SplitExpressions,
InferTypes)
val c = Parser.parse("",splitExpTestCode.split("\n").toIterator)
val c2 = passes.foldLeft(c)((c, p) => p run c)
new VerilogEmitter().run(c2, new OutputStreamWriter(new ByteArrayOutputStream))
}
+
+ "Simple compound expressions" should "be split" in {
+ val passes = Seq(
+ ToWorkingIR,
+ ResolveKinds,
+ InferTypes,
+ ResolveGenders,
+ InferWidths,
+ SplitExpressions
+ )
+ val input =
+ """circuit Top :
+ | module Top :
+ | input a : UInt<32>
+ | input b : UInt<32>
+ | input d : UInt<32>
+ | output c : UInt<1>
+ | c <= geq(add(a, b),d)""".stripMargin
+ val check = Seq(
+ "node GEN_0 = add(a, b)",
+ "c <= geq(GEN_0, d)"
+ )
+ executeTest(input, check, passes)
+ }
+
+ "Smaller widths" should "be explicitly padded" in {
+ val passes = Seq(
+ ToWorkingIR,
+ ResolveKinds,
+ InferTypes,
+ ResolveGenders,
+ InferWidths,
+ PadWidths
+ )
+ val input =
+ """circuit Top :
+ | module Top :
+ | input a : UInt<32>
+ | input b : UInt<20>
+ | input pred : UInt<1>
+ | output c : UInt<32>
+ | c <= mux(pred,a,b)""".stripMargin
+ val check = Seq("c <= mux(pred, a, pad(b, 32))")
+ executeTest(input, check, passes)
+ }
}