aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/ExpandConnects.scala
blob: 518068a16d1765dc123c4379f0850ca8d4aaf9c9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
// SPDX-License-Identifier: Apache-2.0

package firrtl.passes

import firrtl.Utils.{create_exps, flow, get_field, get_valid_points, times, to_flip, to_flow}
import firrtl.ir._
import firrtl.options.Dependency
import firrtl.{DuplexFlow, Flow, SinkFlow, SourceFlow, Transform, WDefInstance, WRef, WSubAccess, WSubField, WSubIndex}
import firrtl.Mappers._

object ExpandConnects extends Pass {

  override def prerequisites =
    Seq(Dependency(PullMuxes), Dependency(ReplaceAccesses)) ++ firrtl.stage.Forms.Deduped

  override def invalidates(a: Transform) = a match {
    case ResolveFlows => true
    case _            => false
  }

  def run(c: Circuit): Circuit = {
    def expand_connects(m: Module): Module = {
      val flows = collection.mutable.LinkedHashMap[String, Flow]()
      def expand_s(s: Statement): Statement = {
        def set_flow(e: Expression): Expression = e.map(set_flow) match {
          case ex: WRef => WRef(ex.name, ex.tpe, ex.kind, flows(ex.name))
          case ex: WSubField =>
            val f = get_field(ex.expr.tpe, ex.name)
            val flowx = times(flow(ex.expr), f.flip)
            WSubField(ex.expr, ex.name, ex.tpe, flowx)
          case ex: WSubIndex  => WSubIndex(ex.expr, ex.value, ex.tpe, flow(ex.expr))
          case ex: WSubAccess => WSubAccess(ex.expr, ex.index, ex.tpe, flow(ex.expr))
          case ex => ex
        }
        s match {
          case sx: DefWire      => flows(sx.name) = DuplexFlow; sx
          case sx: DefRegister  => flows(sx.name) = DuplexFlow; sx
          case sx: WDefInstance => flows(sx.name) = SourceFlow; sx
          case sx: DefMemory    => flows(sx.name) = SourceFlow; sx
          case sx: DefNode => flows(sx.name) = SourceFlow; sx
          case sx: IsInvalid =>
            val invalids = create_exps(sx.expr).flatMap {
              case expx =>
                flow(set_flow(expx)) match {
                  case DuplexFlow => Some(IsInvalid(sx.info, expx))
                  case SinkFlow   => Some(IsInvalid(sx.info, expx))
                  case _          => None
                }
            }
            invalids.size match {
              case 0 => EmptyStmt
              case 1 => invalids.head
              case _ => Block(invalids)
            }
          case sx: Connect =>
            val locs = create_exps(sx.loc)
            val exps = create_exps(sx.expr)
            Block(locs.zip(exps).map {
              case (locx, expx) =>
                to_flip(flow(locx)) match {
                  case Default => Connect(sx.info, locx, expx)
                  case Flip    => Connect(sx.info, expx, locx)
                }
            })
          case sx: PartialConnect =>
            val ls = get_valid_points(sx.loc.tpe, sx.expr.tpe, Default, Default)
            val locs = create_exps(sx.loc)
            val exps = create_exps(sx.expr)
            val stmts = ls.map {
              case (x, y) =>
                locs(x).tpe match {
                  case AnalogType(_) => Attach(sx.info, Seq(locs(x), exps(y)))
                  case _ =>
                    to_flip(flow(locs(x))) match {
                      case Default => Connect(sx.info, locs(x), exps(y))
                      case Flip    => Connect(sx.info, exps(y), locs(x))
                    }
                }
            }
            Block(stmts)
          case sx => sx.map(expand_s)
        }
      }

      m.ports.foreach { p => flows(p.name) = to_flow(p.direction) }
      Module(m.info, m.name, m.ports, expand_s(m.body))
    }

    val modulesx = c.modules.map {
      case (m: ExtModule) => m
      case (m: Module)    => expand_connects(m)
    }
    Circuit(c.info, modulesx, c.main)
  }
}