diff options
| author | Kevin Laeufer | 2020-08-14 18:39:42 -0700 |
|---|---|---|
| committer | GitHub | 2020-08-15 01:39:42 +0000 |
| commit | 2e5f942d25d7afab79ee1263c5d6833cad9d743d (patch) | |
| tree | add86d0b4b090807b48bb2307d10f2b7b38e0bce /src/test | |
| parent | 1b48fe5f5e94bdfdef700956e45d478b5706f25e (diff) | |
experimental SMTLib and btor2 emitter (#1826)
This adds an experimental new SMTLib and Btor2 emitter
that converts a firrtl module into a format
suitable for open source model checkers.
The format generally follows the behavior of yosys'
write_smt2 and write_btor commands.
To generate btor2 for the module in m.fir run
> ./utils/bin/firrtl -i m.fir -E experimental-btor2
for SMT:
> ./utils/bin/firrtl -i m.fir -E experimental-smt2
If you have a design with multiple clocks
or an asynchronous reset, try out the new StutteringClockTransform.
You can designate any input of type Clock to be your
global simulation clock using the new GlobalClockAnnotation.
If your toplevel module instantiates submodules,
you need to inline them if you want the submodule
logic to be included in the formal model.
Diffstat (limited to 'src/test')
11 files changed, 1181 insertions, 0 deletions
diff --git a/src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala b/src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala new file mode 100644 index 00000000..56f891e6 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala @@ -0,0 +1,61 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt + +private class Btor2Spec extends SMTBackendBaseSpec { + + it should "convert a hello world module" in { + val src = + """circuit m: + | module m: + | input clock: Clock + | input a: UInt<8> + | output b: UInt<16> + | b <= a + | assert(clock, eq(a, b), UInt(1), "") + |""".stripMargin + + val expected = + """1 sort bitvec 8 + |2 input 1 a + |3 sort bitvec 16 + |4 uext 3 2 8 + |5 output 4 ; b + |6 sort bitvec 1 + |7 uext 3 2 8 + |8 eq 6 7 4 + |9 not 6 8 + |10 bad 9 ; assert_ + |""".stripMargin + + assert(toBotr2Str(src) == expected) + } + + it should "include FileInfo in the output" in { + val src = + """circuit m: @[circuit 0:0] + | module m: @[module 0:0] + | input clock: Clock @[clock 0:0] + | input a: UInt<8> @[a 0:0] + | output b: UInt<16> @[b 0:0] + | b <= a @[b_a 0:0] + | assert(clock, eq(a, b), UInt(1), "") @[assert 0:0] + |""".stripMargin + + val expected = + """; @ module 0:0 + |1 sort bitvec 8 + |2 input 1 a ; @ a 0:0 + |3 sort bitvec 16 + |4 uext 3 2 8 + |5 output 4 ; b @ b 0:0, b_a 0:0 + |6 sort bitvec 1 + |7 uext 3 2 8 + |8 eq 6 7 4 + |9 not 6 8 + |10 bad 9 ; assert_ @ assert 0:0 + |""".stripMargin + + assert(toBotr2Str(src) == expected) + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala new file mode 100644 index 00000000..015ac4a9 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala @@ -0,0 +1,224 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt + +private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec { + + def primop(op: String, resTpe: String, inTpes: Seq[String], consts: Seq[Int]): String = { + val inputs = inTpes.zipWithIndex.map { case (tpe, ii) => s" input i$ii : $tpe" }.mkString("\n") + val args = (inTpes.zipWithIndex.map { case (_, ii) => s"i$ii" } ++ consts.map(_.toString)).mkString(", ") + val src = + s"""circuit m: + | module m: + |$inputs + | output res: $resTpe + | res <= $op($args) + | + |""".stripMargin + val sys = toSys(src) + assert(sys.signals.length == 1) + sys.signals.head.e.toString + } + + def primop(signed: Boolean, op: String, resWidth: Int, inWidth: Seq[Int], consts: Seq[Int] = List(), + resAlwaysUnsigned: Boolean = false): String = { + val tpe = if(signed) "SInt" else "UInt" + val resTpe = if(resAlwaysUnsigned) "UInt" else tpe + val inTpes = inWidth.map(w => s"$tpe<$w>") + primop(op, s"$resTpe<$resWidth>", inTpes, consts) + } + + it should "correctly translate the add primitive operation with different operand sizes" in { + assert(primop(false, "add", 5, List(3, 5)) == "add(zext(i0, 3), zext(i1, 1))[4:0]") + assert(primop(false, "add", 5, List(3, 4)) == "add(zext(i0, 2), zext(i1, 1))") + assert(primop(true, "add", 5, List(3, 5)) == "add(sext(i0, 3), sext(i1, 1))[4:0]") + assert(primop(true, "add", 5, List(3, 4)) == "add(sext(i0, 2), sext(i1, 1))") + + // could be simplified to just `add(i0, i1)` + assert(primop(false, "add", 8, List(8, 8)) == "add(zext(i0, 1), zext(i1, 1))[7:0]") + } + + it should "correctly translate the `add` primitive operation" in { + assert(primop(false, "add", 8, List(7, 7)) == "add(zext(i0, 1), zext(i1, 1))") + } + + it should "correctly translate the `sub` primitive operation" in { + assert(primop(false, "sub", 8, List(7, 7)) == "sub(zext(i0, 1), zext(i1, 1))") + } + + it should "correctly translate the `mul` primitive operation" in { + assert(primop(false, "mul", 8, List(4, 4)) == "mul(zext(i0, 4), zext(i1, 4))") + } + + it should "correctly translate the `div` primitive operation" in { + // division is a little bit more complicated because the result of division by zero is undefined + assert(primop(false, "div", 8, List(8, 8)) == + "ite(eq(i1, 8'b0), RANDOM.res, udiv(i0, i1))") + assert(primop(false, "div", 8, List(8, 4)) == + "ite(eq(i1, 4'b0), RANDOM.res, udiv(i0, zext(i1, 4)))") + + // signed division increases result width by 1 + assert(primop(true, "div", 8, List(7, 7)) == + "ite(eq(i1, 7'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 1)))") + assert(primop(true, "div", 8, List(7, 4)) + == "ite(eq(i1, 4'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 4)))") + } + + it should "correctly translate the `rem` primitive operation" in { + // rem can decrease the size of operands, but we should only do that decrease on the result + assert(primop(false, "rem", 4, List(4, 8)) == "urem(zext(i0, 4), i1)[3:0]") + assert(primop(false, "rem", 4, List(8, 4)) == "urem(i0, zext(i1, 4))[3:0]") + assert(primop(true, "rem", 4, List(4, 8)) == "srem(sext(i0, 4), i1)[3:0]") + assert(primop(true, "rem", 4, List(8, 4)) == "srem(i0, sext(i1, 4))[3:0]") + // TODO: add test to make sure we are using the correct mod/rem operation for signed and unsigned + // https://groups.google.com/g/stp-users/c/od43h8q5RSI has some tests that we could copy and + // use with a SMT solver + } + + it should "correctly translate the comparison primitive operations" in { + // some comparisons are represented as the negation of others + assert(primop(false, "lt", 1, List(8, 8)) == "not(ugeq(i0, i1))") + assert(primop(false, "leq", 1, List(8, 8)) == "not(ugt(i0, i1))") + assert(primop(false, "gt", 1, List(8, 8)) == "ugt(i0, i1)") + assert(primop(false, "geq", 1, List(8, 8)) == "ugeq(i0, i1)") + assert(primop(false, "eq", 1, List(8, 8)) == "eq(i0, i1)") + assert(primop(false, "neq", 1, List(8, 8)) == "not(eq(i0, i1))") + + assert(primop(true, "lt", 1, List(8, 8), resAlwaysUnsigned = true) == "not(sgeq(i0, i1))") + assert(primop(true, "leq", 1, List(8, 8), resAlwaysUnsigned = true) == "not(sgt(i0, i1))") + assert(primop(true, "gt", 1, List(8, 8), resAlwaysUnsigned = true) == "sgt(i0, i1)") + assert(primop(true, "geq", 1, List(8, 8), resAlwaysUnsigned = true) == "sgeq(i0, i1)") + assert(primop(true, "eq", 1, List(8, 8), resAlwaysUnsigned = true) == "eq(i0, i1)") + assert(primop(true, "neq", 1, List(8, 8), resAlwaysUnsigned = true) == "not(eq(i0, i1))") + + // it should always extend the width to the max of both + assert(primop(false, "gt", 1, List(7, 8)) == "ugt(zext(i0, 1), i1)") + } + + it should "correctly translate the `pad` primitive operation" in { + // firrtl pad takes new width as argument, whereas the smt zext takes the number of bits to extend by + assert(primop(false, "pad", 8, List(3), List(8)) == "zext(i0, 5)") + assert(primop(false, "pad", 8, List(3), List(5)) == "zext(zext(i0, 2), 3)") + + // there is no negative padding, instead the result is just e + assert(primop(false, "pad", 3, List(3), List(2)) == "i0") + + assert(primop(true, "pad", 8, List(3), List(8)) == "sext(i0, 5)") + assert(primop(true, "pad", 8, List(3), List(5)) == "sext(sext(i0, 2), 3)") + } + + it should "correctly translate the asX primitive operations" in { + // these are all essentially no-ops + assert(primop(false, "asUInt", 3, List(3)) == "i0") + assert(primop(true, "asSInt", 3, List(3)) == "i0") + } + + it should "correctly translate the `shl` primitive operation" in { + assert(primop(false, "shl", 6, List(3), List(3)) == "concat(i0, 3'b0)") + assert(primop(true, "shl", 6, List(3), List(3)) == "concat(i0, 3'b0)") + assert(primop(false, "shl", 3, List(3), List(0)) == "i0") + } + + it should "correctly translate the `shr` primitive operation" in { + assert(primop(false, "shr", 6, List(9), List(3)) == "i0[8:3]") + assert(primop(true, "shr", 6, List(9), List(3)) == "i0[8:3]") + + // "If n is greater than or equal to the bit-width of e, + // the resulting value will be zero for unsigned types and the sign bit for signed types." + assert(primop(false, "shr", 1, List(3), List(3)) == "1'b0") + assert(primop(false, "shr", 1, List(3), List(4)) == "1'b0") + assert(primop(true, "shr", 1, List(3), List(3)) == "i0[2]") + assert(primop(true, "shr", 1, List(3), List(4)) == "i0[2]") + } + + it should "correctly translate the `dshl` primitive operation" in { + assert(primop(false, "dshl", 31, List(16, 4)) == "logical_shift_left(zext(i0, 15), zext(i1, 27))") + assert(primop(false, "dshl", 19, List(16, 2)) == "logical_shift_left(zext(i0, 3), zext(i1, 17))") + assert(primop("dshl", "SInt<19>", List("SInt<16>", "UInt<2>"), List()) == + "logical_shift_left(sext(i0, 3), zext(i1, 17))") + } + + it should "correctly translate the `dshr` primitive operation" in { + assert(primop(false, "dshr", 16, List(16, 4)) == "logical_shift_right(i0, zext(i1, 12))") + assert(primop(false, "dshr", 16, List(16, 2)) == "logical_shift_right(i0, zext(i1, 14))") + assert(primop("dshr", "SInt<16>", List("SInt<16>", "UInt<2>"), List()) == + "arithmetic_shift_right(i0, zext(i1, 14))") + } + + it should "correctly translate the `cvt` primitive operation" in { + // for signed operands, this is a no-op + assert(primop(true, "cvt", 3, List(3)) == "i0") + + // for unsigned, a zero is prepended + assert(primop("cvt", "SInt<16>", List("UInt<15>"), List()) == "concat(1'b0, i0)") + assert(primop("cvt", "SInt<16>", List("UInt<14>"), List()) == "sext(concat(1'b0, i0), 1)") + } + + it should "correctly translate the `neg` primitive operation" in { + assert(primop(true, "neg", 4, List(3)) == "neg(sext(i0, 1))") + assert(primop("neg", "SInt<4>", List("UInt<3>"), List()) == "neg(zext(i0, 1))") + } + + it should "correctly translate the `not` primitive operation" in { + assert(primop(false, "not", 4, List(4)) == "not(i0)") + assert(primop("not", "UInt<4>", List("SInt<4>"), List()) == "not(i0)") + } + + it should "correctly translate the binary bitwise primitive operations" in { + assert(primop(false, "and", 4, List(4, 3)) == "and(i0, zext(i1, 1))") + assert(primop("and", "UInt<4>", List("SInt<4>", "SInt<3>"), List()) == "and(i0, sext(i1, 1))") + + assert(primop(false, "or", 4, List(4, 3)) == "or(i0, zext(i1, 1))") + assert(primop("or", "UInt<4>", List("SInt<4>", "SInt<3>"), List()) == "or(i0, sext(i1, 1))") + + assert(primop(false, "xor", 4, List(4, 3)) == "xor(i0, zext(i1, 1))") + assert(primop("xor", "UInt<4>", List("SInt<4>", "SInt<3>"), List()) == "xor(i0, sext(i1, 1))") + } + + it should "correctly translate the bitwise reduction primitive operation" in { + // zero width special cases are removed by the firrtl compiler + assert(primop(false, "andr", 1, List(0)) == "1'b1") + assert(primop(false, "orr", 1, List(0)) == "redor(1'b0)") + assert(primop(false, "xorr", 1, List(0)) == "redxor(1'b0)") + + assert(primop(false, "andr", 1, List(3)) == "redand(i0)") + assert(primop(true, "andr", 1, List(3), resAlwaysUnsigned = true) == "redand(i0)") + + assert(primop(false, "orr", 1, List(3)) == "redor(i0)") + assert(primop(true, "orr", 1, List(3), resAlwaysUnsigned = true) == "redor(i0)") + + assert(primop(false, "xorr", 1, List(3)) == "redxor(i0)") + assert(primop(true, "xorr", 1, List(3), resAlwaysUnsigned = true) == "redxor(i0)") + } + + it should "correctly translate the `cat` primitive operation" in { + assert(primop(false, "cat", 7, List(4, 3)) == "concat(i0, i1)") + assert(primop(true, "cat", 7, List(4, 3), resAlwaysUnsigned = true) == "concat(i0, i1)") + } + + it should "correctly translate the `bits` primitive operation" in { + assert(primop(false, "bits", 1, List(4), List(2,2)) == "i0[2]") + assert(primop(false, "bits", 2, List(4), List(2,1)) == "i0[2:1]") + assert(primop(false, "bits", 1, List(4), List(2,1)) == "i0[2:1][0]") + assert(primop(false, "bits", 3, List(4), List(2,1)) == "zext(i0[2:1], 1)") + + assert(primop(true, "bits", 1, List(4), List(2,2), resAlwaysUnsigned = true) == "i0[2]") + assert(primop(true, "bits", 2, List(4), List(2,1), resAlwaysUnsigned = true) == "i0[2:1]") + assert(primop(true, "bits", 1, List(4), List(2,1), resAlwaysUnsigned = true) == "i0[2:1][0]") + assert(primop(true, "bits", 3, List(4), List(2,1), resAlwaysUnsigned = true) == "zext(i0[2:1], 1)") + } + + it should "correctly translate the `head` primitive operation" in { + // "The result of the head operation are the n most significant bits of e" + assert(primop(false, "head", 1, List(4), List(1)) == "i0[3]") + assert(primop(false, "head", 1, List(5), List(1)) == "i0[4]") + assert(primop(false, "head", 3, List(5), List(3)) == "i0[4:2]") + } + + it should "correctly translate the `tail` primitive operation" in { + // "The tail operation truncates the n most significant bits from e" + assert(primop(false, "tail", 3, List(4), List(1)) == "i0[2:0]") + assert(primop(false, "tail", 4, List(5), List(1)) == "i0[3:0]") + assert(primop(false, "tail", 2, List(5), List(3)) == "i0[1:0]") + } +}
\ No newline at end of file diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala new file mode 100644 index 00000000..ca7974c5 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala @@ -0,0 +1,314 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt + +import firrtl.{MemoryArrayInit, MemoryScalarInit, Utils} + +private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec { + behavior of "ModuleToTransitionSystem.run" + + + it should "model registers as state" in { + // if a signal is invalid, it could take on an arbitrary value in that cycle + val src = + """circuit m: + | module m: + | input reset : UInt<1> + | input clock : Clock + | input en : UInt<1> + | input in : UInt<8> + | output out : UInt<8> + | + | reg r : UInt<8>, clock with : (reset => (reset, UInt<8>(0))) + | when en: + | r <= in + | out <= r + | + |""".stripMargin + val sys = toSys(src) + + assert(sys.signals.length == 2) + + // the when is translated as a ITE + val genSignal = sys.signals.filterNot(_.name == "out").head + assert(genSignal.e.toString == "ite(en, in, r)") + + // the reset is synchronous + val r = sys.states.head + assert(r.sym.name == "r") + assert(r.init.isEmpty, "we are not using any preset, so the initial register content is arbitrary") + assert(r.next.get.toString == s"ite(reset, 8'b0, ${genSignal.name})") + } + + private def memCircuit(depth: Int = 32) = + s"""circuit m: + | module m: + | input reset : UInt<1> + | input clock : Clock + | input addr : UInt<${Utils.getUIntWidth(depth)}> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => $depth + | reader => r + | writer => w + | read-latency => 0 + | write-latency => 1 + | read-under-write => new + | + | m.w.clk <= clock + | m.w.mask <= UInt(1) + | m.w.en <= UInt(1) + | m.w.data <= in + | m.w.addr <= addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= addr + | + |""".stripMargin + + it should "model memories as state" in { + val sys = toSys(memCircuit()) + + assert(sys.signals.length == 9-2+1, "9 connects - 2 clock connects + 1 combinatorial read port") + + val sig = sys.signals.map(s => s.name -> s.e).toMap + + // masks and enables should all be true + val True = BVLiteral(1, 1) + assert(sig("m.w.mask") == True) + assert(sig("m.w.en") == True) + assert(sig("m.r.en") == True) + + // read data should always be enabled + assert(sig("m.r.data").toString == "m[m.r.addr]") + + // the memory is modelled as a state + val m = sys.states.find(_.sym.name == "m").get + assert(m.sym.isInstanceOf[ArraySymbol]) + val sym = m.sym.asInstanceOf[ArraySymbol] + assert(sym.indexWidth == 5) + assert(sym.dataWidth == 8) + assert(m.init.isEmpty) + //assert(m.next.get.toString.contains("m[m.w.addr := m.w.data]")) + assert(m.next.get.toString == "m[m.w.addr := m.w.data]") + } + + it should "support scalar initialization of a memory to 0" in { + val sys = toSys(memCircuit(), memInit = Map("m" -> MemoryScalarInit(0))) + val m = sys.states.find(_.sym.name == "m").get + assert(m.init.isDefined) + assert(m.init.get.toString == "([8'b0] x 32)") + } + + it should "support scalar initialization of a memory to 127" in { + val sys = toSys(memCircuit(31), memInit = Map("m" -> MemoryScalarInit(127))) + val m = sys.states.find(_.sym.name == "m").get + assert(m.init.isDefined) + assert(m.init.get.toString == "([8'b1111111] x 32)") + } + + it should "support array initialization of a memory to Seq(0, 1, 2, 3)" in { + val sys = toSys(memCircuit(4), memInit = Map("m" -> MemoryArrayInit(Seq(0, 1, 2, 3)))) + val m = sys.states.find(_.sym.name == "m").get + assert(m.init.isDefined) + assert(m.init.get.toString == "([8'b0] x 4)[2'b1 := 8'b1][2'b10 := 8'b10][2'b11 := 8'b11]") + } + + it should "support array initialization of a memory to Seq(1, 0, 1, 0)" in { + val sys = toSys(memCircuit(4), memInit = Map("m" -> MemoryArrayInit(Seq(1, 0, 1, 0)))) + val m = sys.states.find(_.sym.name == "m").get + assert(m.init.isDefined) + assert(m.init.get.toString == "([8'b1] x 4)[2'b1 := 8'b0][2'b11 := 8'b0]") + } + + it should "support array initialization of a memory to Seq(1, 0, 0, 0)" in { + val sys = toSys(memCircuit(4), memInit = Map("m" -> MemoryArrayInit(Seq(1, 0, 0, 0)))) + val m = sys.states.find(_.sym.name == "m").get + assert(m.init.isDefined) + assert(m.init.get.toString == "([8'b0] x 4)[2'b0 := 8'b1]") + } + + it should "support array initialization from a file" ignore { + assert(false, "TODO") + } + + it should "support memories with registered read port" in { + def src(readUnderWrite: String) = + s"""circuit m: + | module m: + | input reset : UInt<1> + | input clock : Clock + | input addr : UInt<5> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => 32 + | reader => r + | writer => w1, w2 + | read-latency => 1 + | write-latency => 1 + | read-under-write => $readUnderWrite + | + | m.w1.clk <= clock + | m.w1.mask <= UInt(1) + | m.w1.en <= UInt(1) + | m.w1.data <= in + | m.w1.addr <= addr + | m.w2.clk <= clock + | m.w2.mask <= UInt(1) + | m.w2.en <= UInt(1) + | m.w2.data <= in + | m.w2.addr <= addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= addr + | + |""".stripMargin + + + val oldValue = toSys(src("old")) + val oldMData = oldValue.states.find(_.sym.name == "m.r.data").get + assert(oldMData.sym.toString == "m.r.data") + assert(oldMData.next.get.toString == "m[m.r.addr]", "we just need to read the current value") + val readDataSignal = oldValue.signals.find(_.name == "m.r.data") + assert(readDataSignal.isEmpty, s"${readDataSignal.map(_.toString)} should not exist") + + val undefinedValue = toSys(src("undefined")) + val undefinedMData = undefinedValue.states.find(_.sym.name == "m.r.data").get + assert(undefinedMData.sym.toString == "m.r.data") + val undefined = "RANDOM.m_r_read_under_write_undefined" + assert(undefinedMData.next.get.toString == + s"ite(or(eq(m.r.addr, m.w1.addr), eq(m.r.addr, m.w2.addr)), $undefined, m[m.r.addr])", + "randomize result if there is a write") + } + + it should "support memories with potential write-write conflicts" in { + val src = + s"""circuit m: + | module m: + | input reset : UInt<1> + | input clock : Clock + | input addr : UInt<5> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => 32 + | reader => r + | writer => w1, w2 + | read-latency => 1 + | write-latency => 1 + | read-under-write => undefined + | + | m.w1.clk <= clock + | m.w1.mask <= UInt(1) + | m.w1.en <= UInt(1) + | m.w1.data <= in + | m.w1.addr <= addr + | m.w2.clk <= clock + | m.w2.mask <= UInt(1) + | m.w2.en <= UInt(1) + | m.w2.data <= in + | m.w2.addr <= addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= addr + | + |""".stripMargin + + + val sys = toSys(src) + val m = sys.states.find(_.sym.name == "m").get + + val regularUpdate = "m[m.w1.addr := m.w1.data][m.w2.addr := m.w2.data]" + val collision = "eq(m.w1.addr, m.w2.addr)" + val collisionUpdate = "m[m.w1.addr := RANDOM.m_w1_write_write_collision]" + + assert(m.next.get.toString == s"ite($collision, $collisionUpdate, $regularUpdate)") + } + + it should "model invalid signals as inputs" in { + // if a signal is invalid, it could take on an arbitrary value in that cycle + val src = + """circuit m: + | module m: + | input en : UInt<1> + | output o : UInt<8> + | o is invalid + | when en: + | o <= UInt<8>(0) + |""".stripMargin + val sys = toSys(src) + assert(sys.inputs.length == 2) + val random = sys.inputs.filter(_.name.contains("RANDOM")) + assert(random.length == 1) + assert(random.head.width == 8) + } + + it should "throw an error on async reset" in { + val err = intercept[AsyncResetException] { + toSys( + """circuit m: + | module m: + | input reset : AsyncReset + |""".stripMargin + ) + } + assert(err.getMessage.contains("reset")) + } + + it should "throw an error on casting to async reset" in { + val err = intercept[AssertionError] { + toSys( + """circuit m: + | module m: + | input reset : UInt<1> + | node async = asAsyncReset(reset) + |""".stripMargin + ) + } + assert(err.getMessage.contains("reset")) + } + + it should "throw an error on multiple clocks" in { + val err = intercept[MultiClockException] { + toSys( + """circuit m: + | module m: + | input clk1 : Clock + | input clk2 : Clock + |""".stripMargin + ) + } + assert(err.getMessage.contains("clk1, clk2")) + } + + it should "throw an error on using a clock as uInt" in { + // While this could potentially be supported in the future, for now we do not allow + // a clock to be used for anything besides updating registers and memories. + val err = intercept[AssertionError] { + toSys( + """circuit m: + | module m: + | input clk : Clock + | output o : UInt<1> + | o <= asUInt(clk) + | + |""".stripMargin + ) + } + assert(err.getMessage.contains("clk")) + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala new file mode 100644 index 00000000..6bfb5437 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala @@ -0,0 +1,38 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt + +import firrtl.annotations.Annotation +import firrtl.{MemoryInitValue, ir} +import firrtl.stage.{Forms, TransformManager} +import org.scalatest.flatspec.AnyFlatSpec + +private abstract class SMTBackendBaseSpec extends AnyFlatSpec { + private val dependencies = Forms.LowForm + private val compiler = new TransformManager(dependencies) + + protected def compile(src: String, annos: Seq[Annotation] = List()): ir.Circuit = { + val c = firrtl.Parser.parse(src) + compiler.runTransform(firrtl.CircuitState(c, annos)).circuit + } + + protected def toSys(src: String, mod: String = "m", presetRegs: Set[String] = Set(), + memInit: Map[String, MemoryInitValue] = Map()): TransitionSystem = { + val circuit = compile(src) + val module = circuit.modules.find(_.name == mod).get.asInstanceOf[ir.Module] + // println(module.serialize) + new ModuleToTransitionSystem().run(module, presetRegs = presetRegs, memInit = memInit) + } + + protected def toBotr2(src: String, mod: String = "m"): Iterable[String] = + Btor2Serializer.serialize(toSys(src, mod)) + + protected def toBotr2Str(src: String, mod: String = "m"): String = + toBotr2(src, mod).mkString("\n") + "\n" + + protected def toSMTLib(src: String, mod: String = "m"): Iterable[String] = + SMTTransitionSystemEncoder.encode(toSys(src, mod)).map(SMTLibSerializer.serialize) + + protected def toSMTLibStr(src: String, mod: String = "m"): String = + toSMTLib(src, mod).mkString("\n") + "\n" +}
\ No newline at end of file diff --git a/src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala new file mode 100644 index 00000000..7193474d --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala @@ -0,0 +1,66 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt + +private class SMTLibSpec extends SMTBackendBaseSpec { + + it should "convert a hello world module" in { + val src = + """circuit m: + | module m: + | input clock: Clock + | input a: UInt<8> + | output b: UInt<16> + | b <= a + | assert(clock, eq(a, b), UInt(1), "") + |""".stripMargin + + val expected = + """(declare-sort m_s 0) + |; firrtl-smt2-input a 8 + |(declare-fun a_f (m_s) (_ BitVec 8)) + |; firrtl-smt2-output b 16 + |(define-fun b_f ((state m_s)) (_ BitVec 16) ((_ zero_extend 8) (a_f state))) + |; firrtl-smt2-assert assert_ 1 + |(define-fun assert__f ((state m_s)) Bool (= ((_ zero_extend 8) (a_f state)) (b_f state))) + |(define-fun m_t ((state m_s) (state_n m_s)) Bool true) + |(define-fun m_i ((state m_s)) Bool true) + |(define-fun m_a ((state m_s)) Bool (assert__f state)) + |(define-fun m_u ((state m_s)) Bool true) + |""".stripMargin + + assert(toSMTLibStr(src) == expected) + } + + it should "include FileInfo in the output" in { + val src = + """circuit m: @[circuit 0:0] + | module m: @[module 0:0] + | input clock: Clock @[clock 0:0] + | input a: UInt<8> @[a 0:0] + | output b: UInt<16> @[b 0:0] + | b <= a @[b_a 0:0] + | assert(clock, eq(a, b), UInt(1), "") @[assert 0:0] + |""".stripMargin + + val expected = + """; @ module 0:0 + |(declare-sort m_s 0) + |; firrtl-smt2-input a 8 + |; @ a 0:0 + |(declare-fun a_f (m_s) (_ BitVec 8)) + |; firrtl-smt2-output b 16 + |; @ b 0:0, b_a 0:0 + |(define-fun b_f ((state m_s)) (_ BitVec 16) ((_ zero_extend 8) (a_f state))) + |; firrtl-smt2-assert assert_ 1 + |; @ assert 0:0 + |(define-fun assert__f ((state m_s)) Bool (= ((_ zero_extend 8) (a_f state)) (b_f state))) + |(define-fun m_t ((state m_s) (state_n m_s)) Bool true) + |(define-fun m_i ((state m_s)) Bool true) + |(define-fun m_a ((state m_s)) Bool (assert__f state)) + |(define-fun m_u ((state m_s)) Bool true) + |""".stripMargin + + assert(toSMTLibStr(src) == expected) + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala new file mode 100644 index 00000000..4c6901ea --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala @@ -0,0 +1,46 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +import firrtl.annotations.CircuitTarget +import firrtl.backends.experimental.smt.{GlobalClockAnnotation, StutteringClockTransform} +import firrtl.options.Dependency +import firrtl.stage.RunFirrtlTransformAnnotation + +class AsyncResetSpec extends EndToEndSMTBaseSpec { + def annos(name: String) = Seq( + RunFirrtlTransformAnnotation(Dependency[StutteringClockTransform]), + GlobalClockAnnotation(CircuitTarget(name).module(name).ref("global_clock"))) + + "a module with asynchronous reset" should "allow a register to change between clock edges" taggedAs(RequiresZ3) in { + def in(resetType: String) = + s"""circuit AsyncReset00: + | module AsyncReset00: + | input global_clock: Clock + | input c: Clock + | input reset: $resetType + | input preset: AsyncReset + | + | ; a register with async reset + | reg r: UInt<4>, c with: (reset => (reset, UInt(3))) + | + | ; a counter/toggler connected to the clock c + | reg count: UInt<1>, c with: (reset => (preset, UInt(0))) + | count <= add(count, UInt(1)) + | + | ; the past machinery and the assertion uses the global clock + | reg past_valid: UInt<1>, global_clock with: (reset => (preset, UInt(0))) + | past_valid <= UInt(1) + | reg past_r: UInt<4>, global_clock + | past_r <= r + | reg past_count: UInt<1>, global_clock + | past_count <= count + | + | ; can the value of r change without the count changing? + | assert(global_clock, or(not(eq(count, past_count)), eq(r, past_r)), past_valid, "count = past(count) |-> r = past(r)") + |""".stripMargin + test(in("AsyncReset"), MCFail(1), kmax=2, annos=annos("AsyncReset00")) + test(in("UInt<1>"), MCSuccess, kmax=2, annos=annos("AsyncReset00")) + } + +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala new file mode 100644 index 00000000..2227719b --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala @@ -0,0 +1,230 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +import java.io.{File, PrintWriter} + +import firrtl.annotations.{Annotation, CircuitTarget, PresetAnnotation} +import firrtl.backends.experimental.smt.{Btor2Emitter, SMTLibEmitter} +import firrtl.options.TargetDirAnnotation +import firrtl.stage.{FirrtlCircuitAnnotation, FirrtlStage, OutputFileAnnotation, RunFirrtlTransformAnnotation} +import firrtl.util.BackendCompilationUtilities +import logger.{LazyLogging, LogLevel, LogLevelAnnotation} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.must.Matchers + +import scala.sys.process._ + +class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging { + "we" should "check if Z3 is available" taggedAs(RequiresZ3) in { + val log = ProcessLogger(_ => (), logger.warn(_)) + val ret = Process(Seq("which", "z3")).run(log).exitValue() + if(ret != 0) { + logger.error( + """The z3 SMT-Solver seems not to be installed. + |You can exclude the end-to-end smt backend tests which rely on z3 like this: + |sbt testOnly -- -l RequiresZ3 + |""".stripMargin) + } + assert(ret == 0) + } + + "Z3" should "be available in version 4" taggedAs(RequiresZ3) in { + assert(Z3ModelChecker.getZ3Version.startsWith("4.")) + } + + "a simple combinatorial check" should "pass" taggedAs(RequiresZ3) in { + val in = + """circuit CC00: + | module CC00: + | input c: Clock + | input a: UInt<1> + | input b: UInt<1> + | assert(c, lt(add(a, b), UInt(3)), UInt(1), "a + b < 3") + |""".stripMargin + test(in, MCSuccess) + } + + "a simple combinatorial check" should "fail immediately" taggedAs(RequiresZ3) in { + val in = + """circuit CC01: + | module CC01: + | input c: Clock + | input a: UInt<1> + | input b: UInt<1> + | assert(c, gt(add(a, b), UInt(3)), UInt(1), "a + b > 3") + |""".stripMargin + test(in, MCFail(0)) + } + + "adding the right assumption" should "make a test pass" taggedAs(RequiresZ3) in { + val in0 = + """circuit CC01: + | module CC01: + | input c: Clock + | input a: UInt<1> + | input b: UInt<1> + | assert(c, neq(add(a, b), UInt(0)), UInt(1), "a + b != 0") + |""".stripMargin + val in1 = + """circuit CC01: + | module CC01: + | input c: Clock + | input a: UInt<1> + | input b: UInt<1> + | assert(c, neq(add(a, b), UInt(0)), UInt(1), "a + b != 0") + | assume(c, neq(a, UInt(0)), UInt(1), "a != 0") + |""".stripMargin + test(in0, MCFail(0)) + test(in1, MCSuccess) + + val in2 = + """circuit CC01: + | module CC01: + | input c: Clock + | input a: UInt<1> + | input b: UInt<1> + | input en: UInt<1> + | assert(c, neq(add(a, b), UInt(0)), UInt(1), "a + b != 0") + | assume(c, neq(a, UInt(0)), en, "a != 0 if en") + |""".stripMargin + test(in2, MCFail(0)) + } + + "a register connected to preset reset" should "be initialized with the reset value" taggedAs(RequiresZ3) in { + def in(rEq: Int) = + s"""circuit Preset00: + | module Preset00: + | input c: Clock + | input preset: AsyncReset + | reg r: UInt<4>, c with: (reset => (preset, UInt(3))) + | assert(c, eq(r, UInt($rEq)), UInt(1), "r = $rEq") + |""".stripMargin + test(in(3), MCSuccess, kmax = 1) + test(in(2), MCFail(0)) + } + + "a register's initial value" should "should not change" taggedAs(RequiresZ3) in { + val in = + """circuit Preset00: + | module Preset00: + | input c: Clock + | input preset: AsyncReset + | + | ; the past value of our register will only be valid in the 1st unrolling + | reg past_valid: UInt<1>, c with: (reset => (preset, UInt(0))) + | past_valid <= UInt(1) + | + | reg r: UInt<4>, c + | reg r_past: UInt<4>, c + | r_past <= r + | assert(c, eq(r, r_past), past_valid, "past_valid => r == r_past") + |""".stripMargin + test(in, MCSuccess, kmax = 2) + } +} + +abstract class EndToEndSMTBaseSpec extends AnyFlatSpec with Matchers { + def test(src: String, expected: MCResult, kmax: Int = 0, clue: String = "", annos: Seq[Annotation] = Seq()): Unit = { + expected match { + case MCFail(k) => assert(kmax >= k, s"Please set a kmax that includes the expected failing step! ($kmax < $expected)") + case _ => + } + val fir = firrtl.Parser.parse(src) + val name = fir.main + val testDir = BackendCompilationUtilities.createTestDirectory("EndToEndSMT." + name) + // we automagically add a preset annotation if an input called preset exists + val presetAnno = if(!src.contains("input preset")) { None } else { + Some(PresetAnnotation(CircuitTarget(name).module(name).ref("preset"))) + } + val res = (new FirrtlStage).execute(Array(), Seq( + LogLevelAnnotation(LogLevel.Error), // silence warnings for tests + RunFirrtlTransformAnnotation(new SMTLibEmitter), + RunFirrtlTransformAnnotation(new Btor2Emitter), + FirrtlCircuitAnnotation(fir), + TargetDirAnnotation(testDir.getAbsolutePath) + ) ++ presetAnno ++ annos) + assert(res.collectFirst{ case _: OutputFileAnnotation => true }.isDefined) + val r = Z3ModelChecker.bmc(testDir, name, kmax) + assert(r == expected, clue + "\n" + s"$testDir") + } +} + +/** Minimal implementation of a Z3 based bounded model checker. + * A more complete version of this with better use feedback should eventually be provided by a + * chisel3 formal verification library. Do not use this implementation outside of the firrtl test suite! + * */ +private object Z3ModelChecker extends LazyLogging { + def getZ3Version: String = { + val (out, ret) = executeCmd("-version") + assert(ret == 0, "failed to call z3") + assert(out.startsWith("Z3 version"), s"$out does not start with 'Z3 version'") + val version = out.split(" ")(2) + version + } + + def bmc(testDir: File, main: String, kmax: Int): MCResult = { + assert(kmax >=0 && kmax < 50, "Trying to keep kmax in a reasonable range.") + val smtFile = new File(testDir, main + ".smt2") + val header = read(smtFile) + val steps = (0 to kmax).map(k => new File(testDir, main + s"_step$k.smt2")).zipWithIndex + steps.foreach { case (f,k) => + writeStep(f, main, header, k) + val success = executeStep(f.getAbsolutePath) + if(!success) return MCFail(k) + } + MCSuccess + } + + private def executeStep(filename: String): Boolean = { + val (out, ret) = executeCmd(filename) + assert(ret == 0, s"expected success (0), not $ret: `$out`\nz3 $filename") + assert(out == "sat" || out == "unsat", s"Unexpected output: $out") + out == "unsat" + } + + private def executeCmd(cmd: String): (String, Int) = { + var out = "" + val log = ProcessLogger(s => out = s, logger.warn(_)) + val ret = Process(Seq("z3", cmd)).run(log).exitValue() + (out, ret) + } + + private def writeStep(f: File, main: String, header: Iterable[String], k: Int): Unit = { + val pw = new PrintWriter(f) + val lines = header ++ step(main, k) ++ List("(check-sat)") + lines.foreach(pw.println) + pw.close() + } + + private def step(main: String, k: Int): Iterable[String] = { + // define all states + (0 to k).map(ii => s"(declare-fun s$ii () $main$StateTpe)") ++ + // assert that init holds in state 0 + List(s"(assert ($main$Init s0))") ++ + // assert transition relation + (0 until k).map(ii => s"(assert ($main$Transition s$ii s${ii+1}))") ++ + // assert that assumptions hold in all states + (0 to k).map(ii => s"(assert ($main$Assumes s$ii))") ++ + // assert that assertions hold for all but last state + (0 until k).map(ii => s"(assert ($main$Asserts s$ii))") ++ + // check to see if we can violate the assertions in the last state + List(s"(assert (not ($main$Asserts s$k)))") + } + + private def read(f: File): Iterable[String] = { + val source = scala.io.Source.fromFile(f) + try source.getLines().toVector finally source.close() + } + + // the following suffixes have to match the ones in [[SMTTransitionSystemEncoder]] + private val Transition = "_t" + private val Init = "_i" + private val Asserts = "_a" + private val Assumes = "_u" + private val StateTpe = "_s" +} + +private sealed trait MCResult +private case object MCSuccess extends MCResult +private case class MCFail(k: Int) extends MCResult diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala new file mode 100644 index 00000000..10de9cda --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala @@ -0,0 +1,107 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +import firrtl.annotations.{CircuitTarget, MemoryArrayInitAnnotation, MemoryScalarInitAnnotation} + +class MemorySpec extends EndToEndSMTBaseSpec { + private def registeredTestMem(name: String, cmds: String, readUnderWrite: String): String = + registeredTestMem(name, cmds.split("\n"), readUnderWrite) + private def registeredTestMem(name: String, cmds: Iterable[String], readUnderWrite: String): String = + s"""circuit $name: + | module $name: + | input reset : UInt<1> + | input clock : Clock + | input preset: AsyncReset + | input write_addr : UInt<5> + | input read_addr : UInt<5> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => 32 + | reader => r + | writer => w + | read-latency => 1 + | write-latency => 1 + | read-under-write => $readUnderWrite + | + | m.w.clk <= clock + | m.w.mask <= UInt(1) + | m.w.en <= UInt(1) + | m.w.data <= in + | m.w.addr <= write_addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= read_addr + | + | reg cycle: UInt<8>, clock with: (reset => (preset, UInt(0))) + | cycle <= add(cycle, UInt(1)) + | node past_valid = geq(cycle, UInt(1)) + | + | ${cmds.mkString("\n ")} + |""".stripMargin + + "Registered test memory" should "return written data after two cycles" taggedAs(RequiresZ3) in { + val cmds = + """node past_past_valid = geq(cycle, UInt(2)) + |reg past_in: UInt<8>, clock + |past_in <= in + |reg past_past_in: UInt<8>, clock + |past_past_in <= past_in + |reg past_write_addr: UInt<5>, clock + |past_write_addr <= write_addr + | + |assume(clock, eq(read_addr, past_write_addr), past_valid, "read_addr = past(write_addr)") + |assert(clock, eq(out, past_past_in), past_past_valid, "out = past(past(in))") + |""".stripMargin + test(registeredTestMem("Mem00", cmds, "old"), MCSuccess, kmax = 3) + } + + private def readOnlyMem(pred: String, num: Int) = + s"""circuit Mem0$num: + | module Mem0$num: + | input c : Clock + | input read_addr : UInt<2> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => 4 + | reader => r + | read-latency => 0 + | write-latency => 1 + | read-under-write => new + | + | m.r.clk <= c + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= read_addr + | + | assert(c, $pred, UInt(1), "") + |""".stripMargin + private def m(num: Int) = CircuitTarget(s"Mem0$num").module(s"Mem0$num").ref("m") + + "read-only memory" should "always return 0" taggedAs(RequiresZ3) in { + test(readOnlyMem("eq(out, UInt(0))", 1), MCSuccess, kmax=2, + annos=Seq(MemoryScalarInitAnnotation(m(1), 0))) + } + + "read-only memory" should "not always return 1" taggedAs(RequiresZ3) in { + test(readOnlyMem("eq(out, UInt(1))", 2), MCFail(0), kmax=2, + annos=Seq(MemoryScalarInitAnnotation(m(2), 0))) + } + + "read-only memory" should "always return 1 or 2" taggedAs(RequiresZ3) in { + test(readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 3), MCSuccess, kmax=2, + annos=Seq(MemoryArrayInitAnnotation(m(3), Seq(1, 2, 2, 1)))) + } + + "read-only memory" should "not always return 1 or 2 or 3" taggedAs(RequiresZ3) in { + test(readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 4), MCFail(0), kmax=2, + annos=Seq(MemoryArrayInitAnnotation(m(4), Seq(1, 2, 2, 3)))) + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/RequiresZ3.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/RequiresZ3.scala new file mode 100644 index 00000000..d633a1a0 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/RequiresZ3.scala @@ -0,0 +1,9 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +import org.scalatest.Tag + +// To disable tests that require the Z3 SMT solver to be installed use the following: +// `sbt testOnly -- -l RequiresZ3` +object RequiresZ3 extends Tag("RequiresZ3") diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala new file mode 100644 index 00000000..cbf194dd --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala @@ -0,0 +1,46 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +import java.io.File + +import firrtl.stage.{FirrtlStage, OutputFileAnnotation} +import firrtl.util.BackendCompilationUtilities +import logger.LazyLogging +import org.scalatest.flatspec.AnyFlatSpec + +import scala.sys.process.{Process, ProcessLogger} + +/** compiles the regression tests to SMTLib and parses the result with z3 */ +class SMTCompilationTest extends AnyFlatSpec with LazyLogging { + it should "generate valid SMTLib for AddNot" taggedAs(RequiresZ3) in { compileAndParse("AddNot") } + it should "generate valid SMTLib for FPU" taggedAs(RequiresZ3) in { compileAndParse("FPU") } + // we get a stack overflow in Scala 2.11 because of a deeply nested and(...) expression in the sequencer + it should "generate valid SMTLib for HwachaSequencer" taggedAs(RequiresZ3) ignore { compileAndParse("HwachaSequencer") } + it should "generate valid SMTLib for ICache" taggedAs(RequiresZ3) in { compileAndParse("ICache") } + it should "generate valid SMTLib for Ops" taggedAs(RequiresZ3) in { compileAndParse("Ops") } + // TODO: enable Rob test once we support more than 2 write ports on a memory + it should "generate valid SMTLib for Rob" taggedAs(RequiresZ3) ignore { compileAndParse("Rob") } + it should "generate valid SMTLib for RocketCore" taggedAs(RequiresZ3) in { compileAndParse("RocketCore") } + + private def compileAndParse(name: String): Unit = { + val testDir = BackendCompilationUtilities.createTestDirectory(name + "-smt") + val inputFile = new File(testDir, s"${name}.fir") + BackendCompilationUtilities.copyResourceToFile(s"/regress/${name}.fir", inputFile) + + val args = Array( + "-ll", "error", // surpress warnings to keep test output clean + "--target-dir", testDir.toString, + "-i", inputFile.toString, + "-E", "experimental-smt2" + // "-fct", "firrtl.backends.experimental.smt.StutteringClockTransform" + ) + val res = (new FirrtlStage).execute(args, Seq()) + val fileName = res.collectFirst{ case OutputFileAnnotation(file) => file }.get + + val smtFile = testDir.toString + "/" + fileName + ".smt2" + val log = ProcessLogger(_ => (), logger.error(_)) + val z3Ret = Process(Seq("z3", smtFile)).run(log).exitValue() + assert(z3Ret == 0, s"Failed to parse SMTLib file $smtFile generated for $name") + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala new file mode 100644 index 00000000..8fa80b4c --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala @@ -0,0 +1,40 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +/** undefined values in firrtl are modelled as fresh auxiliary variables (inputs) */ +class UndefinedFirrtlSpec extends EndToEndSMTBaseSpec { + + "division by zero" should "result in an arbitrary value" taggedAs(RequiresZ3) in { + // the SMTLib spec defines the result of division by zero to be all 1s + // https://cs.nyu.edu/pipermail/smt-lib/2015/000977.html + def in(dEq: Int) = + s"""circuit CC00: + | module CC00: + | input c: Clock + | input a: UInt<2> + | input b: UInt<2> + | assume(c, eq(b, UInt(0)), UInt(1), "b = 0") + | node d = div(a, b) + | assert(c, eq(d, UInt($dEq)), UInt(1), "d = $dEq") + |""".stripMargin + // we try to assert that (d = a / 0) is any fixed value which should be false + (0 until 4).foreach { ii => test(in(ii), MCFail(0), 0, s"d = a / 0 = $ii") } + } + + // TODO: rem should probably also be undefined, but the spec isn't 100% clear here + + + "invalid signals" should "have an arbitrary values" taggedAs(RequiresZ3) in { + def in(aEq: Int) = + s"""circuit CC00: + | module CC00: + | input c: Clock + | wire a: UInt<2> + | a is invalid + | assert(c, eq(a, UInt($aEq)), UInt(1), "a = $aEq") + |""".stripMargin + // a should not be equivalent to any fixed value (0, 1, 2 or 3) + (0 until 4).foreach { ii => test(in(ii), MCFail(0), 0, s"a = $ii") } + } +} |
