aboutsummaryrefslogtreecommitdiff
path: root/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'src/test')
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala61
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala224
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala314
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala38
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala66
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala46
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala230
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala107
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/end2end/RequiresZ3.scala9
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala46
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala40
11 files changed, 1181 insertions, 0 deletions
diff --git a/src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala b/src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala
new file mode 100644
index 00000000..56f891e6
--- /dev/null
+++ b/src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala
@@ -0,0 +1,61 @@
+// See LICENSE for license details.
+
+package firrtl.backends.experimental.smt
+
+private class Btor2Spec extends SMTBackendBaseSpec {
+
+ it should "convert a hello world module" in {
+ val src =
+ """circuit m:
+ | module m:
+ | input clock: Clock
+ | input a: UInt<8>
+ | output b: UInt<16>
+ | b <= a
+ | assert(clock, eq(a, b), UInt(1), "")
+ |""".stripMargin
+
+ val expected =
+ """1 sort bitvec 8
+ |2 input 1 a
+ |3 sort bitvec 16
+ |4 uext 3 2 8
+ |5 output 4 ; b
+ |6 sort bitvec 1
+ |7 uext 3 2 8
+ |8 eq 6 7 4
+ |9 not 6 8
+ |10 bad 9 ; assert_
+ |""".stripMargin
+
+ assert(toBotr2Str(src) == expected)
+ }
+
+ it should "include FileInfo in the output" in {
+ val src =
+ """circuit m: @[circuit 0:0]
+ | module m: @[module 0:0]
+ | input clock: Clock @[clock 0:0]
+ | input a: UInt<8> @[a 0:0]
+ | output b: UInt<16> @[b 0:0]
+ | b <= a @[b_a 0:0]
+ | assert(clock, eq(a, b), UInt(1), "") @[assert 0:0]
+ |""".stripMargin
+
+ val expected =
+ """; @ module 0:0
+ |1 sort bitvec 8
+ |2 input 1 a ; @ a 0:0
+ |3 sort bitvec 16
+ |4 uext 3 2 8
+ |5 output 4 ; b @ b 0:0, b_a 0:0
+ |6 sort bitvec 1
+ |7 uext 3 2 8
+ |8 eq 6 7 4
+ |9 not 6 8
+ |10 bad 9 ; assert_ @ assert 0:0
+ |""".stripMargin
+
+ assert(toBotr2Str(src) == expected)
+ }
+}
diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala
new file mode 100644
index 00000000..015ac4a9
--- /dev/null
+++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala
@@ -0,0 +1,224 @@
+// See LICENSE for license details.
+
+package firrtl.backends.experimental.smt
+
+private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec {
+
+ def primop(op: String, resTpe: String, inTpes: Seq[String], consts: Seq[Int]): String = {
+ val inputs = inTpes.zipWithIndex.map { case (tpe, ii) => s" input i$ii : $tpe" }.mkString("\n")
+ val args = (inTpes.zipWithIndex.map { case (_, ii) => s"i$ii" } ++ consts.map(_.toString)).mkString(", ")
+ val src =
+ s"""circuit m:
+ | module m:
+ |$inputs
+ | output res: $resTpe
+ | res <= $op($args)
+ |
+ |""".stripMargin
+ val sys = toSys(src)
+ assert(sys.signals.length == 1)
+ sys.signals.head.e.toString
+ }
+
+ def primop(signed: Boolean, op: String, resWidth: Int, inWidth: Seq[Int], consts: Seq[Int] = List(),
+ resAlwaysUnsigned: Boolean = false): String = {
+ val tpe = if(signed) "SInt" else "UInt"
+ val resTpe = if(resAlwaysUnsigned) "UInt" else tpe
+ val inTpes = inWidth.map(w => s"$tpe<$w>")
+ primop(op, s"$resTpe<$resWidth>", inTpes, consts)
+ }
+
+ it should "correctly translate the add primitive operation with different operand sizes" in {
+ assert(primop(false, "add", 5, List(3, 5)) == "add(zext(i0, 3), zext(i1, 1))[4:0]")
+ assert(primop(false, "add", 5, List(3, 4)) == "add(zext(i0, 2), zext(i1, 1))")
+ assert(primop(true, "add", 5, List(3, 5)) == "add(sext(i0, 3), sext(i1, 1))[4:0]")
+ assert(primop(true, "add", 5, List(3, 4)) == "add(sext(i0, 2), sext(i1, 1))")
+
+ // could be simplified to just `add(i0, i1)`
+ assert(primop(false, "add", 8, List(8, 8)) == "add(zext(i0, 1), zext(i1, 1))[7:0]")
+ }
+
+ it should "correctly translate the `add` primitive operation" in {
+ assert(primop(false, "add", 8, List(7, 7)) == "add(zext(i0, 1), zext(i1, 1))")
+ }
+
+ it should "correctly translate the `sub` primitive operation" in {
+ assert(primop(false, "sub", 8, List(7, 7)) == "sub(zext(i0, 1), zext(i1, 1))")
+ }
+
+ it should "correctly translate the `mul` primitive operation" in {
+ assert(primop(false, "mul", 8, List(4, 4)) == "mul(zext(i0, 4), zext(i1, 4))")
+ }
+
+ it should "correctly translate the `div` primitive operation" in {
+ // division is a little bit more complicated because the result of division by zero is undefined
+ assert(primop(false, "div", 8, List(8, 8)) ==
+ "ite(eq(i1, 8'b0), RANDOM.res, udiv(i0, i1))")
+ assert(primop(false, "div", 8, List(8, 4)) ==
+ "ite(eq(i1, 4'b0), RANDOM.res, udiv(i0, zext(i1, 4)))")
+
+ // signed division increases result width by 1
+ assert(primop(true, "div", 8, List(7, 7)) ==
+ "ite(eq(i1, 7'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 1)))")
+ assert(primop(true, "div", 8, List(7, 4))
+ == "ite(eq(i1, 4'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 4)))")
+ }
+
+ it should "correctly translate the `rem` primitive operation" in {
+ // rem can decrease the size of operands, but we should only do that decrease on the result
+ assert(primop(false, "rem", 4, List(4, 8)) == "urem(zext(i0, 4), i1)[3:0]")
+ assert(primop(false, "rem", 4, List(8, 4)) == "urem(i0, zext(i1, 4))[3:0]")
+ assert(primop(true, "rem", 4, List(4, 8)) == "srem(sext(i0, 4), i1)[3:0]")
+ assert(primop(true, "rem", 4, List(8, 4)) == "srem(i0, sext(i1, 4))[3:0]")
+ // TODO: add test to make sure we are using the correct mod/rem operation for signed and unsigned
+ // https://groups.google.com/g/stp-users/c/od43h8q5RSI has some tests that we could copy and
+ // use with a SMT solver
+ }
+
+ it should "correctly translate the comparison primitive operations" in {
+ // some comparisons are represented as the negation of others
+ assert(primop(false, "lt", 1, List(8, 8)) == "not(ugeq(i0, i1))")
+ assert(primop(false, "leq", 1, List(8, 8)) == "not(ugt(i0, i1))")
+ assert(primop(false, "gt", 1, List(8, 8)) == "ugt(i0, i1)")
+ assert(primop(false, "geq", 1, List(8, 8)) == "ugeq(i0, i1)")
+ assert(primop(false, "eq", 1, List(8, 8)) == "eq(i0, i1)")
+ assert(primop(false, "neq", 1, List(8, 8)) == "not(eq(i0, i1))")
+
+ assert(primop(true, "lt", 1, List(8, 8), resAlwaysUnsigned = true) == "not(sgeq(i0, i1))")
+ assert(primop(true, "leq", 1, List(8, 8), resAlwaysUnsigned = true) == "not(sgt(i0, i1))")
+ assert(primop(true, "gt", 1, List(8, 8), resAlwaysUnsigned = true) == "sgt(i0, i1)")
+ assert(primop(true, "geq", 1, List(8, 8), resAlwaysUnsigned = true) == "sgeq(i0, i1)")
+ assert(primop(true, "eq", 1, List(8, 8), resAlwaysUnsigned = true) == "eq(i0, i1)")
+ assert(primop(true, "neq", 1, List(8, 8), resAlwaysUnsigned = true) == "not(eq(i0, i1))")
+
+ // it should always extend the width to the max of both
+ assert(primop(false, "gt", 1, List(7, 8)) == "ugt(zext(i0, 1), i1)")
+ }
+
+ it should "correctly translate the `pad` primitive operation" in {
+ // firrtl pad takes new width as argument, whereas the smt zext takes the number of bits to extend by
+ assert(primop(false, "pad", 8, List(3), List(8)) == "zext(i0, 5)")
+ assert(primop(false, "pad", 8, List(3), List(5)) == "zext(zext(i0, 2), 3)")
+
+ // there is no negative padding, instead the result is just e
+ assert(primop(false, "pad", 3, List(3), List(2)) == "i0")
+
+ assert(primop(true, "pad", 8, List(3), List(8)) == "sext(i0, 5)")
+ assert(primop(true, "pad", 8, List(3), List(5)) == "sext(sext(i0, 2), 3)")
+ }
+
+ it should "correctly translate the asX primitive operations" in {
+ // these are all essentially no-ops
+ assert(primop(false, "asUInt", 3, List(3)) == "i0")
+ assert(primop(true, "asSInt", 3, List(3)) == "i0")
+ }
+
+ it should "correctly translate the `shl` primitive operation" in {
+ assert(primop(false, "shl", 6, List(3), List(3)) == "concat(i0, 3'b0)")
+ assert(primop(true, "shl", 6, List(3), List(3)) == "concat(i0, 3'b0)")
+ assert(primop(false, "shl", 3, List(3), List(0)) == "i0")
+ }
+
+ it should "correctly translate the `shr` primitive operation" in {
+ assert(primop(false, "shr", 6, List(9), List(3)) == "i0[8:3]")
+ assert(primop(true, "shr", 6, List(9), List(3)) == "i0[8:3]")
+
+ // "If n is greater than or equal to the bit-width of e,
+ // the resulting value will be zero for unsigned types and the sign bit for signed types."
+ assert(primop(false, "shr", 1, List(3), List(3)) == "1'b0")
+ assert(primop(false, "shr", 1, List(3), List(4)) == "1'b0")
+ assert(primop(true, "shr", 1, List(3), List(3)) == "i0[2]")
+ assert(primop(true, "shr", 1, List(3), List(4)) == "i0[2]")
+ }
+
+ it should "correctly translate the `dshl` primitive operation" in {
+ assert(primop(false, "dshl", 31, List(16, 4)) == "logical_shift_left(zext(i0, 15), zext(i1, 27))")
+ assert(primop(false, "dshl", 19, List(16, 2)) == "logical_shift_left(zext(i0, 3), zext(i1, 17))")
+ assert(primop("dshl", "SInt<19>", List("SInt<16>", "UInt<2>"), List()) ==
+ "logical_shift_left(sext(i0, 3), zext(i1, 17))")
+ }
+
+ it should "correctly translate the `dshr` primitive operation" in {
+ assert(primop(false, "dshr", 16, List(16, 4)) == "logical_shift_right(i0, zext(i1, 12))")
+ assert(primop(false, "dshr", 16, List(16, 2)) == "logical_shift_right(i0, zext(i1, 14))")
+ assert(primop("dshr", "SInt<16>", List("SInt<16>", "UInt<2>"), List()) ==
+ "arithmetic_shift_right(i0, zext(i1, 14))")
+ }
+
+ it should "correctly translate the `cvt` primitive operation" in {
+ // for signed operands, this is a no-op
+ assert(primop(true, "cvt", 3, List(3)) == "i0")
+
+ // for unsigned, a zero is prepended
+ assert(primop("cvt", "SInt<16>", List("UInt<15>"), List()) == "concat(1'b0, i0)")
+ assert(primop("cvt", "SInt<16>", List("UInt<14>"), List()) == "sext(concat(1'b0, i0), 1)")
+ }
+
+ it should "correctly translate the `neg` primitive operation" in {
+ assert(primop(true, "neg", 4, List(3)) == "neg(sext(i0, 1))")
+ assert(primop("neg", "SInt<4>", List("UInt<3>"), List()) == "neg(zext(i0, 1))")
+ }
+
+ it should "correctly translate the `not` primitive operation" in {
+ assert(primop(false, "not", 4, List(4)) == "not(i0)")
+ assert(primop("not", "UInt<4>", List("SInt<4>"), List()) == "not(i0)")
+ }
+
+ it should "correctly translate the binary bitwise primitive operations" in {
+ assert(primop(false, "and", 4, List(4, 3)) == "and(i0, zext(i1, 1))")
+ assert(primop("and", "UInt<4>", List("SInt<4>", "SInt<3>"), List()) == "and(i0, sext(i1, 1))")
+
+ assert(primop(false, "or", 4, List(4, 3)) == "or(i0, zext(i1, 1))")
+ assert(primop("or", "UInt<4>", List("SInt<4>", "SInt<3>"), List()) == "or(i0, sext(i1, 1))")
+
+ assert(primop(false, "xor", 4, List(4, 3)) == "xor(i0, zext(i1, 1))")
+ assert(primop("xor", "UInt<4>", List("SInt<4>", "SInt<3>"), List()) == "xor(i0, sext(i1, 1))")
+ }
+
+ it should "correctly translate the bitwise reduction primitive operation" in {
+ // zero width special cases are removed by the firrtl compiler
+ assert(primop(false, "andr", 1, List(0)) == "1'b1")
+ assert(primop(false, "orr", 1, List(0)) == "redor(1'b0)")
+ assert(primop(false, "xorr", 1, List(0)) == "redxor(1'b0)")
+
+ assert(primop(false, "andr", 1, List(3)) == "redand(i0)")
+ assert(primop(true, "andr", 1, List(3), resAlwaysUnsigned = true) == "redand(i0)")
+
+ assert(primop(false, "orr", 1, List(3)) == "redor(i0)")
+ assert(primop(true, "orr", 1, List(3), resAlwaysUnsigned = true) == "redor(i0)")
+
+ assert(primop(false, "xorr", 1, List(3)) == "redxor(i0)")
+ assert(primop(true, "xorr", 1, List(3), resAlwaysUnsigned = true) == "redxor(i0)")
+ }
+
+ it should "correctly translate the `cat` primitive operation" in {
+ assert(primop(false, "cat", 7, List(4, 3)) == "concat(i0, i1)")
+ assert(primop(true, "cat", 7, List(4, 3), resAlwaysUnsigned = true) == "concat(i0, i1)")
+ }
+
+ it should "correctly translate the `bits` primitive operation" in {
+ assert(primop(false, "bits", 1, List(4), List(2,2)) == "i0[2]")
+ assert(primop(false, "bits", 2, List(4), List(2,1)) == "i0[2:1]")
+ assert(primop(false, "bits", 1, List(4), List(2,1)) == "i0[2:1][0]")
+ assert(primop(false, "bits", 3, List(4), List(2,1)) == "zext(i0[2:1], 1)")
+
+ assert(primop(true, "bits", 1, List(4), List(2,2), resAlwaysUnsigned = true) == "i0[2]")
+ assert(primop(true, "bits", 2, List(4), List(2,1), resAlwaysUnsigned = true) == "i0[2:1]")
+ assert(primop(true, "bits", 1, List(4), List(2,1), resAlwaysUnsigned = true) == "i0[2:1][0]")
+ assert(primop(true, "bits", 3, List(4), List(2,1), resAlwaysUnsigned = true) == "zext(i0[2:1], 1)")
+ }
+
+ it should "correctly translate the `head` primitive operation" in {
+ // "The result of the head operation are the n most significant bits of e"
+ assert(primop(false, "head", 1, List(4), List(1)) == "i0[3]")
+ assert(primop(false, "head", 1, List(5), List(1)) == "i0[4]")
+ assert(primop(false, "head", 3, List(5), List(3)) == "i0[4:2]")
+ }
+
+ it should "correctly translate the `tail` primitive operation" in {
+ // "The tail operation truncates the n most significant bits from e"
+ assert(primop(false, "tail", 3, List(4), List(1)) == "i0[2:0]")
+ assert(primop(false, "tail", 4, List(5), List(1)) == "i0[3:0]")
+ assert(primop(false, "tail", 2, List(5), List(3)) == "i0[1:0]")
+ }
+} \ No newline at end of file
diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala
new file mode 100644
index 00000000..ca7974c5
--- /dev/null
+++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala
@@ -0,0 +1,314 @@
+// See LICENSE for license details.
+
+package firrtl.backends.experimental.smt
+
+import firrtl.{MemoryArrayInit, MemoryScalarInit, Utils}
+
+private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec {
+ behavior of "ModuleToTransitionSystem.run"
+
+
+ it should "model registers as state" in {
+ // if a signal is invalid, it could take on an arbitrary value in that cycle
+ val src =
+ """circuit m:
+ | module m:
+ | input reset : UInt<1>
+ | input clock : Clock
+ | input en : UInt<1>
+ | input in : UInt<8>
+ | output out : UInt<8>
+ |
+ | reg r : UInt<8>, clock with : (reset => (reset, UInt<8>(0)))
+ | when en:
+ | r <= in
+ | out <= r
+ |
+ |""".stripMargin
+ val sys = toSys(src)
+
+ assert(sys.signals.length == 2)
+
+ // the when is translated as a ITE
+ val genSignal = sys.signals.filterNot(_.name == "out").head
+ assert(genSignal.e.toString == "ite(en, in, r)")
+
+ // the reset is synchronous
+ val r = sys.states.head
+ assert(r.sym.name == "r")
+ assert(r.init.isEmpty, "we are not using any preset, so the initial register content is arbitrary")
+ assert(r.next.get.toString == s"ite(reset, 8'b0, ${genSignal.name})")
+ }
+
+ private def memCircuit(depth: Int = 32) =
+ s"""circuit m:
+ | module m:
+ | input reset : UInt<1>
+ | input clock : Clock
+ | input addr : UInt<${Utils.getUIntWidth(depth)}>
+ | input in : UInt<8>
+ | output out : UInt<8>
+ |
+ | mem m:
+ | data-type => UInt<8>
+ | depth => $depth
+ | reader => r
+ | writer => w
+ | read-latency => 0
+ | write-latency => 1
+ | read-under-write => new
+ |
+ | m.w.clk <= clock
+ | m.w.mask <= UInt(1)
+ | m.w.en <= UInt(1)
+ | m.w.data <= in
+ | m.w.addr <= addr
+ |
+ | m.r.clk <= clock
+ | m.r.en <= UInt(1)
+ | out <= m.r.data
+ | m.r.addr <= addr
+ |
+ |""".stripMargin
+
+ it should "model memories as state" in {
+ val sys = toSys(memCircuit())
+
+ assert(sys.signals.length == 9-2+1, "9 connects - 2 clock connects + 1 combinatorial read port")
+
+ val sig = sys.signals.map(s => s.name -> s.e).toMap
+
+ // masks and enables should all be true
+ val True = BVLiteral(1, 1)
+ assert(sig("m.w.mask") == True)
+ assert(sig("m.w.en") == True)
+ assert(sig("m.r.en") == True)
+
+ // read data should always be enabled
+ assert(sig("m.r.data").toString == "m[m.r.addr]")
+
+ // the memory is modelled as a state
+ val m = sys.states.find(_.sym.name == "m").get
+ assert(m.sym.isInstanceOf[ArraySymbol])
+ val sym = m.sym.asInstanceOf[ArraySymbol]
+ assert(sym.indexWidth == 5)
+ assert(sym.dataWidth == 8)
+ assert(m.init.isEmpty)
+ //assert(m.next.get.toString.contains("m[m.w.addr := m.w.data]"))
+ assert(m.next.get.toString == "m[m.w.addr := m.w.data]")
+ }
+
+ it should "support scalar initialization of a memory to 0" in {
+ val sys = toSys(memCircuit(), memInit = Map("m" -> MemoryScalarInit(0)))
+ val m = sys.states.find(_.sym.name == "m").get
+ assert(m.init.isDefined)
+ assert(m.init.get.toString == "([8'b0] x 32)")
+ }
+
+ it should "support scalar initialization of a memory to 127" in {
+ val sys = toSys(memCircuit(31), memInit = Map("m" -> MemoryScalarInit(127)))
+ val m = sys.states.find(_.sym.name == "m").get
+ assert(m.init.isDefined)
+ assert(m.init.get.toString == "([8'b1111111] x 32)")
+ }
+
+ it should "support array initialization of a memory to Seq(0, 1, 2, 3)" in {
+ val sys = toSys(memCircuit(4), memInit = Map("m" -> MemoryArrayInit(Seq(0, 1, 2, 3))))
+ val m = sys.states.find(_.sym.name == "m").get
+ assert(m.init.isDefined)
+ assert(m.init.get.toString == "([8'b0] x 4)[2'b1 := 8'b1][2'b10 := 8'b10][2'b11 := 8'b11]")
+ }
+
+ it should "support array initialization of a memory to Seq(1, 0, 1, 0)" in {
+ val sys = toSys(memCircuit(4), memInit = Map("m" -> MemoryArrayInit(Seq(1, 0, 1, 0))))
+ val m = sys.states.find(_.sym.name == "m").get
+ assert(m.init.isDefined)
+ assert(m.init.get.toString == "([8'b1] x 4)[2'b1 := 8'b0][2'b11 := 8'b0]")
+ }
+
+ it should "support array initialization of a memory to Seq(1, 0, 0, 0)" in {
+ val sys = toSys(memCircuit(4), memInit = Map("m" -> MemoryArrayInit(Seq(1, 0, 0, 0))))
+ val m = sys.states.find(_.sym.name == "m").get
+ assert(m.init.isDefined)
+ assert(m.init.get.toString == "([8'b0] x 4)[2'b0 := 8'b1]")
+ }
+
+ it should "support array initialization from a file" ignore {
+ assert(false, "TODO")
+ }
+
+ it should "support memories with registered read port" in {
+ def src(readUnderWrite: String) =
+ s"""circuit m:
+ | module m:
+ | input reset : UInt<1>
+ | input clock : Clock
+ | input addr : UInt<5>
+ | input in : UInt<8>
+ | output out : UInt<8>
+ |
+ | mem m:
+ | data-type => UInt<8>
+ | depth => 32
+ | reader => r
+ | writer => w1, w2
+ | read-latency => 1
+ | write-latency => 1
+ | read-under-write => $readUnderWrite
+ |
+ | m.w1.clk <= clock
+ | m.w1.mask <= UInt(1)
+ | m.w1.en <= UInt(1)
+ | m.w1.data <= in
+ | m.w1.addr <= addr
+ | m.w2.clk <= clock
+ | m.w2.mask <= UInt(1)
+ | m.w2.en <= UInt(1)
+ | m.w2.data <= in
+ | m.w2.addr <= addr
+ |
+ | m.r.clk <= clock
+ | m.r.en <= UInt(1)
+ | out <= m.r.data
+ | m.r.addr <= addr
+ |
+ |""".stripMargin
+
+
+ val oldValue = toSys(src("old"))
+ val oldMData = oldValue.states.find(_.sym.name == "m.r.data").get
+ assert(oldMData.sym.toString == "m.r.data")
+ assert(oldMData.next.get.toString == "m[m.r.addr]", "we just need to read the current value")
+ val readDataSignal = oldValue.signals.find(_.name == "m.r.data")
+ assert(readDataSignal.isEmpty, s"${readDataSignal.map(_.toString)} should not exist")
+
+ val undefinedValue = toSys(src("undefined"))
+ val undefinedMData = undefinedValue.states.find(_.sym.name == "m.r.data").get
+ assert(undefinedMData.sym.toString == "m.r.data")
+ val undefined = "RANDOM.m_r_read_under_write_undefined"
+ assert(undefinedMData.next.get.toString ==
+ s"ite(or(eq(m.r.addr, m.w1.addr), eq(m.r.addr, m.w2.addr)), $undefined, m[m.r.addr])",
+ "randomize result if there is a write")
+ }
+
+ it should "support memories with potential write-write conflicts" in {
+ val src =
+ s"""circuit m:
+ | module m:
+ | input reset : UInt<1>
+ | input clock : Clock
+ | input addr : UInt<5>
+ | input in : UInt<8>
+ | output out : UInt<8>
+ |
+ | mem m:
+ | data-type => UInt<8>
+ | depth => 32
+ | reader => r
+ | writer => w1, w2
+ | read-latency => 1
+ | write-latency => 1
+ | read-under-write => undefined
+ |
+ | m.w1.clk <= clock
+ | m.w1.mask <= UInt(1)
+ | m.w1.en <= UInt(1)
+ | m.w1.data <= in
+ | m.w1.addr <= addr
+ | m.w2.clk <= clock
+ | m.w2.mask <= UInt(1)
+ | m.w2.en <= UInt(1)
+ | m.w2.data <= in
+ | m.w2.addr <= addr
+ |
+ | m.r.clk <= clock
+ | m.r.en <= UInt(1)
+ | out <= m.r.data
+ | m.r.addr <= addr
+ |
+ |""".stripMargin
+
+
+ val sys = toSys(src)
+ val m = sys.states.find(_.sym.name == "m").get
+
+ val regularUpdate = "m[m.w1.addr := m.w1.data][m.w2.addr := m.w2.data]"
+ val collision = "eq(m.w1.addr, m.w2.addr)"
+ val collisionUpdate = "m[m.w1.addr := RANDOM.m_w1_write_write_collision]"
+
+ assert(m.next.get.toString == s"ite($collision, $collisionUpdate, $regularUpdate)")
+ }
+
+ it should "model invalid signals as inputs" in {
+ // if a signal is invalid, it could take on an arbitrary value in that cycle
+ val src =
+ """circuit m:
+ | module m:
+ | input en : UInt<1>
+ | output o : UInt<8>
+ | o is invalid
+ | when en:
+ | o <= UInt<8>(0)
+ |""".stripMargin
+ val sys = toSys(src)
+ assert(sys.inputs.length == 2)
+ val random = sys.inputs.filter(_.name.contains("RANDOM"))
+ assert(random.length == 1)
+ assert(random.head.width == 8)
+ }
+
+ it should "throw an error on async reset" in {
+ val err = intercept[AsyncResetException] {
+ toSys(
+ """circuit m:
+ | module m:
+ | input reset : AsyncReset
+ |""".stripMargin
+ )
+ }
+ assert(err.getMessage.contains("reset"))
+ }
+
+ it should "throw an error on casting to async reset" in {
+ val err = intercept[AssertionError] {
+ toSys(
+ """circuit m:
+ | module m:
+ | input reset : UInt<1>
+ | node async = asAsyncReset(reset)
+ |""".stripMargin
+ )
+ }
+ assert(err.getMessage.contains("reset"))
+ }
+
+ it should "throw an error on multiple clocks" in {
+ val err = intercept[MultiClockException] {
+ toSys(
+ """circuit m:
+ | module m:
+ | input clk1 : Clock
+ | input clk2 : Clock
+ |""".stripMargin
+ )
+ }
+ assert(err.getMessage.contains("clk1, clk2"))
+ }
+
+ it should "throw an error on using a clock as uInt" in {
+ // While this could potentially be supported in the future, for now we do not allow
+ // a clock to be used for anything besides updating registers and memories.
+ val err = intercept[AssertionError] {
+ toSys(
+ """circuit m:
+ | module m:
+ | input clk : Clock
+ | output o : UInt<1>
+ | o <= asUInt(clk)
+ |
+ |""".stripMargin
+ )
+ }
+ assert(err.getMessage.contains("clk"))
+ }
+}
diff --git a/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala
new file mode 100644
index 00000000..6bfb5437
--- /dev/null
+++ b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala
@@ -0,0 +1,38 @@
+// See LICENSE for license details.
+
+package firrtl.backends.experimental.smt
+
+import firrtl.annotations.Annotation
+import firrtl.{MemoryInitValue, ir}
+import firrtl.stage.{Forms, TransformManager}
+import org.scalatest.flatspec.AnyFlatSpec
+
+private abstract class SMTBackendBaseSpec extends AnyFlatSpec {
+ private val dependencies = Forms.LowForm
+ private val compiler = new TransformManager(dependencies)
+
+ protected def compile(src: String, annos: Seq[Annotation] = List()): ir.Circuit = {
+ val c = firrtl.Parser.parse(src)
+ compiler.runTransform(firrtl.CircuitState(c, annos)).circuit
+ }
+
+ protected def toSys(src: String, mod: String = "m", presetRegs: Set[String] = Set(),
+ memInit: Map[String, MemoryInitValue] = Map()): TransitionSystem = {
+ val circuit = compile(src)
+ val module = circuit.modules.find(_.name == mod).get.asInstanceOf[ir.Module]
+ // println(module.serialize)
+ new ModuleToTransitionSystem().run(module, presetRegs = presetRegs, memInit = memInit)
+ }
+
+ protected def toBotr2(src: String, mod: String = "m"): Iterable[String] =
+ Btor2Serializer.serialize(toSys(src, mod))
+
+ protected def toBotr2Str(src: String, mod: String = "m"): String =
+ toBotr2(src, mod).mkString("\n") + "\n"
+
+ protected def toSMTLib(src: String, mod: String = "m"): Iterable[String] =
+ SMTTransitionSystemEncoder.encode(toSys(src, mod)).map(SMTLibSerializer.serialize)
+
+ protected def toSMTLibStr(src: String, mod: String = "m"): String =
+ toSMTLib(src, mod).mkString("\n") + "\n"
+} \ No newline at end of file
diff --git a/src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala
new file mode 100644
index 00000000..7193474d
--- /dev/null
+++ b/src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala
@@ -0,0 +1,66 @@
+// See LICENSE for license details.
+
+package firrtl.backends.experimental.smt
+
+private class SMTLibSpec extends SMTBackendBaseSpec {
+
+ it should "convert a hello world module" in {
+ val src =
+ """circuit m:
+ | module m:
+ | input clock: Clock
+ | input a: UInt<8>
+ | output b: UInt<16>
+ | b <= a
+ | assert(clock, eq(a, b), UInt(1), "")
+ |""".stripMargin
+
+ val expected =
+ """(declare-sort m_s 0)
+ |; firrtl-smt2-input a 8
+ |(declare-fun a_f (m_s) (_ BitVec 8))
+ |; firrtl-smt2-output b 16
+ |(define-fun b_f ((state m_s)) (_ BitVec 16) ((_ zero_extend 8) (a_f state)))
+ |; firrtl-smt2-assert assert_ 1
+ |(define-fun assert__f ((state m_s)) Bool (= ((_ zero_extend 8) (a_f state)) (b_f state)))
+ |(define-fun m_t ((state m_s) (state_n m_s)) Bool true)
+ |(define-fun m_i ((state m_s)) Bool true)
+ |(define-fun m_a ((state m_s)) Bool (assert__f state))
+ |(define-fun m_u ((state m_s)) Bool true)
+ |""".stripMargin
+
+ assert(toSMTLibStr(src) == expected)
+ }
+
+ it should "include FileInfo in the output" in {
+ val src =
+ """circuit m: @[circuit 0:0]
+ | module m: @[module 0:0]
+ | input clock: Clock @[clock 0:0]
+ | input a: UInt<8> @[a 0:0]
+ | output b: UInt<16> @[b 0:0]
+ | b <= a @[b_a 0:0]
+ | assert(clock, eq(a, b), UInt(1), "") @[assert 0:0]
+ |""".stripMargin
+
+ val expected =
+ """; @ module 0:0
+ |(declare-sort m_s 0)
+ |; firrtl-smt2-input a 8
+ |; @ a 0:0
+ |(declare-fun a_f (m_s) (_ BitVec 8))
+ |; firrtl-smt2-output b 16
+ |; @ b 0:0, b_a 0:0
+ |(define-fun b_f ((state m_s)) (_ BitVec 16) ((_ zero_extend 8) (a_f state)))
+ |; firrtl-smt2-assert assert_ 1
+ |; @ assert 0:0
+ |(define-fun assert__f ((state m_s)) Bool (= ((_ zero_extend 8) (a_f state)) (b_f state)))
+ |(define-fun m_t ((state m_s) (state_n m_s)) Bool true)
+ |(define-fun m_i ((state m_s)) Bool true)
+ |(define-fun m_a ((state m_s)) Bool (assert__f state))
+ |(define-fun m_u ((state m_s)) Bool true)
+ |""".stripMargin
+
+ assert(toSMTLibStr(src) == expected)
+ }
+}
diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala
new file mode 100644
index 00000000..4c6901ea
--- /dev/null
+++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala
@@ -0,0 +1,46 @@
+// See LICENSE for license details.
+
+package firrtl.backends.experimental.smt.end2end
+
+import firrtl.annotations.CircuitTarget
+import firrtl.backends.experimental.smt.{GlobalClockAnnotation, StutteringClockTransform}
+import firrtl.options.Dependency
+import firrtl.stage.RunFirrtlTransformAnnotation
+
+class AsyncResetSpec extends EndToEndSMTBaseSpec {
+ def annos(name: String) = Seq(
+ RunFirrtlTransformAnnotation(Dependency[StutteringClockTransform]),
+ GlobalClockAnnotation(CircuitTarget(name).module(name).ref("global_clock")))
+
+ "a module with asynchronous reset" should "allow a register to change between clock edges" taggedAs(RequiresZ3) in {
+ def in(resetType: String) =
+ s"""circuit AsyncReset00:
+ | module AsyncReset00:
+ | input global_clock: Clock
+ | input c: Clock
+ | input reset: $resetType
+ | input preset: AsyncReset
+ |
+ | ; a register with async reset
+ | reg r: UInt<4>, c with: (reset => (reset, UInt(3)))
+ |
+ | ; a counter/toggler connected to the clock c
+ | reg count: UInt<1>, c with: (reset => (preset, UInt(0)))
+ | count <= add(count, UInt(1))
+ |
+ | ; the past machinery and the assertion uses the global clock
+ | reg past_valid: UInt<1>, global_clock with: (reset => (preset, UInt(0)))
+ | past_valid <= UInt(1)
+ | reg past_r: UInt<4>, global_clock
+ | past_r <= r
+ | reg past_count: UInt<1>, global_clock
+ | past_count <= count
+ |
+ | ; can the value of r change without the count changing?
+ | assert(global_clock, or(not(eq(count, past_count)), eq(r, past_r)), past_valid, "count = past(count) |-> r = past(r)")
+ |""".stripMargin
+ test(in("AsyncReset"), MCFail(1), kmax=2, annos=annos("AsyncReset00"))
+ test(in("UInt<1>"), MCSuccess, kmax=2, annos=annos("AsyncReset00"))
+ }
+
+}
diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala
new file mode 100644
index 00000000..2227719b
--- /dev/null
+++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala
@@ -0,0 +1,230 @@
+// See LICENSE for license details.
+
+package firrtl.backends.experimental.smt.end2end
+
+import java.io.{File, PrintWriter}
+
+import firrtl.annotations.{Annotation, CircuitTarget, PresetAnnotation}
+import firrtl.backends.experimental.smt.{Btor2Emitter, SMTLibEmitter}
+import firrtl.options.TargetDirAnnotation
+import firrtl.stage.{FirrtlCircuitAnnotation, FirrtlStage, OutputFileAnnotation, RunFirrtlTransformAnnotation}
+import firrtl.util.BackendCompilationUtilities
+import logger.{LazyLogging, LogLevel, LogLevelAnnotation}
+import org.scalatest.flatspec.AnyFlatSpec
+import org.scalatest.matchers.must.Matchers
+
+import scala.sys.process._
+
+class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging {
+ "we" should "check if Z3 is available" taggedAs(RequiresZ3) in {
+ val log = ProcessLogger(_ => (), logger.warn(_))
+ val ret = Process(Seq("which", "z3")).run(log).exitValue()
+ if(ret != 0) {
+ logger.error(
+ """The z3 SMT-Solver seems not to be installed.
+ |You can exclude the end-to-end smt backend tests which rely on z3 like this:
+ |sbt testOnly -- -l RequiresZ3
+ |""".stripMargin)
+ }
+ assert(ret == 0)
+ }
+
+ "Z3" should "be available in version 4" taggedAs(RequiresZ3) in {
+ assert(Z3ModelChecker.getZ3Version.startsWith("4."))
+ }
+
+ "a simple combinatorial check" should "pass" taggedAs(RequiresZ3) in {
+ val in =
+ """circuit CC00:
+ | module CC00:
+ | input c: Clock
+ | input a: UInt<1>
+ | input b: UInt<1>
+ | assert(c, lt(add(a, b), UInt(3)), UInt(1), "a + b < 3")
+ |""".stripMargin
+ test(in, MCSuccess)
+ }
+
+ "a simple combinatorial check" should "fail immediately" taggedAs(RequiresZ3) in {
+ val in =
+ """circuit CC01:
+ | module CC01:
+ | input c: Clock
+ | input a: UInt<1>
+ | input b: UInt<1>
+ | assert(c, gt(add(a, b), UInt(3)), UInt(1), "a + b > 3")
+ |""".stripMargin
+ test(in, MCFail(0))
+ }
+
+ "adding the right assumption" should "make a test pass" taggedAs(RequiresZ3) in {
+ val in0 =
+ """circuit CC01:
+ | module CC01:
+ | input c: Clock
+ | input a: UInt<1>
+ | input b: UInt<1>
+ | assert(c, neq(add(a, b), UInt(0)), UInt(1), "a + b != 0")
+ |""".stripMargin
+ val in1 =
+ """circuit CC01:
+ | module CC01:
+ | input c: Clock
+ | input a: UInt<1>
+ | input b: UInt<1>
+ | assert(c, neq(add(a, b), UInt(0)), UInt(1), "a + b != 0")
+ | assume(c, neq(a, UInt(0)), UInt(1), "a != 0")
+ |""".stripMargin
+ test(in0, MCFail(0))
+ test(in1, MCSuccess)
+
+ val in2 =
+ """circuit CC01:
+ | module CC01:
+ | input c: Clock
+ | input a: UInt<1>
+ | input b: UInt<1>
+ | input en: UInt<1>
+ | assert(c, neq(add(a, b), UInt(0)), UInt(1), "a + b != 0")
+ | assume(c, neq(a, UInt(0)), en, "a != 0 if en")
+ |""".stripMargin
+ test(in2, MCFail(0))
+ }
+
+ "a register connected to preset reset" should "be initialized with the reset value" taggedAs(RequiresZ3) in {
+ def in(rEq: Int) =
+ s"""circuit Preset00:
+ | module Preset00:
+ | input c: Clock
+ | input preset: AsyncReset
+ | reg r: UInt<4>, c with: (reset => (preset, UInt(3)))
+ | assert(c, eq(r, UInt($rEq)), UInt(1), "r = $rEq")
+ |""".stripMargin
+ test(in(3), MCSuccess, kmax = 1)
+ test(in(2), MCFail(0))
+ }
+
+ "a register's initial value" should "should not change" taggedAs(RequiresZ3) in {
+ val in =
+ """circuit Preset00:
+ | module Preset00:
+ | input c: Clock
+ | input preset: AsyncReset
+ |
+ | ; the past value of our register will only be valid in the 1st unrolling
+ | reg past_valid: UInt<1>, c with: (reset => (preset, UInt(0)))
+ | past_valid <= UInt(1)
+ |
+ | reg r: UInt<4>, c
+ | reg r_past: UInt<4>, c
+ | r_past <= r
+ | assert(c, eq(r, r_past), past_valid, "past_valid => r == r_past")
+ |""".stripMargin
+ test(in, MCSuccess, kmax = 2)
+ }
+}
+
+abstract class EndToEndSMTBaseSpec extends AnyFlatSpec with Matchers {
+ def test(src: String, expected: MCResult, kmax: Int = 0, clue: String = "", annos: Seq[Annotation] = Seq()): Unit = {
+ expected match {
+ case MCFail(k) => assert(kmax >= k, s"Please set a kmax that includes the expected failing step! ($kmax < $expected)")
+ case _ =>
+ }
+ val fir = firrtl.Parser.parse(src)
+ val name = fir.main
+ val testDir = BackendCompilationUtilities.createTestDirectory("EndToEndSMT." + name)
+ // we automagically add a preset annotation if an input called preset exists
+ val presetAnno = if(!src.contains("input preset")) { None } else {
+ Some(PresetAnnotation(CircuitTarget(name).module(name).ref("preset")))
+ }
+ val res = (new FirrtlStage).execute(Array(), Seq(
+ LogLevelAnnotation(LogLevel.Error), // silence warnings for tests
+ RunFirrtlTransformAnnotation(new SMTLibEmitter),
+ RunFirrtlTransformAnnotation(new Btor2Emitter),
+ FirrtlCircuitAnnotation(fir),
+ TargetDirAnnotation(testDir.getAbsolutePath)
+ ) ++ presetAnno ++ annos)
+ assert(res.collectFirst{ case _: OutputFileAnnotation => true }.isDefined)
+ val r = Z3ModelChecker.bmc(testDir, name, kmax)
+ assert(r == expected, clue + "\n" + s"$testDir")
+ }
+}
+
+/** Minimal implementation of a Z3 based bounded model checker.
+ * A more complete version of this with better use feedback should eventually be provided by a
+ * chisel3 formal verification library. Do not use this implementation outside of the firrtl test suite!
+ * */
+private object Z3ModelChecker extends LazyLogging {
+ def getZ3Version: String = {
+ val (out, ret) = executeCmd("-version")
+ assert(ret == 0, "failed to call z3")
+ assert(out.startsWith("Z3 version"), s"$out does not start with 'Z3 version'")
+ val version = out.split(" ")(2)
+ version
+ }
+
+ def bmc(testDir: File, main: String, kmax: Int): MCResult = {
+ assert(kmax >=0 && kmax < 50, "Trying to keep kmax in a reasonable range.")
+ val smtFile = new File(testDir, main + ".smt2")
+ val header = read(smtFile)
+ val steps = (0 to kmax).map(k => new File(testDir, main + s"_step$k.smt2")).zipWithIndex
+ steps.foreach { case (f,k) =>
+ writeStep(f, main, header, k)
+ val success = executeStep(f.getAbsolutePath)
+ if(!success) return MCFail(k)
+ }
+ MCSuccess
+ }
+
+ private def executeStep(filename: String): Boolean = {
+ val (out, ret) = executeCmd(filename)
+ assert(ret == 0, s"expected success (0), not $ret: `$out`\nz3 $filename")
+ assert(out == "sat" || out == "unsat", s"Unexpected output: $out")
+ out == "unsat"
+ }
+
+ private def executeCmd(cmd: String): (String, Int) = {
+ var out = ""
+ val log = ProcessLogger(s => out = s, logger.warn(_))
+ val ret = Process(Seq("z3", cmd)).run(log).exitValue()
+ (out, ret)
+ }
+
+ private def writeStep(f: File, main: String, header: Iterable[String], k: Int): Unit = {
+ val pw = new PrintWriter(f)
+ val lines = header ++ step(main, k) ++ List("(check-sat)")
+ lines.foreach(pw.println)
+ pw.close()
+ }
+
+ private def step(main: String, k: Int): Iterable[String] = {
+ // define all states
+ (0 to k).map(ii => s"(declare-fun s$ii () $main$StateTpe)") ++
+ // assert that init holds in state 0
+ List(s"(assert ($main$Init s0))") ++
+ // assert transition relation
+ (0 until k).map(ii => s"(assert ($main$Transition s$ii s${ii+1}))") ++
+ // assert that assumptions hold in all states
+ (0 to k).map(ii => s"(assert ($main$Assumes s$ii))") ++
+ // assert that assertions hold for all but last state
+ (0 until k).map(ii => s"(assert ($main$Asserts s$ii))") ++
+ // check to see if we can violate the assertions in the last state
+ List(s"(assert (not ($main$Asserts s$k)))")
+ }
+
+ private def read(f: File): Iterable[String] = {
+ val source = scala.io.Source.fromFile(f)
+ try source.getLines().toVector finally source.close()
+ }
+
+ // the following suffixes have to match the ones in [[SMTTransitionSystemEncoder]]
+ private val Transition = "_t"
+ private val Init = "_i"
+ private val Asserts = "_a"
+ private val Assumes = "_u"
+ private val StateTpe = "_s"
+}
+
+private sealed trait MCResult
+private case object MCSuccess extends MCResult
+private case class MCFail(k: Int) extends MCResult
diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala
new file mode 100644
index 00000000..10de9cda
--- /dev/null
+++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala
@@ -0,0 +1,107 @@
+// See LICENSE for license details.
+
+package firrtl.backends.experimental.smt.end2end
+
+import firrtl.annotations.{CircuitTarget, MemoryArrayInitAnnotation, MemoryScalarInitAnnotation}
+
+class MemorySpec extends EndToEndSMTBaseSpec {
+ private def registeredTestMem(name: String, cmds: String, readUnderWrite: String): String =
+ registeredTestMem(name, cmds.split("\n"), readUnderWrite)
+ private def registeredTestMem(name: String, cmds: Iterable[String], readUnderWrite: String): String =
+ s"""circuit $name:
+ | module $name:
+ | input reset : UInt<1>
+ | input clock : Clock
+ | input preset: AsyncReset
+ | input write_addr : UInt<5>
+ | input read_addr : UInt<5>
+ | input in : UInt<8>
+ | output out : UInt<8>
+ |
+ | mem m:
+ | data-type => UInt<8>
+ | depth => 32
+ | reader => r
+ | writer => w
+ | read-latency => 1
+ | write-latency => 1
+ | read-under-write => $readUnderWrite
+ |
+ | m.w.clk <= clock
+ | m.w.mask <= UInt(1)
+ | m.w.en <= UInt(1)
+ | m.w.data <= in
+ | m.w.addr <= write_addr
+ |
+ | m.r.clk <= clock
+ | m.r.en <= UInt(1)
+ | out <= m.r.data
+ | m.r.addr <= read_addr
+ |
+ | reg cycle: UInt<8>, clock with: (reset => (preset, UInt(0)))
+ | cycle <= add(cycle, UInt(1))
+ | node past_valid = geq(cycle, UInt(1))
+ |
+ | ${cmds.mkString("\n ")}
+ |""".stripMargin
+
+ "Registered test memory" should "return written data after two cycles" taggedAs(RequiresZ3) in {
+ val cmds =
+ """node past_past_valid = geq(cycle, UInt(2))
+ |reg past_in: UInt<8>, clock
+ |past_in <= in
+ |reg past_past_in: UInt<8>, clock
+ |past_past_in <= past_in
+ |reg past_write_addr: UInt<5>, clock
+ |past_write_addr <= write_addr
+ |
+ |assume(clock, eq(read_addr, past_write_addr), past_valid, "read_addr = past(write_addr)")
+ |assert(clock, eq(out, past_past_in), past_past_valid, "out = past(past(in))")
+ |""".stripMargin
+ test(registeredTestMem("Mem00", cmds, "old"), MCSuccess, kmax = 3)
+ }
+
+ private def readOnlyMem(pred: String, num: Int) =
+ s"""circuit Mem0$num:
+ | module Mem0$num:
+ | input c : Clock
+ | input read_addr : UInt<2>
+ | output out : UInt<8>
+ |
+ | mem m:
+ | data-type => UInt<8>
+ | depth => 4
+ | reader => r
+ | read-latency => 0
+ | write-latency => 1
+ | read-under-write => new
+ |
+ | m.r.clk <= c
+ | m.r.en <= UInt(1)
+ | out <= m.r.data
+ | m.r.addr <= read_addr
+ |
+ | assert(c, $pred, UInt(1), "")
+ |""".stripMargin
+ private def m(num: Int) = CircuitTarget(s"Mem0$num").module(s"Mem0$num").ref("m")
+
+ "read-only memory" should "always return 0" taggedAs(RequiresZ3) in {
+ test(readOnlyMem("eq(out, UInt(0))", 1), MCSuccess, kmax=2,
+ annos=Seq(MemoryScalarInitAnnotation(m(1), 0)))
+ }
+
+ "read-only memory" should "not always return 1" taggedAs(RequiresZ3) in {
+ test(readOnlyMem("eq(out, UInt(1))", 2), MCFail(0), kmax=2,
+ annos=Seq(MemoryScalarInitAnnotation(m(2), 0)))
+ }
+
+ "read-only memory" should "always return 1 or 2" taggedAs(RequiresZ3) in {
+ test(readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 3), MCSuccess, kmax=2,
+ annos=Seq(MemoryArrayInitAnnotation(m(3), Seq(1, 2, 2, 1))))
+ }
+
+ "read-only memory" should "not always return 1 or 2 or 3" taggedAs(RequiresZ3) in {
+ test(readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 4), MCFail(0), kmax=2,
+ annos=Seq(MemoryArrayInitAnnotation(m(4), Seq(1, 2, 2, 3))))
+ }
+}
diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/RequiresZ3.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/RequiresZ3.scala
new file mode 100644
index 00000000..d633a1a0
--- /dev/null
+++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/RequiresZ3.scala
@@ -0,0 +1,9 @@
+// See LICENSE for license details.
+
+package firrtl.backends.experimental.smt.end2end
+
+import org.scalatest.Tag
+
+// To disable tests that require the Z3 SMT solver to be installed use the following:
+// `sbt testOnly -- -l RequiresZ3`
+object RequiresZ3 extends Tag("RequiresZ3")
diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala
new file mode 100644
index 00000000..cbf194dd
--- /dev/null
+++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala
@@ -0,0 +1,46 @@
+// See LICENSE for license details.
+
+package firrtl.backends.experimental.smt.end2end
+
+import java.io.File
+
+import firrtl.stage.{FirrtlStage, OutputFileAnnotation}
+import firrtl.util.BackendCompilationUtilities
+import logger.LazyLogging
+import org.scalatest.flatspec.AnyFlatSpec
+
+import scala.sys.process.{Process, ProcessLogger}
+
+/** compiles the regression tests to SMTLib and parses the result with z3 */
+class SMTCompilationTest extends AnyFlatSpec with LazyLogging {
+ it should "generate valid SMTLib for AddNot" taggedAs(RequiresZ3) in { compileAndParse("AddNot") }
+ it should "generate valid SMTLib for FPU" taggedAs(RequiresZ3) in { compileAndParse("FPU") }
+ // we get a stack overflow in Scala 2.11 because of a deeply nested and(...) expression in the sequencer
+ it should "generate valid SMTLib for HwachaSequencer" taggedAs(RequiresZ3) ignore { compileAndParse("HwachaSequencer") }
+ it should "generate valid SMTLib for ICache" taggedAs(RequiresZ3) in { compileAndParse("ICache") }
+ it should "generate valid SMTLib for Ops" taggedAs(RequiresZ3) in { compileAndParse("Ops") }
+ // TODO: enable Rob test once we support more than 2 write ports on a memory
+ it should "generate valid SMTLib for Rob" taggedAs(RequiresZ3) ignore { compileAndParse("Rob") }
+ it should "generate valid SMTLib for RocketCore" taggedAs(RequiresZ3) in { compileAndParse("RocketCore") }
+
+ private def compileAndParse(name: String): Unit = {
+ val testDir = BackendCompilationUtilities.createTestDirectory(name + "-smt")
+ val inputFile = new File(testDir, s"${name}.fir")
+ BackendCompilationUtilities.copyResourceToFile(s"/regress/${name}.fir", inputFile)
+
+ val args = Array(
+ "-ll", "error", // surpress warnings to keep test output clean
+ "--target-dir", testDir.toString,
+ "-i", inputFile.toString,
+ "-E", "experimental-smt2"
+ // "-fct", "firrtl.backends.experimental.smt.StutteringClockTransform"
+ )
+ val res = (new FirrtlStage).execute(args, Seq())
+ val fileName = res.collectFirst{ case OutputFileAnnotation(file) => file }.get
+
+ val smtFile = testDir.toString + "/" + fileName + ".smt2"
+ val log = ProcessLogger(_ => (), logger.error(_))
+ val z3Ret = Process(Seq("z3", smtFile)).run(log).exitValue()
+ assert(z3Ret == 0, s"Failed to parse SMTLib file $smtFile generated for $name")
+ }
+}
diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala
new file mode 100644
index 00000000..8fa80b4c
--- /dev/null
+++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala
@@ -0,0 +1,40 @@
+// See LICENSE for license details.
+
+package firrtl.backends.experimental.smt.end2end
+
+/** undefined values in firrtl are modelled as fresh auxiliary variables (inputs) */
+class UndefinedFirrtlSpec extends EndToEndSMTBaseSpec {
+
+ "division by zero" should "result in an arbitrary value" taggedAs(RequiresZ3) in {
+ // the SMTLib spec defines the result of division by zero to be all 1s
+ // https://cs.nyu.edu/pipermail/smt-lib/2015/000977.html
+ def in(dEq: Int) =
+ s"""circuit CC00:
+ | module CC00:
+ | input c: Clock
+ | input a: UInt<2>
+ | input b: UInt<2>
+ | assume(c, eq(b, UInt(0)), UInt(1), "b = 0")
+ | node d = div(a, b)
+ | assert(c, eq(d, UInt($dEq)), UInt(1), "d = $dEq")
+ |""".stripMargin
+ // we try to assert that (d = a / 0) is any fixed value which should be false
+ (0 until 4).foreach { ii => test(in(ii), MCFail(0), 0, s"d = a / 0 = $ii") }
+ }
+
+ // TODO: rem should probably also be undefined, but the spec isn't 100% clear here
+
+
+ "invalid signals" should "have an arbitrary values" taggedAs(RequiresZ3) in {
+ def in(aEq: Int) =
+ s"""circuit CC00:
+ | module CC00:
+ | input c: Clock
+ | wire a: UInt<2>
+ | a is invalid
+ | assert(c, eq(a, UInt($aEq)), UInt(1), "a = $aEq")
+ |""".stripMargin
+ // a should not be equivalent to any fixed value (0, 1, 2 or 3)
+ (0 until 4).foreach { ii => test(in(ii), MCFail(0), 0, s"a = $ii") }
+ }
+}