summaryrefslogtreecommitdiff
path: root/src/main/scala/chisel3/aop
diff options
context:
space:
mode:
authorAdam Izraelevitz2019-08-12 15:49:42 -0700
committerGitHub2019-08-12 15:49:42 -0700
commitfddb5943b1d36925a5435d327c3312572e98ca58 (patch)
treeb22e3a544dbb265dead955544c75bf7abddb7c69 /src/main/scala/chisel3/aop
parent466ffbc9ca4fcca73d56f849df9e2753f68c53a8 (diff)
Aspect-Oriented Programming for Chisel (#1077)
Added Aspects to Chisel, enabling a mechanism for dependency injection to hardware modules.
Diffstat (limited to 'src/main/scala/chisel3/aop')
-rw-r--r--src/main/scala/chisel3/aop/Select.scala418
-rw-r--r--src/main/scala/chisel3/aop/injecting/InjectStatement.scala21
-rw-r--r--src/main/scala/chisel3/aop/injecting/InjectingAspect.scala63
-rw-r--r--src/main/scala/chisel3/aop/injecting/InjectingTransform.scala46
4 files changed, 548 insertions, 0 deletions
diff --git a/src/main/scala/chisel3/aop/Select.scala b/src/main/scala/chisel3/aop/Select.scala
new file mode 100644
index 00000000..612cdcc7
--- /dev/null
+++ b/src/main/scala/chisel3/aop/Select.scala
@@ -0,0 +1,418 @@
+// See LICENSE for license details.
+
+package chisel3.aop
+
+import chisel3._
+import chisel3.experimental.{BaseModule, FixedPoint}
+import chisel3.internal.HasId
+import chisel3.internal.firrtl._
+import firrtl.annotations.ReferenceTarget
+
+import scala.collection.mutable
+
+/** Use to select Chisel components in a module, after that module has been constructed
+ * Useful for adding additional Chisel annotations or for use within an [[Aspect]]
+ */
+object Select {
+
+ /** Return just leaf components of expanded node
+ *
+ * @param d Component to find leafs if aggregate typed. Intermediate fields/indicies are not included
+ * @return
+ */
+ def getLeafs(d: Data): Seq[Data] = d match {
+ case b: Bundle => b.getElements.flatMap(getLeafs)
+ case v: Vec[_] => v.getElements.flatMap(getLeafs)
+ case other => Seq(other)
+ }
+
+ /** Return all expanded components, including intermediate aggregate nodes
+ *
+ * @param d Component to find leafs if aggregate typed. Intermediate fields/indicies ARE included
+ * @return
+ */
+ def getIntermediateAndLeafs(d: Data): Seq[Data] = d match {
+ case b: Bundle => b +: b.getElements.flatMap(getIntermediateAndLeafs)
+ case v: Vec[_] => v +: v.getElements.flatMap(getIntermediateAndLeafs)
+ case other => Seq(other)
+ }
+
+
+ /** Collects all components selected by collector within module and all children modules it instantiates
+ * directly or indirectly
+ * Accepts a collector function, rather than a collector partial function (see [[collectDeep]])
+ * @param module Module to collect components, as well as all children module it directly and indirectly instantiates
+ * @param collector Collector function to pick, given a module, which components to collect
+ * @param tag Required for generics to work, should ignore this
+ * @tparam T Type of the component that will be collected
+ * @return
+ */
+ def getDeep[T](module: BaseModule)(collector: BaseModule => Seq[T]): Seq[T] = {
+ check(module)
+ val myItems = collector(module)
+ val deepChildrenItems = instances(module).flatMap {
+ i => getDeep(i)(collector)
+ }
+ myItems ++ deepChildrenItems
+ }
+
+ /** Collects all components selected by collector within module and all children modules it instantiates
+ * directly or indirectly
+ * Accepts a collector partial function, rather than a collector function (see [[getDeep]])
+ * @param module Module to collect components, as well as all children module it directly and indirectly instantiates
+ * @param collector Collector partial function to pick, given a module, which components to collect
+ * @param tag Required for generics to work, should ignore this
+ * @tparam T Type of the component that will be collected
+ * @return
+ */
+ def collectDeep[T](module: BaseModule)(collector: PartialFunction[BaseModule, T]): Iterable[T] = {
+ check(module)
+ val myItems = collector.lift(module)
+ val deepChildrenItems = instances(module).flatMap {
+ i => collectDeep(i)(collector)
+ }
+ myItems ++ deepChildrenItems
+ }
+
+ /** Selects all instances directly instantiated within given module
+ * @param module
+ * @return
+ */
+ def instances(module: BaseModule): Seq[BaseModule] = {
+ check(module)
+ module._component.get.asInstanceOf[DefModule].commands.collect {
+ case i: DefInstance => i.id
+ }
+ }
+
+ /** Selects all registers directly instantiated within given module
+ * @param module
+ * @return
+ */
+ def registers(module: BaseModule): Seq[Data] = {
+ check(module)
+ module._component.get.asInstanceOf[DefModule].commands.collect {
+ case r: DefReg => r.id
+ case r: DefRegInit => r.id
+ }
+ }
+
+ /** Selects all ios directly contained within given module
+ * @param module
+ * @return
+ */
+ def ios(module: BaseModule): Seq[Data] = {
+ check(module)
+ module._component.get.asInstanceOf[DefModule].ports.map(_.id)
+ }
+
+ /** Selects all SyncReadMems directly contained within given module
+ * @param module
+ * @return
+ */
+ def syncReadMems(module: BaseModule): Seq[SyncReadMem[_]] = {
+ check(module)
+ module._component.get.asInstanceOf[DefModule].commands.collect {
+ case r: DefSeqMemory => r.id.asInstanceOf[SyncReadMem[_]]
+ }
+ }
+
+ /** Selects all Mems directly contained within given module
+ * @param module
+ * @return
+ */
+ def mems(module: BaseModule): Seq[Mem[_]] = {
+ check(module)
+ module._component.get.asInstanceOf[DefModule].commands.collect {
+ case r: DefMemory => r.id.asInstanceOf[Mem[_]]
+ }
+ }
+
+ /** Selects all arithmetic or logical operators directly instantiated within given module
+ * @param module
+ * @return
+ */
+ def ops(module: BaseModule): Seq[(String, Data)] = {
+ check(module)
+ module._component.get.asInstanceOf[DefModule].commands.collect {
+ case d: DefPrim[_] => (d.op.name, d.id)
+ }
+ }
+
+ /** Selects a kind of arithmetic or logical operator directly instantiated within given module
+ * The kind of operators are contained in [[chisel3.internal.firrtl.PrimOp]]
+ * @param opKind the kind of operator, e.g. "mux", "add", or "bits"
+ * @param module
+ * @return
+ */
+ def ops(opKind: String)(module: BaseModule): Seq[Data] = {
+ check(module)
+ module._component.get.asInstanceOf[DefModule].commands.collect {
+ case d: DefPrim[_] if d.op.name == opKind => d.id
+ }
+ }
+
+ /** Selects all wires in a module
+ * @param module
+ * @return
+ */
+ def wires(module: BaseModule): Seq[Data] = {
+ check(module)
+ module._component.get.asInstanceOf[DefModule].commands.collect {
+ case r: DefWire => r.id
+ }
+ }
+
+ /** Selects all memory ports, including their direction and memory
+ * @param module
+ * @return
+ */
+ def memPorts(module: BaseModule): Seq[(Data, MemPortDirection, MemBase[_])] = {
+ check(module)
+ module._component.get.asInstanceOf[DefModule].commands.collect {
+ case r: DefMemPort[_] => (r.id, r.dir, r.source.id.asInstanceOf[MemBase[_ <: Data]])
+ }
+ }
+
+ /** Selects all memory ports of a given direction, including their memory
+ * @param dir The direction of memory ports to select
+ * @param module
+ * @return
+ */
+ def memPorts(dir: MemPortDirection)(module: BaseModule): Seq[(Data, MemBase[_])] = {
+ check(module)
+ module._component.get.asInstanceOf[DefModule].commands.collect {
+ case r: DefMemPort[_] if r.dir == dir => (r.id, r.source.id.asInstanceOf[MemBase[_ <: Data]])
+ }
+ }
+
+ /** Selects all components who have been set to be invalid, even if they are later connected to
+ * @param module
+ * @return
+ */
+ def invalids(module: BaseModule): Seq[Data] = {
+ check(module)
+ module._component.get.asInstanceOf[DefModule].commands.collect {
+ case DefInvalid(_, arg) => getData(arg)
+ }
+ }
+
+ /** Selects all components who are attached to a given signal, within a module
+ * @param module
+ * @return
+ */
+ def attachedTo(module: BaseModule)(signal: Data): Set[Data] = {
+ check(module)
+ module._component.get.asInstanceOf[DefModule].commands.collect {
+ case Attach(_, seq) if seq.contains(signal) => seq
+ }.flatMap { seq => seq.map(_.id.asInstanceOf[Data]) }.toSet
+ }
+
+ /** Selects all connections to a signal or its parent signal(s) (if the signal is an element of an aggregate signal)
+ * The when predicates surrounding each connection are included in the returned values
+ *
+ * E.g. if signal = io.foo.bar, connectionsTo will return all connections to io, io.foo, and io.bar
+ * @param module
+ * @param signal
+ * @return
+ */
+ def connectionsTo(module: BaseModule)(signal: Data): Seq[PredicatedConnect] = {
+ check(module)
+ val sensitivitySignals = getIntermediateAndLeafs(signal).toSet
+ val predicatedConnects = mutable.ArrayBuffer[PredicatedConnect]()
+ val isPort = module._component.get.asInstanceOf[DefModule].ports.flatMap{ p => getIntermediateAndLeafs(p.id) }.contains(signal)
+ var prePredicates: Seq[Predicate] = Nil
+ var seenDef = isPort
+ searchWhens(module, (cmd: Command, preds) => {
+ cmd match {
+ case cmd: Definition if cmd.id.isInstanceOf[Data] =>
+ val x = getIntermediateAndLeafs(cmd.id.asInstanceOf[Data])
+ if(x.contains(signal)) prePredicates = preds
+ case Connect(_, loc@Node(d: Data), exp) =>
+ val effected = getEffected(loc).toSet
+ if(sensitivitySignals.intersect(effected).nonEmpty) {
+ val expData = getData(exp)
+ prePredicates.reverse.zip(preds.reverse).foreach(x => assert(x._1 == x._2, s"Prepredicates $x must match for signal $signal"))
+ predicatedConnects += PredicatedConnect(preds.dropRight(prePredicates.size), d, expData, isBulk = false)
+ }
+ case BulkConnect(_, loc@Node(d: Data), exp) =>
+ val effected = getEffected(loc).toSet
+ if(sensitivitySignals.intersect(effected).nonEmpty) {
+ val expData = getData(exp)
+ prePredicates.reverse.zip(preds.reverse).foreach(x => assert(x._1 == x._2, s"Prepredicates $x must match for signal $signal"))
+ predicatedConnects += PredicatedConnect(preds.dropRight(prePredicates.size), d, expData, isBulk = true)
+ }
+ case other =>
+ }
+ })
+ predicatedConnects
+ }
+
+ /** Selects all stop statements, and includes the predicates surrounding the stop statement
+ *
+ * @param module
+ * @return
+ */
+ def stops(module: BaseModule): Seq[Stop] = {
+ val stops = mutable.ArrayBuffer[Stop]()
+ searchWhens(module, (cmd: Command, preds: Seq[Predicate]) => {
+ cmd match {
+ case chisel3.internal.firrtl.Stop(_, clock, ret) => stops += Stop(preds, ret, getId(clock).asInstanceOf[Clock])
+ case other =>
+ }
+ })
+ stops
+ }
+
+ /** Selects all printf statements, and includes the predicates surrounding the printf statement
+ *
+ * @param module
+ * @return
+ */
+ def printfs(module: BaseModule): Seq[Printf] = {
+ val printfs = mutable.ArrayBuffer[Printf]()
+ searchWhens(module, (cmd: Command, preds: Seq[Predicate]) => {
+ cmd match {
+ case chisel3.internal.firrtl.Printf(_, clock, pable) => printfs += Printf(preds, pable, getId(clock).asInstanceOf[Clock])
+ case other =>
+ }
+ })
+ printfs
+ }
+
+ // Checks that a module has finished its construction
+ private def check(module: BaseModule): Unit = {
+ require(module.isClosed, "Can't use Selector on modules that have not finished construction!")
+ require(module._component.isDefined, "Can't use Selector on modules that don't have components!")
+ }
+
+ // Given a loc, return all subcomponents of id that could be assigned to in connect
+ private def getEffected(a: Arg): Seq[Data] = a match {
+ case Node(id: Data) => getIntermediateAndLeafs(id)
+ case Slot(imm, name) => Seq(imm.id.asInstanceOf[Record].elements(name))
+ case Index(imm, value) => getEffected(imm)
+ }
+
+ // Given an arg, return the corresponding id. Don't use on a loc of a connect.
+ private def getId(a: Arg): HasId = a match {
+ case Node(id) => id
+ case l: ULit => l.num.U(l.w)
+ case l: SLit => l.num.S(l.w)
+ case l: FPLit => FixedPoint(l.num, l.w, l.binaryPoint)
+ case other =>
+ sys.error(s"Something went horribly wrong! I was expecting ${other} to be a lit or a node!")
+ }
+
+ private def getData(a: Arg): Data = a match {
+ case Node(data: Data) => data
+ case other =>
+ sys.error(s"Something went horribly wrong! I was expecting ${other} to be Data!")
+ }
+
+ // Given an id, either get its name or its value, if its a lit
+ private def getName(i: HasId): String = try {
+ i.toTarget match {
+ case r: ReferenceTarget =>
+ val str = r.serialize
+ str.splitAt(str.indexOf('>'))._2.drop(1)
+ }
+ } catch {
+ case e: ChiselException => i.getOptionRef.get match {
+ case l: LitArg => l.num.intValue().toString
+ }
+ }
+
+ // Collects when predicates as it searches through a module, then applying processCommand to non-when related commands
+ private def searchWhens(module: BaseModule, processCommand: (Command, Seq[Predicate]) => Unit) = {
+ check(module)
+ module._component.get.asInstanceOf[DefModule].commands.foldLeft((Seq.empty[Predicate], Option.empty[Predicate])) {
+ (blah, cmd) =>
+ (blah, cmd) match {
+ case ((preds, o), cmd) => cmd match {
+ case WhenBegin(_, Node(pred: Bool)) => (When(pred) +: preds, None)
+ case WhenBegin(_, l: LitArg) if l.num == BigInt(1) => (When(true.B) +: preds, None)
+ case WhenBegin(_, l: LitArg) if l.num == BigInt(0) => (When(false.B) +: preds, None)
+ case other: WhenBegin =>
+ sys.error(s"Something went horribly wrong! I was expecting ${other.pred} to be a lit or a bool!")
+ case _: WhenEnd => (preds.tail, Some(preds.head))
+ case AltBegin(_) if o.isDefined => (o.get.not +: preds, o)
+ case _: AltBegin =>
+ sys.error(s"Something went horribly wrong! I was expecting ${o} to be nonEmpty!")
+ case OtherwiseEnd(_, _) => (preds.tail, None)
+ case other =>
+ processCommand(cmd, preds)
+ (preds, o)
+ }
+ }
+ }
+ }
+
+ trait Serializeable {
+ def serialize: String
+ }
+
+ /** Used to indicates a when's predicate (or its otherwise predicate)
+ */
+ trait Predicate extends Serializeable {
+ val bool: Bool
+ def not: Predicate
+ }
+
+ /** Used to represent [[chisel3.when]] predicate
+ *
+ * @param bool the when predicate
+ */
+ case class When(bool: Bool) extends Predicate {
+ def not: WhenNot = WhenNot(bool)
+ def serialize: String = s"${getName(bool)}"
+ }
+
+ /** Used to represent the `otherwise` predicate of a [[chisel3.when]]
+ *
+ * @param bool the when predicate corresponding to this otherwise predicate
+ */
+ case class WhenNot(bool: Bool) extends Predicate {
+ def not: When = When(bool)
+ def serialize: String = s"!${getName(bool)}"
+ }
+
+ /** Used to represent a connection or bulk connection
+ *
+ * Additionally contains the sequence of when predicates seen when the connection is declared
+ *
+ * @param preds
+ * @param loc
+ * @param exp
+ * @param isBulk
+ */
+ case class PredicatedConnect(preds: Seq[Predicate], loc: Data, exp: Data, isBulk: Boolean) extends Serializeable {
+ def serialize: String = {
+ val moduleTarget = loc.toTarget.moduleTarget.serialize
+ s"$moduleTarget: when(${preds.map(_.serialize).mkString(" & ")}): ${getName(loc)} ${if(isBulk) "<>" else ":="} ${getName(exp)}"
+ }
+ }
+
+ /** Used to represent a [[chisel3.stop]]
+ *
+ * @param preds
+ * @param ret
+ * @param clock
+ */
+ case class Stop(preds: Seq[Predicate], ret: Int, clock: Clock) extends Serializeable {
+ def serialize: String = {
+ s"stop when(${preds.map(_.serialize).mkString(" & ")}) on ${getName(clock)}: $ret"
+ }
+ }
+
+ /** Used to represent a [[chisel3.printf]]
+ *
+ * @param preds
+ * @param pable
+ * @param clock
+ */
+ case class Printf(preds: Seq[Predicate], pable: Printable, clock: Clock) extends Serializeable {
+ def serialize: String = {
+ s"printf when(${preds.map(_.serialize).mkString(" & ")}) on ${getName(clock)}: $pable"
+ }
+ }
+}
diff --git a/src/main/scala/chisel3/aop/injecting/InjectStatement.scala b/src/main/scala/chisel3/aop/injecting/InjectStatement.scala
new file mode 100644
index 00000000..c207454d
--- /dev/null
+++ b/src/main/scala/chisel3/aop/injecting/InjectStatement.scala
@@ -0,0 +1,21 @@
+// See LICENSE for license details.
+
+package chisel3.aop.injecting
+
+import chisel3.stage.phases.AspectPhase
+import firrtl.annotations.{Annotation, ModuleTarget, NoTargetAnnotation, SingleTargetAnnotation}
+
+/** Contains all information needed to inject statements into a module
+ *
+ * Generated when a [[InjectingAspect]] is consumed by a [[AspectPhase]]
+ * Consumed by [[InjectingTransform]]
+ *
+ * @param module Module to inject code into at the end of the module
+ * @param s Statements to inject
+ * @param modules Additional modules that may be instantiated by s
+ * @param annotations Additional annotations that should be passed down compiler
+ */
+case class InjectStatement(module: ModuleTarget, s: firrtl.ir.Statement, modules: Seq[firrtl.ir.DefModule], annotations: Seq[Annotation]) extends SingleTargetAnnotation[ModuleTarget] {
+ val target: ModuleTarget = module
+ override def duplicate(n: ModuleTarget): Annotation = this.copy(module = n)
+}
diff --git a/src/main/scala/chisel3/aop/injecting/InjectingAspect.scala b/src/main/scala/chisel3/aop/injecting/InjectingAspect.scala
new file mode 100644
index 00000000..74cd62f3
--- /dev/null
+++ b/src/main/scala/chisel3/aop/injecting/InjectingAspect.scala
@@ -0,0 +1,63 @@
+// See LICENSE for license details.
+
+package chisel3.aop.injecting
+
+import chisel3.{Module, ModuleAspect, experimental, withClockAndReset}
+import chisel3.aop._
+import chisel3.experimental.RawModule
+import chisel3.internal.Builder
+import chisel3.internal.firrtl.DefModule
+import chisel3.stage.DesignAnnotation
+import firrtl.annotations.ModuleTarget
+import firrtl.stage.RunFirrtlTransformAnnotation
+import firrtl.{ir, _}
+
+import scala.collection.mutable
+import scala.reflect.runtime.universe.TypeTag
+
+/** Aspect to inject Chisel code into a module of type M
+ *
+ * @param selectRoots Given top-level module, pick the instances of a module to apply the aspect (root module)
+ * @param injection Function to generate Chisel hardware that will be injected to the end of module m
+ * Signals in m can be referenced and assigned to as if inside m (yes, it is a bit magical)
+ * @param tTag Needed to prevent type-erasure of the top-level module type
+ * @tparam T Type of top-level module
+ * @tparam M Type of root module (join point)
+ */
+case class InjectingAspect[T <: RawModule,
+ M <: RawModule](selectRoots: T => Iterable[M],
+ injection: M => Unit
+ )(implicit tTag: TypeTag[T]) extends Aspect[T] {
+ final def toAnnotation(top: T): AnnotationSeq = {
+ toAnnotation(selectRoots(top), top.name)
+ }
+
+ final def toAnnotation(modules: Iterable[M], circuit: String): AnnotationSeq = {
+ RunFirrtlTransformAnnotation(new InjectingTransform) +: modules.map { module =>
+ val (chiselIR, _) = Builder.build(Module(new ModuleAspect(module) {
+ module match {
+ case x: experimental.MultiIOModule => withClockAndReset(x.clock, x.reset) { injection(module) }
+ case x: RawModule => injection(module)
+ }
+ }))
+ val comps = chiselIR.components.map {
+ case x: DefModule if x.name == module.name => x.copy(id = module)
+ case other => other
+ }
+
+ val annotations = chiselIR.annotations.map(_.toFirrtl).filterNot{ a => a.isInstanceOf[DesignAnnotation[_]] }
+
+ val stmts = mutable.ArrayBuffer[ir.Statement]()
+ val modules = Aspect.getFirrtl(chiselIR.copy(components = comps)).modules.flatMap {
+ case m: firrtl.ir.Module if m.name == module.name =>
+ stmts += m.body
+ Nil
+ case other =>
+ Seq(other)
+ }
+
+ InjectStatement(ModuleTarget(circuit, module.name), ir.Block(stmts), modules, annotations)
+ }.toSeq
+ }
+}
+
diff --git a/src/main/scala/chisel3/aop/injecting/InjectingTransform.scala b/src/main/scala/chisel3/aop/injecting/InjectingTransform.scala
new file mode 100644
index 00000000..c65bee38
--- /dev/null
+++ b/src/main/scala/chisel3/aop/injecting/InjectingTransform.scala
@@ -0,0 +1,46 @@
+// See LICENSE for license details.
+
+package chisel3.aop.injecting
+
+import firrtl.{ChirrtlForm, CircuitForm, CircuitState, Transform, ir}
+
+import scala.collection.mutable
+
+/** Appends statements contained in [[InjectStatement]] annotations to the end of their corresponding modules
+ *
+ * Implemented with Chisel Aspects and the [[chisel3.aop.injecting]] library
+ */
+class InjectingTransform extends Transform {
+ override def inputForm: CircuitForm = ChirrtlForm
+ override def outputForm: CircuitForm = ChirrtlForm
+
+ override def execute(state: CircuitState): CircuitState = {
+
+ val addStmtMap = mutable.HashMap[String, Seq[ir.Statement]]()
+ val addModules = mutable.ArrayBuffer[ir.DefModule]()
+
+ // Populate addStmtMap and addModules, return annotations in InjectStatements, and omit InjectStatement annotation
+ val newAnnotations = state.annotations.flatMap {
+ case InjectStatement(mt, s, addedModules, annotations) =>
+ addModules ++= addedModules
+ addStmtMap(mt.module) = s +: addStmtMap.getOrElse(mt.module, Nil)
+ annotations
+ case other => Seq(other)
+ }
+
+ // Append all statements to end of corresponding modules
+ val newModules = state.circuit.modules.map { m: ir.DefModule =>
+ m match {
+ case m: ir.Module if addStmtMap.contains(m.name) =>
+ m.copy(body = ir.Block(m.body +: addStmtMap(m.name)))
+ case m: _root_.firrtl.ir.ExtModule if addStmtMap.contains(m.name) =>
+ ir.Module(m.info, m.name, m.ports, ir.Block(addStmtMap(m.name)))
+ case other: ir.DefModule => other
+ }
+ }
+
+ // Return updated circuit and annotations
+ val newCircuit = state.circuit.copy(modules = newModules ++ addModules)
+ state.copy(annotations = newAnnotations, circuit = newCircuit)
+ }
+}