aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/checks/CheckResets.scala
blob: d6337f9e7c20870db1d580b5ebf8d4d3b60c81be (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
// See LICENSE for license details.

package firrtl.checks

import firrtl._
import firrtl.passes.{Errors, PassException}
import firrtl.ir._
import firrtl.traversals.Foreachers._
import firrtl.WrappedExpression._

import scala.collection.mutable

object CheckResets {
  class NonLiteralAsyncResetValueException(info: Info, mname: String, reg: String, init: String) extends PassException(
    s"$info: [module $mname] AsyncReset Reg '$reg' reset to non-literal '$init'")

  // Map of Initialization Expression to check
  private type RegCheckList = mutable.ListBuffer[(Expression, DefRegister)]
  // Record driving for literal propagation
  // Indicates *driven by*
  private type DirectDriverMap = mutable.HashMap[WrappedExpression, Expression]

}

// Must run after ExpandWhens
// Requires
//   - static single connections of ground types
class CheckResets extends Transform {
  def inputForm: CircuitForm = MidForm
  def outputForm: CircuitForm = MidForm

  import CheckResets._

  private def onStmt(regCheck: RegCheckList, drivers: DirectDriverMap)(stmt: Statement): Unit = {
    stmt match {
      case DefNode(_, name, expr) => drivers += we(WRef(name)) -> expr
      case Connect(_, lhs, rhs) => drivers += we(lhs) -> rhs
      case reg @ DefRegister(_,_,_,_, reset, init) if reset.tpe == AsyncResetType =>
        regCheck += init -> reg
      case _ => // Do nothing
    }
    stmt.foreach(onStmt(regCheck, drivers))
  }

  private def findDriver(drivers: DirectDriverMap)(expr: Expression): Expression =
    drivers.get(we(expr)) match {
      case Some(lit: Literal) => lit
      case Some(other) => findDriver(drivers)(other)
      case None => expr
    }

  private def onMod(errors: Errors)(mod: DefModule): Unit = {
    val regCheck = new RegCheckList()
    val drivers = new DirectDriverMap()
    mod.foreach(onStmt(regCheck, drivers))
    for ((init, reg) <- regCheck) {
      for (subInit <- Utils.create_exps(init)) {
        findDriver(drivers)(subInit) match {
          case lit: Literal => // All good
          case other =>
            val e = new NonLiteralAsyncResetValueException(reg.info, mod.name, reg.name, other.serialize)
            errors.append(e)
        }
      }
    }
  }

  def execute(state: CircuitState): CircuitState = {
    val errors = new Errors
    state.circuit.foreach(onMod(errors))
    errors.trigger()
    state
  }
}