aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/CInferMDir.scala
blob: b4819751ae4f1852354913d531f16b1e2080027d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
// See LICENSE for license details.

package firrtl.passes

import firrtl._
import firrtl.ir._
import firrtl.Mappers._
import firrtl.options.Dependency
import Utils.throwInternalError

object CInferMDir extends Pass {

  override def prerequisites = firrtl.stage.Forms.ChirrtlForm :+ Dependency(CInferTypes)

  override def invalidates(a: Transform) = false

  type MPortDirMap = collection.mutable.LinkedHashMap[String, MPortDir]

  def infer_mdir_e(mports: MPortDirMap, dir: MPortDir)(e: Expression): Expression = e match {
    case e: Reference =>
      mports get e.name match {
        case None =>
        case Some(p) => mports(e.name) = (p, dir) match {
          case (MInfer, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir")
          case (MInfer, MWrite) => MWrite
          case (MInfer, MRead) => MRead
          case (MInfer, MReadWrite) => MReadWrite
          case (MWrite, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir")
          case (MWrite, MWrite) => MWrite
          case (MWrite, MRead) => MReadWrite
          case (MWrite, MReadWrite) => MReadWrite
          case (MRead, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir")
          case (MRead, MWrite) => MReadWrite
          case (MRead, MRead) => MRead
          case (MRead, MReadWrite) => MReadWrite
          case (MReadWrite, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir")
          case (MReadWrite, MWrite) => MReadWrite
          case (MReadWrite, MRead) => MReadWrite
          case (MReadWrite, MReadWrite) => MReadWrite
        }
      }
      e
    case e: SubAccess =>
      infer_mdir_e(mports, dir)(e.expr)
      infer_mdir_e(mports, MRead)(e.index) // index can't be a write port
      e
    case e => e map infer_mdir_e(mports, dir)
  }

  def infer_mdir_s(mports: MPortDirMap)(s: Statement): Statement = s match {
    case sx: CDefMPort =>
       mports(sx.name) = sx.direction
       sx map infer_mdir_e(mports, MRead)
    case sx: Connect =>
       infer_mdir_e(mports, MRead)(sx.expr)
       infer_mdir_e(mports, MWrite)(sx.loc)
       sx
    case sx: PartialConnect =>
       infer_mdir_e(mports, MRead)(sx.expr)
       infer_mdir_e(mports, MWrite)(sx.loc)
       sx
    case sx => sx map infer_mdir_s(mports) map infer_mdir_e(mports, MRead)
  }

  def set_mdir_s(mports: MPortDirMap)(s: Statement): Statement = s match {
    case sx: CDefMPort => sx copy (direction = mports(sx.name))
    case sx => sx map set_mdir_s(mports)
  }

  def infer_mdir(m: DefModule): DefModule = {
    val mports = new MPortDirMap
    m map infer_mdir_s(mports) map set_mdir_s(mports)
  }

  def run(c: Circuit): Circuit =
    c copy (modules = c.modules map infer_mdir)
}