summaryrefslogtreecommitdiff
path: root/src/main/scala/chisel3/aop/injecting/InjectingTransform.scala
blob: 8a0b6ecb2540e8a1d1294e97868f4599bb21fd18 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
// SPDX-License-Identifier: Apache-2.0

package chisel3.aop.injecting

import firrtl.{ir, ChirrtlForm, CircuitForm, CircuitState, Transform}

import scala.collection.mutable

/** Appends statements contained in [[InjectStatement]] annotations to the end of their corresponding modules
  *
  * Implemented with Chisel Aspects and the [[chisel3.aop.injecting]] library
  */
class InjectingTransform extends Transform {
  override def inputForm:  CircuitForm = ChirrtlForm
  override def outputForm: CircuitForm = ChirrtlForm

  override def execute(state: CircuitState): CircuitState = {

    val addStmtMap = mutable.HashMap[String, Seq[ir.Statement]]()
    val addModules = mutable.ArrayBuffer[ir.DefModule]()

    // Populate addStmtMap and addModules, return annotations in InjectStatements, and omit InjectStatement annotation
    val newAnnotations = state.annotations.flatMap {
      case InjectStatement(mt, s, addedModules, annotations) =>
        addModules ++= addedModules
        addStmtMap(mt.module) = s +: addStmtMap.getOrElse(mt.module, Nil)
        annotations
      case other => Seq(other)
    }

    // Append all statements to end of corresponding modules
    val newModules = state.circuit.modules.map { m: ir.DefModule =>
      m match {
        case m: ir.Module if addStmtMap.contains(m.name) =>
          m.copy(body = ir.Block(m.body +: addStmtMap(m.name)))
        case m: _root_.firrtl.ir.ExtModule if addStmtMap.contains(m.name) =>
          ir.Module(m.info, m.name, m.ports, ir.Block(addStmtMap(m.name)))
        case other: ir.DefModule => other
      }
    }

    // Return updated circuit and annotations
    val newCircuit = state.circuit.copy(modules = newModules ++ addModules)
    state.copy(annotations = newAnnotations, circuit = newCircuit)
  }
}