aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala
blob: 4f096c28adbe515dcb859e1afe992a7fa34e116f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
// SPDX-License-Identifier: Apache-2.0
// Author: Kevin Laeufer <laeufer@cs.berkeley.edu>

package firrtl.backends.experimental.smt

import scala.collection.mutable

/** This Transition System encoding is directly inspired by yosys' SMT backend:
  * https://github.com/YosysHQ/yosys/blob/master/backends/smt2/smt2.cc
  * It if fairly compact, but unfortunately, the use of an uninterpreted sort for the state
  * prevents this encoding from working with boolector.
  * For simplicity reasons, we do not support hierarchical designs (no `_h` function).
  */
object SMTTransitionSystemEncoder {

  def encode(sys: TransitionSystem): Iterable[SMTCommand] = {
    val cmds = mutable.ArrayBuffer[SMTCommand]()
    val name = sys.name

    // declare UFs if necessary
    cmds ++= TransitionSystem.findUninterpretedFunctions(sys)

    // emit header as comments
    if (sys.header.nonEmpty) {
      cmds ++= sys.header.split('\n').map(Comment)
    }

    // declare state type
    val stateType = id(name + "_s")
    cmds += DeclareUninterpretedSort(stateType)

    // state symbol
    val State = UTSymbol("state", stateType)
    val StateNext = UTSymbol("state_n", stateType)

    // inputs and states are modelled as constants
    def declare(sym: SMTSymbol, kind: String): Unit = {
      cmds ++= toDescription(sym, kind, sys.comments.get)
      val s = SMTSymbol.fromExpr(sym.name + SignalSuffix, sym)
      cmds += DeclareFunction(s, List(State))
    }
    sys.inputs.foreach(i => declare(i, "input"))
    sys.states.foreach(s => declare(s.sym, "register"))

    // signals are just functions of other signals, inputs and state
    def define(sym: SMTSymbol, e: SMTExpr, suffix: String = SignalSuffix): Unit = {
      val withReplacedSymbols = replaceSymbols(SignalSuffix, State)(e)
      cmds += DefineFunction(sym.name + suffix, List(State), withReplacedSymbols)
    }
    sys.signals.foreach { signal =>
      val sym = signal.sym
      cmds ++= toDescription(sym, lblToKind(signal.lbl), sys.comments.get)
      val e = if (signal.lbl == IsBad) BVNot(signal.e.asInstanceOf[BVExpr]) else signal.e
      define(sym, e)
    }

    // define the next and init functions for all states
    sys.states.foreach { state =>
      assert(state.next.nonEmpty, "Next function required")
      define(state.sym, state.next.get, NextSuffix)
      // init is optional
      state.init.foreach { init =>
        define(state.sym, init, InitSuffix)
      }
    }

    def defineConjunction(e: List[BVExpr], suffix: String): Unit = {
      define(BVSymbol(name, 1), if (e.isEmpty) True() else BVAnd(e), suffix)
    }

    // the transition relation asserts that the value of the next state is the next value from the previous state
    // e.g., (reg state_n) == (reg_next state)
    val transitionRelations = sys.states.map { state =>
      val newState = replaceSymbols(SignalSuffix, StateNext)(state.sym)
      val nextOldState = replaceSymbols(NextSuffix, State)(state.sym)
      SMTEqual(newState, nextOldState)
    }
    // the transition relation is over two states
    val transitionExpr = if (transitionRelations.isEmpty) { True() }
    else {
      replaceSymbols(SignalSuffix, State)(BVAnd(transitionRelations))
    }
    cmds += DefineFunction(name + "_t", List(State, StateNext), transitionExpr)

    // The init relation just asserts that all init function hold
    val initRelations = sys.states.filter(_.init.isDefined).map { state =>
      val stateSignal = replaceSymbols(SignalSuffix, State)(state.sym)
      val initSignal = replaceSymbols(InitSuffix, State)(state.sym)
      SMTEqual(stateSignal, initSignal)
    }
    defineConjunction(initRelations, "_i")

    // assertions and assumptions
    val assertions = sys.signals.filter(_.lbl == IsBad).map(a => replaceSymbols(SignalSuffix, State)(a.sym))
    defineConjunction(assertions.map(_.asInstanceOf[BVExpr]), AssertionSuffix)
    val assumptions = sys.signals.filter(_.lbl == IsConstraint).map(a => replaceSymbols(SignalSuffix, State)(a.sym))
    defineConjunction(assumptions.map(_.asInstanceOf[BVExpr]), AssumptionSuffix)

    cmds
  }

  private def id(s: String): String = SMTLibSerializer.escapeIdentifier(s)
  private val SignalSuffix = "_f"
  private val NextSuffix = "_next"
  private val InitSuffix = "_init"
  val AssertionSuffix = "_a"
  val AssumptionSuffix = "_u"
  private def lblToKind(lbl: SignalLabel): String = lbl match {
    case IsNode | IsInit | IsNext => "wire"
    case IsOutput                 => "output"
    // for the SMT encoding we turn bad state signals back into assertions
    case IsBad        => "assert"
    case IsConstraint => "assume"
    case IsFair       => "fair"
  }
  private def toDescription(sym: SMTSymbol, kind: String, comments: String => Option[String]): List[Comment] = {
    List(sym match {
      case BVSymbol(name, width) => Comment(s"firrtl-smt2-$kind $name $width")
      case ArraySymbol(name, indexWidth, dataWidth) =>
        Comment(s"firrtl-smt2-$kind $name $indexWidth $dataWidth")
    }) ++ comments(sym.name).map(Comment)
  }
  // All signals are modelled with functions that need to be called with the state as argument,
  // this replaces all Symbols with function applications to the state.
  private def replaceSymbols(suffix: String, arg: SMTFunctionArg, vars: Set[String] = Set())(e: SMTExpr): SMTExpr =
    e match {
      case BVSymbol(name, width) if !vars(name) => BVFunctionCall(id(name + suffix), List(arg), width)
      case ArraySymbol(name, indexWidth, dataWidth) if !vars(name) =>
        ArrayFunctionCall(id(name + suffix), List(arg), indexWidth, dataWidth)
      case fa @ BVForall(variable, _) => SMTExprMap.mapExpr(fa, replaceSymbols(suffix, arg, vars + variable.name))
      case other                      => SMTExprMap.mapExpr(other, replaceSymbols(suffix, arg, vars))
    }
}