aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/backends/experimental/smt/TransitionSystem.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/backends/experimental/smt/TransitionSystem.scala')
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/TransitionSystem.scala120
1 files changed, 120 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/backends/experimental/smt/TransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/TransitionSystem.scala
new file mode 100644
index 00000000..66a1b385
--- /dev/null
+++ b/src/main/scala/firrtl/backends/experimental/smt/TransitionSystem.scala
@@ -0,0 +1,120 @@
+// SPDX-License-Identifier: Apache-2.0
+// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>
+
+package firrtl.backends.experimental.smt
+
+import firrtl.graph.MutableDiGraph
+import scala.collection.mutable
+
+private case class State(sym: SMTSymbol, init: Option[SMTExpr], next: Option[SMTExpr]) {
+ def name: String = sym.name
+}
+private case class Signal(name: String, e: SMTExpr, lbl: SignalLabel = IsNode) {
+ def toSymbol: SMTSymbol = SMTSymbol.fromExpr(name, e)
+ def sym: SMTSymbol = toSymbol
+}
+private case class TransitionSystem(
+ name: String,
+ inputs: List[BVSymbol],
+ states: List[State],
+ signals: List[Signal],
+ comments: Map[String, String] = Map(),
+ header: String = "") {
+ def serialize: String = TransitionSystem.serialize(this)
+}
+
+private sealed trait SignalLabel
+private case object IsNode extends SignalLabel
+private case object IsOutput extends SignalLabel
+private case object IsConstraint extends SignalLabel
+private case object IsBad extends SignalLabel
+private case object IsFair extends SignalLabel
+private case object IsNext extends SignalLabel
+private case object IsInit extends SignalLabel
+
+private object SignalLabel {
+ private val labels = Seq(IsNode, IsOutput, IsConstraint, IsBad, IsFair, IsNext, IsInit)
+ val labelStrings = Seq("node", "output", "constraint", "bad", "fair", "next", "init")
+ val labelToString: SignalLabel => String = labels.zip(labelStrings).toMap
+ val stringToLabel: String => SignalLabel = labelStrings.zip(labels).toMap
+}
+
+private object TransitionSystem {
+ def serialize(sys: TransitionSystem): String = {
+ (Iterator(sys.name) ++
+ sys.inputs.map(i => s"input ${i.name} : ${SMTExpr.serializeType(i)}") ++
+ sys.signals.map(s => s"${SignalLabel.labelToString(s.lbl)} ${s.name} : ${SMTExpr.serializeType(s.e)} = ${s.e}") ++
+ sys.states.map(serialize)).mkString("\n")
+ }
+
+ def serialize(s: State): String = {
+ s"state ${s.sym.name} : ${SMTExpr.serializeType(s.sym)}" +
+ s.init.map("\n [init] " + _).getOrElse("") +
+ s.next.map("\n [next] " + _).getOrElse("")
+ }
+
+ def systemExpressions(sys: TransitionSystem): List[SMTExpr] =
+ sys.signals.map(_.e) ++ sys.states.flatMap(s => s.init ++ s.next)
+
+ def findUninterpretedFunctions(sys: TransitionSystem): List[DeclareFunction] = {
+ val calls = systemExpressions(sys).flatMap(findUFCalls)
+ // find unique functions
+ calls.groupBy(_.sym.name).map(_._2.head).toList
+ }
+
+ private def findUFCalls(e: SMTExpr): List[DeclareFunction] = {
+ val f = e match {
+ case BVFunctionCall(name, args, width) =>
+ Some(DeclareFunction(BVSymbol(name, width), args))
+ case ArrayFunctionCall(name, args, indexWidth, dataWidth) =>
+ Some(DeclareFunction(ArraySymbol(name, indexWidth, dataWidth), args))
+ case _ => None
+ }
+ f.toList ++ e.children.flatMap(findUFCalls)
+ }
+}
+
+private object TopologicalSort {
+
+ /** Ensures that all signals in the resulting system are topologically sorted.
+ * This is necessary because [[firrtl.transforms.RemoveWires]] does
+ * not sort assignments to outputs, submodule inputs nor memory ports.
+ */
+ def run(sys: TransitionSystem): TransitionSystem = {
+ val inputsAndStates = sys.inputs.map(_.name) ++ sys.states.map(_.sym.name)
+ val signalOrder = sort(sys.signals.map(s => s.name -> s.e), inputsAndStates)
+ // TODO: maybe sort init expressions of states (this should not be needed most of the time)
+ signalOrder match {
+ case None => sys
+ case Some(order) =>
+ val signalMap = sys.signals.map(s => s.name -> s).toMap
+ // we flatMap over `get` in order to ignore inputs/states in the order
+ sys.copy(signals = order.flatMap(signalMap.get).toList)
+ }
+ }
+
+ private def sort(signals: Iterable[(String, SMTExpr)], globalSignals: Iterable[String]): Option[Iterable[String]] = {
+ val known = new mutable.HashSet[String]() ++ globalSignals
+ var needsReordering = false
+ val digraph = new MutableDiGraph[String]
+ signals.foreach {
+ case (name, expr) =>
+ digraph.addVertex(name)
+ val uniqueDependencies = mutable.LinkedHashSet[String]() ++ findDependencies(expr)
+ uniqueDependencies.foreach { d =>
+ if (!known.contains(d)) { needsReordering = true }
+ digraph.addPairWithEdge(name, d)
+ }
+ known.add(name)
+ }
+ if (needsReordering) {
+ Some(digraph.linearize.reverse)
+ } else { None }
+ }
+
+ private def findDependencies(expr: SMTExpr): List[String] = expr match {
+ case BVSymbol(name, _) => List(name)
+ case ArraySymbol(name, _, _) => List(name)
+ case other => other.children.flatMap(findDependencies)
+ }
+}