aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala')
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala312
1 files changed, 102 insertions, 210 deletions
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
index c1857667..d3a1ed68 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
@@ -5,9 +5,12 @@ package firrtl.backends.experimental.smt
import firrtl.annotations.{MemoryInitAnnotation, NoTargetAnnotation, PresetRegAnnotation}
import FirrtlExpressionSemantics.getWidth
+import firrtl.backends.experimental.smt.random._
import firrtl.graph.MutableDiGraph
import firrtl.options.Dependency
+import firrtl.passes.MemPortUtils.memPortField
import firrtl.passes.PassException
+import firrtl.passes.memlib.VerilogMemDelays
import firrtl.stage.Forms
import firrtl.stage.TransformManager.TransformDependency
import firrtl.transforms.{DeadCodeElimination, PropagatePresetAnnotations}
@@ -58,7 +61,8 @@ object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration {
// TODO: We only really need [[Forms.MidForm]] + LowerTypes, but we also want to fail if there are CombLoops
// TODO: We also would like to run some optimization passes, but RemoveValidIf won't allow us to model DontCare
// precisely and PadWidths emits ill-typed firrtl.
- override def prerequisites: Seq[Dependency[Transform]] = Forms.LowForm
+ override def prerequisites: Seq[Dependency[Transform]] = Forms.LowForm ++
+ Seq(Dependency(UndefinedMemoryBehaviorPass), Dependency(VerilogMemDelays))
override def invalidates(a: Transform): Boolean = false
// since this pass only runs on the main module, inlining needs to happen before
override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency[firrtl.passes.InlineInstances])
@@ -169,8 +173,7 @@ private class ModuleToTransitionSystem extends LazyLogging {
onRegister(name, width, resetExpr, initExpr, nextExpr, presetRegs)
}
// turn memories into state
- val memoryEncoding = new MemoryEncoding(makeRandom, scan.namespace)
- val memoryStatesAndOutputs = scan.memories.map(m => memoryEncoding.onMemory(m, scan.connects, memInit.get(m.name)))
+ val memoryStatesAndOutputs = scan.memories.map(m => onMemory(m, scan.connects, memInit.get(m.name)))
// replace pseudo assigns for memory outputs
val memOutputs = memoryStatesAndOutputs.flatMap(_._2).toMap
val signalsWithMem = signals.map { s =>
@@ -185,7 +188,7 @@ private class ModuleToTransitionSystem extends LazyLogging {
case _ => true
}
)
- val states = regStates.toArray ++ memoryStatesAndOutputs.flatMap(_._1)
+ val states = regStates.toArray ++ memoryStatesAndOutputs.map(_._1)
// generate comments from infos
val comments = mutable.HashMap[String, String]()
@@ -247,233 +250,116 @@ private class ModuleToTransitionSystem extends LazyLogging {
}
}
- private val InfoSeparator = ", "
- private val InfoPrefix = "@ "
- private def serializeInfo(info: ir.Info): Option[String] = info match {
- case ir.NoInfo => None
- case f: ir.FileInfo => Some(f.escaped)
- case m: ir.MultiInfo =>
- val infos = m.flatten
- if (infos.isEmpty) { None }
- else { Some(infos.map(_.escaped).mkString(InfoSeparator)) }
- }
-
- private[firrtl] val randoms = mutable.LinkedHashMap[String, BVSymbol]()
- private def makeRandom(baseName: String, width: Int): BVExpr = {
- // TODO: actually ensure that there cannot be any name clashes with other identifiers
- val suffixes = Iterator(baseName) ++ (0 until 200).map(ii => baseName + "_" + ii)
- val name = suffixes.map(s => "RANDOM." + s).find(!randoms.contains(_)).get
- val sym = BVSymbol(name, width)
- randoms(name) = sym
- sym
- }
-}
-
-private class MemoryEncoding(makeRandom: (String, Int) => BVExpr, namespace: Namespace) extends LazyLogging {
type Connects = Iterable[(String, BVExpr)]
- def onMemory(
- defMem: ir.DefMemory,
- connects: Connects,
- initValue: Option[MemoryInitValue]
- ): (Iterable[State], Connects) = {
- // we can only work on appropriately lowered memories
- assert(
- defMem.dataType.isInstanceOf[ir.GroundType],
- s"Memory $defMem is of type ${defMem.dataType} which is not a ground type!"
- )
- assert(defMem.readwriters.isEmpty, "Combined read/write ports are not supported! Please split them up.")
+ private def onMemory(m: ir.DefMemory, connects: Connects, initValue: Option[MemoryInitValue]): (State, Connects) = {
+ checkMem(m)
- // collect all memory meta-data in a custom class
- val m = new MemInfo(defMem)
+ // map of inputs to the memory
+ val inputs = connects.filter(_._1.startsWith(m.name)).toMap
- // find all connections related to this memory
- val inputs = connects.filter(_._1.startsWith(m.prefix)).toMap
+ // derive the type of the memory from the dataType and depth
+ val dataWidth = getWidth(m.dataType)
+ val indexWidth = Utils.getUIntWidth(m.depth - 1).max(1)
+ val memSymbol = ArraySymbol(m.name, indexWidth, dataWidth)
// there could be a constant init
- val init = initValue.map(getInit(m, _))
-
- // parse and check read and write ports
- val writers = defMem.writers.map(w => new WritePort(m, w, inputs))
- val readers = defMem.readers.map(r => new ReadPort(m, r, inputs))
-
- // derive next state from all write ports
- assert(defMem.writeLatency == 1, "Only memories with write-latency of one are supported.")
- val next: ArrayExpr = if (writers.isEmpty) { m.sym }
- else {
- if (writers.length > 2) {
- throw new UnsupportedFeatureException(s"memories with 3+ write ports (${m.name})")
- }
- val validData = writers.foldLeft[ArrayExpr](m.sym) { case (sym, w) => w.writeTo(sym) }
- if (writers.length == 1) { validData }
- else {
- assert(writers.length == 2)
- val conflict = writers.head.doesConflict(writers.last)
- val conflictData = writers.head.makeRandomData("_write_write_collision")
- val conflictStore = ArrayStore(m.sym, writers.head.addr, conflictData)
- ArrayIte(conflict, conflictStore, validData)
- }
- }
- val state = State(m.sym, init, Some(next))
+ val init = initValue.map(getInit(m, indexWidth, dataWidth, _))
+ init.foreach(e => assert(e.dataWidth == memSymbol.dataWidth && e.indexWidth == memSymbol.indexWidth))
- // derive data signals from all read ports
- assert(defMem.readLatency >= 0)
- if (defMem.readLatency > 1) {
- throw new UnsupportedFeatureException(s"memories with read latency 2+ (${m.name})")
- }
- val readPortSignals = if (defMem.readLatency == 0) {
- readers.map { r =>
- // combinatorial read
- if (defMem.readUnderWrite != ir.ReadUnderWrite.New) {
- logger.warn(
- s"WARN: Memory ${m.name} with combinatorial read port will always return the most recently written entry." +
- s" The read-under-write => ${defMem.readUnderWrite} setting will be ignored."
- )
- }
- // since we do a combinatorial read, the "old" data is the current data
- val data = r.read()
- r.data.name -> data
- }
- } else { Seq() }
- val readPortSignalsAndStates = if (defMem.readLatency == 1) {
- readers.map { r =>
- defMem.readUnderWrite match {
- case ir.ReadUnderWrite.New =>
- // create a state to save the address and the enable signal
- val enPrev = BVSymbol(namespace.newName(r.en.name + "_prev"), r.en.width)
- val addrPrev = BVSymbol(namespace.newName(r.addr.name + "_prev"), r.addr.width)
- val signal = r.data.name -> r.read(addr = addrPrev, en = enPrev)
- val states = Seq(State(enPrev, None, next = Some(r.en)), State(addrPrev, None, next = Some(r.addr)))
- (Seq(signal), states)
- case ir.ReadUnderWrite.Undefined =>
- // check for potential read/write conflicts in which case we need to return an arbitrary value
- val anyWriteToTheSameAddress = any(writers.map(_.doesConflict(r)))
- val next = if (anyWriteToTheSameAddress == False) { r.read() }
- else {
- val readUnderWriteData = r.makeRandomData("_read_under_write_undefined")
- BVIte(anyWriteToTheSameAddress, readUnderWriteData, r.read())
- }
- (Seq(), Seq(State(r.data, init = None, next = Some(next))))
- case ir.ReadUnderWrite.Old =>
- // we create a register for the read port data
- (Seq(), Seq(State(r.data, init = None, next = Some(r.read()))))
- }
+ // derive next state expression
+ val next = if (m.writers.isEmpty) {
+ memSymbol
+ } else {
+ m.writers.foldLeft[ArrayExpr](memSymbol) {
+ case (prev, write) =>
+ // update
+ val addr = BVSymbol(memPortField(m, write, "addr").serialize, indexWidth)
+ val data = BVSymbol(memPortField(m, write, "data").serialize, dataWidth)
+ val update = ArrayStore(prev, index = addr, data = data)
+
+ // update guard
+ val en = BVSymbol(memPortField(m, write, "en").serialize, 1)
+ val mask = BVSymbol(memPortField(m, write, "mask").serialize, 1)
+ val alwaysEnabled = Seq(en, mask).forall(s => inputs(s.name) == True)
+ if (alwaysEnabled) { update }
+ else {
+ ArrayIte(and(en, mask), update, prev)
+ }
}
- } else { Seq() }
+ }
- val allReadPortSignals = readPortSignals ++ readPortSignalsAndStates.flatMap(_._1)
- val readPortStates = readPortSignalsAndStates.flatMap(_._2)
+ val state = State(memSymbol, init, Some(next))
- (state +: readPortStates, allReadPortSignals)
- }
+ // derive read expressions
+ val readSignals = m.readers.map { read =>
+ val addr = BVSymbol(memPortField(m, read, "addr").serialize, indexWidth)
+ memPortField(m, read, "data").serialize -> ArrayRead(memSymbol, addr)
+ }
- private def getInit(m: MemInfo, initValue: MemoryInitValue): ArrayExpr = initValue match {
- case MemoryScalarInit(value) => ArrayConstant(BVLiteral(value, m.dataWidth), m.indexWidth)
- case MemoryArrayInit(values) =>
- assert(
- values.length == m.depth,
- s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!"
- )
- // in order to get a more compact encoding try to find the most common values
- val histogram = mutable.LinkedHashMap[BigInt, Int]()
- values.foreach(v => histogram(v) = 1 + histogram.getOrElse(v, 0))
- val baseValue = histogram.maxBy(_._2)._1
- val base = ArrayConstant(BVLiteral(baseValue, m.dataWidth), m.indexWidth)
- values.zipWithIndex
- .filterNot(_._1 == baseValue)
- .foldLeft[ArrayExpr](base) {
- case (array, (value, index)) =>
- ArrayStore(array, BVLiteral(index, m.indexWidth), BVLiteral(value, m.dataWidth))
- }
- case other => throw new RuntimeException(s"Unsupported memory init option: $other")
+ (state, readSignals)
}
- private class MemInfo(m: ir.DefMemory) {
- val name = m.name
- val depth = m.depth
- // derrive the type of the memory from the dataType and depth
- val dataWidth = getWidth(m.dataType)
- val indexWidth = Utils.getUIntWidth(m.depth - 1).max(1)
- val sym = ArraySymbol(m.name, indexWidth, dataWidth)
- val prefix = m.name + "."
- val fullAddressRange = (BigInt(1) << indexWidth) == m.depth
- lazy val depthBV = BVLiteral(m.depth, indexWidth)
- def isValidAddress(addr: BVExpr): BVExpr = {
- if (fullAddressRange) { True }
- else {
- BVComparison(Compare.Greater, depthBV, addr, signed = false)
- }
- }
- }
- private abstract class MemPort(memory: MemInfo, val name: String, inputs: String => BVExpr) {
- val en: BVSymbol = makeField("en", 1)
- val data: BVSymbol = makeField("data", memory.dataWidth)
- val addr: BVSymbol = makeField("addr", memory.indexWidth)
- protected def makeField(field: String, width: Int): BVSymbol = BVSymbol(memory.prefix + name + "." + field, width)
- // make sure that all widths are correct
- assert(inputs(en.name).width == en.width)
- assert(inputs(addr.name).width == addr.width)
- val enIsTrue: Boolean = inputs(en.name) == True
- def makeRandomData(suffix: String): BVExpr =
- makeRandom(memory.name + "_" + name + suffix, memory.dataWidth)
- def read(addr: BVSymbol = addr, en: BVSymbol = en): BVExpr = {
- val canBeOutOfRange = !memory.fullAddressRange
- val canBeDisabled = !enIsTrue
- val data = ArrayRead(memory.sym, addr)
- val dataWithRangeCheck = if (canBeOutOfRange) {
- val outOfRangeData = makeRandomData("_addr_out_of_range")
- BVIte(memory.isValidAddress(addr), data, outOfRangeData)
- } else { data }
- val dataWithEnabledCheck = if (canBeDisabled) {
- val disabledData = makeRandomData("_not_enabled")
- BVIte(en, dataWithRangeCheck, disabledData)
- } else { dataWithRangeCheck }
- dataWithEnabledCheck
- }
- }
- private class WritePort(memory: MemInfo, name: String, inputs: String => BVExpr)
- extends MemPort(memory, name, inputs) {
- assert(inputs(data.name).width == data.width)
- val mask: BVSymbol = makeField("mask", 1)
- assert(inputs(mask.name).width == mask.width)
- val maskIsTrue: Boolean = inputs(mask.name) == True
- val doWrite: BVExpr = (enIsTrue, maskIsTrue) match {
- case (true, true) => True
- case (true, false) => mask
- case (false, true) => en
- case (false, false) => and(en, mask)
- }
- def doesConflict(r: ReadPort): BVExpr = {
- val sameAddress = BVEqual(r.addr, addr)
- if (doWrite == True) { sameAddress }
- else { and(doWrite, sameAddress) }
- }
- def doesConflict(w: WritePort): BVExpr = {
- val bothWrite = and(doWrite, w.doWrite)
- val sameAddress = BVEqual(addr, w.addr)
- if (bothWrite == True) { sameAddress }
- else { and(bothWrite, sameAddress) }
- }
- def writeTo(array: ArrayExpr): ArrayExpr = {
- val doUpdate = if (memory.fullAddressRange) doWrite else and(doWrite, memory.isValidAddress(addr))
- val update = ArrayStore(array, index = addr, data = data)
- if (doUpdate == True) update else ArrayIte(doUpdate, update, array)
+ private def getInit(m: ir.DefMemory, indexWidth: Int, dataWidth: Int, initValue: MemoryInitValue): ArrayExpr =
+ initValue match {
+ case MemoryScalarInit(value) => ArrayConstant(BVLiteral(value, dataWidth), indexWidth)
+ case MemoryArrayInit(values) =>
+ assert(
+ values.length == m.depth,
+ s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!"
+ )
+ // in order to get a more compact encoding try to find the most common values
+ val histogram = mutable.LinkedHashMap[BigInt, Int]()
+ values.foreach(v => histogram(v) = 1 + histogram.getOrElse(v, 0))
+ val baseValue = histogram.maxBy(_._2)._1
+ val base = ArrayConstant(BVLiteral(baseValue, dataWidth), indexWidth)
+ values.zipWithIndex
+ .filterNot(_._1 == baseValue)
+ .foldLeft[ArrayExpr](base) {
+ case (array, (value, index)) =>
+ ArrayStore(array, BVLiteral(index, indexWidth), BVLiteral(value, dataWidth))
+ }
+ case other => throw new RuntimeException(s"Unsupported memory init option: $other")
}
- }
- private class ReadPort(memory: MemInfo, name: String, inputs: String => BVExpr)
- extends MemPort(memory, name, inputs) {}
-
+ // TODO: add to BV expression library
private def and(a: BVExpr, b: BVExpr): BVExpr = (a, b) match {
case (True, True) => True
case (True, x) => x
case (x, True) => x
case _ => BVOp(Op.And, a, b)
}
- private def or(a: BVExpr, b: BVExpr): BVExpr = BVOp(Op.Or, a, b)
+
private val True = BVLiteral(1, 1)
- private val False = BVLiteral(0, 1)
- private def all(b: Iterable[BVExpr]): BVExpr = if (b.isEmpty) False else b.reduce((a, b) => and(a, b))
- private def any(b: Iterable[BVExpr]): BVExpr = if (b.isEmpty) True else b.reduce((a, b) => or(a, b))
+ private def checkMem(m: ir.DefMemory): Unit = {
+ assert(m.readLatency == 0, "Expected read latency to be 0. Did you run VerilogMemDelays?")
+ assert(m.writeLatency == 1, "Expected read latency to be 1. Did you run VerilogMemDelays?")
+ assert(
+ m.dataType.isInstanceOf[ir.GroundType],
+ s"Memory $m is of type ${m.dataType} which is not a ground type!"
+ )
+ assert(m.readwriters.isEmpty, "Combined read/write ports are not supported! Please split them up.")
+ }
+
+ private val InfoSeparator = ", "
+ private val InfoPrefix = "@ "
+ private def serializeInfo(info: ir.Info): Option[String] = info match {
+ case ir.NoInfo => None
+ case f: ir.FileInfo => Some(f.escaped)
+ case m: ir.MultiInfo =>
+ val infos = m.flatten
+ if (infos.isEmpty) { None }
+ else { Some(infos.map(_.escaped).mkString(InfoSeparator)) }
+ }
+
+ private[firrtl] val randoms = mutable.LinkedHashMap[String, BVSymbol]()
+ private def makeRandom(baseName: String, width: Int): BVExpr = {
+ // TODO: actually ensure that there cannot be any name clashes with other identifiers
+ val suffixes = Iterator(baseName) ++ (0 until 200).map(ii => baseName + "_" + ii)
+ val name = suffixes.map(s => "RANDOM." + s).find(!randoms.contains(_)).get
+ val sym = BVSymbol(name, width)
+ randoms(name) = sym
+ sym
+ }
}
// performas a first pass over the module collecting all connections, wires, registers, input and outputs
@@ -526,6 +412,12 @@ private class ModuleScanner(
}
private[firrtl] def onStatement(s: ir.Statement): Unit = s match {
+ case DefRandom(info, name, tpe, _, _) =>
+ namespace.newName(name)
+ assert(!isClock(tpe), "rand should never be a clock!")
+ // we model random sources as inputs and ignore the enable signal
+ infos.append(name -> info)
+ inputs.append(BVSymbol(name, getWidth(tpe)))
case ir.DefWire(info, name, tpe) =>
namespace.newName(name)
if (!isClock(tpe)) {