diff options
Diffstat (limited to 'src')
17 files changed, 876 insertions, 312 deletions
diff --git a/src/main/scala/chisel3/Driver.scala b/src/main/scala/chisel3/Driver.scala index 906ae7fc..66146755 100644 --- a/src/main/scala/chisel3/Driver.scala +++ b/src/main/scala/chisel3/Driver.scala @@ -3,18 +3,17 @@ package chisel3 import chisel3.internal.ErrorLog -import chisel3.internal.firrtl._ -import chisel3.experimental.{RawModule, RunFirrtlTransform} -import chisel3.stage.{ChiselCircuitAnnotation, ChiselGeneratorAnnotation, ChiselStage, ChiselExecutionResultView} -import chisel3.stage.phases.DriverCompatibility - -import java.io._ - +import chisel3.experimental.RawModule +import internal.firrtl._ import firrtl._ -import firrtl.annotations.JsonProtocol import firrtl.options.Phase import firrtl.options.Viewer.view +import firrtl.annotations.JsonProtocol import firrtl.util.{BackendCompilationUtilities => FirrtlBackendCompilationUtilities} +import chisel3.stage.{ChiselExecutionResultView, ChiselGeneratorAnnotation, ChiselStage} +import chisel3.stage.phases.DriverCompatibility +import java.io._ + /** * The Driver provides methods to invoke the chisel3 compiler and the firrtl compiler. @@ -91,7 +90,7 @@ object Driver extends BackendCompilationUtilities { * @param gen A function that creates a Module hierarchy. * @return The resulting Chisel IR in the form of a Circuit. (TODO: Should be FIRRTL IR) */ - def elaborate[T <: RawModule](gen: () => T): Circuit = internal.Builder.build(Module(gen())) + def elaborate[T <: RawModule](gen: () => T): Circuit = internal.Builder.build(Module(gen()))._1 /** * Convert the given Chisel IR Circuit to a FIRRTL Circuit. diff --git a/src/main/scala/chisel3/aop/Select.scala b/src/main/scala/chisel3/aop/Select.scala new file mode 100644 index 00000000..612cdcc7 --- /dev/null +++ b/src/main/scala/chisel3/aop/Select.scala @@ -0,0 +1,418 @@ +// See LICENSE for license details. + +package chisel3.aop + +import chisel3._ +import chisel3.experimental.{BaseModule, FixedPoint} +import chisel3.internal.HasId +import chisel3.internal.firrtl._ +import firrtl.annotations.ReferenceTarget + +import scala.collection.mutable + +/** Use to select Chisel components in a module, after that module has been constructed + * Useful for adding additional Chisel annotations or for use within an [[Aspect]] + */ +object Select { + + /** Return just leaf components of expanded node + * + * @param d Component to find leafs if aggregate typed. Intermediate fields/indicies are not included + * @return + */ + def getLeafs(d: Data): Seq[Data] = d match { + case b: Bundle => b.getElements.flatMap(getLeafs) + case v: Vec[_] => v.getElements.flatMap(getLeafs) + case other => Seq(other) + } + + /** Return all expanded components, including intermediate aggregate nodes + * + * @param d Component to find leafs if aggregate typed. Intermediate fields/indicies ARE included + * @return + */ + def getIntermediateAndLeafs(d: Data): Seq[Data] = d match { + case b: Bundle => b +: b.getElements.flatMap(getIntermediateAndLeafs) + case v: Vec[_] => v +: v.getElements.flatMap(getIntermediateAndLeafs) + case other => Seq(other) + } + + + /** Collects all components selected by collector within module and all children modules it instantiates + * directly or indirectly + * Accepts a collector function, rather than a collector partial function (see [[collectDeep]]) + * @param module Module to collect components, as well as all children module it directly and indirectly instantiates + * @param collector Collector function to pick, given a module, which components to collect + * @param tag Required for generics to work, should ignore this + * @tparam T Type of the component that will be collected + * @return + */ + def getDeep[T](module: BaseModule)(collector: BaseModule => Seq[T]): Seq[T] = { + check(module) + val myItems = collector(module) + val deepChildrenItems = instances(module).flatMap { + i => getDeep(i)(collector) + } + myItems ++ deepChildrenItems + } + + /** Collects all components selected by collector within module and all children modules it instantiates + * directly or indirectly + * Accepts a collector partial function, rather than a collector function (see [[getDeep]]) + * @param module Module to collect components, as well as all children module it directly and indirectly instantiates + * @param collector Collector partial function to pick, given a module, which components to collect + * @param tag Required for generics to work, should ignore this + * @tparam T Type of the component that will be collected + * @return + */ + def collectDeep[T](module: BaseModule)(collector: PartialFunction[BaseModule, T]): Iterable[T] = { + check(module) + val myItems = collector.lift(module) + val deepChildrenItems = instances(module).flatMap { + i => collectDeep(i)(collector) + } + myItems ++ deepChildrenItems + } + + /** Selects all instances directly instantiated within given module + * @param module + * @return + */ + def instances(module: BaseModule): Seq[BaseModule] = { + check(module) + module._component.get.asInstanceOf[DefModule].commands.collect { + case i: DefInstance => i.id + } + } + + /** Selects all registers directly instantiated within given module + * @param module + * @return + */ + def registers(module: BaseModule): Seq[Data] = { + check(module) + module._component.get.asInstanceOf[DefModule].commands.collect { + case r: DefReg => r.id + case r: DefRegInit => r.id + } + } + + /** Selects all ios directly contained within given module + * @param module + * @return + */ + def ios(module: BaseModule): Seq[Data] = { + check(module) + module._component.get.asInstanceOf[DefModule].ports.map(_.id) + } + + /** Selects all SyncReadMems directly contained within given module + * @param module + * @return + */ + def syncReadMems(module: BaseModule): Seq[SyncReadMem[_]] = { + check(module) + module._component.get.asInstanceOf[DefModule].commands.collect { + case r: DefSeqMemory => r.id.asInstanceOf[SyncReadMem[_]] + } + } + + /** Selects all Mems directly contained within given module + * @param module + * @return + */ + def mems(module: BaseModule): Seq[Mem[_]] = { + check(module) + module._component.get.asInstanceOf[DefModule].commands.collect { + case r: DefMemory => r.id.asInstanceOf[Mem[_]] + } + } + + /** Selects all arithmetic or logical operators directly instantiated within given module + * @param module + * @return + */ + def ops(module: BaseModule): Seq[(String, Data)] = { + check(module) + module._component.get.asInstanceOf[DefModule].commands.collect { + case d: DefPrim[_] => (d.op.name, d.id) + } + } + + /** Selects a kind of arithmetic or logical operator directly instantiated within given module + * The kind of operators are contained in [[chisel3.internal.firrtl.PrimOp]] + * @param opKind the kind of operator, e.g. "mux", "add", or "bits" + * @param module + * @return + */ + def ops(opKind: String)(module: BaseModule): Seq[Data] = { + check(module) + module._component.get.asInstanceOf[DefModule].commands.collect { + case d: DefPrim[_] if d.op.name == opKind => d.id + } + } + + /** Selects all wires in a module + * @param module + * @return + */ + def wires(module: BaseModule): Seq[Data] = { + check(module) + module._component.get.asInstanceOf[DefModule].commands.collect { + case r: DefWire => r.id + } + } + + /** Selects all memory ports, including their direction and memory + * @param module + * @return + */ + def memPorts(module: BaseModule): Seq[(Data, MemPortDirection, MemBase[_])] = { + check(module) + module._component.get.asInstanceOf[DefModule].commands.collect { + case r: DefMemPort[_] => (r.id, r.dir, r.source.id.asInstanceOf[MemBase[_ <: Data]]) + } + } + + /** Selects all memory ports of a given direction, including their memory + * @param dir The direction of memory ports to select + * @param module + * @return + */ + def memPorts(dir: MemPortDirection)(module: BaseModule): Seq[(Data, MemBase[_])] = { + check(module) + module._component.get.asInstanceOf[DefModule].commands.collect { + case r: DefMemPort[_] if r.dir == dir => (r.id, r.source.id.asInstanceOf[MemBase[_ <: Data]]) + } + } + + /** Selects all components who have been set to be invalid, even if they are later connected to + * @param module + * @return + */ + def invalids(module: BaseModule): Seq[Data] = { + check(module) + module._component.get.asInstanceOf[DefModule].commands.collect { + case DefInvalid(_, arg) => getData(arg) + } + } + + /** Selects all components who are attached to a given signal, within a module + * @param module + * @return + */ + def attachedTo(module: BaseModule)(signal: Data): Set[Data] = { + check(module) + module._component.get.asInstanceOf[DefModule].commands.collect { + case Attach(_, seq) if seq.contains(signal) => seq + }.flatMap { seq => seq.map(_.id.asInstanceOf[Data]) }.toSet + } + + /** Selects all connections to a signal or its parent signal(s) (if the signal is an element of an aggregate signal) + * The when predicates surrounding each connection are included in the returned values + * + * E.g. if signal = io.foo.bar, connectionsTo will return all connections to io, io.foo, and io.bar + * @param module + * @param signal + * @return + */ + def connectionsTo(module: BaseModule)(signal: Data): Seq[PredicatedConnect] = { + check(module) + val sensitivitySignals = getIntermediateAndLeafs(signal).toSet + val predicatedConnects = mutable.ArrayBuffer[PredicatedConnect]() + val isPort = module._component.get.asInstanceOf[DefModule].ports.flatMap{ p => getIntermediateAndLeafs(p.id) }.contains(signal) + var prePredicates: Seq[Predicate] = Nil + var seenDef = isPort + searchWhens(module, (cmd: Command, preds) => { + cmd match { + case cmd: Definition if cmd.id.isInstanceOf[Data] => + val x = getIntermediateAndLeafs(cmd.id.asInstanceOf[Data]) + if(x.contains(signal)) prePredicates = preds + case Connect(_, loc@Node(d: Data), exp) => + val effected = getEffected(loc).toSet + if(sensitivitySignals.intersect(effected).nonEmpty) { + val expData = getData(exp) + prePredicates.reverse.zip(preds.reverse).foreach(x => assert(x._1 == x._2, s"Prepredicates $x must match for signal $signal")) + predicatedConnects += PredicatedConnect(preds.dropRight(prePredicates.size), d, expData, isBulk = false) + } + case BulkConnect(_, loc@Node(d: Data), exp) => + val effected = getEffected(loc).toSet + if(sensitivitySignals.intersect(effected).nonEmpty) { + val expData = getData(exp) + prePredicates.reverse.zip(preds.reverse).foreach(x => assert(x._1 == x._2, s"Prepredicates $x must match for signal $signal")) + predicatedConnects += PredicatedConnect(preds.dropRight(prePredicates.size), d, expData, isBulk = true) + } + case other => + } + }) + predicatedConnects + } + + /** Selects all stop statements, and includes the predicates surrounding the stop statement + * + * @param module + * @return + */ + def stops(module: BaseModule): Seq[Stop] = { + val stops = mutable.ArrayBuffer[Stop]() + searchWhens(module, (cmd: Command, preds: Seq[Predicate]) => { + cmd match { + case chisel3.internal.firrtl.Stop(_, clock, ret) => stops += Stop(preds, ret, getId(clock).asInstanceOf[Clock]) + case other => + } + }) + stops + } + + /** Selects all printf statements, and includes the predicates surrounding the printf statement + * + * @param module + * @return + */ + def printfs(module: BaseModule): Seq[Printf] = { + val printfs = mutable.ArrayBuffer[Printf]() + searchWhens(module, (cmd: Command, preds: Seq[Predicate]) => { + cmd match { + case chisel3.internal.firrtl.Printf(_, clock, pable) => printfs += Printf(preds, pable, getId(clock).asInstanceOf[Clock]) + case other => + } + }) + printfs + } + + // Checks that a module has finished its construction + private def check(module: BaseModule): Unit = { + require(module.isClosed, "Can't use Selector on modules that have not finished construction!") + require(module._component.isDefined, "Can't use Selector on modules that don't have components!") + } + + // Given a loc, return all subcomponents of id that could be assigned to in connect + private def getEffected(a: Arg): Seq[Data] = a match { + case Node(id: Data) => getIntermediateAndLeafs(id) + case Slot(imm, name) => Seq(imm.id.asInstanceOf[Record].elements(name)) + case Index(imm, value) => getEffected(imm) + } + + // Given an arg, return the corresponding id. Don't use on a loc of a connect. + private def getId(a: Arg): HasId = a match { + case Node(id) => id + case l: ULit => l.num.U(l.w) + case l: SLit => l.num.S(l.w) + case l: FPLit => FixedPoint(l.num, l.w, l.binaryPoint) + case other => + sys.error(s"Something went horribly wrong! I was expecting ${other} to be a lit or a node!") + } + + private def getData(a: Arg): Data = a match { + case Node(data: Data) => data + case other => + sys.error(s"Something went horribly wrong! I was expecting ${other} to be Data!") + } + + // Given an id, either get its name or its value, if its a lit + private def getName(i: HasId): String = try { + i.toTarget match { + case r: ReferenceTarget => + val str = r.serialize + str.splitAt(str.indexOf('>'))._2.drop(1) + } + } catch { + case e: ChiselException => i.getOptionRef.get match { + case l: LitArg => l.num.intValue().toString + } + } + + // Collects when predicates as it searches through a module, then applying processCommand to non-when related commands + private def searchWhens(module: BaseModule, processCommand: (Command, Seq[Predicate]) => Unit) = { + check(module) + module._component.get.asInstanceOf[DefModule].commands.foldLeft((Seq.empty[Predicate], Option.empty[Predicate])) { + (blah, cmd) => + (blah, cmd) match { + case ((preds, o), cmd) => cmd match { + case WhenBegin(_, Node(pred: Bool)) => (When(pred) +: preds, None) + case WhenBegin(_, l: LitArg) if l.num == BigInt(1) => (When(true.B) +: preds, None) + case WhenBegin(_, l: LitArg) if l.num == BigInt(0) => (When(false.B) +: preds, None) + case other: WhenBegin => + sys.error(s"Something went horribly wrong! I was expecting ${other.pred} to be a lit or a bool!") + case _: WhenEnd => (preds.tail, Some(preds.head)) + case AltBegin(_) if o.isDefined => (o.get.not +: preds, o) + case _: AltBegin => + sys.error(s"Something went horribly wrong! I was expecting ${o} to be nonEmpty!") + case OtherwiseEnd(_, _) => (preds.tail, None) + case other => + processCommand(cmd, preds) + (preds, o) + } + } + } + } + + trait Serializeable { + def serialize: String + } + + /** Used to indicates a when's predicate (or its otherwise predicate) + */ + trait Predicate extends Serializeable { + val bool: Bool + def not: Predicate + } + + /** Used to represent [[chisel3.when]] predicate + * + * @param bool the when predicate + */ + case class When(bool: Bool) extends Predicate { + def not: WhenNot = WhenNot(bool) + def serialize: String = s"${getName(bool)}" + } + + /** Used to represent the `otherwise` predicate of a [[chisel3.when]] + * + * @param bool the when predicate corresponding to this otherwise predicate + */ + case class WhenNot(bool: Bool) extends Predicate { + def not: When = When(bool) + def serialize: String = s"!${getName(bool)}" + } + + /** Used to represent a connection or bulk connection + * + * Additionally contains the sequence of when predicates seen when the connection is declared + * + * @param preds + * @param loc + * @param exp + * @param isBulk + */ + case class PredicatedConnect(preds: Seq[Predicate], loc: Data, exp: Data, isBulk: Boolean) extends Serializeable { + def serialize: String = { + val moduleTarget = loc.toTarget.moduleTarget.serialize + s"$moduleTarget: when(${preds.map(_.serialize).mkString(" & ")}): ${getName(loc)} ${if(isBulk) "<>" else ":="} ${getName(exp)}" + } + } + + /** Used to represent a [[chisel3.stop]] + * + * @param preds + * @param ret + * @param clock + */ + case class Stop(preds: Seq[Predicate], ret: Int, clock: Clock) extends Serializeable { + def serialize: String = { + s"stop when(${preds.map(_.serialize).mkString(" & ")}) on ${getName(clock)}: $ret" + } + } + + /** Used to represent a [[chisel3.printf]] + * + * @param preds + * @param pable + * @param clock + */ + case class Printf(preds: Seq[Predicate], pable: Printable, clock: Clock) extends Serializeable { + def serialize: String = { + s"printf when(${preds.map(_.serialize).mkString(" & ")}) on ${getName(clock)}: $pable" + } + } +} diff --git a/src/main/scala/chisel3/aop/injecting/InjectStatement.scala b/src/main/scala/chisel3/aop/injecting/InjectStatement.scala new file mode 100644 index 00000000..c207454d --- /dev/null +++ b/src/main/scala/chisel3/aop/injecting/InjectStatement.scala @@ -0,0 +1,21 @@ +// See LICENSE for license details. + +package chisel3.aop.injecting + +import chisel3.stage.phases.AspectPhase +import firrtl.annotations.{Annotation, ModuleTarget, NoTargetAnnotation, SingleTargetAnnotation} + +/** Contains all information needed to inject statements into a module + * + * Generated when a [[InjectingAspect]] is consumed by a [[AspectPhase]] + * Consumed by [[InjectingTransform]] + * + * @param module Module to inject code into at the end of the module + * @param s Statements to inject + * @param modules Additional modules that may be instantiated by s + * @param annotations Additional annotations that should be passed down compiler + */ +case class InjectStatement(module: ModuleTarget, s: firrtl.ir.Statement, modules: Seq[firrtl.ir.DefModule], annotations: Seq[Annotation]) extends SingleTargetAnnotation[ModuleTarget] { + val target: ModuleTarget = module + override def duplicate(n: ModuleTarget): Annotation = this.copy(module = n) +} diff --git a/src/main/scala/chisel3/aop/injecting/InjectingAspect.scala b/src/main/scala/chisel3/aop/injecting/InjectingAspect.scala new file mode 100644 index 00000000..74cd62f3 --- /dev/null +++ b/src/main/scala/chisel3/aop/injecting/InjectingAspect.scala @@ -0,0 +1,63 @@ +// See LICENSE for license details. + +package chisel3.aop.injecting + +import chisel3.{Module, ModuleAspect, experimental, withClockAndReset} +import chisel3.aop._ +import chisel3.experimental.RawModule +import chisel3.internal.Builder +import chisel3.internal.firrtl.DefModule +import chisel3.stage.DesignAnnotation +import firrtl.annotations.ModuleTarget +import firrtl.stage.RunFirrtlTransformAnnotation +import firrtl.{ir, _} + +import scala.collection.mutable +import scala.reflect.runtime.universe.TypeTag + +/** Aspect to inject Chisel code into a module of type M + * + * @param selectRoots Given top-level module, pick the instances of a module to apply the aspect (root module) + * @param injection Function to generate Chisel hardware that will be injected to the end of module m + * Signals in m can be referenced and assigned to as if inside m (yes, it is a bit magical) + * @param tTag Needed to prevent type-erasure of the top-level module type + * @tparam T Type of top-level module + * @tparam M Type of root module (join point) + */ +case class InjectingAspect[T <: RawModule, + M <: RawModule](selectRoots: T => Iterable[M], + injection: M => Unit + )(implicit tTag: TypeTag[T]) extends Aspect[T] { + final def toAnnotation(top: T): AnnotationSeq = { + toAnnotation(selectRoots(top), top.name) + } + + final def toAnnotation(modules: Iterable[M], circuit: String): AnnotationSeq = { + RunFirrtlTransformAnnotation(new InjectingTransform) +: modules.map { module => + val (chiselIR, _) = Builder.build(Module(new ModuleAspect(module) { + module match { + case x: experimental.MultiIOModule => withClockAndReset(x.clock, x.reset) { injection(module) } + case x: RawModule => injection(module) + } + })) + val comps = chiselIR.components.map { + case x: DefModule if x.name == module.name => x.copy(id = module) + case other => other + } + + val annotations = chiselIR.annotations.map(_.toFirrtl).filterNot{ a => a.isInstanceOf[DesignAnnotation[_]] } + + val stmts = mutable.ArrayBuffer[ir.Statement]() + val modules = Aspect.getFirrtl(chiselIR.copy(components = comps)).modules.flatMap { + case m: firrtl.ir.Module if m.name == module.name => + stmts += m.body + Nil + case other => + Seq(other) + } + + InjectStatement(ModuleTarget(circuit, module.name), ir.Block(stmts), modules, annotations) + }.toSeq + } +} + diff --git a/src/main/scala/chisel3/aop/injecting/InjectingTransform.scala b/src/main/scala/chisel3/aop/injecting/InjectingTransform.scala new file mode 100644 index 00000000..c65bee38 --- /dev/null +++ b/src/main/scala/chisel3/aop/injecting/InjectingTransform.scala @@ -0,0 +1,46 @@ +// See LICENSE for license details. + +package chisel3.aop.injecting + +import firrtl.{ChirrtlForm, CircuitForm, CircuitState, Transform, ir} + +import scala.collection.mutable + +/** Appends statements contained in [[InjectStatement]] annotations to the end of their corresponding modules + * + * Implemented with Chisel Aspects and the [[chisel3.aop.injecting]] library + */ +class InjectingTransform extends Transform { + override def inputForm: CircuitForm = ChirrtlForm + override def outputForm: CircuitForm = ChirrtlForm + + override def execute(state: CircuitState): CircuitState = { + + val addStmtMap = mutable.HashMap[String, Seq[ir.Statement]]() + val addModules = mutable.ArrayBuffer[ir.DefModule]() + + // Populate addStmtMap and addModules, return annotations in InjectStatements, and omit InjectStatement annotation + val newAnnotations = state.annotations.flatMap { + case InjectStatement(mt, s, addedModules, annotations) => + addModules ++= addedModules + addStmtMap(mt.module) = s +: addStmtMap.getOrElse(mt.module, Nil) + annotations + case other => Seq(other) + } + + // Append all statements to end of corresponding modules + val newModules = state.circuit.modules.map { m: ir.DefModule => + m match { + case m: ir.Module if addStmtMap.contains(m.name) => + m.copy(body = ir.Block(m.body +: addStmtMap(m.name))) + case m: _root_.firrtl.ir.ExtModule if addStmtMap.contains(m.name) => + ir.Module(m.info, m.name, m.ports, ir.Block(addStmtMap(m.name))) + case other: ir.DefModule => other + } + } + + // Return updated circuit and annotations + val newCircuit = state.circuit.copy(modules = newModules ++ addModules) + state.copy(annotations = newAnnotations, circuit = newCircuit) + } +} diff --git a/src/main/scala/chisel3/internal/firrtl/Converter.scala b/src/main/scala/chisel3/internal/firrtl/Converter.scala deleted file mode 100644 index cdc55b59..00000000 --- a/src/main/scala/chisel3/internal/firrtl/Converter.scala +++ /dev/null @@ -1,267 +0,0 @@ -// See LICENSE for license details. - -package chisel3.internal.firrtl -import chisel3._ -import chisel3.experimental._ -import chisel3.internal.sourceinfo.{NoSourceInfo, SourceLine, SourceInfo} -import firrtl.{ir => fir} -import chisel3.internal.{castToInt, throwException} - -import scala.annotation.tailrec -import scala.collection.immutable.Queue - -private[chisel3] object Converter { - // TODO modeled on unpack method on Printable, refactor? - def unpack(pable: Printable, ctx: Component): (String, Seq[Arg]) = pable match { - case Printables(pables) => - val (fmts, args) = pables.map(p => unpack(p, ctx)).unzip - (fmts.mkString, args.flatten.toSeq) - case PString(str) => (str.replaceAll("%", "%%"), List.empty) - case format: FirrtlFormat => - ("%" + format.specifier, List(format.bits.ref)) - case Name(data) => (data.ref.name, List.empty) - case FullName(data) => (data.ref.fullName(ctx), List.empty) - case Percent => ("%%", List.empty) - } - - def convert(info: SourceInfo): fir.Info = info match { - case _: NoSourceInfo => fir.NoInfo - case SourceLine(fn, line, col) => fir.FileInfo(fir.StringLit(s"$fn $line:$col")) - } - - def convert(op: PrimOp): fir.PrimOp = firrtl.PrimOps.fromString(op.name) - - def convert(dir: MemPortDirection): firrtl.MPortDir = dir match { - case MemPortDirection.INFER => firrtl.MInfer - case MemPortDirection.READ => firrtl.MRead - case MemPortDirection.WRITE => firrtl.MWrite - case MemPortDirection.RDWR => firrtl.MReadWrite - } - - // TODO - // * Memoize? - // * Move into the Chisel IR? - def convert(arg: Arg, ctx: Component): fir.Expression = arg match { // scalastyle:ignore cyclomatic.complexity - case Node(id) => - convert(id.getRef, ctx) - case Ref(name) => - fir.Reference(name, fir.UnknownType) - case Slot(imm, name) => - fir.SubField(convert(imm, ctx), name, fir.UnknownType) - case Index(imm, ILit(idx)) => - fir.SubIndex(convert(imm, ctx), castToInt(idx, "Index"), fir.UnknownType) - case Index(imm, value) => - fir.SubAccess(convert(imm, ctx), convert(value, ctx), fir.UnknownType) - case ModuleIO(mod, name) => - // scalastyle:off if.brace - if (mod eq ctx.id) fir.Reference(name, fir.UnknownType) - else fir.SubField(fir.Reference(mod.getRef.name, fir.UnknownType), name, fir.UnknownType) - // scalastyle:on if.brace - case u @ ULit(n, UnknownWidth()) => - fir.UIntLiteral(n, fir.IntWidth(u.minWidth)) - case ULit(n, w) => - fir.UIntLiteral(n, convert(w)) - case slit @ SLit(n, w) => fir.SIntLiteral(n, convert(w)) - val unsigned = if (n < 0) (BigInt(1) << slit.width.get) + n else n - val uint = convert(ULit(unsigned, slit.width), ctx) - fir.DoPrim(firrtl.PrimOps.AsSInt, Seq(uint), Seq.empty, fir.UnknownType) - // TODO Simplify - case fplit @ FPLit(n, w, bp) => - val unsigned = if (n < 0) (BigInt(1) << fplit.width.get) + n else n - val uint = convert(ULit(unsigned, fplit.width), ctx) - val lit = bp.asInstanceOf[KnownBinaryPoint].value - fir.DoPrim(firrtl.PrimOps.AsFixedPoint, Seq(uint), Seq(lit), fir.UnknownType) - case lit: ILit => - throwException(s"Internal Error! Unexpected ILit: $lit") - } - - /** Convert Commands that map 1:1 to Statements */ - def convertSimpleCommand(cmd: Command, ctx: Component): Option[fir.Statement] = cmd match { // scalastyle:ignore cyclomatic.complexity line.size.limit - case e: DefPrim[_] => - val consts = e.args.collect { case ILit(i) => i } - val args = e.args.flatMap { - case _: ILit => None - case other => Some(convert(other, ctx)) - } - val expr = e.op.name match { - case "mux" => - assert(args.size == 3, s"Mux with unexpected args: $args") - fir.Mux(args(0), args(1), args(2), fir.UnknownType) - case _ => - fir.DoPrim(convert(e.op), args, consts, fir.UnknownType) - } - Some(fir.DefNode(convert(e.sourceInfo), e.name, expr)) - case e @ DefWire(info, id) => - Some(fir.DefWire(convert(info), e.name, extractType(id))) - case e @ DefReg(info, id, clock) => - Some(fir.DefRegister(convert(info), e.name, extractType(id), convert(clock, ctx), - firrtl.Utils.zero, convert(id.getRef, ctx))) - case e @ DefRegInit(info, id, clock, reset, init) => - Some(fir.DefRegister(convert(info), e.name, extractType(id), convert(clock, ctx), - convert(reset, ctx), convert(init, ctx))) - case e @ DefMemory(info, id, t, size) => - Some(firrtl.CDefMemory(convert(info), e.name, extractType(t), size, false)) - case e @ DefSeqMemory(info, id, t, size) => - Some(firrtl.CDefMemory(convert(info), e.name, extractType(t), size, true)) - case e: DefMemPort[_] => - Some(firrtl.CDefMPort(convert(e.sourceInfo), e.name, fir.UnknownType, - e.source.fullName(ctx), Seq(convert(e.index, ctx), convert(e.clock, ctx)), convert(e.dir))) - case Connect(info, loc, exp) => - Some(fir.Connect(convert(info), convert(loc, ctx), convert(exp, ctx))) - case BulkConnect(info, loc, exp) => - Some(fir.PartialConnect(convert(info), convert(loc, ctx), convert(exp, ctx))) - case Attach(info, locs) => - Some(fir.Attach(convert(info), locs.map(l => convert(l, ctx)))) - case DefInvalid(info, arg) => - Some(fir.IsInvalid(convert(info), convert(arg, ctx))) - case e @ DefInstance(info, id, _) => - Some(fir.DefInstance(convert(info), e.name, id.name)) - case Stop(info, clock, ret) => - Some(fir.Stop(convert(info), ret, convert(clock, ctx), firrtl.Utils.one)) - case Printf(info, clock, pable) => - val (fmt, args) = unpack(pable, ctx) - Some(fir.Print(convert(info), fir.StringLit(fmt), - args.map(a => convert(a, ctx)), convert(clock, ctx), firrtl.Utils.one)) - case _ => None - } - - /** Internal datastructure to help translate Chisel's flat Command structure to FIRRTL's AST - * - * In particular, when scoping is translated from flat with begin end to a nested datastructure - * - * @param when Current when Statement, holds info, condition, and consequence as they are - * available - * @param outer Already converted Statements that precede the current when block in the scope in - * which the when is defined (ie. 1 level up from the scope inside the when) - * @param alt Indicates if currently processing commands in the alternate (else) of the when scope - */ - // TODO we should probably have a different structure in the IR to close elses - private case class WhenFrame(when: fir.Conditionally, outer: Queue[fir.Statement], alt: Boolean) - - /** Convert Chisel IR Commands into FIRRTL Statements - * - * @note ctx is needed because references to ports translate differently when referenced within - * the module in which they are defined vs. parent modules - * @param cmds Chisel IR Commands to convert - * @param ctx Component (Module) context within which we are translating - * @return FIRRTL Statement that is equivalent to the input cmds - */ - def convert(cmds: Seq[Command], ctx: Component): fir.Statement = { // scalastyle:ignore cyclomatic.complexity - @tailrec - // scalastyle:off if.brace - def rec(acc: Queue[fir.Statement], - scope: List[WhenFrame]) - (cmds: Seq[Command]): Seq[fir.Statement] = { - if (cmds.isEmpty) { - assert(scope.isEmpty) - acc - } else convertSimpleCommand(cmds.head, ctx) match { - // Most Commands map 1:1 - case Some(stmt) => - rec(acc :+ stmt, scope)(cmds.tail) - // When scoping logic does not map 1:1 and requires pushing/popping WhenFrames - // Please see WhenFrame for more details - case None => cmds.head match { - case WhenBegin(info, pred) => - val when = fir.Conditionally(convert(info), convert(pred, ctx), fir.EmptyStmt, fir.EmptyStmt) - val frame = WhenFrame(when, acc, false) - rec(Queue.empty, frame +: scope)(cmds.tail) - case WhenEnd(info, depth, _) => - val frame = scope.head - val when = if (frame.alt) frame.when.copy(alt = fir.Block(acc)) - else frame.when.copy(conseq = fir.Block(acc)) - // Check if this when has an else - cmds.tail.headOption match { - case Some(AltBegin(_)) => - assert(!frame.alt, "Internal Error! Unexpected when structure!") // Only 1 else per when - rec(Queue.empty, frame.copy(when = when, alt = true) +: scope.tail)(cmds.drop(2)) - case _ => // Not followed by otherwise - // If depth > 0 then we need to close multiple When scopes so we add a new WhenEnd - // If we're nested we need to add more WhenEnds to ensure each When scope gets - // properly closed - val cmdsx = if (depth > 0) WhenEnd(info, depth - 1, false) +: cmds.tail else cmds.tail - rec(frame.outer :+ when, scope.tail)(cmdsx) - } - case OtherwiseEnd(info, depth) => - val frame = scope.head - val when = frame.when.copy(alt = fir.Block(acc)) - // TODO For some reason depth == 1 indicates the last closing otherwise whereas - // depth == 0 indicates last closing when - val cmdsx = if (depth > 1) OtherwiseEnd(info, depth - 1) +: cmds.tail else cmds.tail - rec(scope.head.outer :+ when, scope.tail)(cmdsx) - } - } - } - // scalastyle:on if.brace - fir.Block(rec(Queue.empty, List.empty)(cmds)) - } - - def convert(width: Width): fir.Width = width match { - case UnknownWidth() => fir.UnknownWidth - case KnownWidth(value) => fir.IntWidth(value) - } - - def convert(bp: BinaryPoint): fir.Width = bp match { - case UnknownBinaryPoint => fir.UnknownWidth - case KnownBinaryPoint(value) => fir.IntWidth(value) - } - - private def firrtlUserDirOf(d: Data): SpecifiedDirection = d match { - case d: Vec[_] => - SpecifiedDirection.fromParent(d.specifiedDirection, firrtlUserDirOf(d.sample_element)) - case d => d.specifiedDirection - } - - def extractType(data: Data, clearDir: Boolean = false): fir.Type = data match { // scalastyle:ignore cyclomatic.complexity line.size.limit - case _: Clock => fir.ClockType - case d: EnumType => fir.UIntType(convert(d.width)) - case d: UInt => fir.UIntType(convert(d.width)) - case d: SInt => fir.SIntType(convert(d.width)) - case d: FixedPoint => fir.FixedType(convert(d.width), convert(d.binaryPoint)) - case d: Analog => fir.AnalogType(convert(d.width)) - case d: Vec[_] => fir.VectorType(extractType(d.sample_element, clearDir), d.length) - case d: Record => - val childClearDir = clearDir || - d.specifiedDirection == SpecifiedDirection.Input || d.specifiedDirection == SpecifiedDirection.Output - def eltField(elt: Data): fir.Field = (childClearDir, firrtlUserDirOf(elt)) match { - case (true, _) => fir.Field(elt.getRef.name, fir.Default, extractType(elt, true)) - case (false, SpecifiedDirection.Unspecified | SpecifiedDirection.Output) => - fir.Field(elt.getRef.name, fir.Default, extractType(elt, false)) - case (false, SpecifiedDirection.Flip | SpecifiedDirection.Input) => - fir.Field(elt.getRef.name, fir.Flip, extractType(elt, false)) - } - fir.BundleType(d.elements.toIndexedSeq.reverse.map { case (_, e) => eltField(e) }) - } - - def convert(name: String, param: Param): fir.Param = param match { - case IntParam(value) => fir.IntParam(name, value) - case DoubleParam(value) => fir.DoubleParam(name, value) - case StringParam(value) => fir.StringParam(name, fir.StringLit(value)) - case RawParam(value) => fir.RawStringParam(name, value) - } - def convert(port: Port, topDir: SpecifiedDirection = SpecifiedDirection.Unspecified): fir.Port = { - val resolvedDir = SpecifiedDirection.fromParent(topDir, port.dir) - val dir = resolvedDir match { - case SpecifiedDirection.Unspecified | SpecifiedDirection.Output => fir.Output - case SpecifiedDirection.Flip | SpecifiedDirection.Input => fir.Input - } - val clearDir = resolvedDir match { - case SpecifiedDirection.Input | SpecifiedDirection.Output => true - case SpecifiedDirection.Unspecified | SpecifiedDirection.Flip => false - } - val tpe = extractType(port.id, clearDir) - fir.Port(fir.NoInfo, port.id.getRef.name, dir, tpe) - } - - def convert(component: Component): fir.DefModule = component match { - case ctx @ DefModule(_, name, ports, cmds) => - fir.Module(fir.NoInfo, name, ports.map(p => convert(p)), convert(cmds.toList, ctx)) - case ctx @ DefBlackBox(id, name, ports, topDir, params) => - fir.ExtModule(fir.NoInfo, name, ports.map(p => convert(p, topDir)), id.desiredName, - params.map { case (name, p) => convert(name, p) }.toSeq) - } - - def convert(circuit: Circuit): fir.Circuit = - fir.Circuit(fir.NoInfo, circuit.components.map(convert), circuit.name) -} - diff --git a/src/main/scala/chisel3/stage/ChiselAnnotations.scala b/src/main/scala/chisel3/stage/ChiselAnnotations.scala index fb02173b..e722bac2 100644 --- a/src/main/scala/chisel3/stage/ChiselAnnotations.scala +++ b/src/main/scala/chisel3/stage/ChiselAnnotations.scala @@ -4,11 +4,11 @@ package chisel3.stage import firrtl.annotations.{Annotation, NoTargetAnnotation} import firrtl.options.{HasShellOptions, OptionsException, ShellOption, Unserializable} - import chisel3.{ChiselException, Module} import chisel3.experimental.RawModule import chisel3.internal.Builder import chisel3.internal.firrtl.Circuit +import firrtl.AnnotationSeq /** Mixin that indicates that this is an [[firrtl.annotations.Annotation]] used to generate a [[ChiselOptions]] view. */ @@ -46,8 +46,9 @@ case class ChiselGeneratorAnnotation(gen: () => RawModule) extends NoTargetAnnot /** Run elaboration on the Chisel module generator function stored by this [[firrtl.annotations.Annotation]] */ - def elaborate: ChiselCircuitAnnotation = try { - ChiselCircuitAnnotation(Builder.build(Module(gen()))) + def elaborate: AnnotationSeq = try { + val (circuit, dut) = Builder.build(Module(gen())) + Seq(ChiselCircuitAnnotation(circuit), DesignAnnotation(dut)) } catch { case e @ (_: OptionsException | _: ChiselException) => throw e case e: Throwable => @@ -103,3 +104,11 @@ object ChiselOutputFileAnnotation extends HasShellOptions { helpValueName = Some("<file>") ) ) } + +/** Contains the top-level elaborated Chisel design. + * + * By default is created during Chisel elaboration and passed to the FIRRTL compiler. + * @param design top-level Chisel design + * @tparam DUT Type of the top-level Chisel design + */ +case class DesignAnnotation[DUT <: RawModule](design: DUT) extends NoTargetAnnotation with Unserializable diff --git a/src/main/scala/chisel3/stage/ChiselStage.scala b/src/main/scala/chisel3/stage/ChiselStage.scala index 1e92aaf6..0c6512af 100644 --- a/src/main/scala/chisel3/stage/ChiselStage.scala +++ b/src/main/scala/chisel3/stage/ChiselStage.scala @@ -14,6 +14,7 @@ class ChiselStage extends Stage { new chisel3.stage.phases.Elaborate, new chisel3.stage.phases.AddImplicitOutputFile, new chisel3.stage.phases.AddImplicitOutputAnnotationFile, + new chisel3.stage.phases.MaybeAspectPhase, new chisel3.stage.phases.Emitter, new chisel3.stage.phases.Convert, new chisel3.stage.phases.MaybeFirrtlStage ) diff --git a/src/main/scala/chisel3/stage/phases/AspectPhase.scala b/src/main/scala/chisel3/stage/phases/AspectPhase.scala new file mode 100644 index 00000000..f8038a2c --- /dev/null +++ b/src/main/scala/chisel3/stage/phases/AspectPhase.scala @@ -0,0 +1,37 @@ +// See LICENSE for license details. + +package chisel3.stage.phases + +import chisel3.aop.Aspect +import chisel3.experimental.RawModule +import chisel3.stage.DesignAnnotation +import firrtl.AnnotationSeq +import firrtl.options.Phase + +import scala.collection.mutable + +/** Phase that consumes all Aspects and calls their toAnnotationSeq methods. + * + * Consumes the [[chisel3.stage.DesignAnnotation]] and converts every [[Aspect]] into their annotations prior to executing FIRRTL + */ +class AspectPhase extends Phase { + def transform(annotations: AnnotationSeq): AnnotationSeq = { + var dut: Option[RawModule] = None + val aspects = mutable.ArrayBuffer[Aspect[_]]() + + val remainingAnnotations = annotations.flatMap { + case DesignAnnotation(d) => + dut = Some(d) + Nil + case a: Aspect[_] => + aspects += a + Nil + case other => Seq(other) + } + if(dut.isDefined) { + val newAnnotations = aspects.flatMap { _.resolveAspect(dut.get) } + remainingAnnotations ++ newAnnotations + } else annotations + } +} + diff --git a/src/main/scala/chisel3/stage/phases/Convert.scala b/src/main/scala/chisel3/stage/phases/Convert.scala index 174030ae..f08367c6 100644 --- a/src/main/scala/chisel3/stage/phases/Convert.scala +++ b/src/main/scala/chisel3/stage/phases/Convert.scala @@ -5,7 +5,6 @@ package chisel3.stage.phases import chisel3.experimental.RunFirrtlTransform import chisel3.internal.firrtl.Converter import chisel3.stage.ChiselCircuitAnnotation - import firrtl.{AnnotationSeq, Transform} import firrtl.options.Phase import firrtl.stage.{FirrtlCircuitAnnotation, RunFirrtlTransformAnnotation} @@ -18,7 +17,7 @@ import firrtl.stage.{FirrtlCircuitAnnotation, RunFirrtlTransformAnnotation} class Convert extends Phase { def transform(annotations: AnnotationSeq): AnnotationSeq = annotations.flatMap { - case a: ChiselCircuitAnnotation => { + case a: ChiselCircuitAnnotation => /* Convert this Chisel Circuit to a FIRRTL Circuit */ Some(FirrtlCircuitAnnotation(Converter.convert(a.circuit))) ++ /* Convert all Chisel Annotations to FIRRTL Annotations */ @@ -26,15 +25,15 @@ class Convert extends Phase { .circuit .annotations .map(_.toFirrtl) ++ - /* Add requested FIRRTL Transforms for any Chisel Annotations which mixed in RunFirrtlTransform */ a .circuit .annotations - .collect { case b: RunFirrtlTransform => b.transformClass } + .collect { + case anno: RunFirrtlTransform => anno.transformClass + } .distinct .filterNot(_ == classOf[firrtl.Transform]) .map { c: Class[_ <: Transform] => RunFirrtlTransformAnnotation(c.newInstance()) } - } case a => Some(a) } diff --git a/src/main/scala/chisel3/stage/phases/Elaborate.scala b/src/main/scala/chisel3/stage/phases/Elaborate.scala index 0b0d71fb..2ec5f92c 100644 --- a/src/main/scala/chisel3/stage/phases/Elaborate.scala +++ b/src/main/scala/chisel3/stage/phases/Elaborate.scala @@ -21,7 +21,7 @@ class Elaborate extends Phase { def transform(annotations: AnnotationSeq): AnnotationSeq = annotations.flatMap { case a: ChiselGeneratorAnnotation => try { - Some(a.elaborate) + a.elaborate } catch { case e: OptionsException => throw e case e: ChiselException => diff --git a/src/main/scala/chisel3/stage/phases/MaybeAspectPhase.scala b/src/main/scala/chisel3/stage/phases/MaybeAspectPhase.scala new file mode 100644 index 00000000..3e8b8feb --- /dev/null +++ b/src/main/scala/chisel3/stage/phases/MaybeAspectPhase.scala @@ -0,0 +1,18 @@ +// See LICENSE for license details. + +package chisel3.stage.phases + +import chisel3.aop.Aspect +import firrtl.AnnotationSeq +import firrtl.options.Phase + +/** Run [[AspectPhase]] if a [[chisel3.aop.Aspect]] is present. + */ +class MaybeAspectPhase extends Phase { + + def transform(annotations: AnnotationSeq): AnnotationSeq = { + if(annotations.collectFirst { case a: Aspect[_] => annotations }.isDefined) { + new AspectPhase().transform(annotations) + } else annotations + } +} diff --git a/src/main/scala/chisel3/testers/TesterDriver.scala b/src/main/scala/chisel3/testers/TesterDriver.scala index df26e3c3..7e3730a3 100644 --- a/src/main/scala/chisel3/testers/TesterDriver.scala +++ b/src/main/scala/chisel3/testers/TesterDriver.scala @@ -5,7 +5,10 @@ package chisel3.testers import chisel3._ import java.io._ +import chisel3.aop.Aspect import chisel3.experimental.RunFirrtlTransform +import chisel3.stage.phases.AspectPhase +import chisel3.stage.{ChiselCircuitAnnotation, ChiselStage, DesignAnnotation} import firrtl.{Driver => _, _} import firrtl.transforms.BlackBoxSourceHelper.writeResourceToDirectory @@ -14,9 +17,13 @@ object TesterDriver extends BackendCompilationUtilities { /** For use with modules that should successfully be elaborated by the * frontend, and which can be turned into executables with assertions. */ def execute(t: () => BasicTester, - additionalVResources: Seq[String] = Seq()): Boolean = { + additionalVResources: Seq[String] = Seq(), + annotations: AnnotationSeq = Seq() + ): Boolean = { // Invoke the chisel compiler to get the circuit's IR - val circuit = Driver.elaborate(finishWrapper(t)) + val (circuit, dut) = new chisel3.stage.ChiselGeneratorAnnotation(finishWrapper(t)).elaborate.toSeq match { + case Seq(ChiselCircuitAnnotation(cir), d:DesignAnnotation[_]) => (cir, d) + } // Set up a bunch of file handlers based on a random temp filename, // plus the quirks of Verilator's naming conventions @@ -41,13 +48,16 @@ object TesterDriver extends BackendCompilationUtilities { }) // Compile firrtl - val transforms = circuit.annotations.collect { case anno: RunFirrtlTransform => anno.transformClass }.distinct - .filterNot(_ == classOf[Transform]) - .map { transformClass: Class[_ <: Transform] => transformClass.newInstance() } - val annotations = circuit.annotations.map(_.toFirrtl).toList + val transforms = circuit.annotations.collect { + case anno: RunFirrtlTransform => anno.transformClass + }.distinct + .filterNot(_ == classOf[Transform]) + .map { transformClass: Class[_ <: Transform] => transformClass.newInstance() } + val newAnnotations = circuit.annotations.map(_.toFirrtl).toList ++ annotations ++ Seq(dut) + val resolvedAnnotations = new AspectPhase().transform(newAnnotations).toList val optionsManager = new ExecutionOptionsManager("chisel3") with HasChiselExecutionOptions with HasFirrtlOptions { commonOptions = CommonOptions(topName = target, targetDirName = path.getAbsolutePath) - firrtlOptions = FirrtlExecutionOptions(compilerName = "verilog", annotations = annotations, + firrtlOptions = FirrtlExecutionOptions(compilerName = "verilog", annotations = resolvedAnnotations, customTransforms = transforms, firrtlCircuit = Some(firrtlCircuit)) } diff --git a/src/test/scala/chiselTests/ChiselSpec.scala b/src/test/scala/chiselTests/ChiselSpec.scala index 5973cb63..75fa68dd 100644 --- a/src/test/scala/chiselTests/ChiselSpec.scala +++ b/src/test/scala/chiselTests/ChiselSpec.scala @@ -8,25 +8,29 @@ import org.scalacheck._ import chisel3._ import chisel3.experimental.RawModule import chisel3.testers._ -import firrtl.{ - CommonOptions, - ExecutionOptionsManager, - HasFirrtlOptions, - FirrtlExecutionSuccess, - FirrtlExecutionFailure -} +import firrtl.options.OptionsException +import firrtl.{AnnotationSeq, CommonOptions, ExecutionOptionsManager, FirrtlExecutionFailure, FirrtlExecutionSuccess, HasFirrtlOptions} import firrtl.util.BackendCompilationUtilities /** Common utility functions for Chisel unit tests. */ trait ChiselRunners extends Assertions with BackendCompilationUtilities { - def runTester(t: => BasicTester, additionalVResources: Seq[String] = Seq()): Boolean = { - TesterDriver.execute(() => t, additionalVResources) + def runTester(t: => BasicTester, + additionalVResources: Seq[String] = Seq(), + annotations: AnnotationSeq = Seq() + ): Boolean = { + TesterDriver.execute(() => t, additionalVResources, annotations) } - def assertTesterPasses(t: => BasicTester, additionalVResources: Seq[String] = Seq()): Unit = { - assert(runTester(t, additionalVResources)) + def assertTesterPasses(t: => BasicTester, + additionalVResources: Seq[String] = Seq(), + annotations: AnnotationSeq = Seq() + ): Unit = { + assert(runTester(t, additionalVResources, annotations)) } - def assertTesterFails(t: => BasicTester, additionalVResources: Seq[String] = Seq()): Unit = { - assert(!runTester(t, additionalVResources)) + def assertTesterFails(t: => BasicTester, + additionalVResources: Seq[String] = Seq(), + annotations: Seq[chisel3.aop.Aspect[_]] = Seq() + ): Unit = { + assert(!runTester(t, additionalVResources, annotations)) } def elaborate(t: => RawModule): Unit = Driver.elaborate(() => t) @@ -95,11 +99,12 @@ class ChiselTestUtilitiesSpec extends ChiselFlatSpec { import org.scalatest.exceptions.TestFailedException // Who tests the testers? "assertKnownWidth" should "error when the expected width is wrong" in { - a [TestFailedException] shouldBe thrownBy { + val caught = intercept[OptionsException] { assertKnownWidth(7) { Wire(UInt(8.W)) } } + assert(caught.getCause.isInstanceOf[TestFailedException]) } it should "error when the width is unknown" in { @@ -117,11 +122,12 @@ class ChiselTestUtilitiesSpec extends ChiselFlatSpec { } "assertInferredWidth" should "error if the width is known" in { - a [TestFailedException] shouldBe thrownBy { + val caught = intercept[OptionsException] { assertInferredWidth(8) { Wire(UInt(8.W)) } } + assert(caught.getCause.isInstanceOf[TestFailedException]) } it should "error if the expected width is wrong" in { diff --git a/src/test/scala/chiselTests/aop/InjectionSpec.scala b/src/test/scala/chiselTests/aop/InjectionSpec.scala new file mode 100644 index 00000000..6c022d60 --- /dev/null +++ b/src/test/scala/chiselTests/aop/InjectionSpec.scala @@ -0,0 +1,58 @@ +// See LICENSE for license details. + +package chiselTests.aop + +import chisel3.testers.BasicTester +import chiselTests.ChiselFlatSpec +import chisel3._ +import chisel3.aop.injecting.InjectingAspect + +class AspectTester(results: Seq[Int]) extends BasicTester { + val values = VecInit(results.map(_.U)) + val counter = RegInit(0.U(results.length.W)) + counter := counter + 1.U + when(counter >= values.length.U) { + stop() + }.otherwise { + when(reset.asBool() === false.B) { + printf("values(%d) = %d\n", counter, values(counter)) + assert(counter === values(counter)) + } + } +} + +class InjectionSpec extends ChiselFlatSpec { + val correctValueAspect = InjectingAspect( + {dut: AspectTester => Seq(dut)}, + {dut: AspectTester => + for(i <- 0 until dut.values.length) { + dut.values(i) := i.U + } + } + ) + + val wrongValueAspect = InjectingAspect( + {dut: AspectTester => Seq(dut)}, + {dut: AspectTester => + for(i <- 0 until dut.values.length) { + dut.values(i) := (i + 1).U + } + } + ) + + "Test" should "pass if inserted the correct values" in { + assertTesterPasses{ new AspectTester(Seq(0, 1, 2)) } + } + "Test" should "fail if inserted the wrong values" in { + assertTesterFails{ new AspectTester(Seq(9, 9, 9)) } + } + "Test" should "pass if pass wrong values, but correct with aspect" in { + assertTesterPasses({ new AspectTester(Seq(9, 9, 9))} , Nil, Seq(correctValueAspect)) + } + "Test" should "pass if pass wrong values, then wrong aspect, then correct aspect" in { + assertTesterPasses({ new AspectTester(Seq(9, 9, 9))} , Nil, Seq(wrongValueAspect, correctValueAspect)) + } + "Test" should "fail if pass wrong values, then correct aspect, then wrong aspect" in { + assertTesterFails({ new AspectTester(Seq(9, 9, 9))} , Nil, Seq(correctValueAspect, wrongValueAspect)) + } +} diff --git a/src/test/scala/chiselTests/aop/SelectSpec.scala b/src/test/scala/chiselTests/aop/SelectSpec.scala new file mode 100644 index 00000000..d3f72551 --- /dev/null +++ b/src/test/scala/chiselTests/aop/SelectSpec.scala @@ -0,0 +1,144 @@ +// See LICENSE for license details. + +package chiselTests.aop + +import chisel3.testers.BasicTester +import chiselTests.ChiselFlatSpec +import chisel3._ +import chisel3.aop.Select.{PredicatedConnect, When, WhenNot} +import chisel3.aop.{Aspect, Select} +import chisel3.experimental.RawModule +import firrtl.{AnnotationSeq} + +import scala.reflect.runtime.universe.TypeTag + +class SelectTester(results: Seq[Int]) extends BasicTester { + val values = VecInit(results.map(_.U)) + val counter = RegInit(0.U(results.length.W)) + val added = counter + 1.U + counter := added + val overflow = counter >= values.length.U + val nreset = reset.asBool() === false.B + val selected = values(counter) + val zero = 0.U + 0.U + when(overflow) { + counter := zero + stop() + }.otherwise { + when(nreset) { + assert(counter === values(counter)) + printf("values(%d) = %d\n", counter, selected) + } + } +} + +case class SelectAspect[T <: RawModule, X](selector: T => Seq[X], desired: T => Seq[X])(implicit tTag: TypeTag[T]) extends Aspect[T] { + override def toAnnotation(top: T): AnnotationSeq = { + val results = selector(top) + val desiredSeq = desired(top) + assert(results.length == desiredSeq.length, s"Failure! Results $results have different length than desired $desiredSeq!") + val mismatches = results.zip(desiredSeq).flatMap { + case (res, des) if res != des => Seq((res, des)) + case other => Nil + } + assert(mismatches.isEmpty,s"Failure! The following selected items do not match their desired item:\n" + mismatches.map{ + case (res: Select.Serializeable, des: Select.Serializeable) => s" ${res.serialize} does not match:\n ${des.serialize}" + case (res, des) => s" $res does not match:\n $des" + }.mkString("\n")) + Nil + } +} + +class SelectSpec extends ChiselFlatSpec { + + def execute[T <: RawModule, X](dut: () => T, selector: T => Seq[X], desired: T => Seq[X])(implicit tTag: TypeTag[T]): Unit = { + val ret = new chisel3.stage.ChiselStage().run( + Seq( + new chisel3.stage.ChiselGeneratorAnnotation(dut), + SelectAspect(selector, desired), + new chisel3.stage.ChiselOutputFileAnnotation("test_run_dir/Select.fir") + ) + ) + } + + "Test" should "pass if selecting correct registers" in { + execute( + () => new SelectTester(Seq(0, 1, 2)), + { dut: SelectTester => Select.registers(dut) }, + { dut: SelectTester => Seq(dut.counter) } + ) + } + + "Test" should "pass if selecting correct wires" in { + execute( + () => new SelectTester(Seq(0, 1, 2)), + { dut: SelectTester => Select.wires(dut) }, + { dut: SelectTester => Seq(dut.values) } + ) + } + + "Test" should "pass if selecting correct printfs" in { + execute( + () => new SelectTester(Seq(0, 1, 2)), + { dut: SelectTester => Seq(Select.printfs(dut).last) }, + { dut: SelectTester => + Seq(Select.Printf( + Seq( + When(Select.ops("eq")(dut).last.asInstanceOf[Bool]), + When(dut.nreset), + WhenNot(dut.overflow) + ), + Printable.pack("values(%d) = %d\n", dut.counter, dut.selected), + dut.clock + )) + } + ) + } + + "Test" should "pass if selecting correct connections" in { + execute( + () => new SelectTester(Seq(0, 1, 2)), + { dut: SelectTester => Select.connectionsTo(dut)(dut.counter) }, + { dut: SelectTester => + Seq(PredicatedConnect(Nil, dut.counter, dut.added, false), + PredicatedConnect(Seq(When(dut.overflow)), dut.counter, dut.zero, false)) + } + ) + } + + "Test" should "pass if selecting ops by kind" in { + execute( + () => new SelectTester(Seq(0, 1, 2)), + { dut: SelectTester => Select.ops("tail")(dut) }, + { dut: SelectTester => Seq(dut.added, dut.zero) } + ) + } + + "Test" should "pass if selecting ops" in { + execute( + () => new SelectTester(Seq(0, 1, 2)), + { dut: SelectTester => Select.ops(dut).collect { case ("tail", d) => d} }, + { dut: SelectTester => Seq(dut.added, dut.zero) } + ) + } + + "Test" should "pass if selecting correct stops" in { + execute( + () => new SelectTester(Seq(0, 1, 2)), + { dut: SelectTester => Seq(Select.stops(dut).last) }, + { dut: SelectTester => + Seq(Select.Stop( + Seq( + When(Select.ops("eq")(dut).dropRight(1).last.asInstanceOf[Bool]), + When(dut.nreset), + WhenNot(dut.overflow) + ), + 1, + dut.clock + )) + } + ) + } + +} + diff --git a/src/test/scala/chiselTests/stage/ChiselAnnotationsSpec.scala b/src/test/scala/chiselTests/stage/ChiselAnnotationsSpec.scala index c89955f2..63b1001f 100644 --- a/src/test/scala/chiselTests/stage/ChiselAnnotationsSpec.scala +++ b/src/test/scala/chiselTests/stage/ChiselAnnotationsSpec.scala @@ -3,11 +3,9 @@ package chiselTests.stage import org.scalatest.{FlatSpec, Matchers} - import chisel3._ -import chisel3.stage.{ChiselCircuitAnnotation, ChiselGeneratorAnnotation} +import chisel3.stage.{ChiselCircuitAnnotation, ChiselGeneratorAnnotation, DesignAnnotation} import chisel3.experimental.RawModule - import firrtl.options.OptionsException class ChiselAnnotationsSpecFoo extends RawModule { @@ -33,7 +31,9 @@ class ChiselAnnotationsSpec extends FlatSpec with Matchers { it should "elaborate to a ChiselCircuitAnnotation" in { val annotation = ChiselGeneratorAnnotation(() => new ChiselAnnotationsSpecFoo) - annotation.elaborate shouldBe a [ChiselCircuitAnnotation] + val res = annotation.elaborate + res(0) shouldBe a [ChiselCircuitAnnotation] + res(1) shouldBe a [DesignAnnotation[ChiselAnnotationsSpecFoo]] } it should "throw an exception if elaboration fails" in { @@ -45,7 +45,9 @@ class ChiselAnnotationsSpec extends FlatSpec with Matchers { it should "elaborate from a String" in { val annotation = ChiselGeneratorAnnotation("chiselTests.stage.ChiselAnnotationsSpecFoo") - annotation.elaborate shouldBe a [ChiselCircuitAnnotation] + val res = annotation.elaborate + res(0) shouldBe a [ChiselCircuitAnnotation] + res(1) shouldBe a [DesignAnnotation[ChiselAnnotationsSpecFoo]] } it should "throw an exception if elaboration from a String refers to nonexistant class" in { |
