aboutsummaryrefslogtreecommitdiff
path: root/src/test/scala/firrtlTests/CInferMDirSpec.scala
blob: 349353d1ed671765d06f337e7436a16328aff1cd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
// SPDX-License-Identifier: Apache-2.0

package firrtlTests

import firrtl._
import firrtl.ir._
import firrtl.passes._
import firrtl.transforms._
import firrtl.testutils._

class CInferMDirSpec extends LowTransformSpec {
  object CInferMDirCheckPass extends Pass {
    // finds the memory and check its read port
    def checkStmt(s: Statement): Boolean = s match {
      case s: DefMemory if s.name == "indices" =>
        (s.readers contains "index") &&
          (s.writers contains "bar") &&
          s.readwriters.isEmpty
      case s: Block =>
        s.stmts.exists(checkStmt)
      case _ => false
    }

    def run(c: Circuit) = {
      val errors = new Errors
      val check = c.modules.exists {
        case m: Module    => checkStmt(m.body)
        case m: ExtModule => false
      }
      if (!check) {
        errors.append(new PassException("Memory has incorrect port directions!"))
        errors.trigger
      }
      c
    }
  }

  def transform = new SeqTransform {
    def inputForm = LowForm
    def outputForm = LowForm
    def transforms = Seq(new ConstantPropagation, CInferMDirCheckPass)
  }

  "Memory" should "have correct mem port directions" in {
    val input = """
circuit foo :
  module foo :
    input clock : Clock
    input reset : UInt<1>
    output io : {flip wen : UInt<1>, flip in : UInt<1>, flip counter : UInt<2>, ren: UInt<1>[4], out : UInt<1>[4]}

    io is invalid
    cmem indices : UInt<2>[4]
    node T_0 = add(io.counter, UInt<1>("h01"))
    node temp = tail(T_0, 1)
    infer mport index = indices[temp], clock
    io.out[0] <= UInt<1>("h0")
    io.out[1] <= UInt<1>("h0")
    io.out[2] <= UInt<1>("h0")
    io.out[3] <= UInt<1>("h0")
    when io.ren[index] :
      io.out[index] <= io.in
    else :
      when io.wen :
        infer mport bar = indices[temp], clock
        bar <= io.in
""".stripMargin

    val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm))
    // Check correctness of firrtl
    parse(res.getEmittedCircuit.value)
  }
}