diff options
| author | Adam Izraelevitz | 2017-02-23 13:28:49 -0800 |
|---|---|---|
| committer | Adam Izraelevitz | 2017-03-06 16:48:15 -0800 |
| commit | b5ef5b876d4f4ad4a17bc81362b2264970272d63 (patch) | |
| tree | d25820fb2e8c47caef2afc9ea4fc4f302feb156b /src | |
| parent | 2370185a9ba231fe0349091eb7f0926b61b15853 (diff) | |
Addresses #459. Rewords transform annotations API.
Now, any annotation not propagated by a transform is considered deleted.
A new DeletedAnnotation is added in place of it.
Diffstat (limited to 'src')
15 files changed, 98 insertions, 42 deletions
diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala index d1ca6320..780f5eb0 100644 --- a/src/main/scala/firrtl/Compiler.scala +++ b/src/main/scala/firrtl/Compiler.scala @@ -14,7 +14,7 @@ import Utils.throwInternalError * RenameMap maps old names to modified names. Generated by transformations * that modify names */ -case class RenameMap(map: Map[Named, Seq[Named]]) +case class RenameMap(map: Map[Named, Seq[Named]] = Map[Named, Seq[Named]]()) /** * Container of all annotations for a Firrtl compiler. @@ -153,7 +153,7 @@ abstract class PassBasedTransform extends Transform with PassBased { def execute(state: CircuitState): CircuitState = { require(state.form <= inputForm, s"[$name]: Input form must be lower or equal to $inputForm. Got ${state.form}") - CircuitState(runPasses(state.circuit), outputForm) + CircuitState(runPasses(state.circuit), outputForm, state.annotations) } } @@ -231,8 +231,7 @@ object CompilerUtils { } -trait Compiler { - // Emitter is still somewhat special because we want to make sure it is run last +trait Compiler extends LazyLogging { def emitter: Emitter /** The sequence of transforms this compiler will execute @@ -314,28 +313,31 @@ trait Compiler { val finalState = allTransforms.foldLeft(state) { (in, xform) => val result = Utils.time(s"***${xform.name}***") { xform.execute(in) } - // Annotation propagation - // TODO: This should be redone - val inAnnotationMap = in.annotations getOrElse AnnotationMap(Seq.empty) - val remappedAnnotations: Seq[Annotation] = result.renames match { - case Some(RenameMap(rmap)) => - // For each key in the rename map (rmap), obtain the - // corresponding annotations (in.annotationMap.get(from)). If any - // annotations exist, for each annotation, create a sequence of - // annotations with the names in rmap's value. - for { - (oldName, newNames) <- rmap.toSeq - oldAnno <- inAnnotationMap.get(oldName) - newAnno <- oldAnno.update(newNames) - } yield newAnno - case _ => inAnnotationMap.annotations + val newAnnotations = { + val inSet = in.annotations.getOrElse(AnnotationMap(Seq.empty)).annotations.toSet + val resSet = result.annotations.getOrElse(AnnotationMap(Seq.empty)).annotations.toSet + val deleted = (inSet -- resSet).map { + case DeletedAnnotation(xFormName, delAnno) => DeletedAnnotation(s"${xFormName}+${xform.name}", delAnno) + case anno => DeletedAnnotation(xform.name, anno) + } + val created = resSet -- inSet + val unchanged = resSet & inSet + (deleted ++ created ++ unchanged) } - val resultAnnotations: Seq[Annotation] = result.annotations match { - case None => Nil - case Some(p) => p.annotations + + // For each annotation, rename all annotations. + val renames = result.renames.getOrElse(RenameMap()).map + val remappedAnnotations: Seq[Annotation] = for { + anno <- newAnnotations.toSeq + newAnno <- anno.update(renames.getOrElse(anno.target, Seq(anno.target))) + } yield newAnno + logger.debug(s"*** ${xform.name} ***") + logger.debug(s"Form: ${result.form}") + logger.debug(result.circuit.serialize) + remappedAnnotations.foreach { a => + logger.debug(a.serialize) } - val newAnnotations = AnnotationMap(remappedAnnotations ++ resultAnnotations) - CircuitState(result.circuit, result.form, Some(newAnnotations)) + CircuitState(result.circuit, result.form, Some(AnnotationMap(remappedAnnotations))) } finalState } diff --git a/src/main/scala/firrtl/annotations/Annotation.scala b/src/main/scala/firrtl/annotations/Annotation.scala index 2e361833..9efc6e9b 100644 --- a/src/main/scala/firrtl/annotations/Annotation.scala +++ b/src/main/scala/firrtl/annotations/Annotation.scala @@ -3,6 +3,9 @@ package firrtl package annotations +import net.jcazevedo.moultingyaml._ +import firrtl.annotations.AnnotationYamlProtocol._ + case class AnnotationException(message: String) extends Exception(message) final case class Annotation(target: Named, transform: Class[_ <: Transform], value: String) { @@ -26,3 +29,15 @@ final case class Annotation(target: Named, transform: Class[_ <: Transform], val def check(from: Named, tos: Seq[Named], which: Annotation): Unit = {} def duplicate(n: Named) = Annotation(n, transform, value) } + +object DeletedAnnotation { + def apply(xFormName: String, anno: Annotation): Annotation = + Annotation(anno.target, classOf[Transform], s"""DELETED by $xFormName\n${AnnotationUtils.toYaml(anno)}""") + + private val deletedRegex = """(?s)DELETED by ([^\n]*)\n(.*)""".r + def unapply(a: Annotation): Option[Tuple2[String, Annotation]] = a match { + case Annotation(named, t, deletedRegex(xFormName, annoString)) if t == classOf[Transform] => + Some((xFormName, AnnotationUtils.fromYaml(annoString))) + case _ => None + } +} diff --git a/src/main/scala/firrtl/annotations/AnnotationUtils.scala b/src/main/scala/firrtl/annotations/AnnotationUtils.scala index 6e6af81d..8f55c13e 100644 --- a/src/main/scala/firrtl/annotations/AnnotationUtils.scala +++ b/src/main/scala/firrtl/annotations/AnnotationUtils.scala @@ -3,9 +3,15 @@ package firrtl package annotations +import net.jcazevedo.moultingyaml._ +import firrtl.annotations.AnnotationYamlProtocol._ + import firrtl.ir._ object AnnotationUtils { + def toYaml(a: Annotation): String = a.toYaml.prettyPrint + def fromYaml(s: String): Annotation = s.parseYaml.convertTo[Annotation] + /** Returns true if a valid Module name */ val SerializedModuleName = """([a-zA-Z_][a-zA-Z_0-9~!@#$%^*\-+=?/]*)""".r def validModuleName(s: String): Boolean = s match { @@ -34,6 +40,13 @@ object AnnotationUtils { case None => Seq(s) } + def toNamed(s: String): Named = tokenize(s) match { + case Seq(n) => CircuitName(n) + case Seq(c, m) => ModuleName(m, CircuitName(c)) + case Seq(c, m) => ModuleName(m, CircuitName(c)) + case Seq(c, m, x) => ComponentName(x, ModuleName(m, CircuitName(c))) + } + /** Given a serialized component/subcomponent reference, subindex, subaccess, * or subfield, return the corresponding IR expression. * E.g. "foo.bar" becomes SubField(Reference("foo", UnknownType), "bar", UnknownType) diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala index 93ec6cea..f4556733 100644 --- a/src/main/scala/firrtl/passes/Inline.scala +++ b/src/main/scala/firrtl/passes/Inline.scala @@ -49,7 +49,7 @@ class InlineInstances extends Transform { case Nil => CircuitState(state.circuit, state.form) case myAnnotations => val (modNames, instNames) = collectAnns(state.circuit, myAnnotations) - run(state.circuit, modNames, instNames) + run(state.circuit, modNames, instNames, state.annotations) } } @@ -93,7 +93,7 @@ class InlineInstances extends Transform { } - def run(c: Circuit, modsToInline: Set[ModuleName], instsToInline: Set[ComponentName]): CircuitState = { + def run(c: Circuit, modsToInline: Set[ModuleName], instsToInline: Set[ComponentName], annos: Option[AnnotationMap]): CircuitState = { def getInstancesOf(c: Circuit, modules: Set[String]): Set[String] = c.modules.foldLeft(Set[String]()) { (set, d) => d match { @@ -146,6 +146,6 @@ class InlineInstances extends Transform { case m => Some(m map onStmt("", m.name)) }) - CircuitState(flatCircuit, LowForm, None, None) + CircuitState(flatCircuit, LowForm, annos, None) } } diff --git a/src/main/scala/firrtl/passes/clocklist/ClockList.scala b/src/main/scala/firrtl/passes/clocklist/ClockList.scala index 231afbdd..66139c49 100644 --- a/src/main/scala/firrtl/passes/clocklist/ClockList.scala +++ b/src/main/scala/firrtl/passes/clocklist/ClockList.scala @@ -44,7 +44,7 @@ class ClockList(top: String, writer: Writer) extends Pass { // Inline the clock-only circuit up to the specified top module val modulesToInline = (c.modules.collect { case Module(_, n, _, _) if n != top => ModuleName(n, CircuitName(c.main)) }).toSet val inlineTransform = new InlineInstances - val inlinedCircuit = inlineTransform.run(onlyClockCircuit, modulesToInline, Set()).circuit + val inlinedCircuit = inlineTransform.run(onlyClockCircuit, modulesToInline, Set(), None).circuit val topModule = inlinedCircuit.modules.find(_.name == top).getOrElse(throwInternalError) // Build a hashmap of connections to use for getOrigins diff --git a/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala b/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala index 8b5a0627..b04171a7 100644 --- a/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala +++ b/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala @@ -69,8 +69,8 @@ class ClockListTransform extends Transform { val outputFile = new PrintWriter(out) val newC = (new ClockList(top, outputFile)).run(state.circuit) outputFile.close() - CircuitState(newC, state.form) - case Nil => CircuitState(state.circuit, state.form) + CircuitState(newC, state.form, state.annotations) + case Nil => state case seq => error(s"Found illegal clock list annotation(s): $seq") } } diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala index 554c1f0d..b941503f 100644 --- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala +++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala @@ -158,8 +158,8 @@ class InferReadWrite extends Transform with PassBased { ResolveGenders ) def execute(state: CircuitState): CircuitState = getMyAnnotations(state) match { - case Nil => CircuitState(state.circuit, state.form) + case Nil => state case Seq(InferReadWriteAnnotation(CircuitName(state.circuit.main))) => - CircuitState(runPasses(state.circuit), state.form) + state.copy(circuit = runPasses(state.circuit)) } } diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala index 1659cf22..0c12d2aa 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala @@ -110,7 +110,7 @@ Optional Arguments: class SimpleTransform(p: Pass, form: CircuitForm) extends Transform { def inputForm = form def outputForm = form - def execute(state: CircuitState): CircuitState = CircuitState(p.run(state.circuit), state.form) + def execute(state: CircuitState): CircuitState = CircuitState(p.run(state.circuit), state.form, state.annotations) } class SimpleMidTransform(p: Pass) extends SimpleTransform(p, MidForm) @@ -143,12 +143,12 @@ class ReplSeqMem extends Transform with SimpleRun { Some(AnnotationMap(ann.annotations ++ curState.annotations.get.annotations)) } CircuitState(res.circuit, res.form, newAnnotations) - }).copy(annotations = None) + }) } def execute(state: CircuitState): CircuitState = getMyAnnotations(state) match { - case Nil => state.copy(annotations = None) // Do nothing if there are no annotations + case Nil => state // Do nothing if there are no annotations case p => (p.collectFirst { case a if (a.target == CircuitName(state.circuit.main)) => a }) match { case Some(ReplSeqMemAnnotation(target, inputFileName, outputConfig)) => val inConfigFile = { diff --git a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala index a4f7245b..5a35d85c 100644 --- a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala +++ b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala @@ -69,7 +69,7 @@ class WiringTransform extends Transform with SimpleRun { ResolveKinds, ResolveGenders) def execute(state: CircuitState): CircuitState = getMyAnnotations(state) match { - case Nil => CircuitState(state.circuit, state.form) + case Nil => state case p => val sinks = mutable.HashMap[String, Set[String]]() val sources = mutable.HashMap[String, String]() @@ -84,12 +84,12 @@ class WiringTransform extends Transform with SimpleRun { case TopAnnotation(m, pin) => tops(pin) = m.name } (sources.size, tops.size, sinks.size, comp.size) match { - case (0, 0, p, 0) => state.copy(annotations = None) + case (0, 0, p, 0) => state case (s, t, p, c) if (p > 0) & (s == t) & (t == c) => val wis = tops.foldLeft(Seq[WiringInfo]()) { case (seq, (pin, top)) => seq :+ WiringInfo(sources(pin), comp(pin), sinks("pin:" + pin), pin, top) } - state.copy(circuit = runPasses(state.circuit, passSeq(wis)), annotations = None) + state.copy(circuit = runPasses(state.circuit, passSeq(wis))) case _ => error("Wrong number of sources, tops, or sinks!") } } diff --git a/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala b/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala index aaae284b..9fff30fe 100644 --- a/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala +++ b/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala @@ -122,7 +122,7 @@ class BlackBoxSourceHelper extends firrtl.Transform { writer.close() } - CircuitState(resultState.circuit, resultState.form) + resultState } } diff --git a/src/main/scala/firrtl/transforms/Dedup.scala b/src/main/scala/firrtl/transforms/Dedup.scala index 0ca471af..5fa2c036 100644 --- a/src/main/scala/firrtl/transforms/Dedup.scala +++ b/src/main/scala/firrtl/transforms/Dedup.scala @@ -32,10 +32,10 @@ class DedupModules extends Transform { def outputForm = HighForm def execute(state: CircuitState): CircuitState = { getMyAnnotations(state) match { - case Nil => CircuitState(run(state.circuit, Seq.empty), state.form) + case Nil => state.copy(circuit = run(state.circuit, Seq.empty)) case annos => val noDedups = annos.collect { case NoDedupAnnotation(ModuleName(m, c)) => m } - CircuitState(run(state.circuit, noDedups), state.form) + state.copy(circuit = run(state.circuit, noDedups)) } } // Orders the modules of a circuit from leaves to root diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala index 29f8f51a..534b6540 100644 --- a/src/test/scala/firrtlTests/AnnotationTests.scala +++ b/src/test/scala/firrtlTests/AnnotationTests.scala @@ -121,4 +121,25 @@ class AnnotationTests extends AnnotationSpec with Matchers { beforeAnno should be (afterAnno) } } + + "Deleting annotations" should "create a DeletedAnnotation" in { + val compiler = new VerilogCompiler + val input = + """circuit Top : + | module Top : + | input in: UInt<3> + |""".stripMargin + class DeletingTransform extends Transform { + val inputForm = LowForm + val outputForm = LowForm + def execute(state: CircuitState) = state.copy(annotations = None) + } + val anno = InlineAnnotation(CircuitName("Top")) + val annoOpt = Some(AnnotationMap(Seq(anno))) + val writer = new StringWriter() + val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, annoOpt), writer, Seq(new DeletingTransform)) + result.annotations.get.annotations.head should matchPattern { + case DeletedAnnotation(x, anno) => + } + } } diff --git a/src/test/scala/firrtlTests/InlineInstancesTests.scala b/src/test/scala/firrtlTests/InlineInstancesTests.scala index 25a194d4..a3b7386d 100644 --- a/src/test/scala/firrtlTests/InlineInstancesTests.scala +++ b/src/test/scala/firrtlTests/InlineInstancesTests.scala @@ -17,6 +17,8 @@ import firrtl.annotations.{ Annotation } import firrtl.passes.{InlineInstances, InlineAnnotation} +import logger.Logger +import logger.LogLevel.Debug /** @@ -24,6 +26,8 @@ import firrtl.passes.{InlineInstances, InlineAnnotation} */ class InlineInstancesTests extends LowTransformSpec { def transform = new InlineInstances + // Set this to debug + // Logger.setClassLogLevels(Map(this.getClass.getName -> Debug)) "The module Inline" should "be inlined" in { val input = """circuit Top : diff --git a/src/test/scala/firrtlTests/PassTests.scala b/src/test/scala/firrtlTests/PassTests.scala index df56c097..8e5d74ad 100644 --- a/src/test/scala/firrtlTests/PassTests.scala +++ b/src/test/scala/firrtlTests/PassTests.scala @@ -2,7 +2,6 @@ package firrtlTests -import com.typesafe.scalalogging.LazyLogging import java.io.{StringWriter,Writer} import org.scalatest.{FlatSpec, Matchers} import org.scalatest.junit.JUnitRunner diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index 1a5b44e6..93aec7f4 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -7,6 +7,8 @@ import firrtl.ir._ import firrtl.passes._ import firrtl.passes.memlib._ import annotations._ +import logger.Logger +import logger.LogLevel.Debug class ReplSeqMemSpec extends SimpleTransformSpec { def emitter = new LowFirrtlEmitter |
