diff options
Diffstat (limited to 'src/test')
11 files changed, 1181 insertions, 0 deletions
diff --git a/src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala b/src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala new file mode 100644 index 00000000..56f891e6 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala @@ -0,0 +1,61 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt + +private class Btor2Spec extends SMTBackendBaseSpec { + + it should "convert a hello world module" in { + val src = + """circuit m: + | module m: + | input clock: Clock + | input a: UInt<8> + | output b: UInt<16> + | b <= a + | assert(clock, eq(a, b), UInt(1), "") + |""".stripMargin + + val expected = + """1 sort bitvec 8 + |2 input 1 a + |3 sort bitvec 16 + |4 uext 3 2 8 + |5 output 4 ; b + |6 sort bitvec 1 + |7 uext 3 2 8 + |8 eq 6 7 4 + |9 not 6 8 + |10 bad 9 ; assert_ + |""".stripMargin + + assert(toBotr2Str(src) == expected) + } + + it should "include FileInfo in the output" in { + val src = + """circuit m: @[circuit 0:0] + | module m: @[module 0:0] + | input clock: Clock @[clock 0:0] + | input a: UInt<8> @[a 0:0] + | output b: UInt<16> @[b 0:0] + | b <= a @[b_a 0:0] + | assert(clock, eq(a, b), UInt(1), "") @[assert 0:0] + |""".stripMargin + + val expected = + """; @ module 0:0 + |1 sort bitvec 8 + |2 input 1 a ; @ a 0:0 + |3 sort bitvec 16 + |4 uext 3 2 8 + |5 output 4 ; b @ b 0:0, b_a 0:0 + |6 sort bitvec 1 + |7 uext 3 2 8 + |8 eq 6 7 4 + |9 not 6 8 + |10 bad 9 ; assert_ @ assert 0:0 + |""".stripMargin + + assert(toBotr2Str(src) == expected) + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala new file mode 100644 index 00000000..015ac4a9 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala @@ -0,0 +1,224 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt + +private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec { + + def primop(op: String, resTpe: String, inTpes: Seq[String], consts: Seq[Int]): String = { + val inputs = inTpes.zipWithIndex.map { case (tpe, ii) => s" input i$ii : $tpe" }.mkString("\n") + val args = (inTpes.zipWithIndex.map { case (_, ii) => s"i$ii" } ++ consts.map(_.toString)).mkString(", ") + val src = + s"""circuit m: + | module m: + |$inputs + | output res: $resTpe + | res <= $op($args) + | + |""".stripMargin + val sys = toSys(src) + assert(sys.signals.length == 1) + sys.signals.head.e.toString + } + + def primop(signed: Boolean, op: String, resWidth: Int, inWidth: Seq[Int], consts: Seq[Int] = List(), + resAlwaysUnsigned: Boolean = false): String = { + val tpe = if(signed) "SInt" else "UInt" + val resTpe = if(resAlwaysUnsigned) "UInt" else tpe + val inTpes = inWidth.map(w => s"$tpe<$w>") + primop(op, s"$resTpe<$resWidth>", inTpes, consts) + } + + it should "correctly translate the add primitive operation with different operand sizes" in { + assert(primop(false, "add", 5, List(3, 5)) == "add(zext(i0, 3), zext(i1, 1))[4:0]") + assert(primop(false, "add", 5, List(3, 4)) == "add(zext(i0, 2), zext(i1, 1))") + assert(primop(true, "add", 5, List(3, 5)) == "add(sext(i0, 3), sext(i1, 1))[4:0]") + assert(primop(true, "add", 5, List(3, 4)) == "add(sext(i0, 2), sext(i1, 1))") + + // could be simplified to just `add(i0, i1)` + assert(primop(false, "add", 8, List(8, 8)) == "add(zext(i0, 1), zext(i1, 1))[7:0]") + } + + it should "correctly translate the `add` primitive operation" in { + assert(primop(false, "add", 8, List(7, 7)) == "add(zext(i0, 1), zext(i1, 1))") + } + + it should "correctly translate the `sub` primitive operation" in { + assert(primop(false, "sub", 8, List(7, 7)) == "sub(zext(i0, 1), zext(i1, 1))") + } + + it should "correctly translate the `mul` primitive operation" in { + assert(primop(false, "mul", 8, List(4, 4)) == "mul(zext(i0, 4), zext(i1, 4))") + } + + it should "correctly translate the `div` primitive operation" in { + // division is a little bit more complicated because the result of division by zero is undefined + assert(primop(false, "div", 8, List(8, 8)) == + "ite(eq(i1, 8'b0), RANDOM.res, udiv(i0, i1))") + assert(primop(false, "div", 8, List(8, 4)) == + "ite(eq(i1, 4'b0), RANDOM.res, udiv(i0, zext(i1, 4)))") + + // signed division increases result width by 1 + assert(primop(true, "div", 8, List(7, 7)) == + "ite(eq(i1, 7'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 1)))") + assert(primop(true, "div", 8, List(7, 4)) + == "ite(eq(i1, 4'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 4)))") + } + + it should "correctly translate the `rem` primitive operation" in { + // rem can decrease the size of operands, but we should only do that decrease on the result + assert(primop(false, "rem", 4, List(4, 8)) == "urem(zext(i0, 4), i1)[3:0]") + assert(primop(false, "rem", 4, List(8, 4)) == "urem(i0, zext(i1, 4))[3:0]") + assert(primop(true, "rem", 4, List(4, 8)) == "srem(sext(i0, 4), i1)[3:0]") + assert(primop(true, "rem", 4, List(8, 4)) == "srem(i0, sext(i1, 4))[3:0]") + // TODO: add test to make sure we are using the correct mod/rem operation for signed and unsigned + // https://groups.google.com/g/stp-users/c/od43h8q5RSI has some tests that we could copy and + // use with a SMT solver + } + + it should "correctly translate the comparison primitive operations" in { + // some comparisons are represented as the negation of others + assert(primop(false, "lt", 1, List(8, 8)) == "not(ugeq(i0, i1))") + assert(primop(false, "leq", 1, List(8, 8)) == "not(ugt(i0, i1))") + assert(primop(false, "gt", 1, List(8, 8)) == "ugt(i0, i1)") + assert(primop(false, "geq", 1, List(8, 8)) == "ugeq(i0, i1)") + assert(primop(false, "eq", 1, List(8, 8)) == "eq(i0, i1)") + assert(primop(false, "neq", 1, List(8, 8)) == "not(eq(i0, i1))") + + assert(primop(true, "lt", 1, List(8, 8), resAlwaysUnsigned = true) == "not(sgeq(i0, i1))") + assert(primop(true, "leq", 1, List(8, 8), resAlwaysUnsigned = true) == "not(sgt(i0, i1))") + assert(primop(true, "gt", 1, List(8, 8), resAlwaysUnsigned = true) == "sgt(i0, i1)") + assert(primop(true, "geq", 1, List(8, 8), resAlwaysUnsigned = true) == "sgeq(i0, i1)") + assert(primop(true, "eq", 1, List(8, 8), resAlwaysUnsigned = true) == "eq(i0, i1)") + assert(primop(true, "neq", 1, List(8, 8), resAlwaysUnsigned = true) == "not(eq(i0, i1))") + + // it should always extend the width to the max of both + assert(primop(false, "gt", 1, List(7, 8)) == "ugt(zext(i0, 1), i1)") + } + + it should "correctly translate the `pad` primitive operation" in { + // firrtl pad takes new width as argument, whereas the smt zext takes the number of bits to extend by + assert(primop(false, "pad", 8, List(3), List(8)) == "zext(i0, 5)") + assert(primop(false, "pad", 8, List(3), List(5)) == "zext(zext(i0, 2), 3)") + + // there is no negative padding, instead the result is just e + assert(primop(false, "pad", 3, List(3), List(2)) == "i0") + + assert(primop(true, "pad", 8, List(3), List(8)) == "sext(i0, 5)") + assert(primop(true, "pad", 8, List(3), List(5)) == "sext(sext(i0, 2), 3)") + } + + it should "correctly translate the asX primitive operations" in { + // these are all essentially no-ops + assert(primop(false, "asUInt", 3, List(3)) == "i0") + assert(primop(true, "asSInt", 3, List(3)) == "i0") + } + + it should "correctly translate the `shl` primitive operation" in { + assert(primop(false, "shl", 6, List(3), List(3)) == "concat(i0, 3'b0)") + assert(primop(true, "shl", 6, List(3), List(3)) == "concat(i0, 3'b0)") + assert(primop(false, "shl", 3, List(3), List(0)) == "i0") + } + + it should "correctly translate the `shr` primitive operation" in { + assert(primop(false, "shr", 6, List(9), List(3)) == "i0[8:3]") + assert(primop(true, "shr", 6, List(9), List(3)) == "i0[8:3]") + + // "If n is greater than or equal to the bit-width of e, + // the resulting value will be zero for unsigned types and the sign bit for signed types." + assert(primop(false, "shr", 1, List(3), List(3)) == "1'b0") + assert(primop(false, "shr", 1, List(3), List(4)) == "1'b0") + assert(primop(true, "shr", 1, List(3), List(3)) == "i0[2]") + assert(primop(true, "shr", 1, List(3), List(4)) == "i0[2]") + } + + it should "correctly translate the `dshl` primitive operation" in { + assert(primop(false, "dshl", 31, List(16, 4)) == "logical_shift_left(zext(i0, 15), zext(i1, 27))") + assert(primop(false, "dshl", 19, List(16, 2)) == "logical_shift_left(zext(i0, 3), zext(i1, 17))") + assert(primop("dshl", "SInt<19>", List("SInt<16>", "UInt<2>"), List()) == + "logical_shift_left(sext(i0, 3), zext(i1, 17))") + } + + it should "correctly translate the `dshr` primitive operation" in { + assert(primop(false, "dshr", 16, List(16, 4)) == "logical_shift_right(i0, zext(i1, 12))") + assert(primop(false, "dshr", 16, List(16, 2)) == "logical_shift_right(i0, zext(i1, 14))") + assert(primop("dshr", "SInt<16>", List("SInt<16>", "UInt<2>"), List()) == + "arithmetic_shift_right(i0, zext(i1, 14))") + } + + it should "correctly translate the `cvt` primitive operation" in { + // for signed operands, this is a no-op + assert(primop(true, "cvt", 3, List(3)) == "i0") + + // for unsigned, a zero is prepended + assert(primop("cvt", "SInt<16>", List("UInt<15>"), List()) == "concat(1'b0, i0)") + assert(primop("cvt", "SInt<16>", List("UInt<14>"), List()) == "sext(concat(1'b0, i0), 1)") + } + + it should "correctly translate the `neg` primitive operation" in { + assert(primop(true, "neg", 4, List(3)) == "neg(sext(i0, 1))") + assert(primop("neg", "SInt<4>", List("UInt<3>"), List()) == "neg(zext(i0, 1))") + } + + it should "correctly translate the `not` primitive operation" in { + assert(primop(false, "not", 4, List(4)) == "not(i0)") + assert(primop("not", "UInt<4>", List("SInt<4>"), List()) == "not(i0)") + } + + it should "correctly translate the binary bitwise primitive operations" in { + assert(primop(false, "and", 4, List(4, 3)) == "and(i0, zext(i1, 1))") + assert(primop("and", "UInt<4>", List("SInt<4>", "SInt<3>"), List()) == "and(i0, sext(i1, 1))") + + assert(primop(false, "or", 4, List(4, 3)) == "or(i0, zext(i1, 1))") + assert(primop("or", "UInt<4>", List("SInt<4>", "SInt<3>"), List()) == "or(i0, sext(i1, 1))") + + assert(primop(false, "xor", 4, List(4, 3)) == "xor(i0, zext(i1, 1))") + assert(primop("xor", "UInt<4>", List("SInt<4>", "SInt<3>"), List()) == "xor(i0, sext(i1, 1))") + } + + it should "correctly translate the bitwise reduction primitive operation" in { + // zero width special cases are removed by the firrtl compiler + assert(primop(false, "andr", 1, List(0)) == "1'b1") + assert(primop(false, "orr", 1, List(0)) == "redor(1'b0)") + assert(primop(false, "xorr", 1, List(0)) == "redxor(1'b0)") + + assert(primop(false, "andr", 1, List(3)) == "redand(i0)") + assert(primop(true, "andr", 1, List(3), resAlwaysUnsigned = true) == "redand(i0)") + + assert(primop(false, "orr", 1, List(3)) == "redor(i0)") + assert(primop(true, "orr", 1, List(3), resAlwaysUnsigned = true) == "redor(i0)") + + assert(primop(false, "xorr", 1, List(3)) == "redxor(i0)") + assert(primop(true, "xorr", 1, List(3), resAlwaysUnsigned = true) == "redxor(i0)") + } + + it should "correctly translate the `cat` primitive operation" in { + assert(primop(false, "cat", 7, List(4, 3)) == "concat(i0, i1)") + assert(primop(true, "cat", 7, List(4, 3), resAlwaysUnsigned = true) == "concat(i0, i1)") + } + + it should "correctly translate the `bits` primitive operation" in { + assert(primop(false, "bits", 1, List(4), List(2,2)) == "i0[2]") + assert(primop(false, "bits", 2, List(4), List(2,1)) == "i0[2:1]") + assert(primop(false, "bits", 1, List(4), List(2,1)) == "i0[2:1][0]") + assert(primop(false, "bits", 3, List(4), List(2,1)) == "zext(i0[2:1], 1)") + + assert(primop(true, "bits", 1, List(4), List(2,2), resAlwaysUnsigned = true) == "i0[2]") + assert(primop(true, "bits", 2, List(4), List(2,1), resAlwaysUnsigned = true) == "i0[2:1]") + assert(primop(true, "bits", 1, List(4), List(2,1), resAlwaysUnsigned = true) == "i0[2:1][0]") + assert(primop(true, "bits", 3, List(4), List(2,1), resAlwaysUnsigned = true) == "zext(i0[2:1], 1)") + } + + it should "correctly translate the `head` primitive operation" in { + // "The result of the head operation are the n most significant bits of e" + assert(primop(false, "head", 1, List(4), List(1)) == "i0[3]") + assert(primop(false, "head", 1, List(5), List(1)) == "i0[4]") + assert(primop(false, "head", 3, List(5), List(3)) == "i0[4:2]") + } + + it should "correctly translate the `tail` primitive operation" in { + // "The tail operation truncates the n most significant bits from e" + assert(primop(false, "tail", 3, List(4), List(1)) == "i0[2:0]") + assert(primop(false, "tail", 4, List(5), List(1)) == "i0[3:0]") + assert(primop(false, "tail", 2, List(5), List(3)) == "i0[1:0]") + } +}
\ No newline at end of file diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala new file mode 100644 index 00000000..ca7974c5 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala @@ -0,0 +1,314 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt + +import firrtl.{MemoryArrayInit, MemoryScalarInit, Utils} + +private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec { + behavior of "ModuleToTransitionSystem.run" + + + it should "model registers as state" in { + // if a signal is invalid, it could take on an arbitrary value in that cycle + val src = + """circuit m: + | module m: + | input reset : UInt<1> + | input clock : Clock + | input en : UInt<1> + | input in : UInt<8> + | output out : UInt<8> + | + | reg r : UInt<8>, clock with : (reset => (reset, UInt<8>(0))) + | when en: + | r <= in + | out <= r + | + |""".stripMargin + val sys = toSys(src) + + assert(sys.signals.length == 2) + + // the when is translated as a ITE + val genSignal = sys.signals.filterNot(_.name == "out").head + assert(genSignal.e.toString == "ite(en, in, r)") + + // the reset is synchronous + val r = sys.states.head + assert(r.sym.name == "r") + assert(r.init.isEmpty, "we are not using any preset, so the initial register content is arbitrary") + assert(r.next.get.toString == s"ite(reset, 8'b0, ${genSignal.name})") + } + + private def memCircuit(depth: Int = 32) = + s"""circuit m: + | module m: + | input reset : UInt<1> + | input clock : Clock + | input addr : UInt<${Utils.getUIntWidth(depth)}> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => $depth + | reader => r + | writer => w + | read-latency => 0 + | write-latency => 1 + | read-under-write => new + | + | m.w.clk <= clock + | m.w.mask <= UInt(1) + | m.w.en <= UInt(1) + | m.w.data <= in + | m.w.addr <= addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= addr + | + |""".stripMargin + + it should "model memories as state" in { + val sys = toSys(memCircuit()) + + assert(sys.signals.length == 9-2+1, "9 connects - 2 clock connects + 1 combinatorial read port") + + val sig = sys.signals.map(s => s.name -> s.e).toMap + + // masks and enables should all be true + val True = BVLiteral(1, 1) + assert(sig("m.w.mask") == True) + assert(sig("m.w.en") == True) + assert(sig("m.r.en") == True) + + // read data should always be enabled + assert(sig("m.r.data").toString == "m[m.r.addr]") + + // the memory is modelled as a state + val m = sys.states.find(_.sym.name == "m").get + assert(m.sym.isInstanceOf[ArraySymbol]) + val sym = m.sym.asInstanceOf[ArraySymbol] + assert(sym.indexWidth == 5) + assert(sym.dataWidth == 8) + assert(m.init.isEmpty) + //assert(m.next.get.toString.contains("m[m.w.addr := m.w.data]")) + assert(m.next.get.toString == "m[m.w.addr := m.w.data]") + } + + it should "support scalar initialization of a memory to 0" in { + val sys = toSys(memCircuit(), memInit = Map("m" -> MemoryScalarInit(0))) + val m = sys.states.find(_.sym.name == "m").get + assert(m.init.isDefined) + assert(m.init.get.toString == "([8'b0] x 32)") + } + + it should "support scalar initialization of a memory to 127" in { + val sys = toSys(memCircuit(31), memInit = Map("m" -> MemoryScalarInit(127))) + val m = sys.states.find(_.sym.name == "m").get + assert(m.init.isDefined) + assert(m.init.get.toString == "([8'b1111111] x 32)") + } + + it should "support array initialization of a memory to Seq(0, 1, 2, 3)" in { + val sys = toSys(memCircuit(4), memInit = Map("m" -> MemoryArrayInit(Seq(0, 1, 2, 3)))) + val m = sys.states.find(_.sym.name == "m").get + assert(m.init.isDefined) + assert(m.init.get.toString == "([8'b0] x 4)[2'b1 := 8'b1][2'b10 := 8'b10][2'b11 := 8'b11]") + } + + it should "support array initialization of a memory to Seq(1, 0, 1, 0)" in { + val sys = toSys(memCircuit(4), memInit = Map("m" -> MemoryArrayInit(Seq(1, 0, 1, 0)))) + val m = sys.states.find(_.sym.name == "m").get + assert(m.init.isDefined) + assert(m.init.get.toString == "([8'b1] x 4)[2'b1 := 8'b0][2'b11 := 8'b0]") + } + + it should "support array initialization of a memory to Seq(1, 0, 0, 0)" in { + val sys = toSys(memCircuit(4), memInit = Map("m" -> MemoryArrayInit(Seq(1, 0, 0, 0)))) + val m = sys.states.find(_.sym.name == "m").get + assert(m.init.isDefined) + assert(m.init.get.toString == "([8'b0] x 4)[2'b0 := 8'b1]") + } + + it should "support array initialization from a file" ignore { + assert(false, "TODO") + } + + it should "support memories with registered read port" in { + def src(readUnderWrite: String) = + s"""circuit m: + | module m: + | input reset : UInt<1> + | input clock : Clock + | input addr : UInt<5> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => 32 + | reader => r + | writer => w1, w2 + | read-latency => 1 + | write-latency => 1 + | read-under-write => $readUnderWrite + | + | m.w1.clk <= clock + | m.w1.mask <= UInt(1) + | m.w1.en <= UInt(1) + | m.w1.data <= in + | m.w1.addr <= addr + | m.w2.clk <= clock + | m.w2.mask <= UInt(1) + | m.w2.en <= UInt(1) + | m.w2.data <= in + | m.w2.addr <= addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= addr + | + |""".stripMargin + + + val oldValue = toSys(src("old")) + val oldMData = oldValue.states.find(_.sym.name == "m.r.data").get + assert(oldMData.sym.toString == "m.r.data") + assert(oldMData.next.get.toString == "m[m.r.addr]", "we just need to read the current value") + val readDataSignal = oldValue.signals.find(_.name == "m.r.data") + assert(readDataSignal.isEmpty, s"${readDataSignal.map(_.toString)} should not exist") + + val undefinedValue = toSys(src("undefined")) + val undefinedMData = undefinedValue.states.find(_.sym.name == "m.r.data").get + assert(undefinedMData.sym.toString == "m.r.data") + val undefined = "RANDOM.m_r_read_under_write_undefined" + assert(undefinedMData.next.get.toString == + s"ite(or(eq(m.r.addr, m.w1.addr), eq(m.r.addr, m.w2.addr)), $undefined, m[m.r.addr])", + "randomize result if there is a write") + } + + it should "support memories with potential write-write conflicts" in { + val src = + s"""circuit m: + | module m: + | input reset : UInt<1> + | input clock : Clock + | input addr : UInt<5> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => 32 + | reader => r + | writer => w1, w2 + | read-latency => 1 + | write-latency => 1 + | read-under-write => undefined + | + | m.w1.clk <= clock + | m.w1.mask <= UInt(1) + | m.w1.en <= UInt(1) + | m.w1.data <= in + | m.w1.addr <= addr + | m.w2.clk <= clock + | m.w2.mask <= UInt(1) + | m.w2.en <= UInt(1) + | m.w2.data <= in + | m.w2.addr <= addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= addr + | + |""".stripMargin + + + val sys = toSys(src) + val m = sys.states.find(_.sym.name == "m").get + + val regularUpdate = "m[m.w1.addr := m.w1.data][m.w2.addr := m.w2.data]" + val collision = "eq(m.w1.addr, m.w2.addr)" + val collisionUpdate = "m[m.w1.addr := RANDOM.m_w1_write_write_collision]" + + assert(m.next.get.toString == s"ite($collision, $collisionUpdate, $regularUpdate)") + } + + it should "model invalid signals as inputs" in { + // if a signal is invalid, it could take on an arbitrary value in that cycle + val src = + """circuit m: + | module m: + | input en : UInt<1> + | output o : UInt<8> + | o is invalid + | when en: + | o <= UInt<8>(0) + |""".stripMargin + val sys = toSys(src) + assert(sys.inputs.length == 2) + val random = sys.inputs.filter(_.name.contains("RANDOM")) + assert(random.length == 1) + assert(random.head.width == 8) + } + + it should "throw an error on async reset" in { + val err = intercept[AsyncResetException] { + toSys( + """circuit m: + | module m: + | input reset : AsyncReset + |""".stripMargin + ) + } + assert(err.getMessage.contains("reset")) + } + + it should "throw an error on casting to async reset" in { + val err = intercept[AssertionError] { + toSys( + """circuit m: + | module m: + | input reset : UInt<1> + | node async = asAsyncReset(reset) + |""".stripMargin + ) + } + assert(err.getMessage.contains("reset")) + } + + it should "throw an error on multiple clocks" in { + val err = intercept[MultiClockException] { + toSys( + """circuit m: + | module m: + | input clk1 : Clock + | input clk2 : Clock + |""".stripMargin + ) + } + assert(err.getMessage.contains("clk1, clk2")) + } + + it should "throw an error on using a clock as uInt" in { + // While this could potentially be supported in the future, for now we do not allow + // a clock to be used for anything besides updating registers and memories. + val err = intercept[AssertionError] { + toSys( + """circuit m: + | module m: + | input clk : Clock + | output o : UInt<1> + | o <= asUInt(clk) + | + |""".stripMargin + ) + } + assert(err.getMessage.contains("clk")) + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala new file mode 100644 index 00000000..6bfb5437 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala @@ -0,0 +1,38 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt + +import firrtl.annotations.Annotation +import firrtl.{MemoryInitValue, ir} +import firrtl.stage.{Forms, TransformManager} +import org.scalatest.flatspec.AnyFlatSpec + +private abstract class SMTBackendBaseSpec extends AnyFlatSpec { + private val dependencies = Forms.LowForm + private val compiler = new TransformManager(dependencies) + + protected def compile(src: String, annos: Seq[Annotation] = List()): ir.Circuit = { + val c = firrtl.Parser.parse(src) + compiler.runTransform(firrtl.CircuitState(c, annos)).circuit + } + + protected def toSys(src: String, mod: String = "m", presetRegs: Set[String] = Set(), + memInit: Map[String, MemoryInitValue] = Map()): TransitionSystem = { + val circuit = compile(src) + val module = circuit.modules.find(_.name == mod).get.asInstanceOf[ir.Module] + // println(module.serialize) + new ModuleToTransitionSystem().run(module, presetRegs = presetRegs, memInit = memInit) + } + + protected def toBotr2(src: String, mod: String = "m"): Iterable[String] = + Btor2Serializer.serialize(toSys(src, mod)) + + protected def toBotr2Str(src: String, mod: String = "m"): String = + toBotr2(src, mod).mkString("\n") + "\n" + + protected def toSMTLib(src: String, mod: String = "m"): Iterable[String] = + SMTTransitionSystemEncoder.encode(toSys(src, mod)).map(SMTLibSerializer.serialize) + + protected def toSMTLibStr(src: String, mod: String = "m"): String = + toSMTLib(src, mod).mkString("\n") + "\n" +}
\ No newline at end of file diff --git a/src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala new file mode 100644 index 00000000..7193474d --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala @@ -0,0 +1,66 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt + +private class SMTLibSpec extends SMTBackendBaseSpec { + + it should "convert a hello world module" in { + val src = + """circuit m: + | module m: + | input clock: Clock + | input a: UInt<8> + | output b: UInt<16> + | b <= a + | assert(clock, eq(a, b), UInt(1), "") + |""".stripMargin + + val expected = + """(declare-sort m_s 0) + |; firrtl-smt2-input a 8 + |(declare-fun a_f (m_s) (_ BitVec 8)) + |; firrtl-smt2-output b 16 + |(define-fun b_f ((state m_s)) (_ BitVec 16) ((_ zero_extend 8) (a_f state))) + |; firrtl-smt2-assert assert_ 1 + |(define-fun assert__f ((state m_s)) Bool (= ((_ zero_extend 8) (a_f state)) (b_f state))) + |(define-fun m_t ((state m_s) (state_n m_s)) Bool true) + |(define-fun m_i ((state m_s)) Bool true) + |(define-fun m_a ((state m_s)) Bool (assert__f state)) + |(define-fun m_u ((state m_s)) Bool true) + |""".stripMargin + + assert(toSMTLibStr(src) == expected) + } + + it should "include FileInfo in the output" in { + val src = + """circuit m: @[circuit 0:0] + | module m: @[module 0:0] + | input clock: Clock @[clock 0:0] + | input a: UInt<8> @[a 0:0] + | output b: UInt<16> @[b 0:0] + | b <= a @[b_a 0:0] + | assert(clock, eq(a, b), UInt(1), "") @[assert 0:0] + |""".stripMargin + + val expected = + """; @ module 0:0 + |(declare-sort m_s 0) + |; firrtl-smt2-input a 8 + |; @ a 0:0 + |(declare-fun a_f (m_s) (_ BitVec 8)) + |; firrtl-smt2-output b 16 + |; @ b 0:0, b_a 0:0 + |(define-fun b_f ((state m_s)) (_ BitVec 16) ((_ zero_extend 8) (a_f state))) + |; firrtl-smt2-assert assert_ 1 + |; @ assert 0:0 + |(define-fun assert__f ((state m_s)) Bool (= ((_ zero_extend 8) (a_f state)) (b_f state))) + |(define-fun m_t ((state m_s) (state_n m_s)) Bool true) + |(define-fun m_i ((state m_s)) Bool true) + |(define-fun m_a ((state m_s)) Bool (assert__f state)) + |(define-fun m_u ((state m_s)) Bool true) + |""".stripMargin + + assert(toSMTLibStr(src) == expected) + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala new file mode 100644 index 00000000..4c6901ea --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala @@ -0,0 +1,46 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +import firrtl.annotations.CircuitTarget +import firrtl.backends.experimental.smt.{GlobalClockAnnotation, StutteringClockTransform} +import firrtl.options.Dependency +import firrtl.stage.RunFirrtlTransformAnnotation + +class AsyncResetSpec extends EndToEndSMTBaseSpec { + def annos(name: String) = Seq( + RunFirrtlTransformAnnotation(Dependency[StutteringClockTransform]), + GlobalClockAnnotation(CircuitTarget(name).module(name).ref("global_clock"))) + + "a module with asynchronous reset" should "allow a register to change between clock edges" taggedAs(RequiresZ3) in { + def in(resetType: String) = + s"""circuit AsyncReset00: + | module AsyncReset00: + | input global_clock: Clock + | input c: Clock + | input reset: $resetType + | input preset: AsyncReset + | + | ; a register with async reset + | reg r: UInt<4>, c with: (reset => (reset, UInt(3))) + | + | ; a counter/toggler connected to the clock c + | reg count: UInt<1>, c with: (reset => (preset, UInt(0))) + | count <= add(count, UInt(1)) + | + | ; the past machinery and the assertion uses the global clock + | reg past_valid: UInt<1>, global_clock with: (reset => (preset, UInt(0))) + | past_valid <= UInt(1) + | reg past_r: UInt<4>, global_clock + | past_r <= r + | reg past_count: UInt<1>, global_clock + | past_count <= count + | + | ; can the value of r change without the count changing? + | assert(global_clock, or(not(eq(count, past_count)), eq(r, past_r)), past_valid, "count = past(count) |-> r = past(r)") + |""".stripMargin + test(in("AsyncReset"), MCFail(1), kmax=2, annos=annos("AsyncReset00")) + test(in("UInt<1>"), MCSuccess, kmax=2, annos=annos("AsyncReset00")) + } + +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala new file mode 100644 index 00000000..2227719b --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala @@ -0,0 +1,230 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +import java.io.{File, PrintWriter} + +import firrtl.annotations.{Annotation, CircuitTarget, PresetAnnotation} +import firrtl.backends.experimental.smt.{Btor2Emitter, SMTLibEmitter} +import firrtl.options.TargetDirAnnotation +import firrtl.stage.{FirrtlCircuitAnnotation, FirrtlStage, OutputFileAnnotation, RunFirrtlTransformAnnotation} +import firrtl.util.BackendCompilationUtilities +import logger.{LazyLogging, LogLevel, LogLevelAnnotation} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.must.Matchers + +import scala.sys.process._ + +class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging { + "we" should "check if Z3 is available" taggedAs(RequiresZ3) in { + val log = ProcessLogger(_ => (), logger.warn(_)) + val ret = Process(Seq("which", "z3")).run(log).exitValue() + if(ret != 0) { + logger.error( + """The z3 SMT-Solver seems not to be installed. + |You can exclude the end-to-end smt backend tests which rely on z3 like this: + |sbt testOnly -- -l RequiresZ3 + |""".stripMargin) + } + assert(ret == 0) + } + + "Z3" should "be available in version 4" taggedAs(RequiresZ3) in { + assert(Z3ModelChecker.getZ3Version.startsWith("4.")) + } + + "a simple combinatorial check" should "pass" taggedAs(RequiresZ3) in { + val in = + """circuit CC00: + | module CC00: + | input c: Clock + | input a: UInt<1> + | input b: UInt<1> + | assert(c, lt(add(a, b), UInt(3)), UInt(1), "a + b < 3") + |""".stripMargin + test(in, MCSuccess) + } + + "a simple combinatorial check" should "fail immediately" taggedAs(RequiresZ3) in { + val in = + """circuit CC01: + | module CC01: + | input c: Clock + | input a: UInt<1> + | input b: UInt<1> + | assert(c, gt(add(a, b), UInt(3)), UInt(1), "a + b > 3") + |""".stripMargin + test(in, MCFail(0)) + } + + "adding the right assumption" should "make a test pass" taggedAs(RequiresZ3) in { + val in0 = + """circuit CC01: + | module CC01: + | input c: Clock + | input a: UInt<1> + | input b: UInt<1> + | assert(c, neq(add(a, b), UInt(0)), UInt(1), "a + b != 0") + |""".stripMargin + val in1 = + """circuit CC01: + | module CC01: + | input c: Clock + | input a: UInt<1> + | input b: UInt<1> + | assert(c, neq(add(a, b), UInt(0)), UInt(1), "a + b != 0") + | assume(c, neq(a, UInt(0)), UInt(1), "a != 0") + |""".stripMargin + test(in0, MCFail(0)) + test(in1, MCSuccess) + + val in2 = + """circuit CC01: + | module CC01: + | input c: Clock + | input a: UInt<1> + | input b: UInt<1> + | input en: UInt<1> + | assert(c, neq(add(a, b), UInt(0)), UInt(1), "a + b != 0") + | assume(c, neq(a, UInt(0)), en, "a != 0 if en") + |""".stripMargin + test(in2, MCFail(0)) + } + + "a register connected to preset reset" should "be initialized with the reset value" taggedAs(RequiresZ3) in { + def in(rEq: Int) = + s"""circuit Preset00: + | module Preset00: + | input c: Clock + | input preset: AsyncReset + | reg r: UInt<4>, c with: (reset => (preset, UInt(3))) + | assert(c, eq(r, UInt($rEq)), UInt(1), "r = $rEq") + |""".stripMargin + test(in(3), MCSuccess, kmax = 1) + test(in(2), MCFail(0)) + } + + "a register's initial value" should "should not change" taggedAs(RequiresZ3) in { + val in = + """circuit Preset00: + | module Preset00: + | input c: Clock + | input preset: AsyncReset + | + | ; the past value of our register will only be valid in the 1st unrolling + | reg past_valid: UInt<1>, c with: (reset => (preset, UInt(0))) + | past_valid <= UInt(1) + | + | reg r: UInt<4>, c + | reg r_past: UInt<4>, c + | r_past <= r + | assert(c, eq(r, r_past), past_valid, "past_valid => r == r_past") + |""".stripMargin + test(in, MCSuccess, kmax = 2) + } +} + +abstract class EndToEndSMTBaseSpec extends AnyFlatSpec with Matchers { + def test(src: String, expected: MCResult, kmax: Int = 0, clue: String = "", annos: Seq[Annotation] = Seq()): Unit = { + expected match { + case MCFail(k) => assert(kmax >= k, s"Please set a kmax that includes the expected failing step! ($kmax < $expected)") + case _ => + } + val fir = firrtl.Parser.parse(src) + val name = fir.main + val testDir = BackendCompilationUtilities.createTestDirectory("EndToEndSMT." + name) + // we automagically add a preset annotation if an input called preset exists + val presetAnno = if(!src.contains("input preset")) { None } else { + Some(PresetAnnotation(CircuitTarget(name).module(name).ref("preset"))) + } + val res = (new FirrtlStage).execute(Array(), Seq( + LogLevelAnnotation(LogLevel.Error), // silence warnings for tests + RunFirrtlTransformAnnotation(new SMTLibEmitter), + RunFirrtlTransformAnnotation(new Btor2Emitter), + FirrtlCircuitAnnotation(fir), + TargetDirAnnotation(testDir.getAbsolutePath) + ) ++ presetAnno ++ annos) + assert(res.collectFirst{ case _: OutputFileAnnotation => true }.isDefined) + val r = Z3ModelChecker.bmc(testDir, name, kmax) + assert(r == expected, clue + "\n" + s"$testDir") + } +} + +/** Minimal implementation of a Z3 based bounded model checker. + * A more complete version of this with better use feedback should eventually be provided by a + * chisel3 formal verification library. Do not use this implementation outside of the firrtl test suite! + * */ +private object Z3ModelChecker extends LazyLogging { + def getZ3Version: String = { + val (out, ret) = executeCmd("-version") + assert(ret == 0, "failed to call z3") + assert(out.startsWith("Z3 version"), s"$out does not start with 'Z3 version'") + val version = out.split(" ")(2) + version + } + + def bmc(testDir: File, main: String, kmax: Int): MCResult = { + assert(kmax >=0 && kmax < 50, "Trying to keep kmax in a reasonable range.") + val smtFile = new File(testDir, main + ".smt2") + val header = read(smtFile) + val steps = (0 to kmax).map(k => new File(testDir, main + s"_step$k.smt2")).zipWithIndex + steps.foreach { case (f,k) => + writeStep(f, main, header, k) + val success = executeStep(f.getAbsolutePath) + if(!success) return MCFail(k) + } + MCSuccess + } + + private def executeStep(filename: String): Boolean = { + val (out, ret) = executeCmd(filename) + assert(ret == 0, s"expected success (0), not $ret: `$out`\nz3 $filename") + assert(out == "sat" || out == "unsat", s"Unexpected output: $out") + out == "unsat" + } + + private def executeCmd(cmd: String): (String, Int) = { + var out = "" + val log = ProcessLogger(s => out = s, logger.warn(_)) + val ret = Process(Seq("z3", cmd)).run(log).exitValue() + (out, ret) + } + + private def writeStep(f: File, main: String, header: Iterable[String], k: Int): Unit = { + val pw = new PrintWriter(f) + val lines = header ++ step(main, k) ++ List("(check-sat)") + lines.foreach(pw.println) + pw.close() + } + + private def step(main: String, k: Int): Iterable[String] = { + // define all states + (0 to k).map(ii => s"(declare-fun s$ii () $main$StateTpe)") ++ + // assert that init holds in state 0 + List(s"(assert ($main$Init s0))") ++ + // assert transition relation + (0 until k).map(ii => s"(assert ($main$Transition s$ii s${ii+1}))") ++ + // assert that assumptions hold in all states + (0 to k).map(ii => s"(assert ($main$Assumes s$ii))") ++ + // assert that assertions hold for all but last state + (0 until k).map(ii => s"(assert ($main$Asserts s$ii))") ++ + // check to see if we can violate the assertions in the last state + List(s"(assert (not ($main$Asserts s$k)))") + } + + private def read(f: File): Iterable[String] = { + val source = scala.io.Source.fromFile(f) + try source.getLines().toVector finally source.close() + } + + // the following suffixes have to match the ones in [[SMTTransitionSystemEncoder]] + private val Transition = "_t" + private val Init = "_i" + private val Asserts = "_a" + private val Assumes = "_u" + private val StateTpe = "_s" +} + +private sealed trait MCResult +private case object MCSuccess extends MCResult +private case class MCFail(k: Int) extends MCResult diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala new file mode 100644 index 00000000..10de9cda --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala @@ -0,0 +1,107 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +import firrtl.annotations.{CircuitTarget, MemoryArrayInitAnnotation, MemoryScalarInitAnnotation} + +class MemorySpec extends EndToEndSMTBaseSpec { + private def registeredTestMem(name: String, cmds: String, readUnderWrite: String): String = + registeredTestMem(name, cmds.split("\n"), readUnderWrite) + private def registeredTestMem(name: String, cmds: Iterable[String], readUnderWrite: String): String = + s"""circuit $name: + | module $name: + | input reset : UInt<1> + | input clock : Clock + | input preset: AsyncReset + | input write_addr : UInt<5> + | input read_addr : UInt<5> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => 32 + | reader => r + | writer => w + | read-latency => 1 + | write-latency => 1 + | read-under-write => $readUnderWrite + | + | m.w.clk <= clock + | m.w.mask <= UInt(1) + | m.w.en <= UInt(1) + | m.w.data <= in + | m.w.addr <= write_addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= read_addr + | + | reg cycle: UInt<8>, clock with: (reset => (preset, UInt(0))) + | cycle <= add(cycle, UInt(1)) + | node past_valid = geq(cycle, UInt(1)) + | + | ${cmds.mkString("\n ")} + |""".stripMargin + + "Registered test memory" should "return written data after two cycles" taggedAs(RequiresZ3) in { + val cmds = + """node past_past_valid = geq(cycle, UInt(2)) + |reg past_in: UInt<8>, clock + |past_in <= in + |reg past_past_in: UInt<8>, clock + |past_past_in <= past_in + |reg past_write_addr: UInt<5>, clock + |past_write_addr <= write_addr + | + |assume(clock, eq(read_addr, past_write_addr), past_valid, "read_addr = past(write_addr)") + |assert(clock, eq(out, past_past_in), past_past_valid, "out = past(past(in))") + |""".stripMargin + test(registeredTestMem("Mem00", cmds, "old"), MCSuccess, kmax = 3) + } + + private def readOnlyMem(pred: String, num: Int) = + s"""circuit Mem0$num: + | module Mem0$num: + | input c : Clock + | input read_addr : UInt<2> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => 4 + | reader => r + | read-latency => 0 + | write-latency => 1 + | read-under-write => new + | + | m.r.clk <= c + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= read_addr + | + | assert(c, $pred, UInt(1), "") + |""".stripMargin + private def m(num: Int) = CircuitTarget(s"Mem0$num").module(s"Mem0$num").ref("m") + + "read-only memory" should "always return 0" taggedAs(RequiresZ3) in { + test(readOnlyMem("eq(out, UInt(0))", 1), MCSuccess, kmax=2, + annos=Seq(MemoryScalarInitAnnotation(m(1), 0))) + } + + "read-only memory" should "not always return 1" taggedAs(RequiresZ3) in { + test(readOnlyMem("eq(out, UInt(1))", 2), MCFail(0), kmax=2, + annos=Seq(MemoryScalarInitAnnotation(m(2), 0))) + } + + "read-only memory" should "always return 1 or 2" taggedAs(RequiresZ3) in { + test(readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 3), MCSuccess, kmax=2, + annos=Seq(MemoryArrayInitAnnotation(m(3), Seq(1, 2, 2, 1)))) + } + + "read-only memory" should "not always return 1 or 2 or 3" taggedAs(RequiresZ3) in { + test(readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 4), MCFail(0), kmax=2, + annos=Seq(MemoryArrayInitAnnotation(m(4), Seq(1, 2, 2, 3)))) + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/RequiresZ3.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/RequiresZ3.scala new file mode 100644 index 00000000..d633a1a0 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/RequiresZ3.scala @@ -0,0 +1,9 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +import org.scalatest.Tag + +// To disable tests that require the Z3 SMT solver to be installed use the following: +// `sbt testOnly -- -l RequiresZ3` +object RequiresZ3 extends Tag("RequiresZ3") diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala new file mode 100644 index 00000000..cbf194dd --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala @@ -0,0 +1,46 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +import java.io.File + +import firrtl.stage.{FirrtlStage, OutputFileAnnotation} +import firrtl.util.BackendCompilationUtilities +import logger.LazyLogging +import org.scalatest.flatspec.AnyFlatSpec + +import scala.sys.process.{Process, ProcessLogger} + +/** compiles the regression tests to SMTLib and parses the result with z3 */ +class SMTCompilationTest extends AnyFlatSpec with LazyLogging { + it should "generate valid SMTLib for AddNot" taggedAs(RequiresZ3) in { compileAndParse("AddNot") } + it should "generate valid SMTLib for FPU" taggedAs(RequiresZ3) in { compileAndParse("FPU") } + // we get a stack overflow in Scala 2.11 because of a deeply nested and(...) expression in the sequencer + it should "generate valid SMTLib for HwachaSequencer" taggedAs(RequiresZ3) ignore { compileAndParse("HwachaSequencer") } + it should "generate valid SMTLib for ICache" taggedAs(RequiresZ3) in { compileAndParse("ICache") } + it should "generate valid SMTLib for Ops" taggedAs(RequiresZ3) in { compileAndParse("Ops") } + // TODO: enable Rob test once we support more than 2 write ports on a memory + it should "generate valid SMTLib for Rob" taggedAs(RequiresZ3) ignore { compileAndParse("Rob") } + it should "generate valid SMTLib for RocketCore" taggedAs(RequiresZ3) in { compileAndParse("RocketCore") } + + private def compileAndParse(name: String): Unit = { + val testDir = BackendCompilationUtilities.createTestDirectory(name + "-smt") + val inputFile = new File(testDir, s"${name}.fir") + BackendCompilationUtilities.copyResourceToFile(s"/regress/${name}.fir", inputFile) + + val args = Array( + "-ll", "error", // surpress warnings to keep test output clean + "--target-dir", testDir.toString, + "-i", inputFile.toString, + "-E", "experimental-smt2" + // "-fct", "firrtl.backends.experimental.smt.StutteringClockTransform" + ) + val res = (new FirrtlStage).execute(args, Seq()) + val fileName = res.collectFirst{ case OutputFileAnnotation(file) => file }.get + + val smtFile = testDir.toString + "/" + fileName + ".smt2" + val log = ProcessLogger(_ => (), logger.error(_)) + val z3Ret = Process(Seq("z3", smtFile)).run(log).exitValue() + assert(z3Ret == 0, s"Failed to parse SMTLib file $smtFile generated for $name") + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala new file mode 100644 index 00000000..8fa80b4c --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala @@ -0,0 +1,40 @@ +// See LICENSE for license details. + +package firrtl.backends.experimental.smt.end2end + +/** undefined values in firrtl are modelled as fresh auxiliary variables (inputs) */ +class UndefinedFirrtlSpec extends EndToEndSMTBaseSpec { + + "division by zero" should "result in an arbitrary value" taggedAs(RequiresZ3) in { + // the SMTLib spec defines the result of division by zero to be all 1s + // https://cs.nyu.edu/pipermail/smt-lib/2015/000977.html + def in(dEq: Int) = + s"""circuit CC00: + | module CC00: + | input c: Clock + | input a: UInt<2> + | input b: UInt<2> + | assume(c, eq(b, UInt(0)), UInt(1), "b = 0") + | node d = div(a, b) + | assert(c, eq(d, UInt($dEq)), UInt(1), "d = $dEq") + |""".stripMargin + // we try to assert that (d = a / 0) is any fixed value which should be false + (0 until 4).foreach { ii => test(in(ii), MCFail(0), 0, s"d = a / 0 = $ii") } + } + + // TODO: rem should probably also be undefined, but the spec isn't 100% clear here + + + "invalid signals" should "have an arbitrary values" taggedAs(RequiresZ3) in { + def in(aEq: Int) = + s"""circuit CC00: + | module CC00: + | input c: Clock + | wire a: UInt<2> + | a is invalid + | assert(c, eq(a, UInt($aEq)), UInt(1), "a = $aEq") + |""".stripMargin + // a should not be equivalent to any fixed value (0, 1, 2 or 3) + (0 until 4).foreach { ii => test(in(ii), MCFail(0), 0, s"a = $ii") } + } +} |
