diff options
| author | jackkoenig | 2016-10-20 00:19:01 -0700 |
|---|---|---|
| committer | Jack Koenig | 2016-11-04 13:29:09 -0700 |
| commit | 8fa9429a6e916ab2a789f5d81fa803b022805b52 (patch) | |
| tree | fac2efcbd0a68bfb1916f09afc7f003c7a3d6528 /src/main/scala/firrtl/passes | |
| parent | 62133264a788f46b319ebab9c31424b7e0536101 (diff) | |
Refactor Compilers and Transforms
* Transform Ids now handled by Class[_ <: Transform] instead of magic numbers
* Transforms define inputForm and outputForm
* Custom transforms can be inserted at runtime into compiler or the Driver
* Current "built-in" custom transforms handled via above mechanism
* Verilog-specific passes moved to the Verilog emitter
Diffstat (limited to 'src/main/scala/firrtl/passes')
6 files changed, 131 insertions, 102 deletions
diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala index 5c80baff..c741dc06 100644 --- a/src/main/scala/firrtl/passes/Inline.scala +++ b/src/main/scala/firrtl/passes/Inline.scala @@ -9,34 +9,37 @@ import firrtl.Annotations._ import scala.collection.mutable // Tags an annotation to be consumed by this pass -case class InlineAnnotation(target: Named, tID: TransID) extends Annotation with Loose with Unstable { +case class InlineAnnotation(target: Named) extends Annotation with Loose with Unstable { def duplicate(n: Named) = this.copy(target=n) + def transform = classOf[InlineInstances] } // Only use on legal Firrtl. Specifically, the restriction of // instance loops must have been checked, or else this pass can // infinitely recurse -class InlineInstances (transID: TransID) extends Transform { +class InlineInstances extends Transform { + def inputForm = LowForm + def outputForm = LowForm val inlineDelim = "$" - def name = "Inline Instances" - def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult = { - annotationMap.get(transID) match { - case None => TransformResult(circuit, None, None) - case Some(map) => - val moduleNames = mutable.HashSet[ModuleName]() - val instanceNames = mutable.HashSet[ComponentName]() - map.values.foreach {x: Annotation => x match { - case InlineAnnotation(ModuleName(mod, cir), _) => moduleNames += ModuleName(mod, cir) - case InlineAnnotation(ComponentName(com, mod), _) => instanceNames += ComponentName(com, mod) - case _ => throw new PassException("Annotation must be InlineAnnotation") - }} - check(circuit, moduleNames.toSet, instanceNames.toSet) - run(circuit, moduleNames.toSet, instanceNames.toSet) + override def name = "Inline Instances" - // Default behavior is to error if more than one annotation for inlining - // This could potentially change - case _ => throw new PassException("Found more than one circuit annotation of InlineCAKind!") + private def collectAnns(anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) = + anns.foldLeft(Set.empty[ModuleName], Set.empty[ComponentName]) { + case ((modNames, instNames), ann) => ann match { + case InlineAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames) + case InlineAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod)) + case _ => throw new PassException("Annotation must be InlineAnnotation") + } } + + def execute(state: CircuitState): CircuitState = { + // TODO Add error check for more than one annotation for inlining + // TODO Propagate other annotations + val result = for { + myAnnotations <- getMyAnnotations(state) + (modNames, instNames) = collectAnns(myAnnotations.values) + } yield run(state.circuit, modNames, instNames) + result getOrElse state // Return state if nothing to do } // Checks the following properties: @@ -78,7 +81,10 @@ class InlineInstances (transID: TransID) extends Transform { if (errors.nonEmpty) throw new PassExceptions(errors) } - def run(c: Circuit, modsToInline: Set[ModuleName], instsToInline: Set[ComponentName]): TransformResult = { + def run(c: Circuit, modsToInline: Set[ModuleName], instsToInline: Set[ComponentName]): CircuitState = { + // Check annotations and circuit match up + check(c, modsToInline, instsToInline) + // ---- Rename functions/data ---- val renameMap = mutable.HashMap[Named,Seq[Named]]() // Updates renameMap with new names @@ -168,6 +174,6 @@ class InlineInstances (transID: TransID) extends Transform { val top = c.modules.find(m => m.name == c.main).get onModule(top) val modulesx = c.modules.map(m => inlinedModules(m.name)) - TransformResult(Circuit(c.info, modulesx, c.main), Some(RenameMap(renameMap.toMap)), None) + CircuitState(Circuit(c.info, modulesx, c.main), LowForm, None, Some(RenameMap(renameMap.toMap))) } } diff --git a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala index 10cc8f88..c98dd4ca 100644 --- a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala +++ b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala @@ -5,20 +5,22 @@ import ir._ import Annotations._ import wiring._ -class CreateMemoryAnnotations(reader: Option[YamlFileReader], replaceID: TransID, wiringID: TransID) extends Transform { - def name = "Create Memory Annotations" - def execute(c: Circuit, map: AnnotationMap): TransformResult = reader match { - case None => TransformResult(c) +class CreateMemoryAnnotations(reader: Option[YamlFileReader]) extends Transform { + def inputForm = MidForm + def outputForm = MidForm + override def name = "Create Memory Annotations" + def execute(state: CircuitState): CircuitState = reader match { + case None => state case Some(r) => import CustomYAMLProtocol._ r.parse[Config] match { case Seq(config) => - val cN = CircuitName(c.main) - val top = TopAnnotation(ModuleName(config.top.name, cN), wiringID) - val source = SourceAnnotation(ComponentName(config.source.name, ModuleName(config.source.module, cN)), wiringID) - val pin = PinAnnotation(cN, replaceID, config.pin.name) - TransformResult(c, None, Some(AnnotationMap(Seq(top, source, pin)))) - case Nil => TransformResult(c, None, None) + val cN = CircuitName(state.circuit.main) + val top = TopAnnotation(ModuleName(config.top.name, cN)) + val source = SourceAnnotation(ComponentName(config.source.name, ModuleName(config.source.module, cN))) + val pin = PinAnnotation(cN, config.pin.name) + state.copy(annotations = Some(AnnotationMap(Seq(top, source, pin)))) + case Nil => state case _ => error("Can only have one config in yaml file") } } diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala index 28291135..2d6f4e96 100644 --- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala +++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala @@ -38,10 +38,10 @@ import firrtl.passes.memlib.AnalysisUtils.{Connects, getConnects, getOrigin} import WrappedExpression.weq import Annotations._ -case class InferReadWriteAnnotation(t: String, tID: TransID) - extends Annotation with Loose with Unstable { +case class InferReadWriteAnnotation(t: String) extends Annotation with Loose with Unstable { val target = CircuitName(t) def duplicate(n: Named) = this.copy(t=n.name) + def transform = classOf[InferReadWrite] } // This pass examine the enable signals of the read & write ports of memories @@ -168,7 +168,9 @@ object InferReadWritePass extends Pass { // Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl" // To use this transform, circuit name should be annotated with its TransId. -class InferReadWrite(transID: TransID) extends Transform with SimpleRun { +class InferReadWrite extends Transform with PassBased { + def inputForm = MidForm + def outputForm = MidForm def passSeq = Seq( InferReadWritePass, CheckInitialization, @@ -176,11 +178,12 @@ class InferReadWrite(transID: TransID) extends Transform with SimpleRun { ResolveKinds, ResolveGenders ) - def execute(c: Circuit, map: AnnotationMap) = map get transID match { - case Some(p) => p get CircuitName(c.main) match { - case Some(InferReadWriteAnnotation(_, _)) => run(c, passSeq) - case _ => sys.error("Unexpected annotation for InferReadWrite") - } - case _ => TransformResult(c) + def execute(state: CircuitState): CircuitState = { + val result = for { + myAnnotations <- getMyAnnotations(state) + InferReadWriteAnnotation(_) <- myAnnotations get CircuitName(state.circuit.main) + resCircuit = runPasses(state.circuit) + } yield state.copy(circuit = resCircuit) + result getOrElse state // Return state if nothing to do } } diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala index 9ab496d2..ae872639 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala @@ -16,7 +16,8 @@ import wiring._ /** Annotates the name of the pin to add for WiringTransform */ -case class PinAnnotation(target: CircuitName, tID: TransID, pin: String) extends Annotation with Loose with Unstable { +case class PinAnnotation(target: CircuitName, pin: String) extends Annotation with Loose with Unstable { + def transform = classOf[ReplaceMemMacros] def duplicate(n: Named) = n match { case n: CircuitName => this.copy(target = n) case _ => throwInternalError @@ -27,8 +28,10 @@ case class PinAnnotation(target: CircuitName, tID: TransID, pin: String) extends * This will not generate wmask ports if not needed. * Creates the minimum # of black boxes needed by the design. */ -class ReplaceMemMacros(writer: ConfWriter, myID: TransID, wiringID: TransID) extends Transform { - def name = "Replace Memory Macros" +class ReplaceMemMacros(writer: ConfWriter) extends Transform { + override def name = "Replace Memory Macros" + def inputForm = MidForm + def outputForm = MidForm /** Return true if mask granularity is per bit, false if per byte or unspecified */ @@ -206,7 +209,8 @@ class ReplaceMemMacros(writer: ConfWriter, myID: TransID, wiringID: TransID) ext map updateStmtRefs(memPortMap)) } - def execute(c: Circuit, map: AnnotationMap): TransformResult = { + def execute(state: CircuitState): CircuitState = { + val c = state.circuit val namespace = Namespace(c) val memMods = new Modules val nameMap = new NameMap @@ -214,15 +218,15 @@ class ReplaceMemMacros(writer: ConfWriter, myID: TransID, wiringID: TransID) ext val modules = c.modules map updateMemMods(namespace, nameMap, memMods) // print conf writer.serialize() - val pin = map get myID match { - case Some(p) => + val pin = getMyAnnotations(state) match { + case Some(p) => p.values.head match { - case PinAnnotation(c, _, pin) => pin + case PinAnnotation(c, pin) => pin case _ => error(s"Bad Annotations: ${p.values}") } case None => "pin" } - val annos = memMods.collect { case m: ExtModule => SinkAnnotation(ModuleName(m.name, CircuitName(c.main)), wiringID, pin) } - TransformResult(c.copy(modules = modules ++ memMods), None, Some(AnnotationMap(annos))) - } + val annos = memMods.collect { case m: ExtModule => SinkAnnotation(ModuleName(m.name, CircuitName(c.main)), pin) } + CircuitState(c.copy(modules = modules ++ memMods), inputForm, Some(AnnotationMap(annos))) + } } diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala index 01f020f5..818bd9cc 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala @@ -61,8 +61,7 @@ class ConfWriter(filename: String) { } } -case class ReplSeqMemAnnotation(t: String, tID: TransID) - extends Annotation with Loose with Unstable { +case class ReplSeqMemAnnotation(t: String) extends Annotation with Loose with Unstable { val usage = """ [Optional] ReplSeqMem @@ -91,52 +90,60 @@ Optional Arguments: ) val target = CircuitName(passCircuit) def duplicate(n: Named) = this copy (t = t.replace(s"-c:$passCircuit", s"-c:${n.name}")) + def transform = classOf[ReplSeqMem] } -case class SimpleTransform(p: Pass) extends Transform { - def execute(c: Circuit, map: AnnotationMap): TransformResult = - TransformResult(p.run(c)) +class SimpleTransform(p: Pass, form: CircuitForm) extends Transform { + def inputForm = form + def outputForm = form + def execute(state: CircuitState): CircuitState = state.copy(circuit = p.run(state.circuit)) } -class ReplSeqMem(transID: TransID) extends Transform with SimpleRun { +class SimpleMidTransform(p: Pass) extends SimpleTransform(p, MidForm) + +// SimpleRun instead of PassBased because of the arguments to passSeq +class ReplSeqMem extends Transform with SimpleRun { + def inputForm = MidForm + def outputForm = MidForm def passSeq(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter): Seq[Transform] = - Seq(SimpleTransform(Legalize), - SimpleTransform(ToMemIR), - SimpleTransform(ResolveMaskGranularity), - SimpleTransform(RenameAnnotatedMemoryPorts), - SimpleTransform(ResolveMemoryReference), - new CreateMemoryAnnotations(inConfigFile, TransID(-7), TransID(-8)), - new ReplaceMemMacros(outConfigFile, TransID(-7), TransID(-8)), - new WiringTransform(TransID(-8)), - SimpleTransform(RemoveEmpty), - SimpleTransform(CheckInitialization), - SimpleTransform(InferTypes), - SimpleTransform(Uniquify), - SimpleTransform(ResolveKinds), - SimpleTransform(ResolveGenders)) - def run(circuit: Circuit, map: AnnotationMap, xForms: Seq[Transform]): TransformResult = { - (xForms.foldLeft(TransformResult(circuit, None, Some(map)))) { case (tr: TransformResult, xForm: Transform) => - val x = xForm.execute(tr.circuit, tr.annotation.get) - x.annotation match { - case None => TransformResult(x.circuit, None, Some(map)) - case Some(ann) => TransformResult(x.circuit, None, Some( - AnnotationMap(ann.annotations ++ tr.annotation.get.annotations))) + Seq(new SimpleMidTransform(Legalize), + new SimpleMidTransform(ToMemIR), + new SimpleMidTransform(ResolveMaskGranularity), + new SimpleMidTransform(RenameAnnotatedMemoryPorts), + new SimpleMidTransform(ResolveMemoryReference), + new CreateMemoryAnnotations(inConfigFile), + new ReplaceMemMacros(outConfigFile), + new WiringTransform, + new SimpleMidTransform(RemoveEmpty), + new SimpleMidTransform(CheckInitialization), + new SimpleMidTransform(InferTypes), + new SimpleMidTransform(Uniquify), + new SimpleMidTransform(ResolveKinds), + new SimpleMidTransform(ResolveGenders)) + def run(state: CircuitState, xForms: Seq[Transform]): CircuitState = { + xForms.foldLeft(state) { case (curState: CircuitState, xForm: Transform) => + val res = xForm.execute(state) + res.annotations match { + case None => CircuitState(res.circuit, res.form, state.annotations) + case Some(ann) => CircuitState(res.circuit, res.form, Some( + AnnotationMap(ann.annotations ++ curState.annotations.get.annotations))) } } } - def execute(c: Circuit, map: AnnotationMap) = map get transID match { - case Some(p) => p get CircuitName(c.main) match { - case Some(ReplSeqMemAnnotation(t, _)) => - val inputFileName = PassConfigUtil.getPassOptions(t).getOrElse(InputConfigFileName, "") - val inConfigFile = { - if (inputFileName.isEmpty) None - else if (new File(inputFileName).exists) Some(new YamlFileReader(inputFileName)) - else error("Input configuration file does not exist!") - } - val outConfigFile = new ConfWriter(PassConfigUtil.getPassOptions(t)(OutputConfigFileName)) - run(c, map, passSeq(inConfigFile, outConfigFile)) - case _ => error("Unexpected transform annotation") + def execute(state: CircuitState): CircuitState = + getMyAnnotations(state) match { + case Some(p) => p get CircuitName(state.circuit.main) match { + case Some(ReplSeqMemAnnotation(t)) => + val inputFileName = PassConfigUtil.getPassOptions(t).getOrElse(InputConfigFileName, "") + val inConfigFile = { + if (inputFileName.isEmpty) None + else if (new File(inputFileName).exists) Some(new YamlFileReader(inputFileName)) + else error("Input configuration file does not exist!") + } + val outConfigFile = new ConfWriter(PassConfigUtil.getPassOptions(t)(OutputConfigFileName)) + run(state, passSeq(inConfigFile, outConfigFile)) + case _ => error("Unexpected transform annotation") + } + case None => state // Do nothing if there are no annotations } - case _ => TransformResult(c) - } } diff --git a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala index 919948b6..59e76d65 100644 --- a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala +++ b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala @@ -11,7 +11,8 @@ import WiringUtils._ /** A component, e.g. register etc. Must be declared only once under the TopAnnotation */ -case class SourceAnnotation(target: ComponentName, tID: TransID) extends Annotation with Loose with Unstable { +case class SourceAnnotation(target: ComponentName) extends Annotation with Loose with Unstable { + def transform = classOf[WiringTransform] def duplicate(n: Named) = n match { case n: ComponentName => this.copy(target = n) case _ => throwInternalError @@ -20,7 +21,8 @@ case class SourceAnnotation(target: ComponentName, tID: TransID) extends Annotat /** A module, e.g. ExtModule etc., that should add the input pin */ -case class SinkAnnotation(target: ModuleName, tID: TransID, pin: String) extends Annotation with Loose with Unstable { +case class SinkAnnotation(target: ModuleName, pin: String) extends Annotation with Loose with Unstable { + def transform = classOf[WiringTransform] def duplicate(n: Named) = n match { case n: ModuleName => this.copy(target = n) case _ => throwInternalError @@ -30,7 +32,8 @@ case class SinkAnnotation(target: ModuleName, tID: TransID, pin: String) extends /** A module under which all sink module must be declared, and there is only * one source component */ -case class TopAnnotation(target: ModuleName, tID: TransID) extends Annotation with Loose with Unstable { +case class TopAnnotation(target: ModuleName) extends Annotation with Loose with Unstable { + def transform = classOf[WiringTransform] def duplicate(n: Named) = n match { case n: ModuleName => this.copy(target = n) case _ => throwInternalError @@ -49,13 +52,15 @@ case class TopAnnotation(target: ModuleName, tID: TransID) extends Annotation wi * Notes: * - No module uniquification occurs (due to imposed restrictions) */ -class WiringTransform(transID: TransID) extends Transform with SimpleRun { +class WiringTransform extends Transform with SimpleRun { + def inputForm = MidForm + def outputForm = MidForm def passSeq(wi: WiringInfo) = Seq(new Wiring(wi), InferTypes, ResolveKinds, ResolveGenders) - def execute(c: Circuit, map: AnnotationMap) = map get transID match { + def execute(state: CircuitState): CircuitState = getMyAnnotations(state) match { case Some(p) => val sinks = mutable.HashMap[String, String]() val sources = mutable.Set[String]() @@ -63,18 +68,20 @@ class WiringTransform(transID: TransID) extends Transform with SimpleRun { val comp = mutable.Set[String]() p.values.foreach { a => a match { - case SinkAnnotation(m, _, pin) => sinks(m.name) = pin - case SourceAnnotation(c, _) => + case SinkAnnotation(m, pin) => sinks(m.name) = pin + case SourceAnnotation(c) => sources += c.module.name comp += c.name - case TopAnnotation(m, _) => tops += m.name + case TopAnnotation(m) => tops += m.name } } (sources.size, tops.size, sinks.size, comp.size) match { - case (0, 0, p, 0) => TransformResult(c) - case (1, 1, p, 1) if p > 0 => run(c, passSeq(WiringInfo(sources.head, comp.head, sinks.toMap, tops.head))) + case (0, 0, p, 0) => state + case (1, 1, p, 1) if p > 0 => + val winfo = WiringInfo(sources.head, comp.head, sinks.toMap, tops.head) + state.copy(circuit = runPasses(state.circuit, passSeq(winfo))) case _ => error("Wrong number of sources, tops, or sinks!") } - case None => TransformResult(c) + case None => state } } |
