aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/CheckFlows.scala
blob: 3a9cc212f7129fec3fc251d3e4e3f674ee18b445 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
// See LICENSE for license details.

package firrtl.passes

import firrtl._
import firrtl.ir._
import firrtl.Utils._
import firrtl.traversals.Foreachers._
import firrtl.options.Dependency

object CheckFlows extends Pass {

  override def prerequisites = Dependency(passes.ResolveFlows) +: firrtl.stage.Forms.WorkingIR

  override def optionalPrerequisiteOf =
    Seq( Dependency[passes.InferBinaryPoints],
         Dependency[passes.TrimIntervals],
         Dependency[passes.InferWidths],
         Dependency[transforms.InferResets] )

  override def invalidates(a: Transform) = false

  type FlowMap = collection.mutable.HashMap[String, Flow]

  implicit def toStr(g: Flow): String = g match {
    case SourceFlow => "source"
    case SinkFlow => "sink"
    case UnknownFlow => "unknown"
    case DuplexFlow => "duplex"
  }

  class WrongFlow(info:Info, mname: String, expr: String, wrong: Flow, right: Flow) extends PassException(
    s"$info: [module $mname]  Expression $expr is used as a $wrong but can only be used as a $right.")

  def run (c:Circuit): Circuit = {
    val errors = new Errors()

    def get_flow(e: Expression, flows: FlowMap): Flow = e match {
      case (e: WRef) => flows(e.name)
      case (e: WSubIndex) => get_flow(e.expr, flows)
      case (e: WSubAccess) => get_flow(e.expr, flows)
      case (e: WSubField) => e.expr.tpe match {case t: BundleType =>
        val f = (t.fields find (_.name == e.name)).get
        times(get_flow(e.expr, flows), f.flip)
      }
      case _ => SourceFlow
    }

    def flip_q(t: Type): Boolean = {
      def flip_rec(t: Type, f: Orientation): Boolean = t match {
        case tx:BundleType => tx.fields exists (
          field => flip_rec(field.tpe, times(f, field.flip))
        )
        case tx: VectorType => flip_rec(tx.tpe, f)
        case tx => f == Flip
      }
      flip_rec(t, Default)
    }

    def check_flow(info:Info, mname: String, flows: FlowMap, desired: Flow)(e:Expression): Unit = {
      val flow = get_flow(e,flows)
      (flow, desired) match {
        case (SourceFlow, SinkFlow) =>
          errors.append(new WrongFlow(info, mname, e.serialize, desired, flow))
        case (SinkFlow, SourceFlow) => kind(e) match {
          case PortKind | InstanceKind if !flip_q(e.tpe) => // OK!
          case _ =>
            errors.append(new WrongFlow(info, mname, e.serialize, desired, flow))
        }
        case _ =>
      }
   }

    def check_flows_e (info:Info, mname: String, flows: FlowMap)(e:Expression): Unit = {
      e match {
        case e: Mux => e foreach check_flow(info, mname, flows, SourceFlow)
        case e: DoPrim => e.args foreach check_flow(info, mname, flows, SourceFlow)
        case _ =>
      }
      e foreach check_flows_e(info, mname, flows)
    }

    def check_flows_s(minfo: Info, mname: String, flows: FlowMap)(s: Statement): Unit = {
      val info = get_info(s) match { case NoInfo => minfo case x => x }
      s match {
        case (s: DefWire) => flows(s.name) = DuplexFlow
        case (s: DefRegister) => flows(s.name) = DuplexFlow
        case (s: DefMemory) => flows(s.name) = SourceFlow
        case (s: WDefInstance) => flows(s.name) = SourceFlow
        case (s: DefNode) =>
          check_flow(info, mname, flows, SourceFlow)(s.value)
          flows(s.name) = SourceFlow
        case (s: Connect) =>
          check_flow(info, mname, flows, SinkFlow)(s.loc)
          check_flow(info, mname, flows, SourceFlow)(s.expr)
        case (s: Print) =>
          s.args foreach check_flow(info, mname, flows, SourceFlow)
          check_flow(info, mname, flows, SourceFlow)(s.en)
          check_flow(info, mname, flows, SourceFlow)(s.clk)
        case (s: PartialConnect) =>
          check_flow(info, mname, flows, SinkFlow)(s.loc)
          check_flow(info, mname, flows, SourceFlow)(s.expr)
        case (s: Conditionally) =>
          check_flow(info, mname, flows, SourceFlow)(s.pred)
        case (s: Stop) =>
          check_flow(info, mname, flows, SourceFlow)(s.en)
          check_flow(info, mname, flows, SourceFlow)(s.clk)
        case (s: Verification) =>
          check_flow(info, mname, flows, SourceFlow)(s.clk)
          check_flow(info, mname, flows, SourceFlow)(s.pred)
          check_flow(info, mname, flows, SourceFlow)(s.en)
        case _ =>
      }
      s foreach check_flows_e(info, mname, flows)
      s foreach check_flows_s(minfo, mname, flows)
    }

    for (m <- c.modules) {
      val flows = new FlowMap
      flows ++= (m.ports map (p => p.name -> to_flow(p.direction)))
      m foreach check_flows_s(m.info, m.name, flows)
    }
    errors.trigger()
    c
  }
}