aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/CInferMDir.scala
blob: cca8fde4ff56850dd3b50954f1e7ed5d80fc07e5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
// SPDX-License-Identifier: Apache-2.0

package firrtl.passes

import firrtl._
import firrtl.ir._
import firrtl.Mappers._
import firrtl.options.Dependency
import Utils.throwInternalError

object CInferMDir extends Pass {

  override def prerequisites = firrtl.stage.Forms.ChirrtlForm :+ Dependency(CInferTypes)

  override def invalidates(a: Transform) = false

  type MPortDirMap = collection.mutable.LinkedHashMap[String, MPortDir]

  def infer_mdir_e(mports: MPortDirMap, dir: MPortDir)(e: Expression): Expression = e match {
    case e: Reference =>
      mports.get(e.name) match {
        case None =>
        case Some(p) =>
          mports(e.name) = (p, dir) match {
            case (MInfer, MWrite)         => MWrite
            case (MInfer, MRead)          => MRead
            case (MInfer, MReadWrite)     => MReadWrite
            case (MWrite, MWrite)         => MWrite
            case (MWrite, MRead)          => MReadWrite
            case (MWrite, MReadWrite)     => MReadWrite
            case (MRead, MWrite)          => MReadWrite
            case (MRead, MRead)           => MRead
            case (MRead, MReadWrite)      => MReadWrite
            case (MReadWrite, MWrite)     => MReadWrite
            case (MReadWrite, MRead)      => MReadWrite
            case (MReadWrite, MReadWrite) => MReadWrite
            case _                        => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir")
          }
      }
      e
    case e: SubAccess =>
      infer_mdir_e(mports, dir)(e.expr)
      infer_mdir_e(mports, MRead)(e.index) // index can't be a write port
      e
    case e => e.map(infer_mdir_e(mports, dir))
  }

  def infer_mdir_s(mports: MPortDirMap)(s: Statement): Statement = s match {
    case sx: CDefMPort =>
      mports(sx.name) = sx.direction
      sx.map(infer_mdir_e(mports, MRead))
    case sx: Connect =>
      infer_mdir_e(mports, MRead)(sx.expr)
      infer_mdir_e(mports, MWrite)(sx.loc)
      sx
    case sx: PartialConnect =>
      infer_mdir_e(mports, MRead)(sx.expr)
      infer_mdir_e(mports, MWrite)(sx.loc)
      sx
    case sx => sx.map(infer_mdir_s(mports)).map(infer_mdir_e(mports, MRead))
  }

  def set_mdir_s(mports: MPortDirMap)(s: Statement): Statement = s match {
    case sx: CDefMPort => sx.copy(direction = mports(sx.name))
    case sx => sx.map(set_mdir_s(mports))
  }

  def infer_mdir(m: DefModule): DefModule = {
    val mports = new MPortDirMap
    m.map(infer_mdir_s(mports)).map(set_mdir_s(mports))
  }

  def run(c: Circuit): Circuit =
    c.copy(modules = c.modules.map(infer_mdir))
}