diff options
| author | John's Brew | 2020-03-13 02:35:10 +0100 |
|---|---|---|
| committer | GitHub | 2020-03-12 18:35:10 -0700 |
| commit | 5c0c0018d812d57270035a9d3bd82e2289acf4ec (patch) | |
| tree | 3e9c319c0e98566b42540a5f31d043d5d0287c17 /src | |
| parent | 7e8d21e7f5fe3469eada53e6a6c60e38c134c403 (diff) | |
Add Support for FPGA Bitstream Preset-registers (#1050)
Introduce Preset Register Specialized Emission
- Introduce EmissionOption trait
- Introduce PresetAnnotation & PresetRegAnnotation
- Enable the collection of Annotations in the Emitter
- Introduce collection mechanism for EmissionOptions in the Emitter
- Add PropagatePresetAnnotation transform to annotate register for emission and clean-up the useless reset tree (no DCE involved)
- Add corresponding tests spec and tester
Co-authored-by: Jack Koenig <koenig@sifive.com>
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/AddDescriptionNodes.scala | 1 | ||||
| -rw-r--r-- | src/main/scala/firrtl/EmissionOption.scala | 51 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Emitter.scala | 114 | ||||
| -rw-r--r-- | src/main/scala/firrtl/annotations/PresetAnnotations.scala | 33 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/InlineCasts.scala | 3 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/PropagatePresetAnnotations.scala | 451 | ||||
| -rw-r--r-- | src/test/resources/features/PresetTester.fir | 51 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/FirrtlSpec.scala | 4 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/LoweringCompilersSpec.scala | 2 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/PresetSpec.scala | 239 |
10 files changed, 928 insertions, 21 deletions
diff --git a/src/main/scala/firrtl/AddDescriptionNodes.scala b/src/main/scala/firrtl/AddDescriptionNodes.scala index dac0e513..213bfad6 100644 --- a/src/main/scala/firrtl/AddDescriptionNodes.scala +++ b/src/main/scala/firrtl/AddDescriptionNodes.scala @@ -77,6 +77,7 @@ class AddDescriptionNodes extends Transform with PreservesAll[Transform] { Dependency[firrtl.transforms.FixAddingNegativeLiterals], Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], Dependency[firrtl.transforms.InlineBitExtractionsTransform], + Dependency[firrtl.transforms.PropagatePresetAnnotations], Dependency[firrtl.transforms.InlineCastsTransform], Dependency[firrtl.transforms.LegalizeClocksTransform], Dependency[firrtl.transforms.FlattenRegUpdate], diff --git a/src/main/scala/firrtl/EmissionOption.scala b/src/main/scala/firrtl/EmissionOption.scala new file mode 100644 index 00000000..11c16e2c --- /dev/null +++ b/src/main/scala/firrtl/EmissionOption.scala @@ -0,0 +1,51 @@ +// See LICENSE for license details. + +package firrtl + +/** + * Base type for emission customization options + * NOTE: all the following traits must be mixed with SingleTargetAnnotation[T <: Named] + * in order to be taken into account in the Emitter + */ +trait EmissionOption + + +/** Emission customization options for registers */ +trait RegisterEmissionOption extends EmissionOption { + /** when true the reset init value will be used to emit a bitstream preset */ + def useInitAsPreset : Boolean = false + + /** when true the initial randomization is disabled for this register */ + def disableRandomization : Boolean = false +} + +/** default Emitter behavior for registers */ +case object RegisterEmissionOptionDefault extends RegisterEmissionOption + + +/** Emission customization options for IO ports */ +trait PortEmissionOption extends EmissionOption + +/** default Emitter behavior for IO ports */ +case object PortEmissionOptionDefault extends PortEmissionOption + + +/** Emission customization options for wires */ +trait WireEmissionOption extends EmissionOption + +/** default Emitter behavior for wires */ +case object WireEmissionOptionDefault extends WireEmissionOption + + +/** Emission customization options for nodes */ +trait NodeEmissionOption extends EmissionOption + +/** default Emitter behavior for nodes */ +case object NodeEmissionOptionDefault extends NodeEmissionOption + + +/** Emission customization options for connect */ +trait ConnectEmissionOption extends EmissionOption + +/** default Emitter behavior for connect */ +case object ConnectEmissionOptionDefault extends ConnectEmissionOption diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index 36734d81..5ab385be 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -8,7 +8,6 @@ import scala.collection.mutable import firrtl.ir._ import firrtl.passes._ -import firrtl.transforms._ import firrtl.annotations._ import firrtl.traversals.Foreachers._ import firrtl.PrimOps._ @@ -448,10 +447,61 @@ class VerilogEmitter extends SeqTransform with Emitter { val newMod = new AddDescriptionNodes().executeModule(m, descriptions) newMod match { - case DescribedMod(d, pds, m: Module) => new VerilogRender(d, pds, m, moduleMap)(writer) + case DescribedMod(d, pds, m: Module) => new VerilogRender(d, pds, m, moduleMap, "", new EmissionOptions(Seq.empty))(writer) case m: Module => new VerilogRender(m, moduleMap)(writer) } } + + /** + * Store Emission option per Target + * Guarantee only one emission option per Target + */ + private[firrtl] class EmissionOptionMap[V <: EmissionOption](val df : V) extends collection.mutable.HashMap[ReferenceTarget, V] { + override def default(key: ReferenceTarget) = df + override def +=(elem : (ReferenceTarget, V)) : EmissionOptionMap.this.type = { + if (this.contains(elem._1)) + throw EmitterException(s"Multiple EmissionOption for the target ${elem._1} (${this(elem._1)} ; ${elem._2})") + super.+=(elem) + } + } + + /** + * Provide API to retrieve EmissionOptions based on the provided AnnotationSeq + * + * @param annotations : AnnotationSeq to be searched for EmissionOptions + * + */ + private[firrtl] class EmissionOptions(annotations: AnnotationSeq) { + // Private so that we can present an immutable API + private val registerEmissionOption = new EmissionOptionMap[RegisterEmissionOption](RegisterEmissionOptionDefault) + private val wireEmissionOption = new EmissionOptionMap[WireEmissionOption](WireEmissionOptionDefault) + private val portEmissionOption = new EmissionOptionMap[PortEmissionOption](PortEmissionOptionDefault) + private val nodeEmissionOption = new EmissionOptionMap[NodeEmissionOption](NodeEmissionOptionDefault) + private val connectEmissionOption = new EmissionOptionMap[ConnectEmissionOption](ConnectEmissionOptionDefault) + + def getRegisterEmissionOption(target: ReferenceTarget): RegisterEmissionOption = + registerEmissionOption(target) + + def getWireEmissionOption(target: ReferenceTarget): WireEmissionOption = + wireEmissionOption(target) + + def getPortEmissionOption(target: ReferenceTarget): PortEmissionOption = + portEmissionOption(target) + + def getNodeEmissionOption(target: ReferenceTarget): NodeEmissionOption = + nodeEmissionOption(target) + + def getConnectEmissionOption(target: ReferenceTarget): ConnectEmissionOption = + connectEmissionOption(target) + + private val emissionAnnos = annotations.collect{ case m : SingleTargetAnnotation[ReferenceTarget] with EmissionOption => m } + // using multiple foreach instead of a single partial function as an Annotation can gather multiple EmissionOptions for simplicity + emissionAnnos.foreach { case a :RegisterEmissionOption => registerEmissionOption += ((a.target,a)) case _ => } + emissionAnnos.foreach { case a :WireEmissionOption => wireEmissionOption += ((a.target,a)) case _ => } + emissionAnnos.foreach { case a :PortEmissionOption => portEmissionOption += ((a.target,a)) case _ => } + emissionAnnos.foreach { case a :NodeEmissionOption => nodeEmissionOption += ((a.target,a)) case _ => } + emissionAnnos.foreach { case a :ConnectEmissionOption => connectEmissionOption += ((a.target,a)) case _ => } + } /** * Used by getRenderer, it has machinery to produce verilog from IR. @@ -466,10 +516,15 @@ class VerilogEmitter extends SeqTransform with Emitter { class VerilogRender(description: Description, portDescriptions: Map[String, Description], m: Module, - moduleMap: Map[String, DefModule])(implicit writer: Writer) { + moduleMap: Map[String, DefModule], + circuitName: String, + emissionOptions: EmissionOptions)(implicit writer: Writer) { + def this(m: Module, moduleMap: Map[String, DefModule], circuitName: String, emissionOptions: EmissionOptions)(implicit writer: Writer) { + this(EmptyDescription, Map.empty, m, moduleMap, circuitName, emissionOptions)(writer) + } def this(m: Module, moduleMap: Map[String, DefModule])(implicit writer: Writer) { - this(EmptyDescription, Map.empty, m, moduleMap)(writer) + this(EmptyDescription, Map.empty, m, moduleMap, "", new EmissionOptions(Seq.empty))(writer) } val netlist = mutable.LinkedHashMap[WrappedExpression, Expression]() @@ -520,7 +575,21 @@ class VerilogEmitter extends SeqTransform with Emitter { def declareVectorType(b: String, n: String, tpe: Type, size: BigInt, info: Info) = { declares += Seq(b, " ", tpe, " ", n, " [0:", bigIntToVLit(size - 1), "];", info) } - + def declareVectorType(b: String, n: String, tpe: Type, size: BigInt, info: Info, preset: Expression) = { + declares += Seq(b, " ", tpe, " ", n, " [0:", bigIntToVLit(size - 1), "] = ", preset, ";", info) + } + + val moduleTarget = CircuitTarget(circuitName).module(m.name) + + // declare with initial value + def declare(b: String, n: String, t: Type, info: Info, preset: Expression) = t match { + case tx: VectorType => + declareVectorType(b, n, tx.tpe, tx.size, info, preset) + case tx => + declares += Seq(b, " ", tx, " ", n, " = ", preset, ";", info) + } + + // original declare without initial value def declare(b: String, n: String, t: Type, info: Info) = t match { case tx: VectorType => declareVectorType(b, n, tx.tpe, tx.size, info) @@ -743,10 +812,17 @@ class VerilogEmitter extends SeqTransform with Emitter { case sx: DefWire => declare("wire", sx.name, sx.tpe, sx.info) case sx: DefRegister => - declare("reg", sx.name, sx.tpe, sx.info) + val options = emissionOptions.getRegisterEmissionOption(moduleTarget.ref(sx.name)) val e = wref(sx.name, sx.tpe) - regUpdate(e, sx.clock, sx.reset, sx.init) - initialize(e, sx.reset, sx.init) + if (options.useInitAsPreset){ + declare("reg", sx.name, sx.tpe, sx.info, sx.init) + regUpdate(e, sx.clock, sx.reset, e) + } else { + declare("reg", sx.name, sx.tpe, sx.info) + regUpdate(e, sx.clock, sx.reset, sx.init) + } + if (!options.disableRandomization) + initialize(e, sx.reset, sx.init) case sx: DefNode => declare("wire", sx.name, sx.value.tpe, sx.info) assign(WRef(sx.name, sx.value.tpe, NodeKind, SourceFlow), sx.value, sx.info) @@ -981,14 +1057,15 @@ class VerilogEmitter extends SeqTransform with Emitter { def transforms = new TransformManager(firrtl.stage.Forms.VerilogOptimized, prerequisites).flattenedTransformOrder def emit(state: CircuitState, writer: Writer): Unit = { - val circuit = runTransforms(state).circuit - val moduleMap = circuit.modules.map(m => m.name -> m).toMap - circuit.modules.foreach { + val cs = runTransforms(state) + val emissionOptions = new EmissionOptions(cs.annotations) + val moduleMap = cs.circuit.modules.map(m => m.name -> m).toMap + cs.circuit.modules.foreach { case dm @ DescribedMod(d, pds, m: Module) => - val renderer = new VerilogRender(d, pds, m, moduleMap)(writer) + val renderer = new VerilogRender(d, pds, m, moduleMap, cs.circuit.main, emissionOptions)(writer) renderer.emit_verilog() case m: Module => - val renderer = new VerilogRender(m, moduleMap)(writer) + val renderer = new VerilogRender(m, moduleMap, cs.circuit.main, emissionOptions)(writer) renderer.emit_verilog() case _ => // do nothing } @@ -1002,18 +1079,19 @@ class VerilogEmitter extends SeqTransform with Emitter { Seq(EmittedVerilogCircuitAnnotation(EmittedVerilogCircuit(state.circuit.main, writer.toString, outputSuffix))) case EmitAllModulesAnnotation(_) => - val circuit = runTransforms(state).circuit - val moduleMap = circuit.modules.map(m => m.name -> m).toMap + val cs = runTransforms(state) + val emissionOptions = new EmissionOptions(cs.annotations) + val moduleMap = cs.circuit.modules.map(m => m.name -> m).toMap - circuit.modules flatMap { + cs.circuit.modules flatMap { case dm @ DescribedMod(d, pds, module: Module) => val writer = new java.io.StringWriter - val renderer = new VerilogRender(d, pds, module, moduleMap)(writer) + val renderer = new VerilogRender(d, pds, module, moduleMap, cs.circuit.main, emissionOptions)(writer) renderer.emit_verilog() Some(EmittedVerilogModuleAnnotation(EmittedVerilogModule(module.name, writer.toString, outputSuffix))) case module: Module => val writer = new java.io.StringWriter - val renderer = new VerilogRender(module, moduleMap)(writer) + val renderer = new VerilogRender(module, moduleMap, cs.circuit.main, emissionOptions)(writer) renderer.emit_verilog() Some(EmittedVerilogModuleAnnotation(EmittedVerilogModule(module.name, writer.toString, outputSuffix))) case _ => None diff --git a/src/main/scala/firrtl/annotations/PresetAnnotations.scala b/src/main/scala/firrtl/annotations/PresetAnnotations.scala new file mode 100644 index 00000000..727417c1 --- /dev/null +++ b/src/main/scala/firrtl/annotations/PresetAnnotations.scala @@ -0,0 +1,33 @@ +// See LICENSE for license details. + +package firrtl +package annotations + +/** + * Transform all registers connected to the targeted AsyncReset tree into bitstream preset registers + * Impacts all registers connected to any child (cross module) of the target AsyncReset + * + * @param target ReferenceTarget to an AsyncReset + */ +case class PresetAnnotation(target: ReferenceTarget) + extends SingleTargetAnnotation[ReferenceTarget] with firrtl.transforms.DontTouchAllTargets { + override def duplicate(n: ReferenceTarget) = this.copy(target = n) +} + + +/** + * Transform the targeted asynchronously-reset Reg into a bitstream preset Reg + * Used internally to annotate all registers associated to an AsyncReset tree + * + * @param target ReferenceTarget to a Reg + */ +private[firrtl] case class PresetRegAnnotation( + target: ReferenceTarget +) extends SingleTargetAnnotation[ReferenceTarget] with RegisterEmissionOption { + def duplicate(n: ReferenceTarget) = this.copy(target = n) + override def useInitAsPreset = true + override def disableRandomization = true +} + + + diff --git a/src/main/scala/firrtl/transforms/InlineCasts.scala b/src/main/scala/firrtl/transforms/InlineCasts.scala index 91ba7578..eeafb0e4 100644 --- a/src/main/scala/firrtl/transforms/InlineCasts.scala +++ b/src/main/scala/firrtl/transforms/InlineCasts.scala @@ -68,7 +68,8 @@ class InlineCastsTransform extends Transform with PreservesAll[Transform] { Seq( Dependency[BlackBoxSourceHelper], Dependency[FixAddingNegativeLiterals], Dependency[ReplaceTruncatingArithmetic], - Dependency[InlineBitExtractionsTransform] ) + Dependency[InlineBitExtractionsTransform], + Dependency[PropagatePresetAnnotations] ) override val optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized diff --git a/src/main/scala/firrtl/transforms/PropagatePresetAnnotations.scala b/src/main/scala/firrtl/transforms/PropagatePresetAnnotations.scala new file mode 100644 index 00000000..92022278 --- /dev/null +++ b/src/main/scala/firrtl/transforms/PropagatePresetAnnotations.scala @@ -0,0 +1,451 @@ +// See LICENSE for license details. + +package firrtl +package transforms + +import firrtl.{Utils} + + +import firrtl.PrimOps._ +import firrtl.ir._ +import firrtl.ir.{AsyncResetType} +import firrtl.annotations._ +import firrtl.options.{Dependency, PreservesAll} + +import scala.collection.mutable + +object PropagatePresetAnnotations { + val advice = "Please Note that a Preset-annotated AsyncReset shall NOT be casted to other types with any of the following functions: asInterval, asUInt, asSInt, asClock, asFixedPoint, asAsyncReset." + case class TreeCleanUpOrphanException(message: String) extends FirrtlUserException(s"Node left an orphan during tree cleanup: $message $advice") +} + +/** Propagate PresetAnnotations to all children of targeted AsyncResets + * Leaf Registers are annotated with PresetRegAnnotation + * All wires, nodes and connectors along the way are suppressed + * + * Processing of multiples targets are NOT isolated from one another as the expected outcome does not differ + * Annotations of leaf registers, wires, nodes & connectors does indeed not depend on the initial AsyncReset reference + * The set of created annotation based on multiple initial AsyncReset PresetAnnotation + * + * This transform consists of 2 successive walk of the AST + * I./ Propagate + * - 1./ Create all AsyncResetTrees + * - 2./ Leverage them to annotate register for specialized emission & PresetTree for cleanUp + * II./ CleanUpTree + * - clean up all the intermediate nodes (replaced with EmptyStmt) + * - raise Error on orphans (typically cast of Annotated Reset) + * - disconnect Registers from their reset nodes (replaced with UInt(0)) + * + * Thanks to the clean-up phase, this transform does not rely on DCE + * + * @note This pass must run before InlineCastsTransform + */ +class PropagatePresetAnnotations extends Transform with PreservesAll[Transform] { + def inputForm = UnknownForm + def outputForm = UnknownForm + + override val prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ + Seq( Dependency[BlackBoxSourceHelper], + Dependency[FixAddingNegativeLiterals], + Dependency[ReplaceTruncatingArithmetic]) + + override val optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized + + override val dependents = Seq.empty + + + import PropagatePresetAnnotations._ + + private type TargetSet = mutable.HashSet[ReferenceTarget] + private type TargetMap = mutable.HashMap[ReferenceTarget,String] + private type TargetSetMap = mutable.HashMap[ReferenceTarget, TargetSet] + + private val toCleanUp = new TargetSet() + + /** + * Logic of the propagation, divided in two main phases: + * 1./ Walk all the Circuit looking for annotated AsyncResets : + * - Store all Annotated AsyncReset reference + * - Build all AsyncReset Trees (whether annotated or not) + * - Store all async-reset-registers (whether annotated or not) + * 2./ Walk the AsyncReset Tree based on Annotated AsyncReset as entry points + * - Annotate all leaf register with PresetRegAnnotation + * - Annotate all intermediate wire, node, connect with PresetConnectorAnnotation + * + * @param circuit the circuit + * @param annotations all the annotations + * @return updated annotations + */ + private def propagate(cs: CircuitState, presetAnnos: Seq[PresetAnnotation]): AnnotationSeq = { + val presets = presetAnnos.groupBy(_.target) + // store all annotated asyncreset references + val asyncToAnnotate = new TargetSet() + // store all async-reset-registers + val asyncRegMap = new TargetSetMap() + // store async-reset trees + val asyncCoMap = new TargetSetMap() + // Annotations to be appended and returned as result of the transform + val annos = cs.annotations.to[mutable.ArrayBuffer] -- presetAnnos + + val circuitTarget = CircuitTarget(cs.circuit.main) + + /* + * WALK I PHASE 1 FUNCTIONS + */ + + /** + * Walk current module + * - process ports + * - store connections & entry points for PHASE 2 + * - process statements + * - Instances => record local instances for cross module AsyncReset Tree Buidling + * - Registers => store AsyncReset bound registers for PHASE 2 + * - Wire => store AsyncReset Connections & entry points for PHASE 2 + * - Connect => store AsyncReset Connections & entry points for PHASE 2 + * + * @param m module + */ + def processModule(m: DefModule): Unit = { + val moduleTarget = circuitTarget.module(m.name) + val localInstances = new TargetMap() + + /** + * Recursively process a given type + * Recursive on Bundle and Vector Type only + * Store Register and Connections for AsyncResetType + * + * @param tpe Type to be processed + * @param target ReferenceTarget associated to the tpe + * @param all Boolean indicating whether all subelements of the current tpe should also be stored as Annotated AsyncReset entry points + */ + def processType(tpe: Type, target: ReferenceTarget, all: Boolean): Unit = { + if(tpe == AsyncResetType){ + asyncRegMap(target) = new TargetSet() + asyncCoMap(target) = new TargetSet() + if (presets.contains(target) || all) { + asyncToAnnotate += target + } + } else { + tpe match { + case b: BundleType => + b.fields.foreach{ + (x: Field) => + val tar = target.field(x.name) + processType(x.tpe, tar, (presets.contains(tar) || all)) + } + + case v: VectorType => + for(i <- 0 until v.size) { + val tar = target.index(i) + processType(v.tpe, tar, (presets.contains(tar) || all)) + } + case _ => + } + } + } + + def processWire(w: DefWire): Unit = { + val target = moduleTarget.ref(w.name) + processType(w.tpe, target, presets.contains(target)) + } + + /** + * Recursively search for the ReferenceTarget of a given Expression + * + * @param e Targeted Expression + * @param ta Local ReferenceTarget of the Targeted Expression + * @return a ReferenceTarget in case of success, a GenericTarget otherwise + * @throw Internal Error on unexpected recursive path return results + */ + def getRef(e: Expression, ta: ReferenceTarget, annoCo: Boolean = false) : Target = { + e match { + case w: WRef => moduleTarget.ref(w.name) + case w: WSubField => + getRef(w.expr, ta, annoCo) match { + case rt: ReferenceTarget => + if(localInstances.contains(rt)){ + val remote_ref = circuitTarget.module(localInstances(rt)) + if (annoCo) + asyncCoMap(ta) += rt.field(w.name) + remote_ref.ref(w.name) + } else { + rt.field(w.name) + } + case remote_target => remote_target + } + case w: WSubIndex => + getRef(w.expr, ta, annoCo) match { + case remote_target: ReferenceTarget => + if (annoCo) + asyncCoMap(ta) += remote_target + remote_target.index(w.value) + case _ => Utils.throwInternalError("Unexpected Reference kind") + } + + case _ => Target(None, None, Seq.empty) + } + } + + def processRegister(r: DefRegister): Unit = { + getRef(r.reset, moduleTarget.ref(r.name), false) match { + case rt : ReferenceTarget => + if (asyncRegMap.contains(rt)) { + asyncRegMap(rt) += moduleTarget.ref(r.name) + } + case _ => + } + + } + + def processConnect(c: Connect): Unit = { + getRef(c.expr, ReferenceTarget("","", Seq.empty, "", Seq.empty)) match { + case rhs: ReferenceTarget => + if (presets.contains(rhs) || asyncRegMap.contains(rhs)) { + getRef(c.loc, rhs, true) match { + case lhs : ReferenceTarget => + if(asyncRegMap.contains(rhs)){ + asyncRegMap(rhs) += lhs + } else { + asyncToAnnotate += lhs + } + case _ => // + } + } + case rhs: GenericTarget => //nothing to do + case _ => Utils.throwInternalError("Unexpected Reference kind") + } + } + + def processNode(n: DefNode): Unit = { + val target = moduleTarget.ref(n.name) + processType(n.value.tpe, target, presets.contains(target)) + + getRef(n.value, ReferenceTarget("","", Seq.empty, "", Seq.empty)) match { + case rhs: ReferenceTarget => + if (presets.contains(rhs) || asyncRegMap.contains(rhs)) { + if(asyncRegMap.contains(rhs)){ + asyncRegMap(rhs) += target + } else { + asyncToAnnotate += target + } + } + case rhs: GenericTarget => //nothing to do + case _ => Utils.throwInternalError("Unexpected Reference kind") + } + } + + def processStatements(statement: Statement): Unit = { + statement match { + case i : WDefInstance => + localInstances(moduleTarget.ref(i.name)) = i.module + case r : DefRegister => processRegister(r) + case w : DefWire => processWire(w) + case n : DefNode => processNode(n) + case c : Connect => processConnect(c) + case s => s.foreachStmt(processStatements) + } + } + + def processPorts(port: Port): Unit = { + if(port.tpe == AsyncResetType){ + val target = moduleTarget.ref(port.name) + asyncRegMap(target) = new TargetSet() + asyncCoMap(target) = new TargetSet() + if (presets.contains(target)) { + asyncToAnnotate += target + toCleanUp += target + } + } + } + + m match { + case module: firrtl.ir.Module => + module.foreachPort(processPorts) + processStatements(module.body) + case _ => + } + } + + /* + * WALK I PHASE 2 FUNCTIONS + */ + + /** Annotate a given target and all its children according to the asyncCoMap */ + def annotateCo(ta: ReferenceTarget){ + if (asyncCoMap.contains(ta)){ + toCleanUp += ta + asyncCoMap(ta) foreach( (t: ReferenceTarget) => { + toCleanUp += t + }) + } + } + + /** Annotate all registers somehow connected to the orignal annotated async reset */ + def annotateRegSet(set: TargetSet) : Unit = { + set foreach ( (ta: ReferenceTarget) => { + annotateCo(ta) + if (asyncRegMap.contains(ta)) { + annotateRegSet(asyncRegMap(ta)) + } else { + annos += new PresetRegAnnotation(ta) + } + }) + } + + /** + * Walk AsyncReset Trees with all Annotated AsyncReset as entry points + * Annotate all leaf registers and intermediate wires, nodes, connectors along the way + */ + def annotateAsyncSet(set: TargetSet) : Unit = { + set foreach ((t: ReferenceTarget) => { + annotateCo(t) + if (asyncRegMap.contains(t)) + annotateRegSet(asyncRegMap(t)) + }) + } + + /* + * MAIN + */ + + cs.circuit.foreachModule(processModule) // PHASE 1 : Initialize + annotateAsyncSet(asyncToAnnotate) // PHASE 2 : Annotate + annos + } + + /* + * WALK II FUNCTIONS + */ + + /** + * Clean-up useless reset tree (not relying on DCE) + * Disconnect preset registers from their reset tree + */ + private def cleanUpPresetTree(circuit: Circuit, annos: AnnotationSeq) : Circuit = { + val presetRegs = annos.collect {case a : PresetRegAnnotation => a}.groupBy(_.target) + val circuitTarget = CircuitTarget(circuit.main) + + def processModule(m: DefModule): DefModule = { + val moduleTarget = circuitTarget.module(m.name) + val localInstances = new TargetMap() + + def getRef(e: Expression) : Target = { + e match { + case w: WRef => moduleTarget.ref(w.name) + case w: WSubField => + getRef(w.expr) match { + case rt: ReferenceTarget => + if(localInstances.contains(rt)){ + circuitTarget.module(localInstances(rt)).ref(w.name) + } else { + rt.field(w.name) + } + case remote_target => remote_target + } + case w: WSubIndex => + getRef(w.expr) match { + case remote_target: ReferenceTarget => remote_target.index(w.value) + case _ => Utils.throwInternalError("Unexpected Reference kind") + } + case DoPrim(op, args, _, _) => + op match { + case AsInterval | AsUInt | AsSInt | AsClock | AsFixedPoint | AsAsyncReset => getRef(args.head) + case _ => Target(None, None, Seq.empty) + } + case _ => Target(None, None, Seq.empty) + } + } + + + def processRegister(r: DefRegister) : DefRegister = { + if (presetRegs.contains(moduleTarget.ref(r.name))) { + r.copy(reset = UIntLiteral(0)) + } else { + r + } + } + + def processWire(w: DefWire) : Statement = { + if (toCleanUp.contains(moduleTarget.ref(w.name))) { + EmptyStmt + } else { + w + } + } + + def processNode(n: DefNode) : Statement = { + if (toCleanUp.contains(moduleTarget.ref(n.name))) { + EmptyStmt + } else { + getRef(n.value) match { + case rt : ReferenceTarget if(toCleanUp.contains(rt)) => + throw TreeCleanUpOrphanException(s"Orphan (${moduleTarget.ref(n.name)}) the way.") + case _ => n + } + } + } + + def processConnect(c: Connect): Statement = { + getRef(c.expr) match { + case rhs: ReferenceTarget if (toCleanUp.contains(rhs)) => + getRef(c.loc) match { + case lhs : ReferenceTarget if(!toCleanUp.contains(lhs)) => + throw TreeCleanUpOrphanException(s"Orphan ${lhs} connected deleted node $rhs.") + case _ => EmptyStmt + } + case _ => c + } + } + + def processInstance(i: WDefInstance) : WDefInstance = { + localInstances(moduleTarget.ref(i.name)) = i.module + val tpe = i.tpe match { + case b: BundleType => + val inst = moduleTarget.instOf(i.name, i.module).asReference + BundleType(b.fields.filterNot(p => toCleanUp.contains(inst.field(p.name)))) + case other => other + } + i.copy(tpe = tpe) + } + + def processStatements(statement: Statement): Statement = { + statement match { + case i : WDefInstance => processInstance(i) + case r : DefRegister => processRegister(r) + case w : DefWire => processWire(w) + case n : DefNode => processNode(n) + case c : Connect => processConnect(c) + case s => s.mapStmt(processStatements) + } + } + + m match { + case module: firrtl.ir.Module => + val ports = module.ports.filterNot(p => toCleanUp.contains(moduleTarget.ref(p.name))) + module.copy(body = processStatements(module.body), ports = ports) + case _ => m + } + } + circuit.mapModule(processModule) + } + + def execute(state: CircuitState): CircuitState = { + // Collect all user-defined PresetAnnotation + val presets = state.annotations + .collect{ case m : PresetAnnotation => m } + + // No PresetAnnotation => no need to walk the IR + if (presets.size == 0){ + state + } else { + // PHASE I - Propagate + val annos = propagate(state, presets) + // PHASE II - CleanUp + val cleanCircuit = cleanUpPresetTree(state.circuit, annos) + // Because toCleanup is a class field, we need to clear it + // TODO refactor so that toCleanup is not a class field + toCleanUp.clear() + state.copy(annotations = annos, circuit = cleanCircuit) + } + } +} diff --git a/src/test/resources/features/PresetTester.fir b/src/test/resources/features/PresetTester.fir new file mode 100644 index 00000000..a2395c99 --- /dev/null +++ b/src/test/resources/features/PresetTester.fir @@ -0,0 +1,51 @@ + +circuit PresetTester : + + module Test : + input clock : Clock + input reset : AsyncReset + input x : UInt<4> + output z : UInt<4> + reg r : UInt<4>, clock with : (reset => (reset, UInt(12))) + r <= x + z <= r + + module PresetTester : + input clock : Clock + input reset : UInt<1> + + reg div : UInt<2>, clock with : (reset => (reset, UInt(0))) + div <= tail(add(div, UInt(1)), 1) + + reg slowClkReg : UInt<1>, clock with : (reset => (reset, UInt(0))) + slowClkReg <= eq(div, UInt(0)) + node slowClk = asClock(slowClkReg) + + reg counter : UInt<4>, clock with : (reset => (reset, UInt(0))) + counter <= tail(add(counter, UInt(1)), 1) + + reg x : UInt<5>, slowClk with : (reset => (reset, UInt(9))) + wire z : UInt<5> + + wire preset : AsyncReset + preset <= asAsyncReset(UInt(0)) ; should be annotated as Preset + + inst i of Test + i.clock <= slowClk + i.reset <= preset + i.x <= x + z <= i.z + + when eq(counter, UInt(0)) : + when neq(z, UInt(12)) : + printf(clock, UInt(1), "Assertion 1 failed! z=%d \n",z) + stop(clock, UInt(1), 1) + ; Do the async reset + when eq(counter, UInt(1)) : + when neq(z, UInt(9)) : + printf(clock, UInt(1), "Assertion 2 failed! z=%d \n",z) + stop(clock, UInt(1), 1) + ; Success! + when eq(counter, UInt(3)) : + stop(clock, UInt(1), 0) + diff --git a/src/test/scala/firrtlTests/FirrtlSpec.scala b/src/test/scala/firrtlTests/FirrtlSpec.scala index fe94a643..1eea3671 100644 --- a/src/test/scala/firrtlTests/FirrtlSpec.scala +++ b/src/test/scala/firrtlTests/FirrtlSpec.scala @@ -310,9 +310,9 @@ class TestFirrtlFlatSpec extends FirrtlFlatSpec { } /** Super class for execution driven Firrtl tests */ -abstract class ExecutionTest(name: String, dir: String, vFiles: Seq[String] = Seq.empty) extends FirrtlPropSpec { +abstract class ExecutionTest(name: String, dir: String, vFiles: Seq[String] = Seq.empty, annotations: AnnotationSeq = Seq.empty) extends FirrtlPropSpec { property(s"$name should execute correctly") { - runFirrtlTest(name, dir, vFiles) + runFirrtlTest(name, dir, vFiles, annotations = annotations) } } /** Super class for compilation driven Firrtl tests */ diff --git a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala index f3183599..dcc4e48d 100644 --- a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala +++ b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala @@ -203,6 +203,7 @@ class LoweringCompilersSpec extends FlatSpec with Matchers { new firrtl.transforms.FixAddingNegativeLiterals, new firrtl.transforms.ReplaceTruncatingArithmetic, new firrtl.transforms.InlineBitExtractionsTransform, + new firrtl.transforms.PropagatePresetAnnotations, new firrtl.transforms.InlineCastsTransform, new firrtl.transforms.LegalizeClocksTransform, new firrtl.transforms.FlattenRegUpdate, @@ -222,6 +223,7 @@ class LoweringCompilersSpec extends FlatSpec with Matchers { new firrtl.transforms.FixAddingNegativeLiterals, new firrtl.transforms.ReplaceTruncatingArithmetic, new firrtl.transforms.InlineBitExtractionsTransform, + new firrtl.transforms.PropagatePresetAnnotations, new firrtl.transforms.InlineCastsTransform, new firrtl.transforms.LegalizeClocksTransform, new firrtl.transforms.FlattenRegUpdate, diff --git a/src/test/scala/firrtlTests/PresetSpec.scala b/src/test/scala/firrtlTests/PresetSpec.scala new file mode 100644 index 00000000..d35aa69f --- /dev/null +++ b/src/test/scala/firrtlTests/PresetSpec.scala @@ -0,0 +1,239 @@ +// See LICENSE for license details. + +package firrtlTests + +import firrtl._ +import FirrtlCheckers._ +import firrtl.annotations._ + +class PresetSpec extends FirrtlFlatSpec { + type Mod = Seq[String] + type ModuleSeq = Seq[Mod] + def compile(input: String, annos: AnnotationSeq): CircuitState = + (new VerilogCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos), List.empty) + def compileBody(modules: ModuleSeq) = { + val annos = Seq(new PresetAnnotation(CircuitTarget("Test").module("Test").ref("reset")), firrtl.transforms.NoDCEAnnotation) + var str = """ + |circuit Test : + |""".stripMargin + modules foreach ((m: Mod) => { + val header = "|module " + m(0) + " :" + str += header.stripMargin.stripMargin.split("\n").mkString(" ", "\n ", "") + str += m(1).split("\n").mkString(" ", "\n ", "") + str += """ + |""".stripMargin + }) + compile(str,annos) + } + + "Preset" should """behave properly given a `Preset` annotated `AsyncReset` INPUT reset: + - replace AsyncReset specific blocks by standard Register blocks + - add inline declaration of all registers connected to reset + - remove the useless input port""" in { + val result = compileBody(Seq(Seq("Test",s""" + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1> + |output z : UInt<1> + |reg r : UInt<1>, clock with : (reset => (reset, UInt(0))) + |r <= x + |z <= r""".stripMargin)) + ) + result shouldNot containLine ("always @(posedge clock or posedge reset) begin") + result shouldNot containLines ( + "if (reset) begin", + "r = 1'h0;", + "end") + result shouldNot containLine ("input reset,") + result should containLine ("always @(posedge clock) begin") + result should containLine ("reg r = 1'h0;") + } + + it should """behave properly given a `Preset` annotated `AsyncReset` WIRE reset: + - replace AsyncReset specific blocks by standard Register blocks + - add inline declaration of all registers connected to reset + - remove the useless wire declaration and assignation""" in { + val result = compileBody(Seq(Seq("Test",s""" + |input clock : Clock + |input x : UInt<1> + |output z : UInt<1> + |wire reset : AsyncReset + |reset <= asAsyncReset(UInt(0)) + |reg r : UInt<1>, clock with : (reset => (reset, UInt(0))) + |r <= x + |z <= r""".stripMargin)) + ) + result shouldNot containLine ("always @(posedge clock or posedge reset) begin") + result shouldNot containLines ( + "if (reset) begin", + "r = 1'h0;", + "end") + result should containLine ("always @(posedge clock) begin") + result should containLine ("reg r = 1'h0;") + // it should also remove useless asyncReset signal, all along the path down to registers + result shouldNot containLine ("wire reset;") + result shouldNot containLine ("assign reset = 1'h0;") + } + it should "raise TreeCleanUpOrphantException on cast of annotated AsyncReset" in { + an [firrtl.transforms.PropagatePresetAnnotations.TreeCleanUpOrphanException] shouldBe thrownBy { + compileBody(Seq(Seq("Test",s""" + |input clock : Clock + |input x : UInt<1> + |output z : UInt<1> + |output sz : UInt<1> + |wire reset : AsyncReset + |reset <= asAsyncReset(UInt(0)) + |reg r : UInt<1>, clock with : (reset => (reset, UInt(0))) + |wire sreset : UInt<1> + |sreset <= asUInt(reset) ; this is FORBIDDEN + |reg s : UInt<1>, clock with : (reset => (sreset, UInt(0))) + |r <= x + |s <= x + |z <= r + |sz <= s""".stripMargin)) + ) + } + } + + it should "propagate through bundles" in { + val result = compileBody(Seq(Seq("Test",s""" + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1> + |output z : UInt<1> + |wire bundle : {in_rst: AsyncReset, out_rst:AsyncReset} + |bundle.in_rst <= reset + |bundle.out_rst <= bundle.in_rst + |reg r : UInt<1>, clock with : (reset => (bundle.out_rst, UInt(0))) + |r <= x + |z <= r""".stripMargin)) + ) + result shouldNot containLine ("input reset,") + result shouldNot containLine ("always @(posedge clock or posedge reset) begin") + result shouldNot containLines ( + "if (reset) begin", + "r = 1'h0;", + "end") + result should containLine ("always @(posedge clock) begin") + result should containLine ("reg r = 1'h0;") + } + it should "propagate through vectors" in { + val result = compileBody(Seq(Seq("Test",s""" + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1> + |output z : UInt<1> + |wire vector : AsyncReset[2] + |vector[0] <= reset + |vector[1] <= vector[0] + |reg r : UInt<1>, clock with : (reset => (vector[1], UInt(0))) + |r <= x + |z <= r""".stripMargin)) + ) + result shouldNot containLine ("input reset,") + result shouldNot containLine ("always @(posedge clock or posedge reset) begin") + result shouldNot containLines ( + "if (reset) begin", + "r = 1'h0;", + "end") + result should containLine ("always @(posedge clock) begin") + result should containLine ("reg r = 1'h0;") + } + + it should "propagate through bundles of vectors" in { + val result = compileBody(Seq(Seq("Test",s""" + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1> + |output z : UInt<1> + |wire bundle : {in_rst: AsyncReset[2], out_rst:AsyncReset} + |bundle.in_rst[0] <= reset + |bundle.in_rst[1] <= bundle.in_rst[0] + |bundle.out_rst <= bundle.in_rst[1] + |reg r : UInt<1>, clock with : (reset => (bundle.out_rst, UInt(0))) + |r <= x + |z <= r""".stripMargin)) + ) + result shouldNot containLine ("input reset,") + result shouldNot containLine ("always @(posedge clock or posedge reset) begin") + result shouldNot containLines ( + "if (reset) begin", + "r = 1'h0;", + "end") + result should containLine ("always @(posedge clock) begin") + result should containLine ("reg r = 1'h0;") + } + it should """propagate properly accross modules: + - replace AsyncReset specific blocks by standard Register blocks + - add inline declaration of all registers connected to reset + - remove the useless input port of instanciated module + - remove the useless instance connections + - remove wires and assignations used in instance connections + """ in { + val result = compileBody(Seq( + Seq("TestA",s""" + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1> + |output z : UInt<1> + |reg r : UInt<1>, clock with : (reset => (reset, UInt(0))) + |r <= x + |z <= r + |""".stripMargin), + Seq("Test",s""" + |input clock : Clock + |input x : UInt<1> + |output z : UInt<1> + |wire reset : AsyncReset + |reset <= asAsyncReset(UInt(0)) + |inst i of TestA + |i.clock <= clock + |i.reset <= reset + |i.x <= x + |z <= i.z""".stripMargin) + )) + // assess that all useless connections are not emitted + result shouldNot containLine ("wire i_reset;") + result shouldNot containLine (".reset(i_reset),") + result shouldNot containLine ("assign i_reset = reset;") + result shouldNot containLine ("input reset,") + + result shouldNot containLine ("always @(posedge clock or posedge reset) begin") + result shouldNot containLines ( + "if (reset) begin", + "r = 1'h0;", + "end") + result should containLine ("always @(posedge clock) begin") + result should containLine ("reg r = 1'h0;") + } + + it should "propagate even through disordonned statements" in { + val result = compileBody(Seq(Seq("Test",s""" + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1> + |output z : UInt<1> + |wire bundle : {in_rst: AsyncReset, out_rst:AsyncReset} + |reg r : UInt<1>, clock with : (reset => (bundle.out_rst, UInt(0))) + |bundle.out_rst <= bundle.in_rst + |bundle.in_rst <= reset + |r <= x + |z <= r""".stripMargin)) + ) + result shouldNot containLine ("input reset,") + result shouldNot containLine ("always @(posedge clock or posedge reset) begin") + result shouldNot containLines ( + "if (reset) begin", + "r = 1'h0;", + "end") + result should containLine ("always @(posedge clock) begin") + result should containLine ("reg r = 1'h0;") + } + +} + +class PresetExecutionTest extends ExecutionTest( + "PresetTester", + "/features", + annotations = Seq(new PresetAnnotation(CircuitTarget("PresetTester").module("PresetTester").ref("preset"))) +) |
