aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/backends/experimental
diff options
context:
space:
mode:
authorAditya Naik2024-05-29 16:57:13 -0700
committerAditya Naik2024-05-29 16:57:13 -0700
commit165804ee58cb18443042b9655328278434ddedf4 (patch)
tree4e167eff9e7b3ec09d73dbd9feaa6f9964cd8a68 /src/main/scala/firrtl/backends/experimental
parent57b8a395ee8d5fdabb2deed3db7d0c644f0a7eed (diff)
Add Scala3 support
Diffstat (limited to 'src/main/scala/firrtl/backends/experimental')
-rw-r--r--src/main/scala/firrtl/backends/experimental/rtlil/RtlilEmitter.scala1083
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala253
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala179
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala379
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTCommand.scala12
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala81
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala342
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTExprMap.scala88
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTExprSerializer.scala60
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala177
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala133
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala272
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/TransitionSystem.scala120
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/UninterpretedModuleAnnotation.scala86
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/random/DefRandom.scala31
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/random/InvalidToRandomPass.scala125
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala461
17 files changed, 0 insertions, 3882 deletions
diff --git a/src/main/scala/firrtl/backends/experimental/rtlil/RtlilEmitter.scala b/src/main/scala/firrtl/backends/experimental/rtlil/RtlilEmitter.scala
deleted file mode 100644
index 6c6c0b69..00000000
--- a/src/main/scala/firrtl/backends/experimental/rtlil/RtlilEmitter.scala
+++ /dev/null
@@ -1,1083 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-
-package firrtl.backends.experimental.rtlil
-
-import java.io.Writer
-import firrtl._
-import firrtl.PrimOps._
-import firrtl.ir._
-import firrtl.Utils.{throwInternalError, _}
-import firrtl.WrappedExpression._
-import firrtl.traversals.Foreachers._
-import firrtl.annotations._
-import firrtl.options.Viewer.view
-import firrtl.options.{CustomFileEmission, Dependency}
-import firrtl.passes.LowerTypes
-import firrtl.passes.MemPortUtils.memPortField
-import firrtl.stage.{FirrtlOptions, TransformManager}
-
-import scala.annotation.tailrec
-import scala.collection.mutable
-import scala.collection.mutable.ArrayBuffer
-import scala.language.postfixOps
-
-case class EmittedRtlilCircuitAnnotation(name: String, value: String, outputSuffix: String)
- extends NoTargetAnnotation
- with CustomFileEmission {
- override protected def baseFileName(annotations: AnnotationSeq): String =
- view[FirrtlOptions](annotations).outputFileName.getOrElse(name)
- override protected def suffix: Option[String] = Some(outputSuffix)
- override def getBytes: Iterable[Byte] = value.getBytes
-}
-case class EmittedRtlilModuleAnnotation(name: String, value: String, outputSuffix: String)
- extends NoTargetAnnotation
- with CustomFileEmission {
- override protected def baseFileName(annotations: AnnotationSeq): String =
- view[FirrtlOptions](annotations).outputFileName.getOrElse(name)
- override protected def suffix: Option[String] = Some(outputSuffix)
- override def getBytes: Iterable[Byte] = value.getBytes
-}
-
-private[firrtl] class RtlilEmitter extends SeqTransform with Emitter with DependencyAPIMigration {
-
- override def prerequisites: Seq[TransformManager.TransformDependency] =
- Seq(
- Dependency[firrtl.transforms.CombineCats],
- Dependency(firrtl.passes.memlib.VerilogMemDelays)
- ) ++: firrtl.stage.Forms.LowFormOptimized
-
- override def outputSuffix: String = ".il"
- val tab = " "
-
- override def transforms: Seq[Transform] = new TransformManager(prerequisites).flattenedTransformOrder
-
- def emit(state: CircuitState, writer: Writer): Unit = {
- val cs = runTransforms(state)
- val emissionOptions = new EmissionOptions(cs.annotations)
- val moduleMap = cs.circuit.modules.map(m => m.name -> m).toMap
- cs.circuit.modules.foreach {
- case DescribedMod(d, pds, m: Module) =>
- val renderer = new RtlilRender(d, pds, m, moduleMap, cs.circuit.main, emissionOptions)(writer)
- renderer.emit_rtlil()
- case m: Module =>
- val renderer = new RtlilRender(m, moduleMap, cs.circuit.main, emissionOptions)(writer)
- renderer.emit_rtlil()
- case _ => // do nothing
- }
- }
-
- override def execute(state: CircuitState): CircuitState = {
- val writerToString =
- (writer: java.io.StringWriter) => writer.toString.replaceAll("""(?m) +$""", "") // trim trailing whitespace
-
- val newAnnos = state.annotations.flatMap {
- case EmitCircuitAnnotation(a) if this.getClass == a =>
- val writer = new java.io.StringWriter
- emit(state, writer)
- Seq(
- EmittedRtlilModuleAnnotation(state.circuit.main, writerToString(writer), outputSuffix)
- )
-
- case EmitAllModulesAnnotation(a) if this.getClass == a =>
- val cs = runTransforms(state)
- val emissionOptions = new EmissionOptions(cs.annotations)
- val moduleMap = cs.circuit.modules.map(m => m.name -> m).toMap
-
- cs.circuit.modules.flatMap {
- case DescribedMod(d, pds, module: Module) =>
- val writer = new java.io.StringWriter
- val renderer = new RtlilRender(d, pds, module, moduleMap, cs.circuit.main, emissionOptions)(writer)
- renderer.emit_rtlil()
- Some(
- EmittedRtlilModuleAnnotation(module.name, writerToString(writer), outputSuffix)
- )
- case module: Module =>
- val writer = new java.io.StringWriter
- val renderer = new RtlilRender(module, moduleMap, cs.circuit.main, emissionOptions)(writer)
- renderer.emit_rtlil()
- Some(
- EmittedRtlilModuleAnnotation(module.name, writerToString(writer), outputSuffix)
- )
- case _ => None
- }
- case _ => Seq()
- }
- state.copy(annotations = newAnnos ++ state.annotations)
- }
-
- private class RtlilRender(
- description: Seq[Description],
- portDescriptions: Map[String, Seq[Description]],
- m: Module,
- moduleMap: Map[String, DefModule],
- circuitName: String,
- emissionOptions: EmissionOptions
- )(
- implicit writer: Writer) {
- def this(
- m: Module,
- moduleMap: Map[String, DefModule],
- circuitName: String,
- emissionOptions: EmissionOptions
- )(
- implicit writer: Writer
- ) = {
- this(Seq(), Map.empty, m, moduleMap, circuitName, emissionOptions)(writer)
- }
-
- private val netlist: mutable.LinkedHashMap[WrappedExpression, InfoExpr] = mutable.LinkedHashMap()
- private val namespace: Namespace = Namespace(m)
-
- private val portdefs: ArrayBuffer[Seq[Any]] = ArrayBuffer[Seq[Any]]()
- private val declares: ArrayBuffer[Seq[Any]] = ArrayBuffer()
- private val instdeclares: mutable.Map[String, InstInfo] = mutable.Map()
- private val assigns: ArrayBuffer[Seq[Any]] = ArrayBuffer()
- private val attachSynAssigns: ArrayBuffer[Seq[Any]] = ArrayBuffer()
- private val processes: ArrayBuffer[Seq[Any]] = ArrayBuffer()
- // Used to determine type of initvar for initializing memories
- private val initials: ArrayBuffer[Seq[Any]] = ArrayBuffer()
- private val formals: ArrayBuffer[Seq[Any]] = ArrayBuffer()
- private val moduleTarget: ModuleTarget = CircuitTarget(circuitName).module(m.name)
-
- private def getLeadingTabs(x: Any): String = {
- x match {
- case seq: Seq[_] =>
- val head = seq.takeWhile(_ == tab).mkString
- val tail = seq.dropWhile(_ == tab).headOption.map(getLeadingTabs).getOrElse(tab)
- head + tail
- case _ => tab
- }
- }
-
- private def emit(x: Any)(implicit w: Writer): Unit = {
- this.emitCol(x, 0, getLeadingTabs(x))(writer)
- }
-
- private def emit(x: Any, top: Int)(implicit w: Writer): Unit = {
- emitCol(x, top, "")(writer)
- }
-
- private def emitCol(x: Any, top: Int, tabs: String)(implicit w: Writer): Unit = {
- x match {
- case e: SrcInfo => w.write(e.str_rep)
- case e: Reference => w.write(ref_to_name(e))
- case e: ValidIf => emitCol(Seq(e.value), top + 1, tabs)(writer)
- case e: WSubField => w.write(SrcInfo(e).str_rep)
- case e: WSubAccess =>
- w.write("\\" + s"${LowerTypes.loweredName(e.expr)} [ ${LowerTypes.loweredName(e.index)} ]")
- case e: Literal => w.write(bigint_to_str_rep(e.value, get_type_width(e.tpe)))
- case t: GroundType => w.write(stringify(t))
- case t: VectorType =>
- emit(t.tpe, top + 1)(writer)
- w.write(s"[${t.size - 1}:0]")
- case s: String => w.write(s)
- case i: Int => w.write(i.toString)
- case i: Long => w.write(i.toString)
- case i: BigInt => w.write(bigint_to_str_rep(i, if (i > 0) i.bitLength else i.bitLength + 1))
- case i: Info =>
- infos_to_attr(i) match {
- case Some(attr) =>
- w.write(attr)
- case None =>
- }
- case s: Seq[Any] =>
- s.foreach { e => emitCol(e, top + 1, tabs)(writer) }
- if (top == 0)
- w.write("\n")
- case x => throwInternalError(s"trying to emit unsupported operation: $x")
- }
- }
-
- private def build_netlist(s: Statement): Unit = {
- s.foreach(build_netlist)
- s match {
- case sx: Connect => netlist(sx.loc) = InfoExpr(sx.info, sx.expr)
- case _: IsInvalid => error("Should have removed these!")
- // TODO Since only register update and memories use the netlist anymore, I think nodes are unnecessary
- case sx: DefNode =>
- val e = WRef(sx.name, sx.value.tpe, NodeKind, SourceFlow)
- netlist(e) = InfoExpr(sx.info, sx.value)
- case _ =>
- }
- }
-
- @tailrec
- private def remove_root(ex: Expression): Expression = ex match {
- case ex: WSubField =>
- ex.expr match {
- case e: WSubField => remove_root(e)
- case _: WRef => WRef(ex.name, ex.tpe, InstanceKind, UnknownFlow)
- }
- case _ => throwInternalError(s"shouldn't be here: remove_root($ex)")
- }
-
- private def stringify(tpe: GroundType): String = tpe match {
- case _: UIntType | _: AnalogType =>
- val wx = bitWidth(tpe)
- if (wx > 1) s"width $wx" else ""
- case _: SIntType =>
- val wx = bitWidth(tpe)
- if (wx > 1) s"signed width $wx" else "signed"
- case ClockType | AsyncResetType => ""
- case _ => throwInternalError(s"trying to write unsupported type in the Rtlil Emitter: $tpe")
- }
-
- private def stringify(param: Param): String = param match {
- case IntParam(name, value) =>
- val lit =
- if (value.isValidInt) {
- s"$value"
- } else {
- val blen = value.bitLength
- if (value > 0) s"$blen'd$value" else s"-${blen + 1}'sd${value.abs}"
- }
- s"parameter \\$name $lit"
- case DoubleParam(name, value) => s"parameter \\$name $value"
- case StringParam(name, value) => s"parameter \\$name ${value.verilogEscape}"
- case RawStringParam(name, value) => s"parameter \\$name $value"
- }
-
- // turn strings into Seq[String] verilog comments
- private def build_comment(desc: String): Seq[Seq[String]] = {
- val lines = desc.split("\n").toSeq
- lines.tail.map {
- case "" => Seq("#")
- case nonEmpty => Seq("#", nonEmpty)
- }
- }
- private def build_attribute(attr: String): Seq[Seq[String]] = {
- Seq(Seq("attribute \\") ++ Seq(attr))
- }
-
- private def build_description(d: Seq[Description]): Seq[Seq[String]] = d.flatMap {
- case DocString(desc) => build_comment(desc.string)
- case Attribute(attr) => build_attribute(attr.string)
- }
-
- // Turn ports into Seq[String] and add to portdefs
- private def build_ports(): Unit = {
- def padToMax(strs: Seq[String]): Seq[String] = {
- val len = if (strs.nonEmpty) strs.map(_.length).max else 0
- strs.map(_.padTo(len, ' '))
- }
-
- // Turn directions into strings (and AnalogType into inout)
- val dirs = m.ports.map {
- case Port(_, _, dir, tpe) =>
- (dir, tpe) match {
- case (_, AnalogType(_)) => "inout " // padded to length of output
- case (Input, _) => "input "
- case (Output, _) => "output"
- }
- }
- // Turn types into strings, all ports must be GroundTypes
- val tpes = m.ports.map {
- case Port(_, _, _, tpe: GroundType) => stringify(tpe)
- case port: Port => error(s"Trying to emit non-GroundType Port $port")
- }
-
- // dirs are already padded
- (dirs, padToMax(tpes), m.ports).zipped.toSeq.zipWithIndex.foreach {
- case ((dir, tpe, Port(info, name, _, _)), i) =>
- portDescriptions.get(name).map { d =>
- portdefs += Seq("")
- portdefs ++= build_description(d)
- }
- portdefs += Seq("wire ", tpe, " ", dir, " ", i + 1, " \\", name, info)
- }
- }
-
- private def infos_to_attr(info: Info): Option[String] = {
- def info_extract(info: Info, prev: Seq[String] = Seq()): Seq[String] = info match {
- case FileInfo(str) =>
- val (file, line, col) = FileInfo(str).split
- prev :+ (file + ":" + line + "." + col)
- case MultiInfo(infos) =>
- infos.foldLeft(prev)((a, b) => {
- info_extract(b, a)
- })
- case NoInfo =>
- prev
- }
- val srcinfo = info_extract(info)
- if (srcinfo.isEmpty)
- Option.empty
- else
- Option("attribute \\src \"" + srcinfo.mkString("|") + "\"")
- }
-
- private def string_to_rtlil_name(name: String): String = {
- if (name.head == '_') {
- "$" + name
- } else {
- "\\" + name
- }
- }
-
- private def ref_to_name(ref: Reference): String = {
- string_to_rtlil_name(ref.name)
- }
-
- private def regUpdate(r: Expression, clk: Expression, reset: Expression, init: Expression) = {
- val procName = namespace.newName("$process$" + this.m.name)
- val regTempName = "\\" + r.serialize + procName
- val loweredReset = SrcInfo(reset)
- val loweredClk = SrcInfo(clk)
- val loweredInit = SrcInfo(init)
- val loweredReg = SrcInfo(r)
- def addUpdate(info: Info, expr: Expression, tabs: Seq[String]): Seq[Seq[Any]] = expr match {
- case m: Mux =>
- if (m.tpe == ClockType) throw EmitterException("Cannot emit clock muxes directly")
- if (m.tpe == AsyncResetType) throw EmitterException("Cannot emit async reset muxes directly")
-
- val (eninfo, tinfo, finfo) = MultiInfo.demux(info)
- lazy val _if: Seq[Seq[Any]] =
- Seq(Seq(tabs, eninfo), Seq(tabs, "switch ", SrcInfo(m.cond, eninfo).str_rep)) ++ (
- if (infos_to_attr(tinfo).nonEmpty)
- Seq(Seq(tabs, tab, tinfo), Seq(tabs, tab, "case 1'1"))
- else
- Seq(Seq(tabs, tab, "case 1'1"))
- )
- lazy val _else: Seq[Seq[Any]] = infos_to_attr(finfo) match {
- case Some(_) =>
- Seq(Seq(tabs, tab, finfo), Seq(tabs, tab, "case"))
- case None =>
- Seq(Seq(tabs, tab, "case"))
- }
- lazy val _ifNot: Seq[Seq[Any]] =
- Seq(Seq(tabs, eninfo), Seq(tabs, "switch ", SrcInfo(m.cond, eninfo).str_rep)) ++ (
- if (infos_to_attr(finfo).nonEmpty)
- Seq(Seq(tabs, tab, finfo), Seq(tabs, tab, "case 1'0"))
- else
- Seq(Seq(tabs, tab, "case 1'0"))
- )
- lazy val _end = Seq(Seq(tabs, "end"))
- lazy val _true = addUpdate(tinfo, m.tval, Seq(tab, tab) ++ tabs)
- lazy val _false = addUpdate(finfo, m.fval, Seq(tab, tab) ++ tabs)
- /* For a Mux assignment, there are five possibilities, with one subcase for asynchronous reset:
- * 1. Both the true and false condition are self-assignments; do nothing
- * 2. The true condition is a self-assignment; invert the false condition and use that only
- * 3. The false condition is a self-assignment
- * a) The reset is asynchronous; emit both 'if' and a trivial 'else' to avoid latches
- * b) The reset is synchronous; skip the false condition
- * 4. The false condition is a Mux; use the true condition and use 'else if' for the false condition
- * 5. Default; use both the true and false conditions
- */
- (m.tval, m.fval) match {
- case (t, f) if weq(t, r) && weq(f, r) => Nil
- case (t, _) if weq(t, r) => _ifNot ++ _false ++ _end
- case (_, f) if weq(f, r) =>
- m.cond.tpe match {
- case AsyncResetType => (_if ++ _true ++ _else) ++ _true ++ _end
- case _ => _if ++ _true ++ _end
- }
- case _ => (_if ++ _true ++ _else) ++ _false ++ _end
- }
- case e =>
- Seq(Seq(tabs, "assign ", regTempName, " ", SrcInfo(e, info).str_rep))
- }
- if (weq(init, r)) { // Synchronous Reset
- val InfoExpr(info, e) = netlist(r)
- processes += Seq(info)
- processes += Seq("wire ", r.tpe, " ", regTempName)
- processes += Seq("process ", procName)
- processes += Seq("assign ", regTempName, " ", loweredInit.str_rep)
- processes ++= addUpdate(info, e, Seq(tab))
- processes += Seq(tab, "sync posedge ", clk)
- processes += Seq(tab, tab, "update ", SrcInfo(r).str_rep, " ", regTempName)
- processes += Seq("end")
- } else { // Asynchronous Reset
- assert(reset.tpe == AsyncResetType, "Error! Synchronous reset should have been removed!")
- val tv = init
- val InfoExpr(finfo, fv) = netlist(r)
- processes += Seq(finfo)
- processes += Seq("wire ", r.tpe, " ", regTempName)
- processes += Seq("process ", procName)
- processes += Seq("assign ", regTempName, " ", loweredInit.str_rep)
- processes ++= addUpdate(NoInfo, Mux(reset, tv, fv, mux_type_and_widths(tv, fv)), Seq.empty)
- processes += Seq("sync posedge ", loweredClk.str_rep)
- processes += Seq(tab, "update ", loweredReset.str_rep, " ", regTempName)
- processes += Seq("sync posedge ", reset)
- processes += Seq(tab, "update ", loweredReg.str_rep, " ", regTempName)
- processes += Seq("end")
- }
- }
-
- private def bigint_to_str_rep(bigInt: BigInt, width: BigInt): String = {
- if (width > 31) {
- var bigboi = bigInt
- var widthcnt = width
- var concatlist: Seq[String] = List()
-
- while (widthcnt > 32) {
- val lowbits = bigboi & 0xffffffff
- concatlist = concatlist :+ "%d'%s".format(32, lowbits.toString(2))
- bigboi >>= 32
- widthcnt -= 32
- }
- concatlist = concatlist :+ "%d'%s".format(widthcnt, bigboi.toString(2))
- "{ " + concatlist.reverse.mkString(" ") + " }"
- } else
- "%d'%s".format(width, bigInt.toString(2))
- }
-
- private case class InstInfo(inst_name: String, mod_name: String, info: Info) {
- val conns: mutable.Map[String, String] = mutable.Map()
- var params: Seq[String] = Seq()
- def getConnection(port: String): Option[String] = {
- conns.get(port)
- }
- def addConnection(port: String, targetValue: String): Unit = {
- conns(port) = targetValue
- }
- }
-
- private case class SrcInfo(str_rep: String, signed: Boolean, width: BigInt)
- private object SrcInfo {
- def apply(e: Expression, i: Info = NoInfo): SrcInfo = e match {
- case InfoExpr(info, expr) =>
- SrcInfo(expr, MultiInfo(info, i))
- case x: Reference =>
- SrcInfo(ref_to_name(x), x.tpe.isInstanceOf[SIntType], get_type_width(x.tpe))
- case x: Literal =>
- val width = x.width.asInstanceOf[IntWidth].width
- SrcInfo(bigint_to_str_rep(x.value, width), x.isInstanceOf[SIntLiteral], width)
- case x @ DoPrim(op, args, consts, tpe) =>
- op match {
- case Cat =>
- SrcInfo(
- Seq(" { ", args.map(SrcInfo(_).str_rep).mkString(" "), " }").mkString,
- tpe.isInstanceOf[SIntType],
- get_type_width(tpe)
- )
- case Head =>
- val src0 = SrcInfo(args.head)
- SrcInfo(
- Seq(src0.str_rep, " [", (src0.width - 1).toInt, ":", consts.head.toInt, "]").mkString,
- tpe.isInstanceOf[SIntType],
- get_type_width(tpe)
- )
- case Tail =>
- val src0 = SrcInfo(args.head)
- SrcInfo(
- Seq(src0.str_rep, " [", (src0.width - 1 - consts.head).toInt, ":0]").mkString,
- tpe.isInstanceOf[SIntType],
- get_type_width(tpe)
- )
- case Pad =>
- val src0 = SrcInfo(args.head)
- if (src0.width >= consts.head)
- SrcInfo(
- Seq(src0.str_rep, " [", (consts.head - 1).toInt, ":0]").mkString,
- tpe.isInstanceOf[SIntType],
- get_type_width(tpe)
- )
- else if (src0.signed)
- SrcInfo(
- Seq(
- " { ",
- s"${src0.str_rep} [${src0.width - 1}] " * (consts.head - src0.width).toInt,
- src0.str_rep,
- " }"
- ).mkString,
- tpe.isInstanceOf[SIntType],
- get_type_width(tpe)
- )
- else
- SrcInfo(
- Seq(" { ", (consts.head - src0.width).toInt, "'0 ", src0.str_rep, " }").mkString,
- tpe.isInstanceOf[SIntType],
- get_type_width(tpe)
- )
- case _ =>
- val tempNetName = namespace.newName("$_PRIM_EX")
- if (infos_to_attr(i).nonEmpty) declares += Seq(i)
- declares += Seq("wire ", x.tpe, " ", tempNetName)
- assigns ++= output_expr(tempNetName, x, i)
- SrcInfo(tempNetName, x.tpe.isInstanceOf[SIntType], get_type_width(x.tpe))
- }
- case x @ SubField(Reference(modname, _, InstanceKind, _), portname, _, _) =>
- val currentPortConn = instdeclares(modname).getConnection(portname)
- if (currentPortConn.isEmpty) {
- val tempNetName = "\\" + LowerTypes.loweredName(x)
- if (infos_to_attr(i).nonEmpty) declares += Seq(i)
- declares += Seq("wire ", x.tpe, " ", tempNetName)
- instdeclares(modname).addConnection(portname, tempNetName)
- SrcInfo(tempNetName, x.tpe.isInstanceOf[SIntType], get_type_width(x.tpe))
- } else {
- SrcInfo(currentPortConn.get, x.tpe.isInstanceOf[SIntType], get_type_width(x.tpe))
- }
- case x: SubField =>
- SrcInfo("\\" + LowerTypes.loweredName(x), x.tpe.isInstanceOf[SIntType], get_type_width(x.tpe))
- case x: Mux =>
- val tempNetName = namespace.newName("$_MUX_EX")
- if (infos_to_attr(i).nonEmpty) declares += Seq(i)
- declares += Seq("wire ", x.tpe, " ", tempNetName)
- assigns ++= output_expr(tempNetName, e, i)
- SrcInfo(tempNetName, x.tpe.isInstanceOf[SIntType], get_type_width(x.tpe))
- case x =>
- throw EmitterException(s"Internal error! unhandled value $x passed to SrcInfo()")
- }
- }
-
- private def emit_streams(): Unit = {
- build_description(description).foreach(emit(_))
- emit(Seq("# Generated by firrtl.RtlilEmitter (FIRRTL Version ", BuildInfo.version + ")"))
- emit(Seq("autoidx 1"))
- emit(Seq("attribute \\cells_not_processed 1"))
- emit(Seq("module \\", m.name, m.info))
- for (x <- portdefs) emit(Seq(tab, x))
- for (x <- declares) emit(Seq(tab, x))
- for ((_, x) <- instdeclares) {
- emit(Seq(tab, "attribute \\module_not_derived 1"))
- emit(Seq(tab, x.info))
- emit(Seq(tab, "cell \\", x.mod_name, " \\", x.inst_name))
- for (p <- x.params) emit(Seq(tab, tab, p))
- for ((a, b) <- x.conns) emit(Seq(tab, tab, "connect \\", a, " ", b))
- emit(Seq(tab, "end"))
- }
- for (x <- assigns) emit(Seq(tab, x))
- for (x <- processes) emit(Seq(tab, x))
- for (x <- attachSynAssigns) emit(Seq(tab, x))
- for (x <- initials) emit(Seq(tab, x))
- emit(Seq("end"))
- emit(Seq())
- }
-
- private def primop_to_cell(p: PrimOp): String = p match {
- case Not => "$not"
- case Neg => "$neg"
- case Andr => "$reduce_and"
- case Orr => "$reduce_or"
- case Xorr => "$reduce_xor"
- case And => "$and"
- case Or => "$or"
- case Xor => "$xor"
- case Shl => "$shl"
- case Dshl => "$shl"
- case Eq => "$eq"
- case Lt => "$lt"
- case Leq => "$le"
- case Neq => "$ne"
- case Geq => "$ge"
- case Gt => "$gt"
- case Add => "$add"
- case Addw => "$add"
- case Sub => "$sub"
- case Subw => "$sub"
- case Mul => "$mul"
- case Div => "$div"
- case Rem => "$rem"
- case _ =>
- throwInternalError(
- "Internal Error! primop %s shouldn't have propagated this far!".format(p.serialize)
- );
- }
-
- private def unary_cells = List("$not", "$neg", "$reduce_and", "$reduce_or", "$reduce_xor")
- private def get_type_width(e: Type): BigInt = { // just trust me bro, its lofirrtl
- e.asInstanceOf[GroundType].width.asInstanceOf[IntWidth].width
- }
-
- private def emit_cell(
- i: Info,
- name: String,
- params: Seq[(String, String)],
- connections: Seq[(String, String)]
- ): Seq[Seq[Any]] = {
- Seq(Seq(i), Seq("cell ", name, " ", namespace.newName(name + "$" + m.name))) ++
- params.map { p => Seq(tab, "parameter \\", p._1, " ", p._2) } ++
- connections.map { c => Seq(tab, "connect \\", c._1, " ", c._2) } ++
- Seq(Seq("end"))
- }
-
- private def emit_unary_cell(cell: String, src: SrcInfo, target: String, tgt_width: BigInt): Seq[Seq[Any]] = {
- emit_cell(
- NoInfo,
- cell,
- Seq(
- (
- "A_SIGNED",
- if (src.signed) { "1" }
- else { "0" }
- ),
- ("A_WIDTH", src.width.toString),
- ("Y_WIDTH", tgt_width.toString)
- ),
- Seq(("A", src.str_rep), ("Y", target))
- )
- }
-
- private def emit_binary_cell(
- cell: String,
- src_a: SrcInfo,
- src_b: SrcInfo,
- target: String,
- tgt_width: BigInt
- ): Seq[Seq[Any]] = {
- emit_cell(
- NoInfo,
- cell,
- Seq(
- (
- "A_SIGNED",
- if (src_a.signed) "1" else "0"
- ),
- ("A_WIDTH", src_a.width.toString),
- (
- "B_SIGNED",
- if (src_b.signed) "1" else "0"
- ),
- ("B_WIDTH", src_b.width.toString),
- ("Y_WIDTH", tgt_width.toString)
- ),
- Seq(("A", src_a.str_rep), ("B", src_b.str_rep), ("Y", target))
- )
- }
-
- @tailrec
- private def output_expr(n: String, d: Expression, i: Info): Seq[Seq[Any]] = d match {
- case UIntLiteral(_, _) | SIntLiteral(_, _) | Reference(_, _, _, _) | SubField(_, _, _, _) =>
- Seq(Seq("connect ", n, " ", SrcInfo(d, i).str_rep))
- case InfoExpr(info, expr) =>
- output_expr(n, expr, MultiInfo(Seq(i, info)))
- case Mux(cond, tval, fval, tpe) =>
- val (eninfo, tinfo, finfo) = MultiInfo.demux(i)
- val csrc = SrcInfo(cond, eninfo)
- val tsrc = SrcInfo(tval, tinfo)
- val fsrc = SrcInfo(fval, finfo)
- emit_cell(
- i,
- "$mux",
- Seq(("WIDTH", get_type_width(tpe).toString)),
- Seq(("A", fsrc.str_rep), ("B", tsrc.str_rep), ("S", csrc.str_rep), ("Y", n))
- )
- case DoPrim(op, args, consts, _) =>
- val sources = args.map(SrcInfo(_, i))
- val src0 = sources.head
- if (sources.map(_.width).contains(-1)) return Seq()
- op match {
- case AsSInt | AsUInt | AsClock | AsAsyncReset =>
- Seq(Seq("connect ", n, " ", src0))
- case Cvt =>
- if (src0.signed)
- Seq(Seq("connect ", n, " ", src0))
- else
- Seq(Seq("connect ", n, " { 1'0 ", src0, " }"))
- case Bits =>
- if (consts.head == consts.last)
- Seq(Seq("connect ", n, " ", src0, " [", consts.head.toInt, "]"))
- else
- Seq(Seq("connect ", n, " ", src0, " [", consts.head.toInt, ":", consts.last.toInt, "]"))
- case Shr | Shl =>
- val prim = if (op == Shr) (if (src0.signed) "$sshr" else "$shr") else "$shl"
- emit_binary_cell(
- prim,
- src0,
- SrcInfo(bigint_to_str_rep(consts.head, consts.head.bitLength), signed = false, consts.head.bitLength),
- n,
- get_type_width(d.tpe)
- )
- case Add =>
- if (src0.signed && sources(1).signed) {
- val src0_ext = SrcInfo(s"{ ${src0.str_rep} [${src0.width - 1}] ${src0.str_rep} }", true, src0.width + 1)
- val src1_ext = SrcInfo(
- s"{ ${sources(1).str_rep} [${sources(1).width - 1}] ${sources(1).str_rep} }",
- true,
- sources(1).width + 1
- )
- emit_binary_cell("$add", src0_ext, src1_ext, n, get_type_width(d.tpe))
- } else {
- emit_binary_cell("$add", src0, sources(1), n, get_type_width(d.tpe))
- }
- case Dshr | Dshl =>
- val prim = if (op == Dshr) (if (src0.signed) "$sshr" else "$shr") else "$shl"
- emit_binary_cell(prim, src0, sources(1), n, get_type_width(d.tpe))
- case Cat =>
- Seq(Seq("connect ", n, " { ", sources.map(_.str_rep).mkString(" "), " }"))
- case Head =>
- Seq(Seq("connect ", n, " ", src0, " [", (src0.width - 1).toInt, ":", consts.head.toInt, "]"))
- case Tail =>
- Seq(Seq("connect ", n, " ", src0, " [", (src0.width - 1 - consts.head).toInt, ":0]"))
- case Pad =>
- if (src0.width >= consts.head)
- Seq(Seq("connect ", n, " ", src0, " [", (consts.head - 1).toInt, ":0]"))
- else if (src0.signed)
- Seq(
- Seq("connect ", n) ++
- Seq(
- " { ",
- s"${src0.str_rep} [${src0.width - 1}] " * (consts.head - src0.width).toInt,
- src0.str_rep,
- " }"
- )
- )
- else
- Seq(Seq("connect ", n, " { ", (consts.head - src0.width).toInt, "'0 ", src0, " }"))
- case _ =>
- val cell = primop_to_cell(op)
- if (unary_cells.contains(cell))
- Seq(i) +: emit_unary_cell(cell, src0, n, get_type_width(d.tpe))
- else
- Seq(i) +: emit_binary_cell(cell, src0, sources(1), n, get_type_width(d.tpe))
- }
- case unk =>
- throw EmitterException(s"Internal error! unhandled output expression $unk passed to output_expr()")
- }
-
- private def build_streams(s: Statement): Unit = {
- val withoutDescription = s match {
- case DescribedStmt(d, stmt) =>
- stmt match {
- case _: IsDeclaration =>
- declares ++= build_description(d)
- case _ =>
- }
- stmt
- case stmt => stmt
- }
- withoutDescription.foreach(build_streams)
- withoutDescription match {
- case DefInstance(info, name, mdle, _) =>
- val (module, params) = moduleMap(mdle) match {
- case DescribedMod(_, _, ExtModule(_, _, _, extname, params)) => (extname, params)
- case DescribedMod(_, _, Module(_, name, _, _)) => (name, Seq.empty)
- case ExtModule(_, _, _, extname, params) => (extname, params)
- case Module(_, name, _, _) => (name, Seq.empty)
- }
- instdeclares(name) = InstInfo(name, module, info)
- instdeclares(name).params = if (params.nonEmpty) params.map(stringify) else Seq()
- case WDefInstanceConnector(info, name, mdle, _, portCons) =>
- val (_, params) = moduleMap(mdle) match {
- case DescribedMod(_, _, ExtModule(_, _, _, extname, params)) => (extname, params)
- case DescribedMod(_, _, Module(_, name, _, _)) => (name, Seq.empty)
- case ExtModule(_, _, _, extname, params) => (extname, params)
- case Module(_, name, _, _) => (name, Seq.empty)
- }
- instdeclares(name) = InstInfo(name, mdle, info)
- instdeclares(name).params = if (params.nonEmpty) params.map(stringify) else Seq()
- for ((port, ref) <- portCons) {
- val portName = SrcInfo(remove_root(port)).str_rep.tail
- if (instdeclares(name).getConnection(portName).nonEmpty) {
- assigns ++= output_expr(instdeclares(name).getConnection(portName).get, ref, NoInfo)
- } else {
- instdeclares(name).addConnection(SrcInfo(remove_root(port)).str_rep.tail, SrcInfo(ref).str_rep)
- }
- }
- case Connect(info, loc @ WRef(_, _, PortKind | WireKind | InstanceKind, _), expr) =>
- assigns ++= output_expr(ref_to_name(loc), expr, info)
- case Connect(info, SubField(Reference(modname, _, InstanceKind, _), portname, _, _), expr) =>
- if (instdeclares(modname).getConnection(portname).nonEmpty) {
- assigns ++= output_expr(instdeclares(modname).getConnection(portname).get, expr, NoInfo)
- } else {
- instdeclares(modname).addConnection(portname, SrcInfo(expr, info).str_rep)
- }
- case sx: DefWire =>
- declares += Seq(sx.info)
- declares += Seq("wire ", sx.tpe, " ", string_to_rtlil_name(sx.name))
- case sx: DefRegister =>
- val options = emissionOptions.getRegisterEmissionOption(moduleTarget.ref(sx.name))
- val e = WRef(sx.name, sx.tpe, ExpKind, UnknownFlow)
- declares += Seq(sx.info)
- declares += Seq("wire ", sx.tpe, " ", string_to_rtlil_name(sx.name))
- if (options.useInitAsPreset)
- regUpdate(e, sx.clock, sx.reset, e)
- else
- regUpdate(e, sx.clock, sx.reset, sx.init)
- case sx: DefNode =>
- declares += Seq(sx.info)
- declares += Seq("wire ", sx.value.tpe, " ", string_to_rtlil_name(sx.name))
- assigns ++= output_expr(string_to_rtlil_name(sx.name), sx.value, sx.info)
- case x @ Verification(value, info, _, pred, en, _) =>
- value match {
- case Formal.Assert =>
- formals += emit_cell(
- info,
- "$assert",
- Seq(),
- Seq(("A", SrcInfo(pred).str_rep), ("EN", SrcInfo(en).str_rep))
- )
- case Formal.Assume =>
- formals += emit_cell(
- info,
- "$assume",
- Seq(),
- Seq(("A", SrcInfo(pred).str_rep), ("EN", SrcInfo(en).str_rep))
- )
- case Formal.Cover =>
- formals += emit_cell(
- info,
- "$cover",
- Seq(),
- Seq(("A", SrcInfo(pred).str_rep), ("EN", SrcInfo(en).str_rep))
- )
- }
- case x @ DefMemory(i, name, tpe, depth, wlat, rlat, rd, wr, rdwr, runderw) =>
- val options = emissionOptions.getMemoryEmissionOption(moduleTarget.ref(name))
- val hasComplexRW = rdwr.nonEmpty && (rlat != 1)
- if (rlat > 1 || wlat != 1 || hasComplexRW)
- throw EmitterException(
- Seq(
- s"Memory $name is too complex to emit directly.",
- "Consider running VerilogMemDelays to simplify complex memories.",
- "Alternatively, add the --repl-seq-mem flag to replace memories with blackboxes."
- ).mkString(" ")
- )
- val dataWidth = bitWidth(tpe)
- val maxDataValue = (BigInt(1) << dataWidth.toInt) - 1
-
- def checkValueRange(value: BigInt, at: String): Unit = {
- if (value > maxDataValue)
- throw EmitterException(
- s"Memory $at cannot be initialized with value: $value. Too large (> $maxDataValue)!"
- )
- }
- declares += Seq("memory width ", dataWidth.toString, " size ", depth.toString, " \\", name)
- options.initValue match {
- case MemoryArrayInit(values) =>
- values.zipWithIndex.foreach {
- case (value, addr) =>
- checkValueRange(value, s"$name[$addr]")
- initials ++= emit_cell(
- i,
- "$meminit_v2",
- Seq(
- ("MEMID", "\"\\\\" + name + "\""),
- ("ABITS", "32"),
- ("WIDTH", dataWidth.toString),
- ("WORDS", "1"),
- ("PRIORITY", addr.toString)
- ),
- Seq(
- ("ADDR", addr.toString),
- ("DATA", bigint_to_str_rep(value, dataWidth)),
- ("EN", bigint_to_str_rep(BigInt(2).pow(dataWidth.toInt) - BigInt(1), dataWidth))
- )
- )
- }
-
- case MemoryScalarInit(value) =>
- for (addr <- 0 until depth.intValue) {
- initials ++= emit_cell(
- i,
- "$meminit_v2",
- Seq(
- ("MEMID", "\"\\\\" + name + "\""),
- ("ABITS", "32"),
- ("WIDTH", dataWidth.toString),
- ("WORDS", "1"),
- ("PRIORITY", addr.toString)
- ),
- Seq(
- ("ADDR", addr.toString),
- ("DATA", bigint_to_str_rep(value, dataWidth)),
- ("EN", bigint_to_str_rep(BigInt(2).pow(dataWidth.toInt) - BigInt(1), dataWidth))
- )
- )
- }
- case MemoryRandomInit =>
- println(s"Memory $name cannot be initialized with random data, RTLIL cannot express this.")
- println("Leaving memory uninitialized.")
- case MemoryFileInlineInit(_, _) =>
- throw EmitterException(s"Memory $name cannot be initialized from a file, RTLIL cannot express this.")
- case MemoryNoInit =>
- // No initialization to emit
- }
- for (r <- rd) {
- val data = memPortField(x, r, "data")
- val addr = memPortField(x, r, "addr")
- val en = memPortField(x, r, "en")
- val hasClk = if (rlat == 1) { "1'1" }
- else { "1'0" }
- val clkSrc = netlist(memPortField(x, r, "clk")).expr
- val transparent = runderw match {
- case ReadUnderWrite.New => "1'1"
- case ReadUnderWrite.Old => "1'0"
- case ReadUnderWrite.Undefined => "1'x"
- }
- declares += Seq("wire ", data.tpe, " ", SrcInfo(data).str_rep)
- assigns ++= emit_cell(
- i,
- "$memrd",
- Seq(
- ("ABITS", get_type_width(addr.tpe).toString),
- ("MEMID", "\"\\\\" + name + "\""),
- ("WIDTH", get_type_width(data.tpe).toString),
- ("CLK_ENABLE", hasClk),
- ("CLK_POLARITY", "1'1"),
- ("TRANSPARENT", transparent)
- ),
- Seq(
- ("CLK", SrcInfo(clkSrc, i).str_rep),
- ("EN", if (rlat == 1) SrcInfo(netlist(en), i).str_rep else "1'1"),
- ("ADDR", SrcInfo(netlist(addr), i).str_rep),
- ("DATA", SrcInfo(data, i).str_rep)
- )
- )
- }
- for (w <- wr) {
- val data = memPortField(x, w, "data")
- val addr = memPortField(x, w, "addr")
- val en = memPortField(x, w, "en")
- val mask = memPortField(x, w, "mask")
- val enSrc = SrcInfo(netlist(en))
- val maskSrc = SrcInfo(netlist(mask))
- if (maskSrc.width > 1) {
- throw EmitterException("Compound type memory write ports arent fully supported yet.")
- }
- var memwr_enmask = enSrc.str_rep
- if (bitWidth(data.tpe) != 1) {
- memwr_enmask = namespace.newName("$memwr_enmask$" + m.name)
- declares += Seq("wire signed width ", bitWidth(data.tpe).toInt, " ", memwr_enmask)
- assigns ++= emit_cell(
- i,
- "$and",
- Seq(
- ("A_SIGNED", "1"),
- ("B_SIGNED", "1"),
- ("A_WIDTH", bitWidth(en.tpe).toString()),
- ("B_WIDTH", maskSrc.width.toString()),
- ("Y_WIDTH", bitWidth(data.tpe).toString())
- ),
- Seq(("A", enSrc.str_rep), ("B", maskSrc.str_rep), ("Y", memwr_enmask))
- )
- }
- val hasClk = if (wlat == 1) { "1'1" }
- else { "1'0" }
- val clkSrc = netlist(memPortField(x, w, "clk")).expr
- assigns ++= emit_cell(
- i,
- "$memwr",
- Seq(
- ("ABITS", get_type_width(addr.tpe).toString),
- ("MEMID", "\"\\\\" + name + "\""),
- ("WIDTH", get_type_width(data.tpe).toString),
- ("CLK_ENABLE", hasClk),
- ("CLK_POLARITY", "1'1"),
- ("PRIORITY", "32'1")
- ),
- Seq(
- ("CLK", SrcInfo(clkSrc).str_rep),
- ("EN", memwr_enmask),
- ("ADDR", SrcInfo(netlist(addr)).str_rep),
- ("DATA", SrcInfo(netlist(data)).str_rep)
- )
- )
- }
- case sx: Attach =>
- for (set <- sx.exprs.toSet.subsets(2)) {
- val (a, b) = set.toSeq match {
- case Seq(x, y) => (x, y)
- }
- attachSynAssigns += Seq("connect ", SrcInfo(a, sx.info).str_rep, " ", SrcInfo(b, sx.info).str_rep)
- }
- case _ =>
- }
- }
-
- def emit_rtlil(): DefModule = {
- build_netlist(m.body)
- build_ports()
- build_streams(m.body)
- emit_streams()
- m
- }
- }
-}
-
-private[firrtl] class EmissionOptionMap[V <: EmissionOption](val df: V) {
- private val m = collection.mutable.HashMap[ReferenceTarget, V]().withDefaultValue(df)
- def +=(elem: (ReferenceTarget, V)): EmissionOptionMap.this.type = {
- if (m.contains(elem._1))
- throw EmitterException(s"Multiple EmissionOption for the target ${elem._1} (${m(elem._1)} ; ${elem._2})")
- m += elem
- this
- }
- def apply(key: ReferenceTarget): V = m.apply(key)
-}
-
-private[firrtl] class EmissionOptions(annotations: AnnotationSeq) {
- // Private so that we can present an immutable API
- private val memoryEmissionOption = new EmissionOptionMap[MemoryEmissionOption](
- annotations.collectFirst { case a: CustomDefaultMemoryEmission => a }.getOrElse(MemoryEmissionOptionDefault)
- )
- private val registerEmissionOption = new EmissionOptionMap[RegisterEmissionOption](
- annotations.collectFirst { case a: CustomDefaultRegisterEmission => a }.getOrElse(RegisterEmissionOptionDefault)
- )
- private val wireEmissionOption = new EmissionOptionMap[WireEmissionOption](WireEmissionOptionDefault)
- private val portEmissionOption = new EmissionOptionMap[PortEmissionOption](PortEmissionOptionDefault)
- private val nodeEmissionOption = new EmissionOptionMap[NodeEmissionOption](NodeEmissionOptionDefault)
- private val connectEmissionOption = new EmissionOptionMap[ConnectEmissionOption](ConnectEmissionOptionDefault)
-
- def getMemoryEmissionOption(target: ReferenceTarget): MemoryEmissionOption =
- memoryEmissionOption(target)
-
- def getRegisterEmissionOption(target: ReferenceTarget): RegisterEmissionOption =
- registerEmissionOption(target)
-
- def getWireEmissionOption(target: ReferenceTarget): WireEmissionOption =
- wireEmissionOption(target)
-
- def getPortEmissionOption(target: ReferenceTarget): PortEmissionOption =
- portEmissionOption(target)
-
- def getNodeEmissionOption(target: ReferenceTarget): NodeEmissionOption =
- nodeEmissionOption(target)
-
- def getConnectEmissionOption(target: ReferenceTarget): ConnectEmissionOption =
- connectEmissionOption(target)
-
- def emitMemoryInitAsNoSynth: Boolean = {
- val annos = annotations.collect { case a @ (MemoryNoSynthInit | MemorySynthInit) => a }
- annos match {
- case Seq() => true
- case Seq(MemoryNoSynthInit) => true
- case Seq(MemorySynthInit) => false
- case _ =>
- throw new FirrtlUserException(
- "There should only be at most one memory initialization option annotation, got $other"
- )
- }
- }
-
- private val emissionAnnos = annotations.collect {
- case m: SingleTargetAnnotation[ReferenceTarget] @unchecked with EmissionOption => m
- }
-
- annotations.foreach {
- case a: Annotation if a.dedup.nonEmpty =>
- val (_, _, target) = a.dedup.get
- if (!target.isLocal) {
- throw new FirrtlUserException(
- s"At least one dedupable annotation did not deduplicate: got non-local annotation $a from [[DedupAnnotationsTransform]]"
- )
- }
- case _ =>
- }
-
- // using multiple foreach instead of a single partial function as an Annotation can gather multiple EmissionOptions for simplicity
- emissionAnnos.foreach {
- case a: MemoryEmissionOption => memoryEmissionOption += ((a.target, a))
- case _ =>
- }
- emissionAnnos.foreach {
- case a: RegisterEmissionOption => registerEmissionOption += ((a.target, a))
- case _ =>
- }
- emissionAnnos.foreach {
- case a: WireEmissionOption => wireEmissionOption += ((a.target, a))
- case _ =>
- }
- emissionAnnos.foreach {
- case a: PortEmissionOption => portEmissionOption += ((a.target, a))
- case _ =>
- }
- emissionAnnos.foreach {
- case a: NodeEmissionOption => nodeEmissionOption += ((a.target, a))
- case _ =>
- }
- emissionAnnos.foreach {
- case a: ConnectEmissionOption => connectEmissionOption += ((a.target, a))
- case _ =>
- }
-}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala
deleted file mode 100644
index 37f9228f..00000000
--- a/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala
+++ /dev/null
@@ -1,253 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
-
-package firrtl.backends.experimental.smt
-
-import scala.collection.mutable
-
-object Btor2Serializer {
- def serialize(sys: TransitionSystem, skipOutput: Boolean = false): Iterable[String] = {
- new Btor2Serializer().run(sys, skipOutput)
- }
-}
-
-private class Btor2Serializer private () {
- private val symbols = mutable.HashMap[String, Int]()
- private val lines = mutable.ArrayBuffer[String]()
- private var index = 1
-
- private def line(l: String): Int = {
- val ii = index
- lines += s"$ii $l"
- index += 1
- ii
- }
-
- private def comment(c: String): Unit = { lines += s"; $c" }
- private def trailingComment(c: String): Unit = {
- val lastLine = lines.last
- val newLine = if (lastLine.contains(';')) { lastLine + " " + c }
- else { lastLine + " ; " + c }
- lines(lines.size - 1) = newLine
- }
-
- // bit vector type serialization
- private val bitVecTypeCache = mutable.HashMap[Int, Int]()
-
- private def t(width: Int): Int = bitVecTypeCache.getOrElseUpdate(width, line(s"sort bitvec $width"))
-
- // bit vector expression serialization
- private def s(expr: BVExpr): Int = expr match {
- case BVLiteral(value, width) => lit(value, width)
- case BVSymbol(name, _) => symbols.getOrElse(name, throw new RuntimeException(s"Unknown symbol: $name"))
- case BVExtend(e, 0, _) => s(e)
- case BVExtend(e, by, true) => line(s"sext ${t(expr.width)} ${s(e)} $by")
- case BVExtend(e, by, false) => line(s"uext ${t(expr.width)} ${s(e)} $by")
- case BVSlice(e, hi, lo) =>
- if (lo == 0 && hi == e.width - 1) { s(e) }
- else {
- line(s"slice ${t(expr.width)} ${s(e)} $hi $lo")
- }
- case BVNot(BVEqual(a, b)) => binary("neq", expr.width, a, b)
- case BVNot(BVNot(e)) => s(e)
- case BVNot(e) => unary("not", expr.width, e)
- case BVNegate(e) => unary("neg", expr.width, e)
- case BVReduceAnd(e) => unary("redand", expr.width, e)
- case BVReduceOr(e) => unary("redor", expr.width, e)
- case BVReduceXor(e) => unary("redxor", expr.width, e)
- case BVImplies(BVLiteral(v, 1), b) if v == 1 => s(b)
- case BVImplies(a, b) => binary("implies", expr.width, a, b)
- case BVEqual(a, b) => binary("eq", expr.width, a, b)
- case ArrayEqual(a, b) => line(s"eq ${t(expr.width)} ${s(a)} ${s(b)}")
- case BVComparison(Compare.Greater, a, b, false) => binary("ugt", expr.width, a, b)
- case BVComparison(Compare.GreaterEqual, a, b, false) => binary("ugte", expr.width, a, b)
- case BVComparison(Compare.Greater, a, b, true) => binary("sgt", expr.width, a, b)
- case BVComparison(Compare.GreaterEqual, a, b, true) => binary("sgte", expr.width, a, b)
- case BVOp(op, a, b) => binary(s(op), expr.width, a, b)
- case BVConcat(a, b) => binary("concat", expr.width, a, b)
- case call: BVFunctionCall => s(functionCallToArrayRead(call))
- case ArrayRead(array, index) =>
- line(s"read ${t(expr.width)} ${s(array)} ${s(index)}")
- case BVIte(cond, tru, fals) =>
- line(s"ite ${t(expr.width)} ${s(cond)} ${s(tru)} ${s(fals)}")
- case b @ BVAnd(terms) => variadic("and", b.width, terms)
- case b @ BVOr(terms) => variadic("or", b.width, terms)
- case forall: BVForall =>
- throw new RuntimeException(s"Quantifiers are not supported by the btor2 format: ${forall}")
- }
-
- private def s(op: Op.Value): String = op match {
- case Op.Xor => "xor"
- case Op.ArithmeticShiftRight => "sra"
- case Op.ShiftRight => "srl"
- case Op.ShiftLeft => "sll"
- case Op.Add => "add"
- case Op.Mul => "mul"
- case Op.Sub => "sub"
- case Op.SignedDiv => "sdiv"
- case Op.UnsignedDiv => "udiv"
- case Op.SignedMod => "smod"
- case Op.SignedRem => "srem"
- case Op.UnsignedRem => "urem"
- }
-
- private def unary(op: String, width: Int, e: BVExpr): Int = line(s"$op ${t(width)} ${s(e)}")
-
- private def binary(op: String, width: Int, a: BVExpr, b: BVExpr): Int =
- line(s"$op ${t(width)} ${s(a)} ${s(b)}")
-
- private def variadic(op: String, width: Int, terms: List[BVExpr]): Int = terms match {
- case Seq() | Seq(_) => throw new RuntimeException(s"expected at least two elements in variadic op $op")
- case Seq(a, b) => binary(op, width, a, b)
- case head :: tail =>
- val tailId = variadic(op, width, tail)
- line(s"$op ${t(width)} ${s(head)} ${tailId}")
- }
-
- private def lit(value: BigInt, w: Int): Int = {
- val typ = t(w)
- lazy val mask = (BigInt(1) << w) - 1
- if (value == 0) line(s"zero $typ")
- else if (value == 1) line(s"one $typ")
- else if (value == mask) line(s"ones $typ")
- else {
- val digits = value.toString(2)
- val padded = digits.reverse.padTo(w, '0').reverse
- line(s"const $typ $padded")
- }
- }
-
- // array type serialization
- private val arrayTypeCache = mutable.HashMap[(Int, Int), Int]()
-
- private def t(indexWidth: Int, dataWidth: Int): Int =
- arrayTypeCache.getOrElseUpdate((indexWidth, dataWidth), line(s"sort array ${t(indexWidth)} ${t(dataWidth)}"))
-
- // array expression serialization
- private def s(expr: ArrayExpr): Int = expr match {
- case ArraySymbol(name, _, _) => symbols(name)
- case ArrayStore(array, index, data) =>
- line(s"write ${t(expr.indexWidth, expr.dataWidth)} ${s(array)} ${s(index)} ${s(data)}")
- case ArrayIte(cond, tru, fals) =>
- // println("WARN: ITE on array is probably not supported by btor2")
- // While the spec does not seem to allow array ite, it seems to be supported in practice.
- // It is essential to model memories, so any support in the wild should be fairly well tested.
- line(s"ite ${t(expr.indexWidth, expr.dataWidth)} ${s(cond)} ${s(tru)} ${s(fals)}")
- case ArrayConstant(e, indexWidth) =>
- // The problem we are facing here is that the only way to create a constant array from a bv expression
- // seems to be to use the bv expression as the init value of a state variable.
- // Thus we need to create a fake state for every array init expression.
- arrayConstants.getOrElseUpdate(
- e.toString, {
- comment(s"$expr")
- val eId = s(e)
- val tpeId = t(indexWidth, e.width)
- val state = line(s"state $tpeId")
- line(s"init $tpeId $state $eId")
- state
- }
- )
- case f: ArrayFunctionCall =>
- throw new RuntimeException(s"The btor2 format does not support uninterpreted functions that return arrays!: $f")
- }
- private val arrayConstants = mutable.HashMap[String, Int]()
-
- private def s(expr: SMTExpr): Int = expr match {
- case b: BVExpr => s(b)
- case a: ArrayExpr => s(a)
- }
-
- // serialize the type of the expression
- private def t(expr: SMTExpr): Int = expr match {
- case b: BVExpr => t(b.width)
- case a: ArrayExpr => t(a.indexWidth, a.dataWidth)
- }
-
- private def functionCallToArrayRead(call: BVFunctionCall): BVExpr = {
- if (call.args.isEmpty) {
- BVSymbol(call.name, call.width)
- } else {
- val args: List[BVExpr] = call.args.map {
- case b: BVExpr => b
- case other => throw new RuntimeException(s"Unsupported call argument: $other in $call")
- }
- val index = concat(args)
- val a = ArraySymbol(call.name, indexWidth = index.width, dataWidth = call.width)
- ArrayRead(a, index)
- }
- }
- private def concat(e: Iterable[BVExpr]): BVExpr = {
- require(e.nonEmpty)
- e.reduce((a, b) => BVConcat(a, b))
- }
-
- def run(sys: TransitionSystem, skipOutput: Boolean): Iterable[String] = {
- def declare(name: String, lbl: Option[SignalLabel], expr: => Int): Unit = {
- assert(!symbols.contains(name), s"Trying to redeclare `$name`")
- val id = expr
- symbols(name) = id
- // add label
- lbl match {
- case Some(IsOutput) => if (!skipOutput) line(s"output $id ; $name")
- case Some(IsConstraint) => line(s"constraint $id ; $name")
- case Some(IsBad) => line(s"bad $id ; $name")
- case Some(IsFair) => line(s"fair $id ; $name")
- case _ =>
- }
- // add trailing comment
- sys.comments.get(name).foreach(trailingComment)
- }
-
- // header
- if (sys.header.nonEmpty) {
- sys.header.split('\n').foreach(comment)
- }
-
- // declare inputs
- sys.inputs.foreach { ii =>
- declare(ii.name, None, line(s"input ${t(ii.width)} ${ii.name}"))
- }
-
- // declare uninterpreted functions a constant arrays
- val ufs = TransitionSystem.findUninterpretedFunctions(sys)
- ufs.foreach { foo =>
- // only functions returning bit-vectors are supported!
- val bvSym = foo.sym.asInstanceOf[BVSymbol]
- val sym = if (foo.args.isEmpty) { bvSym }
- else {
- ArraySymbol(bvSym.name, foo.args.map(_.asInstanceOf[BVExpr].width).sum, bvSym.width)
- }
- comment(foo.toString)
- declare(sym.name, None, line(s"state ${t(sym)} ${sym.name}"))
- line(s"next ${t(sym)} ${s(sym)} ${s(sym)}")
- }
-
- // define state init
- sys.states.foreach { st =>
- // calculate init expression before declaring the state
- // this is required by btormc (presumably to avoid cycles in the init expression)
- val initId = st.init.map {
- // only in the context of initializing a state can we use a bv expression to model an array
- case ArrayConstant(e, _) => comment(s"${st.sym}.init"); s(e)
- case init => comment(s"${st.sym}.init"); s(init)
- }
- declare(st.sym.name, None, line(s"state ${t(st.sym)} ${st.sym.name}"))
- st.init.foreach { init => line(s"init ${t(init)} ${s(st.sym)} ${initId.get}") }
- }
-
- // define all other signals
- sys.signals.foreach { signal =>
- declare(signal.name, Some(signal.lbl), s(signal.e))
- }
-
- // define state next
- sys.states.foreach { st =>
- st.next.foreach { next =>
- comment(s"${st.sym}.next")
- line(s"next ${t(next)} ${s(st.sym)} ${s(next)}")
- }
- }
-
- lines
- }
-}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
deleted file mode 100644
index 865382c9..00000000
--- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
+++ /dev/null
@@ -1,179 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
-
-package firrtl.backends.experimental.smt
-
-import firrtl.ir
-import firrtl.PrimOps
-import firrtl.passes.CheckWidths.WidthTooBig
-
-private object FirrtlExpressionSemantics {
- def toSMT(e: ir.Expression): BVExpr = {
- val eSMT = e match {
- case ir.DoPrim(op, args, consts, _) => onPrim(op, args, consts)
- case r: ir.RefLikeExpression => BVSymbol(r.serialize, getWidth(r))
- case ir.UIntLiteral(value, ir.IntWidth(width)) => BVLiteral(value, width.toInt)
- case ir.SIntLiteral(value, ir.IntWidth(width)) =>
- val twosComplementValue = value & ((BigInt(1) << width.toInt) - 1)
- BVLiteral(twosComplementValue, width.toInt)
- case ir.Mux(cond, tval, fval, _) =>
- val width = List(tval, fval).map(getWidth).max
- BVIte(toSMT(cond), toSMT(tval, width), toSMT(fval, width))
- case v: ir.ValidIf =>
- throw new RuntimeException(s"Unsupported expression: ValidIf ${v.serialize}")
- }
- assert(
- eSMT.width == getWidth(e),
- "We aim to always produce a SMT expression of the same width as the firrtl expression."
- )
- eSMT
- }
-
- /** Ensures that the result has the desired width by appropriately extending it. */
- def toSMT(e: ir.Expression, width: Int, allowNarrow: Boolean = false): BVExpr =
- forceWidth(toSMT(e), isSigned(e), width, allowNarrow)
-
- private def forceWidth(eSMT: BVExpr, eSigned: Boolean, width: Int, allowNarrow: Boolean = false): BVExpr = {
- if (eSMT.width == width) { eSMT }
- else if (width < eSMT.width) {
- assert(allowNarrow, s"Narrowing from ${eSMT.width} bits to $width bits is not allowed!")
- BVSlice(eSMT, width - 1, 0)
- } else {
- BVExtend(eSMT, width - eSMT.width, eSigned)
- }
- }
-
- // see "Primitive Operations" section in the Firrtl Specification
- private def onPrim(
- op: ir.PrimOp,
- args: Seq[ir.Expression],
- consts: Seq[BigInt]
- ): BVExpr = {
- (op, args, consts) match {
- case (PrimOps.Add, Seq(e1, e2), _) =>
- val width = args.map(getWidth).max + 1
- BVOp(Op.Add, toSMT(e1, width), toSMT(e2, width))
- case (PrimOps.Sub, Seq(e1, e2), _) =>
- val width = args.map(getWidth).max + 1
- BVOp(Op.Sub, toSMT(e1, width), toSMT(e2, width))
- case (PrimOps.Mul, Seq(e1, e2), _) =>
- val width = args.map(getWidth).sum
- BVOp(Op.Mul, toSMT(e1, width), toSMT(e2, width))
- case (PrimOps.Div, Seq(num, den), _) =>
- val signed = isSigned(num)
- val resWidth = if (signed) { getWidth(num) + 1 }
- else { getWidth(num) }
- val op = if (signed) { Op.SignedDiv }
- else { Op.UnsignedDiv }
- // we do the calculation on the widened values and then narrow the result if needed
- val width = args.map(getWidth).max + (if (signed) 1 else 0)
- val res = BVOp(op, toSMT(num, width), toSMT(den, width))
- forceWidth(res, signed, resWidth, allowNarrow = true)
- case (PrimOps.Rem, Seq(num, den), _) =>
- val signed = isSigned(num)
- val op = if (signed) Op.SignedRem else Op.UnsignedRem
- val width = args.map(getWidth).max
- val resWidth = args.map(getWidth).min
- val res = BVOp(op, toSMT(num, width), toSMT(den, width))
- forceWidth(res, signed, resWidth, allowNarrow = true)
- case (PrimOps.Lt, Seq(e1, e2), _) =>
- val width = args.map(getWidth).max
- BVNot(BVComparison(Compare.GreaterEqual, toSMT(e1, width), toSMT(e2, width), isSigned(e1)))
- case (PrimOps.Leq, Seq(e1, e2), _) =>
- val width = args.map(getWidth).max
- BVNot(BVComparison(Compare.Greater, toSMT(e1, width), toSMT(e2, width), isSigned(e1)))
- case (PrimOps.Gt, Seq(e1, e2), _) =>
- val width = args.map(getWidth).max
- BVComparison(Compare.Greater, toSMT(e1, width), toSMT(e2, width), isSigned(e1))
- case (PrimOps.Geq, Seq(e1, e2), _) =>
- val width = args.map(getWidth).max
- BVComparison(Compare.GreaterEqual, toSMT(e1, width), toSMT(e2, width), isSigned(e1))
- case (PrimOps.Eq, Seq(e1, e2), _) =>
- val width = args.map(getWidth).max
- BVEqual(toSMT(e1, width), toSMT(e2, width))
- case (PrimOps.Neq, Seq(e1, e2), _) =>
- val width = args.map(getWidth).max
- BVNot(BVEqual(toSMT(e1, width), toSMT(e2, width)))
- case (PrimOps.Pad, Seq(e), Seq(n)) =>
- val width = getWidth(e)
- if (n <= width) { toSMT(e) }
- else { BVExtend(toSMT(e), n.toInt - width, isSigned(e)) }
- case (PrimOps.AsUInt, Seq(e), _) => checkForClockInCast(PrimOps.AsUInt, e); toSMT(e)
- case (PrimOps.AsSInt, Seq(e), _) => checkForClockInCast(PrimOps.AsSInt, e); toSMT(e)
- case (PrimOps.AsFixedPoint, Seq(e), _) => throw new AssertionError("Fixed-Point numbers need to be lowered!")
- case (PrimOps.AsClock, Seq(e), _) => toSMT(e)
- case (PrimOps.AsAsyncReset, Seq(e), _) =>
- checkForClockInCast(PrimOps.AsAsyncReset, e)
- throw new AssertionError(s"Asynchronous resets are not supported! Cannot cast ${e.serialize}.")
- case (PrimOps.Shl, Seq(e), Seq(n)) =>
- if (n == 0) { toSMT(e) }
- else {
- val zeros = BVLiteral(0, n.toInt)
- BVConcat(toSMT(e), zeros)
- }
- case (PrimOps.Shr, Seq(e), Seq(n)) =>
- val width = getWidth(e)
- // "If n is greater than or equal to the bit-width of e,
- // the resulting value will be zero for unsigned types
- // and the sign bit for signed types"
- if (n >= width) {
- if (isSigned(e)) { BVSlice(toSMT(e), width - 1, width - 1) }
- else { BV1BitZero }
- } else {
- BVSlice(toSMT(e), width - 1, n.toInt)
- }
- case (PrimOps.Dshl, Seq(e1, e2), _) =>
- val width = getWidth(e1) + (1 << getWidth(e2)) - 1
- BVOp(Op.ShiftLeft, toSMT(e1, width), toSMT(e2, width))
- case (PrimOps.Dshr, Seq(e1, e2), _) =>
- val width = getWidth(e1)
- val o = if (isSigned(e1)) Op.ArithmeticShiftRight else Op.ShiftRight
- BVOp(o, toSMT(e1, width), toSMT(e2, width))
- case (PrimOps.Cvt, Seq(e), _) =>
- if (isSigned(e)) { toSMT(e) }
- else { BVConcat(BV1BitZero, toSMT(e)) }
- case (PrimOps.Neg, Seq(e), _) => BVNegate(BVExtend(toSMT(e), 1, isSigned(e)))
- case (PrimOps.Not, Seq(e), _) => BVNot(toSMT(e))
- case (PrimOps.And, Seq(e1, e2), _) =>
- val width = args.map(getWidth).max
- BVAnd(toSMT(e1, width), toSMT(e2, width))
- case (PrimOps.Or, Seq(e1, e2), _) =>
- val width = args.map(getWidth).max
- BVOr(toSMT(e1, width), toSMT(e2, width))
- case (PrimOps.Xor, Seq(e1, e2), _) =>
- val width = args.map(getWidth).max
- BVOp(Op.Xor, toSMT(e1, width), toSMT(e2, width))
- case (PrimOps.Andr, Seq(e), _) => BVReduceAnd(toSMT(e))
- case (PrimOps.Orr, Seq(e), _) => BVReduceOr(toSMT(e))
- case (PrimOps.Xorr, Seq(e), _) => BVReduceXor(toSMT(e))
- case (PrimOps.Cat, Seq(e1, e2), _) => BVConcat(toSMT(e1), toSMT(e2))
- case (PrimOps.Bits, Seq(e), Seq(hi, lo)) => BVSlice(toSMT(e), hi.toInt, lo.toInt)
- case (PrimOps.Head, Seq(e), Seq(n)) =>
- val width = getWidth(e)
- assert(n >= 0 && n <= width)
- BVSlice(toSMT(e), width - 1, width - n.toInt)
- case (PrimOps.Tail, Seq(e), Seq(n)) =>
- val width = getWidth(e)
- assert(n >= 0 && n <= width)
- assert(n < width, "While allowed by the firrtl standard, we do not support 0-bit values in this backend!")
- BVSlice(toSMT(e), width - n.toInt - 1, 0)
- }
- }
-
- /** For now we strictly forbid casting clocks to anything else.
- * Eventually this should be replaced by a more sophisticated clock analysis pass.
- */
- private def checkForClockInCast(cast: ir.PrimOp, signal: ir.Expression): Unit = {
- assert(signal.tpe != ir.ClockType, s"Cannot cast (${cast.serialize}) clock expression ${signal.serialize}!")
- }
-
- private val BV1BitZero = BVLiteral(0, 1)
-
- private def isSigned(e: ir.Expression): Boolean = e.tpe match {
- case _: ir.SIntType => true
- case _ => false
- }
-
- // Helper function
- private def getWidth(e: ir.Expression): Int = firrtl.bitWidth(e.tpe).toInt
-}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
deleted file mode 100644
index 7da2e1e6..00000000
--- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
+++ /dev/null
@@ -1,379 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
-
-package firrtl.backends.experimental.smt
-
-import firrtl.annotations.{MemoryInitAnnotation, NoTargetAnnotation, PresetRegAnnotation}
-import firrtl._
-import firrtl.backends.experimental.smt.random._
-import firrtl.options.Dependency
-import firrtl.passes.MemPortUtils.memPortField
-import firrtl.passes.PassException
-import firrtl.passes.memlib.VerilogMemDelays
-import firrtl.stage.Forms
-import firrtl.stage.TransformManager.TransformDependency
-import firrtl.transforms.{EnsureNamedStatements, PropagatePresetAnnotations}
-import logger.LazyLogging
-
-import scala.collection.mutable
-
-case class TransitionSystemAnnotation(sys: TransitionSystem) extends NoTargetAnnotation
-
-/** Contains code to convert a flat firrtl module into a functional transition system which
- * can then be exported as SMTLib or Btor2 file.
- */
-object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration {
- override def prerequisites: Seq[Dependency[Transform]] = Forms.LowForm ++
- Seq(
- Dependency(VerilogMemDelays),
- Dependency(EnsureNamedStatements), // this is required to give assert/assume statements good names
- Dependency[PropagatePresetAnnotations]
- )
- override def invalidates(a: Transform): Boolean = false
- // since this pass only runs on the main module, inlining needs to happen before
- override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency[firrtl.passes.InlineInstances])
-
- override protected def execute(state: CircuitState): CircuitState = {
- val circuit = state.circuit
- val presetRegs = state.annotations.collect {
- case PresetRegAnnotation(target) if target.module == circuit.main => target.ref
- }.toSet
-
- // collect all non-random memory initialization
- val memInit = state.annotations.collect { case a: MemoryInitAnnotation if !a.isRandomInit => a }
- .filter(_.target.module == circuit.main)
- .map(a => a.target.ref -> a.initValue)
- .toMap
-
- // module look up table
- val modules = circuit.modules.map(m => m.name -> m).toMap
-
- // collect uninterpreted module annotations
- val uninterpreted = state.annotations.collect {
- case a: UninterpretedModuleAnnotation =>
- UninterpretedModuleAnnotation.checkModule(modules(a.target.module), a)
- a.target.module -> a
- }.toMap
-
- // convert the main module
- val main = modules(circuit.main)
- val sys = main match {
- case _: ir.ExtModule =>
- throw new ExtModuleException(
- "External modules are not supported by the SMT backend. Use yosys if you need to convert Verilog."
- )
- case m: ir.Module =>
- new ModuleToTransitionSystem(presetRegs = presetRegs, memInit = memInit, uninterpreted = uninterpreted).run(m)
- }
-
- val sortedSys = TopologicalSort.run(sys)
- val anno = TransitionSystemAnnotation(sortedSys)
- state.copy(circuit = circuit, annotations = state.annotations :+ anno)
- }
-}
-
-private object UnsupportedException {
- val HowToRunStuttering: String =
- """
- |You can run the StutteringClockTransform which
- |replaces all clock inputs with a clock enable signal.
- |This is required not only for multi-clock designs, but also to
- |accurately model asynchronous reset which could happen even if there
- |isn't a clock edge.
- | If you are using the firrtl CLI, please add:
- | -fct firrtl.backends.experimental.smt.StutteringClockTransform
- | If you are calling into firrtl programmatically you can use:
- | RunFirrtlTransformAnnotation(Dependency[StutteringClockTransform])
- | To designate a clock to be the global_clock (i.e. the simulation tick), use:
- | GlobalClockAnnotation(CircuitTarget(...).module(...).ref("your_clock")))
- |""".stripMargin
-}
-
-private class ExtModuleException(s: String) extends PassException(s)
-private class AsyncResetException(s: String) extends PassException(s + UnsupportedException.HowToRunStuttering)
-private class MultiClockException(s: String) extends PassException(s + UnsupportedException.HowToRunStuttering)
-private class MissingFeatureException(s: String)
- extends PassException("Unfortunately the SMT backend does not yet support: " + s)
-
-private class ModuleToTransitionSystem(
- presetRegs: Set[String],
- memInit: Map[String, MemoryInitValue],
- uninterpreted: Map[String, UninterpretedModuleAnnotation])
- extends LazyLogging {
-
- def run(m: ir.Module): TransitionSystem = {
- // first pass over the module to convert expressions; discover state and I/O
- m.foreachPort(onPort)
- m.foreachStmt(onStatement)
-
- // multi-clock support requires the StutteringClock transform to be run
- if (clocks.size > 1) {
- throw new MultiClockException(s"The module ${m.name} has more than one clock: ${clocks.mkString(", ")}")
- }
-
- // generate comments from infos
- val comments = mutable.HashMap[String, String]()
- infos.foreach {
- case (name, info) =>
- val infoStr = info.serialize.trim
- if (infoStr.nonEmpty) {
- val prefix = comments.get(name).map(_ + ", ").getOrElse("")
- comments(name) = prefix + infoStr
- }
- }
-
- // module info to the comment header
- val header = m.info.serialize.trim
-
- TransitionSystem(m.name, inputs.toList, states.values.toList, signals.toList, comments.toMap, header)
- }
-
- private val inputs = mutable.ArrayBuffer[BVSymbol]()
- private val clocks = mutable.ArrayBuffer[String]()
- private val signals = mutable.ArrayBuffer[Signal]()
- private val states = mutable.LinkedHashMap[String, State]()
- private val infos = mutable.ArrayBuffer[(String, ir.Info)]()
-
- private def onPort(p: ir.Port): Unit = {
- if (isAsyncReset(p.tpe)) {
- throw new AsyncResetException(s"Found AsyncReset ${p.name}.")
- }
- infos.append(p.name -> p.info)
- p.direction match {
- case ir.Input =>
- if (isClock(p.tpe)) {
- clocks.append(p.name)
- } else {
- inputs.append(BVSymbol(p.name, bitWidth(p.tpe).toInt))
- }
- case ir.Output =>
- }
- }
-
- private def onStatement(s: ir.Statement): Unit = s match {
- case DefRandom(info, name, tpe, _, en) =>
- assert(!isClock(tpe), "rand should never be a clock!")
- // we model random sources as inputs and the enable signal as output
- infos.append(name -> info)
- inputs.append(BVSymbol(name, bitWidth(tpe).toInt))
- signals.append(Signal(name + ".en", onExpression(en, 1), IsOutput))
- case w: ir.DefWire =>
- if (!isClock(w.tpe)) {
- // InlineInstances can insert wires without re-running RemoveWires for now we just deal with it when
- // the Wires is connected to (ir.Connect).
- }
- case ir.DefNode(info, name, expr) =>
- if (!isClock(expr.tpe) && !isAsyncReset(expr.tpe)) {
- infos.append(name -> info)
- signals.append(Signal(name, onExpression(expr), IsNode))
- }
- case r: ir.DefRegister =>
- infos.append(r.name -> r.info)
- states(r.name) = onRegister(r)
- case m: ir.DefMemory =>
- infos.append(m.name -> m.info)
- states(m.name) = onMemory(m)
- case ir.Connect(info, loc, expr) =>
- if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!")
- if (!isClock(loc.tpe) && !isAsyncReset(expr.tpe)) { // we ignore clock connections
- val name = loc.serialize
- val e = onExpression(expr, bitWidth(loc.tpe).toInt, allowNarrow = false)
- Utils.kind(loc) match {
- case RegKind => states(name) = states(name).copy(next = Some(e))
- case PortKind | InstanceKind => // module output or submodule input
- infos.append(name -> info)
- signals.append(Signal(name, e, IsOutput))
- case MemKind | WireKind =>
- // InlineInstances can insert wires without re-running RemoveWires for now we just deal with it.
- infos.append(name -> info)
- signals.append(Signal(name, e, IsNode))
- }
- }
- case i: ir.IsInvalid =>
- throw new UnsupportedFeatureException(s"IsInvalid statements are not supported: ${i.serialize}")
- case ir.DefInstance(info, name, module, tpe) => onInstance(info, name, module, tpe)
- case s: ir.Verification =>
- if (s.op == ir.Formal.Cover) {
- logger.info(s"[info] Cover statement was ignored: ${s.serialize}")
- } else {
- val name = s.name
- val predicate = onExpression(s.pred)
- val enabled = onExpression(s.en)
- val e = BVImplies(enabled, predicate)
- infos.append(name -> s.info)
- val signal = if (s.op == ir.Formal.Assert) {
- Signal(name, BVNot(e), IsBad)
- } else {
- Signal(name, e, IsConstraint)
- }
- signals.append(signal)
- }
- case s: ir.Conditionally =>
- error(s"When conditions are not supported. Please run ExpandWhens: ${s.serialize}")
- case s: ir.PartialConnect =>
- error(s"PartialConnects are not supported. Please run ExpandConnects: ${s.serialize}")
- case s: ir.Attach =>
- error(s"Analog wires are not supported in the SMT backend: ${s.serialize}")
- case s: ir.Stop =>
- if (s.ret == 0) {
- logger.info(
- s"[info] Stop statements with a return code of 0 are currently not supported. Ignoring: ${s.serialize}"
- )
- } else {
- // we treat Stop statements with a non-zero exit value as assertions that en will always be false!
- val name = s.name
- infos.append(name -> s.info)
- signals.append(Signal(name, onExpression(s.en), IsBad))
- }
- case s: ir.Print =>
- logger.info(s"Info: ignoring: ${s.serialize}")
- case other => other.foreachStmt(onStatement)
- }
-
- private def onRegister(r: ir.DefRegister): State = {
- val width = bitWidth(r.tpe).toInt
- val resetExpr = onExpression(r.reset, 1)
- assert(resetExpr == False(), s"Expected reset expression of ${r.name} to be 0, not $resetExpr")
- val initExpr = onExpression(r.init, width)
- val sym = BVSymbol(r.name, width)
- val hasReset = initExpr != sym
- val isPreset = presetRegs.contains(r.name)
- assert(!isPreset || hasReset, s"Expected preset register ${r.name} to have a reset value, not just $initExpr!")
- val state = State(sym, if (isPreset) Some(initExpr) else None, None)
- state
- }
-
- private def onInstance(info: ir.Info, name: String, module: String, tpe: ir.Type): Unit = {
- if (!tpe.isInstanceOf[ir.BundleType]) error(s"Instance $name of $module has an invalid type: ${tpe.serialize}")
- if (uninterpreted.contains(module)) {
- onUninterpretedInstance(info: ir.Info, name: String, module: String, tpe: ir.Type)
- } else {
- // We treat all instances that aren't annotated as uninterpreted as blackboxes
- // this means that their outputs could be any value, no matter what their inputs are.
- logger.warn(
- s"WARN: treating instance $name of $module as blackbox. " +
- "Please flatten your hierarchy if you want to include submodules in the formal model."
- )
- val ports = tpe.asInstanceOf[ir.BundleType].fields
- // skip async reset ports
- ports.filterNot(p => isAsyncReset(p.tpe)).foreach { p =>
- if (!p.tpe.isInstanceOf[ir.GroundType]) error(s"Instance $name of $module has an invalid port type: $p")
- val isOutput = p.flip == ir.Default
- val pName = name + "." + p.name
- infos.append(pName -> info)
- // outputs of the submodule become inputs to our module
- if (isOutput) {
- if (isClock(p.tpe)) {
- clocks.append(pName)
- } else {
- inputs.append(BVSymbol(pName, bitWidth(p.tpe).toInt))
- }
- }
- }
- }
- }
-
- private def onUninterpretedInstance(info: ir.Info, instanceName: String, module: String, tpe: ir.Type): Unit = {
- val anno = uninterpreted(module)
-
- // sanity checks for ports were done already using the UninterpretedModule.checkModule function
- val ports = tpe.asInstanceOf[ir.BundleType].fields
-
- val outputs = ports.filter(_.flip == ir.Default).map(p => BVSymbol(p.name, bitWidth(p.tpe).toInt))
- val inputs = ports.filterNot(_.flip == ir.Default).map(p => BVSymbol(p.name, bitWidth(p.tpe).toInt))
-
- assert(anno.stateBits == 0, "TODO: implement support for uninterpreted stateful modules!")
-
- // for state-less (i.e. combinatorial) circuits, the outputs only depend on the inputs
- val args = inputs.map(i => BVSymbol(instanceName + "." + i.name, i.width)).toList
- outputs.foreach { out =>
- val functionName = anno.prefix + "." + out.name
- val call = BVFunctionCall(functionName, args, out.width)
- val wireName = instanceName + "." + out.name
- signals.append(Signal(wireName, call))
- }
- }
-
- private def onMemory(m: ir.DefMemory): State = {
- checkMem(m)
-
- // derive the type of the memory from the dataType and depth
- val dataWidth = bitWidth(m.dataType).toInt
- val indexWidth = Utils.getUIntWidth(m.depth - 1).max(1)
- val memSymbol = ArraySymbol(m.name, indexWidth, dataWidth)
-
- // there could be a constant init
- val init = memInit.get(m.name).map(getMemInit(m, indexWidth, dataWidth, _))
- init.foreach(e => assert(e.dataWidth == memSymbol.dataWidth && e.indexWidth == memSymbol.indexWidth))
-
- // derive next state expression
- val next = if (m.writers.isEmpty) {
- memSymbol
- } else {
- m.writers.foldLeft[ArrayExpr](memSymbol) {
- case (prev, write) =>
- // update
- val addr = BVSymbol(memPortField(m, write, "addr").serialize, indexWidth)
- val data = BVSymbol(memPortField(m, write, "data").serialize, dataWidth)
- val update = ArrayStore(prev, index = addr, data = data)
-
- // update guard
- val en = BVSymbol(memPortField(m, write, "en").serialize, 1)
- val mask = BVSymbol(memPortField(m, write, "mask").serialize, 1)
- ArrayIte(BVAnd(en, mask), update, prev)
- }
- }
-
- val state = State(memSymbol, init, Some(next))
-
- // derive read expressions
- val readSignals = m.readers.map { read =>
- val addr = BVSymbol(memPortField(m, read, "addr").serialize, indexWidth)
- Signal(memPortField(m, read, "data").serialize, ArrayRead(memSymbol, addr), IsNode)
- }
- signals ++= readSignals
-
- state
- }
-
- private def getMemInit(m: ir.DefMemory, indexWidth: Int, dataWidth: Int, initValue: MemoryInitValue): ArrayExpr =
- initValue match {
- case MemoryScalarInit(value) => ArrayConstant(BVLiteral(value, dataWidth), indexWidth)
- case MemoryArrayInit(values) =>
- assert(
- values.length == m.depth,
- s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!"
- )
- // in order to get a more compact encoding try to find the most common values
- val histogram = mutable.LinkedHashMap[BigInt, Int]()
- values.foreach(v => histogram(v) = 1 + histogram.getOrElse(v, 0))
- val baseValue = histogram.maxBy(_._2)._1
- val base = ArrayConstant(BVLiteral(baseValue, dataWidth), indexWidth)
- values.zipWithIndex
- .filterNot(_._1 == baseValue)
- .foldLeft[ArrayExpr](base) {
- case (array, (value, index)) =>
- ArrayStore(array, BVLiteral(index, indexWidth), BVLiteral(value, dataWidth))
- }
- case other => throw new RuntimeException(s"Unsupported memory init option: $other")
- }
-
- private def checkMem(m: ir.DefMemory): Unit = {
- assert(m.readLatency == 0, "Expected read latency to be 0. Did you run VerilogMemDelays?")
- assert(m.writeLatency == 1, "Expected read latency to be 1. Did you run VerilogMemDelays?")
- assert(
- m.dataType.isInstanceOf[ir.GroundType],
- s"Memory $m is of type ${m.dataType} which is not a ground type!"
- )
- assert(m.readwriters.isEmpty, "Combined read/write ports are not supported! Please split them up.")
- }
-
- private def onExpression(e: ir.Expression, width: Int, allowNarrow: Boolean = false): BVExpr =
- FirrtlExpressionSemantics.toSMT(e, width, allowNarrow)
- private def onExpression(e: ir.Expression): BVExpr = FirrtlExpressionSemantics.toSMT(e)
-
- private def error(msg: String): Unit = throw new RuntimeException(msg)
- private def isGroundType(tpe: ir.Type): Boolean = tpe.isInstanceOf[ir.GroundType]
- private def isClock(tpe: ir.Type): Boolean = tpe == ir.ClockType
- private def isAsyncReset(tpe: ir.Type): Boolean = tpe == ir.AsyncResetType
-}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTCommand.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTCommand.scala
deleted file mode 100644
index 7b332b83..00000000
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTCommand.scala
+++ /dev/null
@@ -1,12 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
-
-package firrtl.backends.experimental.smt
-
-sealed trait SMTCommand
-case class Comment(msg: String) extends SMTCommand
-case class SetLogic(logic: String) extends SMTCommand
-case class DefineFunction(name: String, args: Seq[SMTFunctionArg], e: SMTExpr) extends SMTCommand
-case class DeclareFunction(sym: SMTSymbol, args: Seq[SMTFunctionArg]) extends SMTCommand
-case class DeclareUninterpretedSort(name: String) extends SMTCommand
-case class DeclareUninterpretedSymbol(name: String, tpe: String) extends SMTCommand
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala
deleted file mode 100644
index 45ec6898..00000000
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala
+++ /dev/null
@@ -1,81 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
-
-package firrtl.backends.experimental.smt
-
-import java.io.Writer
-
-import firrtl._
-import firrtl.annotations.{Annotation, NoTargetAnnotation}
-import firrtl.options.Viewer.view
-import firrtl.options.{CustomFileEmission, Dependency}
-import firrtl.stage.FirrtlOptions
-
-private[firrtl] abstract class SMTEmitter private[firrtl] ()
- extends Transform
- with Emitter
- with DependencyAPIMigration {
- override def prerequisites: Seq[Dependency[Transform]] = Seq(Dependency(FirrtlToTransitionSystem))
- override def invalidates(a: Transform): Boolean = false
-
- override def emit(state: CircuitState, writer: Writer): Unit = error("Deprecated since firrtl 1.0!")
-
- protected def serialize(sys: TransitionSystem): Annotation
-
- override protected def execute(state: CircuitState): CircuitState = {
- val emitCircuit = state.annotations.exists {
- case EmitCircuitAnnotation(a) if this.getClass == a => true
- case EmitAllModulesAnnotation(a) if this.getClass == a => error("EmitAllModulesAnnotation not supported!")
- case _ => false
- }
-
- if (!emitCircuit) { return state }
-
- val sys = state.annotations.collectFirst { case TransitionSystemAnnotation(sys) => sys }.getOrElse {
- error("Could not find the transition system!")
- }
- state.copy(annotations = state.annotations :+ serialize(sys))
- }
-
- protected def generatedHeader(format: String, name: String): String =
- s"; $format description generated by firrtl ${BuildInfo.version} for module $name.\n"
-
- protected def error(msg: String): Nothing = throw new RuntimeException(msg)
-}
-
-case class EmittedSMTModelAnnotation(name: String, src: String, outputSuffix: String)
- extends NoTargetAnnotation
- with CustomFileEmission {
- override protected def baseFileName(annotations: AnnotationSeq): String =
- view[FirrtlOptions](annotations).outputFileName.getOrElse(name)
- override protected def suffix: Option[String] = Some(outputSuffix)
- override def getBytes: Iterable[Byte] = src.getBytes
-}
-
-/** Turns the transition system generated by [[FirrtlToTransitionSystem]] into a btor2 file. */
-object Btor2Emitter extends SMTEmitter {
- override def outputSuffix: String = ".btor2"
- override protected def serialize(sys: TransitionSystem): Annotation = {
- val btor = generatedHeader("BTOR", sys.name) + Btor2Serializer.serialize(sys).mkString("\n") + "\n"
- EmittedSMTModelAnnotation(sys.name, btor, outputSuffix)
- }
-}
-
-/** Turns the transition system generated by [[FirrtlToTransitionSystem]] into an SMTLib file. */
-object SMTLibEmitter extends SMTEmitter {
- override def outputSuffix: String = ".smt2"
- override protected def serialize(sys: TransitionSystem): Annotation = {
- val hasMemory = sys.states.exists(_.sym.isInstanceOf[ArrayExpr])
- val logic = if (hasMemory) "QF_AUFBV" else "QF_UFBV"
- val logicCmd = SMTLibSerializer.serialize(SetLogic(logic)) + "\n"
- val header = if (hasMemory) {
- "; We have to disable the logic for z3 to accept the non-standard \"as const\"\n" +
- "; see https://github.com/Z3Prover/z3/issues/1803\n" +
- "; for CVC4 you probably want to include the logic\n" +
- ";" + logicCmd
- } else { logicCmd }
- val smt = generatedHeader("SMT-LIBv2", sys.name) + header +
- SMTTransitionSystemEncoder.encode(sys).map(SMTLibSerializer.serialize).mkString("\n") + "\n"
- EmittedSMTModelAnnotation(sys.name, smt, outputSuffix)
- }
-}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala
deleted file mode 100644
index f2eae58a..00000000
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala
+++ /dev/null
@@ -1,342 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
-// Inspired by the uclid5 SMT library (https://github.com/uclid-org/uclid).
-// And the btor2 documentation (BTOR2 , BtorMC and Boolector 3.0 by Niemetz et.al.)
-
-package firrtl.backends.experimental.smt
-
-/** base trait for all SMT expressions */
-sealed trait SMTExpr extends SMTFunctionArg {
- def tpe: SMTType
- def children: List[SMTExpr]
-}
-sealed trait SMTSymbol extends SMTExpr with SMTNullaryExpr {
- def name: String
-
- /** keeps the type of the symbol while changing the name */
- def rename(newName: String): SMTSymbol
-}
-object SMTSymbol {
-
- /** makes a SMTSymbol of the same type as the expression */
- def fromExpr(name: String, e: SMTExpr): SMTSymbol = e match {
- case b: BVExpr => BVSymbol(name, b.width)
- case a: ArrayExpr => ArraySymbol(name, a.indexWidth, a.dataWidth)
- }
-}
-sealed trait SMTNullaryExpr extends SMTExpr {
- override def children: List[SMTExpr] = List()
-}
-
-/** a SMT bit vector expression: https://smtlib.cs.uiowa.edu/theories-FixedSizeBitVectors.shtml */
-sealed trait BVExpr extends SMTExpr {
- def width: Int
- def tpe: BVType = BVType(width)
- override def toString: String = SMTExprSerializer.serialize(this)
-}
-case class BVLiteral(value: BigInt, width: Int) extends BVExpr with SMTNullaryExpr {
- private def minWidth = value.bitLength + (if (value <= 0) 1 else 0)
- assert(value >= 0, "Negative values are not supported! Please normalize by calculating 2s complement.")
- assert(width > 0, "Zero or negative width literals are not allowed!")
- assert(width >= minWidth, "Value (" + value.toString + ") too big for BitVector of width " + width + " bits.")
-}
-object BVLiteral {
- def apply(nums: String): BVLiteral = nums.head match {
- case 'b' => BVLiteral(BigInt(nums.drop(1), 2), nums.length - 1)
- }
-}
-case class BVSymbol(name: String, width: Int) extends BVExpr with SMTSymbol {
- assert(!name.contains("|"), s"Invalid id $name contains escape character `|`")
- assert(width > 0, "Zero width bit vectors are not supported!")
- override def rename(newName: String) = BVSymbol(newName, width)
-}
-
-sealed trait BVUnaryExpr extends BVExpr {
- def e: BVExpr
-
- /** same function, different child, e.g.: not(x) -- reapply(Y) --> not(Y) */
- def reapply(expr: BVExpr): BVUnaryExpr
- override def children: List[BVExpr] = List(e)
-}
-case class BVExtend(e: BVExpr, by: Int, signed: Boolean) extends BVUnaryExpr {
- assert(by >= 0, "Extension must be non-negative!")
- override val width: Int = e.width + by
- override def reapply(expr: BVExpr) = BVExtend(expr, by, signed)
-}
-// also known as bit extract operation
-case class BVSlice(e: BVExpr, hi: Int, lo: Int) extends BVUnaryExpr {
- assert(lo >= 0, s"lo (lsb) must be non-negative!")
- assert(hi >= lo, s"hi (msb) must not be smaller than lo (lsb): msb: $hi lsb: $lo")
- assert(e.width > hi, s"Out off bounds hi (msb) access: width: ${e.width} msb: $hi")
- override def width: Int = hi - lo + 1
- override def reapply(expr: BVExpr) = BVSlice(expr, hi, lo)
-}
-case class BVNot(e: BVExpr) extends BVUnaryExpr {
- override val width: Int = e.width
- override def reapply(expr: BVExpr) = new BVNot(expr)
-}
-case class BVNegate(e: BVExpr) extends BVUnaryExpr {
- override val width: Int = e.width
- override def reapply(expr: BVExpr) = BVNegate(expr)
-}
-
-case class BVReduceOr(e: BVExpr) extends BVUnaryExpr {
- override def width: Int = 1
- override def reapply(expr: BVExpr) = BVReduceOr(expr)
-}
-case class BVReduceAnd(e: BVExpr) extends BVUnaryExpr {
- override def width: Int = 1
- override def reapply(expr: BVExpr) = BVReduceAnd(expr)
-}
-case class BVReduceXor(e: BVExpr) extends BVUnaryExpr {
- override def width: Int = 1
- override def reapply(expr: BVExpr) = BVReduceXor(expr)
-}
-
-sealed trait BVBinaryExpr extends BVExpr {
- def a: BVExpr
- def b: BVExpr
- override def children: List[BVExpr] = List(a, b)
-
- /** same function, different child, e.g.: add(a,b) -- reapply(a,c) --> add(a,c) */
- def reapply(nA: BVExpr, nB: BVExpr): BVBinaryExpr
-}
-case class BVEqual(a: BVExpr, b: BVExpr) extends BVBinaryExpr {
- assert(a.width == b.width, s"Both argument need to be the same width!")
- override def width: Int = 1
- override def reapply(nA: BVExpr, nB: BVExpr) = BVEqual(nA, nB)
-}
-// added as a separate node because it is used a lot in model checking and benefits from pretty printing
-class BVImplies(val a: BVExpr, val b: BVExpr) extends BVBinaryExpr {
- assert(a.width == 1, s"The antecedent needs to be a boolean expression!")
- assert(b.width == 1, s"The consequent needs to be a boolean expression!")
- override def width: Int = 1
- override def reapply(nA: BVExpr, nB: BVExpr) = new BVImplies(nA, nB)
-}
-object BVImplies {
- def apply(a: BVExpr, b: BVExpr): BVExpr = {
- assert(a.width == b.width, s"Both argument need to be the same width!")
- (a, b) match {
- case (True(), b) => b // (!1 || b) = b
- case (False(), _) => True() // (!0 || _) = (1 || _) = 1
- case (_, True()) => True() // (!a || 1) = 1
- case (a, False()) => BVNot(a) // (!a || 0) = !a
- case (a, b) => new BVImplies(a, b)
- }
- }
- def unapply(i: BVImplies): Some[(BVExpr, BVExpr)] = Some((i.a, i.b))
-}
-
-object Compare extends Enumeration {
- val Greater, GreaterEqual = Value
-}
-case class BVComparison(op: Compare.Value, a: BVExpr, b: BVExpr, signed: Boolean) extends BVBinaryExpr {
- assert(a.width == b.width, s"Both argument need to be the same width!")
- override def width: Int = 1
- override def reapply(nA: BVExpr, nB: BVExpr) = BVComparison(op, nA, nB, signed)
-}
-
-object Op extends Enumeration {
- val Xor = Value("xor")
- val ShiftLeft = Value("logical_shift_left")
- val ArithmeticShiftRight = Value("arithmetic_shift_right")
- val ShiftRight = Value("logical_shift_right")
- val Add = Value("add")
- val Mul = Value("mul")
- val SignedDiv = Value("sdiv")
- val UnsignedDiv = Value("udiv")
- val SignedMod = Value("smod")
- val SignedRem = Value("srem")
- val UnsignedRem = Value("urem")
- val Sub = Value("sub")
-}
-case class BVOp(op: Op.Value, a: BVExpr, b: BVExpr) extends BVBinaryExpr {
- assert(a.width == b.width, s"Both argument need to be the same width!")
- override val width: Int = a.width
- override def reapply(nA: BVExpr, nB: BVExpr) = BVOp(op, nA, nB)
-}
-case class BVConcat(a: BVExpr, b: BVExpr) extends BVBinaryExpr {
- override val width: Int = a.width + b.width
- override def reapply(nA: BVExpr, nB: BVExpr) = BVConcat(nA, nB)
-}
-case class ArrayRead(array: ArrayExpr, index: BVExpr) extends BVExpr {
- assert(array.indexWidth == index.width, "Index with does not match expected array index width!")
- override val width: Int = array.dataWidth
- override def children: List[SMTExpr] = List(array, index)
-}
-case class BVIte(cond: BVExpr, tru: BVExpr, fals: BVExpr) extends BVExpr {
- assert(cond.width == 1, s"Condition needs to be a 1-bit value not ${cond.width}-bit!")
- assert(tru.width == fals.width, s"Both branches need to be of the same width! ${tru.width} vs ${fals.width}")
- override val width: Int = tru.width
- override def children: List[BVExpr] = List(cond, tru, fals)
-}
-
-case class BVAnd(terms: List[BVExpr]) extends BVExpr {
- require(terms.size > 1)
- override val width: Int = terms.head.width
- require(terms.forall(_.width == width))
- override def children: List[BVExpr] = terms
-}
-
-case class BVOr(terms: List[BVExpr]) extends BVExpr {
- require(terms.size > 1)
- override val width: Int = terms.head.width
- require(terms.forall(_.width == width))
- override def children: List[BVExpr] = terms
-}
-
-sealed trait ArrayExpr extends SMTExpr {
- val indexWidth: Int
- val dataWidth: Int
- def tpe: ArrayType = ArrayType(indexWidth = indexWidth, dataWidth = dataWidth)
- override def toString: String = SMTExprSerializer.serialize(this)
-}
-case class ArraySymbol(name: String, indexWidth: Int, dataWidth: Int) extends ArrayExpr with SMTSymbol {
- assert(!name.contains("|"), s"Invalid id $name contains escape character `|`")
- assert(!name.contains("\\"), s"Invalid id $name contains `\\`")
- override def rename(newName: String) = ArraySymbol(newName, indexWidth, dataWidth)
-}
-case class ArrayConstant(e: BVExpr, indexWidth: Int) extends ArrayExpr {
- override val dataWidth: Int = e.width
- override def children: List[SMTExpr] = List(e)
-}
-case class ArrayEqual(a: ArrayExpr, b: ArrayExpr) extends BVExpr {
- assert(a.indexWidth == b.indexWidth, s"Both argument need to be the same index width!")
- assert(a.dataWidth == b.dataWidth, s"Both argument need to be the same data width!")
- override def width: Int = 1
- override def children: List[SMTExpr] = List(a, b)
-}
-case class ArrayStore(array: ArrayExpr, index: BVExpr, data: BVExpr) extends ArrayExpr {
- assert(array.indexWidth == index.width, "Index with does not match expected array index width!")
- assert(array.dataWidth == data.width, "Data with does not match expected array data width!")
- override val dataWidth: Int = array.dataWidth
- override val indexWidth: Int = array.indexWidth
- override def children: List[SMTExpr] = List(array, index, data)
-}
-case class ArrayIte(cond: BVExpr, tru: ArrayExpr, fals: ArrayExpr) extends ArrayExpr {
- assert(cond.width == 1, s"Condition needs to be a 1-bit value not ${cond.width}-bit!")
- assert(
- tru.indexWidth == fals.indexWidth,
- s"Both branches need to be of the same type! ${tru.indexWidth} vs ${fals.indexWidth}"
- )
- assert(
- tru.dataWidth == fals.dataWidth,
- s"Both branches need to be of the same type! ${tru.dataWidth} vs ${fals.dataWidth}"
- )
- override val dataWidth: Int = tru.dataWidth
- override val indexWidth: Int = tru.indexWidth
- override def children: List[SMTExpr] = List(cond, tru, fals)
-}
-
-case class BVForall(variable: BVSymbol, e: BVExpr) extends BVUnaryExpr {
- assert(e.width == 1, "Can only quantify over boolean expressions!")
- override def width = 1
- override def reapply(expr: BVExpr) = BVForall(variable, expr)
-}
-
-/** apply arguments to a function which returns a result of bit vector type */
-case class BVFunctionCall(name: String, args: List[SMTFunctionArg], width: Int) extends BVExpr {
- override def children = args.map(_.asInstanceOf[SMTExpr])
-}
-
-/** apply arguments to a function which returns a result of array type */
-case class ArrayFunctionCall(name: String, args: List[SMTFunctionArg], indexWidth: Int, dataWidth: Int)
- extends ArrayExpr {
- override def children = args.map(_.asInstanceOf[SMTExpr])
-}
-sealed trait SMTFunctionArg
-// we allow symbols with uninterpreted type to be function arguments
-case class UTSymbol(name: String, tpe: String) extends SMTFunctionArg
-
-object BVAnd {
- def apply(a: BVExpr, b: BVExpr): BVExpr = {
- assert(a.width == b.width, s"Both argument need to be the same width!")
- (a, b) match {
- case (True(), b) => b
- case (a, True()) => a
- case (False(), _) => False()
- case (_, False()) => False()
- case (a, b) => new BVAnd(List(a, b))
- }
- }
- def apply(exprs: List[BVExpr]): BVExpr = {
- assert(exprs.nonEmpty, "Don't know what to do with an empty list!")
- val nonTriviallyTrue = exprs.filterNot(_ == True())
- nonTriviallyTrue.distinct match {
- case Seq() => True()
- case Seq(one) => one
- case terms => new BVAnd(terms)
- }
- }
-}
-object BVOr {
- def apply(a: BVExpr, b: BVExpr): BVExpr = {
- assert(a.width == b.width, s"Both argument need to be the same width!")
- (a, b) match {
- case (True(), _) => True()
- case (_, True()) => True()
- case (False(), b) => b
- case (a, False()) => a
- case (a, b) => new BVOr(List(a, b))
- }
- }
- def apply(exprs: List[BVExpr]): BVExpr = {
- assert(exprs.nonEmpty, "Don't know what to do with an empty list!")
- val nonTriviallyFalse = exprs.filterNot(_ == False())
- nonTriviallyFalse.distinct match {
- case Seq() => False()
- case Seq(one) => one
- case terms => new BVOr(terms)
- }
- }
-}
-
-object BVNot {
- def apply(e: BVExpr): BVExpr = e match {
- case True() => False()
- case False() => True()
- case BVNot(inner) => inner
- case other => new BVNot(other)
- }
-}
-
-object SMTEqual {
- def apply(a: SMTExpr, b: SMTExpr): BVExpr = (a, b) match {
- case (ab: BVExpr, bb: BVExpr) => BVEqual(ab, bb)
- case (aa: ArrayExpr, ba: ArrayExpr) => ArrayEqual(aa, ba)
- case _ => throw new RuntimeException(s"Cannot compare $a and $b")
- }
-}
-
-object SMTIte {
- def apply(cond: BVExpr, tru: SMTExpr, fals: SMTExpr): SMTExpr = (tru, fals) match {
- case (ab: BVExpr, bb: BVExpr) => BVIte(cond, ab, bb)
- case (aa: ArrayExpr, ba: ArrayExpr) => ArrayIte(cond, aa, ba)
- case _ => throw new RuntimeException(s"Cannot mux $tru and $fals")
- }
-}
-
-object SMTExpr {
- def serializeType(e: SMTExpr): String = e match {
- case b: BVExpr => s"bv<${b.width}>"
- case a: ArrayExpr => s"bv<${a.indexWidth}> -> bv<${a.dataWidth}>"
- }
-}
-
-// unapply for matching BVLiteral(1, 1)
-object True {
- private val _True = BVLiteral(1, 1)
- def apply(): BVLiteral = _True
- def unapply(l: BVLiteral): Boolean = l.value == 1 && l.width == 1
-}
-
-// unapply for matching BVLiteral(0, 1)
-object False {
- private val _False = BVLiteral(0, 1)
- def apply(): BVLiteral = _False
- def unapply(l: BVLiteral): Boolean = l.value == 0 && l.width == 1
-}
-
-sealed trait SMTType
-case class BVType(width: Int) extends SMTType
-case class ArrayType(indexWidth: Int, dataWidth: Int) extends SMTType
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExprMap.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExprMap.scala
deleted file mode 100644
index 8e035186..00000000
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTExprMap.scala
+++ /dev/null
@@ -1,88 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
-package firrtl.backends.experimental.smt
-
-object SMTExprMap {
-
- /** maps f over subexpressions of expr and returns expr with the results replaced */
- def mapExpr(expr: SMTExpr, f: SMTExpr => SMTExpr): SMTExpr = {
- val bv = (b: BVExpr) => f(b).asInstanceOf[BVExpr]
- val ar = (a: ArrayExpr) => f(a).asInstanceOf[ArrayExpr]
- expr match {
- case b: BVExpr => mapExpr(b, bv, ar)
- case a: ArrayExpr => mapExpr(a, bv, ar)
- }
- }
-
- /** maps bv/ar over subexpressions of expr and returns expr with the results replaced */
- def mapExpr(expr: BVExpr, bv: BVExpr => BVExpr, ar: ArrayExpr => ArrayExpr): BVExpr = expr match {
- // nullary
- case old: BVLiteral => old
- case old: BVSymbol => old
- // unary
- case old @ BVExtend(e, by, signed) => val n = bv(e); if (n.eq(e)) old else BVExtend(n, by, signed)
- case old @ BVSlice(e, hi, lo) => val n = bv(e); if (n.eq(e)) old else BVSlice(n, hi, lo)
- case old @ BVNot(e) => val n = bv(e); if (n.eq(e)) old else BVNot(n)
- case old @ BVNegate(e) => val n = bv(e); if (n.eq(e)) old else BVNegate(n)
- case old @ BVForall(variables, e) => val n = bv(e); if (n.eq(e)) old else BVForall(variables, n)
- case old @ BVReduceAnd(e) => val n = bv(e); if (n.eq(e)) old else BVReduceAnd(n)
- case old @ BVReduceOr(e) => val n = bv(e); if (n.eq(e)) old else BVReduceOr(n)
- case old @ BVReduceXor(e) => val n = bv(e); if (n.eq(e)) old else BVReduceXor(n)
- // binary
- case old @ BVEqual(a, b) =>
- val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVEqual(nA, nB)
- case old @ ArrayEqual(a, b) =>
- val (nA, nB) = (ar(a), ar(b)); if (nA.eq(a) && nB.eq(b)) old else ArrayEqual(nA, nB)
- case old @ BVComparison(op, a, b, signed) =>
- val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVComparison(op, nA, nB, signed)
- case old @ BVOp(op, a, b) =>
- val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVOp(op, nA, nB)
- case old @ BVConcat(a, b) =>
- val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVConcat(nA, nB)
- case old @ ArrayRead(a, b) =>
- val (nA, nB) = (ar(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else ArrayRead(nA, nB)
- case old @ BVImplies(a, b) =>
- val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVImplies(nA, nB)
- // ternary
- case old @ BVIte(a, b, c) =>
- val (nA, nB, nC) = (bv(a), bv(b), bv(c))
- if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else BVIte(nA, nB, nC)
- // n-ary
- case old @ BVFunctionCall(name, args, width) =>
- val nArgs = args.map {
- case b: BVExpr => bv(b)
- case a: ArrayExpr => ar(a)
- case u: UTSymbol => u
- }
- val anyNew = nArgs.zip(args).exists { case (n, o) => !n.eq(o) }
- if (anyNew) BVFunctionCall(name, nArgs, width) else old
- case old @ BVAnd(terms) =>
- val nTerms = terms.map(bv)
- val anyNew = nTerms.zip(terms).exists { case (n, o) => !n.eq(o) }
- if (anyNew) BVAnd(nTerms) else old
- case old @ BVOr(terms) =>
- val nTerms = terms.map(bv)
- val anyNew = nTerms.zip(terms).exists { case (n, o) => !n.eq(o) }
- if (anyNew) BVOr(nTerms) else old
- }
-
- /** maps bv/ar over subexpressions of expr and returns expr with the results replaced */
- def mapExpr(expr: ArrayExpr, bv: BVExpr => BVExpr, ar: ArrayExpr => ArrayExpr): ArrayExpr = expr match {
- case old: ArraySymbol => old
- case old @ ArrayConstant(e, indexWidth) => val n = bv(e); if (n.eq(e)) old else ArrayConstant(n, indexWidth)
- case old @ ArrayStore(a, b, c) =>
- val (nA, nB, nC) = (ar(a), bv(b), bv(c))
- if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayStore(nA, nB, nC)
- case old @ ArrayIte(a, b, c) =>
- val (nA, nB, nC) = (bv(a), ar(b), ar(c))
- if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayIte(nA, nB, nC)
- case old @ ArrayFunctionCall(name, args, indexWidth, dataWidth) =>
- val nArgs = args.map {
- case b: BVExpr => bv(b)
- case a: ArrayExpr => ar(a)
- case u: UTSymbol => u
- }
- val anyNew = nArgs.zip(args).exists { case (n, o) => !n.eq(o) }
- if (anyNew) ArrayFunctionCall(name, nArgs, indexWidth, dataWidth) else old
- }
-}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExprSerializer.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExprSerializer.scala
deleted file mode 100644
index 4aaf78a2..00000000
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTExprSerializer.scala
+++ /dev/null
@@ -1,60 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
-
-package firrtl.backends.experimental.smt
-
-private object SMTExprSerializer {
- def serialize(expr: BVExpr): String = expr match {
- // nullary
- case lit: BVLiteral =>
- if (lit.width <= 8) {
- lit.width.toString + "'b" + lit.value.toString(2)
- } else {
- lit.width.toString + "'x" + lit.value.toString(16)
- }
- case BVSymbol(name, _) => name
- // unary
- case BVExtend(e, by, false) => s"zext(${serialize(e)}, $by)"
- case BVExtend(e, by, true) => s"sext(${serialize(e)}, $by)"
- case BVSlice(e, hi, lo) if hi == lo => s"${serialize(e)}[$hi]"
- case BVSlice(e, hi, lo) => s"${serialize(e)}[$hi:$lo]"
- case BVNot(e) => s"not(${serialize(e)})"
- case BVNegate(e) => s"neg(${serialize(e)})"
- case BVForall(variable, e) => s"forall(${variable.name} : bv<${variable.width}, ${serialize(e)})"
- case BVReduceAnd(e) => s"redand(${serialize(e)})"
- case BVReduceOr(e) => s"redor(${serialize(e)})"
- case BVReduceXor(e) => s"redxor(${serialize(e)})"
- // binary
- case BVEqual(a, b) => s"eq(${serialize(a)}, ${serialize(b)})"
- case BVComparison(Compare.Greater, a, b, false) => s"ugt(${serialize(a)}, ${serialize(b)})"
- case BVComparison(Compare.Greater, a, b, true) => s"sgt(${serialize(a)}, ${serialize(b)})"
- case BVComparison(Compare.GreaterEqual, a, b, false) => s"ugeq(${serialize(a)}, ${serialize(b)})"
- case BVComparison(Compare.GreaterEqual, a, b, true) => s"sgeq(${serialize(a)}, ${serialize(b)})"
- case BVOp(op, a, b) => s"$op(${serialize(a)}, ${serialize(b)})"
- case BVConcat(a, b) => s"concat(${serialize(a)}, ${serialize(b)})"
- case ArrayRead(array, index) => s"${serialize(array)}[${serialize(index)}]"
- case ArrayEqual(a, b) => s"eq(${serialize(a)}, ${serialize(b)})"
- case BVImplies(a, b) => s"implies(${serialize(a)}, ${serialize(b)})"
- // ternary
- case BVIte(cond, tru, fals) => s"ite(${serialize(cond)}, ${serialize(tru)}, ${serialize(fals)})"
- // n-ary
- case BVFunctionCall(name, args, _) => name + serialize(args).mkString("(", ",", ")")
- case BVAnd(terms) => terms.map(serialize).mkString("and(", ", ", ")")
- case BVOr(terms) => terms.map(serialize).mkString("or(", ", ", ")")
- }
-
- def serialize(expr: ArrayExpr): String = expr match {
- case ArraySymbol(name, _, _) => name
- case ArrayConstant(e, indexWidth) => s"([${serialize(e)}] x ${(BigInt(1) << indexWidth)})"
- case ArrayStore(array, index, data) => s"${serialize(array)}[${serialize(index)} := ${serialize(data)}]"
- case ArrayIte(cond, tru, fals) => s"ite(${serialize(cond)}, ${serialize(tru)}, ${serialize(fals)})"
- case ArrayFunctionCall(name, args, _, _) => name + serialize(args).mkString("(", ",", ")")
- }
-
- private def serialize(args: Iterable[SMTFunctionArg]): Iterable[String] =
- args.map {
- case b: BVExpr => serialize(b)
- case a: ArrayExpr => serialize(a)
- case u: UTSymbol => u.name
- }
-}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala
deleted file mode 100644
index 20a499b9..00000000
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala
+++ /dev/null
@@ -1,177 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
-
-package firrtl.backends.experimental.smt
-
-import scala.util.matching.Regex
-
-/** Converts STM Expressions to a SMTLib compatible string representation.
- * See http://smtlib.cs.uiowa.edu/
- * Assumes well typed expression, so it is advisable to run the TypeChecker
- * before serializing!
- * Automatically converts 1-bit vectors to bool.
- */
-object SMTLibSerializer {
- def serialize(e: SMTExpr): String = e match {
- case b: BVExpr => serialize(b)
- case a: ArrayExpr => serialize(a)
- }
-
- def serialize(t: SMTType): String = t match {
- case BVType(width) => serializeBitVectorType(width)
- case ArrayType(indexWidth, dataWidth) => serializeArrayType(indexWidth, dataWidth)
- }
-
- private def serialize(e: BVExpr): String = e match {
- case BVLiteral(value, width) =>
- val mask = (BigInt(1) << width) - 1
- val twosComplement = if (value < 0) { ((~(-value)) & mask) + 1 }
- else value
- if (width == 1) {
- if (twosComplement == 1) "true" else "false"
- } else {
- s"(_ bv$twosComplement $width)"
- }
- case BVSymbol(name, _) => escapeIdentifier(name)
- case BVExtend(e, 0, _) => serialize(e)
- case BVExtend(BVLiteral(value, width), by, false) => serialize(BVLiteral(value, width + by))
- case BVExtend(e, by, signed) =>
- val foo = if (signed) "sign_extend" else "zero_extend"
- s"((_ $foo $by) ${asBitVector(e)})"
- case BVSlice(e, hi, lo) =>
- if (lo == 0 && hi == e.width - 1) { serialize(e) }
- else {
- val bits = s"((_ extract $hi $lo) ${asBitVector(e)})"
- // 1-bit extracts need to be turned into a boolean
- if (lo == hi) { toBool(bits) }
- else { bits }
- }
- case BVNot(BVEqual(a, b)) if a.width == 1 => s"(distinct ${serialize(a)} ${serialize(b)})"
- case BVNot(BVNot(e)) => serialize(e)
- case BVNot(e) =>
- if (e.width == 1) { s"(not ${serialize(e)})" }
- else { s"(bvnot ${serialize(e)})" }
- case BVNegate(e) => s"(bvneg ${asBitVector(e)})"
- case r: BVReduceAnd => serialize(Expander.expand(r))
- case r: BVReduceOr => serialize(Expander.expand(r))
- case r: BVReduceXor => serialize(Expander.expand(r))
- case BVImplies(BVLiteral(v, 1), b) if v == 1 => serialize(b)
- case BVImplies(a, b) => s"(=> ${serialize(a)} ${serialize(b)})"
- case BVEqual(a, b) => s"(= ${serialize(a)} ${serialize(b)})"
- case ArrayEqual(a, b) => s"(= ${serialize(a)} ${serialize(b)})"
- case BVComparison(Compare.Greater, a, b, false) => s"(bvugt ${asBitVector(a)} ${asBitVector(b)})"
- case BVComparison(Compare.GreaterEqual, a, b, false) => s"(bvuge ${asBitVector(a)} ${asBitVector(b)})"
- case BVComparison(Compare.Greater, a, b, true) => s"(bvsgt ${asBitVector(a)} ${asBitVector(b)})"
- case BVComparison(Compare.GreaterEqual, a, b, true) => s"(bvsge ${asBitVector(a)} ${asBitVector(b)})"
- // boolean operations get a special treatment for 1-bit vectors aka bools
- case b: BVAnd => serializeVariadic(if (b.width == 1) "and" else "bvand", b.terms)
- case b: BVOr => serializeVariadic(if (b.width == 1) "or" else "bvor", b.terms)
- case BVOp(Op.Xor, a, b) if a.width == 1 => s"(xor ${serialize(a)} ${serialize(b)})"
- case BVOp(op, a, b) if a.width == 1 => toBool(s"(${serialize(op)} ${asBitVector(a)} ${asBitVector(b)})")
- case BVOp(op, a, b) => s"(${serialize(op)} ${serialize(a)} ${serialize(b)})"
- case BVConcat(a, b) => s"(concat ${asBitVector(a)} ${asBitVector(b)})"
- case ArrayRead(array, index) => s"(select ${serialize(array)} ${serialize(index)})"
- case BVIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})"
- case BVFunctionCall(name, args, _) => args.map(serializeArg).mkString(s"($name ", " ", ")")
- case BVForall(variable, e) => s"(forall ((${variable.name} ${serialize(variable.tpe)})) ${serialize(e)})"
- }
-
- private def serializeVariadic(op: String, terms: List[BVExpr]): String = terms match {
- case Seq() | Seq(_) => throw new RuntimeException(s"expected at least two elements in variadic op $op")
- case Seq(a, b) => s"($op ${serialize(a)} ${serialize(b)})"
- case head :: tail => s"($op ${serialize(head)} ${serializeVariadic(op, tail)})"
- }
-
- def serialize(e: ArrayExpr): String = e match {
- case ArraySymbol(name, _, _) => escapeIdentifier(name)
- case ArrayStore(array, index, data) => s"(store ${serialize(array)} ${serialize(index)} ${serialize(data)})"
- case ArrayIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})"
- case c @ ArrayConstant(e, _) => s"((as const ${serializeArrayType(c.indexWidth, c.dataWidth)}) ${serialize(e)})"
- case ArrayFunctionCall(name, args, _, _) => args.map(serializeArg).mkString(s"($name ", " ", ")")
- }
-
- def serialize(c: SMTCommand): String = c match {
- case Comment(msg) => msg.split("\n").map("; " + _).mkString("\n")
- case DeclareUninterpretedSort(name) => s"(declare-sort ${escapeIdentifier(name)} 0)"
- case DefineFunction(name, args, e) =>
- val aa = args.map(a => s"(${serializeArg(a)} ${serializeArgTpe(a)})").mkString(" ")
- s"(define-fun ${escapeIdentifier(name)} ($aa) ${serialize(e.tpe)} ${serialize(e)})"
- case DeclareFunction(sym, tpes) =>
- val aa = tpes.map(serializeArgTpe).mkString(" ")
- s"(declare-fun ${escapeIdentifier(sym.name)} ($aa) ${serialize(sym.tpe)})"
- case SetLogic(logic) => s"(set-logic $logic)"
- case DeclareUninterpretedSymbol(name, tpe) =>
- s"(declare-fun ${escapeIdentifier(name)} () ${escapeIdentifier(tpe)})"
- }
-
- private def serializeArgTpe(a: SMTFunctionArg): String =
- a match {
- case u: UTSymbol => escapeIdentifier(u.tpe)
- case s: SMTExpr => serialize(s.tpe)
- }
- private def serializeArg(a: SMTFunctionArg): String =
- a match {
- case u: UTSymbol => escapeIdentifier(u.name)
- case s: SMTExpr => serialize(s)
- }
-
- private def serializeArrayType(indexWidth: Int, dataWidth: Int): String =
- s"(Array ${serializeBitVectorType(indexWidth)} ${serializeBitVectorType(dataWidth)})"
- private def serializeBitVectorType(width: Int): String =
- if (width == 1) { "Bool" }
- else { assert(width > 1); s"(_ BitVec $width)" }
-
- private def serialize(op: Op.Value): String = op match {
- case Op.Xor => "bvxor"
- case Op.ArithmeticShiftRight => "bvashr"
- case Op.ShiftRight => "bvlshr"
- case Op.ShiftLeft => "bvshl"
- case Op.Add => "bvadd"
- case Op.Mul => "bvmul"
- case Op.Sub => "bvsub"
- case Op.SignedDiv => "bvsdiv"
- case Op.UnsignedDiv => "bvudiv"
- case Op.SignedMod => "bvsmod"
- case Op.SignedRem => "bvsrem"
- case Op.UnsignedRem => "bvurem"
- }
-
- private def toBool(e: String): String = s"(= $e (_ bv1 1))"
-
- private val bvZero = "(_ bv0 1)"
- private val bvOne = "(_ bv1 1)"
- private def asBitVector(e: BVExpr): String =
- if (e.width > 1) { serialize(e) }
- else { s"(ite ${serialize(e)} $bvOne $bvZero)" }
-
- // See <simple_symbol> definition in the Concrete Syntax Appendix of the SMTLib Spec
- private val simple: Regex = raw"[a-zA-Z\+-/\*\=%\?!\.$$_~&\^<>@][a-zA-Z0-9\+-/\*\=%\?!\.$$_~&\^<>@]*".r
- def escapeIdentifier(name: String): String = name match {
- case simple() => name
- case _ => if (name.startsWith("|") && name.endsWith("|")) name else s"|$name|"
- }
-}
-
-/** Expands expressions that are not natively supported by SMTLib */
-private object Expander {
- def expand(r: BVReduceAnd): BVExpr = {
- if (r.e.width == 1) { r.e }
- else {
- val allOnes = (BigInt(1) << r.e.width) - 1
- BVEqual(r.e, BVLiteral(allOnes, r.e.width))
- }
- }
- def expand(r: BVReduceOr): BVExpr = {
- if (r.e.width == 1) { r.e }
- else {
- BVNot(BVEqual(r.e, BVLiteral(0, r.e.width)))
- }
- }
- def expand(r: BVReduceXor): BVExpr = {
- if (r.e.width == 1) { r.e }
- else {
- val bits = (0 until r.e.width).map(ii => BVSlice(r.e, ii, ii))
- bits.reduce[BVExpr]((a, b) => BVOp(Op.Xor, a, b))
- }
- }
-}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala
deleted file mode 100644
index 4f096c28..00000000
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala
+++ /dev/null
@@ -1,133 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
-
-package firrtl.backends.experimental.smt
-
-import scala.collection.mutable
-
-/** This Transition System encoding is directly inspired by yosys' SMT backend:
- * https://github.com/YosysHQ/yosys/blob/master/backends/smt2/smt2.cc
- * It if fairly compact, but unfortunately, the use of an uninterpreted sort for the state
- * prevents this encoding from working with boolector.
- * For simplicity reasons, we do not support hierarchical designs (no `_h` function).
- */
-object SMTTransitionSystemEncoder {
-
- def encode(sys: TransitionSystem): Iterable[SMTCommand] = {
- val cmds = mutable.ArrayBuffer[SMTCommand]()
- val name = sys.name
-
- // declare UFs if necessary
- cmds ++= TransitionSystem.findUninterpretedFunctions(sys)
-
- // emit header as comments
- if (sys.header.nonEmpty) {
- cmds ++= sys.header.split('\n').map(Comment)
- }
-
- // declare state type
- val stateType = id(name + "_s")
- cmds += DeclareUninterpretedSort(stateType)
-
- // state symbol
- val State = UTSymbol("state", stateType)
- val StateNext = UTSymbol("state_n", stateType)
-
- // inputs and states are modelled as constants
- def declare(sym: SMTSymbol, kind: String): Unit = {
- cmds ++= toDescription(sym, kind, sys.comments.get)
- val s = SMTSymbol.fromExpr(sym.name + SignalSuffix, sym)
- cmds += DeclareFunction(s, List(State))
- }
- sys.inputs.foreach(i => declare(i, "input"))
- sys.states.foreach(s => declare(s.sym, "register"))
-
- // signals are just functions of other signals, inputs and state
- def define(sym: SMTSymbol, e: SMTExpr, suffix: String = SignalSuffix): Unit = {
- val withReplacedSymbols = replaceSymbols(SignalSuffix, State)(e)
- cmds += DefineFunction(sym.name + suffix, List(State), withReplacedSymbols)
- }
- sys.signals.foreach { signal =>
- val sym = signal.sym
- cmds ++= toDescription(sym, lblToKind(signal.lbl), sys.comments.get)
- val e = if (signal.lbl == IsBad) BVNot(signal.e.asInstanceOf[BVExpr]) else signal.e
- define(sym, e)
- }
-
- // define the next and init functions for all states
- sys.states.foreach { state =>
- assert(state.next.nonEmpty, "Next function required")
- define(state.sym, state.next.get, NextSuffix)
- // init is optional
- state.init.foreach { init =>
- define(state.sym, init, InitSuffix)
- }
- }
-
- def defineConjunction(e: List[BVExpr], suffix: String): Unit = {
- define(BVSymbol(name, 1), if (e.isEmpty) True() else BVAnd(e), suffix)
- }
-
- // the transition relation asserts that the value of the next state is the next value from the previous state
- // e.g., (reg state_n) == (reg_next state)
- val transitionRelations = sys.states.map { state =>
- val newState = replaceSymbols(SignalSuffix, StateNext)(state.sym)
- val nextOldState = replaceSymbols(NextSuffix, State)(state.sym)
- SMTEqual(newState, nextOldState)
- }
- // the transition relation is over two states
- val transitionExpr = if (transitionRelations.isEmpty) { True() }
- else {
- replaceSymbols(SignalSuffix, State)(BVAnd(transitionRelations))
- }
- cmds += DefineFunction(name + "_t", List(State, StateNext), transitionExpr)
-
- // The init relation just asserts that all init function hold
- val initRelations = sys.states.filter(_.init.isDefined).map { state =>
- val stateSignal = replaceSymbols(SignalSuffix, State)(state.sym)
- val initSignal = replaceSymbols(InitSuffix, State)(state.sym)
- SMTEqual(stateSignal, initSignal)
- }
- defineConjunction(initRelations, "_i")
-
- // assertions and assumptions
- val assertions = sys.signals.filter(_.lbl == IsBad).map(a => replaceSymbols(SignalSuffix, State)(a.sym))
- defineConjunction(assertions.map(_.asInstanceOf[BVExpr]), AssertionSuffix)
- val assumptions = sys.signals.filter(_.lbl == IsConstraint).map(a => replaceSymbols(SignalSuffix, State)(a.sym))
- defineConjunction(assumptions.map(_.asInstanceOf[BVExpr]), AssumptionSuffix)
-
- cmds
- }
-
- private def id(s: String): String = SMTLibSerializer.escapeIdentifier(s)
- private val SignalSuffix = "_f"
- private val NextSuffix = "_next"
- private val InitSuffix = "_init"
- val AssertionSuffix = "_a"
- val AssumptionSuffix = "_u"
- private def lblToKind(lbl: SignalLabel): String = lbl match {
- case IsNode | IsInit | IsNext => "wire"
- case IsOutput => "output"
- // for the SMT encoding we turn bad state signals back into assertions
- case IsBad => "assert"
- case IsConstraint => "assume"
- case IsFair => "fair"
- }
- private def toDescription(sym: SMTSymbol, kind: String, comments: String => Option[String]): List[Comment] = {
- List(sym match {
- case BVSymbol(name, width) => Comment(s"firrtl-smt2-$kind $name $width")
- case ArraySymbol(name, indexWidth, dataWidth) =>
- Comment(s"firrtl-smt2-$kind $name $indexWidth $dataWidth")
- }) ++ comments(sym.name).map(Comment)
- }
- // All signals are modelled with functions that need to be called with the state as argument,
- // this replaces all Symbols with function applications to the state.
- private def replaceSymbols(suffix: String, arg: SMTFunctionArg, vars: Set[String] = Set())(e: SMTExpr): SMTExpr =
- e match {
- case BVSymbol(name, width) if !vars(name) => BVFunctionCall(id(name + suffix), List(arg), width)
- case ArraySymbol(name, indexWidth, dataWidth) if !vars(name) =>
- ArrayFunctionCall(id(name + suffix), List(arg), indexWidth, dataWidth)
- case fa @ BVForall(variable, _) => SMTExprMap.mapExpr(fa, replaceSymbols(suffix, arg, vars + variable.name))
- case other => SMTExprMap.mapExpr(other, replaceSymbols(suffix, arg, vars))
- }
-}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala b/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala
deleted file mode 100644
index 534db217..00000000
--- a/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala
+++ /dev/null
@@ -1,272 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
-
-package firrtl.backends.experimental.smt
-
-import firrtl._
-import firrtl.annotations._
-import firrtl.ir.EmptyStmt
-import firrtl.options.Dependency
-import firrtl.passes.PassException
-import firrtl.stage.Forms
-import firrtl.stage.TransformManager.TransformDependency
-import firrtl.transforms.PropagatePresetAnnotations
-import firrtl.renamemap.MutableRenameMap
-
-import scala.collection.mutable
-
-case class GlobalClockAnnotation(target: ReferenceTarget) extends SingleTargetAnnotation[ReferenceTarget] {
- override def duplicate(n: ReferenceTarget): Annotation = this.copy(n)
-}
-
-/** Converts every input clock into a clock enable input and adds a single global clock.
- * - all registers and memory ports will be connected to the new global clock
- * - all registers and memory ports will be guarded by the enable signal of their original clock
- * - the clock enabled signal can be understood as a clock tick or posedge
- * - this transform can be used in order to (formally) verify designs with multiple clocks or asynchronous resets
- */
-class StutteringClockTransform extends Transform with DependencyAPIMigration {
- override def prerequisites: Seq[TransformDependency] = Forms.LowForm
- override def invalidates(a: Transform): Boolean = false
-
- // this pass needs to run *before* converting to a transition system
- override def optionalPrerequisiteOf: Seq[TransformDependency] = Seq(Dependency(FirrtlToTransitionSystem))
- // since this pass only runs on the main module, inlining needs to happen before
- override def optionalPrerequisites: Seq[TransformDependency] = Seq(
- Dependency[firrtl.passes.InlineInstances],
- Dependency[PropagatePresetAnnotations]
- )
-
- override protected def execute(state: CircuitState): CircuitState = {
- if (state.circuit.modules.size > 1) {
- logger.warn(
- "WARN: StutteringClockTransform currently only supports running on a single module.\n" +
- s"All submodules of ${state.circuit.main} will be ignored! Please inline all submodules if this is not what you want."
- )
- }
-
- // get main module
- val main = state.circuit.modules.find(_.name == state.circuit.main).get match {
- case m: ir.Module => m
- case e: ir.ExtModule => unsupportedError(s"Cannot run on extmodule $e")
- }
- mainName = main.name
-
- val namespace = Namespace(main)
-
- // create a global clock
- val globalClocks = state.annotations.collect { case GlobalClockAnnotation(c) => c }
- assert(globalClocks.size < 2, "There can only be a single global clock: " + globalClocks.mkString(", "))
- val (globalClock, portsWithGlobalClock) = globalClocks.headOption match {
- case Some(clock) =>
- assert(clock.module == main.name, "GlobalClock needs to be an input of the main module!")
- assert(main.ports.exists(_.name == clock.ref), "GlobalClock needs to be an input port!")
- assert(main.ports.find(_.name == clock.ref).get.direction == ir.Input, "GlobalClock needs to be an input port!")
- (clock.ref, main.ports)
- case None =>
- val name = namespace.newName("global_clock")
- (name, ir.Port(ir.NoInfo, name, ir.Input, ir.ClockType) +: main.ports)
- }
-
- // replace all other clocks with enable signals, unless they are the global clock
- val clocks = portsWithGlobalClock.filter(p => p.tpe == ir.ClockType && p.name != globalClock).map(_.name)
- val clockToEnable = clocks.map { c =>
- c -> ir.Reference(namespace.newName(c + "_en"), Utils.BoolType, firrtl.PortKind, firrtl.SourceFlow)
- }.toMap
- val portsWithEnableSignals = portsWithGlobalClock.map { p =>
- if (clockToEnable.contains(p.name)) { p.copy(name = clockToEnable(p.name).name, tpe = Utils.BoolType) }
- else { p }
- }
- // replace async reset with synchronous reset (since everything will we synchronous with the global clock)
- // unless it is a preset reset
- val asyncResets = portsWithEnableSignals.filter(_.tpe == ir.AsyncResetType).map(_.name)
- val isPresetReset = state.annotations.collect { case PresetAnnotation(r) if r.module == main.name => r.ref }.toSet
- val resetsToChange = asyncResets.filterNot(isPresetReset).toSet
- val portsWithSyncReset = portsWithEnableSignals.map { p =>
- if (resetsToChange.contains(p.name)) { p.copy(tpe = Utils.BoolType) }
- else { p }
- }
- val presetRegs = state.annotations.collect {
- case PresetRegAnnotation(target) if target.module == mainName => target.ref
- }.toSet
-
- // discover clock and reset connections
- val scan = scanClocks(main, clockToEnable, resetsToChange)
-
- // rename clocks to clock enable signals
- val mRef = CircuitTarget(state.circuit.main).module(main.name)
- val renameMap = MutableRenameMap()
- scan.clockToEnable.foreach {
- case (clk, en) =>
- renameMap.record(mRef.ref(clk), mRef.ref(en.name))
- }
-
- // make changes
- implicit val ctx: Context = new Context(globalClock, scan, presetRegs)
- val newMain = main.copy(ports = portsWithSyncReset).mapStmt(onStatement)
-
- val nonMainModules = state.circuit.modules.filterNot(_.name == state.circuit.main)
- val newCircuit = state.circuit.copy(modules = nonMainModules :+ newMain)
- state.copy(circuit = newCircuit, renames = Some(renameMap))
- }
-
- private def onStatement(s: ir.Statement)(implicit ctx: Context): ir.Statement = {
- s.foreachExpr(checkExpr)
- s match {
- // memory field connects
- case c @ ir.Connect(_, ir.SubField(ir.SubField(ir.Reference(mem, _, _, _), port, _, _), field, _, _), _)
- if ctx.isMem(mem) && ctx.memPortToClockEnable.contains(mem + "." + port) =>
- // replace clock with the global clock
- if (field == "clk") {
- c.copy(expr = ctx.globalClock)
- } else if (field == "en") {
- val m = ctx.memInfo(mem)
- val isWritePort = m.writers.contains(port)
- assert(isWritePort || m.readers.contains(port))
-
- // for write ports we guard the write enable with the clock enable signal, similar to registers
- if (isWritePort) {
- val clockEn = ctx.memPortToClockEnable(mem + "." + port)
- val guardedEnable = Utils.and(clockEn, c.expr)
- c.copy(expr = guardedEnable)
- } else { c }
- } else { c }
- // register field connects
- case c @ ir.Connect(_, r: ir.Reference, next) if ctx.registerToEnable.contains(r.name) =>
- val clockEnable = ctx.registerToEnable(r.name)
- val guardedNext = Utils.mux(clockEnable, next, r)
- val withReset = ctx.registerToAsyncReset.get(r.name) match {
- case None => guardedNext
- case Some((asyncReset, init)) => Utils.mux(asyncReset, init, guardedNext)
- }
- c.copy(expr = withReset)
- // remove other clock wires and nodes
- case ir.Connect(_, loc, expr) if expr.tpe == ir.ClockType && ctx.isRemovedClock(loc.serialize) => EmptyStmt
- case ir.DefNode(_, name, value) if value.tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt
- case ir.DefWire(_, name, tpe) if tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt
- // change async reset to synchronous reset
- case ir.Connect(info, loc: ir.Reference, expr: ir.Reference)
- if expr.tpe == ir.AsyncResetType && ctx.isResetToChange(loc.serialize) =>
- ir.Connect(info, loc.copy(tpe = Utils.BoolType), expr.copy(tpe = Utils.BoolType))
- case d @ ir.DefNode(_, name, value: ir.Reference)
- if value.tpe == ir.AsyncResetType && ctx.isResetToChange(name) =>
- d.copy(value = value.copy(tpe = Utils.BoolType))
- case d @ ir.DefWire(_, name, tpe) if tpe == ir.AsyncResetType && ctx.isResetToChange(name) =>
- d.copy(tpe = Utils.BoolType)
- // change memory clock and synchronize reset
- case ir.DefRegister(info, name, tpe, _, _, init) if ctx.registerToEnable.contains(name) =>
- val newInit = if (ctx.isPresetReg(name)) init else ir.Reference(name, tpe, RegKind, SourceFlow)
- ir.DefRegister(info, name, tpe, ctx.globalClock, Utils.False(), newInit)
- case other => other.mapStmt(onStatement)
- }
- }
-
- private def scanClocks(
- m: ir.Module,
- initialClockToEnable: Map[String, ir.Reference],
- resetsToChange: Set[String]
- ): ScanCtx = {
- implicit val ctx: ScanCtx = new ScanCtx(initialClockToEnable, resetsToChange)
- m.foreachStmt(scanClocksAndResets)
- ctx
- }
-
- private def scanClocksAndResets(s: ir.Statement)(implicit ctx: ScanCtx): Unit = {
- s.foreachExpr(checkExpr)
- s match {
- // track clock aliases
- case ir.Connect(_, loc, expr) if expr.tpe == ir.ClockType =>
- val locName = loc.serialize
- ctx.clockToEnable.get(expr.serialize).foreach { clockEn =>
- ctx.clockToEnable(locName) = clockEn
- // keep track of memory clocks
- if (loc.isInstanceOf[ir.SubField]) {
- val parts = locName.split('.')
- if (ctx.mems.contains(parts.head)) {
- assert(parts.length == 3 && parts.last == "clk")
- ctx.memPortToClockEnable.append(parts.dropRight(1).mkString(".") -> clockEn)
- }
- }
- }
- case ir.DefNode(_, name, value) if value.tpe == ir.ClockType =>
- ctx.clockToEnable.get(value.serialize).foreach(c => ctx.clockToEnable(name) = c)
- // track reset aliases
- case ir.Connect(_, loc, expr) if expr.tpe == ir.AsyncResetType && ctx.resetsToChange(expr.serialize) =>
- ctx.resetsToChange.add(loc.serialize)
- case ir.DefNode(_, name, value) if value.tpe == ir.AsyncResetType && ctx.resetsToChange(value.serialize) =>
- ctx.resetsToChange.add(name)
- // modify clocked elements
- case ir.DefRegister(_, name, _, clock, reset, init) =>
- ctx.clockToEnable.get(clock.serialize).foreach { clockEnable =>
- ctx.registerToEnable.append(name -> clockEnable)
- }
- reset match {
- case Utils.False() =>
- case other => ctx.registerToAsyncReset.append(name -> (other, init))
- }
- case m: ir.DefMemory =>
- assert(m.readwriters.isEmpty, "Combined read/write ports are not supported!")
- assert(m.readLatency == 0 || m.readLatency == 1, "Only read-latency 1 and read latency 0 are supported!")
- assert(m.writeLatency == 1, "Only write-latency 1 is supported!")
- if (m.readers.nonEmpty && m.readLatency == 1) {
- unsupportedError("Registers memory read ports are not properly implemented yet :(")
- }
- ctx.mems(m.name) = m
- case other => other.foreachStmt(scanClocksAndResets)
- }
- }
-
- // we rely on people not casting clocks or async resets
- private def checkExpr(expr: ir.Expression): Unit = expr match {
- case ir.DoPrim(PrimOps.AsUInt, Seq(e), _, _) if e.tpe == ir.ClockType =>
- unsupportedError(s"Clock casts are not supported: ${expr.serialize}")
- case ir.DoPrim(PrimOps.AsSInt, Seq(e), _, _) if e.tpe == ir.ClockType =>
- unsupportedError(s"Clock casts are not supported: ${expr.serialize}")
- case ir.DoPrim(PrimOps.AsUInt, Seq(e), _, _) if e.tpe == ir.AsyncResetType =>
- unsupportedError(s"AsyncReset casts are not supported: ${expr.serialize}")
- case ir.DoPrim(PrimOps.AsSInt, Seq(e), _, _) if e.tpe == ir.AsyncResetType =>
- unsupportedError(s"AsyncReset casts are not supported: ${expr.serialize}")
- case ir.DoPrim(PrimOps.AsAsyncReset, _, _, _) =>
- unsupportedError(s"AsyncReset casts are not supported: ${expr.serialize}")
- case ir.DoPrim(PrimOps.AsClock, _, _, _) =>
- unsupportedError(s"Clock casts are not supported: ${expr.serialize}")
- case other => other.foreachExpr(checkExpr)
- }
-
- private class ScanCtx(initialClockToEnable: Map[String, ir.Reference], initialResetsToChange: Set[String]) {
- // keeps track of which clock signals will be replaced by which clock enable signal
- val clockToEnable = mutable.HashMap[String, ir.Reference]() ++ initialClockToEnable
- // kepp track of asynchronous resets that need to be changed to bool
- val resetsToChange = mutable.HashSet[String]() ++ initialResetsToChange
- // registers whose next function needs to be guarded with a clock enable
- val registerToEnable = mutable.ArrayBuffer[(String, ir.Reference)]()
- // registers with asynchronous reset
- val registerToAsyncReset = mutable.ArrayBuffer[(String, (ir.Expression, ir.Expression))]()
- // memory enables which need to be guarded with clock enables
- val memPortToClockEnable = mutable.ArrayBuffer[(String, ir.Reference)]()
- // keep track of memory names
- val mems = mutable.HashMap[String, ir.DefMemory]()
- }
-
- private class Context(globalClockName: String, scanResults: ScanCtx, val isPresetReg: String => Boolean) {
- val globalClock: ir.Reference = ir.Reference(globalClockName, ir.ClockType, firrtl.PortKind, firrtl.SourceFlow)
- // keeps track of which clock signals will be replaced by which clock enable signal
- val isRemovedClock: String => Boolean = scanResults.clockToEnable.contains
- // registers whose next function needs to be guarded with a clock enable
- val registerToEnable: Map[String, ir.Reference] = scanResults.registerToEnable.toMap
- // registers with asynchronous reset
- val registerToAsyncReset: Map[String, (ir.Expression, ir.Expression)] = scanResults.registerToAsyncReset.toMap
- // memory enables which need to be guarded with clock enables
- val memPortToClockEnable: Map[String, ir.Reference] = scanResults.memPortToClockEnable.toMap
- // keep track of memory names
- val isMem: String => Boolean = scanResults.mems.contains
- val memInfo: String => ir.DefMemory = scanResults.mems
- val isResetToChange: String => Boolean = scanResults.resetsToChange.contains
- }
-
- private var mainName: String = "" // for debugging
- private def unsupportedError(msg: String): Nothing =
- throw new UnsupportedFeatureException(s"StutteringClockTransform: [$mainName] $msg")
-}
-
-private class UnsupportedFeatureException(s: String) extends PassException(s)
diff --git a/src/main/scala/firrtl/backends/experimental/smt/TransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/TransitionSystem.scala
deleted file mode 100644
index bd3ad740..00000000
--- a/src/main/scala/firrtl/backends/experimental/smt/TransitionSystem.scala
+++ /dev/null
@@ -1,120 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
-
-package firrtl.backends.experimental.smt
-
-import firrtl.graph.MutableDiGraph
-import scala.collection.mutable
-
-case class State(sym: SMTSymbol, init: Option[SMTExpr], next: Option[SMTExpr]) {
- def name: String = sym.name
-}
-case class Signal(name: String, e: SMTExpr, lbl: SignalLabel = IsNode) {
- def toSymbol: SMTSymbol = SMTSymbol.fromExpr(name, e)
- def sym: SMTSymbol = toSymbol
-}
-case class TransitionSystem(
- name: String,
- inputs: List[BVSymbol],
- states: List[State],
- signals: List[Signal],
- comments: Map[String, String] = Map(),
- header: String = "") {
- def serialize: String = TransitionSystem.serialize(this)
-}
-
-sealed trait SignalLabel
-case object IsNode extends SignalLabel
-case object IsOutput extends SignalLabel
-case object IsConstraint extends SignalLabel
-case object IsBad extends SignalLabel
-case object IsFair extends SignalLabel
-case object IsNext extends SignalLabel
-case object IsInit extends SignalLabel
-
-object SignalLabel {
- private val labels = Seq(IsNode, IsOutput, IsConstraint, IsBad, IsFair, IsNext, IsInit)
- val labelStrings = Seq("node", "output", "constraint", "bad", "fair", "next", "init")
- val labelToString: SignalLabel => String = labels.zip(labelStrings).toMap
- val stringToLabel: String => SignalLabel = labelStrings.zip(labels).toMap
-}
-
-object TransitionSystem {
- def serialize(sys: TransitionSystem): String = {
- (Iterator(sys.name) ++
- sys.inputs.map(i => s"input ${i.name} : ${SMTExpr.serializeType(i)}") ++
- sys.signals.map(s => s"${SignalLabel.labelToString(s.lbl)} ${s.name} : ${SMTExpr.serializeType(s.e)} = ${s.e}") ++
- sys.states.map(serialize)).mkString("\n")
- }
-
- def serialize(s: State): String = {
- s"state ${s.sym.name} : ${SMTExpr.serializeType(s.sym)}" +
- s.init.map("\n [init] " + _).getOrElse("") +
- s.next.map("\n [next] " + _).getOrElse("")
- }
-
- def systemExpressions(sys: TransitionSystem): List[SMTExpr] =
- sys.signals.map(_.e) ++ sys.states.flatMap(s => s.init ++ s.next)
-
- def findUninterpretedFunctions(sys: TransitionSystem): List[DeclareFunction] = {
- val calls = systemExpressions(sys).flatMap(findUFCalls)
- // find unique functions
- calls.groupBy(_.sym.name).map(_._2.head).toList
- }
-
- private def findUFCalls(e: SMTExpr): List[DeclareFunction] = {
- val f = e match {
- case BVFunctionCall(name, args, width) =>
- Some(DeclareFunction(BVSymbol(name, width), args))
- case ArrayFunctionCall(name, args, indexWidth, dataWidth) =>
- Some(DeclareFunction(ArraySymbol(name, indexWidth, dataWidth), args))
- case _ => None
- }
- f.toList ++ e.children.flatMap(findUFCalls)
- }
-}
-
-private object TopologicalSort {
-
- /** Ensures that all signals in the resulting system are topologically sorted.
- * This is necessary because [[firrtl.transforms.RemoveWires]] does
- * not sort assignments to outputs, submodule inputs nor memory ports.
- */
- def run(sys: TransitionSystem): TransitionSystem = {
- val inputsAndStates = sys.inputs.map(_.name) ++ sys.states.map(_.sym.name)
- val signalOrder = sort(sys.signals.map(s => s.name -> s.e), inputsAndStates)
- // TODO: maybe sort init expressions of states (this should not be needed most of the time)
- signalOrder match {
- case None => sys
- case Some(order) =>
- val signalMap = sys.signals.map(s => s.name -> s).toMap
- // we flatMap over `get` in order to ignore inputs/states in the order
- sys.copy(signals = order.flatMap(signalMap.get).toList)
- }
- }
-
- private def sort(signals: Iterable[(String, SMTExpr)], globalSignals: Iterable[String]): Option[Iterable[String]] = {
- val known = new mutable.HashSet[String]() ++ globalSignals
- var needsReordering = false
- val digraph = new MutableDiGraph[String]
- signals.foreach {
- case (name, expr) =>
- digraph.addVertex(name)
- val uniqueDependencies = mutable.LinkedHashSet[String]() ++ findDependencies(expr)
- uniqueDependencies.foreach { d =>
- if (!known.contains(d)) { needsReordering = true }
- digraph.addPairWithEdge(name, d)
- }
- known.add(name)
- }
- if (needsReordering) {
- Some(digraph.linearize.reverse)
- } else { None }
- }
-
- private def findDependencies(expr: SMTExpr): List[String] = expr match {
- case BVSymbol(name, _) => List(name)
- case ArraySymbol(name, _, _) => List(name)
- case other => other.children.flatMap(findDependencies)
- }
-}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/UninterpretedModuleAnnotation.scala b/src/main/scala/firrtl/backends/experimental/smt/UninterpretedModuleAnnotation.scala
deleted file mode 100644
index c7442f69..00000000
--- a/src/main/scala/firrtl/backends/experimental/smt/UninterpretedModuleAnnotation.scala
+++ /dev/null
@@ -1,86 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
-
-package firrtl.backends.experimental.smt
-
-import firrtl.annotations._
-import firrtl.ir
-import firrtl.passes.PassException
-
-/** ExtModules annotated as UninterpretedModule will be modelled as
- * UninterpretedFunction (SMTLib) or constant arrays (btor2).
- * This can be useful when trying to abstract over a function that the
- * SMT solver or model checker is struggling with.
- *
- * E.g., one could declare an abstract 64bit multiplier like this:
- * ```
- * extmodule Mul64 :
- * input a : UInt<64>
- * input b : UInt<64>
- * output r : UInt<64>
- * ```
- * Now instead of using Chisel to actually implement a multiplication circuit
- * we can instantiate this Mul64 module twice: Once in our implementation
- * and once for our correctness property that might specify how the
- * multiply instruction is supposed to be executed on our CPU.
- * Now instead of having to prove equivalence of multiplication circuits, the
- * solver only has to make sure that the connections to the multiplier are correct,
- * since if `a` and `b` are the same on both instances of `Mul64`, then the `r` output
- * will also be the same. This is a much easier problem and will result in much faster
- * solving due to manual abstraction.
- *
- * When [[stateBits]] is 0, we model the module as purely combinatorial circuit and
- * thus expect there to be no clock wire going into the module.
- * Every output is thus a function of all inputs of the module.
- *
- * When [[stateBits]] is an N greater than zero, we will model the module as having an abstract state of width N.
- * Thus on every clock transition the abstract state is updated and all outputs will take the state
- * as well as the current inputs as arguments.
- * TODO: Support for stateful circuits is work in progress.
- *
- * All output functions well be prefixed with [[prefix]] and end in the name of the output pin.
- * It is the users responsibility to ensure that all function names will be unique by choosing apropriate
- * prefixes.
- *
- * The annotation is consumed by the [[FirrtlToTransitionSystem]] pass.
- */
-case class UninterpretedModuleAnnotation(target: ModuleTarget, prefix: String, stateBits: Int = 0)
- extends SingleTargetAnnotation[ModuleTarget] {
- require(stateBits >= 0, "negative number of bits is forbidden")
- if (stateBits > 0) throw new NotImplementedError("TODO: support for stateful circuits is not implemented yet!")
- override def duplicate(n: ModuleTarget) = copy(n)
-}
-
-object UninterpretedModuleAnnotation {
-
- /** checks to see whether the annotation module can actually be abstracted. Use *after* LowerTypes! */
- def checkModule(m: ir.DefModule, anno: UninterpretedModuleAnnotation): Unit = m match {
- case _: ir.Module =>
- throw new UninterpretedModuleException(s"UninterpretedModuleAnnotation can only be used with extmodule! $anno")
- case m: ir.ExtModule =>
- val clockInputs = m.ports.collect { case p @ ir.Port(_, _, ir.Input, ir.ClockType) => p.name }
- val clockOutput = m.ports.collect { case p @ ir.Port(_, _, ir.Output, ir.ClockType) => p.name }
- val asyncResets = m.ports.collect { case p @ ir.Port(_, _, _, ir.AsyncResetType) => p.name }
- if (clockOutput.nonEmpty) {
- throw new UninterpretedModuleException(
- s"We do not support clock outputs for uninterpreted modules! $clockOutput"
- )
- }
- if (asyncResets.nonEmpty) {
- throw new UninterpretedModuleException(
- s"We do not support async reset I/O for uninterpreted modules! $asyncResets"
- )
- }
- if (anno.stateBits == 0) {
- if (clockInputs.nonEmpty) {
- throw new UninterpretedModuleException(s"A combinatorial module may not have any clock inputs! $clockInputs")
- }
- } else {
- if (clockInputs.size != 1) {
- throw new UninterpretedModuleException(s"A stateful module must have exactly one clock input! $clockInputs")
- }
- }
- }
-}
-
-private class UninterpretedModuleException(s: String) extends PassException(s)
diff --git a/src/main/scala/firrtl/backends/experimental/smt/random/DefRandom.scala b/src/main/scala/firrtl/backends/experimental/smt/random/DefRandom.scala
deleted file mode 100644
index 7381056e..00000000
--- a/src/main/scala/firrtl/backends/experimental/smt/random/DefRandom.scala
+++ /dev/null
@@ -1,31 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-
-package firrtl.backends.experimental.smt.random
-
-import firrtl.Utils
-import firrtl.ir._
-
-/** Named source of random values. If there is no clock expression, than it will be clocked by the global clock. */
-case class DefRandom(
- info: Info,
- name: String,
- tpe: Type,
- clock: Option[Expression],
- en: Expression = Utils.True())
- extends Statement
- with HasInfo
- with IsDeclaration
- with CanBeReferenced
- with UseSerializer {
- def mapStmt(f: Statement => Statement): Statement = this
- def mapExpr(f: Expression => Expression): Statement =
- DefRandom(info, name, tpe, clock.map(f), f(en))
- def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
- def mapString(f: String => String): Statement = this.copy(name = f(name))
- def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
- def foreachStmt(f: Statement => Unit): Unit = ()
- def foreachExpr(f: Expression => Unit): Unit = { clock.foreach(f); f(en) }
- def foreachType(f: Type => Unit): Unit = f(tpe)
- def foreachString(f: String => Unit): Unit = f(name)
- def foreachInfo(f: Info => Unit): Unit = f(info)
-}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/random/InvalidToRandomPass.scala b/src/main/scala/firrtl/backends/experimental/smt/random/InvalidToRandomPass.scala
deleted file mode 100644
index c7eaad74..00000000
--- a/src/main/scala/firrtl/backends/experimental/smt/random/InvalidToRandomPass.scala
+++ /dev/null
@@ -1,125 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-
-package firrtl.backends.experimental.smt.random
-
-import firrtl._
-import firrtl.annotations.NoTargetAnnotation
-import firrtl.ir._
-import firrtl.passes._
-import firrtl.options.Dependency
-import firrtl.stage.Forms
-import firrtl.transforms.RemoveWires
-
-import scala.collection.mutable
-
-/** Chooses how to model explicit and implicit invalid values in the circuit */
-case class InvalidToRandomOptions(
- randomizeInvalidSignals: Boolean = true,
- randomizeDivisionByZero: Boolean = true)
- extends NoTargetAnnotation
-
-/** Replaces all explicit and implicit "invalid" values with random values.
- * Explicit invalids are:
- * - signal is invalid
- * - signal <= valid(..., expr)
- * Implicit invalids are:
- * - a / b when eq(b, 0)
- */
-object InvalidToRandomPass extends Transform with DependencyAPIMigration {
- override def prerequisites = Forms.LowForm
- // once ValidIf has been removed, we can no longer detect and randomize them
- override def optionalPrerequisiteOf = Seq(Dependency(RemoveValidIf))
- override def invalidates(a: Transform) = a match {
- // this pass might destroy SSA form, as we add a wire for the data field of every read port
- case _: RemoveWires => true
- // TODO: should we add some optimization passes here? we could be generating some dead code.
- case _ => false
- }
-
- override protected def execute(state: CircuitState): CircuitState = {
- val opts = state.annotations.collect { case o: InvalidToRandomOptions => o }
- require(opts.size < 2, s"Multiple options: $opts")
- val opt = opts.headOption.getOrElse(InvalidToRandomOptions())
-
- // quick exit if we just want to skip this pass
- if (!opt.randomizeDivisionByZero && !opt.randomizeInvalidSignals) {
- state
- } else {
- val c = state.circuit.mapModule(onModule(_, opt))
- state.copy(circuit = c)
- }
- }
-
- private def onModule(m: DefModule, opt: InvalidToRandomOptions): DefModule = m match {
- case d: DescribedMod =>
- throw new RuntimeException(s"CompilerError: Unexpected internal node: ${d.serialize}")
- case e: ExtModule => e
- case mod: Module =>
- val namespace = Namespace(mod)
- mod.mapStmt(onStmt(namespace, opt, _))
- }
-
- private def onStmt(namespace: Namespace, opt: InvalidToRandomOptions, s: Statement): Statement = s match {
- case IsInvalid(info, loc: RefLikeExpression) if opt.randomizeInvalidSignals =>
- val name = namespace.newName(loc.serialize.replace('.', '_') + "_invalid")
- val rand = DefRandom(info, name, loc.tpe, None)
- Block(List(rand, Connect(info, loc, Reference(rand))))
- case other =>
- val info = other match {
- case h: HasInfo => h.info
- case _ => NoInfo
- }
- val prefix = other match {
- case c: Connect => c.loc.serialize.replace('.', '_')
- case h: HasName => h.name
- case _ => ""
- }
- val ctx = ExprCtx(namespace, opt, prefix, info, mutable.ListBuffer[Statement]())
- val stmt = other.mapExpr(onExpr(ctx, _)).mapStmt(onStmt(namespace, opt, _))
- if (ctx.rands.isEmpty) { stmt }
- else { Block(Block(ctx.rands.toList), stmt) }
- }
-
- private case class ExprCtx(
- namespace: Namespace,
- opt: InvalidToRandomOptions,
- prefix: String,
- info: Info,
- rands: mutable.ListBuffer[Statement])
-
- private def onExpr(ctx: ExprCtx, e: Expression): Expression =
- e.mapExpr(onExpr(ctx, _)) match {
- case ValidIf(_, value, tpe) if tpe == ClockType =>
- // we currently assume that clocks are always valid
- // TODO: is that a good assumption?
- value
- case ValidIf(cond, value, tpe) if ctx.opt.randomizeInvalidSignals =>
- makeRand(ctx, cond, tpe, value, invert = true)
- case d @ DoPrim(PrimOps.Div, Seq(_, den), _, tpe) if ctx.opt.randomizeDivisionByZero =>
- val denIsZero = Utils.eq(den, Utils.getGroundZero(den.tpe.asInstanceOf[GroundType]))
- makeRand(ctx, denIsZero, tpe, d, invert = false)
- case other => other
- }
-
- private def makeRand(
- ctx: ExprCtx,
- cond: Expression,
- tpe: Type,
- value: Expression,
- invert: Boolean
- ): Expression = {
- val name = ctx.namespace.newName(if (ctx.prefix.isEmpty) "invalid" else ctx.prefix + "_invalid")
- // create a condition node if the condition isn't a reference already
- val condRef = cond match {
- case r: RefLikeExpression => if (invert) Utils.not(r) else r
- case other =>
- val cond = if (invert) Utils.not(other) else other
- val condNode = DefNode(ctx.info, ctx.namespace.newName(name + "_cond"), cond)
- ctx.rands.append(condNode)
- Reference(condNode)
- }
- val rand = DefRandom(ctx.info, name, tpe, None, condRef)
- ctx.rands.append(rand)
- Utils.mux(condRef, Reference(rand), value)
- }
-}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala b/src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala
deleted file mode 100644
index 96582778..00000000
--- a/src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala
+++ /dev/null
@@ -1,461 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-
-package firrtl.backends.experimental.smt.random
-
-import firrtl.Utils.{isLiteral, BoolType}
-import firrtl._
-import firrtl.annotations.NoTargetAnnotation
-import firrtl.backends.experimental.smt._
-import firrtl.ir._
-import firrtl.options.Dependency
-import firrtl.passes.MemPortUtils.memPortField
-import firrtl.passes.memlib.AnalysisUtils.Connects
-import firrtl.passes.memlib.InferReadWritePass.checkComplement
-import firrtl.passes.memlib.{AnalysisUtils, InferReadWritePass, VerilogMemDelays}
-import firrtl.stage.Forms
-import firrtl.transforms.RemoveWires
-
-import scala.collection.mutable
-
-/** Chooses which undefined memory behaviors should be instrumented. */
-case class UndefinedMemoryBehaviorOptions(
- randomizeWriteWriteConflicts: Boolean = true,
- assertNoOutOfBoundsWrites: Boolean = false,
- randomizeOutOfBoundsRead: Boolean = true,
- randomizeDisabledReads: Boolean = true,
- randomizeReadWriteConflicts: Boolean = true)
- extends NoTargetAnnotation
-
-/** Adds sources of randomness to model the various "undefined behaviors" of firrtl memory.
- * - Write/Write conflict: leads to arbitrary value written to write address
- * - Out-of-bounds write: assertion failure (disabled by default)
- * - Out-Of-bounds read: leads to arbitrary value being read
- * - Read w/ en=0: leads to arbitrary value being read
- * - Read/Write conflict: leads to arbitrary value being read
- */
-object UndefinedMemoryBehaviorPass extends Transform with DependencyAPIMigration {
- override def prerequisites = Forms.LowForm
- override def optionalPrerequisiteOf = Seq(Dependency(VerilogMemDelays))
- override def invalidates(a: Transform) = a match {
- // this pass might destroy SSA form, as we add a wire for the data field of every read port
- case _: RemoveWires => true
- // TODO: should we add some optimization passes here? we could be generating some dead code.
- case _ => false
- }
-
- override protected def execute(state: CircuitState): CircuitState = {
- val opts = state.annotations.collect { case o: UndefinedMemoryBehaviorOptions => o }
- require(opts.size < 2, s"Multiple options: $opts")
- val opt = opts.headOption.getOrElse(UndefinedMemoryBehaviorOptions())
-
- val c = state.circuit.mapModule(onModule(_, opt))
- state.copy(circuit = c)
- }
-
- private def onModule(m: DefModule, opt: UndefinedMemoryBehaviorOptions): DefModule = m match {
- case mod: Module =>
- val mems = findMems(mod)
- if (mems.isEmpty) { mod }
- else {
- val namespace = Namespace(mod)
- val connects = AnalysisUtils.getConnects(mod)
- new InstrumentMems(opt, mems, connects, namespace).run(mod)
- }
- case other => other
- }
-
- /** finds all memory instantiations in a circuit */
- private def findMems(m: Module): List[DefMemory] = {
- val mems = mutable.ListBuffer[DefMemory]()
- m.foreachStmt(findMems(_, mems))
- mems.toList
- }
- private def findMems(s: Statement, mems: mutable.ListBuffer[DefMemory]): Unit = s match {
- case mem: DefMemory => mems.append(mem)
- case other => other.foreachStmt(findMems(_, mems))
- }
-}
-
-private class InstrumentMems(
- opt: UndefinedMemoryBehaviorOptions,
- mems: List[DefMemory],
- connects: Connects,
- namespace: Namespace) {
- def run(m: Module): DefModule = {
- // ensure that all memories are the kind we can support
- mems.foreach(checkSupported(m.name, _))
-
- // transform circuit
- val body = m.body.mapStmt(transform)
- m.copy(body = Block(body +: newStmts.toList))
- }
-
- // used to replace memory signals like `m.r.data` in RHS expressions
- private val exprReplacements = mutable.HashMap[String, Expression]()
- // add new statements at the end of the circuit
- private val newStmts = mutable.ListBuffer[Statement]()
- // disconnect references so that they can be reassigned
- private val doDisconnect = mutable.HashSet[String]()
-
- // generates new expression replacements and immediately uses them
- private def transform(s: Statement): Statement = s.mapStmt(transform) match {
- case mem: DefMemory => onMem(mem)
- case sx: Connect if doDisconnect.contains(sx.loc.serialize) => EmptyStmt // Filter old mem connections
- case sx => sx.mapExpr(swapMemRefs)
- }
- private def swapMemRefs(e: Expression): Expression = e.mapExpr(swapMemRefs) match {
- case sf: RefLikeExpression => exprReplacements.getOrElse(sf.serialize, sf)
- case ex => ex
- }
-
- private def onMem(m: DefMemory): Statement = {
- // collect wire and random statement defines
- implicit val declarations: mutable.ListBuffer[Statement] = mutable.ListBuffer[Statement]()
-
- // cache for the expressions of memory inputs
- implicit val cache: mutable.HashMap[String, Expression] = mutable.HashMap[String, Expression]()
-
- // only for non power of 2 memories do we have to worry about reading or writing out of bounds
- val canBeOutOfBounds = !isPow2(m.depth)
-
- // only if we have at least two write ports, can there be conflicts
- val canHaveWriteWriteConflicts = m.writers.size > 1
-
- // only certain memory types exhibit undefined read/write conflicts
- val readWriteUndefined = (m.readLatency == m.writeLatency) && (m.readUnderWrite == ReadUnderWrite.Undefined)
- assert(
- m.readLatency == 0 || m.readLatency == m.writeLatency,
- "TODO: what happens if a sync read mem has asymmetrical latencies?"
- )
-
- // a write port is enabled iff mask & en
- val writeEn = m.writers.map { write =>
- val enRef = memPortField(m, write, "en")
- val maskRef = memPortField(m, write, "mask")
-
- val prods = getProductTerms(enRef) ++ getProductTerms(maskRef)
- val expr = Utils.and(readInput(m.info, enRef), readInput(m.info, maskRef))
-
- (expr, prods)
- }
-
- // implement the three undefined read behaviors
- m.readers.foreach { read =>
- // many memories have their read enable hard wired to true
- val canBeDisabled = !isTrue(readInput(m, read, "en"))
- val readEn = if (canBeDisabled) readInput(m, read, "en") else Utils.True()
-
- // collect signals that would lead to a randomization
- var doRand = List[Expression]()
-
- // randomize the read value when the address is out of bounds
- if (canBeOutOfBounds && opt.randomizeOutOfBoundsRead) {
- val addr = readInput(m, read, "addr")
- val cond = Utils.and(readEn, Utils.not(isInBounds(m.depth, addr)))
- val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_oob"), cond)
- declarations += node
- doRand = Reference(node) +: doRand
- }
-
- if (readWriteUndefined && opt.randomizeReadWriteConflicts) {
- val cond = readWriteConflict(m, read, writeEn)
- val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_rwc"), cond)
- declarations += node
- doRand = Reference(node) +: doRand
- }
-
- // randomize the read value when the read is disabled
- if (canBeDisabled && opt.randomizeDisabledReads) {
- val cond = Utils.not(readEn)
- val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_disabled"), cond)
- declarations += node
- doRand = Reference(node) +: doRand
- }
-
- // if there are no signals that would require a randomization, there is nothing to do
- if (doRand.isEmpty) {
- // nothing to do
- } else {
- val doRandName = s"${m.name}_${read}_do_rand"
- val doRandNode = if (doRand.size == 1) { doRand.head }
- else {
- val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_do_rand"), doRand.reduce(Utils.or))
- declarations += node
- Reference(node)
- }
- val doRandSignal = if (m.readLatency == 0) { doRandNode }
- else {
- val clock = readInput(m, read, "clk")
- val (signal, regDecls) = pipeline(m.info, clock, doRandName, doRandNode, m.readLatency)
- declarations ++= regDecls
- signal
- }
-
- // all old rhs references to m.r.data need to replace with m_r_data which might be random
- val dataRef = memPortField(m, read, "data")
- val dataWire = DefWire(m.info, namespace.newName(s"${m.name}_${read}_data"), m.dataType)
- declarations += dataWire
- exprReplacements(dataRef.serialize) = Reference(dataWire)
-
- // create a source of randomness and connect the new wire either to the actual data port or to the random value
- val randName = namespace.newName(s"${m.name}_${read}_rand_data")
- val random = DefRandom(m.info, randName, m.dataType, Some(readInput(m, read, "clk")), doRandSignal)
- declarations += random
- val data = Utils.mux(doRandSignal, Reference(random), dataRef)
- newStmts.append(Connect(m.info, Reference(dataWire), data))
- }
- }
-
- // write
- if (opt.randomizeWriteWriteConflicts) {
- writeWriteConflicts(m, writeEn)
- }
-
- // add an assertion that if the write is taking place, then the address must be in range
- if (canBeOutOfBounds && opt.assertNoOutOfBoundsWrites) {
- m.writers.zip(writeEn).foreach {
- case (write, (combinedEn, _)) =>
- val addr = readInput(m, write, "addr")
- val cond = Utils.implies(combinedEn, isInBounds(m.depth, addr))
- val clk = readInput(m, write, "clk")
- val a = Verification(Formal.Assert, m.info, clk, cond, Utils.True(), StringLit("out of bounds read"))
- newStmts.append(a)
- }
- }
-
- Block(m +: declarations.toList)
- }
-
- private def pipeline(
- info: Info,
- clk: Expression,
- prefix: String,
- e: Expression,
- latency: Int
- ): (Expression, Seq[Statement]) = {
- require(latency > 0)
- val regs = (1 to latency).map { i =>
- val name = namespace.newName(prefix + s"_r$i")
- DefRegister(info, name, e.tpe, clk, Utils.False(), Reference(name, e.tpe, RegKind, UnknownFlow))
- }
- val expr = regs.foldLeft(e) {
- case (prev, reg) =>
- newStmts.append(Connect(info, Reference(reg), prev))
- Reference(reg)
- }
- (expr, regs)
- }
-
- private def readWriteConflict(
- m: DefMemory,
- read: String,
- writeEn: Seq[(Expression, ProdTerms)]
- )(
- implicit cache: mutable.HashMap[String, Expression],
- declarations: mutable.ListBuffer[Statement]
- ): Expression = {
- if (m.writers.isEmpty) return Utils.False()
-
- val readEn = readInput(m, read, "en")
- val readProd = getProductTerms(readEn)
-
- // create all conflict signals
- val conflicts = m.writers.zip(writeEn).map {
- case (write, (writeEn, writeProd)) =>
- if (isMutuallyExclusive(readProd, writeProd)) {
- Utils.False()
- } else {
- val name = namespace.newName(s"${m.name}_${read}_${write}_rwc")
- val bothEn = Utils.and(readEn, writeEn)
- val sameAddr = Utils.eq(readInput(m, read, "addr"), readInput(m, write, "addr"))
- // we need a wire because this condition might be used in a random statement
- val wire = DefWire(m.info, name, BoolType)
- declarations += wire
- newStmts.append(Connect(m.info, Reference(wire), Utils.and(bothEn, sameAddr)))
- Reference(wire)
- }
- }
-
- conflicts.reduce(Utils.or)
- }
-
- private type ProdTerms = Seq[Expression]
- private def writeWriteConflicts(
- m: DefMemory,
- writeEn: Seq[(Expression, ProdTerms)]
- )(
- implicit cache: mutable.HashMap[String, Expression],
- declarations: mutable.ListBuffer[Statement]
- ): Unit = {
- if (m.writers.size < 2) return
-
- // we first create all conflict signals:
- val conflict =
- m.writers
- .zip(writeEn)
- .zipWithIndex
- .flatMap {
- case ((w1, (en1, en1Prod)), i1) =>
- m.writers.zip(writeEn).drop(i1 + 1).map {
- case (w2, (en2, en2Prod)) =>
- if (isMutuallyExclusive(en1Prod, en2Prod)) {
- (w1, w2) -> Utils.False()
- } else {
- val name = namespace.newName(s"${m.name}_${w1}_${w2}_wwc")
- val bothEn = Utils.and(en1, en2)
- val sameAddr = Utils.eq(readInput(m, w1, "addr"), readInput(m, w2, "addr"))
- // we need a wire because this condition might be used in a random statement
- val wire = DefWire(m.info, name, BoolType)
- declarations += wire
- newStmts.append(Connect(m.info, Reference(wire), Utils.and(bothEn, sameAddr)))
- (w1, w2) -> Reference(wire)
- }
- }
- }
- .toMap
-
- // now we calculate the new enable and data signals
- m.writers.zip(writeEn).zipWithIndex.foreach {
- case ((w1, (en1, _)), i1) =>
- val prev = m.writers.take(i1)
- val next = m.writers.drop(i1 + 1)
-
- // the write is enabled if the original enable is true and there are no prior conflicts
- val en = if (prev.isEmpty) {
- en1
- } else {
- val prevConflicts = prev.map(o => conflict(o, w1)).reduce(Utils.or)
- Utils.and(en1, Utils.not(prevConflicts))
- }
-
- // we write random data if there is a conflict with any of the next ports
- if (next.isEmpty) {
- // nothing to do, leave data as is
- } else {
- val nextConflicts = next.map(n => conflict(w1, n)).reduce(Utils.or)
- // if the conflict expression is more complex, create a node for the signal
- val hasConflict = nextConflicts match {
- case _: DoPrim | _: Mux =>
- val node = DefNode(m.info, namespace.newName(s"${m.name}_${w1}_wwc_active"), nextConflicts)
- declarations += node
- Reference(node)
- case _ => nextConflicts
- }
-
- // create the source of randomness
- val name = namespace.newName(s"${m.name}_${w1}_wwc_data")
- val random = DefRandom(m.info, name, m.dataType, Some(readInput(m, w1, "clk")), hasConflict)
- declarations.append(random)
-
- // generate new data input
- val data = Utils.mux(hasConflict, Reference(random), readInput(m, w1, "data"))
- newStmts.append(Connect(m.info, memPortField(m, w1, "data"), data))
- doDisconnect.add(memPortField(m, w1, "data").serialize)
- }
-
- // connect data enable signals
- val maskIsOne = isTrue(readInput(m, w1, "mask"))
- if (!maskIsOne) {
- newStmts.append(Connect(m.info, memPortField(m, w1, "mask"), Utils.True()))
- doDisconnect.add(memPortField(m, w1, "mask").serialize)
- }
- newStmts.append(Connect(m.info, memPortField(m, w1, "en"), en))
- doDisconnect.add(memPortField(m, w1, "en").serialize)
- }
- }
-
- /** check whether two signals can be proven to be mutually exclusive */
- private def isMutuallyExclusive(prodA: ProdTerms, prodB: ProdTerms): Boolean = {
- // this uses the same approach as the InferReadWrite pass
- val proofOfMutualExclusion = prodA.find(a => prodB.exists(b => checkComplement(a, b)))
- proofOfMutualExclusion.nonEmpty
- }
-
- /** memory inputs my not be read, only assigned to, thus we might need to add a wire to make them accessible */
- private def readInput(
- info: Info,
- signal: RefLikeExpression
- )(
- implicit cache: mutable.HashMap[String, Expression],
- declarations: mutable.ListBuffer[Statement]
- ): Expression =
- cache.getOrElseUpdate(
- signal.serialize, {
- // if it is a literal, we just return it
- val value = connects(signal.serialize)
- if (isLiteral(value)) {
- value
- } else {
- // otherwise we make a wire that refelect the value
- val wire = DefWire(info, copyName(signal), signal.tpe)
- declarations += wire
-
- // connect the old expression to the new wire
- val con = Connect(info, Reference(wire), value)
- newStmts.append(con)
-
- // use a reference to this new wire
- Reference(wire)
- }
- }
- )
- private def readInput(
- m: DefMemory,
- port: String,
- field: String
- )(
- implicit cache: mutable.HashMap[String, Expression],
- declarations: mutable.ListBuffer[Statement]
- ): Expression =
- readInput(m.info, memPortField(m, port, field))
-
- private def copyName(ref: RefLikeExpression): String =
- namespace.newName(ref.serialize.replace('.', '_'))
-
- private def isInBounds(depth: BigInt, addr: Expression): Expression = {
- val width = getWidth(addr)
- // depth > addr (e.g. if the depth is 3, then the address must be in {0, 1, 2})
- DoPrim(PrimOps.Gt, List(UIntLiteral(depth, width), addr), List(), BoolType)
- }
-
- private def isPow2(v: BigInt): Boolean = ((v - 1) & v) == 0
-
- private def checkSupported(modName: String, m: DefMemory): Unit = {
- assert(m.readwriters.isEmpty, s"[$modName] Combined read/write ports are currently not supported!")
- if (m.writeLatency != 1) {
- throw new UnsupportedFeatureException(s"[$modName] memories with write latency > 1 (${m.name})")
- }
- if (m.readLatency > 1) {
- throw new UnsupportedFeatureException(s"[$modName] memories with read latency > 1 (${m.name})")
- }
- }
-
- private def getProductTerms(e: Expression): ProdTerms =
- InferReadWritePass.getProductTerms(connects)(e)
-
- /** tries to expand the expression based on the connects we collected */
- private def expandExpr(e: Expression, fuel: Int): Expression = {
- e match {
- case m @ Mux(cond, tval, fval, _) =>
- m.copy(cond = expandExpr(cond, fuel), tval = expandExpr(tval, fuel), fval = expandExpr(fval, fuel))
- case p @ DoPrim(_, args, _, _) =>
- p.copy(args = args.map(expandExpr(_, fuel)))
- case r: RefLikeExpression =>
- if (fuel > 0) {
- connects.get(r.serialize) match {
- case None => r
- case Some(expr) => expandExpr(expr, fuel - 1)
- }
- } else {
- r
- }
- case other => other
- }
- }
-
- private def isTrue(e: Expression): Boolean = simplifyExpr(expandExpr(e, fuel = 2)) == Utils.True()
-
- private def simplifyExpr(e: Expression): Expression = {
- e // TODO: better simplification could improve the resulting circuit size
- }
-}