aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/checks/CheckResets.scala
blob: e5a3e77afc63de9a8ca54ad80748d220fe4c0f70 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
// SPDX-License-Identifier: Apache-2.0

package firrtl.checks

import firrtl._
import firrtl.options.Dependency
import firrtl.passes.{Errors, PassException}
import firrtl.ir._
import firrtl.Utils.isCast
import firrtl.traversals.Foreachers._
import firrtl.WrappedExpression._

import scala.collection.mutable
import scala.annotation.tailrec

object CheckResets {
  class NonLiteralAsyncResetValueException(info: Info, mname: String, reg: String, init: String)
      extends PassException(s"$info: [module $mname] AsyncReset Reg '$reg' reset to non-literal '$init'")

  // Map of Initialization Expression to check
  private type RegCheckList = mutable.ListBuffer[(Expression, DefRegister)]
  // Record driving for literal propagation
  // Indicates *driven by*
  private type DirectDriverMap = mutable.HashMap[WrappedExpression, Expression]

}

// Must run after ExpandWhens
// Requires
//   - static single connections of ground types
class CheckResets extends Transform with DependencyAPIMigration {

  override def prerequisites =
    Seq(
      Dependency(passes.LowerTypes),
      Dependency(firrtl.transforms.RemoveReset)
    ) ++ firrtl.stage.Forms.MidForm

  override def optionalPrerequisites = Seq(Dependency[firrtl.transforms.CheckCombLoops])

  override def optionalPrerequisiteOf = Seq.empty

  override def invalidates(a: Transform) = false

  import CheckResets._

  private def onStmt(regCheck: RegCheckList, drivers: DirectDriverMap)(stmt: Statement): Unit = {
    stmt match {
      case DefNode(_, name, expr)                                             => drivers += we(WRef(name)) -> expr
      case Connect(_, lhs, rhs)                                               => drivers += we(lhs) -> rhs
      case reg @ DefRegister(_, name, _, _, _, init) if weq(WRef(name), init) => // Self-reset, allowed!
      case reg @ DefRegister(_, _, _, _, reset, init) if reset.tpe == AsyncResetType =>
        regCheck += init -> reg
      case _ => // Do nothing
    }
    stmt.foreach(onStmt(regCheck, drivers))
  }

  private def wireOrNode(kind: Kind) = (kind == WireKind || kind == NodeKind)

  @tailrec
  private def findDriver(drivers: DirectDriverMap)(expr: Expression): Expression = expr match {
    case lit: Literal => lit
    case DoPrim(op, args, _, _) if isCast(op) => findDriver(drivers)(args.head)
    case other =>
      drivers.get(we(other)) match {
        case Some(e) if wireOrNode(Utils.kind(other)) => findDriver(drivers)(e)
        case _                                        => other
      }
  }

  private def onMod(errors: Errors)(mod: DefModule): Unit = {
    val regCheck = new RegCheckList()
    val drivers = new DirectDriverMap()
    mod.foreach(onStmt(regCheck, drivers))
    for ((init, reg) <- regCheck) {
      for (subInit <- Utils.create_exps(init)) {
        findDriver(drivers)(subInit) match {
          case lit: Literal => // All good
          case other =>
            val e = new NonLiteralAsyncResetValueException(reg.info, mod.name, reg.name, other.serialize)
            errors.append(e)
        }
      }
    }
  }

  def execute(state: CircuitState): CircuitState = {
    val errors = new Errors
    state.circuit.foreach(onMod(errors))
    errors.trigger()
    state
  }
}