diff options
16 files changed, 848 insertions, 744 deletions
diff --git a/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala index f96fd4e8..a6eaa51b 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala @@ -3,28 +3,12 @@ package firrtl.backends.experimental.smt -import firrtl.backends.experimental.smt.Btor2Serializer.functionCallToArrayRead - import scala.collection.mutable private object Btor2Serializer { def serialize(sys: TransitionSystem, skipOutput: Boolean = false): Iterable[String] = { new Btor2Serializer().run(sys, skipOutput) } - - private def functionCallToArrayRead(call: BVFunctionCall): BVExpr = { - if (call.args.isEmpty) { - BVSymbol(call.name, call.width) - } else { - val index = concat(call.args) - val a = ArraySymbol(call.name, indexWidth = index.width, dataWidth = call.width) - ArrayRead(a, index) - } - } - private def concat(e: Iterable[BVExpr]): BVExpr = { - require(e.nonEmpty) - e.reduce((a, b) => BVConcat(a, b)) - } } private class Btor2Serializer private () { @@ -55,7 +39,7 @@ private class Btor2Serializer private () { // bit vector expression serialization private def s(expr: BVExpr): Int = expr match { case BVLiteral(value, width) => lit(value, width) - case BVSymbol(name, _) => symbols(name) + case BVSymbol(name, _) => symbols.getOrElse(name, throw new RuntimeException(s"Unknown symbol: $name")) case BVExtend(e, 0, _) => s(e) case BVExtend(e, by, true) => line(s"sext ${t(expr.width)} ${s(e)} $by") case BVExtend(e, by, false) => line(s"uext ${t(expr.width)} ${s(e)} $by") @@ -86,13 +70,13 @@ private class Btor2Serializer private () { line(s"read ${t(expr.width)} ${s(array)} ${s(index)}") case BVIte(cond, tru, fals) => line(s"ite ${t(expr.width)} ${s(cond)} ${s(tru)} ${s(fals)}") - case r: BVRawExpr => - throw new RuntimeException(s"Raw expressions should never reach the btor2 encoder!: ${r.serialized}") + case b @ BVAnd(terms) => variadic("and", b.width, terms) + case b @ BVOr(terms) => variadic("or", b.width, terms) + case forall: BVForall => + throw new RuntimeException(s"Quantifiers are not supported by the btor2 format: ${forall}") } private def s(op: Op.Value): String = op match { - case Op.And => "and" - case Op.Or => "or" case Op.Xor => "xor" case Op.ArithmeticShiftRight => "sra" case Op.ShiftRight => "srl" @@ -112,6 +96,14 @@ private class Btor2Serializer private () { private def binary(op: String, width: Int, a: BVExpr, b: BVExpr): Int = line(s"$op ${t(width)} ${s(a)} ${s(b)}") + private def variadic(op: String, width: Int, terms: List[BVExpr]): Int = terms match { + case Seq() | Seq(_) => throw new RuntimeException(s"expected at least two elements in variadic op $op") + case Seq(a, b) => binary(op, width, a, b) + case head :: tail => + val tailId = variadic(op, width, tail) + line(s"$op ${t(width)} ${s(head)} ${tailId}") + } + private def lit(value: BigInt, w: Int): Int = { val typ = t(w) lazy val mask = (BigInt(1) << w) - 1 @@ -141,10 +133,24 @@ private class Btor2Serializer private () { // While the spec does not seem to allow array ite, it seems to be supported in practice. // It is essential to model memories, so any support in the wild should be fairly well tested. line(s"ite ${t(expr.indexWidth, expr.dataWidth)} ${s(cond)} ${s(tru)} ${s(fals)}") - case ArrayConstant(e, _) => s(e) - case r: ArrayRawExpr => - throw new RuntimeException(s"Raw expressions should never reach the btor2 encoder!: ${r.serialized}") + case ArrayConstant(e, indexWidth) => + // The problem we are facing here is that the only way to create a constant array from a bv expression + // seems to be to use the bv expression as the init value of a state variable. + // Thus we need to create a fake state for every array init expression. + arrayConstants.getOrElseUpdate( + e.toString, { + comment(s"$expr") + val eId = s(e) + val tpeId = t(indexWidth, e.width) + val state = line(s"state $tpeId") + line(s"init $tpeId $state $eId") + state + } + ) + case f: ArrayFunctionCall => + throw new RuntimeException(s"The btor2 format does not support uninterpreted functions that return arrays!: $f") } + private val arrayConstants = mutable.HashMap[String, Int]() private def s(expr: SMTExpr): Int = expr match { case b: BVExpr => s(b) @@ -157,38 +163,62 @@ private class Btor2Serializer private () { case a: ArrayExpr => t(a.indexWidth, a.dataWidth) } + private def functionCallToArrayRead(call: BVFunctionCall): BVExpr = { + if (call.args.isEmpty) { + BVSymbol(call.name, call.width) + } else { + val args: List[BVExpr] = call.args.map { + case b: BVExpr => b + case other => throw new RuntimeException(s"Unsupported call argument: $other in $call") + } + val index = concat(args) + val a = ArraySymbol(call.name, indexWidth = index.width, dataWidth = call.width) + ArrayRead(a, index) + } + } + private def concat(e: Iterable[BVExpr]): BVExpr = { + require(e.nonEmpty) + e.reduce((a, b) => BVConcat(a, b)) + } + def run(sys: TransitionSystem, skipOutput: Boolean): Iterable[String] = { - def declare(name: String, expr: => Int): Unit = { + def declare(name: String, lbl: Option[SignalLabel], expr: => Int): Unit = { assert(!symbols.contains(name), s"Trying to redeclare `$name`") val id = expr symbols(name) = id - if (!skipOutput && sys.outputs.contains(name)) line(s"output $id ; $name") - if (sys.assumes.contains(name)) line(s"constraint $id ; $name") - if (sys.asserts.contains(name)) { - val invertedId = line(s"not ${t(1)} $id") - line(s"bad $invertedId ; $name") + // add label + lbl match { + case Some(IsOutput) => if (!skipOutput) line(s"output $id ; $name") + case Some(IsConstraint) => line(s"constraint $id ; $name") + case Some(IsBad) => line(s"bad $id ; $name") + case Some(IsFair) => line(s"fair $id ; $name") + case _ => } - if (sys.fair.contains(name)) line(s"fair $id ; $name") // add trailing comment sys.comments.get(name).foreach(trailingComment) } // header - sys.header.foreach(comment) + if (sys.header.nonEmpty) { + sys.header.split('\n').foreach(comment) + } // declare inputs sys.inputs.foreach { ii => - declare(ii.name, line(s"input ${t(ii.width)} ${ii.name}")) + declare(ii.name, None, line(s"input ${t(ii.width)} ${ii.name}")) } // declare uninterpreted functions a constant arrays - sys.ufs.foreach { foo => - val sym = if (foo.argWidths.isEmpty) { BVSymbol(foo.name, foo.width) } + val ufs = TransitionSystem.findUninterpretedFunctions(sys) + ufs.foreach { foo => + // only functions returning bit-vectors are supported! + val bvSym = foo.sym.asInstanceOf[BVSymbol] + val sym = if (foo.args.isEmpty) { bvSym } else { - ArraySymbol(foo.name, foo.argWidths.sum, foo.width) + ArraySymbol(bvSym.name, foo.args.map(_.asInstanceOf[BVExpr].width).sum, bvSym.width) } comment(foo.toString) - declare(sym.name, line(s"state ${t(sym)} ${sym.name}")) + declare(sym.name, None, line(s"state ${t(sym)} ${sym.name}")) line(s"next ${t(sym)} ${s(sym)} ${s(sym)}") } @@ -196,14 +226,18 @@ private class Btor2Serializer private () { sys.states.foreach { st => // calculate init expression before declaring the state // this is required by btormc (presumably to avoid cycles in the init expression) - val initId = st.init.map { init => comment(s"${st.sym}.init"); s(init) } - declare(st.sym.name, line(s"state ${t(st.sym)} ${st.sym.name}")) + val initId = st.init.map { + // only in the context of initializing a state can we use a bv expression to model an array + case ArrayConstant(e, _) => comment(s"${st.sym}.init"); s(e) + case init => comment(s"${st.sym}.init"); s(init) + } + declare(st.sym.name, None, line(s"state ${t(st.sym)} ${st.sym.name}")) st.init.foreach { init => line(s"init ${t(init)} ${s(st.sym)} ${initId.get}") } } // define all other signals sys.signals.foreach { signal => - declare(signal.name, s(signal.e)) + declare(signal.name, Some(signal.lbl), s(signal.e)) } // define state next diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala index 2c08ff6a..c7524e21 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala @@ -7,17 +7,11 @@ import firrtl.ir import firrtl.PrimOps import firrtl.passes.CheckWidths.WidthTooBig -private trait TranslationContext { - def getReference(name: String, tpe: ir.Type): BVExpr = BVSymbol(name, firrtl.bitWidth(tpe).toInt) -} - private object FirrtlExpressionSemantics { - def toSMT(e: ir.Expression)(implicit ctx: TranslationContext): BVExpr = { + def toSMT(e: ir.Expression): BVExpr = { val eSMT = e match { case ir.DoPrim(op, args, consts, _) => onPrim(op, args, consts) - case r: ir.Reference => ctx.getReference(r.serialize, r.tpe) - case r: ir.SubField => ctx.getReference(r.serialize, r.tpe) - case r: ir.SubIndex => ctx.getReference(r.serialize, r.tpe) + case r: ir.RefLikeExpression => BVSymbol(r.serialize, getWidth(r)) case ir.UIntLiteral(value, ir.IntWidth(width)) => BVLiteral(value, width.toInt) case ir.SIntLiteral(value, ir.IntWidth(width)) => BVLiteral(value, width.toInt) case ir.Mux(cond, tval, fval, _) => @@ -34,7 +28,7 @@ private object FirrtlExpressionSemantics { } /** Ensures that the result has the desired width by appropriately extending it. */ - def toSMT(e: ir.Expression, width: Int, allowNarrow: Boolean = false)(implicit ctx: TranslationContext): BVExpr = + def toSMT(e: ir.Expression, width: Int, allowNarrow: Boolean = false): BVExpr = forceWidth(toSMT(e), isSigned(e), width, allowNarrow) private def forceWidth(eSMT: BVExpr, eSigned: Boolean, width: Int, allowNarrow: Boolean = false): BVExpr = { @@ -52,8 +46,6 @@ private object FirrtlExpressionSemantics { op: ir.PrimOp, args: Seq[ir.Expression], consts: Seq[BigInt] - )( - implicit ctx: TranslationContext ): BVExpr = { (op, args, consts) match { case (PrimOps.Add, Seq(e1, e2), _) => @@ -137,10 +129,10 @@ private object FirrtlExpressionSemantics { case (PrimOps.Not, Seq(e), _) => BVNot(toSMT(e)) case (PrimOps.And, Seq(e1, e2), _) => val width = args.map(getWidth).max - BVOp(Op.And, toSMT(e1, width), toSMT(e2, width)) + BVAnd(toSMT(e1, width), toSMT(e2, width)) case (PrimOps.Or, Seq(e1, e2), _) => val width = args.map(getWidth).max - BVOp(Op.Or, toSMT(e1, width), toSMT(e2, width)) + BVOr(toSMT(e1, width), toSMT(e2, width)) case (PrimOps.Xor, Seq(e1, e2), _) => val width = args.map(getWidth).max BVOp(Op.Xor, toSMT(e1, width), toSMT(e2, width)) @@ -170,7 +162,7 @@ private object FirrtlExpressionSemantics { private val BV1BitZero = BVLiteral(0, 1) - def isSigned(e: ir.Expression): Boolean = e.tpe match { + private def isSigned(e: ir.Expression): Boolean = e.tpe match { case _: ir.SIntType => true case _ => false } diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala index 726a8854..c5fff849 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala @@ -4,60 +4,24 @@ package firrtl.backends.experimental.smt import firrtl.annotations.{MemoryInitAnnotation, NoTargetAnnotation, PresetRegAnnotation} -import firrtl.bitWidth -import FirrtlExpressionSemantics.getWidth +import firrtl._ import firrtl.backends.experimental.smt.random._ -import firrtl.graph.MutableDiGraph import firrtl.options.Dependency import firrtl.passes.MemPortUtils.memPortField import firrtl.passes.PassException import firrtl.passes.memlib.VerilogMemDelays import firrtl.stage.Forms import firrtl.stage.TransformManager.TransformDependency -import firrtl.transforms.{DeadCodeElimination, EnsureNamedStatements, PropagatePresetAnnotations} -import firrtl.{ - ir, - CircuitState, - DependencyAPIMigration, - MemoryArrayInit, - MemoryInitValue, - MemoryScalarInit, - Namespace, - Transform, - Utils -} +import firrtl.transforms.{EnsureNamedStatements, PropagatePresetAnnotations} import logger.LazyLogging import scala.collection.mutable -// Contains code to convert a flat firrtl module into a functional transition system which -// can then be exported as SMTLib or Btor2 file. - -private case class State(sym: SMTSymbol, init: Option[SMTExpr], next: Option[SMTExpr]) -private case class Signal(name: String, e: BVExpr) { def toSymbol: BVSymbol = BVSymbol(name, e.width) } -private case class TransitionSystem( - name: String, - inputs: Array[BVSymbol], - states: Array[State], - signals: Array[Signal], - outputs: Set[String], - assumes: Set[String], - asserts: Set[String], - fair: Set[String], - ufs: List[BVFunctionSymbol] = List(), - comments: Map[String, String] = Map(), - header: Array[String] = Array()) { - def serialize: String = { - (Iterator(name) ++ - ufs.map(u => u.toString) ++ - inputs.map(i => s"input ${i.name} : ${SMTExpr.serializeType(i)}") ++ - signals.map(s => s"${s.name} : ${SMTExpr.serializeType(s.e)} = ${s.e}") ++ - states.map(s => s"state ${s.sym} = [init] ${s.init} [next] ${s.next}")).mkString("\n") - } -} - private case class TransitionSystemAnnotation(sys: TransitionSystem) extends NoTargetAnnotation +/** Contains code to convert a flat firrtl module into a functional transition system which + * can then be exported as SMTLib or Btor2 file. + */ object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration { override def prerequisites: Seq[Dependency[Transform]] = Forms.LowForm ++ Seq( @@ -94,12 +58,12 @@ object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration { // convert the main module val main = modules(circuit.main) val sys = main match { - case x: ir.ExtModule => + case _: ir.ExtModule => throw new ExtModuleException( "External modules are not supported by the SMT backend. Use yosys if you need to convert Verilog." ) case m: ir.Module => - new ModuleToTransitionSystem().run(m, presetRegs = presetRegs, memInit = memInit, uninterpreted = uninterpreted) + new ModuleToTransitionSystem(presetRegs = presetRegs, memInit = memInit, uninterpreted = uninterpreted).run(m) } val sortedSys = TopologicalSort.run(sys) @@ -131,336 +95,118 @@ private class MultiClockException(s: String) extends PassException(s + Unsupport private class MissingFeatureException(s: String) extends PassException("Unfortunately the SMT backend does not yet support: " + s) -private class ModuleToTransitionSystem extends LazyLogging { +private class ModuleToTransitionSystem( + presetRegs: Set[String], + memInit: Map[String, MemoryInitValue], + uninterpreted: Map[String, UninterpretedModuleAnnotation]) + extends LazyLogging { - def run( - m: ir.Module, - presetRegs: Set[String] = Set(), - memInit: Map[String, MemoryInitValue] = Map(), - uninterpreted: Map[String, UninterpretedModuleAnnotation] = Map() - ): TransitionSystem = { + def run(m: ir.Module): TransitionSystem = { // first pass over the module to convert expressions; discover state and I/O - val scan = new ModuleScanner(uninterpreted) - m.foreachPort(scan.onPort) - m.foreachStmt(scan.onStatement) + m.foreachPort(onPort) + m.foreachStmt(onStatement) // multi-clock support requires the StutteringClock transform to be run - if (scan.clocks.size > 1) { - throw new MultiClockException(s"The module ${m.name} has more than one clock: ${scan.clocks.mkString(", ")}") + if (clocks.size > 1) { + throw new MultiClockException(s"The module ${m.name} has more than one clock: ${clocks.mkString(", ")}") } - // turn wires and nodes into signals - val outputs = scan.outputs.toSet - val constraints = scan.assumes.toSet - val bad = scan.asserts.toSet - val isSignal = (scan.wires ++ scan.nodes ++ scan.memSignals).toSet ++ outputs ++ constraints ++ bad - val signals = scan.connects.filter { case (name, _) => isSignal.contains(name) }.map { - case (name, expr) => Signal(name, expr) - } - - // turn registers and memories into states - val registers = scan.registers.map(r => r._1 -> r).toMap - val regStates = scan.connects.filter(s => registers.contains(s._1)).map { - case (name, nextExpr) => - val (_, width, resetExpr, initExpr) = registers(name) - onRegister(name, width, resetExpr, initExpr, nextExpr, presetRegs) - } - // turn memories into state - val memoryStatesAndOutputs = scan.memories.map(m => onMemory(m, scan.connects, memInit.get(m.name))) - // replace pseudo assigns for memory outputs - val memOutputs = memoryStatesAndOutputs.flatMap(_._2).toMap - val signalsWithMem = signals.map { s => - if (memOutputs.contains(s.name)) { - s.copy(e = memOutputs(s.name)) - } else { s } - } - // filter out any left-over self assignments (this happens when we have a registered read port) - .filter(s => - s match { - case Signal(n0, BVSymbol(n1, _)) if n0 == n1 => false - case _ => true - } - ) - val states = regStates.toArray ++ memoryStatesAndOutputs.map(_._1) - // generate comments from infos val comments = mutable.HashMap[String, String]() - scan.infos.foreach { + infos.foreach { case (name, info) => - serializeInfo(info).foreach { infoString => - if (comments.contains(name)) { comments(name) += InfoSeparator + infoString } - else { comments(name) = InfoPrefix + infoString } + val infoStr = info.serialize.trim + if (infoStr.nonEmpty) { + val prefix = comments.get(name).map(_ + ", ").getOrElse("") + comments(name) = prefix + infoStr } } - // inputs are original module inputs and any DefRandom signal - val inputs = scan.inputs - // module info to the comment header - val header = serializeInfo(m.info).map(InfoPrefix + _).toArray - - val fair = Set[String]() // as of firrtl 1.4 we do not support fairness constraints - - // collect unique functions - val ufs = scan.functionCalls.groupBy(_.name).map(_._2.head).toList - - TransitionSystem( - m.name, - inputs.toArray, - states, - signalsWithMem.toArray, - outputs, - constraints, - bad, - fair, - ufs, - comments.toMap, - header - ) - } - - private def onRegister( - name: String, - width: Int, - resetExpr: BVExpr, - initExpr: BVExpr, - nextExpr: BVExpr, - presetRegs: Set[String] - ): State = { - assert(initExpr.width == width) - assert(nextExpr.width == width) - assert(resetExpr.width == 1) - val sym = BVSymbol(name, width) - val hasReset = initExpr != sym - val isPreset = presetRegs.contains(name) - assert(!isPreset || hasReset, s"Expected preset register $name to have a reset value, not just $initExpr!") - if (hasReset) { - val init = if (isPreset) Some(initExpr) else None - val next = if (isPreset) nextExpr else BVIte(resetExpr, initExpr, nextExpr) - State(sym, next = Some(next), init = init) - } else { - State(sym, next = Some(nextExpr), init = None) - } - } - - type Connects = Iterable[(String, BVExpr)] - private def onMemory(m: ir.DefMemory, connects: Connects, initValue: Option[MemoryInitValue]): (State, Connects) = { - checkMem(m) - - // map of inputs to the memory - val inputs = connects.filter(_._1.startsWith(m.name)).toMap - - // derive the type of the memory from the dataType and depth - val dataWidth = bitWidth(m.dataType).toInt - val indexWidth = Utils.getUIntWidth(m.depth - 1).max(1) - val memSymbol = ArraySymbol(m.name, indexWidth, dataWidth) - - // there could be a constant init - val init = initValue.map(getInit(m, indexWidth, dataWidth, _)) - init.foreach(e => assert(e.dataWidth == memSymbol.dataWidth && e.indexWidth == memSymbol.indexWidth)) - - // derive next state expression - val next = if (m.writers.isEmpty) { - memSymbol - } else { - m.writers.foldLeft[ArrayExpr](memSymbol) { - case (prev, write) => - // update - val addr = BVSymbol(memPortField(m, write, "addr").serialize, indexWidth) - val data = BVSymbol(memPortField(m, write, "data").serialize, dataWidth) - val update = ArrayStore(prev, index = addr, data = data) - - // update guard - val en = BVSymbol(memPortField(m, write, "en").serialize, 1) - val mask = BVSymbol(memPortField(m, write, "mask").serialize, 1) - val alwaysEnabled = Seq(en, mask).forall(s => inputs(s.name) == True) - if (alwaysEnabled) { update } - else { - ArrayIte(and(en, mask), update, prev) - } - } - } - - val state = State(memSymbol, init, Some(next)) - - // derive read expressions - val readSignals = m.readers.map { read => - val addr = BVSymbol(memPortField(m, read, "addr").serialize, indexWidth) - memPortField(m, read, "data").serialize -> ArrayRead(memSymbol, addr) - } + val header = m.info.serialize.trim - (state, readSignals) + TransitionSystem(m.name, inputs.toList, states.values.toList, signals.toList, comments.toMap, header) } - private def getInit(m: ir.DefMemory, indexWidth: Int, dataWidth: Int, initValue: MemoryInitValue): ArrayExpr = - initValue match { - case MemoryScalarInit(value) => ArrayConstant(BVLiteral(value, dataWidth), indexWidth) - case MemoryArrayInit(values) => - assert( - values.length == m.depth, - s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!" - ) - // in order to get a more compact encoding try to find the most common values - val histogram = mutable.LinkedHashMap[BigInt, Int]() - values.foreach(v => histogram(v) = 1 + histogram.getOrElse(v, 0)) - val baseValue = histogram.maxBy(_._2)._1 - val base = ArrayConstant(BVLiteral(baseValue, dataWidth), indexWidth) - values.zipWithIndex - .filterNot(_._1 == baseValue) - .foldLeft[ArrayExpr](base) { - case (array, (value, index)) => - ArrayStore(array, BVLiteral(index, indexWidth), BVLiteral(value, dataWidth)) - } - case other => throw new RuntimeException(s"Unsupported memory init option: $other") - } + private val inputs = mutable.ArrayBuffer[BVSymbol]() + private val clocks = mutable.ArrayBuffer[String]() + private val signals = mutable.ArrayBuffer[Signal]() + private val states = mutable.LinkedHashMap[String, State]() + private val infos = mutable.ArrayBuffer[(String, ir.Info)]() - // TODO: add to BV expression library - private def and(a: BVExpr, b: BVExpr): BVExpr = (a, b) match { - case (True, True) => True - case (True, x) => x - case (x, True) => x - case _ => BVOp(Op.And, a, b) - } - - private val True = BVLiteral(1, 1) - private def checkMem(m: ir.DefMemory): Unit = { - assert(m.readLatency == 0, "Expected read latency to be 0. Did you run VerilogMemDelays?") - assert(m.writeLatency == 1, "Expected read latency to be 1. Did you run VerilogMemDelays?") - assert( - m.dataType.isInstanceOf[ir.GroundType], - s"Memory $m is of type ${m.dataType} which is not a ground type!" - ) - assert(m.readwriters.isEmpty, "Combined read/write ports are not supported! Please split them up.") - } - - private val InfoSeparator = ", " - private val InfoPrefix = "@ " - private def serializeInfo(info: ir.Info): Option[String] = info match { - case ir.NoInfo => None - case f: ir.FileInfo => Some(f.escaped) - case m: ir.MultiInfo => - val infos = m.flatten - if (infos.isEmpty) { None } - else { Some(infos.map(_.escaped).mkString(InfoSeparator)) } - } -} - -// performas a first pass over the module collecting all connections, wires, registers, input and outputs -private class ModuleScanner( - uninterpreted: Map[String, UninterpretedModuleAnnotation]) - extends LazyLogging { - import FirrtlExpressionSemantics.getWidth - - private[firrtl] val inputs = mutable.ArrayBuffer[BVSymbol]() - private[firrtl] val outputs = mutable.ArrayBuffer[String]() - private[firrtl] val clocks = mutable.LinkedHashSet[String]() - private[firrtl] val wires = mutable.ArrayBuffer[String]() - private[firrtl] val nodes = mutable.ArrayBuffer[String]() - private[firrtl] val memSignals = mutable.ArrayBuffer[String]() - private[firrtl] val registers = mutable.ArrayBuffer[(String, Int, BVExpr, BVExpr)]() - private[firrtl] val memories = mutable.ArrayBuffer[ir.DefMemory]() - // DefNode, Connect, IsInvalid and VerificationStatement connections - private[firrtl] val connects = mutable.ArrayBuffer[(String, BVExpr)]() - private[firrtl] val asserts = mutable.ArrayBuffer[String]() - private[firrtl] val assumes = mutable.ArrayBuffer[String]() - // maps identifiers to their info - private[firrtl] val infos = mutable.ArrayBuffer[(String, ir.Info)]() - // Keeps track of (so far) unused memory (data) and uninterpreted module outputs. - // This is used in order to delay declaring them for as long as possible. - private val unusedOutputs = mutable.LinkedHashMap[String, BVExpr]() - // ensure unique names for assert/assume signals - private[firrtl] val namespace = Namespace() - // keep track of all uninterpreted functions called - private[firrtl] val functionCalls = mutable.ArrayBuffer[BVFunctionSymbol]() - - private[firrtl] def onPort(p: ir.Port): Unit = { + private def onPort(p: ir.Port): Unit = { if (isAsyncReset(p.tpe)) { throw new AsyncResetException(s"Found AsyncReset ${p.name}.") } - namespace.newName(p.name) infos.append(p.name -> p.info) p.direction match { case ir.Input => if (isClock(p.tpe)) { - clocks.add(p.name) + clocks.append(p.name) } else { inputs.append(BVSymbol(p.name, bitWidth(p.tpe).toInt)) } case ir.Output => - if (!isClock(p.tpe)) { // we ignore clock outputs - outputs.append(p.name) - } } } - private[firrtl] def onStatement(s: ir.Statement): Unit = s match { - case DefRandom(info, name, tpe, _, _) => - namespace.newName(name) + private def onStatement(s: ir.Statement): Unit = s match { + case DefRandom(info, name, tpe, _, en) => assert(!isClock(tpe), "rand should never be a clock!") - // we model random sources as inputs and ignore the enable signal + // we model random sources as inputs and the enable signal as output infos.append(name -> info) inputs.append(BVSymbol(name, bitWidth(tpe).toInt)) - case ir.DefWire(info, name, tpe) => - namespace.newName(name) - if (!isClock(tpe) && !isAsyncReset(tpe)) { - infos.append(name -> info) - wires.append(name) + signals.append(Signal(name + ".en", onExpression(en, 1), IsOutput)) + case w: ir.DefWire => + if (!isClock(w.tpe)) { + // InlineInstances can insert wires without re-running RemoveWires for now we just deal with it when + // the Wires is connected to (ir.Connect). } case ir.DefNode(info, name, expr) => - namespace.newName(name) if (!isClock(expr.tpe) && !isAsyncReset(expr.tpe)) { - insertDummyAssignsForUnusedOutputs(expr) infos.append(name -> info) - val e = onExpression(expr) - nodes.append(name) - connects.append((name, e)) + signals.append(Signal(name, onExpression(expr), IsNode)) } - case ir.DefRegister(info, name, tpe, _, reset, init) => - namespace.newName(name) - insertDummyAssignsForUnusedOutputs(reset) - insertDummyAssignsForUnusedOutputs(init) - infos.append(name -> info) - val width = bitWidth(tpe).toInt - val resetExpr = onExpression(reset, 1) - val initExpr = onExpression(init, width) - registers.append((name, width, resetExpr, initExpr)) + case r: ir.DefRegister => + infos.append(r.name -> r.info) + states(r.name) = onRegister(r) case m: ir.DefMemory => - namespace.newName(m.name) infos.append(m.name -> m.info) - val outputs = getMemOutputs(m) - (getMemInputs(m) ++ outputs).foreach(memSignals.append(_)) - val dataWidth = bitWidth(m.dataType).toInt - outputs.foreach(name => unusedOutputs(name) = BVSymbol(name, dataWidth)) - memories.append(m) + states(m.name) = onMemory(m) case ir.Connect(info, loc, expr) => if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!") - if (!isClock(loc.tpe)) { // we ignore clock connections + if (!isClock(loc.tpe) && !isAsyncReset(expr.tpe)) { // we ignore clock connections val name = loc.serialize - insertDummyAssignsForUnusedOutputs(expr) - infos.append(name -> info) - connects.append((name, onExpression(expr, bitWidth(loc.tpe).toInt, allowNarrow = true))) + val e = onExpression(expr, bitWidth(loc.tpe).toInt, allowNarrow = false) + Utils.kind(loc) match { + case RegKind => states(name) = states(name).copy(next = Some(e)) + case PortKind | InstanceKind => // module output or submodule input + infos.append(name -> info) + signals.append(Signal(name, e, IsOutput)) + case MemKind | WireKind => + // InlineInstances can insert wires without re-running RemoveWires for now we just deal with it. + infos.append(name -> info) + signals.append(Signal(name, e, IsNode)) + } } - case i @ ir.IsInvalid(info, loc) => - if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!") + case i: ir.IsInvalid => throw new UnsupportedFeatureException(s"IsInvalid statements are not supported: ${i.serialize}") case ir.DefInstance(info, name, module, tpe) => onInstance(info, name, module, tpe) - case s @ ir.Verification(op, info, _, pred, en, msg) => - if (op == ir.Formal.Cover) { + case s: ir.Verification => + if (s.op == ir.Formal.Cover) { logger.info(s"[info] Cover statement was ignored: ${s.serialize}") } else { - insertDummyAssignsForUnusedOutputs(pred) - insertDummyAssignsForUnusedOutputs(en) val name = s.name - val predicate = onExpression(pred) - val enabled = onExpression(en) + val predicate = onExpression(s.pred) + val enabled = onExpression(s.en) val e = BVImplies(enabled, predicate) - infos.append(name -> info) - connects.append(name -> e) - if (op == ir.Formal.Assert) { - asserts.append(name) + infos.append(name -> s.info) + val signal = if (s.op == ir.Formal.Assert) { + Signal(name, BVNot(e), IsBad) } else { - assumes.append(name) + Signal(name, e, IsConstraint) } + signals.append(signal) } case s: ir.Conditionally => error(s"When conditions are not supported. Please run ExpandWhens: ${s.serialize}") @@ -475,20 +221,29 @@ private class ModuleScanner( ) } else { // we treat Stop statements with a non-zero exit value as assertions that en will always be false! - insertDummyAssignsForUnusedOutputs(s.en) val name = s.name infos.append(name -> s.info) - val enabled = onExpression(s.en) - connects.append(name -> BVNot(enabled)) - asserts.append(name) + signals.append(Signal(name, onExpression(s.en), IsBad)) } case s: ir.Print => logger.info(s"Info: ignoring: ${s.serialize}") case other => other.foreachStmt(onStatement) } + private def onRegister(r: ir.DefRegister): State = { + val width = bitWidth(r.tpe).toInt + val resetExpr = onExpression(r.reset, 1) + assert(resetExpr == False(), s"Expected reset expression of ${r.name} to be 0, not $resetExpr") + val initExpr = onExpression(r.init, width) + val sym = BVSymbol(r.name, width) + val hasReset = initExpr != sym + val isPreset = presetRegs.contains(r.name) + assert(!isPreset || hasReset, s"Expected preset register ${r.name} to have a reset value, not just $initExpr!") + val state = State(sym, if (isPreset) Some(initExpr) else None, None) + state + } + private def onInstance(info: ir.Info, name: String, module: String, tpe: ir.Type): Unit = { - namespace.newName(name) if (!tpe.isInstanceOf[ir.BundleType]) error(s"Instance $name of $module has an invalid type: ${tpe.serialize}") if (uninterpreted.contains(module)) { onUninterpretedInstance(info: ir.Info, name: String, module: String, tpe: ir.Type) @@ -509,14 +264,10 @@ private class ModuleScanner( // outputs of the submodule become inputs to our module if (isOutput) { if (isClock(p.tpe)) { - clocks.add(pName) + clocks.append(pName) } else { inputs.append(BVSymbol(pName, bitWidth(p.tpe).toInt)) } - } else { - if (!isClock(p.tpe)) { // we ignore clock outputs - outputs.append(pName) - } } } } @@ -539,112 +290,90 @@ private class ModuleScanner( val functionName = anno.prefix + "." + out.name val call = BVFunctionCall(functionName, args, out.width) val wireName = instanceName + "." + out.name - // remember which functions were called - functionCalls.append(call.toSymbol) - // insert the output definition right before its first use in an attempt to get SSA - unusedOutputs(wireName) = call - // treat these outputs as wires - wires.append(wireName) + signals.append(Signal(wireName, call)) } - - // we also treat the arguments as wires - wires ++= args.map(_.name) } - private val readInputFields = List("en", "addr") - private val writeInputFields = List("en", "mask", "addr", "data") - private def getMemInputs(m: ir.DefMemory): Iterable[String] = { - assert(m.readwriters.isEmpty, "Combined read/write ports are not supported!") - val p = m.name + "." - m.writers.flatMap(w => writeInputFields.map(p + w + "." + _)) ++ - m.readers.flatMap(r => readInputFields.map(p + r + "." + _)) - } - private def getMemOutputs(m: ir.DefMemory): Iterable[String] = { - assert(m.readwriters.isEmpty, "Combined read/write ports are not supported!") - val p = m.name + "." - m.readers.map(r => p + r + ".data") - } - // inserts a dummy assign right before a memory/uninterpreted module output is used for the first time - // example: - // m.r.data <= m.r.data ; this is the dummy assign - // test <= m.r.data ; this is the first use of m.r.data - private def insertDummyAssignsForUnusedOutputs(next: ir.Expression): Unit = if (unusedOutputs.nonEmpty) { - val uses = mutable.ArrayBuffer[String]() - findUnusedOutputUse(next)(uses) - if (uses.nonEmpty) { - val useSet = uses.toSet - unusedOutputs.foreach { - case (name, value) => - if (useSet.contains(name)) connects.append(name -> value) + private def onMemory(m: ir.DefMemory): State = { + checkMem(m) + + // derive the type of the memory from the dataType and depth + val dataWidth = bitWidth(m.dataType).toInt + val indexWidth = Utils.getUIntWidth(m.depth - 1).max(1) + val memSymbol = ArraySymbol(m.name, indexWidth, dataWidth) + + // there could be a constant init + val init = memInit.get(m.name).map(getMemInit(m, indexWidth, dataWidth, _)) + init.foreach(e => assert(e.dataWidth == memSymbol.dataWidth && e.indexWidth == memSymbol.indexWidth)) + + // derive next state expression + val next = if (m.writers.isEmpty) { + memSymbol + } else { + m.writers.foldLeft[ArrayExpr](memSymbol) { + case (prev, write) => + // update + val addr = BVSymbol(memPortField(m, write, "addr").serialize, indexWidth) + val data = BVSymbol(memPortField(m, write, "data").serialize, dataWidth) + val update = ArrayStore(prev, index = addr, data = data) + + // update guard + val en = BVSymbol(memPortField(m, write, "en").serialize, 1) + val mask = BVSymbol(memPortField(m, write, "mask").serialize, 1) + ArrayIte(BVAnd(en, mask), update, prev) } - useSet.foreach(name => unusedOutputs.remove(name)) } - } - private def findUnusedOutputUse(e: ir.Expression)(implicit uses: mutable.ArrayBuffer[String]): Unit = e match { - case s: ir.SubField => - val name = s.serialize - if (unusedOutputs.contains(name)) uses.append(name) - case other => other.foreachExpr(findUnusedOutputUse) - } - private case class Context() extends TranslationContext {} + val state = State(memSymbol, init, Some(next)) - private def onExpression(e: ir.Expression, width: Int, allowNarrow: Boolean = false): BVExpr = { - implicit val ctx: TranslationContext = Context() - FirrtlExpressionSemantics.toSMT(e, width, allowNarrow) + // derive read expressions + val readSignals = m.readers.map { read => + val addr = BVSymbol(memPortField(m, read, "addr").serialize, indexWidth) + Signal(memPortField(m, read, "data").serialize, ArrayRead(memSymbol, addr), IsNode) + } + signals ++= readSignals + + state } - private def onExpression(e: ir.Expression): BVExpr = { - implicit val ctx: TranslationContext = Context() - FirrtlExpressionSemantics.toSMT(e) + + private def getMemInit(m: ir.DefMemory, indexWidth: Int, dataWidth: Int, initValue: MemoryInitValue): ArrayExpr = + initValue match { + case MemoryScalarInit(value) => ArrayConstant(BVLiteral(value, dataWidth), indexWidth) + case MemoryArrayInit(values) => + assert( + values.length == m.depth, + s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!" + ) + // in order to get a more compact encoding try to find the most common values + val histogram = mutable.LinkedHashMap[BigInt, Int]() + values.foreach(v => histogram(v) = 1 + histogram.getOrElse(v, 0)) + val baseValue = histogram.maxBy(_._2)._1 + val base = ArrayConstant(BVLiteral(baseValue, dataWidth), indexWidth) + values.zipWithIndex + .filterNot(_._1 == baseValue) + .foldLeft[ArrayExpr](base) { + case (array, (value, index)) => + ArrayStore(array, BVLiteral(index, indexWidth), BVLiteral(value, dataWidth)) + } + case other => throw new RuntimeException(s"Unsupported memory init option: $other") + } + + private def checkMem(m: ir.DefMemory): Unit = { + assert(m.readLatency == 0, "Expected read latency to be 0. Did you run VerilogMemDelays?") + assert(m.writeLatency == 1, "Expected read latency to be 1. Did you run VerilogMemDelays?") + assert( + m.dataType.isInstanceOf[ir.GroundType], + s"Memory $m is of type ${m.dataType} which is not a ground type!" + ) + assert(m.readwriters.isEmpty, "Combined read/write ports are not supported! Please split them up.") } + private def onExpression(e: ir.Expression, width: Int, allowNarrow: Boolean = false): BVExpr = + FirrtlExpressionSemantics.toSMT(e, width, allowNarrow) + private def onExpression(e: ir.Expression): BVExpr = FirrtlExpressionSemantics.toSMT(e) + private def error(msg: String): Unit = throw new RuntimeException(msg) private def isGroundType(tpe: ir.Type): Boolean = tpe.isInstanceOf[ir.GroundType] private def isClock(tpe: ir.Type): Boolean = tpe == ir.ClockType private def isAsyncReset(tpe: ir.Type): Boolean = tpe == ir.AsyncResetType } - -private object TopologicalSort { - - /** Ensures that all signals in the resulting system are topologically sorted. - * This is necessary because [[firrtl.transforms.RemoveWires]] does - * not sort assignments to outputs, submodule inputs nor memory ports. - */ - def run(sys: TransitionSystem): TransitionSystem = { - val inputsAndStates = sys.inputs.map(_.name) ++ sys.states.map(_.sym.name) - val signalOrder = sort(sys.signals.map(s => s.name -> s.e), inputsAndStates) - // TODO: maybe sort init expressions of states (this should not be needed most of the time) - signalOrder match { - case None => sys - case Some(order) => - val signalMap = sys.signals.map(s => s.name -> s).toMap - // we flatMap over `get` in order to ignore inputs/states in the order - sys.copy(signals = order.flatMap(signalMap.get).toArray) - } - } - - private def sort(signals: Iterable[(String, SMTExpr)], globalSignals: Iterable[String]): Option[Iterable[String]] = { - val known = new mutable.HashSet[String]() ++ globalSignals - var needsReordering = false - val digraph = new MutableDiGraph[String] - signals.foreach { - case (name, expr) => - digraph.addVertex(name) - val uniqueDependencies = mutable.LinkedHashSet[String]() ++ findDependencies(expr) - uniqueDependencies.foreach { d => - if (!known.contains(d)) { needsReordering = true } - digraph.addPairWithEdge(name, d) - } - known.add(name) - } - if (needsReordering) { - Some(digraph.linearize.reverse) - } else { None } - } - - private def findDependencies(expr: SMTExpr): List[String] = expr match { - case BVSymbol(name, _) => List(name) - case ArraySymbol(name, _, _) => List(name) - case other => other.children.flatMap(findDependencies) - } -} diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTCommand.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTCommand.scala new file mode 100644 index 00000000..21a64f98 --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTCommand.scala @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: Apache-2.0 +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> + +package firrtl.backends.experimental.smt + +private sealed trait SMTCommand +private case class Comment(msg: String) extends SMTCommand +private case class SetLogic(logic: String) extends SMTCommand +private case class DefineFunction(name: String, args: Seq[SMTFunctionArg], e: SMTExpr) extends SMTCommand +private case class DeclareFunction(sym: SMTSymbol, args: Seq[SMTFunctionArg]) extends SMTCommand +private case class DeclareUninterpretedSort(name: String) extends SMTCommand +private case class DeclareUninterpretedSymbol(name: String, tpe: String) extends SMTCommand diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala index 0fc507e6..a40717f9 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala @@ -5,9 +5,20 @@ package firrtl.backends.experimental.smt -private sealed trait SMTExpr { def children: List[SMTExpr] } -private sealed trait SMTSymbol extends SMTExpr with SMTNullaryExpr { val name: String } +/** base trait for all SMT expressions */ +private sealed trait SMTExpr extends SMTFunctionArg { + def tpe: SMTType + def children: List[SMTExpr] +} +private sealed trait SMTSymbol extends SMTExpr with SMTNullaryExpr { + def name: String + + /** keeps the type of the symbol while changing the name */ + def rename(newName: String): SMTSymbol +} private object SMTSymbol { + + /** makes a SMTSymbol of the same type as the expression */ def fromExpr(name: String, e: SMTExpr): SMTSymbol = e match { case b: BVExpr => BVSymbol(name, b.width) case a: ArrayExpr => ArraySymbol(name, a.indexWidth, a.dataWidth) @@ -17,91 +28,115 @@ private sealed trait SMTNullaryExpr extends SMTExpr { override def children: List[SMTExpr] = List() } -private sealed trait BVExpr extends SMTExpr { def width: Int } +/** a SMT bit vector expression: https://smtlib.cs.uiowa.edu/theories-FixedSizeBitVectors.shtml */ +private sealed trait BVExpr extends SMTExpr { + def width: Int + def tpe: BVType = BVType(width) + override def toString: String = SMTExprSerializer.serialize(this) +} private case class BVLiteral(value: BigInt, width: Int) extends BVExpr with SMTNullaryExpr { private def minWidth = value.bitLength + (if (value <= 0) 1 else 0) + assert(value >= 0, "Negative values are not supported! Please normalize by calculating 2s complement.") assert(width > 0, "Zero or negative width literals are not allowed!") assert(width >= minWidth, "Value (" + value.toString + ") too big for BitVector of width " + width + " bits.") - override def toString: String = if (width <= 8) { - width.toString + "'b" + value.toString(2) - } else { width.toString + "'x" + value.toString(16) } +} +private object BVLiteral { + def apply(nums: String): BVLiteral = nums.head match { + case 'b' => BVLiteral(BigInt(nums.drop(1), 2), nums.length - 1) + } } private case class BVSymbol(name: String, width: Int) extends BVExpr with SMTSymbol { assert(!name.contains("|"), s"Invalid id $name contains escape character `|`") - assert(!name.contains("\\"), s"Invalid id $name contains `\\`") assert(width > 0, "Zero width bit vectors are not supported!") - override def toString: String = name - def toStringWithType: String = name + " : " + SMTExpr.serializeType(this) + override def rename(newName: String) = BVSymbol(newName, width) } private sealed trait BVUnaryExpr extends BVExpr { def e: BVExpr + + /** same function, different child, e.g.: not(x) -- reapply(Y) --> not(Y) */ + def reapply(expr: BVExpr): BVUnaryExpr override def children: List[BVExpr] = List(e) } private case class BVExtend(e: BVExpr, by: Int, signed: Boolean) extends BVUnaryExpr { assert(by >= 0, "Extension must be non-negative!") override val width: Int = e.width + by - override def toString: String = if (signed) { s"sext($e, $by)" } - else { s"zext($e, $by)" } + override def reapply(expr: BVExpr) = BVExtend(expr, by, signed) } // also known as bit extract operation private case class BVSlice(e: BVExpr, hi: Int, lo: Int) extends BVUnaryExpr { assert(lo >= 0, s"lo (lsb) must be non-negative!") assert(hi >= lo, s"hi (msb) must not be smaller than lo (lsb): msb: $hi lsb: $lo") assert(e.width > hi, s"Out off bounds hi (msb) access: width: ${e.width} msb: $hi") - override def width: Int = hi - lo + 1 - override def toString: String = if (hi == lo) s"$e[$hi]" else s"$e[$hi:$lo]" + override def width: Int = hi - lo + 1 + override def reapply(expr: BVExpr) = BVSlice(expr, hi, lo) } private case class BVNot(e: BVExpr) extends BVUnaryExpr { - override val width: Int = e.width - override def toString: String = s"not($e)" + override val width: Int = e.width + override def reapply(expr: BVExpr) = new BVNot(expr) } private case class BVNegate(e: BVExpr) extends BVUnaryExpr { - override val width: Int = e.width - override def toString: String = s"neg($e)" + override val width: Int = e.width + override def reapply(expr: BVExpr) = BVNegate(expr) } + private case class BVReduceOr(e: BVExpr) extends BVUnaryExpr { - override def width: Int = 1 - override def toString: String = s"redor($e)" + override def width: Int = 1 + override def reapply(expr: BVExpr) = BVReduceOr(expr) } private case class BVReduceAnd(e: BVExpr) extends BVUnaryExpr { - override def width: Int = 1 - override def toString: String = s"redand($e)" + override def width: Int = 1 + override def reapply(expr: BVExpr) = BVReduceAnd(expr) } private case class BVReduceXor(e: BVExpr) extends BVUnaryExpr { - override def width: Int = 1 - override def toString: String = s"redxor($e)" + override def width: Int = 1 + override def reapply(expr: BVExpr) = BVReduceXor(expr) } private sealed trait BVBinaryExpr extends BVExpr { def a: BVExpr def b: BVExpr override def children: List[BVExpr] = List(a, b) -} -private case class BVImplies(a: BVExpr, b: BVExpr) extends BVBinaryExpr { - assert(a.width == 1 && b.width == 1, s"Both arguments need to be 1-bit!") - override def width: Int = 1 - override def toString: String = s"impl($a, $b)" + + /** same function, different child, e.g.: add(a,b) -- reapply(a,c) --> add(a,c) */ + def reapply(nA: BVExpr, nB: BVExpr): BVBinaryExpr } private case class BVEqual(a: BVExpr, b: BVExpr) extends BVBinaryExpr { assert(a.width == b.width, s"Both argument need to be the same width!") - override def width: Int = 1 - override def toString: String = s"eq($a, $b)" + override def width: Int = 1 + override def reapply(nA: BVExpr, nB: BVExpr) = BVEqual(nA, nB) } +// added as a separate node because it is used a lot in model checking and benefits from pretty printing +private class BVImplies(val a: BVExpr, val b: BVExpr) extends BVBinaryExpr { + assert(a.width == 1, s"The antecedent needs to be a boolean expression!") + assert(b.width == 1, s"The consequent needs to be a boolean expression!") + override def width: Int = 1 + override def reapply(nA: BVExpr, nB: BVExpr) = new BVImplies(nA, nB) +} +private object BVImplies { + def apply(a: BVExpr, b: BVExpr): BVExpr = { + assert(a.width == b.width, s"Both argument need to be the same width!") + (a, b) match { + case (True(), b) => b // (!1 || b) = b + case (False(), _) => True() // (!0 || _) = (1 || _) = 1 + case (_, True()) => True() // (!a || 1) = 1 + case (a, False()) => BVNot(a) // (!a || 0) = !a + case (a, b) => new BVImplies(a, b) + } + } + def unapply(i: BVImplies): Some[(BVExpr, BVExpr)] = Some((i.a, i.b)) +} + private object Compare extends Enumeration { val Greater, GreaterEqual = Value } private case class BVComparison(op: Compare.Value, a: BVExpr, b: BVExpr, signed: Boolean) extends BVBinaryExpr { assert(a.width == b.width, s"Both argument need to be the same width!") override def width: Int = 1 - override def toString: String = op match { - case Compare.Greater => (if (signed) "sgt" else "ugt") + s"($a, $b)" - case Compare.GreaterEqual => (if (signed) "sgeq" else "ugeq") + s"($a, $b)" - } + override def reapply(nA: BVExpr, nB: BVExpr) = BVComparison(op, nA, nB, signed) } + private object Op extends Enumeration { - val And = Value("and") - val Or = Value("or") val Xor = Value("xor") val ShiftLeft = Value("logical_shift_left") val ArithmeticShiftRight = Value("arithmetic_shift_right") @@ -117,51 +152,65 @@ private object Op extends Enumeration { } private case class BVOp(op: Op.Value, a: BVExpr, b: BVExpr) extends BVBinaryExpr { assert(a.width == b.width, s"Both argument need to be the same width!") - override val width: Int = a.width - override def toString: String = s"$op($a, $b)" + override val width: Int = a.width + override def reapply(nA: BVExpr, nB: BVExpr) = BVOp(op, nA, nB) } private case class BVConcat(a: BVExpr, b: BVExpr) extends BVBinaryExpr { - override val width: Int = a.width + b.width - override def toString: String = s"concat($a, $b)" + override val width: Int = a.width + b.width + override def reapply(nA: BVExpr, nB: BVExpr) = BVConcat(nA, nB) } private case class ArrayRead(array: ArrayExpr, index: BVExpr) extends BVExpr { assert(array.indexWidth == index.width, "Index with does not match expected array index width!") override val width: Int = array.dataWidth - override def toString: String = s"$array[$index]" override def children: List[SMTExpr] = List(array, index) } private case class BVIte(cond: BVExpr, tru: BVExpr, fals: BVExpr) extends BVExpr { assert(cond.width == 1, s"Condition needs to be a 1-bit value not ${cond.width}-bit!") assert(tru.width == fals.width, s"Both branches need to be of the same width! ${tru.width} vs ${fals.width}") override val width: Int = tru.width - override def toString: String = s"ite($cond, $tru, $fals)" override def children: List[BVExpr] = List(cond, tru, fals) } -/** apply bv arguments to a function which returns a result of bit vector type */ -private case class BVFunctionCall(name: String, args: List[BVExpr], width: Int) extends BVExpr { - override def children = args - def toSymbol: BVFunctionSymbol = BVFunctionSymbol(name, args.map(_.width), width) - override def toString: String = args.mkString(name + "(", ", ", ")") +private case class BVAnd(terms: List[BVExpr]) extends BVExpr { + require(terms.size > 1) + override val width: Int = terms.head.width + require(terms.forall(_.width == width)) + override def children: List[BVExpr] = terms } -private case class BVFunctionSymbol(name: String, argWidths: List[Int], width: Int) { - override def toString: String = s"$name : " + (argWidths :+ width).map(w => s"bv<$w>").mkString(" -> ") +private case class BVOr(terms: List[BVExpr]) extends BVExpr { + require(terms.size > 1) + override val width: Int = terms.head.width + require(terms.forall(_.width == width)) + override def children: List[BVExpr] = terms } -private sealed trait ArrayExpr extends SMTExpr { val indexWidth: Int; val dataWidth: Int } +private sealed trait ArrayExpr extends SMTExpr { + val indexWidth: Int + val dataWidth: Int + def tpe: ArrayType = ArrayType(indexWidth = indexWidth, dataWidth = dataWidth) + override def toString: String = SMTExprSerializer.serialize(this) +} private case class ArraySymbol(name: String, indexWidth: Int, dataWidth: Int) extends ArrayExpr with SMTSymbol { assert(!name.contains("|"), s"Invalid id $name contains escape character `|`") assert(!name.contains("\\"), s"Invalid id $name contains `\\`") - override def toString: String = name - def toStringWithType: String = s"$name : bv<$indexWidth> -> bv<$dataWidth>" + override def rename(newName: String) = ArraySymbol(newName, indexWidth, dataWidth) +} +private case class ArrayConstant(e: BVExpr, indexWidth: Int) extends ArrayExpr { + override val dataWidth: Int = e.width + override def children: List[SMTExpr] = List(e) +} +private case class ArrayEqual(a: ArrayExpr, b: ArrayExpr) extends BVExpr { + assert(a.indexWidth == b.indexWidth, s"Both argument need to be the same index width!") + assert(a.dataWidth == b.dataWidth, s"Both argument need to be the same data width!") + override def width: Int = 1 + override def children: List[SMTExpr] = List(a, b) } private case class ArrayStore(array: ArrayExpr, index: BVExpr, data: BVExpr) extends ArrayExpr { assert(array.indexWidth == index.width, "Index with does not match expected array index width!") assert(array.dataWidth == data.width, "Data with does not match expected array data width!") override val dataWidth: Int = array.dataWidth override val indexWidth: Int = array.indexWidth - override def toString: String = s"$array[$index := $data]" override def children: List[SMTExpr] = List(array, index, data) } private case class ArrayIte(cond: BVExpr, tru: ArrayExpr, fals: ArrayExpr) extends ArrayExpr { @@ -176,20 +225,79 @@ private case class ArrayIte(cond: BVExpr, tru: ArrayExpr, fals: ArrayExpr) exten ) override val dataWidth: Int = tru.dataWidth override val indexWidth: Int = tru.indexWidth - override def toString: String = s"ite($cond, $tru, $fals)" override def children: List[SMTExpr] = List(cond, tru, fals) } -private case class ArrayEqual(a: ArrayExpr, b: ArrayExpr) extends BVExpr { - assert(a.indexWidth == b.indexWidth, s"Both argument need to be the same index width!") - assert(a.dataWidth == b.dataWidth, s"Both argument need to be the same data width!") - override def width: Int = 1 - override def toString: String = s"eq($a, $b)" - override def children: List[SMTExpr] = List(a, b) + +private case class BVForall(variable: BVSymbol, e: BVExpr) extends BVUnaryExpr { + assert(e.width == 1, "Can only quantify over boolean expressions!") + override def width = 1 + override def reapply(expr: BVExpr) = BVForall(variable, expr) } -private case class ArrayConstant(e: BVExpr, indexWidth: Int) extends ArrayExpr { - override val dataWidth: Int = e.width - override def toString: String = s"([$e] x ${(BigInt(1) << indexWidth)})" - override def children: List[SMTExpr] = List(e) + +/** apply arguments to a function which returns a result of bit vector type */ +private case class BVFunctionCall(name: String, args: List[SMTFunctionArg], width: Int) extends BVExpr { + override def children = args.map(_.asInstanceOf[SMTExpr]) +} + +/** apply arguments to a function which returns a result of array type */ +private case class ArrayFunctionCall(name: String, args: List[SMTFunctionArg], indexWidth: Int, dataWidth: Int) + extends ArrayExpr { + override def children = args.map(_.asInstanceOf[SMTExpr]) +} +private sealed trait SMTFunctionArg +// we allow symbols with uninterpreted type to be function arguments +private case class UTSymbol(name: String, tpe: String) extends SMTFunctionArg + +private object BVAnd { + def apply(a: BVExpr, b: BVExpr): BVExpr = { + assert(a.width == b.width, s"Both argument need to be the same width!") + (a, b) match { + case (True(), b) => b + case (a, True()) => a + case (False(), _) => False() + case (_, False()) => False() + case (a, b) => new BVAnd(List(a, b)) + } + } + def apply(exprs: List[BVExpr]): BVExpr = { + assert(exprs.nonEmpty, "Don't know what to do with an empty list!") + val nonTriviallyTrue = exprs.filterNot(_ == True()) + nonTriviallyTrue.distinct match { + case Seq() => True() + case Seq(one) => one + case terms => new BVAnd(terms) + } + } +} +private object BVOr { + def apply(a: BVExpr, b: BVExpr): BVExpr = { + assert(a.width == b.width, s"Both argument need to be the same width!") + (a, b) match { + case (True(), _) => True() + case (_, True()) => True() + case (False(), b) => b + case (a, False()) => a + case (a, b) => new BVOr(List(a, b)) + } + } + def apply(exprs: List[BVExpr]): BVExpr = { + assert(exprs.nonEmpty, "Don't know what to do with an empty list!") + val nonTriviallyFalse = exprs.filterNot(_ == False()) + nonTriviallyFalse.distinct match { + case Seq() => False() + case Seq(one) => one + case terms => new BVOr(terms) + } + } +} + +private object BVNot { + def apply(e: BVExpr): BVExpr = e match { + case True() => False() + case False() => True() + case BVNot(inner) => inner + case other => new BVNot(other) + } } private object SMTEqual { @@ -200,6 +308,14 @@ private object SMTEqual { } } +private object SMTIte { + def apply(cond: BVExpr, tru: SMTExpr, fals: SMTExpr): SMTExpr = (tru, fals) match { + case (ab: BVExpr, bb: BVExpr) => BVIte(cond, ab, bb) + case (aa: ArrayExpr, ba: ArrayExpr) => ArrayIte(cond, aa, ba) + case _ => throw new RuntimeException(s"Cannot mux $tru and $fals") + } +} + private object SMTExpr { def serializeType(e: SMTExpr): String = e match { case b: BVExpr => s"bv<${b.width}>" @@ -207,8 +323,20 @@ private object SMTExpr { } } -// Raw SMTLib encoded expressions as an escape hatch used in the [[SMTTransitionSystemEncoder]] -private case class BVRawExpr(serialized: String, width: Int) extends BVExpr with SMTNullaryExpr -private case class ArrayRawExpr(serialized: String, indexWidth: Int, dataWidth: Int) - extends ArrayExpr - with SMTNullaryExpr +// unapply for matching BVLiteral(1, 1) +private object True { + private val _True = BVLiteral(1, 1) + def apply(): BVLiteral = _True + def unapply(l: BVLiteral): Boolean = l.value == 1 && l.width == 1 +} + +// unapply for matching BVLiteral(0, 1) +private object False { + private val _False = BVLiteral(0, 1) + def apply(): BVLiteral = _False + def unapply(l: BVLiteral): Boolean = l.value == 0 && l.width == 1 +} + +private sealed trait SMTType +private case class BVType(width: Int) extends SMTType +private case class ArrayType(indexWidth: Int, dataWidth: Int) extends SMTType diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExprMap.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExprMap.scala new file mode 100644 index 00000000..c991941f --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExprMap.scala @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: Apache-2.0 +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> +package firrtl.backends.experimental.smt + +private object SMTExprMap { + def mapExpr(expr: SMTExpr, f: SMTExpr => SMTExpr): SMTExpr = { + val bv = (b: BVExpr) => f(b).asInstanceOf[BVExpr] + val ar = (a: ArrayExpr) => f(a).asInstanceOf[ArrayExpr] + expr match { + case b: BVExpr => mapExpr(b, bv, ar) + case a: ArrayExpr => mapExpr(a, bv, ar) + } + } + + /** maps bv/ar over subexpressions of expr and returns expr with the results replaced */ + def mapExpr(expr: BVExpr, bv: BVExpr => BVExpr, ar: ArrayExpr => ArrayExpr): BVExpr = expr match { + // nullary + case old: BVLiteral => old + case old: BVSymbol => old + // unary + case old @ BVExtend(e, by, signed) => val n = bv(e); if (n.eq(e)) old else BVExtend(n, by, signed) + case old @ BVSlice(e, hi, lo) => val n = bv(e); if (n.eq(e)) old else BVSlice(n, hi, lo) + case old @ BVNot(e) => val n = bv(e); if (n.eq(e)) old else BVNot(n) + case old @ BVNegate(e) => val n = bv(e); if (n.eq(e)) old else BVNegate(n) + case old @ BVForall(variables, e) => val n = bv(e); if (n.eq(e)) old else BVForall(variables, n) + case old @ BVReduceAnd(e) => val n = bv(e); if (n.eq(e)) old else BVReduceAnd(n) + case old @ BVReduceOr(e) => val n = bv(e); if (n.eq(e)) old else BVReduceOr(n) + case old @ BVReduceXor(e) => val n = bv(e); if (n.eq(e)) old else BVReduceXor(n) + // binary + case old @ BVEqual(a, b) => + val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVEqual(nA, nB) + case old @ ArrayEqual(a, b) => + val (nA, nB) = (ar(a), ar(b)); if (nA.eq(a) && nB.eq(b)) old else ArrayEqual(nA, nB) + case old @ BVComparison(op, a, b, signed) => + val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVComparison(op, nA, nB, signed) + case old @ BVOp(op, a, b) => + val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVOp(op, nA, nB) + case old @ BVConcat(a, b) => + val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVConcat(nA, nB) + case old @ ArrayRead(a, b) => + val (nA, nB) = (ar(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else ArrayRead(nA, nB) + case old @ BVImplies(a, b) => + val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVImplies(nA, nB) + // ternary + case old @ BVIte(a, b, c) => + val (nA, nB, nC) = (bv(a), bv(b), bv(c)) + if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else BVIte(nA, nB, nC) + // n-ary + case old @ BVFunctionCall(name, args, width) => + val nArgs = args.map { + case b: BVExpr => bv(b) + case a: ArrayExpr => ar(a) + case u: UTSymbol => u + } + val anyNew = nArgs.zip(args).exists { case (n, o) => !n.eq(o) } + if (anyNew) BVFunctionCall(name, nArgs, width) else old + case old @ BVAnd(terms) => + val nTerms = terms.map(bv) + val anyNew = nTerms.zip(terms).exists { case (n, o) => !n.eq(o) } + if (anyNew) BVAnd(nTerms) else old + case old @ BVOr(terms) => + val nTerms = terms.map(bv) + val anyNew = nTerms.zip(terms).exists { case (n, o) => !n.eq(o) } + if (anyNew) BVOr(nTerms) else old + } + + /** maps bv/ar over subexpressions of expr and returns expr with the results replaced */ + def mapExpr(expr: ArrayExpr, bv: BVExpr => BVExpr, ar: ArrayExpr => ArrayExpr): ArrayExpr = expr match { + case old: ArraySymbol => old + case old @ ArrayConstant(e, indexWidth) => val n = bv(e); if (n.eq(e)) old else ArrayConstant(n, indexWidth) + case old @ ArrayStore(a, b, c) => + val (nA, nB, nC) = (ar(a), bv(b), bv(c)) + if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayStore(nA, nB, nC) + case old @ ArrayIte(a, b, c) => + val (nA, nB, nC) = (bv(a), ar(b), ar(c)) + if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayIte(nA, nB, nC) + case old @ ArrayFunctionCall(name, args, indexWidth, dataWidth) => + val nArgs = args.map { + case b: BVExpr => bv(b) + case a: ArrayExpr => ar(a) + case u: UTSymbol => u + } + val anyNew = nArgs.zip(args).exists { case (n, o) => !n.eq(o) } + if (anyNew) ArrayFunctionCall(name, nArgs, indexWidth, dataWidth) else old + } +} diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExprSerializer.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExprSerializer.scala new file mode 100644 index 00000000..4aaf78a2 --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExprSerializer.scala @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: Apache-2.0 +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> + +package firrtl.backends.experimental.smt + +private object SMTExprSerializer { + def serialize(expr: BVExpr): String = expr match { + // nullary + case lit: BVLiteral => + if (lit.width <= 8) { + lit.width.toString + "'b" + lit.value.toString(2) + } else { + lit.width.toString + "'x" + lit.value.toString(16) + } + case BVSymbol(name, _) => name + // unary + case BVExtend(e, by, false) => s"zext(${serialize(e)}, $by)" + case BVExtend(e, by, true) => s"sext(${serialize(e)}, $by)" + case BVSlice(e, hi, lo) if hi == lo => s"${serialize(e)}[$hi]" + case BVSlice(e, hi, lo) => s"${serialize(e)}[$hi:$lo]" + case BVNot(e) => s"not(${serialize(e)})" + case BVNegate(e) => s"neg(${serialize(e)})" + case BVForall(variable, e) => s"forall(${variable.name} : bv<${variable.width}, ${serialize(e)})" + case BVReduceAnd(e) => s"redand(${serialize(e)})" + case BVReduceOr(e) => s"redor(${serialize(e)})" + case BVReduceXor(e) => s"redxor(${serialize(e)})" + // binary + case BVEqual(a, b) => s"eq(${serialize(a)}, ${serialize(b)})" + case BVComparison(Compare.Greater, a, b, false) => s"ugt(${serialize(a)}, ${serialize(b)})" + case BVComparison(Compare.Greater, a, b, true) => s"sgt(${serialize(a)}, ${serialize(b)})" + case BVComparison(Compare.GreaterEqual, a, b, false) => s"ugeq(${serialize(a)}, ${serialize(b)})" + case BVComparison(Compare.GreaterEqual, a, b, true) => s"sgeq(${serialize(a)}, ${serialize(b)})" + case BVOp(op, a, b) => s"$op(${serialize(a)}, ${serialize(b)})" + case BVConcat(a, b) => s"concat(${serialize(a)}, ${serialize(b)})" + case ArrayRead(array, index) => s"${serialize(array)}[${serialize(index)}]" + case ArrayEqual(a, b) => s"eq(${serialize(a)}, ${serialize(b)})" + case BVImplies(a, b) => s"implies(${serialize(a)}, ${serialize(b)})" + // ternary + case BVIte(cond, tru, fals) => s"ite(${serialize(cond)}, ${serialize(tru)}, ${serialize(fals)})" + // n-ary + case BVFunctionCall(name, args, _) => name + serialize(args).mkString("(", ",", ")") + case BVAnd(terms) => terms.map(serialize).mkString("and(", ", ", ")") + case BVOr(terms) => terms.map(serialize).mkString("or(", ", ", ")") + } + + def serialize(expr: ArrayExpr): String = expr match { + case ArraySymbol(name, _, _) => name + case ArrayConstant(e, indexWidth) => s"([${serialize(e)}] x ${(BigInt(1) << indexWidth)})" + case ArrayStore(array, index, data) => s"${serialize(array)}[${serialize(index)} := ${serialize(data)}]" + case ArrayIte(cond, tru, fals) => s"ite(${serialize(cond)}, ${serialize(tru)}, ${serialize(fals)})" + case ArrayFunctionCall(name, args, _, _) => name + serialize(args).mkString("(", ",", ")") + } + + private def serialize(args: Iterable[SMTFunctionArg]): Iterable[String] = + args.map { + case b: BVExpr => serialize(b) + case a: ArrayExpr => serialize(a) + case u: UTSymbol => u.name + } +} diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala deleted file mode 100644 index 13ed8bdd..00000000 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala +++ /dev/null @@ -1,77 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> - -package firrtl.backends.experimental.smt - -/** Similar to the mapExpr and foreachExpr methods of the firrtl ir nodes, but external to the case classes */ -private object SMTExprVisitor { - type ArrayFun = ArrayExpr => ArrayExpr - type BVFun = BVExpr => BVExpr - - def map[T <: SMTExpr](bv: BVFun, ar: ArrayFun)(e: T): T = e match { - case b: BVExpr => map(b, bv, ar).asInstanceOf[T] - case a: ArrayExpr => map(a, bv, ar).asInstanceOf[T] - } - def map[T <: SMTExpr](f: SMTExpr => SMTExpr)(e: T): T = - map(b => f(b).asInstanceOf[BVExpr], a => f(a).asInstanceOf[ArrayExpr])(e) - - private def map(e: BVExpr, bv: BVFun, ar: ArrayFun): BVExpr = e match { - // nullary - case old: BVLiteral => bv(old) - case old: BVSymbol => bv(old) - case old: BVRawExpr => bv(old) - // unary - case old @ BVExtend(e, by, signed) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVExtend(n, by, signed)) - case old @ BVSlice(e, hi, lo) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVSlice(n, hi, lo)) - case old @ BVNot(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVNot(n)) - case old @ BVNegate(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVNegate(n)) - case old @ BVReduceAnd(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVReduceAnd(n)) - case old @ BVReduceOr(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVReduceOr(n)) - case old @ BVReduceXor(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVReduceXor(n)) - // binary - case old @ BVImplies(a, b) => - val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) - bv(if (nA.eq(a) && nB.eq(b)) old else BVImplies(nA, nB)) - case old @ BVEqual(a, b) => - val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) - bv(if (nA.eq(a) && nB.eq(b)) old else BVEqual(nA, nB)) - case old @ ArrayEqual(a, b) => - val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) - bv(if (nA.eq(a) && nB.eq(b)) old else ArrayEqual(nA, nB)) - case old @ BVComparison(op, a, b, signed) => - val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) - bv(if (nA.eq(a) && nB.eq(b)) old else BVComparison(op, nA, nB, signed)) - case old @ BVOp(op, a, b) => - val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) - bv(if (nA.eq(a) && nB.eq(b)) old else BVOp(op, nA, nB)) - case old @ BVConcat(a, b) => - val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) - bv(if (nA.eq(a) && nB.eq(b)) old else BVConcat(nA, nB)) - case old @ ArrayRead(a, b) => - val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) - bv(if (nA.eq(a) && nB.eq(b)) old else ArrayRead(nA, nB)) - // ternary - case old @ BVIte(a, b, c) => - val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar)) - bv(if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else BVIte(nA, nB, nC)) - // n-ary - case old @ BVFunctionCall(name, args, width) => - val nArgs = args.map(a => map(a, bv, ar)) - val noneNew = nArgs.zip(args).forall { case (n, o) => n.eq(o) } - bv(if (noneNew) old else BVFunctionCall(name, nArgs, width)) - } - - private def map(e: ArrayExpr, bv: BVFun, ar: ArrayFun): ArrayExpr = e match { - case old: ArrayRawExpr => ar(old) - case old: ArraySymbol => ar(old) - case old @ ArrayConstant(e, indexWidth) => - val n = map(e, bv, ar); ar(if (n.eq(e)) old else ArrayConstant(n, indexWidth)) - case old @ ArrayStore(a, b, c) => - val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar)) - ar(if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayStore(nA, nB, nC)) - case old @ ArrayIte(a, b, c) => - val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar)) - ar(if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayIte(nA, nB, nC)) - } - -} diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala index 75bde09c..bb4e0348 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala @@ -19,14 +19,9 @@ private object SMTLibSerializer { case a: ArrayExpr => serialize(a) } - def serializeType(e: SMTExpr): String = e match { - case b: BVExpr => serializeBitVectorType(b.width) - case a: ArrayExpr => serializeArrayType(a.indexWidth, a.dataWidth) - } - - def declareFunction(foo: BVFunctionSymbol): SMTCommand = { - val args = foo.argWidths.map(serializeBitVectorType) - DeclareFunction(BVSymbol(foo.name, foo.width), args) + def serialize(t: SMTType): String = t match { + case BVType(width) => serializeBitVectorType(width) + case ArrayType(indexWidth, dataWidth) => serializeArrayType(indexWidth, dataWidth) } private def serialize(e: BVExpr): String = e match { @@ -71,37 +66,57 @@ private object SMTLibSerializer { case BVComparison(Compare.Greater, a, b, true) => s"(bvsgt ${asBitVector(a)} ${asBitVector(b)})" case BVComparison(Compare.GreaterEqual, a, b, true) => s"(bvsge ${asBitVector(a)} ${asBitVector(b)})" // boolean operations get a special treatment for 1-bit vectors aka bools - case BVOp(Op.And, a, b) if a.width == 1 => s"(and ${serialize(a)} ${serialize(b)})" - case BVOp(Op.Or, a, b) if a.width == 1 => s"(or ${serialize(a)} ${serialize(b)})" + case b: BVAnd => serializeVariadic(if (b.width == 1) "and" else "bvand", b.terms) + case b: BVOr => serializeVariadic(if (b.width == 1) "or" else "bvor", b.terms) case BVOp(Op.Xor, a, b) if a.width == 1 => s"(xor ${serialize(a)} ${serialize(b)})" case BVOp(op, a, b) if a.width == 1 => toBool(s"(${serialize(op)} ${asBitVector(a)} ${asBitVector(b)})") case BVOp(op, a, b) => s"(${serialize(op)} ${serialize(a)} ${serialize(b)})" case BVConcat(a, b) => s"(concat ${asBitVector(a)} ${asBitVector(b)})" case ArrayRead(array, index) => s"(select ${serialize(array)} ${asBitVector(index)})" case BVIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})" - case BVFunctionCall(name, args, _) => args.map(serialize).mkString(s"($name ", " ", ")") - case BVRawExpr(serialized, _) => serialized + case BVFunctionCall(name, args, _) => args.map(serializeArg).mkString(s"($name ", " ", ")") + case BVForall(variable, e) => s"(forall ((${variable.name} ${serialize(variable.tpe)})) ${serialize(e)})" + } + + private def serializeVariadic(op: String, terms: List[BVExpr]): String = terms match { + case Seq() | Seq(_) => throw new RuntimeException(s"expected at least two elements in variadic op $op") + case Seq(a, b) => s"($op ${serialize(a)} ${serialize(b)})" + case head :: tail => s"($op ${serialize(head)} ${serializeVariadic(op, tail)})" } def serialize(e: ArrayExpr): String = e match { - case ArraySymbol(name, _, _) => escapeIdentifier(name) - case ArrayStore(array, index, data) => s"(store ${serialize(array)} ${serialize(index)} ${serialize(data)})" - case ArrayIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})" - case c @ ArrayConstant(e, _) => s"((as const ${serializeArrayType(c.indexWidth, c.dataWidth)}) ${serialize(e)})" - case ArrayRawExpr(serialized, _, _) => serialized + case ArraySymbol(name, _, _) => escapeIdentifier(name) + case ArrayStore(array, index, data) => s"(store ${serialize(array)} ${serialize(index)} ${serialize(data)})" + case ArrayIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})" + case c @ ArrayConstant(e, _) => s"((as const ${serializeArrayType(c.indexWidth, c.dataWidth)}) ${serialize(e)})" + case ArrayFunctionCall(name, args, _, _) => args.map(serializeArg).mkString(s"($name ", " ", ")") } def serialize(c: SMTCommand): String = c match { case Comment(msg) => msg.split("\n").map("; " + _).mkString("\n") case DeclareUninterpretedSort(name) => s"(declare-sort ${escapeIdentifier(name)} 0)" case DefineFunction(name, args, e) => - val aa = args.map(a => s"(${escapeIdentifier(a._1)} ${a._2})").mkString(" ") - s"(define-fun ${escapeIdentifier(name)} ($aa) ${serializeType(e)} ${serialize(e)})" + val aa = args.map(a => s"(${serializeArg(a)} ${serializeArgTpe(a)})").mkString(" ") + s"(define-fun ${escapeIdentifier(name)} ($aa) ${serialize(e.tpe)} ${serialize(e)})" case DeclareFunction(sym, tpes) => - val aa = tpes.mkString(" ") - s"(declare-fun ${escapeIdentifier(sym.name)} ($aa) ${serializeType(sym)})" + val aa = tpes.map(serializeArgTpe).mkString(" ") + s"(declare-fun ${escapeIdentifier(sym.name)} ($aa) ${serialize(sym.tpe)})" + case SetLogic(logic) => s"(set-logic $logic)" + case DeclareUninterpretedSymbol(name, tpe) => + s"(declare-fun ${escapeIdentifier(name)} () ${escapeIdentifier(tpe)})" } + private def serializeArgTpe(a: SMTFunctionArg): String = + a match { + case u: UTSymbol => escapeIdentifier(u.tpe) + case s: SMTExpr => serialize(s.tpe) + } + private def serializeArg(a: SMTFunctionArg): String = + a match { + case u: UTSymbol => escapeIdentifier(u.name) + case s: SMTExpr => serialize(s) + } + private def serializeArrayType(indexWidth: Int, dataWidth: Int): String = s"(Array ${serializeBitVectorType(indexWidth)} ${serializeBitVectorType(dataWidth)})" private def serializeBitVectorType(width: Int): String = @@ -109,8 +124,6 @@ private object SMTLibSerializer { else { assert(width > 1); s"(_ BitVec $width)" } private def serialize(op: Op.Value): String = op match { - case Op.And => "bvand" - case Op.Or => "bvor" case Op.Xor => "bvxor" case Op.ArithmeticShiftRight => "bvashr" case Op.ShiftRight => "bvlshr" diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala index d35fe139..472363cc 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala @@ -17,37 +17,41 @@ private object SMTTransitionSystemEncoder { val cmds = mutable.ArrayBuffer[SMTCommand]() val name = sys.name - // emit header as comments - cmds ++= sys.header.map(Comment) + // declare UFs if necessary + cmds ++= TransitionSystem.findUninterpretedFunctions(sys) - // declare uninterpreted functions used in model - cmds ++= sys.ufs.map(SMTLibSerializer.declareFunction) + // emit header as comments + if (sys.header.nonEmpty) { + cmds ++= sys.header.split('\n').map(Comment) + } // declare state type val stateType = id(name + "_s") cmds += DeclareUninterpretedSort(stateType) + // state symbol + val State = UTSymbol("state", stateType) + val StateNext = UTSymbol("state_n", stateType) + // inputs and states are modelled as constants def declare(sym: SMTSymbol, kind: String): Unit = { cmds ++= toDescription(sym, kind, sys.comments.get) val s = SMTSymbol.fromExpr(sym.name + SignalSuffix, sym) - cmds += DeclareFunction(s, List(stateType)) + cmds += DeclareFunction(s, List(State)) } sys.inputs.foreach(i => declare(i, "input")) sys.states.foreach(s => declare(s.sym, "register")) // signals are just functions of other signals, inputs and state def define(sym: SMTSymbol, e: SMTExpr, suffix: String = SignalSuffix): Unit = { - cmds += DefineFunction(sym.name + suffix, List((State, stateType)), replaceSymbols(e)) + val withReplacedSymbols = replaceSymbols(SignalSuffix, State)(e) + cmds += DefineFunction(sym.name + suffix, List(State), withReplacedSymbols) } sys.signals.foreach { signal => - val kind = if (sys.outputs.contains(signal.name)) { "output" } - else if (sys.assumes.contains(signal.name)) { "assume" } - else if (sys.asserts.contains(signal.name)) { "assert" } - else { "wire" } - val sym = SMTSymbol.fromExpr(signal.name, signal.e) - cmds ++= toDescription(sym, kind, sys.comments.get) - define(sym, signal.e) + val sym = signal.sym + cmds ++= toDescription(sym, lblToKind(signal.lbl), sys.comments.get) + val e = if (signal.lbl == IsBad) BVNot(signal.e.asInstanceOf[BVExpr]) else signal.e + define(sym, e) } // define the next and init functions for all states @@ -60,72 +64,70 @@ private object SMTTransitionSystemEncoder { } } - def defineConjunction(e: Iterable[BVExpr], suffix: String): Unit = { - define(BVSymbol(name, 1), andReduce(e), suffix) + def defineConjunction(e: List[BVExpr], suffix: String): Unit = { + define(BVSymbol(name, 1), if (e.isEmpty) True() else BVAnd(e), suffix) } // the transition relation asserts that the value of the next state is the next value from the previous state // e.g., (reg state_n) == (reg_next state) val transitionRelations = sys.states.map { state => - val newState = symbolToFunApp(state.sym, SignalSuffix, StateNext) - val nextOldState = symbolToFunApp(state.sym, NextSuffix, State) + val newState = replaceSymbols(SignalSuffix, StateNext)(state.sym) + val nextOldState = replaceSymbols(NextSuffix, State)(state.sym) SMTEqual(newState, nextOldState) } // the transition relation is over two states - val transitionExpr = replaceSymbols(andReduce(transitionRelations)) - cmds += DefineFunction(name + "_t", List((State, stateType), (StateNext, stateType)), transitionExpr) + val transitionExpr = if (transitionRelations.isEmpty) { True() } + else { + replaceSymbols(SignalSuffix, State)(BVAnd(transitionRelations)) + } + cmds += DefineFunction(name + "_t", List(State, StateNext), transitionExpr) // The init relation just asserts that all init function hold val initRelations = sys.states.filter(_.init.isDefined).map { state => - val stateSignal = symbolToFunApp(state.sym, SignalSuffix, State) - val initSignal = symbolToFunApp(state.sym, InitSuffix, State) + val stateSignal = replaceSymbols(SignalSuffix, State)(state.sym) + val initSignal = replaceSymbols(InitSuffix, State)(state.sym) SMTEqual(stateSignal, initSignal) } defineConjunction(initRelations, "_i") // assertions and assumptions - val assertions = sys.signals.filter(a => sys.asserts.contains(a.name)).map(a => replaceSymbols(a.toSymbol)) - defineConjunction(assertions, "_a") - val assumptions = sys.signals.filter(a => sys.assumes.contains(a.name)).map(a => replaceSymbols(a.toSymbol)) - defineConjunction(assumptions, "_u") + val assertions = sys.signals.filter(_.lbl == IsBad).map(a => replaceSymbols(SignalSuffix, State)(a.sym)) + defineConjunction(assertions.map(_.asInstanceOf[BVExpr]), AssertionSuffix) + val assumptions = sys.signals.filter(_.lbl == IsConstraint).map(a => replaceSymbols(SignalSuffix, State)(a.sym)) + defineConjunction(assumptions.map(_.asInstanceOf[BVExpr]), AssumptionSuffix) cmds } private def id(s: String): String = SMTLibSerializer.escapeIdentifier(s) - private val State = "state" - private val StateNext = "state_n" private val SignalSuffix = "_f" private val NextSuffix = "_next" private val InitSuffix = "_init" + val AssertionSuffix = "_a" + val AssumptionSuffix = "_u" + private def lblToKind(lbl: SignalLabel): String = lbl match { + case IsNode | IsInit | IsNext => "wire" + case IsOutput => "output" + // for the SMT encoding we turn bad state signals back into assertions + case IsBad => "assert" + case IsConstraint => "assume" + case IsFair => "fair" + } private def toDescription(sym: SMTSymbol, kind: String, comments: String => Option[String]): List[Comment] = { List(sym match { - case BVSymbol(name, width) => - Comment(s"firrtl-smt2-$kind $name $width") + case BVSymbol(name, width) => Comment(s"firrtl-smt2-$kind $name $width") case ArraySymbol(name, indexWidth, dataWidth) => Comment(s"firrtl-smt2-$kind $name $indexWidth $dataWidth") }) ++ comments(sym.name).map(Comment) } - - private def andReduce(e: Iterable[BVExpr]): BVExpr = - if (e.isEmpty) BVLiteral(1, 1) else e.reduce((a, b) => BVOp(Op.And, a, b)) - // All signals are modelled with functions that need to be called with the state as argument, // this replaces all Symbols with function applications to the state. - private def replaceSymbols(e: SMTExpr): SMTExpr = { - SMTExprVisitor.map(symbolToFunApp(_, SignalSuffix, State))(e) - } - private def replaceSymbols(e: BVExpr): BVExpr = replaceSymbols(e.asInstanceOf[SMTExpr]).asInstanceOf[BVExpr] - private def symbolToFunApp(sym: SMTExpr, suffix: String, arg: String): SMTExpr = sym match { - case BVSymbol(name, width) => BVRawExpr(s"(${id(name + suffix)} $arg)", width) - case ArraySymbol(name, indexWidth, dataWidth) => ArrayRawExpr(s"(${id(name + suffix)} $arg)", indexWidth, dataWidth) - case other => other - } + private def replaceSymbols(suffix: String, arg: SMTFunctionArg, vars: Set[String] = Set())(e: SMTExpr): SMTExpr = + e match { + case BVSymbol(name, width) if !vars(name) => BVFunctionCall(id(name + suffix), List(arg), width) + case ArraySymbol(name, indexWidth, dataWidth) if !vars(name) => + ArrayFunctionCall(id(name + suffix), List(arg), indexWidth, dataWidth) + case fa @ BVForall(variable, _) => SMTExprMap.mapExpr(fa, replaceSymbols(suffix, arg, vars + variable.name)) + case other => SMTExprMap.mapExpr(other, replaceSymbols(suffix, arg, vars)) + } } - -/** minimal set of pseudo SMT commands needed for our encoding */ -private sealed trait SMTCommand -private case class Comment(msg: String) extends SMTCommand -private case class DefineFunction(name: String, args: Seq[(String, String)], e: SMTExpr) extends SMTCommand -private case class DeclareFunction(sym: SMTSymbol, tpes: Seq[String]) extends SMTCommand -private case class DeclareUninterpretedSort(name: String) extends SMTCommand diff --git a/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala b/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala index eac9f00a..5db39ac9 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala @@ -3,13 +3,14 @@ package firrtl.backends.experimental.smt -import firrtl.{ir, CircuitState, DependencyAPIMigration, Namespace, PrimOps, RenameMap, Transform, Utils} -import firrtl.annotations.{Annotation, CircuitTarget, PresetAnnotation, ReferenceTarget, SingleTargetAnnotation} +import firrtl._ +import firrtl.annotations._ import firrtl.ir.EmptyStmt import firrtl.options.Dependency import firrtl.passes.PassException import firrtl.stage.Forms import firrtl.stage.TransformManager.TransformDependency +import firrtl.transforms.PropagatePresetAnnotations import scala.collection.mutable @@ -30,7 +31,10 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration { // this pass needs to run *before* converting to a transition system override def optionalPrerequisiteOf: Seq[TransformDependency] = Seq(Dependency(FirrtlToTransitionSystem)) // since this pass only runs on the main module, inlining needs to happen before - override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency[firrtl.passes.InlineInstances]) + override def optionalPrerequisites: Seq[TransformDependency] = Seq( + Dependency[firrtl.passes.InlineInstances], + Dependency[PropagatePresetAnnotations] + ) override protected def execute(state: CircuitState): CircuitState = { if (state.circuit.modules.size > 1) { @@ -66,10 +70,10 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration { // replace all other clocks with enable signals, unless they are the global clock val clocks = portsWithGlobalClock.filter(p => p.tpe == ir.ClockType && p.name != globalClock).map(_.name) val clockToEnable = clocks.map { c => - c -> ir.Reference(namespace.newName(c + "_en"), Bool, firrtl.PortKind, firrtl.SourceFlow) + c -> ir.Reference(namespace.newName(c + "_en"), Utils.BoolType, firrtl.PortKind, firrtl.SourceFlow) }.toMap val portsWithEnableSignals = portsWithGlobalClock.map { p => - if (clockToEnable.contains(p.name)) { p.copy(name = clockToEnable(p.name).name, tpe = Bool) } + if (clockToEnable.contains(p.name)) { p.copy(name = clockToEnable(p.name).name, tpe = Utils.BoolType) } else { p } } // replace async reset with synchronous reset (since everything will we synchronous with the global clock) @@ -78,9 +82,12 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration { val isPresetReset = state.annotations.collect { case PresetAnnotation(r) if r.module == main.name => r.ref }.toSet val resetsToChange = asyncResets.filterNot(isPresetReset).toSet val portsWithSyncReset = portsWithEnableSignals.map { p => - if (resetsToChange.contains(p.name)) { p.copy(tpe = Bool) } + if (resetsToChange.contains(p.name)) { p.copy(tpe = Utils.BoolType) } else { p } } + val presetRegs = state.annotations.collect { + case PresetRegAnnotation(target) if target.module == mainName => target.ref + }.toSet // discover clock and reset connections val scan = scanClocks(main, clockToEnable, resetsToChange) @@ -94,7 +101,7 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration { } // make changes - implicit val ctx: Context = new Context(globalClock, scan) + implicit val ctx: Context = new Context(globalClock, scan, presetRegs) val newMain = main.copy(ports = portsWithSyncReset).mapStmt(onStatement) val nonMainModules = state.circuit.modules.filterNot(_.name == state.circuit.main) @@ -119,15 +126,19 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration { // for write ports we guard the write enable with the clock enable signal, similar to registers if (isWritePort) { val clockEn = ctx.memPortToClockEnable(mem + "." + port) - val guardedEnable = and(clockEn, c.expr) + val guardedEnable = Utils.and(clockEn, c.expr) c.copy(expr = guardedEnable) } else { c } } else { c } // register field connects case c @ ir.Connect(_, r: ir.Reference, next) if ctx.registerToEnable.contains(r.name) => val clockEnable = ctx.registerToEnable(r.name) - val guardedNext = mux(clockEnable, next, r) - c.copy(expr = guardedNext) + val guardedNext = Utils.mux(clockEnable, next, r) + val withReset = ctx.registerToAsyncReset.get(r.name) match { + case None => guardedNext + case Some((asyncReset, init)) => Utils.mux(asyncReset, init, guardedNext) + } + c.copy(expr = withReset) // remove other clock wires and nodes case ir.Connect(_, loc, expr) if expr.tpe == ir.ClockType && ctx.isRemovedClock(loc.serialize) => EmptyStmt case ir.DefNode(_, name, value) if value.tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt @@ -135,21 +146,16 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration { // change async reset to synchronous reset case ir.Connect(info, loc: ir.Reference, expr: ir.Reference) if expr.tpe == ir.AsyncResetType && ctx.isResetToChange(loc.serialize) => - ir.Connect(info, loc.copy(tpe = Bool), expr.copy(tpe = Bool)) + ir.Connect(info, loc.copy(tpe = Utils.BoolType), expr.copy(tpe = Utils.BoolType)) case d @ ir.DefNode(_, name, value: ir.Reference) if value.tpe == ir.AsyncResetType && ctx.isResetToChange(name) => - d.copy(value = value.copy(tpe = Bool)) - case d @ ir.DefWire(_, name, tpe) if tpe == ir.AsyncResetType && ctx.isResetToChange(name) => d.copy(tpe = Bool) + d.copy(value = value.copy(tpe = Utils.BoolType)) + case d @ ir.DefWire(_, name, tpe) if tpe == ir.AsyncResetType && ctx.isResetToChange(name) => + d.copy(tpe = Utils.BoolType) // change memory clock and synchronize reset - case ir.DefRegister(info, name, tpe, clock, reset, init) if ctx.registerToEnable.contains(name) => - val clockEnable = ctx.registerToEnable(name) - val newReset = reset match { - case r @ ir.Reference(name, _, _, _) if ctx.isResetToChange(name) => r.copy(tpe = Bool) - case other => other - } - val synchronizedReset = if (reset.tpe == ir.AsyncResetType) { newReset } - else { and(newReset, clockEnable) } - ir.DefRegister(info, name, tpe, ctx.globalClock, synchronizedReset, init) + case ir.DefRegister(info, name, tpe, _, _, init) if ctx.registerToEnable.contains(name) => + val newInit = if (ctx.isPresetReg(name)) init else ir.Reference(name, tpe, RegKind, SourceFlow) + ir.DefRegister(info, name, tpe, ctx.globalClock, Utils.False(), newInit) case other => other.mapStmt(onStatement) } } @@ -189,10 +195,14 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration { case ir.DefNode(_, name, value) if value.tpe == ir.AsyncResetType && ctx.resetsToChange(value.serialize) => ctx.resetsToChange.add(name) // modify clocked elements - case ir.DefRegister(_, name, _, clock, _, _) => + case ir.DefRegister(_, name, _, clock, reset, init) => ctx.clockToEnable.get(clock.serialize).foreach { clockEnable => ctx.registerToEnable.append(name -> clockEnable) } + reset match { + case Utils.False() => + case other => ctx.registerToAsyncReset.append(name -> (other, init)) + } case m: ir.DefMemory => assert(m.readwriters.isEmpty, "Combined read/write ports are not supported!") assert(m.readLatency == 0 || m.readLatency == 1, "Only read-latency 1 and read latency 0 are supported!") @@ -229,18 +239,22 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration { val resetsToChange = mutable.HashSet[String]() ++ initialResetsToChange // registers whose next function needs to be guarded with a clock enable val registerToEnable = mutable.ArrayBuffer[(String, ir.Reference)]() + // registers with asynchronous reset + val registerToAsyncReset = mutable.ArrayBuffer[(String, (ir.Expression, ir.Expression))]() // memory enables which need to be guarded with clock enables val memPortToClockEnable = mutable.ArrayBuffer[(String, ir.Reference)]() // keep track of memory names val mems = mutable.HashMap[String, ir.DefMemory]() } - private class Context(globalClockName: String, scanResults: ScanCtx) { + private class Context(globalClockName: String, scanResults: ScanCtx, val isPresetReg: String => Boolean) { val globalClock: ir.Reference = ir.Reference(globalClockName, ir.ClockType, firrtl.PortKind, firrtl.SourceFlow) // keeps track of which clock signals will be replaced by which clock enable signal val isRemovedClock: String => Boolean = scanResults.clockToEnable.contains // registers whose next function needs to be guarded with a clock enable val registerToEnable: Map[String, ir.Reference] = scanResults.registerToEnable.toMap + // registers with asynchronous reset + val registerToAsyncReset: Map[String, (ir.Expression, ir.Expression)] = scanResults.registerToAsyncReset.toMap // memory enables which need to be guarded with clock enables val memPortToClockEnable: Map[String, ir.Reference] = scanResults.memPortToClockEnable.toMap // keep track of memory names @@ -252,13 +266,6 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration { private var mainName: String = "" // for debugging private def unsupportedError(msg: String): Nothing = throw new UnsupportedFeatureException(s"StutteringClockTransform: [$mainName] $msg") - - private def mux(cond: ir.Expression, a: ir.Expression, b: ir.Expression): ir.Expression = { - ir.Mux(cond, a, b, Utils.mux_type_and_widths(a, b)) - } - private def and(a: ir.Expression, b: ir.Expression): ir.Expression = - ir.DoPrim(PrimOps.And, List(a, b), List(), Bool) - private val Bool = ir.UIntType(ir.IntWidth(1)) } private class UnsupportedFeatureException(s: String) extends PassException(s) diff --git a/src/main/scala/firrtl/backends/experimental/smt/TransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/TransitionSystem.scala new file mode 100644 index 00000000..66a1b385 --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/TransitionSystem.scala @@ -0,0 +1,120 @@ +// SPDX-License-Identifier: Apache-2.0 +// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> + +package firrtl.backends.experimental.smt + +import firrtl.graph.MutableDiGraph +import scala.collection.mutable + +private case class State(sym: SMTSymbol, init: Option[SMTExpr], next: Option[SMTExpr]) { + def name: String = sym.name +} +private case class Signal(name: String, e: SMTExpr, lbl: SignalLabel = IsNode) { + def toSymbol: SMTSymbol = SMTSymbol.fromExpr(name, e) + def sym: SMTSymbol = toSymbol +} +private case class TransitionSystem( + name: String, + inputs: List[BVSymbol], + states: List[State], + signals: List[Signal], + comments: Map[String, String] = Map(), + header: String = "") { + def serialize: String = TransitionSystem.serialize(this) +} + +private sealed trait SignalLabel +private case object IsNode extends SignalLabel +private case object IsOutput extends SignalLabel +private case object IsConstraint extends SignalLabel +private case object IsBad extends SignalLabel +private case object IsFair extends SignalLabel +private case object IsNext extends SignalLabel +private case object IsInit extends SignalLabel + +private object SignalLabel { + private val labels = Seq(IsNode, IsOutput, IsConstraint, IsBad, IsFair, IsNext, IsInit) + val labelStrings = Seq("node", "output", "constraint", "bad", "fair", "next", "init") + val labelToString: SignalLabel => String = labels.zip(labelStrings).toMap + val stringToLabel: String => SignalLabel = labelStrings.zip(labels).toMap +} + +private object TransitionSystem { + def serialize(sys: TransitionSystem): String = { + (Iterator(sys.name) ++ + sys.inputs.map(i => s"input ${i.name} : ${SMTExpr.serializeType(i)}") ++ + sys.signals.map(s => s"${SignalLabel.labelToString(s.lbl)} ${s.name} : ${SMTExpr.serializeType(s.e)} = ${s.e}") ++ + sys.states.map(serialize)).mkString("\n") + } + + def serialize(s: State): String = { + s"state ${s.sym.name} : ${SMTExpr.serializeType(s.sym)}" + + s.init.map("\n [init] " + _).getOrElse("") + + s.next.map("\n [next] " + _).getOrElse("") + } + + def systemExpressions(sys: TransitionSystem): List[SMTExpr] = + sys.signals.map(_.e) ++ sys.states.flatMap(s => s.init ++ s.next) + + def findUninterpretedFunctions(sys: TransitionSystem): List[DeclareFunction] = { + val calls = systemExpressions(sys).flatMap(findUFCalls) + // find unique functions + calls.groupBy(_.sym.name).map(_._2.head).toList + } + + private def findUFCalls(e: SMTExpr): List[DeclareFunction] = { + val f = e match { + case BVFunctionCall(name, args, width) => + Some(DeclareFunction(BVSymbol(name, width), args)) + case ArrayFunctionCall(name, args, indexWidth, dataWidth) => + Some(DeclareFunction(ArraySymbol(name, indexWidth, dataWidth), args)) + case _ => None + } + f.toList ++ e.children.flatMap(findUFCalls) + } +} + +private object TopologicalSort { + + /** Ensures that all signals in the resulting system are topologically sorted. + * This is necessary because [[firrtl.transforms.RemoveWires]] does + * not sort assignments to outputs, submodule inputs nor memory ports. + */ + def run(sys: TransitionSystem): TransitionSystem = { + val inputsAndStates = sys.inputs.map(_.name) ++ sys.states.map(_.sym.name) + val signalOrder = sort(sys.signals.map(s => s.name -> s.e), inputsAndStates) + // TODO: maybe sort init expressions of states (this should not be needed most of the time) + signalOrder match { + case None => sys + case Some(order) => + val signalMap = sys.signals.map(s => s.name -> s).toMap + // we flatMap over `get` in order to ignore inputs/states in the order + sys.copy(signals = order.flatMap(signalMap.get).toList) + } + } + + private def sort(signals: Iterable[(String, SMTExpr)], globalSignals: Iterable[String]): Option[Iterable[String]] = { + val known = new mutable.HashSet[String]() ++ globalSignals + var needsReordering = false + val digraph = new MutableDiGraph[String] + signals.foreach { + case (name, expr) => + digraph.addVertex(name) + val uniqueDependencies = mutable.LinkedHashSet[String]() ++ findDependencies(expr) + uniqueDependencies.foreach { d => + if (!known.contains(d)) { needsReordering = true } + digraph.addPairWithEdge(name, d) + } + known.add(name) + } + if (needsReordering) { + Some(digraph.linearize.reverse) + } else { None } + } + + private def findDependencies(expr: SMTExpr): List[String] = expr match { + case BVSymbol(name, _) => List(name) + case ArraySymbol(name, _, _) => List(name) + case other => other.children.flatMap(findDependencies) + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala b/src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala index fdd51a37..21e8289e 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/Btor2Spec.scala @@ -26,9 +26,8 @@ class Btor2Spec extends AnyFlatSpec { |5 output 4 ; b |6 sort bitvec 1 |7 uext 3 2 8 - |8 eq 6 7 4 - |9 not 6 8 - |10 bad 9 ; a_eq_b + |8 neq 6 7 4 + |9 bad 8 ; a_eq_b |""".stripMargin assert(SMTBackendHelpers.toBotr2Str(src) == expected) @@ -46,17 +45,16 @@ class Btor2Spec extends AnyFlatSpec { |""".stripMargin val expected = - """; @ module 0:0 + """; @[module 0:0] |1 sort bitvec 8 - |2 input 1 a ; @ a 0:0 + |2 input 1 a ; @[a 0:0] |3 sort bitvec 16 |4 uext 3 2 8 - |5 output 4 ; b @ b 0:0, b_a 0:0 + |5 output 4 ; b @[b 0:0], @[b_a 0:0] |6 sort bitvec 1 |7 uext 3 2 8 - |8 eq 6 7 4 - |9 not 6 8 - |10 bad 9 ; assert_0 @ assert 0:0 + |8 neq 6 7 4 + |9 bad 8 ; assert_0 @[assert 0:0] |""".stripMargin assert(SMTBackendHelpers.toBotr2Str(src) == expected) diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala index c100da56..1fd0e99b 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala @@ -94,8 +94,7 @@ class FirrtlModuleToTransitionSystemSpec extends AnyFlatSpec { assert(sym.indexWidth == 5) assert(sym.dataWidth == 8) assert(m.init.isEmpty) - //assert(m.next.get.toString.contains("m[m.w.addr := m.w.data]")) - assert(m.next.get.toString == "m[m.w.addr := m.w.data]") + assert(m.next.get.toString.contains("m[m.w.addr := m.w.data]")) } it should "support scalar initialization of a memory to 0" in { @@ -170,7 +169,8 @@ class FirrtlModuleToTransitionSystemSpec extends AnyFlatSpec { |""".stripMargin val sys = SMTBackendHelpers.toSys(src) assert(sys.inputs.isEmpty, "Clock inputs should be ignored.") - assert(sys.outputs.isEmpty, "Clock outputs should be ignored.") + val outputs = sys.signals.filter(_.lbl == IsOutput) + assert(outputs.isEmpty, "Clock outputs should be ignored.") assert(sys.signals.isEmpty, "Connects of clock type should be ignored.") } diff --git a/src/test/scala/firrtl/backends/experimental/smt/SMTBackendHelpers.scala b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendHelpers.scala index 71d1d38c..8f4486ab 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/SMTBackendHelpers.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendHelpers.scala @@ -38,7 +38,7 @@ private object SMTBackendHelpers { val circuit = if (modelUndef) compileUndef(src) else compile(src) val module = circuit.modules.find(_.name == mod).get.asInstanceOf[ir.Module] // println(module.serialize) - new ModuleToTransitionSystem().run(module, presetRegs = presetRegs, memInit = memInit) + new ModuleToTransitionSystem(presetRegs = presetRegs, memInit = memInit, uninterpreted = Map()).run(module) } def toBotr2(src: String, mod: String = "m"): Iterable[String] = diff --git a/src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala index 338d760c..4d96631e 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/SMTLibSpec.scala @@ -47,16 +47,16 @@ class SMTLibSpec extends AnyFlatSpec { |""".stripMargin val expected = - """; @ module 0:0 + """; @[module 0:0] |(declare-sort m_s 0) |; firrtl-smt2-input a 8 - |; @ a 0:0 + |; @[a 0:0] |(declare-fun a_f (m_s) (_ BitVec 8)) |; firrtl-smt2-output b 16 - |; @ b 0:0, b_a 0:0 + |; @[b 0:0], @[b_a 0:0] |(define-fun b_f ((state m_s)) (_ BitVec 16) ((_ zero_extend 8) (a_f state))) |; firrtl-smt2-assert assert_0 1 - |; @ assert 0:0 + |; @[assert 0:0] |(define-fun assert_0_f ((state m_s)) Bool (= ((_ zero_extend 8) (a_f state)) (b_f state))) |(define-fun m_t ((state m_s) (state_n m_s)) Bool true) |(define-fun m_i ((state m_s)) Bool true) |
