diff options
| author | Kevin Laeufer | 2021-03-08 15:39:15 -0800 |
|---|---|---|
| committer | GitHub | 2021-03-08 23:39:15 +0000 |
| commit | 29d57a612df69ae4a6db4b3755fc292e5a539e11 (patch) | |
| tree | 653723bacce2419c76c64dfcf847e47e81980905 /src | |
| parent | c93d6f5319efd7ba42147180c6e2b6f3796ef943 (diff) | |
SMT: memory port inout fields cannot be used as RHS expressions (#2105)
* SMT: memory port inout fields cannot be used as RHS expressions
* smt: add end2end check for read enable modelling
Diffstat (limited to 'src')
3 files changed, 134 insertions, 88 deletions
diff --git a/src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala b/src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala index 91e77433..5fd0e680 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala @@ -2,8 +2,7 @@ package firrtl.backends.experimental.smt.random -import firrtl.Utils.{isLiteral, kind, BoolType} -import firrtl.WrappedExpression.{we, weq} +import firrtl.Utils.{isLiteral, BoolType} import firrtl._ import firrtl.annotations.NoTargetAnnotation import firrtl.backends.experimental.smt._ @@ -14,7 +13,7 @@ import firrtl.passes.memlib.AnalysisUtils.Connects import firrtl.passes.memlib.InferReadWritePass.checkComplement import firrtl.passes.memlib.{AnalysisUtils, InferReadWritePass, VerilogMemDelays} import firrtl.stage.Forms -import firrtl.transforms.{ConstantPropagation, RemoveWires} +import firrtl.transforms.RemoveWires import scala.collection.mutable @@ -111,7 +110,10 @@ private class InstrumentMems( private def onMem(m: DefMemory): Statement = { // collect wire and random statement defines - val declarations = mutable.ListBuffer[Statement]() + implicit val declarations: mutable.ListBuffer[Statement] = mutable.ListBuffer[Statement]() + + // cache for the expressions of memory inputs + implicit val cache: mutable.HashMap[String, Expression] = mutable.HashMap[String, Expression]() // only for non power of 2 memories do we have to worry about reading or writing out of bounds val canBeOutOfBounds = !isPow2(m.depth) @@ -132,42 +134,23 @@ private class InstrumentMems( val maskRef = memPortField(m, write, "mask") val prods = getProductTerms(enRef) ++ getProductTerms(maskRef) + val expr = Utils.and(readInput(m.info, enRef), readInput(m.info, maskRef)) - // if we can have write/write conflicts, we are going to change the mask and enable pins - val expr = if (canHaveWriteWriteConflicts) { - val maskIsOne = isTrue(connects(maskRef.serialize)) - // if the mask is connected to a constant true, we do not need to consider it, this is a common case - if (maskIsOne) { - val enWire = disconnectInput(m.info, enRef) - declarations += enWire - Reference(enWire) - } else { - val maskWire = disconnectInput(m.info, maskRef) - val enWire = disconnectInput(m.info, enRef) - // create a node for the conjunction - val nodeName = namespace.newName(s"${m.name}_${write}_mask_and_en") - val node = DefNode(m.info, nodeName, Utils.and(Reference(maskWire), Reference(enWire))) - declarations ++= List(maskWire, enWire, node) - Reference(node) - } - } else { - Utils.and(enRef, maskRef) - } (expr, prods) } // implement the three undefined read behaviors m.readers.foreach { read => // many memories have their read enable hard wired to true - val canBeDisabled = !isTrue(memPortField(m, read, "en")) - val readEn = if (canBeDisabled) memPortField(m, read, "en") else Utils.True() - val addr = memPortField(m, read, "addr") + val canBeDisabled = !isTrue(readInput(m, read, "en")) + val readEn = if (canBeDisabled) readInput(m, read, "en") else Utils.True() // collect signals that would lead to a randomization var doRand = List[Expression]() // randomize the read value when the address is out of bounds if (canBeOutOfBounds && opt.randomizeOutOfBoundsRead) { + val addr = readInput(m, read, "addr") val cond = Utils.and(readEn, Utils.not(isInBounds(m.depth, addr))) val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_oob"), cond) declarations += node @@ -175,8 +158,7 @@ private class InstrumentMems( } if (readWriteUndefined && opt.randomizeReadWriteConflicts) { - val (cond, d) = readWriteConflict(m, read, writeEn) - declarations ++= d + val cond = readWriteConflict(m, read, writeEn) val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_rwc"), cond) declarations += node doRand = Reference(node) +: doRand @@ -203,7 +185,7 @@ private class InstrumentMems( } val doRandSignal = if (m.readLatency == 0) { doRandNode } else { - val clock = memPortField(m, read, "clk") + val clock = readInput(m, read, "clk") val (signal, regDecls) = pipeline(m.info, clock, doRandName, doRandNode, m.readLatency) declarations ++= regDecls signal @@ -217,7 +199,7 @@ private class InstrumentMems( // create a source of randomness and connect the new wire either to the actual data port or to the random value val randName = namespace.newName(s"${m.name}_${read}_rand_data") - val random = DefRandom(m.info, randName, m.dataType, Some(memPortField(m, read, "clk")), doRandSignal) + val random = DefRandom(m.info, randName, m.dataType, Some(readInput(m, read, "clk")), doRandSignal) declarations += random val data = Utils.mux(doRandSignal, Reference(random), dataRef) newStmts.append(Connect(m.info, Reference(dataWire), data)) @@ -226,16 +208,16 @@ private class InstrumentMems( // write if (opt.randomizeWriteWriteConflicts) { - declarations ++= writeWriteConflicts(m, writeEn) + writeWriteConflicts(m, writeEn) } // add an assertion that if the write is taking place, then the address must be in range if (canBeOutOfBounds && opt.assertNoOutOfBoundsWrites) { m.writers.zip(writeEn).foreach { case (write, (combinedEn, _)) => - val addr = memPortField(m, write, "addr") + val addr = readInput(m, write, "addr") val cond = Utils.implies(combinedEn, isInBounds(m.depth, addr)) - val clk = memPortField(m, write, "clk") + val clk = readInput(m, write, "clk") val a = Verification(Formal.Assert, m.info, clk, cond, Utils.True(), StringLit("out of bounds read")) newStmts.append(a) } @@ -268,11 +250,13 @@ private class InstrumentMems( m: DefMemory, read: String, writeEn: Seq[(Expression, ProdTerms)] - ): (Expression, Seq[Statement]) = { - if (m.writers.isEmpty) return (Utils.False(), List()) - val declarations = mutable.ListBuffer[Statement]() + )( + implicit cache: mutable.HashMap[String, Expression], + declarations: mutable.ListBuffer[Statement] + ): Expression = { + if (m.writers.isEmpty) return Utils.False() - val readEn = memPortField(m, read, "en") + val readEn = readInput(m, read, "en") val readProd = getProductTerms(readEn) // create all conflict signals @@ -283,7 +267,7 @@ private class InstrumentMems( } else { val name = namespace.newName(s"${m.name}_${read}_${write}_rwc") val bothEn = Utils.and(readEn, writeEn) - val sameAddr = Utils.eq(memPortField(m, read, "addr"), memPortField(m, write, "addr")) + val sameAddr = Utils.eq(readInput(m, read, "addr"), readInput(m, write, "addr")) // we need a wire because this condition might be used in a random statement val wire = DefWire(m.info, name, BoolType) declarations += wire @@ -292,13 +276,18 @@ private class InstrumentMems( } } - (conflicts.reduce(Utils.or), declarations.toList) + conflicts.reduce(Utils.or) } private type ProdTerms = Seq[Expression] - private def writeWriteConflicts(m: DefMemory, writeEn: Seq[(Expression, ProdTerms)]): Seq[Statement] = { - if (m.writers.size < 2) return List() - val declarations = mutable.ListBuffer[Statement]() + private def writeWriteConflicts( + m: DefMemory, + writeEn: Seq[(Expression, ProdTerms)] + )( + implicit cache: mutable.HashMap[String, Expression], + declarations: mutable.ListBuffer[Statement] + ): Unit = { + if (m.writers.size < 2) return // we first create all conflict signals: val conflict = @@ -314,7 +303,7 @@ private class InstrumentMems( } else { val name = namespace.newName(s"${m.name}_${w1}_${w2}_wwc") val bothEn = Utils.and(en1, en2) - val sameAddr = Utils.eq(memPortField(m, w1, "addr"), memPortField(m, w2, "addr")) + val sameAddr = Utils.eq(readInput(m, w1, "addr"), readInput(m, w2, "addr")) // we need a wire because this condition might be used in a random statement val wire = DefWire(m.info, name, BoolType) declarations += wire @@ -355,25 +344,24 @@ private class InstrumentMems( // create the source of randomness val name = namespace.newName(s"${m.name}_${w1}_wwc_data") - val random = DefRandom(m.info, name, m.dataType, Some(memPortField(m, w1, "clk")), hasConflict) + val random = DefRandom(m.info, name, m.dataType, Some(readInput(m, w1, "clk")), hasConflict) declarations.append(random) - // replace the old data input - val dataWire = disconnectInput(m.info, memPortField(m, w1, "data")) - declarations += dataWire + // generate new data input - val data = Utils.mux(hasConflict, Reference(random), Reference(dataWire)) + val data = Utils.mux(hasConflict, Reference(random), readInput(m, w1, "data")) newStmts.append(Connect(m.info, memPortField(m, w1, "data"), data)) + doDisconnect.add(memPortField(m, w1, "data").serialize) } // connect data enable signals - val maskIsOne = isTrue(connects(memPortField(m, w1, "mask").serialize)) + val maskIsOne = isTrue(readInput(m, w1, "mask")) if (!maskIsOne) { newStmts.append(Connect(m.info, memPortField(m, w1, "mask"), Utils.True())) + doDisconnect.add(memPortField(m, w1, "mask").serialize) } newStmts.append(Connect(m.info, memPortField(m, w1, "en"), en)) + doDisconnect.add(memPortField(m, w1, "en").serialize) } - - declarations.toList } /** check whether two signals can be proven to be mutually exclusive */ @@ -383,27 +371,43 @@ private class InstrumentMems( proofOfMutualExclusion.nonEmpty } - /** replace a memory port with a wire */ - private def disconnectInput(info: Info, signal: RefLikeExpression): DefWire = { - // disconnect the old value - doDisconnect.add(signal.serialize) - - // if the old value is a literal, we just replace all references to it with this literal - val oldValue = connects(signal.serialize) - if (isLiteral(oldValue)) { - println("TODO: better code for literal") - } + /** memory inputs my not be read, only assigned to, thus we might need to add a wire to make them accessible */ + private def readInput( + info: Info, + signal: RefLikeExpression + )( + implicit cache: mutable.HashMap[String, Expression], + declarations: mutable.ListBuffer[Statement] + ): Expression = + cache.getOrElseUpdate( + signal.serialize, { + // if it is a literal, we just return it + val value = connects(signal.serialize) + if (isLiteral(value)) { + value + } else { + // otherwise we make a wire that refelect the value + val wire = DefWire(info, copyName(signal), signal.tpe) + declarations += wire - // create a new wire and replace all references to the original port with this wire - val wire = DefWire(info, copyName(signal), signal.tpe) - exprReplacements(signal.serialize) = Reference(wire) - // connect the old expression to the new wire - val con = Connect(info, Reference(wire), connects(signal.serialize)) - newStmts.append(con) + // connect the old expression to the new wire + val con = Connect(info, Reference(wire), value) + newStmts.append(con) - // the wire definition should end up right after the memory definition - wire - } + // use a reference to this new wire + Reference(wire) + } + } + ) + private def readInput( + m: DefMemory, + port: String, + field: String + )( + implicit cache: mutable.HashMap[String, Expression], + declarations: mutable.ListBuffer[Statement] + ): Expression = + readInput(m.info, memPortField(m, port, field)) private def copyName(ref: RefLikeExpression): String = namespace.newName(ref.serialize.replace('.', '_')) diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala index e489db7d..2a0276e1 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala @@ -195,4 +195,46 @@ class MemorySpec extends EndToEndSMTBaseSpec { "memory with two write ports" should "can have collisions when enables are unconstrained" taggedAs (RequiresZ3) in { test(collisionTest("UInt(1)"), MCFail(1), kmax = 1) } + + private def readEnableSrc(pred: String, num: Int) = + s""" + |circuit ReadEnableTest$num: + | module ReadEnableTest$num: + | input c : Clock + | input preset: AsyncReset + | + | reg first: UInt<1>, c with: (reset => (preset, UInt(1))) + | first <= UInt(0) + | + | reg even: UInt<1>, c with: (reset => (preset, UInt(0))) + | node odd = not(even) + | even <= not(even) + | + | mem m: + | data-type => UInt<8> + | depth => 4 + | reader => r + | read-latency => 1 + | write-latency => 1 + | read-under-write => undefined + | + | m.r.clk <= c + | m.r.addr <= UInt(0) + | ; the read port is enabled in even cycles + | m.r.en <= even + | + | assert(c, $pred, not(first), "") + |""".stripMargin + + "a memory with read enable" should "supply valid data one cycle after en=1" in { + val init = Seq(MemoryScalarInitAnnotation(CircuitTarget(s"ReadEnableTest1").module(s"ReadEnableTest1").ref("m"), 0)) + // the read port is enabled on even cycles, so on odd cycles we should reliably get zeros + test(readEnableSrc("or(not(odd), eq(m.r.data, UInt(0)))", 1), MCSuccess, kmax = 3, annos = init) + } + + "a memory with read enable" should "supply invalid data one cycle after en=0" in { + val init = Seq(MemoryScalarInitAnnotation(CircuitTarget(s"ReadEnableTest2").module(s"ReadEnableTest2").ref("m"), 0)) + // the read port is disabled on odd cycles, so on even cycles we should *NOT* reliably get zeros + test(readEnableSrc("or(not(even), eq(m.r.data, UInt(0)))", 2), MCFail(1), kmax = 1, annos = init) + } } diff --git a/src/test/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorSpec.scala index c6f89b93..f8f889ac 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorSpec.scala @@ -13,10 +13,10 @@ class UndefinedMemoryBehaviorSpec extends LeanTransformSpec(Seq(Dependency(Undef val result = circuit.serialize.split('\n').map(_.trim) // a random value should be declared for the data written on a write-write conflict - assert(result.contains("rand m_a_wwc_data : UInt<32>, m.a.clk when m_a_b_wwc")) + assert(result.contains("rand m_a_wwc_data : UInt<32>, m_a_clk when m_a_b_wwc")) // a write-write conflict occurs when both ports are enabled and the addresses match - assert(result.contains("m_a_b_wwc <= and(and(m_a_en, m_b_en), eq(m.a.addr, m.b.addr))")) + assert(result.contains("m_a_b_wwc <= and(and(m_a_en, m_b_en), eq(m_a_addr, m_b_addr))")) // the data of read port a depends on whether there is a write-write conflict assert(result.contains("m.a.data <= mux(m_a_b_wwc, m_a_wwc_data, m_a_data)")) @@ -35,14 +35,14 @@ class UndefinedMemoryBehaviorSpec extends LeanTransformSpec(Seq(Dependency(Undef assert(result.contains("node m_a_wwc_active = or(m_a_b_wwc, m_a_c_wwc)")) // a random value should be declared for the data written on a write-write conflict - assert(result.contains("rand m_a_wwc_data : UInt<32>, m.a.clk when m_a_wwc_active")) - assert(result.contains("rand m_b_wwc_data : UInt<32>, m.b.clk when m_b_c_wwc")) + assert(result.contains("rand m_a_wwc_data : UInt<32>, m_a_clk when m_a_wwc_active")) + assert(result.contains("rand m_b_wwc_data : UInt<32>, m_b_clk when m_b_c_wwc")) // a write-write conflict occurs when both ports are enabled and the addresses match Seq(("a", "b"), ("a", "c"), ("b", "c")).foreach { case (w1, w2) => assert( - result.contains(s"m_${w1}_${w2}_wwc <= and(and(m_${w1}_en, m_${w2}_en), eq(m.${w1}.addr, m.${w2}.addr))") + result.contains(s"m_${w1}_${w2}_wwc <= and(and(m_${w1}_en, m_${w2}_en), eq(m_${w1}_addr, m_${w2}_addr))") ) } @@ -67,7 +67,7 @@ class UndefinedMemoryBehaviorSpec extends LeanTransformSpec(Seq(Dependency(Undef val result = circuit.serialize.split('\n').map(_.trim) // we should not compute the conflict between a and c since it is impossible - assert(!result.contains("node m_a_c_wwc = and(and(m_a_en, m_c_en), eq(m.a.addr, m.c.addr))")) + assert(!result.contains("node m_a_c_wwc = and(and(m_a_en, m_c_en), eq(m_a_addr, m_c_addr))")) // the enable of port b depends on whether there is a conflict with a assert(result.contains("m.b.en <= and(m_b_en, not(m_a_b_wwc))")) @@ -91,7 +91,7 @@ class UndefinedMemoryBehaviorSpec extends LeanTransformSpec(Seq(Dependency(Undef assert( result.contains( - """assert(m.a.clk, or(not(and(m.a.en, m.a.mask)), geq(UInt<5>("h1e"), m.a.addr)), UInt<1>("h1"), "out of bounds read")""" + """assert(m_a_clk, geq(UInt<5>("h1e"), m_a_addr), UInt<1>("h1"), "out of bounds read")""" ) ) } @@ -102,10 +102,10 @@ class UndefinedMemoryBehaviorSpec extends LeanTransformSpec(Seq(Dependency(Undef val result = circuit.serialize.split('\n').map(_.trim) // an out of bounds read happens if the depth is not greater or equal to the address - assert(result.contains("node m_r_oob = not(geq(UInt<5>(\"h1e\"), m.r.addr))")) + assert(result.contains("node m_r_oob = not(geq(UInt<5>(\"h1e\"), m_r_addr))")) // the source of randomness needs to be triggered when there is an out of bounds read - assert(result.contains("rand m_r_rand_data : UInt<32>, m.r.clk when m_r_oob")) + assert(result.contains("rand m_r_rand_data : UInt<32>, m_r_clk when m_r_oob")) // the data is random when there is an oob assert(result.contains("m_r_data <= mux(m_r_oob, m_r_rand_data, m.r.data)")) @@ -118,10 +118,10 @@ class UndefinedMemoryBehaviorSpec extends LeanTransformSpec(Seq(Dependency(Undef val result = circuit.serialize.split('\n').map(_.trim) // the memory is disabled when it is not enabled - assert(result.contains("node m_r_disabled = not(m.r.en)")) + assert(result.contains("node m_r_disabled = not(m_r_en)")) // the source of randomness needs to be triggered when there is an read while the port is disabled - assert(result.contains("rand m_r_rand_data : UInt<32>, m.r.clk when m_r_disabled")) + assert(result.contains("rand m_r_rand_data : UInt<32>, m_r_clk when m_r_disabled")) // the data is random when there is an un-enabled read assert(result.contains("m_r_data <= mux(m_r_disabled, m_r_rand_data, m.r.data)")) @@ -134,16 +134,16 @@ class UndefinedMemoryBehaviorSpec extends LeanTransformSpec(Seq(Dependency(Undef val result = circuit.serialize.split('\n').map(_.trim) // the memory is disabled when it is not enabled - assert(result.contains("node m_r_disabled = not(m.r.en)")) + assert(result.contains("node m_r_disabled = not(m_r_en)")) // an out of bounds read happens if the depth is not greater or equal to the address and the memory is enabled - assert(result.contains("node m_r_oob = and(m.r.en, not(geq(UInt<5>(\"h1e\"), m.r.addr)))")) + assert(result.contains("node m_r_oob = and(m_r_en, not(geq(UInt<5>(\"h1e\"), m_r_addr)))")) // the two possible issues are combined into a single signal assert(result.contains("node m_r_do_rand = or(m_r_disabled, m_r_oob)")) // the source of randomness needs to be triggered when either issue occurs - assert(result.contains("rand m_r_rand_data : UInt<32>, m.r.clk when m_r_do_rand")) + assert(result.contains("rand m_r_rand_data : UInt<32>, m_r_clk when m_r_do_rand")) // the data is random when either issue occurs assert(result.contains("m_r_data <= mux(m_r_do_rand, m_r_rand_data, m.r.data)")) @@ -159,7 +159,7 @@ class UndefinedMemoryBehaviorSpec extends LeanTransformSpec(Seq(Dependency(Undef assert(result.contains("m_r_do_rand_r1 <= m_r_do_rand")) // the source of randomness needs to be triggered by the pipeline register - assert(result.contains("rand m_r_rand_data : UInt<32>, m.r.clk when m_r_do_rand_r1")) + assert(result.contains("rand m_r_rand_data : UInt<32>, m_r_clk when m_r_do_rand_r1")) // the data is random when the pipeline register is 1 assert(result.contains("m_r_data <= mux(m_r_do_rand_r1, m_r_rand_data, m.r.data)")) @@ -171,13 +171,13 @@ class UndefinedMemoryBehaviorSpec extends LeanTransformSpec(Seq(Dependency(Undef val result = circuit.serialize.split('\n').map(_.trim) // detect read/write conflicts - assert(result.contains("m_r_a_rwc <= and(and(m.r.en, and(m.a.en, m.a.mask)), eq(m.r.addr, m.a.addr))")) + assert(result.contains("m_r_a_rwc <= eq(m_r_addr, m_a_addr)")) // delay the signal assert(result.contains("m_r_do_rand_r1 <= m_r_rwc")) // randomize the data - assert(result.contains("rand m_r_rand_data : UInt<32>, m.r.clk when m_r_do_rand_r1")) + assert(result.contains("rand m_r_rand_data : UInt<32>, m_r_clk when m_r_do_rand_r1")) assert(result.contains("m_r_data <= mux(m_r_do_rand_r1, m_r_rand_data, m.r.data)")) } } |
