diff options
Diffstat (limited to 'src')
5 files changed, 144 insertions, 15 deletions
diff --git a/src/main/scala/firrtl/stage/FirrtlAnnotations.scala b/src/main/scala/firrtl/stage/FirrtlAnnotations.scala index 8f5ee3e1..9e6fefca 100644 --- a/src/main/scala/firrtl/stage/FirrtlAnnotations.scala +++ b/src/main/scala/firrtl/stage/FirrtlAnnotations.scala @@ -5,7 +5,7 @@ package firrtl.stage import firrtl._ import firrtl.ir.Circuit import firrtl.annotations.{Annotation, NoTargetAnnotation} -import firrtl.options.{HasShellOptions, OptionsException, ShellOption, Unserializable} +import firrtl.options.{Dependency, HasShellOptions, OptionsException, ShellOption, Unserializable} import java.io.FileNotFoundException import java.nio.file.NoSuchFileException @@ -313,3 +313,46 @@ object DisableFold extends HasShellOptions { ) } + +/** Indicate to the FIRRTL compiler that specific transforms have already been run. + * + * The intended use of this is for advanced users who want to skip specific transforms in the FIRRTL compiler. It is + * far safer for users to use the command line options to the FIRRTL compiler via `--start-from = <form>`. + * @param currentState a sequence of transforms that have already been run on the circuit + */ +case class CurrentFirrtlStateAnnotation(currentState: Seq[TransformDependency]) + extends NoTargetAnnotation + with FirrtlOption + +private[stage] object CurrentFirrtlStateAnnotation extends HasShellOptions { + + /** This is just the transforms necessary for resolving types and checking that everything is okay. */ + private val dontSkip: Set[TransformDependency] = Set( + Dependency[firrtl.stage.transforms.CheckScalaVersion], + Dependency(passes.ResolveKinds), + Dependency(passes.InferTypes), + Dependency(passes.ResolveFlows) + ) ++ Forms.Checks + + override val options = Seq( + new ShellOption[String]( + longOption = "start-from", + toAnnotationSeq = a => + (a match { + case "chirrtl" => Seq.empty + case "mhigh" => Forms.MinimalHighForm + case "high" => Forms.HighForm + case "middle" => Forms.MidForm + case "low" => Forms.LowForm + case "low-opt" => Forms.LowFormOptimized + case _ => throw new OptionsException(s"Unknown start-from argument '$a'! (Did you misspell it?)") + }).filterNot(dontSkip) match { + case b if a.isEmpty => Seq.empty + case b => Seq(CurrentFirrtlStateAnnotation(b)) + }, + helpText = "", + helpValueName = Some("<chirrtl|mhigh|high|middle|low|low-opt>") + ) + ) + +} diff --git a/src/main/scala/firrtl/stage/FirrtlCli.scala b/src/main/scala/firrtl/stage/FirrtlCli.scala index 9cfa6be9..8f84ff18 100644 --- a/src/main/scala/firrtl/stage/FirrtlCli.scala +++ b/src/main/scala/firrtl/stage/FirrtlCli.scala @@ -23,7 +23,8 @@ trait FirrtlCli { this: Shell => WarnNoScalaVersionDeprecation, PrettyNoExprInlining, DisableFold, - OptimizeForFPGA + OptimizeForFPGA, + CurrentFirrtlStateAnnotation ) .map(_.addOptions(parser)) diff --git a/src/main/scala/firrtl/stage/package.scala b/src/main/scala/firrtl/stage/package.scala index 68e7a9c5..a22d299a 100644 --- a/src/main/scala/firrtl/stage/package.scala +++ b/src/main/scala/firrtl/stage/package.scala @@ -35,6 +35,7 @@ package object stage { case WarnNoScalaVersionDeprecation => c case PrettyNoExprInlining => c case _: DisableFold => c + case CurrentFirrtlStateAnnotation(a) => c } } } diff --git a/src/main/scala/firrtl/stage/phases/Compiler.scala b/src/main/scala/firrtl/stage/phases/Compiler.scala index 24848f36..d24775e7 100644 --- a/src/main/scala/firrtl/stage/phases/Compiler.scala +++ b/src/main/scala/firrtl/stage/phases/Compiler.scala @@ -4,23 +4,31 @@ package firrtl.stage.phases import firrtl.{AnnotationSeq, ChirrtlForm, CircuitState, Compiler => FirrtlCompiler, Transform, seqToAnnoSeq} import firrtl.options.{Dependency, Phase, PhasePrerequisiteException, Translator} -import firrtl.stage.{CompilerAnnotation, FirrtlCircuitAnnotation, Forms, RunFirrtlTransformAnnotation} +import firrtl.stage.{ + CompilerAnnotation, + CurrentFirrtlStateAnnotation, + FirrtlCircuitAnnotation, + Forms, + RunFirrtlTransformAnnotation +} import firrtl.stage.TransformManager.TransformDependency import scala.collection.mutable /** An encoding of the information necessary to run the FIRRTL compiler once */ private[stage] case class CompilerRun( - stateIn: CircuitState, - stateOut: Option[CircuitState], - transforms: Seq[Transform], - compiler: Option[FirrtlCompiler]) + stateIn: CircuitState, + stateOut: Option[CircuitState], + transforms: Seq[Transform], + compiler: Option[FirrtlCompiler], + currentState: Seq[TransformDependency]) /** An encoding of possible defaults for a [[CompilerRun]] */ private[stage] case class Defaults( - annotations: AnnotationSeq = Seq.empty, - transforms: Seq[Transform] = Seq.empty, - compiler: Option[FirrtlCompiler] = None) + annotations: AnnotationSeq = Seq.empty, + transforms: Seq[Transform] = Seq.empty, + compiler: Option[FirrtlCompiler] = None, + currentState: Seq[TransformDependency] = Seq.empty) /** Runs the FIRRTL compilers on an [[AnnotationSeq]]. If the input [[AnnotationSeq]] contains more than one circuit * (i.e., more than one [[firrtl.stage.FirrtlCircuitAnnotation FirrtlCircuitAnnotation]]), then annotations will be @@ -64,7 +72,13 @@ class Compiler extends Phase with Translator[AnnotationSeq, Seq[CompilerRun]] { a.foldLeft(Defaults()) { case (d, FirrtlCircuitAnnotation(circuit)) => foundFirstCircuit = true - CompilerRun(CircuitState(circuit, ChirrtlForm, d.annotations, None), None, d.transforms, d.compiler) +=: c + CompilerRun( + CircuitState(circuit, ChirrtlForm, d.annotations, None), + None, + d.transforms, + d.compiler, + d.currentState + ) +=: c d case (d, a) if foundFirstCircuit => a match { @@ -74,6 +88,9 @@ class Compiler extends Phase with Translator[AnnotationSeq, Seq[CompilerRun]] { case CompilerAnnotation(compiler) => c(0) = c(0).copy(compiler = Some(compiler)) d + case CurrentFirrtlStateAnnotation(currentState) => + c(0) = c(0).copy(currentState = currentState ++ c(0).currentState) + d case annotation => val state = c(0).stateIn c(0) = c(0).copy(stateIn = state.copy(annotations = annotation +: state.annotations)) @@ -81,9 +98,10 @@ class Compiler extends Phase with Translator[AnnotationSeq, Seq[CompilerRun]] { } case (d, a) if !foundFirstCircuit => a match { - case RunFirrtlTransformAnnotation(transform) => d.copy(transforms = transform +: d.transforms) - case CompilerAnnotation(compiler) => d.copy(compiler = Some(compiler)) - case annotation => d.copy(annotations = annotation +: d.annotations) + case RunFirrtlTransformAnnotation(transform) => d.copy(transforms = transform +: d.transforms) + case CompilerAnnotation(compiler) => d.copy(compiler = Some(compiler)) + case CurrentFirrtlStateAnnotation(currentState) => d.copy(currentState = currentState ++ d.currentState) + case annotation => d.copy(annotations = annotation +: d.annotations) } } c.toSeq @@ -110,7 +128,7 @@ class Compiler extends Phase with Translator[AnnotationSeq, Seq[CompilerRun]] { c.transforms.reverse.map(Dependency.fromTransform) } } - val tm = new firrtl.stage.transforms.Compiler(targets) + val tm = new firrtl.stage.transforms.Compiler(targets, c.currentState) /* Transform order is lazily evaluated. Force it here to remove its resolution time from actual compilation. */ val (timeResolveDependencies, _) = firrtl.Utils.time { tm.flattenedTransformOrder } logger.info(f"Computed transform order in: $timeResolveDependencies%.1f ms") diff --git a/src/test/scala/firrtl/stage/CurrentFirrtlStateAnnotationSpec.scala b/src/test/scala/firrtl/stage/CurrentFirrtlStateAnnotationSpec.scala new file mode 100644 index 00000000..121a4e81 --- /dev/null +++ b/src/test/scala/firrtl/stage/CurrentFirrtlStateAnnotationSpec.scala @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.stage + +import firrtl.options.Dependency +import firrtl.stage.transforms.Compiler +import firrtl.stage.TransformManager.TransformDependency +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class CurrentFirrtlStateAnnotationSpec extends AnyFlatSpec with Matchers { + + def getTransforms(input: String): Seq[TransformDependency] = { + val currentState = CurrentFirrtlStateAnnotation + .options(0) + .toAnnotationSeq(input) + .collectFirst { + case CurrentFirrtlStateAnnotation(currentState) => currentState + } + .get + new Compiler(Forms.VerilogOptimized, currentState).flattenedTransformOrder.map(Dependency.fromTransform) + } + + behavior.of("CurrentFirrtlStateAnnotation") + + it should "produce an expected transform order for CHIRRTL -> Verilog" in { + getTransforms("chirrtl") should contain(Dependency(firrtl.passes.CheckChirrtl)) + } + + it should "produce an expected transform order for minimum high FIRRTL -> Verilog" in { + val transforms = getTransforms("mhigh") + transforms should not contain noneOf(Dependency(firrtl.passes.CheckChirrtl), Dependency(firrtl.passes.InferTypes)) + transforms should contain(Dependency(firrtl.passes.CheckHighForm)) + } + + it should "produce an expected transform order for high FIRRTL -> Verilog" in { + val transforms = getTransforms("high") + transforms should not contain (Dependency[firrtl.transforms.DedupModules]) + (transforms should contain).allOf( + Dependency(firrtl.passes.InferTypes), + Dependency[firrtl.passes.ExpandWhensAndCheck] + ) + } + + it should "produce an expected transform order for middle FIRRTL -> Verilog" in { + val transforms = getTransforms("middle") + transforms should not contain (Dependency[firrtl.passes.ExpandWhensAndCheck]) + (transforms should contain).allOf(Dependency(firrtl.passes.InferTypes), Dependency(firrtl.passes.LowerTypes)) + } + + it should "produce an expected transform order for low FIRRTL -> Verilog" in { + val transforms = getTransforms("low") + transforms should not contain (Dependency(firrtl.passes.LowerTypes)) + (transforms should contain).allOf( + Dependency(firrtl.passes.InferTypes), + Dependency(firrtl.passes.CommonSubexpressionElimination) + ) + } + + it should "produce an expected transform order for optimized low FIRRTL -> Verilog" in { + val transforms = getTransforms("low-opt") + transforms should not contain (Dependency(firrtl.passes.CommonSubexpressionElimination)) + (transforms should contain).allOf(Dependency(firrtl.passes.InferTypes), Dependency[firrtl.transforms.VerilogRename]) + } + +} |
