aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/RemoveWires.scala
blob: 4fa700023dc1740f18a00fa39744ad7fe6f54cfb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
// SPDX-License-Identifier: Apache-2.0

package firrtl
package transforms

import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
import firrtl.traversals.Foreachers._
import firrtl.WrappedExpression._
import firrtl.graph.{CyclicException, MutableDiGraph}
import firrtl.options.Dependency
import firrtl.Utils.getGroundZero
import firrtl.backends.experimental.smt.random.DefRandom
import firrtl.passes.PadWidths

import scala.collection.mutable
import scala.util.{Failure, Success, Try}

/** Replace wires with nodes in a legal, flow-forward order
  *
  *  This pass must run after LowerTypes because Aggregate-type
  *  wires have multiple connections that may be impossible to order in a
  *  flow-foward way
  */
class RemoveWires extends Transform with DependencyAPIMigration {

  override def prerequisites = firrtl.stage.Forms.MidForm ++
    Seq(
      Dependency(passes.LowerTypes),
      Dependency(passes.ResolveKinds),
      Dependency(transforms.RemoveReset),
      Dependency[transforms.CheckCombLoops],
      Dependency(passes.LegalizeConnects)
    )

  override def optionalPrerequisites = Seq(Dependency[checks.CheckResets])

  override def optionalPrerequisiteOf = Seq.empty

  override def invalidates(a: Transform) = a match {
    case passes.ResolveKinds => true
    case _                   => false
  }

  // Extract all expressions that are references to a Node, Wire, Reg or Rand
  // Since we are operating on LowForm, they can only be WRefs
  private def extractNodeWireRegRefs(expr: Expression): Seq[WRef] = {
    val refs = mutable.ArrayBuffer.empty[WRef]
    def rec(e: Expression): Expression = {
      e match {
        case ref @ WRef(_, _, WireKind | NodeKind | RegKind | RandomKind, _) => refs += ref
        case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested.foreach(rec)
        case _ => // Do nothing
      }
      e
    }
    rec(expr)
    refs.toSeq
  }

  // Transform netlist into DefNodes
  private def getOrderedNodes(
    netlist:  mutable.LinkedHashMap[WrappedExpression, (Seq[Expression], Info)],
    regInfo:  mutable.Map[WrappedExpression, DefRegister],
    randInfo: mutable.Map[WrappedExpression, DefRandom]
  ): Try[Seq[Statement]] = {
    val digraph = new MutableDiGraph[WrappedExpression]
    for ((sink, (exprs, _)) <- netlist) {
      digraph.addVertex(sink)
      for (expr <- exprs) {
        for (source <- extractNodeWireRegRefs(expr)) {
          digraph.addPairWithEdge(sink, source)
        }
      }
    }

    // We could reverse edge directions and not have to do this reverse, but doing it this way does
    // a MUCH better job of preserving the logic order as expressed by the designer
    // See RemoveWireTests for illustration
    Try {
      val ordered = digraph.linearize.reverse
      ordered.map { key =>
        val WRef(name, _, kind, _) = key.e1
        kind match {
          case RegKind    => regInfo(key)
          case RandomKind => randInfo(key)
          case WireKind | NodeKind =>
            val (Seq(rhs), info) = netlist(key)
            DefNode(info, name, rhs)
        }
      }
    }
  }

  private def onModule(m: DefModule): DefModule = {
    // Store all non-node declarations here (like reg, inst, and mem)
    val decls = mutable.ArrayBuffer.empty[Statement]
    // Store all "other" statements here, non-wire, non-node connections, printfs, etc.
    val otherStmts = mutable.ArrayBuffer.empty[Statement]
    // Add nodes and wire connection here
    val netlist = mutable.LinkedHashMap.empty[WrappedExpression, (Seq[Expression], Info)]
    // Info at definition of wires for combining into node
    val wireInfo = mutable.HashMap.empty[WrappedExpression, Info]
    // Additional info about registers
    val regInfo = mutable.HashMap.empty[WrappedExpression, DefRegister]
    // Additional info about rand statements
    val randInfo = mutable.HashMap.empty[WrappedExpression, DefRandom]

    def onStmt(stmt: Statement): Statement = {
      stmt match {
        case node: DefNode =>
          netlist(we(WRef(node))) = (Seq(node.value), node.info)
        case wire: DefWire if !wire.tpe.isInstanceOf[AnalogType] => // Remove all non-Analog wires
          wireInfo(WRef(wire)) = wire.info
        case reg: DefRegister =>
          val resetDep = reg.reset.tpe match {
            case AsyncResetType => Some(reg.reset)
            case _              => None
          }
          val initDep = Some(reg.init).filter(we(WRef(reg)) != we(_)) // Dependency exists IF reg doesn't init itself
          regInfo(we(WRef(reg))) = reg
          netlist(we(WRef(reg))) = (Seq(reg.clock) ++ resetDep ++ initDep, reg.info)
        case rand: DefRandom =>
          randInfo(we(Reference(rand))) = rand
          netlist(we(Reference(rand))) = (rand.clock ++: rand.en +: List(), rand.info)
        case decl: CanBeReferenced =>
          // Keep all declarations except for nodes and non-Analog wires and "other" statements.
          // Thus this is expected to match DefInstance and DefMemory which both do not connect to
          // any signals directly (instead a separate Connect is used).
          decls += decl
        case con @ Connect(cinfo, lhs, rhs) =>
          kind(lhs) match {
            case WireKind =>
              // be sure that connects have the same bit widths on rhs and lhs
              assert(
                bitWidth(lhs.tpe) == bitWidth(rhs.tpe),
                "Connection widths should have been taken care of by LegalizeConnects!"
              )
              val dinfo = wireInfo(lhs)
              netlist(we(lhs)) = (Seq(rhs), MultiInfo(dinfo, cinfo))
            case _ => otherStmts += con // Other connections just pass through
          }
        case invalid @ IsInvalid(info, expr) =>
          kind(expr) match {
            case WireKind =>
              val (tpe, width) = expr.tpe match { case g: GroundType => (g, g.width) } // LowFirrtl
              netlist(we(expr)) = (Seq(ValidIf(Utils.zero, getGroundZero(tpe), tpe)), info)
            case _ => otherStmts += invalid
          }
        case other @ (_: Print | _: Stop | _: Attach | _: Verification) =>
          otherStmts += other
        case EmptyStmt => // Dont bother keeping EmptyStmts around
        case block: Block => block.foreach(onStmt)
        case _ => throwInternalError()
      }
      stmt
    }

    m match {
      case mod @ Module(info, name, ports, body) =>
        onStmt(body)
        getOrderedNodes(netlist, regInfo, randInfo) match {
          case Success(logic) =>
            Module(info, name, ports, Block(List() ++ decls ++ logic ++ otherStmts))
          // If we hit a CyclicException, just abort removing wires
          case Failure(c: CyclicException) =>
            val problematicNode = c.node
            logger.warn(
              s"Cycle found in module $name, " +
                s"wires will not be removed which can prevent optimizations! Problem node: $problematicNode"
            )
            mod
          case Failure(other) => throw other
        }
      case m: ExtModule => m
    }
  }

  def execute(state: CircuitState): CircuitState =
    state.copy(circuit = state.circuit.map(onModule))
}