diff options
Diffstat (limited to 'src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala')
| -rw-r--r-- | src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala | 312 |
1 files changed, 102 insertions, 210 deletions
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala index c1857667..d3a1ed68 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala @@ -5,9 +5,12 @@ package firrtl.backends.experimental.smt import firrtl.annotations.{MemoryInitAnnotation, NoTargetAnnotation, PresetRegAnnotation} import FirrtlExpressionSemantics.getWidth +import firrtl.backends.experimental.smt.random._ import firrtl.graph.MutableDiGraph import firrtl.options.Dependency +import firrtl.passes.MemPortUtils.memPortField import firrtl.passes.PassException +import firrtl.passes.memlib.VerilogMemDelays import firrtl.stage.Forms import firrtl.stage.TransformManager.TransformDependency import firrtl.transforms.{DeadCodeElimination, PropagatePresetAnnotations} @@ -58,7 +61,8 @@ object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration { // TODO: We only really need [[Forms.MidForm]] + LowerTypes, but we also want to fail if there are CombLoops // TODO: We also would like to run some optimization passes, but RemoveValidIf won't allow us to model DontCare // precisely and PadWidths emits ill-typed firrtl. - override def prerequisites: Seq[Dependency[Transform]] = Forms.LowForm + override def prerequisites: Seq[Dependency[Transform]] = Forms.LowForm ++ + Seq(Dependency(UndefinedMemoryBehaviorPass), Dependency(VerilogMemDelays)) override def invalidates(a: Transform): Boolean = false // since this pass only runs on the main module, inlining needs to happen before override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency[firrtl.passes.InlineInstances]) @@ -169,8 +173,7 @@ private class ModuleToTransitionSystem extends LazyLogging { onRegister(name, width, resetExpr, initExpr, nextExpr, presetRegs) } // turn memories into state - val memoryEncoding = new MemoryEncoding(makeRandom, scan.namespace) - val memoryStatesAndOutputs = scan.memories.map(m => memoryEncoding.onMemory(m, scan.connects, memInit.get(m.name))) + val memoryStatesAndOutputs = scan.memories.map(m => onMemory(m, scan.connects, memInit.get(m.name))) // replace pseudo assigns for memory outputs val memOutputs = memoryStatesAndOutputs.flatMap(_._2).toMap val signalsWithMem = signals.map { s => @@ -185,7 +188,7 @@ private class ModuleToTransitionSystem extends LazyLogging { case _ => true } ) - val states = regStates.toArray ++ memoryStatesAndOutputs.flatMap(_._1) + val states = regStates.toArray ++ memoryStatesAndOutputs.map(_._1) // generate comments from infos val comments = mutable.HashMap[String, String]() @@ -247,233 +250,116 @@ private class ModuleToTransitionSystem extends LazyLogging { } } - private val InfoSeparator = ", " - private val InfoPrefix = "@ " - private def serializeInfo(info: ir.Info): Option[String] = info match { - case ir.NoInfo => None - case f: ir.FileInfo => Some(f.escaped) - case m: ir.MultiInfo => - val infos = m.flatten - if (infos.isEmpty) { None } - else { Some(infos.map(_.escaped).mkString(InfoSeparator)) } - } - - private[firrtl] val randoms = mutable.LinkedHashMap[String, BVSymbol]() - private def makeRandom(baseName: String, width: Int): BVExpr = { - // TODO: actually ensure that there cannot be any name clashes with other identifiers - val suffixes = Iterator(baseName) ++ (0 until 200).map(ii => baseName + "_" + ii) - val name = suffixes.map(s => "RANDOM." + s).find(!randoms.contains(_)).get - val sym = BVSymbol(name, width) - randoms(name) = sym - sym - } -} - -private class MemoryEncoding(makeRandom: (String, Int) => BVExpr, namespace: Namespace) extends LazyLogging { type Connects = Iterable[(String, BVExpr)] - def onMemory( - defMem: ir.DefMemory, - connects: Connects, - initValue: Option[MemoryInitValue] - ): (Iterable[State], Connects) = { - // we can only work on appropriately lowered memories - assert( - defMem.dataType.isInstanceOf[ir.GroundType], - s"Memory $defMem is of type ${defMem.dataType} which is not a ground type!" - ) - assert(defMem.readwriters.isEmpty, "Combined read/write ports are not supported! Please split them up.") + private def onMemory(m: ir.DefMemory, connects: Connects, initValue: Option[MemoryInitValue]): (State, Connects) = { + checkMem(m) - // collect all memory meta-data in a custom class - val m = new MemInfo(defMem) + // map of inputs to the memory + val inputs = connects.filter(_._1.startsWith(m.name)).toMap - // find all connections related to this memory - val inputs = connects.filter(_._1.startsWith(m.prefix)).toMap + // derive the type of the memory from the dataType and depth + val dataWidth = getWidth(m.dataType) + val indexWidth = Utils.getUIntWidth(m.depth - 1).max(1) + val memSymbol = ArraySymbol(m.name, indexWidth, dataWidth) // there could be a constant init - val init = initValue.map(getInit(m, _)) - - // parse and check read and write ports - val writers = defMem.writers.map(w => new WritePort(m, w, inputs)) - val readers = defMem.readers.map(r => new ReadPort(m, r, inputs)) - - // derive next state from all write ports - assert(defMem.writeLatency == 1, "Only memories with write-latency of one are supported.") - val next: ArrayExpr = if (writers.isEmpty) { m.sym } - else { - if (writers.length > 2) { - throw new UnsupportedFeatureException(s"memories with 3+ write ports (${m.name})") - } - val validData = writers.foldLeft[ArrayExpr](m.sym) { case (sym, w) => w.writeTo(sym) } - if (writers.length == 1) { validData } - else { - assert(writers.length == 2) - val conflict = writers.head.doesConflict(writers.last) - val conflictData = writers.head.makeRandomData("_write_write_collision") - val conflictStore = ArrayStore(m.sym, writers.head.addr, conflictData) - ArrayIte(conflict, conflictStore, validData) - } - } - val state = State(m.sym, init, Some(next)) + val init = initValue.map(getInit(m, indexWidth, dataWidth, _)) + init.foreach(e => assert(e.dataWidth == memSymbol.dataWidth && e.indexWidth == memSymbol.indexWidth)) - // derive data signals from all read ports - assert(defMem.readLatency >= 0) - if (defMem.readLatency > 1) { - throw new UnsupportedFeatureException(s"memories with read latency 2+ (${m.name})") - } - val readPortSignals = if (defMem.readLatency == 0) { - readers.map { r => - // combinatorial read - if (defMem.readUnderWrite != ir.ReadUnderWrite.New) { - logger.warn( - s"WARN: Memory ${m.name} with combinatorial read port will always return the most recently written entry." + - s" The read-under-write => ${defMem.readUnderWrite} setting will be ignored." - ) - } - // since we do a combinatorial read, the "old" data is the current data - val data = r.read() - r.data.name -> data - } - } else { Seq() } - val readPortSignalsAndStates = if (defMem.readLatency == 1) { - readers.map { r => - defMem.readUnderWrite match { - case ir.ReadUnderWrite.New => - // create a state to save the address and the enable signal - val enPrev = BVSymbol(namespace.newName(r.en.name + "_prev"), r.en.width) - val addrPrev = BVSymbol(namespace.newName(r.addr.name + "_prev"), r.addr.width) - val signal = r.data.name -> r.read(addr = addrPrev, en = enPrev) - val states = Seq(State(enPrev, None, next = Some(r.en)), State(addrPrev, None, next = Some(r.addr))) - (Seq(signal), states) - case ir.ReadUnderWrite.Undefined => - // check for potential read/write conflicts in which case we need to return an arbitrary value - val anyWriteToTheSameAddress = any(writers.map(_.doesConflict(r))) - val next = if (anyWriteToTheSameAddress == False) { r.read() } - else { - val readUnderWriteData = r.makeRandomData("_read_under_write_undefined") - BVIte(anyWriteToTheSameAddress, readUnderWriteData, r.read()) - } - (Seq(), Seq(State(r.data, init = None, next = Some(next)))) - case ir.ReadUnderWrite.Old => - // we create a register for the read port data - (Seq(), Seq(State(r.data, init = None, next = Some(r.read())))) - } + // derive next state expression + val next = if (m.writers.isEmpty) { + memSymbol + } else { + m.writers.foldLeft[ArrayExpr](memSymbol) { + case (prev, write) => + // update + val addr = BVSymbol(memPortField(m, write, "addr").serialize, indexWidth) + val data = BVSymbol(memPortField(m, write, "data").serialize, dataWidth) + val update = ArrayStore(prev, index = addr, data = data) + + // update guard + val en = BVSymbol(memPortField(m, write, "en").serialize, 1) + val mask = BVSymbol(memPortField(m, write, "mask").serialize, 1) + val alwaysEnabled = Seq(en, mask).forall(s => inputs(s.name) == True) + if (alwaysEnabled) { update } + else { + ArrayIte(and(en, mask), update, prev) + } } - } else { Seq() } + } - val allReadPortSignals = readPortSignals ++ readPortSignalsAndStates.flatMap(_._1) - val readPortStates = readPortSignalsAndStates.flatMap(_._2) + val state = State(memSymbol, init, Some(next)) - (state +: readPortStates, allReadPortSignals) - } + // derive read expressions + val readSignals = m.readers.map { read => + val addr = BVSymbol(memPortField(m, read, "addr").serialize, indexWidth) + memPortField(m, read, "data").serialize -> ArrayRead(memSymbol, addr) + } - private def getInit(m: MemInfo, initValue: MemoryInitValue): ArrayExpr = initValue match { - case MemoryScalarInit(value) => ArrayConstant(BVLiteral(value, m.dataWidth), m.indexWidth) - case MemoryArrayInit(values) => - assert( - values.length == m.depth, - s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!" - ) - // in order to get a more compact encoding try to find the most common values - val histogram = mutable.LinkedHashMap[BigInt, Int]() - values.foreach(v => histogram(v) = 1 + histogram.getOrElse(v, 0)) - val baseValue = histogram.maxBy(_._2)._1 - val base = ArrayConstant(BVLiteral(baseValue, m.dataWidth), m.indexWidth) - values.zipWithIndex - .filterNot(_._1 == baseValue) - .foldLeft[ArrayExpr](base) { - case (array, (value, index)) => - ArrayStore(array, BVLiteral(index, m.indexWidth), BVLiteral(value, m.dataWidth)) - } - case other => throw new RuntimeException(s"Unsupported memory init option: $other") + (state, readSignals) } - private class MemInfo(m: ir.DefMemory) { - val name = m.name - val depth = m.depth - // derrive the type of the memory from the dataType and depth - val dataWidth = getWidth(m.dataType) - val indexWidth = Utils.getUIntWidth(m.depth - 1).max(1) - val sym = ArraySymbol(m.name, indexWidth, dataWidth) - val prefix = m.name + "." - val fullAddressRange = (BigInt(1) << indexWidth) == m.depth - lazy val depthBV = BVLiteral(m.depth, indexWidth) - def isValidAddress(addr: BVExpr): BVExpr = { - if (fullAddressRange) { True } - else { - BVComparison(Compare.Greater, depthBV, addr, signed = false) - } - } - } - private abstract class MemPort(memory: MemInfo, val name: String, inputs: String => BVExpr) { - val en: BVSymbol = makeField("en", 1) - val data: BVSymbol = makeField("data", memory.dataWidth) - val addr: BVSymbol = makeField("addr", memory.indexWidth) - protected def makeField(field: String, width: Int): BVSymbol = BVSymbol(memory.prefix + name + "." + field, width) - // make sure that all widths are correct - assert(inputs(en.name).width == en.width) - assert(inputs(addr.name).width == addr.width) - val enIsTrue: Boolean = inputs(en.name) == True - def makeRandomData(suffix: String): BVExpr = - makeRandom(memory.name + "_" + name + suffix, memory.dataWidth) - def read(addr: BVSymbol = addr, en: BVSymbol = en): BVExpr = { - val canBeOutOfRange = !memory.fullAddressRange - val canBeDisabled = !enIsTrue - val data = ArrayRead(memory.sym, addr) - val dataWithRangeCheck = if (canBeOutOfRange) { - val outOfRangeData = makeRandomData("_addr_out_of_range") - BVIte(memory.isValidAddress(addr), data, outOfRangeData) - } else { data } - val dataWithEnabledCheck = if (canBeDisabled) { - val disabledData = makeRandomData("_not_enabled") - BVIte(en, dataWithRangeCheck, disabledData) - } else { dataWithRangeCheck } - dataWithEnabledCheck - } - } - private class WritePort(memory: MemInfo, name: String, inputs: String => BVExpr) - extends MemPort(memory, name, inputs) { - assert(inputs(data.name).width == data.width) - val mask: BVSymbol = makeField("mask", 1) - assert(inputs(mask.name).width == mask.width) - val maskIsTrue: Boolean = inputs(mask.name) == True - val doWrite: BVExpr = (enIsTrue, maskIsTrue) match { - case (true, true) => True - case (true, false) => mask - case (false, true) => en - case (false, false) => and(en, mask) - } - def doesConflict(r: ReadPort): BVExpr = { - val sameAddress = BVEqual(r.addr, addr) - if (doWrite == True) { sameAddress } - else { and(doWrite, sameAddress) } - } - def doesConflict(w: WritePort): BVExpr = { - val bothWrite = and(doWrite, w.doWrite) - val sameAddress = BVEqual(addr, w.addr) - if (bothWrite == True) { sameAddress } - else { and(bothWrite, sameAddress) } - } - def writeTo(array: ArrayExpr): ArrayExpr = { - val doUpdate = if (memory.fullAddressRange) doWrite else and(doWrite, memory.isValidAddress(addr)) - val update = ArrayStore(array, index = addr, data = data) - if (doUpdate == True) update else ArrayIte(doUpdate, update, array) + private def getInit(m: ir.DefMemory, indexWidth: Int, dataWidth: Int, initValue: MemoryInitValue): ArrayExpr = + initValue match { + case MemoryScalarInit(value) => ArrayConstant(BVLiteral(value, dataWidth), indexWidth) + case MemoryArrayInit(values) => + assert( + values.length == m.depth, + s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!" + ) + // in order to get a more compact encoding try to find the most common values + val histogram = mutable.LinkedHashMap[BigInt, Int]() + values.foreach(v => histogram(v) = 1 + histogram.getOrElse(v, 0)) + val baseValue = histogram.maxBy(_._2)._1 + val base = ArrayConstant(BVLiteral(baseValue, dataWidth), indexWidth) + values.zipWithIndex + .filterNot(_._1 == baseValue) + .foldLeft[ArrayExpr](base) { + case (array, (value, index)) => + ArrayStore(array, BVLiteral(index, indexWidth), BVLiteral(value, dataWidth)) + } + case other => throw new RuntimeException(s"Unsupported memory init option: $other") } - } - private class ReadPort(memory: MemInfo, name: String, inputs: String => BVExpr) - extends MemPort(memory, name, inputs) {} - + // TODO: add to BV expression library private def and(a: BVExpr, b: BVExpr): BVExpr = (a, b) match { case (True, True) => True case (True, x) => x case (x, True) => x case _ => BVOp(Op.And, a, b) } - private def or(a: BVExpr, b: BVExpr): BVExpr = BVOp(Op.Or, a, b) + private val True = BVLiteral(1, 1) - private val False = BVLiteral(0, 1) - private def all(b: Iterable[BVExpr]): BVExpr = if (b.isEmpty) False else b.reduce((a, b) => and(a, b)) - private def any(b: Iterable[BVExpr]): BVExpr = if (b.isEmpty) True else b.reduce((a, b) => or(a, b)) + private def checkMem(m: ir.DefMemory): Unit = { + assert(m.readLatency == 0, "Expected read latency to be 0. Did you run VerilogMemDelays?") + assert(m.writeLatency == 1, "Expected read latency to be 1. Did you run VerilogMemDelays?") + assert( + m.dataType.isInstanceOf[ir.GroundType], + s"Memory $m is of type ${m.dataType} which is not a ground type!" + ) + assert(m.readwriters.isEmpty, "Combined read/write ports are not supported! Please split them up.") + } + + private val InfoSeparator = ", " + private val InfoPrefix = "@ " + private def serializeInfo(info: ir.Info): Option[String] = info match { + case ir.NoInfo => None + case f: ir.FileInfo => Some(f.escaped) + case m: ir.MultiInfo => + val infos = m.flatten + if (infos.isEmpty) { None } + else { Some(infos.map(_.escaped).mkString(InfoSeparator)) } + } + + private[firrtl] val randoms = mutable.LinkedHashMap[String, BVSymbol]() + private def makeRandom(baseName: String, width: Int): BVExpr = { + // TODO: actually ensure that there cannot be any name clashes with other identifiers + val suffixes = Iterator(baseName) ++ (0 until 200).map(ii => baseName + "_" + ii) + val name = suffixes.map(s => "RANDOM." + s).find(!randoms.contains(_)).get + val sym = BVSymbol(name, width) + randoms(name) = sym + sym + } } // performas a first pass over the module collecting all connections, wires, registers, input and outputs @@ -526,6 +412,12 @@ private class ModuleScanner( } private[firrtl] def onStatement(s: ir.Statement): Unit = s match { + case DefRandom(info, name, tpe, _, _) => + namespace.newName(name) + assert(!isClock(tpe), "rand should never be a clock!") + // we model random sources as inputs and ignore the enable signal + infos.append(name -> info) + inputs.append(BVSymbol(name, getWidth(tpe))) case ir.DefWire(info, name, tpe) => namespace.newName(name) if (!isClock(tpe)) { |
