diff options
Diffstat (limited to 'src/main/scala/firrtl/Compiler.scala')
| -rw-r--r-- | src/main/scala/firrtl/Compiler.scala | 148 |
1 files changed, 86 insertions, 62 deletions
diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala index b0d42332..ea801086 100644 --- a/src/main/scala/firrtl/Compiler.scala +++ b/src/main/scala/firrtl/Compiler.scala @@ -110,61 +110,117 @@ final case object MidForm extends CircuitForm(1) * - All implicit truncations must be made explicit */ final case object LowForm extends CircuitForm(0) +/** Unknown Form + * + * Often passes may modify a circuit (e.g. InferTypes), but return + * a circuit in the same form it was given. + * + * For this use case, use UnknownForm. It cannot be compared against other + * forms. + * + * TODO(azidar): Replace with PreviousForm, which more explicitly encodes + * this requirement. + */ +final case object UnknownForm extends CircuitForm(-1) { + override def compare(that: CircuitForm): Int = { error("Illegal to compare UnknownForm"); 0 } +} /** The basic unit of operating on a Firrtl AST */ -abstract class Transform { +abstract class Transform extends LazyLogging { /** A convenience function useful for debugging and error messages */ def name: String = this.getClass.getSimpleName /** The [[CircuitForm]] that this transform requires to operate on */ def inputForm: CircuitForm /** The [[CircuitForm]] that this transform outputs */ def outputForm: CircuitForm - /** Perform the transform + /** Perform the transform, encode renaming with RenameMap, and can + * delete annotations + * Called by [[runTransform]]. * * @param state Input Firrtl AST * @return A transformed Firrtl AST */ - def execute(state: CircuitState): CircuitState + protected def execute(state: CircuitState): CircuitState /** Convenience method to get annotations relevant to this Transform * * @param state The [[CircuitState]] form which to extract annotations * @return A collection of annotations */ final def getMyAnnotations(state: CircuitState): Seq[Annotation] = state.annotations match { - case Some(annotations) => annotations.get(this.getClass) + case Some(annotations) => annotations.get(this.getClass) //TODO(azidar): ++ annotations.get(classOf[Transform]) case None => Nil } -} + /** Perform the transform and update annotations. + * + * @param state Input Firrtl AST + * @return A transformed Firrtl AST + */ + final def runTransform(state: CircuitState): CircuitState = { + logger.info(s"======== Starting Transform $name ========") + + val (timeMillis, result) = Utils.time { execute(state) } + + logger.info(s"""----------------------------${"-" * name.size}---------\n""") + logger.info(f"Time: $timeMillis%.1f ms") + + val remappedAnnotations = propagateAnnotations(state.annotations, result.annotations, result.renames) + + logger.info(s"Form: ${result.form}") + logger.debug(s"Annotations:") + remappedAnnotations.foreach { a => + logger.debug(a.serialize) + } + logger.debug(s"Circuit:\n${result.circuit.serialize}") + logger.info(s"======== Finished Transform $name ========\n") + + CircuitState(result.circuit, result.form, Some(AnnotationMap(remappedAnnotations)), None) + } -trait SimpleRun extends LazyLogging { - def runPasses(circuit: Circuit, passSeq: Seq[Pass]): Circuit = - passSeq.foldLeft(circuit) { (c: Circuit, pass: Pass) => - val name = pass.name - logger.info(s"-------- Starting Pass $name --------") - val (timeMillis, x) = Utils.time { pass.run(c) } - logger.info(f"Time: $timeMillis%.1f ms") - logger.debug(s"Circuit:\n${c.serialize}") - logger.info(s"-------- Finished Pass $name --------") - x + /** Propagate annotations and update their names. + * + * @param inAnno input AnnotationMap + * @param resAnno result AnnotationMap + * @param renameOpt result RenameMap + * @return the updated annotations + */ + final private def propagateAnnotations( + inAnno: Option[AnnotationMap], + resAnno: Option[AnnotationMap], + renameOpt: Option[RenameMap]): Seq[Annotation] = { + val newAnnotations = { + val inSet = inAnno.getOrElse(AnnotationMap(Seq.empty)).annotations.toSet + val resSet = resAnno.getOrElse(AnnotationMap(Seq.empty)).annotations.toSet + val deleted = (inSet -- resSet).map { + case DeletedAnnotation(xFormName, delAnno) => DeletedAnnotation(s"$xFormName+$name", delAnno) + case anno => DeletedAnnotation(name, anno) + } + val created = resSet -- inSet + val unchanged = resSet & inSet + (deleted ++ created ++ unchanged) } + + // For each annotation, rename all annotations. + val renames = renameOpt.getOrElse(RenameMap()).map + for { + anno <- newAnnotations.toSeq + newAnno <- anno.update(renames.getOrElse(anno.target, Seq(anno.target))) + } yield newAnno + } } -/** For PassBased Transforms and Emitters - * - * @note passSeq accepts no arguments - * @todo make passes accept CircuitState so annotations can pass data between them - */ -trait PassBased extends SimpleRun { - def passSeq: Seq[Pass] - def runPasses(circuit: Circuit): Circuit = runPasses(circuit, passSeq) +trait SeqTransformBased { + def transforms: Seq[Transform] + protected def runTransforms(state: CircuitState): CircuitState = + transforms.foldLeft(state) { (in, xform) => xform.runTransform(in) } } -/** For transformations that are simply a sequence of passes */ -abstract class PassBasedTransform extends Transform with PassBased { +/** For transformations that are simply a sequence of transforms */ +abstract class SeqTransform extends Transform with SeqTransformBased { def execute(state: CircuitState): CircuitState = { require(state.form <= inputForm, s"[$name]: Input form must be lower or equal to $inputForm. Got ${state.form}") - CircuitState(runPasses(state.circuit), outputForm, state.annotations) + val ret = runTransforms(state) + CircuitState(ret.circuit, outputForm, ret.annotations, ret.renames) } } @@ -174,7 +230,7 @@ trait Emitter extends Transform { def emit(state: CircuitState, writer: Writer): Unit } -object CompilerUtils { +object CompilerUtils extends LazyLogging { /** Generates a sequence of [[Transform]]s to lower a Firrtl circuit * * @param inputForm [[CircuitForm]] to lower from @@ -194,6 +250,7 @@ object CompilerUtils { new HighFirrtlToMiddleFirrtl) ++ getLoweringTransforms(MidForm, outputForm) case MidForm => Seq(new MiddleFirrtlToLowFirrtl) ++ getLoweringTransforms(LowForm, outputForm) case LowForm => throwInternalError // should be caught by if above + case UnknownForm => throwInternalError // should be caught by if above } } } @@ -318,42 +375,9 @@ trait Compiler extends LazyLogging { def compile(state: CircuitState, customTransforms: Seq[Transform]): CircuitState = { val allTransforms = CompilerUtils.mergeTransforms(transforms, customTransforms) :+ emitter - val finalState = allTransforms.foldLeft(state) { (in, xform) => - logger.info(s"======== Starting Transform ${xform.name} ========") - val (timeMillis, result) = Utils.time { xform.execute(in) } - - logger.info(f"Time: $timeMillis%.1f ms") - - val newAnnotations = { - val inSet = in.annotations.getOrElse(AnnotationMap(Seq.empty)).annotations.toSet - val resSet = result.annotations.getOrElse(AnnotationMap(Seq.empty)).annotations.toSet - val deleted = (inSet -- resSet).map { - case DeletedAnnotation(xFormName, delAnno) => DeletedAnnotation(s"${xFormName}+${xform.name}", delAnno) - case anno => DeletedAnnotation(xform.name, anno) - } - val created = resSet -- inSet - val unchanged = resSet & inSet - (deleted ++ created ++ unchanged) - } - - // For each annotation, rename all annotations. - val renames = result.renames.getOrElse(RenameMap()).map - val remappedAnnotations: Seq[Annotation] = for { - anno <- newAnnotations.toSeq - newAnno <- anno.update(renames.getOrElse(anno.target, Seq(anno.target))) - } yield newAnno - - logger.info(s"Form: ${result.form}") - logger.debug(s"Annotations:") - remappedAnnotations.foreach { a => - logger.debug(a.serialize) - } - logger.debug(s"Circuit:\n${result.circuit.serialize}") - logger.info(s"======== Finished Transform ${xform.name} ========\n") - - CircuitState(result.circuit, result.form, Some(AnnotationMap(remappedAnnotations))) - } + val finalState = allTransforms.foldLeft(state) { (in, xform) => xform.runTransform(in) } finalState } + } |
