aboutsummaryrefslogtreecommitdiff
path: root/src/test/scala/firrtlTests/CInferMDirSpec.scala
blob: 299142d966af673f1732fbc9a42be19a2863b151 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
// See LICENSE for license details.

package firrtlTests

import firrtl._
import firrtl.ir._
import firrtl.passes._
import firrtl.transforms._
import firrtl.Mappers._
import annotations._

class CInferMDir extends LowTransformSpec {
  object CInferMDirCheckPass extends Pass {
    // finds the memory and check its read port
    def checkStmt(s: Statement): Boolean = s match {
      case s: DefMemory if s.name == "indices" =>
        (s.readers contains "index") &&
        (s.writers contains "bar") &&
        s.readwriters.isEmpty
      case s: Block =>
        s.stmts exists checkStmt
      case _ => false
    }

    def run (c: Circuit) = {
      val errors = new Errors
      val check = c.modules exists {
        case m: Module => checkStmt(m.body)
        case m: ExtModule => false
      }
      if (!check) {
        errors append new PassException(
          "Memory has incorrect port directions!")
        errors.trigger
      }
      c
    }
  }

  def transform = new SeqTransform {
    def inputForm = LowForm
    def outputForm = LowForm
    def transforms = Seq(new ConstantPropagation, CInferMDirCheckPass)
  }

  "Memory" should "have correct mem port directions" in {
    val input = """
circuit foo :
  module foo :
    input clock : Clock
    input reset : UInt<1>
    output io : {flip wen : UInt<1>, flip in : UInt<1>, flip counter : UInt<2>, ren: UInt<1>[4], out : UInt<1>[4]}

    io is invalid
    cmem indices : UInt<2>[4]
    node T_0 = add(io.counter, UInt<1>("h01"))
    node temp = tail(T_0, 1)
    infer mport index = indices[temp], clock
    io.out[0] <= UInt<1>("h0")
    io.out[1] <= UInt<1>("h0")
    io.out[2] <= UInt<1>("h0")
    io.out[3] <= UInt<1>("h0")
    when io.ren[index] :
      io.out[index] <= io.in
    else :
      when io.wen :
        infer mport bar = indices[temp], clock
        bar <= io.in
""".stripMargin

    val annotationMap = AnnotationMap(Nil)
    val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, Some(annotationMap)))
    // Check correctness of firrtl
    parse(res.getEmittedCircuit.value)
  }
}