diff options
7 files changed, 82 insertions, 58 deletions
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala index 0b8e3ebf..246f0172 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala @@ -59,37 +59,24 @@ private case class TransitionSystem( private case class TransitionSystemAnnotation(sys: TransitionSystem) extends NoTargetAnnotation object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration { - // TODO: We only really need [[Forms.MidForm]] + LowerTypes, but we also want to fail if there are CombLoops - // TODO: We also would like to run some optimization passes, but RemoveValidIf won't allow us to model DontCare - // precisely and PadWidths emits ill-typed firrtl. override def prerequisites: Seq[Dependency[Transform]] = Forms.LowForm ++ Seq( - Dependency(InvalidToRandomPass), - Dependency(UndefinedMemoryBehaviorPass), Dependency(VerilogMemDelays), - Dependency(EnsureNamedStatements) // this is required to give assert/assume statements good names + Dependency(EnsureNamedStatements), // this is required to give assert/assume statements good names + Dependency[PropagatePresetAnnotations] ) override def invalidates(a: Transform): Boolean = false // since this pass only runs on the main module, inlining needs to happen before override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency[firrtl.passes.InlineInstances]) - // We run the propagate preset annotations pass manually since we do not want to remove ValidIfs and other - // Verilog emission passes. - // Ideally we would go in and enable the [[PropagatePresetAnnotations]] to only depend on LowForm. - private val presetPass = new PropagatePresetAnnotations - // We also need to run the DeadCodeElimination since PropagatePresets does not remove possible remaining - // AsyncReset nodes. - private val deadCodeElimination = new DeadCodeElimination override protected def execute(state: CircuitState): CircuitState = { - // run the preset pass to extract all preset registers and remove preset reset signals - val afterPreset = deadCodeElimination.execute(presetPass.execute(state)) - val circuit = afterPreset.circuit - val presetRegs = afterPreset.annotations.collect { + val circuit = state.circuit + val presetRegs = state.annotations.collect { case PresetRegAnnotation(target) if target.module == circuit.main => target.ref }.toSet // collect all non-random memory initialization - val memInit = afterPreset.annotations.collect { case a: MemoryInitAnnotation if !a.isRandomInit => a } + val memInit = state.annotations.collect { case a: MemoryInitAnnotation if !a.isRandomInit => a } .filter(_.target.module == circuit.main) .map(a => a.target.ref -> a.initValue) .toMap @@ -98,7 +85,7 @@ object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration { val modules = circuit.modules.map(m => m.name -> m).toMap // collect uninterpreted module annotations - val uninterpreted = afterPreset.annotations.collect { + val uninterpreted = state.annotations.collect { case a: UninterpretedModuleAnnotation => UninterpretedModuleAnnotation.checkModule(modules(a.target.module), a) a.target.module -> a @@ -117,7 +104,7 @@ object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration { val sortedSys = TopologicalSort.run(sys) val anno = TransitionSystemAnnotation(sortedSys) - state.copy(circuit = circuit, annotations = afterPreset.annotations :+ anno) + state.copy(circuit = circuit, annotations = state.annotations :+ anno) } } @@ -415,13 +402,13 @@ private class ModuleScanner( inputs.append(BVSymbol(name, bitWidth(tpe).toInt)) case ir.DefWire(info, name, tpe) => namespace.newName(name) - if (!isClock(tpe)) { + if (!isClock(tpe) && !isAsyncReset(tpe)) { infos.append(name -> info) wires.append(name) } case ir.DefNode(info, name, expr) => namespace.newName(name) - if (!isClock(expr.tpe)) { + if (!isClock(expr.tpe) && !isAsyncReset(expr.tpe)) { insertDummyAssignsForUnusedOutputs(expr) infos.append(name -> info) val e = onExpression(expr) diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala index f6788435..c9c1c943 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala @@ -6,7 +6,13 @@ import org.scalatest.flatspec.AnyFlatSpec class FirrtlExpressionSemanticsSpec extends AnyFlatSpec { - private def primopSys(op: String, resTpe: String, inTpes: Seq[String], consts: Seq[Int]): TransitionSystem = { + private def primopSys( + op: String, + resTpe: String, + inTpes: Seq[String], + consts: Seq[Int], + modelUndef: Boolean + ): TransitionSystem = { val inputs = inTpes.zipWithIndex.map { case (tpe, ii) => s" input i$ii : $tpe" }.mkString("\n") val args = (inTpes.zipWithIndex.map { case (_, ii) => s"i$ii" } ++ consts.map(_.toString)).mkString(", ") val src = @@ -17,11 +23,11 @@ class FirrtlExpressionSemanticsSpec extends AnyFlatSpec { | res <= $op($args) | |""".stripMargin - SMTBackendHelpers.toSys(src) + SMTBackendHelpers.toSys(src, modelUndef = modelUndef) } def primop(op: String, resTpe: String, inTpes: Seq[String], consts: Seq[Int]): String = { - val sys = primopSys(op, resTpe, inTpes, consts) + val sys = primopSys(op, resTpe, inTpes, consts, modelUndef = false) assert(sys.signals.length >= 1) sys.signals.last.e.toString } @@ -32,12 +38,13 @@ class FirrtlExpressionSemanticsSpec extends AnyFlatSpec { resWidth: Int, inWidth: Seq[Int], consts: Seq[Int], - resAlwaysUnsigned: Boolean + resAlwaysUnsigned: Boolean, + modelUndef: Boolean ): TransitionSystem = { val tpe = if (signed) "SInt" else "UInt" val resTpe = if (resAlwaysUnsigned) "UInt" else tpe val inTpes = inWidth.map(w => s"$tpe<$w>") - primopSys(op, s"$resTpe<$resWidth>", inTpes, consts) + primopSys(op, s"$resTpe<$resWidth>", inTpes, consts, modelUndef) } def primop( @@ -46,9 +53,10 @@ class FirrtlExpressionSemanticsSpec extends AnyFlatSpec { resWidth: Int, inWidth: Seq[Int], consts: Seq[Int] = List(), - resAlwaysUnsigned: Boolean = false + resAlwaysUnsigned: Boolean = false, + modelUndef: Boolean = false ): String = { - val sys = primopSys(signed, op, resWidth, inWidth, consts, resAlwaysUnsigned) + val sys = primopSys(signed, op, resWidth, inWidth, consts, resAlwaysUnsigned, modelUndef) assert(sys.signals.length >= 1) sys.signals.last.e.toString } @@ -81,23 +89,30 @@ class FirrtlExpressionSemanticsSpec extends AnyFlatSpec { //println(sys.serialize) assert( - primop(false, "div", 8, List(8, 8)) == + primop(false, "div", 8, List(8, 8), modelUndef = true) == "ite(res_invalid_cond, res_invalid, udiv(i0, i1))" ) assert( - primop(false, "div", 8, List(8, 4)) == + primop(false, "div", 8, List(8, 4), modelUndef = true) == "ite(res_invalid_cond, res_invalid, udiv(i0, zext(i1, 4)))" ) // signed division increases result width by 1 assert( - primop(true, "div", 8, List(7, 7)) == + primop(true, "div", 8, List(7, 7), modelUndef = true) == "ite(res_invalid_cond, res_invalid, sdiv(sext(i0, 1), sext(i1, 1)))" ) assert( - primop(true, "div", 8, List(7, 4)) + primop(true, "div", 8, List(7, 4), modelUndef = true) == "ite(res_invalid_cond, res_invalid, sdiv(sext(i0, 1), sext(i1, 4)))" ) + + // --------------------------------------------------------- + // without modelling the undefined-ness of division by zero: + assert(primop(false, "div", 8, List(8, 8), modelUndef = false) == "udiv(i0, i1)") + assert(primop(false, "div", 8, List(8, 4), modelUndef = false) == "udiv(i0, zext(i1, 4))") + assert(primop(true, "div", 8, List(7, 7), modelUndef = false) == "sdiv(sext(i0, 1), sext(i1, 1))") + assert(primop(true, "div", 8, List(7, 4), modelUndef = false) == "sdiv(sext(i0, 1), sext(i1, 4))") } it should "correctly translate the `rem` primitive operation" in { diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala index 7bc80102..c100da56 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala @@ -148,7 +148,7 @@ class FirrtlModuleToTransitionSystemSpec extends AnyFlatSpec { | when en: | o <= UInt<8>(0) |""".stripMargin - val sys = SMTBackendHelpers.toSys(src) + val sys = SMTBackendHelpers.toSys(src, modelUndef = true) assert(sys.inputs.length == 2) val invalids = sys.inputs.filter(_.name.contains("_invalid")) assert(invalids.length == 1) @@ -192,25 +192,19 @@ class FirrtlModuleToTransitionSystemSpec extends AnyFlatSpec { assert(err.getMessage.contains("clk, c.clk")) } - it should "throw an error on async reset" in { + it should "throw an error on async reset driving a register" in { val err = intercept[AsyncResetException] { SMTBackendHelpers.toSys( """circuit m: | module m: + | input clock : Clock | input reset : AsyncReset - |""".stripMargin - ) - } - assert(err.getMessage.contains("reset")) - } - - it should "throw an error on casting to async reset" in { - val err = intercept[AssertionError] { - SMTBackendHelpers.toSys( - """circuit m: - | module m: - | input reset : UInt<1> - | node async = asAsyncReset(reset) + | input in : UInt<4> + | output out : UInt<4> + | + | reg r : UInt<4>, clock with : (reset => (reset, UInt<8>(0))) + | r <= in + | out <= r |""".stripMargin ) } diff --git a/src/test/scala/firrtl/backends/experimental/smt/SMTBackendHelpers.scala b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendHelpers.scala index 4d212ad4..71d1d38c 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/SMTBackendHelpers.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendHelpers.scala @@ -3,25 +3,39 @@ package firrtl.backends.experimental.smt import firrtl.annotations.Annotation +import firrtl.backends.experimental.smt.random.{InvalidToRandomPass, UndefinedMemoryBehaviorPass} +import firrtl.options.Dependency import firrtl.{ir, MemoryInitValue} -import firrtl.stage.{Forms, TransformManager} +import firrtl.stage.{Forms, RunFirrtlTransformAnnotation, TransformManager} private object SMTBackendHelpers { private val dependencies = Forms.LowForm ++ FirrtlToTransitionSystem.prerequisites private val compiler = new TransformManager(dependencies) + private val undefCompiler = new TransformManager( + dependencies ++ Seq( + Dependency(InvalidToRandomPass), + Dependency(UndefinedMemoryBehaviorPass) + ) + ) def compile(src: String, annos: Seq[Annotation] = List()): ir.Circuit = { val c = firrtl.Parser.parse(src) compiler.runTransform(firrtl.CircuitState(c, annos)).circuit } + def compileUndef(src: String, annos: Seq[Annotation] = List()): ir.Circuit = { + val c = firrtl.Parser.parse(src) + undefCompiler.runTransform(firrtl.CircuitState(c, annos)).circuit + } + def toSys( src: String, mod: String = "m", presetRegs: Set[String] = Set(), - memInit: Map[String, MemoryInitValue] = Map() + memInit: Map[String, MemoryInitValue] = Map(), + modelUndef: Boolean = false ): TransitionSystem = { - val circuit = compile(src) + val circuit = if (modelUndef) compileUndef(src) else compile(src) val module = circuit.modules.find(_.name == mod).get.asInstanceOf[ir.Module] // println(module.serialize) new ModuleToTransitionSystem().run(module, presetRegs = presetRegs, memInit = memInit) diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala index 5a697980..dc425149 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala @@ -3,8 +3,9 @@ package firrtl.backends.experimental.smt.end2end import firrtl.annotations.{Annotation, CircuitTarget, PresetAnnotation} +import firrtl.backends.experimental.smt.random.{InvalidToRandomPass, UndefinedMemoryBehaviorPass} import firrtl.backends.experimental.smt.{Btor2Emitter, SMTLibEmitter} -import firrtl.options.TargetDirAnnotation +import firrtl.options.{Dependency, TargetDirAnnotation} import firrtl.stage.{FirrtlCircuitAnnotation, FirrtlStage, OutputFileAnnotation, RunFirrtlTransformAnnotation} import firrtl.util.BackendCompilationUtilities.timeStamp import logger.{LazyLogging, LogLevel, LogLevelAnnotation} @@ -155,6 +156,9 @@ abstract class EndToEndSMTBaseSpec extends AnyFlatSpec with Matchers { val r = Z3ModelChecker.bmc(testDir, name, kmax) assert(r == expected, clue + "\n" + s"$testDir") } + + val UndefinedMemAnnos = Seq(RunFirrtlTransformAnnotation(Dependency(UndefinedMemoryBehaviorPass))) + val InvalidToRandomAnnos = Seq(RunFirrtlTransformAnnotation(Dependency(InvalidToRandomPass))) } /** Minimal implementation of a Z3 based bounded model checker. diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala index 2a0276e1..8a8fbc4d 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala @@ -190,10 +190,10 @@ class MemorySpec extends EndToEndSMTBaseSpec { | assert(c, eq(m.r.data, prevData), and(pastValid, prevEn), "") |""".stripMargin "memory with two write ports" should "not have collisions when enables are mutually exclusive" taggedAs (RequiresZ3) in { - test(collisionTest("not(and(aEn, bEn))"), MCSuccess, kmax = 4) + test(collisionTest("not(and(aEn, bEn))"), MCSuccess, kmax = 4, annos = UndefinedMemAnnos) } "memory with two write ports" should "can have collisions when enables are unconstrained" taggedAs (RequiresZ3) in { - test(collisionTest("UInt(1)"), MCFail(1), kmax = 1) + test(collisionTest("UInt(1)"), MCFail(1), kmax = 1, annos = UndefinedMemAnnos) } private def readEnableSrc(pred: String, num: Int) = @@ -229,12 +229,22 @@ class MemorySpec extends EndToEndSMTBaseSpec { "a memory with read enable" should "supply valid data one cycle after en=1" in { val init = Seq(MemoryScalarInitAnnotation(CircuitTarget(s"ReadEnableTest1").module(s"ReadEnableTest1").ref("m"), 0)) // the read port is enabled on even cycles, so on odd cycles we should reliably get zeros - test(readEnableSrc("or(not(odd), eq(m.r.data, UInt(0)))", 1), MCSuccess, kmax = 3, annos = init) + test( + readEnableSrc("or(not(odd), eq(m.r.data, UInt(0)))", 1), + MCSuccess, + kmax = 3, + annos = init ++ UndefinedMemAnnos + ) } "a memory with read enable" should "supply invalid data one cycle after en=0" in { val init = Seq(MemoryScalarInitAnnotation(CircuitTarget(s"ReadEnableTest2").module(s"ReadEnableTest2").ref("m"), 0)) // the read port is disabled on odd cycles, so on even cycles we should *NOT* reliably get zeros - test(readEnableSrc("or(not(even), eq(m.r.data, UInt(0)))", 2), MCFail(1), kmax = 1, annos = init) + test( + readEnableSrc("or(not(even), eq(m.r.data, UInt(0)))", 2), + MCFail(1), + kmax = 1, + annos = init ++ UndefinedMemAnnos + ) } } diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala index c1587580..75eccaf8 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala @@ -19,7 +19,7 @@ class UndefinedFirrtlSpec extends EndToEndSMTBaseSpec { | assert(c, eq(d, UInt($dEq)), UInt(1), "d = $dEq") |""".stripMargin // we try to assert that (d = a / 0) is any fixed value which should be false - (0 until 4).foreach { ii => test(in(ii), MCFail(0), 0, s"d = a / 0 = $ii") } + (0 until 4).foreach { ii => test(in(ii), MCFail(0), 0, s"d = a / 0 = $ii", annos = InvalidToRandomAnnos) } } // TODO: rem should probably also be undefined, but the spec isn't 100% clear here @@ -34,6 +34,6 @@ class UndefinedFirrtlSpec extends EndToEndSMTBaseSpec { | assert(c, eq(a, UInt($aEq)), UInt(1), "a = $aEq") |""".stripMargin // a should not be equivalent to any fixed value (0, 1, 2 or 3) - (0 until 4).foreach { ii => test(in(ii), MCFail(0), 0, s"a = $ii") } + (0 until 4).foreach { ii => test(in(ii), MCFail(0), 0, s"a = $ii", annos = InvalidToRandomAnnos) } } } |
