aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/formal/AssertSubmoduleAssumptions.scala
blob: 3199cedfe6ff86836c79b0c21f4ce01018df30cf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
// SPDX-License-Identifier: Apache-2.0

package firrtl.transforms.formal

import firrtl.ir.{Circuit, Formal, Statement, Verification}
import firrtl.stage.TransformManager.TransformDependency
import firrtl.{CircuitState, DependencyAPIMigration, Transform}
import firrtl.annotations.NoTargetAnnotation
import firrtl.options.{PreservesAll, RegisteredTransform, ShellOption}

/**
  * Assert Submodule Assumptions
  *
  * Converts `assume` statements to `assert` statements in all modules except
  * the top module being compiled. This avoids a class of bugs in which an
  * overly restrictive assume in a child module can prevent the model checker
  * from searching valid inputs and states in the parent module.
  */
class AssertSubmoduleAssumptions
    extends Transform
    with RegisteredTransform
    with DependencyAPIMigration
    with PreservesAll[Transform] {

  override def prerequisites:         Seq[TransformDependency] = Seq.empty
  override def optionalPrerequisites: Seq[TransformDependency] = Seq.empty
  override def optionalPrerequisiteOf: Seq[TransformDependency] =
    firrtl.stage.Forms.MidEmitters

  val options = Seq(
    new ShellOption[Unit](
      longOption = "no-asa",
      toAnnotationSeq = (_: Unit) => Seq(DontAssertSubmoduleAssumptionsAnnotation),
      helpText = "Disable assert submodule assumptions"
    )
  )

  def assertAssumption(s: Statement): Statement = s match {
    case v: Verification if v.op == Formal.Assume => v.withOp(Formal.Assert)
    case t => t.mapStmt(assertAssumption)
  }

  def run(c: Circuit): Circuit = {
    c.mapModule(mod => {
      if (mod.name != c.main) {
        mod.mapStmt(assertAssumption)
      } else {
        mod
      }
    })
  }

  def execute(state: CircuitState): CircuitState = {
    val noASA = state.annotations.contains(DontAssertSubmoduleAssumptionsAnnotation)
    if (noASA) {
      logger.info("Skipping assert submodule assumptions")
      state
    } else {
      state.copy(circuit = run(state.circuit))
    }
  }
}

case object AssertSubmoduleAssumptionsAnnotation extends NoTargetAnnotation {
  val transform = new AssertSubmoduleAssumptions
}

case object DontAssertSubmoduleAssumptionsAnnotation extends NoTargetAnnotation