diff options
| author | Adam Izraelevitz | 2018-10-24 20:40:27 -0700 |
|---|---|---|
| committer | GitHub | 2018-10-24 20:40:27 -0700 |
| commit | 7e2f787e125227dc389d5cf1d09717748ecfed2e (patch) | |
| tree | 2c654726a5c9850440792cf673e91ed01e0bdfe4 /src | |
| parent | f2c50e11c0e1ff3ed7b8ca3ae3d2d3b16f157453 (diff) | |
Instance Annotations (#865)
Added Target, which now supports Instance Annotations. See #865 for details.
Diffstat (limited to 'src')
29 files changed, 2773 insertions, 338 deletions
diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala index 9044c5a8..80ba42c4 100644 --- a/src/main/scala/firrtl/Compiler.scala +++ b/src/main/scala/firrtl/Compiler.scala @@ -4,121 +4,23 @@ package firrtl import logger._ import java.io.Writer -import annotations._ -import scala.collection.mutable - -import firrtl.annotations._ // Note that wildcard imports are not great.... -import firrtl.ir.Circuit -import firrtl.Utils.{error, throwInternalError} - -object RenameMap { - def apply(map: Map[Named, Seq[Named]]) = { - val rm = new RenameMap - rm.addMap(map) - rm - } - def apply() = new RenameMap -} -/** Map old names to new names - * - * Transforms that modify names should return a [[RenameMap]] with the [[CircuitState]] - * These are mutable datastructures for convenience - */ -// TODO This should probably be refactored into immutable and mutable versions -final class RenameMap private () { - private val underlying = mutable.HashMap[Named, Seq[Named]]() - /** Get renames of a [[CircuitName]] - * @note A [[CircuitName]] can only be renamed to a single [[CircuitName]] - */ - def get(key: CircuitName): Option[CircuitName] = underlying.get(key).map { - case Seq(c: CircuitName) => c - case other => error(s"Unsupported Circuit rename to $other!") - } - /** Get renames of a [[ModuleName]] - * @note A [[ModuleName]] can only be renamed to one-or-more [[ModuleName]]s - */ - def get(key: ModuleName): Option[Seq[ModuleName]] = { - def nestedRename(m: ModuleName): Option[Seq[ModuleName]] = - this.get(m.circuit).map(cname => Seq(ModuleName(m.name, cname))) - underlying.get(key) match { - case Some(names) => Some(names.flatMap { - case m: ModuleName => - nestedRename(m).getOrElse(Seq(m)) - case other => error(s"Unsupported Module rename of $key to $other") - }) - case None => nestedRename(key) - } - } - /** Get renames of a [[ComponentName]] - * @note A [[ComponentName]] can only be renamed to one-or-more [[ComponentName]]s - */ - def get(key: ComponentName): Option[Seq[ComponentName]] = { - def nestedRename(c: ComponentName): Option[Seq[ComponentName]] = - this.get(c.module).map { modules => - modules.map(mname => ComponentName(c.name, mname)) - } - underlying.get(key) match { - case Some(names) => Some(names.flatMap { - case c: ComponentName => - nestedRename(c).getOrElse(Seq(c)) - case other => error(s"Unsupported Component rename of $key to $other") - }) - case None => nestedRename(key) - } - } - /** Get new names for an old name - * - * This is analogous to get on standard Scala collection Maps - * None indicates the key was not renamed - * Empty indicates the name was deleted - */ - def get(key: Named): Option[Seq[Named]] = key match { - case c: ComponentName => this.get(c) - case m: ModuleName => this.get(m) - // The CircuitName version returns Option[CircuitName] - case c: CircuitName => this.get(c).map(Seq(_)) - } +import firrtl.RenameMap.{CircularRenameException, IllegalRenameException} - // Mutable helpers - private var circuitName: String = "" - private var moduleName: String = "" - def setModule(s: String) = - moduleName = s - def setCircuit(s: String) = - circuitName = s - def rename(from: String, to: String): Unit = rename(from, Seq(to)) - def rename(from: String, tos: Seq[String]): Unit = { - val fromName = ComponentName(from, ModuleName(moduleName, CircuitName(circuitName))) - val tosName = tos map { to => - ComponentName(to, ModuleName(moduleName, CircuitName(circuitName))) - } - rename(fromName, tosName) - } - def rename(from: Named, to: Named): Unit = rename(from, Seq(to)) - def rename(from: Named, tos: Seq[Named]): Unit = (from, tos) match { - case (x, Seq(y)) if x == y => // TODO is this check expensive in common case? - case _ => - underlying(from) = underlying.getOrElse(from, Seq.empty) ++ tos - } - def delete(names: Seq[String]): Unit = names.foreach(delete(_)) - def delete(name: String): Unit = - delete(ComponentName(name, ModuleName(moduleName, CircuitName(circuitName)))) - def delete(name: Named): Unit = - underlying(name) = Seq.empty - def addMap(map: Map[Named, Seq[Named]]) = - underlying ++= map - def serialize: String = underlying.map { case (k, v) => - k.serialize + "=>" + v.map(_.serialize).mkString(", ") - }.mkString("\n") -} +import scala.collection.mutable +import firrtl.annotations._ +import firrtl.ir.{Circuit, Expression} +import firrtl.Utils.{error, throwInternalError} +import firrtl.annotations.TargetToken +import firrtl.annotations.TargetToken.{Field, Index} +import firrtl.annotations.transforms.{EliminateTargetPaths, ResolvePaths} /** Container of all annotations for a Firrtl compiler */ class AnnotationSeq private (private[firrtl] val underlying: List[Annotation]) { def toSeq: Seq[Annotation] = underlying.toSeq } object AnnotationSeq { - def apply(xs: Seq[Annotation]) = new AnnotationSeq(xs.toList) + def apply(xs: Seq[Annotation]): AnnotationSeq = new AnnotationSeq(xs.toList) } /** Current State of the Circuit @@ -145,15 +47,45 @@ case class CircuitState( case None => throw new FIRRTLException(s"No EmittedCircuit found! Did you delete any annotations?\n$deletedAnnotations") } + /** Helper function for extracting emitted components from annotations */ def emittedComponents: Seq[EmittedComponent] = annotations.collect { case emitted: EmittedAnnotation[_] => emitted.value } def deletedAnnotations: Seq[Annotation] = annotations.collect { case anno: DeletedAnnotation => anno } + + /** Returns a new CircuitState with all targets being resolved. + * Paths through instances are replaced with a uniquified final target + * Includes modifying the circuit and annotations + * @param targets + * @return + */ + def resolvePaths(targets: Seq[CompleteTarget]): CircuitState = { + val newCS = new EliminateTargetPaths().runTransform(this.copy(annotations = ResolvePaths(targets) +: annotations )) + newCS.copy(form = form) + } + + /** Returns a new CircuitState with the targets of every annotation of a type in annoClasses + * @param annoClasses + * @return + */ + def resolvePathsOf(annoClasses: Class[_]*): CircuitState = { + val targets = getAnnotationsOf(annoClasses:_*).flatMap(_.getTargets) + if(targets.nonEmpty) resolvePaths(targets.flatMap{_.getComplete}) else this + } + + /** Returns all annotations which are of a class in annoClasses + * @param annoClasses + * @return + */ + def getAnnotationsOf(annoClasses: Class[_]*): AnnotationSeq = { + annotations.collect { case a if annoClasses.contains(a.getClass) => a } + } } + object CircuitState { def apply(circuit: Circuit, form: CircuitForm): CircuitState = apply(circuit, form, Seq()) - def apply(circuit: Circuit, form: CircuitForm, annotations: AnnotationSeq) = + def apply(circuit: Circuit, form: CircuitForm, annotations: AnnotationSeq): CircuitState = new CircuitState(circuit, form, annotations, None) } @@ -170,6 +102,9 @@ sealed abstract class CircuitForm(private val value: Int) extends Ordered[Circui // Note that value is used only to allow comparisons def compare(that: CircuitForm): Int = this.value - that.value } + +// scalastyle:off magic.number +// These magic numbers give an ordering to CircuitForm /** Chirrtl Form * * The form of the circuit emitted by Chisel. Not a true Firrtl form. @@ -178,7 +113,7 @@ sealed abstract class CircuitForm(private val value: Int) extends Ordered[Circui * * See [[CDefMemory]] and [[CDefMPort]] */ -final case object ChirrtlForm extends CircuitForm(3) +final case object ChirrtlForm extends CircuitForm(value = 3) /** High Form * * As detailed in the Firrtl specification @@ -216,6 +151,7 @@ final case object LowForm extends CircuitForm(0) final case object UnknownForm extends CircuitForm(-1) { override def compare(that: CircuitForm): Int = { sys.error("Illegal to compare UnknownForm"); 0 } } +// scalastyle:on magic.number /** The basic unit of operating on a Firrtl AST */ abstract class Transform extends LazyLogging { @@ -246,6 +182,12 @@ abstract class Transform extends LazyLogging { state.annotations.collect { case a: LegacyAnnotation if a.transform == this.getClass => a } } + /** Executes before any transform's execute method + * @param state + * @return + */ + private[firrtl] def prepare(state: CircuitState): CircuitState = state + /** Perform the transform and update annotations. * * @param state Input Firrtl AST @@ -254,7 +196,7 @@ abstract class Transform extends LazyLogging { final def runTransform(state: CircuitState): CircuitState = { logger.info(s"======== Starting Transform $name ========") - val (timeMillis, result) = Utils.time { execute(state) } + val (timeMillis, result) = Utils.time { execute(prepare(state)) } logger.info(s"""----------------------------${"-" * name.size}---------\n""") logger.info(f"Time: $timeMillis%.1f ms") @@ -296,10 +238,23 @@ abstract class Transform extends LazyLogging { // For each annotation, rename all annotations. val renames = renameOpt.getOrElse(RenameMap()) - for { - anno <- newAnnotations.toSeq - newAnno <- anno.update(renames) - } yield newAnno + val remapped2original = mutable.LinkedHashMap[Annotation, mutable.LinkedHashSet[Annotation]]() + val keysOfNote = mutable.LinkedHashSet[Annotation]() + val finalAnnotations = newAnnotations.flatMap { anno => + val remappedAnnos = anno.update(renames) + remappedAnnos.foreach { remapped => + val set = remapped2original.getOrElseUpdate(remapped, mutable.LinkedHashSet.empty[Annotation]) + set += anno + if(set.size > 1) keysOfNote += remapped + } + remappedAnnos + }.toSeq + keysOfNote.foreach { key => + logger.debug(s"""The following original annotations are renamed to the same new annotation.""") + logger.debug(s"""Original Annotations:\n ${remapped2original(key).mkString("\n ")}""") + logger.debug(s"""New Annotation:\n $key""") + } + finalAnnotations } } @@ -321,6 +276,19 @@ abstract class SeqTransform extends Transform with SeqTransformBased { } } +/** Extend for transforms that require resolved targets in their annotations + * Ensures all targets in annotations of a class in annotationClasses are resolved before the execute method + */ +trait ResolvedAnnotationPaths { + this: Transform => + + val annotationClasses: Traversable[Class[_]] + + override def prepare(state: CircuitState): CircuitState = { + state.resolvePathsOf(annotationClasses.toSeq:_*) + } +} + /** Defines old API for Emission. Deprecated */ trait Emitter extends Transform { @deprecated("Use emission annotations instead", "firrtl 1.0") @@ -406,8 +374,8 @@ trait Compiler extends LazyLogging { def transforms: Seq[Transform] // Similar to (input|output)Form on [[Transform]] but derived from this Compiler's transforms - def inputForm = transforms.head.inputForm - def outputForm = transforms.last.outputForm + def inputForm: CircuitForm = transforms.head.inputForm + def outputForm: CircuitForm = transforms.last.outputForm private def transformsLegal(xforms: Seq[Transform]): Boolean = if (xforms.size < 2) { diff --git a/src/main/scala/firrtl/RenameMap.scala b/src/main/scala/firrtl/RenameMap.scala new file mode 100644 index 00000000..e95260af --- /dev/null +++ b/src/main/scala/firrtl/RenameMap.scala @@ -0,0 +1,424 @@ +// See LICENSE for license details. + +package firrtl + +import annotations._ +import firrtl.RenameMap.{CircularRenameException, IllegalRenameException} +import firrtl.annotations.TargetToken.{Field, Index} + +import scala.collection.mutable + +object RenameMap { + @deprecated("Use create with CompleteTarget instead, this will be removed in 1.3", "1.2") + def apply(map: collection.Map[Named, Seq[Named]]): RenameMap = { + val rm = new RenameMap + rm.addMap(map) + rm + } + + def create(map: collection.Map[CompleteTarget, Seq[CompleteTarget]]): RenameMap = { + val rm = new RenameMap + rm.recordAll(map) + rm + } + + def apply(): RenameMap = new RenameMap + + abstract class RenameTargetException(reason: String) extends Exception(reason) + case class IllegalRenameException(reason: String) extends RenameTargetException(reason) + case class CircularRenameException(reason: String) extends RenameTargetException(reason) +} + +/** Map old names to new names + * + * Transforms that modify names should return a [[RenameMap]] with the [[CircuitState]] + * These are mutable datastructures for convenience + */ +// TODO This should probably be refactored into immutable and mutable versions +final class RenameMap private () { + + /** Record that the from [[CircuitTarget]] is renamed to another [[CircuitTarget]] + * @param from + * @param to + */ + def record(from: CircuitTarget, to: CircuitTarget): Unit = completeRename(from, Seq(to)) + + /** Record that the from [[CircuitTarget]] is renamed to another sequence of [[CircuitTarget]]s + * @param from + * @param tos + */ + def record(from: CircuitTarget, tos: Seq[CircuitTarget]): Unit = completeRename(from, tos) + + /** Record that the from [[IsMember]] is renamed to another [[IsMember]] + * @param from + * @param to + */ + def record(from: IsMember, to: IsMember): Unit = completeRename(from, Seq(to)) + + /** Record that the from [[IsMember]] is renamed to another sequence of [[IsMember]]s + * @param from + * @param tos + */ + def record(from: IsMember, tos: Seq[IsMember]): Unit = completeRename(from, tos) + + /** Records that the keys in map are also renamed to their corresponding value seqs. + * Only ([[CircuitTarget]] -> Seq[ [[CircuitTarget]] ]) and ([[IsMember]] -> Seq[ [[IsMember]] ]) key/value allowed + * @param map + */ + def recordAll(map: collection.Map[CompleteTarget, Seq[CompleteTarget]]): Unit = + map.foreach{ + case (from: IsComponent, tos: Seq[IsMember]) => completeRename(from, tos) + case (from: IsModule, tos: Seq[IsMember]) => completeRename(from, tos) + case (from: CircuitTarget, tos: Seq[CircuitTarget]) => completeRename(from, tos) + case other => Utils.throwInternalError(s"Illegal rename: ${other._1} -> ${other._2}") + } + + /** Records that a [[CompleteTarget]] is deleted + * @param name + */ + def delete(name: CompleteTarget): Unit = underlying(name) = Seq.empty + + /** Renames a [[CompleteTarget]] + * @param t target to rename + * @return renamed targets + */ + def apply(t: CompleteTarget): Seq[CompleteTarget] = completeGet(t).getOrElse(Seq(t)) + + /** Get renames of a [[CircuitTarget]] + * @param key Target referencing the original circuit + * @return Optionally return sequence of targets that key remaps to + */ + def get(key: CompleteTarget): Option[Seq[CompleteTarget]] = completeGet(key) + + /** Get renames of a [[CircuitTarget]] + * @param key Target referencing the original circuit + * @return Optionally return sequence of targets that key remaps to + */ + def get(key: CircuitTarget): Option[Seq[CircuitTarget]] = completeGet(key).map( _.map { case x: CircuitTarget => x } ) + + /** Get renames of a [[IsMember]] + * @param key Target referencing the original member of the circuit + * @return Optionally return sequence of targets that key remaps to + */ + def get(key: IsMember): Option[Seq[IsMember]] = completeGet(key).map { _.map { case x: IsMember => x } } + + + /** Create new [[RenameMap]] that merges this and renameMap + * @param renameMap + * @return + */ + def ++ (renameMap: RenameMap): RenameMap = RenameMap(underlying ++ renameMap.getUnderlying) + + /** Returns the underlying map of rename information + * @return + */ + def getUnderlying: collection.Map[CompleteTarget, Seq[CompleteTarget]] = underlying + + /** @return Whether this [[RenameMap]] has collected any changes */ + def hasChanges: Boolean = underlying.nonEmpty + + def getReverseRenameMap: RenameMap = { + val reverseMap = mutable.HashMap[CompleteTarget, Seq[CompleteTarget]]() + underlying.keysIterator.foreach{ key => + apply(key).foreach { v => + reverseMap(v) = key +: reverseMap.getOrElse(v, Nil) + } + } + RenameMap.create(reverseMap) + } + + def keys: Iterator[CompleteTarget] = underlying.keysIterator + + /** Serialize the underlying remapping of keys to new targets + * @return + */ + def serialize: String = underlying.map { case (k, v) => + k.serialize + "=>" + v.map(_.serialize).mkString(", ") + }.mkString("\n") + + /** Maps old names to new names. New names could still require renaming parts of their name + * Old names must refer to existing names in the old circuit + */ + private val underlying = mutable.HashMap[CompleteTarget, Seq[CompleteTarget]]() + + /** Records which local InstanceTargets will require modification. + * Used to reduce time to rename nonlocal targets who's path does not require renaming + */ + private val sensitivity = mutable.HashSet[IsComponent]() + + /** Caches results of recursiveGet. Is cleared any time a new rename target is added + */ + private val getCache = mutable.HashMap[CompleteTarget, Seq[CompleteTarget]]() + + /** Updates [[sensitivity]] + * @param from original target + * @param to new target + */ + private def recordSensitivity(from: CompleteTarget, to: CompleteTarget): Unit = { + (from, to) match { + case (f: IsMember, t: IsMember) => + val fromSet = f.pathAsTargets.toSet + val toSet = t.pathAsTargets + sensitivity ++= (fromSet -- toSet) + sensitivity ++= (fromSet.map(_.asReference) -- toSet.map(_.asReference)) + case other => + } + } + + /** Get renames of a [[CompleteTarget]] + * @param key Target referencing the original circuit + * @return Optionally return sequence of targets that key remaps to + */ + private def completeGet(key: CompleteTarget): Option[Seq[CompleteTarget]] = { + val errors = mutable.ArrayBuffer[String]() + val ret = if(hasChanges) { + val ret = recursiveGet(mutable.LinkedHashSet.empty[CompleteTarget], errors)(key) + if(errors.nonEmpty) { throw IllegalRenameException(errors.mkString("\n")) } + if(ret.size == 1 && ret.head == key) { None } else { Some(ret) } + } else { None } + ret + } + + // scalastyle:off + // This function requires a large cyclomatic complexity, and is best naturally expressed as a large function + /** Recursively renames a target so the returned targets are complete renamed + * @param set Used to detect circular renames + * @param errors Used to record illegal renames + * @param key Target to rename + * @return Renamed targets + */ + private def recursiveGet(set: mutable.LinkedHashSet[CompleteTarget], + errors: mutable.ArrayBuffer[String] + )(key: CompleteTarget): Seq[CompleteTarget] = { + if(getCache.contains(key)) { + getCache(key) + } else { + // First, check if whole key is remapped + // Note that remapped could hold stale parent targets that require renaming + val remapped = underlying.getOrElse(key, Seq(key)) + + // If we've seen this key before in recursive calls to parentTargets, then we know a circular renaming + // mapping has occurred, and no legal name exists + if(set.contains(key) && !key.isInstanceOf[CircuitTarget]) { + throw CircularRenameException(s"Illegal rename: circular renaming is illegal - ${set.mkString(" -> ")}") + } + + // Add key to set to detect circular renaming + set += key + + // Curry recursiveGet for cleaner syntax below + val getter = recursiveGet(set, errors)(_) + + // For each remapped key, call recursiveGet on their parentTargets + val ret = remapped.flatMap { + + // If t is a CircuitTarget, return it because it has no parent target + case t: CircuitTarget => Seq(t) + + // If t is a ModuleTarget, try to rename parent target, then update t's parent + case t: ModuleTarget => getter(t.targetParent).map { + case CircuitTarget(c) => ModuleTarget(c, t.module) + } + + /** If t is an InstanceTarget (has a path) but has no references: + * 1) Check whether the instance has been renamed (asReference) + * 2) Check whether the ofModule of the instance has been renamed (only 1:1 renaming is ok) + */ + case t: InstanceTarget => + getter(t.asReference).map { + case t2:InstanceTarget => t2 + case t2@ReferenceTarget(c, m, p, r, Nil) => + val t3 = InstanceTarget(c, m, p, r, t.ofModule) + val ofModuleTarget = t3.ofModuleTarget + getter(ofModuleTarget) match { + case Seq(ModuleTarget(newCircuit, newOf)) if newCircuit == t3.circuit => t3.copy(ofModule = newOf) + case other => + errors += s"Illegal rename: ofModule of $t is renamed to $other - must rename $t directly." + t + } + case other => + errors += s"Illegal rename: $t has new instance reference $other" + t + } + + /** If t is a ReferenceTarget: + * 1) Check parentTarget to tokens + * 2) Check ReferenceTarget with one layer stripped from its path hierarchy (i.e. a new root module) + */ + case t: ReferenceTarget => + val ret: Seq[CompleteTarget] = if(t.component.nonEmpty) { + val last = t.component.last + getter(t.targetParent).map{ x => + (x, last) match { + case (t2: ReferenceTarget, Field(f)) => t2.field(f) + case (t2: ReferenceTarget, Index(i)) => t2.index(i) + case other => + errors += s"Illegal rename: ${t.targetParent} cannot be renamed to ${other._1} - must rename $t directly" + t + } + } + } else { + val pathTargets = sensitivity.empty ++ (t.pathAsTargets ++ t.pathAsTargets.map(_.asReference)) + if(t.pathAsTargets.nonEmpty && sensitivity.intersect(pathTargets).isEmpty) Seq(t) else { + getter(t.pathTarget).map { + case newPath: IsModule => t.setPathTarget(newPath) + case other => + errors += s"Illegal rename: path ${t.pathTarget} of $t cannot be renamed to $other - must rename $t directly" + t + } + } + } + ret.flatMap { + case y: IsComponent if !y.isLocal => + val encapsulatingInstance = y.path.head._1.value + getter(y.stripHierarchy(1)).map { + _.addHierarchy(y.moduleOpt.get, encapsulatingInstance) + } + case other => Seq(other) + } + } + + // Remove key from set as visiting the same key twice is ok, as long as its not during the same recursive call + set -= key + + // Cache result + getCache(key) = ret + + // Return result + ret + + } + } + // scalastyle:on + + /** Fully renames from to tos + * @param from + * @param tos + */ + private def completeRename(from: CompleteTarget, tos: Seq[CompleteTarget]): Unit = { + def check(from: CompleteTarget, to: CompleteTarget)(t: CompleteTarget): Unit = { + require(from != t, s"Cannot record $from to $to, as it is a circular constraint") + t match { + case _: CircuitTarget => + case other: IsMember => check(from, to)(other.targetParent) + } + } + tos.foreach { to => if(from != to) check(from, to)(to) } + (from, tos) match { + case (x, Seq(y)) if x == y => + case _ => + tos.foreach{recordSensitivity(from, _)} + val existing = underlying.getOrElse(from, Seq.empty) + val updated = existing ++ tos + underlying(from) = updated + getCache.clear() + } + } + + /* DEPRECATED ACCESSOR/SETTOR METHODS WITH [[Named]] */ + + @deprecated("Use record with CircuitTarget instead, this will be removed in 1.3", "1.2") + def rename(from: Named, to: Named): Unit = rename(from, Seq(to)) + + @deprecated("Use record with IsMember instead, this will be removed in 1.3", "1.2") + def rename(from: Named, tos: Seq[Named]): Unit = recordAll(Map(from.toTarget -> tos.map(_.toTarget))) + + @deprecated("Use record with IsMember instead, this will be removed in 1.3", "1.2") + def rename(from: ComponentName, to: ComponentName): Unit = record(from, to) + + @deprecated("Use record with IsMember instead, this will be removed in 1.3", "1.2") + def rename(from: ComponentName, tos: Seq[ComponentName]): Unit = record(from, tos.map(_.toTarget)) + + @deprecated("Use delete with CircuitTarget instead, this will be removed in 1.3", "1.2") + def delete(name: CircuitName): Unit = underlying(name) = Seq.empty + + @deprecated("Use delete with IsMember instead, this will be removed in 1.3", "1.2") + def delete(name: ModuleName): Unit = underlying(name) = Seq.empty + + @deprecated("Use delete with IsMember instead, this will be removed in 1.3", "1.2") + def delete(name: ComponentName): Unit = underlying(name) = Seq.empty + + @deprecated("Use recordAll with CompleteTarget instead, this will be removed in 1.3", "1.2") + def addMap(map: collection.Map[Named, Seq[Named]]): Unit = + recordAll(map.map { case (key, values) => (Target.convertNamed2Target(key), values.map(Target.convertNamed2Target)) }) + + @deprecated("Use get with CircuitTarget instead, this will be removed in 1.3", "1.2") + def get(key: CircuitName): Option[Seq[CircuitName]] = { + get(Target.convertCircuitName2CircuitTarget(key)).map(_.collect{ case c: CircuitTarget => c.toNamed }) + } + + @deprecated("Use get with IsMember instead, this will be removed in 1.3", "1.2") + def get(key: ModuleName): Option[Seq[ModuleName]] = { + get(Target.convertModuleName2ModuleTarget(key)).map(_.collect{ case m: ModuleTarget => m.toNamed }) + } + + @deprecated("Use get with IsMember instead, this will be removed in 1.3", "1.2") + def get(key: ComponentName): Option[Seq[ComponentName]] = { + get(Target.convertComponentName2ReferenceTarget(key)).map(_.collect{ case c: IsComponent => c.toNamed }) + } + + @deprecated("Use get with IsMember instead, this will be removed in 1.3", "1.2") + def get(key: Named): Option[Seq[Named]] = key match { + case t: CompleteTarget => get(t) + case other => get(key.toTarget).map(_.collect{ case c: IsComponent => c.toNamed }) + } + + + // Mutable helpers - APIs that set these are deprecated! + private var circuitName: String = "" + private var moduleName: String = "" + + /** Sets mutable state to record current module we are visiting + * @param module + */ + @deprecated("Use typesafe rename defs instead, this will be removed in 1.3", "1.2") + def setModule(module: String): Unit = moduleName = module + + /** Sets mutable state to record current circuit we are visiting + * @param circuit + */ + @deprecated("Use typesafe rename defs instead, this will be removed in 1.3", "1.2") + def setCircuit(circuit: String): Unit = circuitName = circuit + + /** Records how a reference maps to a new reference + * @param from + * @param to + */ + @deprecated("Use typesafe rename defs instead, this will be removed in 1.3", "1.2") + def rename(from: String, to: String): Unit = rename(from, Seq(to)) + + /** Records how a reference maps to a new reference + * The reference's root module and circuit are determined by whomever called setModule or setCircuit last + * @param from + * @param tos + */ + @deprecated("Use typesafe rename defs instead, this will be removed in 1.3", "1.2") + def rename(from: String, tos: Seq[String]): Unit = { + val mn = ModuleName(moduleName, CircuitName(circuitName)) + val fromName = ComponentName(from, mn).toTarget + val tosName = tos map { to => ComponentName(to, mn).toTarget } + record(fromName, tosName) + } + + /** Records named reference is deleted + * The reference's root module and circuit are determined by whomever called setModule or setCircuit last + * @param name + */ + @deprecated("Use typesafe rename defs instead, this will be removed in 1.3", "1.2") + def delete(name: String): Unit = { + Target(Some(circuitName), Some(moduleName), AnnotationUtils.toSubComponents(name)).getComplete match { + case Some(t: CircuitTarget) => delete(t) + case Some(m: IsMember) => delete(m) + case other => + } + } + + /** Records that references in names are all deleted + * The reference's root module and circuit are determined by whomever called setModule or setCircuit last + * @param names + */ + @deprecated("Use typesafe rename defs instead, this will be removed in 1.3", "1.2") + def delete(names: Seq[String]): Unit = names.foreach(delete(_)) +} + + diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 32893411..f2cebe56 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -178,7 +178,7 @@ object Utils extends LazyLogging { error("Internal Error! %sPlease file an issue at https://github.com/ucb-bar/firrtl/issues".format(string), throwable) } - private[firrtl] def time[R](block: => R): (Double, R) = { + def time[R](block: => R): (Double, R) = { val t0 = System.nanoTime() val result = block val t1 = System.nanoTime() diff --git a/src/main/scala/firrtl/analyses/InstanceGraph.scala b/src/main/scala/firrtl/analyses/InstanceGraph.scala index 6eb67938..00689a51 100644 --- a/src/main/scala/firrtl/analyses/InstanceGraph.scala +++ b/src/main/scala/firrtl/analyses/InstanceGraph.scala @@ -3,12 +3,12 @@ package firrtl.analyses import scala.collection.mutable - import firrtl._ import firrtl.ir._ import firrtl.graph._ import firrtl.Utils._ import firrtl.Mappers._ +import firrtl.annotations.TargetToken.{Instance, OfModule} /** A class representing the instance hierarchy of a working IR Circuit @@ -99,6 +99,12 @@ class InstanceGraph(c: Circuit) { */ def getChildrenInstances: mutable.LinkedHashMap[String, mutable.LinkedHashSet[WDefInstance]] = childInstances + /** Given a circuit, returns a map from module name to children + * instance/module [[firrtl.annotations.TargetToken]]s + */ + def getChildrenInstanceOfModule: mutable.LinkedHashMap[String, mutable.LinkedHashSet[(Instance, OfModule)]] = + childInstances.map(kv => kv._1 -> kv._2.map(i => (Instance(i.name), OfModule(i.module)))) + } diff --git a/src/main/scala/firrtl/annotations/Annotation.scala b/src/main/scala/firrtl/annotations/Annotation.scala index 4b0591bf..62c2b335 100644 --- a/src/main/scala/firrtl/annotations/Annotation.scala +++ b/src/main/scala/firrtl/annotations/Annotation.scala @@ -5,11 +5,15 @@ package annotations import net.jcazevedo.moultingyaml._ import firrtl.annotations.AnnotationYamlProtocol._ +import firrtl.Utils.throwInternalError + +import scala.collection.mutable case class AnnotationException(message: String) extends Exception(message) /** Base type of auxiliary information */ -trait Annotation { +trait Annotation extends Product { + /** Update the target based on how signals are renamed */ def update(renames: RenameMap): Seq[Annotation] @@ -18,13 +22,29 @@ trait Annotation { * @note In [[logger.LogLevel.Debug]] this is called on every Annotation after every Transform */ def serialize: String = this.toString + + /** Recurses through ls to find all [[Target]] instances + * @param ls + * @return + */ + private def extractComponents(ls: scala.collection.Traversable[_]): Seq[Target] = { + ls.collect { + case c: Target => Seq(c) + case ls: scala.collection.Traversable[_] => extractComponents(ls) + }.foldRight(Seq.empty[Target])((seq, c) => c ++ seq) + } + + /** Returns all [[Target]] members in this annotation + * @return + */ + def getTargets: Seq[Target] = extractComponents(productIterator.toSeq) } /** If an Annotation does not target any [[Named]] thing in the circuit, then all updates just * return the Annotation itself */ trait NoTargetAnnotation extends Annotation { - def update(renames: RenameMap) = Seq(this) + def update(renames: RenameMap): Seq[NoTargetAnnotation] = Seq(this) } /** An Annotation that targets a single [[Named]] thing */ @@ -37,18 +57,27 @@ trait SingleTargetAnnotation[T <: Named] extends Annotation { // This mess of @unchecked and try-catch is working around the fact that T is unknown due to type // erasure. We cannot that newTarget is of type T, but a CastClassException will be thrown upon // invoking duplicate if newTarget cannot be cast to T (only possible in the concrete subclass) - def update(renames: RenameMap): Seq[Annotation] = - renames.get(target).map(_.map(newT => (newT: @unchecked) match { - case newTarget: T @unchecked => - try { - duplicate(newTarget) - } catch { - case _: java.lang.ClassCastException => - val msg = s"${this.getClass.getName} target ${target.getClass.getName} " + - s"cannot be renamed to ${newTarget.getClass}" - throw AnnotationException(msg) - } - })).getOrElse(List(this)) + def update(renames: RenameMap): Seq[Annotation] = { + target match { + case c: Target => + val x = renames.get(c) + x.map(newTargets => newTargets.map(t => duplicate(t.asInstanceOf[T]))).getOrElse(List(this)) + case _: Named => + val ret = renames.get(Target.convertNamed2Target(target)) + ret.map(_.map(newT => Target.convertTarget2Named(newT: @unchecked) match { + case newTarget: T @unchecked => + try { + duplicate(newTarget) + } + catch { + case _: java.lang.ClassCastException => + val msg = s"${this.getClass.getName} target ${target.getClass.getName} " + + s"cannot be renamed to ${newTarget.getClass}" + throw AnnotationException(msg) + } + })).getOrElse(List(this)) + } + } } @deprecated("Just extend NoTargetAnnotation", "1.1") @@ -58,7 +87,7 @@ trait SingleStringAnnotation extends NoTargetAnnotation { object Annotation { @deprecated("This returns a LegacyAnnotation, use an explicit Annotation type", "1.1") - def apply(target: Named, transform: Class[_ <: Transform], value: String) = + def apply(target: Named, transform: Class[_ <: Transform], value: String): LegacyAnnotation = new LegacyAnnotation(target, transform, value) @deprecated("This uses LegacyAnnotation, use an explicit Annotation type", "1.1") def unapply(a: LegacyAnnotation): Option[(Named, Class[_ <: Transform], String)] = @@ -92,13 +121,13 @@ final case class LegacyAnnotation private[firrtl] ( } def propagate(from: Named, tos: Seq[Named], dup: Named=>Annotation): Seq[Annotation] = tos.map(dup(_)) def check(from: Named, tos: Seq[Named], which: Annotation): Unit = {} - def duplicate(n: Named) = new LegacyAnnotation(n, transform, value) + def duplicate(n: Named): LegacyAnnotation = new LegacyAnnotation(n, transform, value) } // Private so that LegacyAnnotation can only be constructed via deprecated Annotation.apply private[firrtl] object LegacyAnnotation { // ***** Everything below here is to help people migrate off of old annotations ***** - def errorIllegalAnno(name: String) = + def errorIllegalAnno(name: String): Annotation = throw new Exception(s"Old-style annotations that look like $name are no longer supported") private val OldDeletedRegex = """(?s)DELETED by ([^\n]*)\n(.*)""".r @@ -111,7 +140,9 @@ private[firrtl] object LegacyAnnotation { import firrtl.passes.memlib._ import firrtl.passes.wiring._ import firrtl.passes.clocklist._ + // Attempt to convert common Annotations and error on the rest of old-style build-in annotations + // scalastyle:off def convertLegacyAnno(anno: LegacyAnnotation): Annotation = anno match { // All old-style Emitter annotations are illegal case LegacyAnnotation(_,_,"emitCircuit") => errorIllegalAnno("EmitCircuitAnnotation") @@ -144,7 +175,8 @@ private[firrtl] object LegacyAnnotation { case LegacyAnnotation(c: ComponentName, _, SourceRegex(pin)) => SourceAnnotation(c, pin) case LegacyAnnotation(n, _, SinkRegex(pin)) => SinkAnnotation(n, pin) case LegacyAnnotation(m: ModuleName, t, text) if t == classOf[BlackBoxSourceHelper] => - text.split("\n", 3).toList match { + val nArgs = 3 + text.split("\n", nArgs).toList match { case "resource" :: id :: _ => BlackBoxResourceAnno(m, id) case "inline" :: name :: text :: _ => BlackBoxInlineAnno(m, name, text) case "targetDir" :: targetDir :: _ => BlackBoxTargetDirAnno(targetDir) @@ -152,11 +184,12 @@ private[firrtl] object LegacyAnnotation { } case LegacyAnnotation(_, transform, "noDCE!") if transform == classOf[DeadCodeElimination] => NoDCEAnnotation - case LegacyAnnotation(c: ComponentName, _, "DONTtouch!") => DontTouchAnnotation(c) + case LegacyAnnotation(c: ComponentName, _, "DONTtouch!") => DontTouchAnnotation(c.toTarget) case LegacyAnnotation(c: ModuleName, _, "optimizableExtModule!") => OptimizableExtModuleAnnotation(c) case other => other } + // scalastyle:on def convertLegacyAnnos(annos: Seq[Annotation]): Seq[Annotation] = { var warned: Boolean = false annos.map { diff --git a/src/main/scala/firrtl/annotations/AnnotationUtils.scala b/src/main/scala/firrtl/annotations/AnnotationUtils.scala index 517cea26..ba9220f7 100644 --- a/src/main/scala/firrtl/annotations/AnnotationUtils.scala +++ b/src/main/scala/firrtl/annotations/AnnotationUtils.scala @@ -64,13 +64,29 @@ object AnnotationUtils { case Array(c, m, x) => ComponentName(x, ModuleName(m, CircuitName(c))) } + /** Converts a serialized FIRRTL component into a sequence of target tokens + * @param s + * @return + */ + def toSubComponents(s: String): Seq[TargetToken] = { + import TargetToken._ + def exp2subcomp(e: ir.Expression): Seq[TargetToken] = e match { + case ir.Reference(name, _) => Seq(Ref(name)) + case ir.SubField(expr, name, _) => exp2subcomp(expr) :+ Field(name) + case ir.SubIndex(expr, idx, _) => exp2subcomp(expr) :+ Index(idx) + case ir.SubAccess(expr, idx, _) => Utils.throwInternalError(s"For string $s, cannot convert a subaccess $e into a Target") + } + exp2subcomp(toExp(s)) + } + + /** Given a serialized component/subcomponent reference, subindex, subaccess, * or subfield, return the corresponding IR expression. * E.g. "foo.bar" becomes SubField(Reference("foo", UnknownType), "bar", UnknownType) */ def toExp(s: String): Expression = { def parse(tokens: Seq[String]): Expression = { - val DecPattern = """([1-9]\d*)""".r + val DecPattern = """(\d+)""".r def findClose(tokens: Seq[String], index: Int, nOpen: Int): Seq[String] = { if(index >= tokens.size) { Utils.error("Cannot find closing bracket ]") diff --git a/src/main/scala/firrtl/annotations/JsonProtocol.scala b/src/main/scala/firrtl/annotations/JsonProtocol.scala index 7b2617f5..36699151 100644 --- a/src/main/scala/firrtl/annotations/JsonProtocol.scala +++ b/src/main/scala/firrtl/annotations/JsonProtocol.scala @@ -35,11 +35,39 @@ object JsonProtocol { { case named: ComponentName => JString(named.serialize) } )) + + class TargetSerializer extends CustomSerializer[Target](format => ( + { case JString(s) => Target.deserialize(s) }, + { case named: Target => JString(named.serialize) } + )) + class GenericTargetSerializer extends CustomSerializer[GenericTarget](format => ( + { case JString(s) => Target.deserialize(s).asInstanceOf[GenericTarget] }, + { case named: GenericTarget => JString(named.serialize) } + )) + class CircuitTargetSerializer extends CustomSerializer[CircuitTarget](format => ( + { case JString(s) => Target.deserialize(s).asInstanceOf[CircuitTarget] }, + { case named: CircuitTarget => JString(named.serialize) } + )) + class ModuleTargetSerializer extends CustomSerializer[ModuleTarget](format => ( + { case JString(s) => Target.deserialize(s).asInstanceOf[ModuleTarget] }, + { case named: ModuleTarget => JString(named.serialize) } + )) + class InstanceTargetSerializer extends CustomSerializer[InstanceTarget](format => ( + { case JString(s) => Target.deserialize(s).asInstanceOf[InstanceTarget] }, + { case named: InstanceTarget => JString(named.serialize) } + )) + class ReferenceTargetSerializer extends CustomSerializer[ReferenceTarget](format => ( + { case JString(s) => Target.deserialize(s).asInstanceOf[ReferenceTarget] }, + { case named: ReferenceTarget => JString(named.serialize) } + )) + /** Construct Json formatter for annotations */ def jsonFormat(tags: Seq[Class[_ <: Annotation]]) = { Serialization.formats(FullTypeHints(tags.toList)).withTypeHintFieldName("class") + new TransformClassSerializer + new NamedSerializer + new CircuitNameSerializer + - new ModuleNameSerializer + new ComponentNameSerializer + new ModuleNameSerializer + new ComponentNameSerializer + new TargetSerializer + + new GenericTargetSerializer + new CircuitTargetSerializer + new ModuleTargetSerializer + + new InstanceTargetSerializer + new ReferenceTargetSerializer } /** Serialize annotations to a String for emission */ diff --git a/src/main/scala/firrtl/annotations/Named.scala b/src/main/scala/firrtl/annotations/Named.scala deleted file mode 100644 index 3da75884..00000000 --- a/src/main/scala/firrtl/annotations/Named.scala +++ /dev/null @@ -1,30 +0,0 @@ -// See LICENSE for license details. - -package firrtl -package annotations - -import firrtl.ir.Expression -import AnnotationUtils.{validModuleName, validComponentName, toExp} - -/** - * Named classes associate an annotation with a component in a Firrtl circuit - */ -sealed trait Named { - def serialize: String -} - -final case class CircuitName(name: String) extends Named { - if(!validModuleName(name)) throw AnnotationException(s"Illegal circuit name: $name") - def serialize: String = name -} - -final case class ModuleName(name: String, circuit: CircuitName) extends Named { - if(!validModuleName(name)) throw AnnotationException(s"Illegal module name: $name") - def serialize: String = circuit.serialize + "." + name -} - -final case class ComponentName(name: String, module: ModuleName) extends Named { - if(!validComponentName(name)) throw AnnotationException(s"Illegal component name: $name") - def expr: Expression = toExp(name) - def serialize: String = module.serialize + "." + name -} diff --git a/src/main/scala/firrtl/annotations/Target.scala b/src/main/scala/firrtl/annotations/Target.scala new file mode 100644 index 00000000..dcf5cb02 --- /dev/null +++ b/src/main/scala/firrtl/annotations/Target.scala @@ -0,0 +1,652 @@ +// See LICENSE for license details. + +package firrtl +package annotations + +import firrtl.ir.Expression +import AnnotationUtils.{toExp, validComponentName, validModuleName} +import TargetToken._ + +import scala.collection.mutable + +/** Refers to something in a FIRRTL [[firrtl.ir.Circuit]]. Used for Annotation targets. + * + * Can be in various states of completion/resolved: + * - Legal: [[TargetToken]]'s in tokens are in an order that makes sense + * - Complete: circuitOpt and moduleOpt are non-empty, and all Instance(_) are followed by OfModule(_) + * - Local: tokens does not refer to things through an instance hierarchy (no Instance(_) or OfModule(_) tokens) + */ +sealed trait Target extends Named { + + /** @return Circuit name, if it exists */ + def circuitOpt: Option[String] + + /** @return Module name, if it exists */ + def moduleOpt: Option[String] + + /** @return [[Target]] tokens */ + def tokens: Seq[TargetToken] + + /** @return Returns a new [[GenericTarget]] with new values */ + def modify(circuitOpt: Option[String] = circuitOpt, + moduleOpt: Option[String] = moduleOpt, + tokens: Seq[TargetToken] = tokens): GenericTarget = GenericTarget(circuitOpt, moduleOpt, tokens.toVector) + + /** @return Human-readable serialization */ + def serialize: String = { + val circuitString = "~" + circuitOpt.getOrElse("???") + val moduleString = "|" + moduleOpt.getOrElse("???") + val tokensString = tokens.map { + case Ref(r) => s">$r" + case Instance(i) => s"/$i" + case OfModule(o) => s":$o" + case Field(f) => s".$f" + case Index(v) => s"[$v]" + case Clock => s"@clock" + case Reset => s"@reset" + case Init => s"@init" + }.mkString("") + if(moduleOpt.isEmpty && tokens.isEmpty) { + circuitString + } else if(tokens.isEmpty) { + circuitString + moduleString + } else { + circuitString + moduleString + tokensString + } + } + + /** @return Converts this [[Target]] into a [[GenericTarget]] */ + def toGenericTarget: GenericTarget = GenericTarget(circuitOpt, moduleOpt, tokens.toVector) + + /** @return Converts this [[Target]] into either a [[CircuitName]], [[ModuleName]], or [[ComponentName]] */ + @deprecated("Use Target instead, will be removed in 1.3", "1.2") + def toNamed: Named = toGenericTarget.toNamed + + /** @return If legal, convert this [[Target]] into a [[CompleteTarget]] */ + def getComplete: Option[CompleteTarget] + + /** @return Converts this [[Target]] into a [[CompleteTarget]] */ + def complete: CompleteTarget = getComplete.get + + /** @return Converts this [[Target]] into a [[CompleteTarget]], or if it can't, return original [[Target]] */ + def tryToComplete: Target = getComplete.getOrElse(this) + + /** Whether the target is directly instantiated in its root module */ + def isLocal: Boolean +} + +object Target { + + def apply(circuitOpt: Option[String], moduleOpt: Option[String], reference: Seq[TargetToken]): GenericTarget = + GenericTarget(circuitOpt, moduleOpt, reference.toVector) + + def unapply(t: Target): Option[(Option[String], Option[String], Seq[TargetToken])] = + Some((t.circuitOpt, t.moduleOpt, t.tokens)) + + case class NamedException(message: String) extends Exception(message) + + implicit def convertCircuitTarget2CircuitName(c: CircuitTarget): CircuitName = c.toNamed + implicit def convertModuleTarget2ModuleName(c: ModuleTarget): ModuleName = c.toNamed + implicit def convertIsComponent2ComponentName(c: IsComponent): ComponentName = c.toNamed + implicit def convertTarget2Named(c: Target): Named = c.toNamed + implicit def convertCircuitName2CircuitTarget(c: CircuitName): CircuitTarget = c.toTarget + implicit def convertModuleName2ModuleTarget(c: ModuleName): ModuleTarget = c.toTarget + implicit def convertComponentName2ReferenceTarget(c: ComponentName): ReferenceTarget = c.toTarget + implicit def convertNamed2Target(n: Named): CompleteTarget = n.toTarget + + /** Converts [[ComponentName]]'s name into TargetTokens + * @param name + * @return + */ + def toTargetTokens(name: String): Seq[TargetToken] = { + val tokens = AnnotationUtils.tokenize(name) + val subComps = mutable.ArrayBuffer[TargetToken]() + subComps += Ref(tokens.head) + if(tokens.tail.nonEmpty) { + tokens.tail.zip(tokens.tail.tail).foreach { + case (".", value: String) => subComps += Field(value) + case ("[", value: String) => subComps += Index(value.toInt) + case other => + } + } + subComps + } + + /** Checks if seq only contains [[TargetToken]]'s with select keywords + * @param seq + * @param keywords + * @return + */ + def isOnly(seq: Seq[TargetToken], keywords:String*): Boolean = { + seq.map(_.is(keywords:_*)).foldLeft(false)(_ || _) && keywords.nonEmpty + } + + /** @return [[Target]] from human-readable serialization */ + def deserialize(s: String): Target = { + val regex = """(?=[~|>/:.\[@])""" + s.split(regex).foldLeft(GenericTarget(None, None, Vector.empty)) { (t, tokenString) => + val value = tokenString.tail + tokenString(0) match { + case '~' if t.circuitOpt.isEmpty && t.moduleOpt.isEmpty && t.tokens.isEmpty => + if(value == "???") t else t.copy(circuitOpt = Some(value)) + case '|' if t.moduleOpt.isEmpty && t.tokens.isEmpty => + if(value == "???") t else t.copy(moduleOpt = Some(value)) + case '/' => t.add(Instance(value)) + case ':' => t.add(OfModule(value)) + case '>' => t.add(Ref(value)) + case '.' => t.add(Field(value)) + case '[' if value.dropRight(1).toInt >= 0 => t.add(Index(value.dropRight(1).toInt)) + case '@' if value == "clock" => t.add(Clock) + case '@' if value == "init" => t.add(Init) + case '@' if value == "reset" => t.add(Reset) + case other => throw NamedException(s"Cannot deserialize Target: $s") + } + }.tryToComplete + } +} + +/** Represents incomplete or non-standard [[Target]]s + * @param circuitOpt Optional circuit name + * @param moduleOpt Optional module name + * @param tokens [[TargetToken]]s to represent the target in a circuit and module + */ +case class GenericTarget(circuitOpt: Option[String], + moduleOpt: Option[String], + tokens: Vector[TargetToken]) extends Target { + + override def toGenericTarget: GenericTarget = this + + override def toNamed: Named = getComplete match { + case Some(c: IsComponent) if c.isLocal => c.toNamed + case Some(c: ModuleTarget) => c.toNamed + case Some(c: CircuitTarget) => c.toNamed + case other => throw Target.NamedException(s"Cannot convert $this to [[Named]]") + } + + override def toTarget: CompleteTarget = getComplete.get + + override def getComplete: Option[CompleteTarget] = { + if(!isComplete) None else { + val target = this match { + case GenericTarget(Some(c), None, Vector()) => CircuitTarget(c) + case GenericTarget(Some(c), Some(m), Vector()) => ModuleTarget(c, m) + case GenericTarget(Some(c), Some(m), Ref(r) +: component) => ReferenceTarget(c, m, Nil, r, component) + case GenericTarget(Some(c), Some(m), Instance(i) +: OfModule(o) +: Vector()) => InstanceTarget(c, m, Nil, i, o) + case GenericTarget(Some(c), Some(m), component) => + val path = getPath.getOrElse(Nil) + (getRef, getInstanceOf) match { + case (Some((r, comps)), _) => ReferenceTarget(c, m, path, r, comps) + case (None, Some((i, o))) => InstanceTarget(c, m, path, i, o) + } + } + Some(target) + } + } + + override def isLocal: Boolean = !(getPath.nonEmpty && getPath.get.nonEmpty) + + /** If complete, return this [[GenericTarget]]'s path + * @return + */ + def getPath: Option[Seq[(Instance, OfModule)]] = if(isComplete) { + val allInstOfs = tokens.grouped(2).collect { case Seq(i: Instance, o:OfModule) => (i, o)}.toSeq + if(tokens.nonEmpty && tokens.last.isInstanceOf[OfModule]) Some(allInstOfs.dropRight(1)) else Some(allInstOfs) + } else { + None + } + + /** If complete and a reference, return the reference and subcomponents + * @return + */ + def getRef: Option[(String, Seq[TargetToken])] = if(isComplete) { + val (optRef, comps) = tokens.foldLeft((None: Option[String], Vector.empty[TargetToken])) { + case ((None, v), Ref(r)) => (Some(r), v) + case ((r: Some[String], comps), c) => (r, comps :+ c) + case ((r, v), other) => (None, v) + } + optRef.map(x => (x, comps)) + } else { + None + } + + /** If complete and an instance target, return the instance and ofmodule + * @return + */ + def getInstanceOf: Option[(String, String)] = if(isComplete) { + tokens.grouped(2).foldLeft(None: Option[(String, String)]) { + case (instOf, Seq(i: Instance, o: OfModule)) => Some((i.value, o.value)) + case (instOf, _) => None + } + } else { + None + } + + /** Requires the last [[TargetToken]] in tokens to be one of the [[TargetToken]] keywords + * @param default Return value if tokens is empty + * @param keywords + */ + private def requireLast(default: Boolean, keywords: String*): Unit = { + val isOne = if (tokens.isEmpty) default else tokens.last.is(keywords: _*) + require(isOne, s"${tokens.last} is not one of $keywords") + } + + /** Appends a target token to tokens, asserts legality + * @param token + * @return + */ + def add(token: TargetToken): GenericTarget = { + token match { + case _: Instance => requireLast(true, "inst", "of") + case _: OfModule => requireLast(false, "inst") + case _: Ref => requireLast(true, "inst", "of") + case _: Field => requireLast(true, "ref", "[]", ".", "init", "clock", "reset") + case _: Index => requireLast(true, "ref", "[]", ".", "init", "clock", "reset") + case Init => requireLast(true, "ref", "[]", ".", "init", "clock", "reset") + case Clock => requireLast(true, "ref", "[]", ".", "init", "clock", "reset") + case Reset => requireLast(true, "ref", "[]", ".", "init", "clock", "reset") + } + this.copy(tokens = tokens :+ token) + } + + /** Removes n number of target tokens from the right side of [[tokens]] */ + def remove(n: Int): GenericTarget = this.copy(tokens = tokens.dropRight(n)) + + /** Optionally tries to append token to tokens, fails return is not a legal Target */ + def optAdd(token: TargetToken): Option[Target] = { + try{ + Some(add(token)) + } catch { + case _: IllegalArgumentException => None + } + } + + /** Checks whether the component is legal (incomplete is ok) + * @return + */ + def isLegal: Boolean = { + try { + var comp: GenericTarget = this.copy(tokens = Vector.empty) + for(token <- tokens) { + comp = comp.add(token) + } + true + } catch { + case _: IllegalArgumentException => false + } + } + + /** Checks whether the component is legal and complete, meaning the circuitOpt and moduleOpt are nonEmpty and + * all Instance(_) are followed by OfModule(_) + * @return + */ + def isComplete: Boolean = { + isLegal && (isCircuitTarget || isModuleTarget || (isComponentTarget && tokens.tails.forall { + case Instance(_) +: OfModule(_) +: tail => true + case Instance(_) +: x +: tail => false + case x +: OfModule(_) +: tail => false + case _ => true + } )) + } + + + def isCircuitTarget: Boolean = circuitOpt.nonEmpty && moduleOpt.isEmpty && tokens.isEmpty + def isModuleTarget: Boolean = circuitOpt.nonEmpty && moduleOpt.nonEmpty && tokens.isEmpty + def isComponentTarget: Boolean = circuitOpt.nonEmpty && moduleOpt.nonEmpty && tokens.nonEmpty +} + +/** Concretely points to a FIRRTL target, no generic selectors + * IsLegal + */ +trait CompleteTarget extends Target { + + /** @return The circuit of this target */ + def circuit: String + + /** @return The [[CircuitTarget]] of this target's circuit */ + def circuitTarget: CircuitTarget = CircuitTarget(circuitOpt.get) + + def getComplete: Option[CompleteTarget] = Some(this) + + /** Adds another level of instance hierarchy + * Example: Given root=A and instance=b, transforms (Top, B)/c:C -> (Top, A)/b:B/c:C + * @param root + * @param instance + * @return + */ + def addHierarchy(root: String, instance: String): IsComponent + + override def toTarget: CompleteTarget = this +} + + +/** A member of a FIRRTL Circuit (e.g. cannot point to a CircuitTarget) + * Concrete Subclasses are: [[ModuleTarget]], [[InstanceTarget]], and [[ReferenceTarget]] + */ +trait IsMember extends CompleteTarget { + + /** @return Root module, e.g. top-level module of this target */ + def module: String + + /** @return Returns the instance hierarchy path, if one exists */ + def path: Seq[(Instance, OfModule)] + + /** @return Creates a path, assuming all Instance and OfModules in this [[IsMember]] is used as a path */ + def asPath: Seq[(Instance, OfModule)] + + /** @return Tokens of just this member's path */ + def justPath: Seq[TargetToken] + + /** @return Local tokens of what this member points (not a path) */ + def notPath: Seq[TargetToken] + + /** @return Same target without a path */ + def pathlessTarget: IsMember + + /** @return Member's path target */ + def pathTarget: CompleteTarget + + /** @return Member's top-level module target */ + def moduleTarget: ModuleTarget = ModuleTarget(circuitOpt.get, moduleOpt.get) + + /** @return Member's parent target */ + def targetParent: CompleteTarget + + /** @return List of local Instance Targets refering to each instance/ofModule in this member's path */ + def pathAsTargets: Seq[InstanceTarget] = { + val targets = mutable.ArrayBuffer[InstanceTarget]() + path.foldLeft((module, Vector.empty[InstanceTarget])) { + case ((m, vec), (Instance(i), OfModule(o))) => + (o, vec :+ InstanceTarget(circuit, m, Nil, i, o)) + }._2 + } + + /** Resets this target to have a new path + * @param newPath + * @return + */ + def setPathTarget(newPath: IsModule): CompleteTarget +} + +/** References a module-like target (e.g. a [[ModuleTarget]] or an [[InstanceTarget]]) + */ +trait IsModule extends IsMember { + + /** @return Creates a new Target, appending a ref */ + def ref(value: String): ReferenceTarget + + /** @return Creates a new Target, appending an instance and ofmodule */ + def instOf(instance: String, of: String): InstanceTarget +} + +/** A component of a FIRRTL Module (e.g. cannot point to a CircuitTarget or ModuleTarget) + */ +trait IsComponent extends IsMember { + + /** @return The [[ModuleTarget]] of the module that directly contains this component */ + def encapsulatingModule: String = if(path.isEmpty) module else path.last._2.value + + /** Removes n levels of instance hierarchy + * + * Example: n=1, transforms (Top, A)/b:B/c:C -> (Top, B)/c:C + * @param n + * @return + */ + def stripHierarchy(n: Int): IsMember + + override def toNamed: ComponentName = { + if(isLocal){ + val mn = ModuleName(module, CircuitName(circuit)) + Seq(tokens:_*) match { + case Seq(Ref(name)) => ComponentName(name, mn) + case Ref(_) :: tail if Target.isOnly(tail, ".", "[]") => + val name = tokens.foldLeft(""){ + case ("", Ref(name)) => name + case (string, Field(value)) => s"$string.$value" + case (string, Index(value)) => s"$string[$value]" + } + ComponentName(name, mn) + case Seq(Instance(name), OfModule(o)) => ComponentName(name, mn) + } + } else { + throw new Exception(s"Cannot convert $this to [[ComponentName]]") + } + } + + override def justPath: Seq[TargetToken] = path.foldLeft(Vector.empty[TargetToken]) { + case (vec, (i, o)) => vec ++ Seq(i, o) + } + + override def pathTarget: IsModule = { + if(path.isEmpty) moduleTarget else { + val (i, o) = path.last + InstanceTarget(circuit, module, path.dropRight(1), i.value, o.value) + } + } + + override def tokens = justPath ++ notPath + + override def isLocal = path.isEmpty +} + + +/** Target pointing to a FIRRTL [[firrtl.ir.Circuit]] + * @param circuit Name of a FIRRTL circuit + */ +case class CircuitTarget(circuit: String) extends CompleteTarget { + + /** Creates a [[ModuleTarget]] of provided name and this circuit + * @param m + * @return + */ + def module(m: String): ModuleTarget = ModuleTarget(circuit, m) + + override def circuitOpt: Option[String] = Some(circuit) + + override def moduleOpt: Option[String] = None + + override def tokens = Nil + + override def isLocal = true + + override def addHierarchy(root: String, instance: String): ReferenceTarget = + ReferenceTarget(circuit, root, Nil, instance, Nil) + + override def toNamed: CircuitName = CircuitName(circuit) +} + +/** Target pointing to a FIRRTL [[firrtl.ir.DefModule]] + * @param circuit Circuit containing the module + * @param module Name of the module + */ +case class ModuleTarget(circuit: String, module: String) extends IsModule { + + override def circuitOpt: Option[String] = Some(circuit) + + override def moduleOpt: Option[String] = Some(module) + + override def tokens: Seq[TargetToken] = Nil + + override def targetParent: CircuitTarget = CircuitTarget(circuit) + + override def addHierarchy(root: String, instance: String): InstanceTarget = InstanceTarget(circuit, root, Nil, instance, module) + + override def ref(value: String): ReferenceTarget = ReferenceTarget(circuit, module, Nil, value, Nil) + + override def instOf(instance: String, of: String): InstanceTarget = InstanceTarget(circuit, module, Nil, instance, of) + + override def asPath = Nil + + override def path: Seq[(Instance, OfModule)] = Nil + + override def justPath: Seq[TargetToken] = Nil + + override def notPath: Seq[TargetToken] = Nil + + override def pathlessTarget: ModuleTarget = this + + override def pathTarget: ModuleTarget = this + + override def isLocal = true + + override def setPathTarget(newPath: IsModule): IsModule = newPath + + override def toNamed: ModuleName = ModuleName(module, CircuitName(circuit)) +} + +/** Target pointing to a declared named component in a [[firrtl.ir.DefModule]] + * This includes: [[firrtl.ir.Port]], [[firrtl.ir.DefWire]], [[firrtl.ir.DefRegister]], [[firrtl.ir.DefInstance]], + * [[firrtl.ir.DefMemory]], [[firrtl.ir.DefNode]] + * @param circuit Name of the encapsulating circuit + * @param module Name of the root module of this reference + * @param path Path through instance/ofModules + * @param ref Name of component + * @param component Subcomponent of this reference, e.g. field or index + */ +case class ReferenceTarget(circuit: String, + module: String, + override val path: Seq[(Instance, OfModule)], + ref: String, + component: Seq[TargetToken]) extends IsComponent { + + /** @param value Index value of this target + * @return A new [[ReferenceTarget]] to the specified index of this [[ReferenceTarget]] + */ + def index(value: Int): ReferenceTarget = ReferenceTarget(circuit, module, path, ref, component :+ Index(value)) + + /** @param value Field name of this target + * @return A new [[ReferenceTarget]] to the specified field of this [[ReferenceTarget]] + */ + def field(value: String): ReferenceTarget = ReferenceTarget(circuit, module, path, ref, component :+ Field(value)) + + /** @return The initialization value of this reference, must be to a [[firrtl.ir.DefRegister]] */ + def init: ReferenceTarget = ReferenceTarget(circuit, module, path, ref, component :+ Init) + + /** @return The reset signal of this reference, must be to a [[firrtl.ir.DefRegister]] */ + def reset: ReferenceTarget = ReferenceTarget(circuit, module, path, ref, component :+ Reset) + + /** @return The clock signal of this reference, must be to a [[firrtl.ir.DefRegister]] */ + def clock: ReferenceTarget = ReferenceTarget(circuit, module, path, ref, component :+ Clock) + + override def circuitOpt: Option[String] = Some(circuit) + + override def moduleOpt: Option[String] = Some(module) + + override def targetParent: CompleteTarget = component match { + case Nil => + if(path.isEmpty) moduleTarget else { + val (i, o) = path.last + InstanceTarget(circuit, module, path.dropRight(1), i.value, o.value) + } + case other => ReferenceTarget(circuit, module, path, ref, component.dropRight(1)) + } + + override def notPath: Seq[TargetToken] = Ref(ref) +: component + + override def addHierarchy(root: String, instance: String): ReferenceTarget = + ReferenceTarget(circuit, root, (Instance(instance), OfModule(module)) +: path, ref, component) + + override def stripHierarchy(n: Int): ReferenceTarget = { + require(path.size >= n, s"Cannot strip $n levels of hierarchy from $this") + if(n == 0) this else { + val newModule = path(n - 1)._2.value + ReferenceTarget(circuit, newModule, path.drop(n), ref, component) + } + } + + override def pathlessTarget: ReferenceTarget = ReferenceTarget(circuit, encapsulatingModule, Nil, ref, component) + + override def setPathTarget(newPath: IsModule): ReferenceTarget = + ReferenceTarget(newPath.circuit, newPath.module, newPath.asPath, ref, component) + + override def asPath: Seq[(Instance, OfModule)] = path +} + +/** Points to an instance declaration of a module (termed an ofModule) + * @param circuit Encapsulating circuit + * @param module Root module (e.g. the base module of this target) + * @param path Path through instance/ofModules + * @param instance Name of the instance + * @param ofModule Name of the instance's module + */ +case class InstanceTarget(circuit: String, + module: String, + override val path: Seq[(Instance, OfModule)], + instance: String, + ofModule: String) extends IsModule with IsComponent { + + /** @return a [[ReferenceTarget]] referring to this declaration of this instance */ + def asReference: ReferenceTarget = ReferenceTarget(circuit, module, path, instance, Nil) + + /** @return a [[ReferenceTarget]] referring to declaration of this ofModule */ + def ofModuleTarget: ModuleTarget = ModuleTarget(circuit, ofModule) + + override def circuitOpt: Option[String] = Some(circuit) + + override def moduleOpt: Option[String] = Some(module) + + override def targetParent: IsModule = { + if(isLocal) ModuleTarget(circuit, module) else { + val (newInstance, newOfModule) = path.last + InstanceTarget(circuit, module, path.dropRight(1), newInstance.value, newOfModule.value) + } + } + + override def addHierarchy(root: String, inst: String): InstanceTarget = + InstanceTarget(circuit, root, (Instance(inst), OfModule(module)) +: path, instance, ofModule) + + override def ref(value: String): ReferenceTarget = ReferenceTarget(circuit, module, asPath, value, Nil) + + override def instOf(inst: String, of: String): InstanceTarget = InstanceTarget(circuit, module, asPath, inst, of) + + override def stripHierarchy(n: Int): IsModule = { + require(path.size >= n, s"Cannot strip $n levels of hierarchy from $this") + if(n == 0) this else { + val newModule = path(n - 1)._2.value + InstanceTarget(circuit, newModule, path.drop(n), instance, ofModule) + } + } + + override def asPath: Seq[(Instance, OfModule)] = path :+ (Instance(instance), OfModule(ofModule)) + + override def pathlessTarget: InstanceTarget = InstanceTarget(circuit, encapsulatingModule, Nil, instance, ofModule) + + override def notPath = Seq(Instance(instance), OfModule(ofModule)) + + override def setPathTarget(newPath: IsModule): InstanceTarget = + InstanceTarget(newPath.circuit, newPath.module, newPath.asPath, instance, ofModule) +} + + +/** Named classes associate an annotation with a component in a Firrtl circuit */ +@deprecated("Use Target instead, will be removed in 1.3", "1.2") +sealed trait Named { + def serialize: String + def toTarget: CompleteTarget +} + +@deprecated("Use Target instead, will be removed in 1.3", "1.2") +final case class CircuitName(name: String) extends Named { + if(!validModuleName(name)) throw AnnotationException(s"Illegal circuit name: $name") + def serialize: String = name + def toTarget: CircuitTarget = CircuitTarget(name) +} + +@deprecated("Use Target instead, will be removed in 1.3", "1.2") +final case class ModuleName(name: String, circuit: CircuitName) extends Named { + if(!validModuleName(name)) throw AnnotationException(s"Illegal module name: $name") + def serialize: String = circuit.serialize + "." + name + def toTarget: ModuleTarget = ModuleTarget(circuit.name, name) +} + +@deprecated("Use Target instead, will be removed in 1.3", "1.2") +final case class ComponentName(name: String, module: ModuleName) extends Named { + if(!validComponentName(name)) throw AnnotationException(s"Illegal component name: $name") + def expr: Expression = toExp(name) + def serialize: String = module.serialize + "." + name + def toTarget: ReferenceTarget = { + Target.toTargetTokens(name).toList match { + case Ref(r) :: components => ReferenceTarget(module.circuit.name, module.name, Nil, r, components) + case other => throw Target.NamedException(s"Cannot convert $this into [[ReferenceTarget]]: $other") + } + } +} diff --git a/src/main/scala/firrtl/annotations/TargetToken.scala b/src/main/scala/firrtl/annotations/TargetToken.scala new file mode 100644 index 00000000..587f30eb --- /dev/null +++ b/src/main/scala/firrtl/annotations/TargetToken.scala @@ -0,0 +1,46 @@ +// See LICENSE for license details. + +package firrtl.annotations + +/** Building block to represent a [[Target]] of a FIRRTL component */ +sealed trait TargetToken { + def keyword: String + def value: Any + + /** Returns whether this token is one of the type of tokens whose keyword is passed as an argument + * @param keywords + * @return + */ + def is(keywords: String*): Boolean = { + keywords.map { kw => + require(TargetToken.keyword2targettoken.keySet.contains(kw), + s"Keyword $kw must be in set ${TargetToken.keyword2targettoken.keys}") + val lastClass = this.getClass + lastClass == TargetToken.keyword2targettoken(kw)("0").getClass + }.reduce(_ || _) + } +} + +/** Object containing all [[TargetToken]] subclasses */ +case object TargetToken { + case class Instance(value: String) extends TargetToken { override def keyword: String = "inst" } + case class OfModule(value: String) extends TargetToken { override def keyword: String = "of" } + case class Ref(value: String) extends TargetToken { override def keyword: String = "ref" } + case class Index(value: Int) extends TargetToken { override def keyword: String = "[]" } + case class Field(value: String) extends TargetToken { override def keyword: String = "." } + case object Clock extends TargetToken { override def keyword: String = "clock"; val value = "" } + case object Init extends TargetToken { override def keyword: String = "init"; val value = "" } + case object Reset extends TargetToken { override def keyword: String = "reset"; val value = "" } + + val keyword2targettoken = Map( + "inst" -> ((value: String) => Instance(value)), + "of" -> ((value: String) => OfModule(value)), + "ref" -> ((value: String) => Ref(value)), + "[]" -> ((value: String) => Index(value.toInt)), + "." -> ((value: String) => Field(value)), + "clock" -> ((value: String) => Clock), + "init" -> ((value: String) => Init), + "reset" -> ((value: String) => Reset) + ) +} + diff --git a/src/main/scala/firrtl/annotations/analysis/DuplicationHelper.scala b/src/main/scala/firrtl/annotations/analysis/DuplicationHelper.scala new file mode 100644 index 00000000..ba3ca9a9 --- /dev/null +++ b/src/main/scala/firrtl/annotations/analysis/DuplicationHelper.scala @@ -0,0 +1,146 @@ +// See LICENSE for license details. + +package firrtl.annotations.analysis + +import firrtl.Namespace +import firrtl.annotations._ +import firrtl.annotations.TargetToken.{Instance, OfModule, Ref} +import firrtl.Utils.throwInternalError + +import scala.collection.mutable + +/** Used by [[firrtl.annotations.transforms.EliminateTargetPaths]] to eliminate target paths + * Calculates needed modifications to a circuit's module/instance hierarchy + */ +case class DuplicationHelper(existingModules: Set[String]) { + // Maps instances to the module it instantiates (an ofModule) + type InstanceOfModuleMap = mutable.HashMap[Instance, OfModule] + + // Maps a module to the instance/ofModules it instantiates + type ModuleHasInstanceOfModuleMap = mutable.HashMap[String, InstanceOfModuleMap] + + // Maps original module names to new duplicated modules and their encapsulated instance/ofModules + type DupMap = mutable.HashMap[String, ModuleHasInstanceOfModuleMap] + + // Internal state to keep track of how paths duplicate + private val dupMap = new DupMap() + + // Internal record of which paths are renamed to which new names, in the case of a collision + private val cachedNames = mutable.HashMap[(String, Seq[(Instance, OfModule)]), String]() ++ + existingModules.map(m => (m, Nil) -> m) + + // Internal record of all paths to ensure unique name generation + private val allModules = mutable.HashSet[String]() ++ existingModules + + /** Updates internal state (dupMap) to calculate instance hierarchy modifications so t's tokens in an instance can be + * expressed as a tokens in a module (e.g. uniquify/duplicate the instance path in t's tokens) + * @param t An instance-resolved component + */ + def expandHierarchy(t: IsMember): Unit = { + val path = t.asPath + path.reverse.tails.map { _.reverse }.foreach { duplicate(t.module, _) } + } + + /** Updates dupMap with how original module names map to new duplicated module names + * @param top Root module of a component + * @param path Path down instance hierarchy of a component + */ + private def duplicate(top: String, path: Seq[(Instance, OfModule)]): Unit = { + val (originalModule, instance, ofModule) = path.size match { + case 0 => return + case 1 => (top, path.head._1, path.head._2) + case _ => (path(path.length - 2)._2.value, path.last._1, path.last._2) + } + val originalModuleToDupedModule = dupMap.getOrElseUpdate(originalModule, new ModuleHasInstanceOfModuleMap()) + val dupedModule = getModuleName(top, path.dropRight(1)) + val dupedModuleToInstances = originalModuleToDupedModule.getOrElseUpdate(dupedModule, new InstanceOfModuleMap()) + val dupedInstanceModule = getModuleName(top, path) + dupedModuleToInstances += ((instance, OfModule(dupedInstanceModule))) + + val originalInstanceModuleToDupedModule = dupMap.getOrElseUpdate(ofModule.value, new ModuleHasInstanceOfModuleMap()) + originalInstanceModuleToDupedModule.getOrElseUpdate(dupedInstanceModule, new InstanceOfModuleMap()) + } + + /** Deterministic name-creation of a duplicated module + * @param top + * @param path + * @return + */ + def getModuleName(top: String, path: Seq[(Instance, OfModule)]): String = { + cachedNames.get((top, path)) match { + case None => // Need a new name + val prefix = path.last._2.value + "___" + val postfix = top + "_" + path.map { case (i, m) => i.value }.mkString("_") + val ns = mutable.HashSet(allModules.toSeq: _*) + val finalName = firrtl.passes.Uniquify.findValidPrefix(prefix, Seq(postfix), ns) + postfix + allModules += finalName + cachedNames((top, path)) = finalName + finalName + case Some(newName) => newName + } + } + + /** Return the duplicated module (formerly originalOfModule) instantiated by instance in newModule (formerly + * originalModule) + * @param originalModule original encapsulating module + * @param newModule new name of encapsulating module + * @param instance instance name being declared in encapsulating module + * @param originalOfModule original module being instantiated in originalModule + * @return + */ + def getNewOfModule(originalModule: String, + newModule: String, + instance: Instance, + originalOfModule: OfModule): OfModule = { + dupMap.get(originalModule) match { + case None => // No duplication, can return originalOfModule + originalOfModule + case Some(newDupedModules) => + newDupedModules.get(newModule) match { + case None if newModule != originalModule => throwInternalError("BAD") + case None => // No duplication, can return originalOfModule + originalOfModule + case Some(newDupedModule) => + newDupedModule.get(instance) match { + case None => // Not duped, can return originalOfModule + originalOfModule + case Some(newOfModule) => + newOfModule + } + } + } + } + + /** Returns the names of this module's duplicated (including the original name) + * @param module + * @return + */ + def getDuplicates(module: String): Set[String] = { + dupMap.get(module).map(_.keys.toSet[String]).getOrElse(Set.empty[String]) ++ Set(module) + } + + /** Rewrites t with new module/instance hierarchy calculated after repeated calls to [[expandHierarchy]] + * @param t A target + * @return t rewritten, is a seq because if the t.module has been duplicated, it must now refer to multiple modules + */ + def makePathless(t: IsMember): Seq[IsMember] = { + val top = t.module + val path = t.asPath + val newTops = getDuplicates(top) + newTops.map { newTop => + val newPath = mutable.ArrayBuffer[TargetToken]() + path.foldLeft((top, newTop)) { case ((originalModule, newModule), (instance, ofModule)) => + val newOfModule = getNewOfModule(originalModule, newModule, instance, ofModule) + newPath ++= Seq(instance, newOfModule) + (ofModule.value, newOfModule.value) + } + val module = if(newPath.nonEmpty) newPath.last.value.toString else newTop + t.notPath match { + case Seq() => ModuleTarget(t.circuit, module) + case Instance(i) +: OfModule(m) +: Seq() => ModuleTarget(t.circuit, module) + case Ref(r) +: components => ReferenceTarget(t.circuit, module, Nil, r, components) + } + }.toSeq + } +} + diff --git a/src/main/scala/firrtl/annotations/transforms/EliminateTargetPaths.scala b/src/main/scala/firrtl/annotations/transforms/EliminateTargetPaths.scala new file mode 100644 index 00000000..8f604c9f --- /dev/null +++ b/src/main/scala/firrtl/annotations/transforms/EliminateTargetPaths.scala @@ -0,0 +1,167 @@ +// See LICENSE for license details. + +package firrtl.annotations.transforms + +import firrtl.Mappers._ +import firrtl.analyses.InstanceGraph +import firrtl.annotations.TargetToken.{Instance, OfModule} +import firrtl.annotations.analysis.DuplicationHelper +import firrtl.annotations._ +import firrtl.ir._ +import firrtl.{CircuitForm, CircuitState, FIRRTLException, HighForm, RenameMap, Transform, WDefInstance} + +import scala.collection.mutable + + +/** Group of targets that should become local targets + * @param targets + */ +case class ResolvePaths(targets: Seq[CompleteTarget]) extends Annotation { + override def update(renames: RenameMap): Seq[Annotation] = { + val newTargets = targets.flatMap(t => renames.get(t).getOrElse(Seq(t))) + Seq(ResolvePaths(newTargets)) + } +} + +case class NoSuchTargetException(message: String) extends FIRRTLException(message) + +/** For a set of non-local targets, modify the instance/module hierarchy of the circuit such that + * the paths in each non-local target can be removed + * + * In other words, if targeting a specific instance of a module, duplicate that module with a unique name + * and instantiate the new module instead. + * + * Consumes [[ResolvePaths]] + * + * E.g. for non-local target A/b:B/c:C/d, rename the following + * A/b:B/c:C/d -> C_/d + * A/b:B/c:C -> B_/c:C_ + * A/b:B -> A/b:B_ + * B/x -> (B/x, B_/x) // where x is any reference in B + * C/x -> (C/x, C_/x) // where x is any reference in C + */ +class EliminateTargetPaths extends Transform { + + def inputForm: CircuitForm = HighForm + + def outputForm: CircuitForm = HighForm + + /** Replaces old ofModules with new ofModules by calling dupMap methods + * Updates oldUsedOfModules, newUsedOfModules + * @param originalModule Original name of this module + * @param newModule New name of this module + * @param s + * @return + */ + private def onStmt(dupMap: DuplicationHelper, + oldUsedOfModules: mutable.HashSet[String], + newUsedOfModules: mutable.HashSet[String]) + (originalModule: String, newModule: String) + (s: Statement): Statement = s match { + case d@DefInstance(_, name, module) => + val ofModule = dupMap.getNewOfModule(originalModule, newModule, Instance(name), OfModule(module)).value + newUsedOfModules += ofModule + oldUsedOfModules += module + d.copy(module = ofModule) + case d@WDefInstance(_, name, module, _) => + val ofModule = dupMap.getNewOfModule(originalModule, newModule, Instance(name), OfModule(module)).value + newUsedOfModules += ofModule + oldUsedOfModules += module + d.copy(module = ofModule) + case other => other map onStmt(dupMap, oldUsedOfModules, newUsedOfModules)(originalModule, newModule) + } + + /** Returns a modified circuit and [[RenameMap]] containing the associated target remapping + * @param cir + * @param targets + * @return + */ + def run(cir: Circuit, targets: Seq[IsMember]): (Circuit, RenameMap) = { + + val dupMap = DuplicationHelper(cir.modules.map(_.name).toSet) + + // For each target, record its path and calculate the necessary modifications to circuit + targets.foreach { t => dupMap.expandHierarchy(t) } + + // Records original list of used ofModules + val oldUsedOfModules = mutable.HashSet[String]() + oldUsedOfModules += cir.main + + // Records new list of used ofModules + val newUsedOfModules = mutable.HashSet[String]() + newUsedOfModules += cir.main + + // Contains new list of module declarations + val duplicatedModuleList = mutable.ArrayBuffer[DefModule]() + + // Foreach module, calculate the unique names of its duplicates + // Then, update the ofModules of instances that it encapsulates + cir.modules.foreach { m => + dupMap.getDuplicates(m.name).foreach { newName => + val newM = m match { + case e: ExtModule => e.copy(name = newName) + case o: Module => + o.copy(name = newName, body = onStmt(dupMap, oldUsedOfModules, newUsedOfModules)(m.name, newName)(o.body)) + } + duplicatedModuleList += newM + } + } + + // Calculate the final module list + // A module is in the final list if: + // 1) it is a module that is instantiated (new or old) + // 2) it is an old module that was not instantiated and is still not instantiated + val finalModuleList = duplicatedModuleList.filter(m => + newUsedOfModules.contains(m.name) || (!newUsedOfModules.contains(m.name) && !oldUsedOfModules.contains(m.name)) + ) + + // Records how targets have been renamed + val renameMap = RenameMap() + + // Foreach target, calculate the pathless version and only rename targets that are instantiated + targets.foreach { t => + val newTsx = dupMap.makePathless(t) + val newTs = newTsx.filter(c => newUsedOfModules.contains(c.moduleOpt.get)) + if(newTs.nonEmpty) { + renameMap.record(t, newTs) + } + } + + // Return modified circuit and associated renameMap + (cir.copy(modules = finalModuleList), renameMap) + } + + override protected def execute(state: CircuitState): CircuitState = { + + val annotations = state.annotations.collect { case a: ResolvePaths => a } + + // Collect targets that are not local + val targets = annotations.flatMap(_.targets.collect { case x: IsMember => x }) + + // Check validity of paths in targets + val instanceOfModules = new InstanceGraph(state.circuit).getChildrenInstanceOfModule + val targetsWithInvalidPaths = mutable.ArrayBuffer[IsMember]() + targets.foreach { t => + val path = t match { + case m: ModuleTarget => Nil + case i: InstanceTarget => i.asPath + case r: ReferenceTarget => r.path + } + path.foldLeft(t.module) { case (module, (inst: Instance, of: OfModule)) => + val childrenOpt = instanceOfModules.get(module) + if(childrenOpt.isEmpty || !childrenOpt.get.contains((inst, of))) { + targetsWithInvalidPaths += t + } + of.value + } + } + if(targetsWithInvalidPaths.nonEmpty) { + val string = targetsWithInvalidPaths.mkString(",") + throw NoSuchTargetException(s"""Some targets have illegal paths that cannot be resolved/eliminated: $string""") + } + + val (newCircuit, renameMap) = run(state.circuit, targets) + + state.copy(circuit = newCircuit, renames = Some(renameMap)) + } +} diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala index f963e762..d6af69c1 100644 --- a/src/main/scala/firrtl/passes/Inline.scala +++ b/src/main/scala/firrtl/passes/Inline.scala @@ -128,11 +128,13 @@ class InlineInstances extends Transform { val port = ComponentName(s"$ref.$field", currentModule) val inst = ComponentName(s"$ref", currentModule) (renames.get(port), renames.get(inst)) match { - case (Some(p :: Nil), None) => WRef(p.name, tpe, WireKind, gen) + case (Some(p :: Nil), _) => + p.toTarget match { + case ReferenceTarget(_, _, Seq(), r, Seq(TargetToken.Field(f))) => wsf.copy(expr = wr.copy(name = r), name = f) + case ReferenceTarget(_, _, Seq(), r, Seq()) => WRef(r, tpe, WireKind, gen) + } case (None, Some(i :: Nil)) => wsf.map(appendRefPrefix(currentModule, renames)) case (None, None) => wsf - case (Some(p), Some(i)) => throw new PassException( - s"Inlining found multiple renames for ports ($p) and/or instances ($i). This should be impossible...") } case wr@ WRef(name, _, _, _) => val comp = ComponentName(name, currentModule) diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala index 10d4e97f..73f967f4 100644 --- a/src/main/scala/firrtl/passes/Uniquify.scala +++ b/src/main/scala/firrtl/passes/Uniquify.scala @@ -36,7 +36,7 @@ object Uniquify extends Transform { def outputForm = UnknownForm private case class UniquifyException(msg: String) extends FIRRTLException(msg) private def error(msg: String)(implicit sinfo: Info, mname: String) = - throw new UniquifyException(s"$sinfo: [module $mname] $msg") + throw new UniquifyException(s"$sinfo: [moduleOpt $mname] $msg") // For creation of rename map private case class NameMapNode(name: String, elts: Map[String, NameMapNode]) @@ -45,7 +45,7 @@ object Uniquify extends Transform { // We don't add an _ in the collision check because elts could be Seq("") // In this case, we're just really checking if prefix itself collides @tailrec - private [firrtl] def findValidPrefix( + def findValidPrefix( prefix: String, elts: Seq[String], namespace: collection.mutable.HashSet[String]): String = { diff --git a/src/main/scala/firrtl/passes/VerilogRename.scala b/src/main/scala/firrtl/passes/VerilogRename.scala new file mode 100644 index 00000000..4d51128c --- /dev/null +++ b/src/main/scala/firrtl/passes/VerilogRename.scala @@ -0,0 +1,11 @@ +package firrtl.passes +import firrtl.ir.Circuit +import firrtl.transforms.VerilogRename + +@deprecated("Use transforms.VerilogRename, will be removed in 1.3", "1.2") +object VerilogRename extends Pass { + override def run(c: Circuit): Circuit = new VerilogRename().run(c) + @deprecated("Use transforms.VerilogRename, will be removed in 1.3", "1.2") + def verilogRenameN(n: String): String = + if (firrtl.Utils.v_keywords(n)) "%s$".format(n) else n +} diff --git a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala index bb73beb4..6927075e 100644 --- a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala +++ b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala @@ -26,7 +26,7 @@ case class SinkAnnotation(target: Named, pin: String) extends def duplicate(n: Named) = this.copy(target = n) } -/** Wires a Module's Source Component to one or more Sink +/** Wires a Module's Source Target to one or more Sink * Modules/Components * * Sinks are wired to their closest source through their lowest diff --git a/src/main/scala/firrtl/passes/wiring/WiringUtils.scala b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala index b89649d3..c5a7f21b 100644 --- a/src/main/scala/firrtl/passes/wiring/WiringUtils.scala +++ b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala @@ -182,7 +182,7 @@ object WiringUtils { .collect { case (k, v) if sinkInsts.contains(k) => (k, v.flatten) }.toMap } - /** Helper script to extract a module name from a named Module or Component */ + /** Helper script to extract a module name from a named Module or Target */ def getModuleName(n: Named): String = { n match { case ModuleName(m, _) => m diff --git a/src/main/scala/firrtl/transforms/CheckCombLoops.scala b/src/main/scala/firrtl/transforms/CheckCombLoops.scala index 98033a2f..44785c62 100644 --- a/src/main/scala/firrtl/transforms/CheckCombLoops.scala +++ b/src/main/scala/firrtl/transforms/CheckCombLoops.scala @@ -26,9 +26,9 @@ case object DontCheckCombLoopsAnnotation extends NoTargetAnnotation case class CombinationalPath(sink: ComponentName, sources: Seq[ComponentName]) extends Annotation { override def update(renames: RenameMap): Seq[Annotation] = { - val newSources = sources.flatMap { s => renames.get(s).getOrElse(Seq(s)) } - val newSinks = renames.get(sink).getOrElse(Seq(sink)) - newSinks.map(snk => CombinationalPath(snk, newSources)) + val newSources: Seq[IsComponent] = sources.flatMap { s => renames.get(s).getOrElse(Seq(s.toTarget)) }.collect {case x: IsComponent if x.isLocal => x} + val newSinks = renames.get(sink).getOrElse(Seq(sink.toTarget)).collect { case x: IsComponent if x.isLocal => x} + newSinks.map(snk => CombinationalPath(snk.toNamed, newSources.map(_.toNamed))) } } diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 0d30446c..da7f1a46 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -12,6 +12,7 @@ import firrtl.PrimOps._ import firrtl.graph.DiGraph import firrtl.WrappedExpression.weq import firrtl.analyses.InstanceGraph +import firrtl.annotations.TargetToken.Ref import annotation.tailrec import collection.mutable @@ -46,11 +47,13 @@ object ConstantPropagation { } } -class ConstantPropagation extends Transform { +class ConstantPropagation extends Transform with ResolvedAnnotationPaths { import ConstantPropagation._ def inputForm = LowForm def outputForm = LowForm + override val annotationClasses: Traversable[Class[_]] = Seq(classOf[DontTouchAnnotation]) + trait FoldCommutativeOp { def fold(c1: Literal, c2: Literal): Expression def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression @@ -520,7 +523,7 @@ class ConstantPropagation extends Transform { def execute(state: CircuitState): CircuitState = { val dontTouches: Seq[(String, String)] = state.annotations.collect { - case DontTouchAnnotation(ComponentName(c, ModuleName(m, _))) => m -> c + case DontTouchAnnotation(Target(_, Some(m), Seq(Ref(c)))) => m -> c } // Map from module name to component names val dontTouchMap: Map[String, Set[String]] = diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala index c98b892c..523c997b 100644 --- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala +++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala @@ -30,7 +30,7 @@ import java.io.{File, FileWriter} * circumstances of their instantiation in their parent module, they will still not be removed. To * remove such modules, use the [[NoDedupAnnotation]] to prevent deduplication. */ -class DeadCodeElimination extends Transform { +class DeadCodeElimination extends Transform with ResolvedAnnotationPaths { def inputForm = LowForm def outputForm = LowForm @@ -321,9 +321,12 @@ class DeadCodeElimination extends Transform { state.copy(circuit = newCircuit, renames = Some(renames)) } + override val annotationClasses: Traversable[Class[_]] = + Seq(classOf[DontTouchAnnotation], classOf[OptimizableExtModuleAnnotation]) + def execute(state: CircuitState): CircuitState = { val dontTouches: Seq[LogicNode] = state.annotations.collect { - case DontTouchAnnotation(component) => LogicNode(component) + case DontTouchAnnotation(component: ReferenceTarget) if component.isLocal => LogicNode(component) } val doTouchExtMods: Seq[String] = state.annotations.collect { case OptimizableExtModuleAnnotation(ModuleName(name, _)) => name diff --git a/src/main/scala/firrtl/transforms/Dedup.scala b/src/main/scala/firrtl/transforms/Dedup.scala index 5630cecf..1c20b448 100644 --- a/src/main/scala/firrtl/transforms/Dedup.scala +++ b/src/main/scala/firrtl/transforms/Dedup.scala @@ -6,17 +6,18 @@ package transforms import firrtl.ir._ import firrtl.Mappers._ import firrtl.analyses.InstanceGraph +import firrtl.annotations.TargetToken.{Instance, OfModule, Ref} import firrtl.annotations._ import firrtl.passes.{InferTypes, MemPortUtils} +import firrtl.Utils.throwInternalError // Datastructures import scala.collection.mutable -/** A component, e.g. register etc. Must be declared only once under the TopAnnotation - */ +/** A component, e.g. register etc. Must be declared only once under the TopAnnotation */ case class NoDedupAnnotation(target: ModuleName) extends SingleTargetAnnotation[ModuleName] { - def duplicate(n: ModuleName) = NoDedupAnnotation(n) + def duplicate(n: ModuleName): NoDedupAnnotation = NoDedupAnnotation(n) } /** Only use on legal Firrtl. @@ -28,62 +29,63 @@ class DedupModules extends Transform { def inputForm: CircuitForm = HighForm def outputForm: CircuitForm = HighForm - /** - * Deduplicate a Circuit + /** Deduplicate a Circuit * @param state Input Firrtl AST * @return A transformed Firrtl AST */ def execute(state: CircuitState): CircuitState = { val noDedups = state.annotations.collect { case NoDedupAnnotation(ModuleName(m, c)) => m } - val (newC, renameMap) = run(state.circuit, noDedups) + val (newC, renameMap) = run(state.circuit, noDedups, state.annotations) state.copy(circuit = newC, renames = Some(renameMap)) } - /** - * Deduplicates a circuit, and records renaming + /** Deduplicates a circuit, and records renaming * @param c Circuit to dedup * @param noDedups Modules not to dedup * @return Deduped Circuit and corresponding RenameMap */ - def run(c: Circuit, noDedups: Seq[String]): (Circuit, RenameMap) = { + def run(c: Circuit, noDedups: Seq[String], annos: Seq[Annotation]): (Circuit, RenameMap) = { // RenameMap val renameMap = RenameMap() renameMap.setCircuit(c.main) // Maps module name to corresponding dedup module - val dedupMap = DedupModules.deduplicate(c, noDedups.toSet, renameMap) + val dedupMap = DedupModules.deduplicate(c, noDedups.toSet, annos, renameMap) // Use old module list to preserve ordering val dedupedModules = c.modules.map(m => dedupMap(m.name)).distinct val cname = CircuitName(c.main) - renameMap.addMap(dedupMap.map { case (from, to) => + val map = dedupMap.map { case (from, to) => logger.debug(s"[Dedup] $from -> ${to.name}") ModuleName(from, cname) -> List(ModuleName(to.name, cname)) - }) + } + renameMap.recordAll( + map.map { + case (k: ModuleName, v: List[ModuleName]) => Target.convertNamed2Target(k) -> v.map(Target.convertNamed2Target) + } + ) (InferTypes.run(c.copy(modules = dedupedModules)), renameMap) } } -/** - * Utility functions for [[DedupModules]] - */ +/** Utility functions for [[DedupModules]] */ object DedupModules { - /** - * Change's a module's internal signal names, types, infos, and modules. + + /** Change's a module's internal signal names, types, infos, and modules. * @param rename Function to rename a signal. Called on declaration and references. * @param retype Function to retype a signal. Called on declaration, references, and subfields * @param reinfo Function to re-info a statement - * @param renameModule Function to rename an instance's module + * @param renameOfModule Function to rename an instance's module * @param module Module to change internals * @return Changed Module */ def changeInternals(rename: String=>String, retype: String=>Type=>Type, reinfo: Info=>Info, - renameModule: String=>String + renameOfModule: (String, String)=>String )(module: DefModule): DefModule = { def onPort(p: Port): Port = Port(reinfo(p.info), rename(p.name), p.direction, retype(p.name)(p.tpe)) def onExp(e: Expression): Expression = e match { @@ -99,9 +101,9 @@ object DedupModules { } def onStmt(s: Statement): Statement = s match { case WDefInstance(i, n, m, t) => - val newmod = renameModule(m) + val newmod = renameOfModule(n, m) WDefInstance(reinfo(i), rename(n), newmod, retype(n)(t)) - case DefInstance(i, n, m) => DefInstance(reinfo(i), rename(n), renameModule(m)) + case DefInstance(i, n, m) => DefInstance(reinfo(i), rename(n), renameOfModule(n, m)) case d: DefMemory => val oldType = MemPortUtils.memType(d) val newType = retype(d.name)(oldType) @@ -129,49 +131,79 @@ object DedupModules { module map onPort map onStmt } - /** - * Turns a module into a name-agnostic module + def uniquifyField(ref: String, depth: Int, field: String): String = ref + depth + field + + /** Turns a module into a name-agnostic module * @param module module to change * @return name-agnostic module */ - def agnostify(module: DefModule, name2tag: mutable.HashMap[String, String], tag2name: mutable.HashMap[String, String]): DefModule = { + def agnostify(top: CircuitTarget, + module: DefModule, + renameMap: RenameMap + ): DefModule = { + + val namespace = Namespace() - val nameMap = mutable.HashMap[String, String]() val typeMap = mutable.HashMap[String, Type]() + + renameMap.setCircuit(top.circuitOpt.get) + renameMap.setModule(module.name) + def rename(name: String): String = { - if (nameMap.contains(name)) nameMap(name) else { - val newName = namespace.newTemp - nameMap(name) = newName - newName + val ret = renameMap.get(top.module(module.name).ref(name)) + ret match { + case Some(Seq(Target(_, _, Seq(Ref(x))))) => x + case None => + val newName = namespace.newTemp + renameMap.rename(name, newName) + newName + case other => throwInternalError(other.toString) } } + def retype(name: String)(tpe: Type): Type = { if (typeMap.contains(name)) typeMap(name) else { - def onType(tpe: Type): Type = tpe map onType match { - case BundleType(fields) => BundleType(fields.map(f => Field(rename(f.name), f.flip, f.tpe))) + def onType(depth: Int)(tpe: Type): Type = tpe map onType(depth + 1) match { + //TODO bugfix: ref.data.data and ref.datax.data will not rename to the right tags, even if they should be + case BundleType(fields) => + BundleType(fields.map(f => Field(rename(uniquifyField(name, depth, f.name)), f.flip, f.tpe))) case other => other } - val newType = onType(tpe) + val newType = onType(0)(tpe) typeMap(name) = newType newType } } - def remodule(name: String): String = tag2name(name2tag(name)) - changeInternals(rename, retype, {i: Info => NoInfo}, remodule)(module) + + def reOfModule(instance: String, ofModule: String): String = { + renameMap.get(top.module(ofModule)) match { + case Some(Seq(Target(_, Some(ofModuleTag), Nil))) => ofModuleTag + case None => ofModule + case other => throwInternalError(other.toString) + } + } + + val renamedModule = changeInternals(rename, retype, {i: Info => NoInfo}, reOfModule)(module) + renamedModule } /** Dedup a module's instances based on dedup map * * Will fixes up module if deduped instance's ports are differently named * - * @param moduleName Module name who's instances will be deduped + * @param top CircuitTarget of circuit + * @param originalModule Module name who's instances will be deduped * @param moduleMap Map of module name to its original module * @param name2name Map of module name to the module deduping it. Not mutated in this function. * @param renameMap Will be modified to keep track of renames in this function * @return fixed up module deduped instances */ - def dedupInstances(moduleName: String, moduleMap: Map[String, DefModule], name2name: mutable.Map[String, String], renameMap: RenameMap): DefModule = { - val module = moduleMap(moduleName) + def dedupInstances(top: CircuitTarget, + originalModule: String, + moduleMap: Map[String, DefModule], + name2name: Map[String, String], + renameMap: RenameMap): DefModule = { + val module = moduleMap(originalModule) // If black box, return it (it has no instances) if (module.isInstanceOf[ExtModule]) return module @@ -187,7 +219,14 @@ object DedupModules { moduleMap(name2name(old)) } // Define rename functions - def renameModule(name: String): String = getNewModule(name).name + def renameOfModule(instance: String, ofModule: String): String = { + val newOfModule = name2name(ofModule) + renameMap.record( + top.module(originalModule).instOf(instance, ofModule), + top.module(originalModule).instOf(instance, newOfModule) + ) + newOfModule + } val typeMap = mutable.HashMap[String, Type]() def retype(name: String)(tpe: Type): Type = { if (typeMap.contains(name)) typeMap(name) else { @@ -198,96 +237,155 @@ object DedupModules { case (old, nuu) => renameMap.rename(old.serialize, nuu.serialize) } newType - } else tpe + } else { + tpe + } } } renameMap.setModule(module.name) // Change module internals - changeInternals({n => n}, retype, {i => i}, renameModule)(module) + changeInternals({n => n}, retype, {i => i}, renameOfModule)(module) } - /** - * Deduplicate - * @param circuit Circuit - * @param noDedups list of modules to not dedup - * @param renameMap rename map to populate when deduping - * @return Map of original Module name -> Deduped Module + //scalastyle:off + /** Returns + * 1) map of tag to all matching module names, + * 2) renameMap of module name to tag (agnostic name) + * 3) maps module name to agnostic renameMap + * @param top CircuitTarget + * @param moduleLinearization Sequence of modules from leaf to top + * @param noDedups Set of modules to not dedup + * @param annotations All annotations to check if annotations are identical + * @return */ - def deduplicate(circuit: Circuit, - noDedups: Set[String], - renameMap: RenameMap): Map[String, DefModule] = { - - // Order of modules, from leaf to top - val moduleLinearization = new InstanceGraph(circuit).moduleOrder.map(_.name).reverse + def buildRTLTags(top: CircuitTarget, + moduleLinearization: Seq[DefModule], + noDedups: Set[String], + annotations: Seq[Annotation] + ): (collection.Map[String, collection.Set[String]], RenameMap, collection.Map[String, RenameMap]) = { - // Maps module name to original module - val moduleMap = circuit.modules.map(m => m.name -> m).toMap - - // Maps a module's tag to its deduplicated module - val tag2name = mutable.HashMap.empty[String, String] - // Maps a module's name to its tag - val name2tag = mutable.HashMap.empty[String, String] + // Maps a module name to its agnostic name + val tagMap = RenameMap() // Maps a tag to all matching module names - val tag2all = mutable.HashMap.empty[String, mutable.Set[String]] + val tag2all = mutable.HashMap.empty[String, mutable.HashSet[String]] - // Build dedupMap - moduleLinearization.foreach { moduleName => - // Get original module - val originalModule = moduleMap(moduleName) + val module2Annotations = mutable.HashMap.empty[String, mutable.HashSet[Annotation]] + annotations.foreach { a => + a.getTargets.foreach { t => + val annos = module2Annotations.getOrElseUpdate(t.moduleOpt.get, mutable.HashSet.empty[Annotation]) + annos += a + } + } + val agnosticModuleMap = RenameMap() + val agnosticRenames = mutable.HashMap[String, RenameMap]() + + moduleLinearization.foreach { originalModule => // Replace instance references to new deduped modules val dontcare = RenameMap() dontcare.setCircuit("dontcare") - //val fixedModule = DedupModules.dedupInstances(originalModule, tag2module, name2tag, name2module, dontcare) + + val agnosticRename = RenameMap.create(agnosticModuleMap.getUnderlying) + agnosticRenames(originalModule.name) = agnosticRename if (noDedups.contains(originalModule.name)) { // Don't dedup. Set dedup module to be the same as fixed module - name2tag(originalModule.name) = originalModule.name - tag2name(originalModule.name) = originalModule.name - //templateModules += originalModule.name + tag2all(originalModule.name) = mutable.HashSet(originalModule.name) } else { // Try to dedup // Build name-agnostic module - val agnosticModule = DedupModules.agnostify(originalModule, name2tag, tag2name) + val agnosticModule = DedupModules.agnostify(top, originalModule, agnosticRename) + agnosticRename.record(top.module(originalModule.name), top.module("thisModule")) + val agnosticAnnos = module2Annotations.getOrElse( + originalModule.name, mutable.HashSet.empty[Annotation] + ).map(_.update(agnosticRename)) + agnosticRename.delete(top.module(originalModule.name)) // Build tag - val tag = (agnosticModule match { - case Module(i, n, ps, b) => - ps.map(_.serialize).mkString + b.serialize + val builder = new mutable.ArrayBuffer[Any]() + agnosticModule.ports.foreach { builder ++= _.serialize } + builder ++= agnosticAnnos + + agnosticModule match { + case Module(i, n, ps, b) => builder ++= b.serialize case ExtModule(i, n, ps, dn, p) => - ps.map(_.serialize).mkString + dn + p.map(_.serialize).mkString - }).hashCode().toString + builder ++= dn + p.foreach { builder ++= _.serialize } + } + val tag = builder.hashCode().toString // Match old module name to its tag - name2tag(originalModule.name) = tag + agnosticRename.record(top.module(originalModule.name), top.module(tag)) + agnosticModuleMap.record(top.module(originalModule.name), top.module(tag)) + tagMap.record(top.module(originalModule.name), top.module(tag)) // Set tag's module to be the first matching module - if (!tag2name.contains(tag)) { - tag2name(tag) = originalModule.name - tag2all(tag) = mutable.Set(originalModule.name) - } else { - tag2all(tag) += originalModule.name - } + val all = tag2all.getOrElseUpdate(tag, mutable.HashSet.empty[String]) + all += originalModule.name } } + (tag2all, tagMap, agnosticRenames) + } + //scalastyle:on + /** Deduplicate + * @param circuit Circuit + * @param noDedups list of modules to not dedup + * @param renameMap rename map to populate when deduping + * @return Map of original Module name -> Deduped Module + */ + def deduplicate(circuit: Circuit, + noDedups: Set[String], + annotations: Seq[Annotation], + renameMap: RenameMap): Map[String, DefModule] = { + + val (moduleMap, moduleLinearization) = { + val iGraph = new InstanceGraph(circuit) + (iGraph.moduleMap, iGraph.moduleOrder.reverse) + } + val top = CircuitTarget(circuit.main) + + val (tag2all, tagMap, agnosticRenames) = buildRTLTags(top, moduleLinearization, noDedups, annotations) // Set tag2name to be the best dedup module name val moduleIndex = circuit.modules.zipWithIndex.map{case (m, i) => m.name -> i}.toMap def order(l: String, r: String): String = if (moduleIndex(l) < moduleIndex(r)) l else r + + // Maps a module's tag to its deduplicated module + val tag2name = mutable.HashMap.empty[String, String] tag2all.foreach { case (tag, all) => tag2name(tag) = all.reduce(order)} // Create map from original to dedup name - val name2name = name2tag.map({ case (name, tag) => name -> tag2name(tag) }) + val name2name = moduleMap.keysIterator.map{ originalModule => + tagMap.get(top.module(originalModule)) match { + case Some(Seq(Target(_, Some(tag), Nil))) => originalModule -> tag2name(tag) + case None => originalModule -> originalModule + case other => throwInternalError(other.toString) + } + }.toMap // Build Remap for modules with deduped module references - val tag2module = tag2name.map({ case (tag, name) => tag -> DedupModules.dedupInstances(name, moduleMap, name2name, renameMap) }) + val dedupedName2module = tag2name.map({ case (tag, name) => name -> DedupModules.dedupInstances(top, name, moduleMap, name2name, renameMap) }) // Build map from original name to corresponding deduped module - val name2module = name2tag.map({ case (name, tag) => name -> tag2module(tag) }) + val name2module = tag2all.flatMap({ case (tag, names) => names.map(n => n -> dedupedName2module(tag2name(tag))) }) + + val reversedAgnosticRenames = mutable.HashMap[String, RenameMap]() + name2module.foreach { case (originalModuleName, dedupedModule) => + if(!reversedAgnosticRenames.contains(dedupedModule.name)) { + reversedAgnosticRenames(dedupedModule.name) = agnosticRenames(dedupedModule.name).getReverseRenameMap + } + agnosticRenames(originalModuleName).keys.foreach { key => + if(key.isInstanceOf[IsComponent]) { + val tag = agnosticRenames(originalModuleName)(key).head + val newKey = reversedAgnosticRenames(dedupedModule.name).apply(tag) + renameMap.record(key.asInstanceOf[IsMember], newKey.asInstanceOf[Seq[IsMember]]) + } + } + } name2module.toMap } diff --git a/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala b/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala index fab540da..a66bd4ce 100644 --- a/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala +++ b/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala @@ -4,6 +4,7 @@ package transforms import firrtl.annotations._ import firrtl.passes.PassException +import firrtl.transforms /** Indicate that DCE should not be run */ case object NoDCEAnnotation extends NoTargetAnnotation @@ -12,13 +13,14 @@ case object NoDCEAnnotation extends NoTargetAnnotation * * DCE treats the component as a top-level sink of the circuit */ -case class DontTouchAnnotation(target: ComponentName) extends SingleTargetAnnotation[ComponentName] { - def duplicate(n: ComponentName) = this.copy(n) +case class DontTouchAnnotation(target: ReferenceTarget) extends SingleTargetAnnotation[ReferenceTarget] { + def targets = Seq(target) + def duplicate(n: ReferenceTarget) = this.copy(n) } object DontTouchAnnotation { class DontTouchNotFoundException(module: String, component: String) extends PassException( - s"Component marked dontTouch ($module.$component) not found!\n" + + s"Target marked dontTouch ($module.$component) not found!\n" + "It was probably accidentally deleted. Please check that your custom transforms are not" + "responsible and then file an issue on Github." ) diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala index 10414786..6b492148 100644 --- a/src/test/scala/firrtlTests/AnnotationTests.scala +++ b/src/test/scala/firrtlTests/AnnotationTests.scala @@ -68,7 +68,6 @@ abstract class AnnotationTests extends AnnotationSpec with Matchers { val tname = transform.name val inlineAnn = InlineAnnotation(CircuitName("Top")) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, Seq(inlineAnn)), Seq(transform)) - println(result.annotations.head) result.annotations.head should matchPattern { case DeletedAnnotation(`tname`, `inlineAnn`) => } @@ -192,7 +191,7 @@ abstract class AnnotationTests extends AnnotationSpec with Matchers { anno("w.a"), anno("w.b[0]"), anno("w.b[1]"), anno("r.a"), anno("r.b[0]"), anno("r.b[1]"), anno("write.a"), anno("write.b[0]"), anno("write.b[1]"), - dontTouch("Top.r"), dontTouch("Top.w") + dontTouch("Top.r"), dontTouch("Top.w"), dontTouch("Top.mem") ) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, annos), Nil) val resultAnno = result.annotations.toSeq diff --git a/src/test/scala/firrtlTests/ClockListTests.scala b/src/test/scala/firrtlTests/ClockListTests.scala index 20718a71..48d6dfd3 100644 --- a/src/test/scala/firrtlTests/ClockListTests.scala +++ b/src/test/scala/firrtlTests/ClockListTests.scala @@ -1,13 +1,11 @@ +// See LICENSE for license details. + package firrtlTests import java.io._ -import org.scalatest._ -import org.scalatest.prop._ import firrtl._ import firrtl.ir.Circuit import firrtl.passes._ -import firrtl.Parser.IgnoreInfo -import annotations._ import clocklist._ class ClockListTests extends FirrtlFlatSpec { diff --git a/src/test/scala/firrtlTests/InlineInstancesTests.scala b/src/test/scala/firrtlTests/InlineInstancesTests.scala index 6d386d48..4affd64d 100644 --- a/src/test/scala/firrtlTests/InlineInstancesTests.scala +++ b/src/test/scala/firrtlTests/InlineInstancesTests.scala @@ -346,6 +346,43 @@ class InlineInstancesTests extends LowTransformSpec { | b <= a""".stripMargin failingexecute(input, Seq(inline("A"))) } + + "Jack's Bug" should "not fail" in { + + val input = """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline + | i.a <= a + | b <= i.b + | module Inline : + | input a : UInt<32> + | output b : UInt<32> + | inst child of InlineChild + | child.a <= a + | b <= child.b + | module InlineChild : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + val check = """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i_a : UInt<32> + | wire i_b : UInt<32> + | inst i_child of InlineChild + | i_b <= i_child.b + | i_child.a <= i_a + | b <= i_b + | i_a <= a + | module InlineChild : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + execute(input, check, Seq(inline("Inline"))) + } } // Execution driven tests for inlining modules diff --git a/src/test/scala/firrtlTests/RenameMapSpec.scala b/src/test/scala/firrtlTests/RenameMapSpec.scala index 9e305b70..b063599b 100644 --- a/src/test/scala/firrtlTests/RenameMapSpec.scala +++ b/src/test/scala/firrtlTests/RenameMapSpec.scala @@ -4,26 +4,27 @@ package firrtlTests import firrtl.RenameMap import firrtl.FIRRTLException -import firrtl.annotations.{ - Named, - CircuitName, - ModuleName, - ComponentName -} +import firrtl.RenameMap.{CircularRenameException, IllegalRenameException} +import firrtl.annotations._ class RenameMapSpec extends FirrtlFlatSpec { - val cir = CircuitName("Top") - val cir2 = CircuitName("Pot") - val cir3 = CircuitName("Cir3") - val modA = ModuleName("A", cir) - val modA2 = ModuleName("A", cir2) - val modB = ModuleName("B", cir) - val foo = ComponentName("foo", modA) - val foo2 = ComponentName("foo", modA2) - val bar = ComponentName("bar", modA) - val fizz = ComponentName("fizz", modA) - val fooB = ComponentName("foo", modB) - val barB = ComponentName("bar", modB) + val cir = CircuitTarget("Top") + val cir2 = CircuitTarget("Pot") + val cir3 = CircuitTarget("Cir3") + val modA = cir.module("A") + val modA2 = cir2.module("A") + val modB = cir.module("B") + val foo = modA.ref("foo") + val foo2 = modA2.ref("foo") + val bar = modA.ref("bar") + val fizz = modA.ref("fizz") + val fooB = modB.ref("foo") + val barB = modB.ref("bar") + + val tmb = cir.module("Top").instOf("mid", "Middle").instOf("bot", "Bottom") + val tm2b = cir.module("Top").instOf("mid", "Middle2").instOf("bot", "Bottom") + val middle = cir.module("Middle") + val middle2 = cir.module("Middle2") behavior of "RenameMap" @@ -35,82 +36,268 @@ class RenameMapSpec extends FirrtlFlatSpec { it should "return a Seq of renamed things if it does rename something" in { val renames = RenameMap() - renames.rename(foo, bar) + renames.record(foo, bar) renames.get(foo) should be (Some(Seq(bar))) } it should "allow something to be renamed to multiple things" in { val renames = RenameMap() - renames.rename(foo, bar) - renames.rename(foo, fizz) + renames.record(foo, bar) + renames.record(foo, fizz) renames.get(foo) should be (Some(Seq(bar, fizz))) } it should "allow something to be renamed to nothing (ie. deleted)" in { val renames = RenameMap() - renames.rename(foo, Seq()) + renames.record(foo, Seq()) renames.get(foo) should be (Some(Seq())) } it should "return None if something is renamed to itself" in { val renames = RenameMap() - renames.rename(foo, foo) + renames.record(foo, foo) renames.get(foo) should be (None) } - it should "allow components to change module" in { + it should "allow targets to change module" in { val renames = RenameMap() - renames.rename(foo, fooB) + renames.record(foo, fooB) renames.get(foo) should be (Some(Seq(fooB))) } - it should "rename components if their module is renamed" in { + it should "rename targets if their module is renamed" in { val renames = RenameMap() - renames.rename(modA, modB) + renames.record(modA, modB) renames.get(foo) should be (Some(Seq(fooB))) renames.get(bar) should be (Some(Seq(barB))) } - it should "rename renamed components if the module of the target component is renamed" in { + it should "rename renamed targets if the module of the target is renamed" in { val renames = RenameMap() - renames.rename(modA, modB) - renames.rename(foo, bar) + renames.record(modA, modB) + renames.record(foo, bar) renames.get(foo) should be (Some(Seq(barB))) } it should "rename modules if their circuit is renamed" in { val renames = RenameMap() - renames.rename(cir, cir2) + renames.record(cir, cir2) renames.get(modA) should be (Some(Seq(modA2))) } - it should "rename components if their circuit is renamed" in { + it should "rename targets if their circuit is renamed" in { val renames = RenameMap() - renames.rename(cir, cir2) + renames.record(cir, cir2) renames.get(foo) should be (Some(Seq(foo2))) } - // Renaming `from` to each of the `tos` at the same time should error - case class BadRename(from: Named, tos: Seq[Named]) - val badRenames = - Seq(BadRename(foo, Seq(cir)), - BadRename(foo, Seq(modA)), - BadRename(modA, Seq(foo)), - BadRename(modA, Seq(cir)), - BadRename(cir, Seq(foo)), - BadRename(cir, Seq(modA)), - BadRename(cir, Seq(cir2, cir3)) - ) - // Run all BadRename tests - for (BadRename(from, tos) <- badRenames) { - val fromN = from.getClass.getSimpleName - val tosN = tos.map(_.getClass.getSimpleName).mkString(", ") - it should s"error if a $fromN is renamed to $tosN" in { + val TopCircuit = cir + val Top = cir.module("Top") + val Top_m = Top.instOf("m", "Middle") + val Top_m_l = Top_m.instOf("l", "Leaf") + val Top_m_l_a = Top_m_l.ref("a") + val Top_m_la = Top_m.ref("l").field("a") + val Middle = cir.module("Middle") + val Middle2 = cir.module("Middle2") + val Middle_la = Middle.ref("l").field("a") + val Middle_l_a = Middle.instOf("l", "Leaf").ref("a") + + it should "rename targets if modules in the path are renamed" in { + val renames = RenameMap() + renames.record(Middle, Middle2) + renames.get(Top_m) should be (Some(Seq(Top.instOf("m", "Middle2")))) + } + + it should "rename targets if instance and module in the path are renamed" in { + val renames = RenameMap() + renames.record(Middle, Middle2) + renames.record(Top.ref("m"), Top.ref("m2")) + renames.get(Top_m) should be (Some(Seq(Top.instOf("m2", "Middle2")))) + } + + it should "rename targets if instance in the path are renamed" in { + val renames = RenameMap() + renames.record(Top.ref("m"), Top.ref("m2")) + renames.get(Top_m) should be (Some(Seq(Top.instOf("m2", "Middle")))) + } + + it should "rename targets if instance and ofmodule in the path are renamed" in { + val renames = RenameMap() + val Top_m2 = Top.instOf("m2", "Middle2") + renames.record(Top_m, Top_m2) + renames.get(Top_m) should be (Some(Seq(Top_m2))) + } + + it should "properly do nothing if no remaps" in { + val renames = RenameMap() + renames.get(Top_m_l_a) should be (None) + } + + it should "properly rename if leaf is inlined" in { + val renames = RenameMap() + renames.record(Middle_l_a, Middle_la) + renames.get(Top_m_l_a) should be (Some(Seq(Top_m_la))) + } + + it should "properly rename if middle is inlined" in { + val renames = RenameMap() + renames.record(Top_m.ref("l"), Top.ref("m_l")) + renames.get(Top_m_l_a) should be (Some(Seq(Top.instOf("m_l", "Leaf").ref("a")))) + } + + it should "properly rename if leaf and middle are inlined" in { + val renames = RenameMap() + val inlined = Top.ref("m_l_a") + renames.record(Top_m_l_a, inlined) + renames.record(Top_m_l, Nil) + renames.record(Top_m, Nil) + renames.get(Top_m_l_a) should be (Some(Seq(inlined))) + } + + it should "quickly rename a target with a long path" in { + (0 until 50 by 10).foreach { endIdx => + val renames = RenameMap() + renames.record(TopCircuit.module("Y0"), TopCircuit.module("X0")) + val deepTarget = (0 until endIdx).foldLeft(Top: IsModule) { (t, idx) => + t.instOf("a", "A" + idx) + }.ref("ref") + val (millis, rename) = firrtl.Utils.time(renames.get(deepTarget)) + println(s"${(deepTarget.tokens.size - 1) / 2} -> $millis") + //rename should be(None) + } + } + + it should "rename with multiple renames" in { + val renames = RenameMap() + val Middle2 = cir.module("Middle2") + renames.record(Middle, Middle2) + renames.record(Middle.ref("l"), Middle.ref("lx")) + renames.get(Middle.ref("l")) should be (Some(Seq(Middle2.ref("lx")))) + } + + it should "rename with fields" in { + val Middle_o = Middle.ref("o") + val Middle_i = Middle.ref("i") + val Middle_o_f = Middle.ref("o").field("f") + val Middle_i_f = Middle.ref("i").field("f") + val renames = RenameMap() + renames.record(Middle_o, Middle_i) + renames.get(Middle_o_f) should be (Some(Seq(Middle_i_f))) + } + + it should "rename instances with same ofModule" in { + val Middle_o = Middle.ref("o") + val Middle_i = Middle.ref("i") + val renames = RenameMap() + renames.record(Middle_o, Middle_i) + renames.get(Middle.instOf("o", "O")) should be (Some(Seq(Middle.instOf("i", "O")))) + } + + it should "detect circular renames" in { + case class BadRename(from: IsMember, tos: Seq[IsMember]) + val badRenames = + Seq( + BadRename(foo, Seq(foo.field("bar"))), + BadRename(modA, Seq(foo)) + //BadRename(cir, Seq(foo)), + //BadRename(cir, Seq(modA)) + ) + // Run all BadRename tests + for (BadRename(from, tos) <- badRenames) { + val fromN = from + val tosN = tos.mkString(", ") + //it should s"error if a $fromN is renamed to $tosN" in { val renames = RenameMap() - for (to <- tos) { renames.rename(from, to) } - a [FIRRTLException] shouldBe thrownBy { - renames.get(foo) + for (to <- tos) { + a [IllegalArgumentException] shouldBe thrownBy { + renames.record(from, to) + } } + //} + } + } + + it should "be able to rename weird stuff" in { + // Renaming `from` to each of the `tos` at the same time should be ok + case class BadRename(from: CompleteTarget, tos: Seq[CompleteTarget]) + val badRenames = + Seq(//BadRename(foo, Seq(cir)), + BadRename(foo, Seq(modB)), + BadRename(modA, Seq(fooB)), + //BadRename(modA, Seq(cir)), + //BadRename(cir, Seq(foo)), + //BadRename(cir, Seq(modA)), + BadRename(cir, Seq(cir2, cir3)) + ) + // Run all BadRename tests + for (BadRename(from, tos) <- badRenames) { + val fromN = from + val tosN = tos.mkString(", ") + //it should s"error if a $fromN is renamed to $tosN" in { + val renames = RenameMap() + for (to <- tos) { + (from, to) match { + case (f: CircuitTarget, t: CircuitTarget) => renames.record(f, t) + case (f: IsMember, t: IsMember) => renames.record(f, t) + } + } + //a [FIRRTLException] shouldBe thrownBy { + renames.get(from) + //} + //} + } + } + + it should "error if a circular rename occurs" in { + val renames = RenameMap() + val top = CircuitTarget("Top") + renames.record(top.module("A"), top.module("B").instOf("c", "C")) + renames.record(top.module("B"), top.module("A").instOf("c", "C")) + a [CircularRenameException] shouldBe thrownBy { + renames.get(top.module("A")) + } + } + + it should "not error if a swapping rename occurs" in { + val renames = RenameMap() + val top = CircuitTarget("Top") + renames.record(top.module("A"), top.module("B")) + renames.record(top.module("B"), top.module("A")) + renames.get(top.module("A")) should be (Some(Seq(top.module("B")))) + renames.get(top.module("B")) should be (Some(Seq(top.module("A")))) + } + + it should "error if a reference is renamed to a module, and then we try to rename the reference's field" in { + val renames = RenameMap() + val top = CircuitTarget("Top") + renames.record(top.module("A").ref("ref"), top.module("B")) + renames.get(top.module("A").ref("ref")) should be(Some(Seq(top.module("B")))) + a [IllegalRenameException] shouldBe thrownBy { + renames.get(top.module("A").ref("ref").field("field")) + } + a [IllegalRenameException] shouldBe thrownBy { + renames.get(top.module("A").instOf("ref", "R")) + } + } + + it should "error if we rename an instance's ofModule into a non-module" in { + val renames = RenameMap() + val top = CircuitTarget("Top") + + renames.record(top.module("C"), top.module("D").ref("x")) + a [IllegalRenameException] shouldBe thrownBy { + renames.get(top.module("A").instOf("c", "C")) + } + } + + it should "error if path is renamed into a non-path" ignore { + val renames = RenameMap() + val top = CircuitTarget("Top") + + renames.record(top.module("E").instOf("f", "F"), top.module("E").ref("g")) + + a [IllegalRenameException] shouldBe thrownBy { + println(renames.get(top.module("E").instOf("f", "F").ref("g"))) } } } diff --git a/src/test/scala/firrtlTests/annotationTests/EliminateTargetPathsSpec.scala b/src/test/scala/firrtlTests/annotationTests/EliminateTargetPathsSpec.scala new file mode 100644 index 00000000..de84d79d --- /dev/null +++ b/src/test/scala/firrtlTests/annotationTests/EliminateTargetPathsSpec.scala @@ -0,0 +1,357 @@ +// See LICENSE for license details. + +package firrtlTests.annotationTests + +import firrtl._ +import firrtl.annotations._ +import firrtl.annotations.analysis.DuplicationHelper +import firrtl.annotations.transforms.NoSuchTargetException +import firrtl.transforms.DontTouchAnnotation +import firrtlTests.{FirrtlMatchers, FirrtlPropSpec} + +class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { + val input = + """circuit Top: + | module Leaf: + | input i: UInt<1> + | output o: UInt<1> + | o <= i + | node a = i + | module Middle: + | input i: UInt<1> + | output o: UInt<1> + | inst l1 of Leaf + | inst l2 of Leaf + | l1.i <= i + | l2.i <= l1.o + | o <= l2.o + | module Top: + | input i: UInt<1> + | output o: UInt<1> + | inst m1 of Middle + | inst m2 of Middle + | m1.i <= i + | m2.i <= m1.o + | o <= m2.o + """.stripMargin + + val TopCircuit = CircuitTarget("Top") + val Top = TopCircuit.module("Top") + val Middle = TopCircuit.module("Middle") + val Leaf = TopCircuit.module("Leaf") + + val Top_m1_l1_a = Top.instOf("m1", "Middle").instOf("l1", "Leaf").ref("a") + val Top_m2_l1_a = Top.instOf("m2", "Middle").instOf("l1", "Leaf").ref("a") + val Top_m1_l2_a = Top.instOf("m1", "Middle").instOf("l2", "Leaf").ref("a") + val Top_m2_l2_a = Top.instOf("m2", "Middle").instOf("l2", "Leaf").ref("a") + val Middle_l1_a = Middle.instOf("l1", "Leaf").ref("a") + val Middle_l2_a = Middle.instOf("l2", "Leaf").ref("a") + val Leaf_a = Leaf.ref("a") + + case class DummyAnnotation(target: Target) extends SingleTargetAnnotation[Target] { + override def duplicate(n: Target): Annotation = DummyAnnotation(n) + } + class DummyTransform() extends Transform with ResolvedAnnotationPaths { + override def inputForm: CircuitForm = LowForm + override def outputForm: CircuitForm = LowForm + + override val annotationClasses: Traversable[Class[_]] = Seq(classOf[DummyAnnotation]) + + override def execute(state: CircuitState): CircuitState = state + } + val customTransforms = Seq(new DummyTransform()) + + val inputState = CircuitState(parse(input), ChirrtlForm) + property("Hierarchical tokens should be expanded properly") { + val dupMap = new DuplicationHelper(inputState.circuit.modules.map(_.name).toSet) + + + // Only a few instance references + dupMap.expandHierarchy(Top_m1_l1_a) + dupMap.expandHierarchy(Top_m2_l1_a) + dupMap.expandHierarchy(Middle_l1_a) + + dupMap.makePathless(Top_m1_l1_a).foreach {Set(TopCircuit.module("Leaf___Top_m1_l1").ref("a")) should contain (_)} + dupMap.makePathless(Top_m2_l1_a).foreach {Set(TopCircuit.module("Leaf___Top_m2_l1").ref("a")) should contain (_)} + dupMap.makePathless(Top_m1_l2_a).foreach {Set(Leaf_a) should contain (_)} + dupMap.makePathless(Top_m2_l2_a).foreach {Set(Leaf_a) should contain (_)} + dupMap.makePathless(Middle_l1_a).foreach {Set( + TopCircuit.module("Leaf___Top_m1_l1").ref("a"), + TopCircuit.module("Leaf___Top_m2_l1").ref("a"), + TopCircuit.module("Leaf___Middle_l1").ref("a") + ) should contain (_) } + dupMap.makePathless(Middle_l2_a).foreach {Set(Leaf_a) should contain (_)} + dupMap.makePathless(Leaf_a).foreach {Set( + TopCircuit.module("Leaf___Top_m1_l1").ref("a"), + TopCircuit.module("Leaf___Top_m2_l1").ref("a"), + TopCircuit.module("Leaf___Middle_l1").ref("a"), + Leaf_a + ) should contain (_)} + dupMap.makePathless(Top).foreach {Set(Top) should contain (_)} + dupMap.makePathless(Middle).foreach {Set( + TopCircuit.module("Middle___Top_m1"), + TopCircuit.module("Middle___Top_m2"), + Middle + ) should contain (_)} + dupMap.makePathless(Leaf).foreach {Set( + TopCircuit.module("Leaf___Top_m1_l1"), + TopCircuit.module("Leaf___Top_m2_l1"), + TopCircuit.module("Leaf___Middle_l1"), + Leaf + ) should contain (_) } + } + + property("Hierarchical donttouch should be resolved properly") { + val inputState = CircuitState(parse(input), ChirrtlForm, Seq(DontTouchAnnotation(Top_m1_l1_a))) + val customTransforms = Seq(new LowFirrtlOptimization()) + val outputState = new LowFirrtlCompiler().compile(inputState, customTransforms) + val check = + """circuit Top : + | module Leaf___Top_m1_l1 : + | input i : UInt<1> + | output o : UInt<1> + | + | node a = i + | o <= i + | + | module Leaf : + | input i : UInt<1> + | output o : UInt<1> + | + | skip + | o <= i + | + | module Middle___Top_m1 : + | input i : UInt<1> + | output o : UInt<1> + | + | inst l1 of Leaf___Top_m1_l1 + | inst l2 of Leaf + | o <= l2.o + | l1.i <= i + | l2.i <= l1.o + | + | module Middle : + | input i : UInt<1> + | output o : UInt<1> + | + | inst l1 of Leaf + | inst l2 of Leaf + | o <= l2.o + | l1.i <= i + | l2.i <= l1.o + | + | module Top : + | input i : UInt<1> + | output o : UInt<1> + | + | inst m1 of Middle___Top_m1 + | inst m2 of Middle + | o <= m2.o + | m1.i <= i + | m2.i <= m1.o + | + """.stripMargin + canonicalize(outputState.circuit).serialize should be (canonicalize(parse(check)).serialize) + outputState.annotations.collect { + case x: DontTouchAnnotation => x.target + } should be (Seq(Top.circuitTarget.module("Leaf___Top_m1_l1").ref("a"))) + } + + property("No name conflicts between old and new modules") { + val input = + """circuit Top: + | module Middle: + | input i: UInt<1> + | output o: UInt<1> + | o <= i + | module Top: + | input i: UInt<1> + | output o: UInt<1> + | inst m1 of Middle + | inst m2 of Middle + | inst x of Middle___Top_m1 + | x.i <= i + | m1.i <= i + | m2.i <= m1.o + | o <= m2.o + | module Middle___Top_m1: + | input i: UInt<1> + | output o: UInt<1> + | o <= i + | node a = i + """.stripMargin + val checks = + """circuit Top : + | module Middle : + | module Top : + | module Middle___Top_m1 : + | module Middle____Top_m1 :""".stripMargin.split("\n") + val Top_m1 = Top.instOf("m1", "Middle") + val inputState = CircuitState(parse(input), ChirrtlForm, Seq(DummyAnnotation(Top_m1))) + val outputState = new LowFirrtlCompiler().compile(inputState, customTransforms) + val outputLines = outputState.circuit.serialize.split("\n") + checks.foreach { line => + outputLines should contain (line) + } + } + + property("Previously unused modules should remain, but newly unused modules should be eliminated") { + val input = + """circuit Top: + | module Leaf: + | input i: UInt<1> + | output o: UInt<1> + | o <= i + | node a = i + | module Middle: + | input i: UInt<1> + | output o: UInt<1> + | o <= i + | module Top: + | input i: UInt<1> + | output o: UInt<1> + | inst m1 of Middle + | inst m2 of Middle + | m1.i <= i + | m2.i <= m1.o + | o <= m2.o + """.stripMargin + + val checks = + """circuit Top : + | module Leaf : + | module Top : + | module Middle___Top_m1 : + | module Middle___Top_m2 :""".stripMargin.split("\n") + + val Top_m1 = Top.instOf("m1", "Middle") + val Top_m2 = Top.instOf("m2", "Middle") + val inputState = CircuitState(parse(input), ChirrtlForm, Seq(DummyAnnotation(Top_m1), DummyAnnotation(Top_m2))) + val outputState = new LowFirrtlCompiler().compile(inputState, customTransforms) + val outputLines = outputState.circuit.serialize.split("\n") + + checks.foreach { line => + outputLines should contain (line) + } + checks.foreach { line => + outputLines should not contain (" module Middle :") + } + } + + property("Paths with incorrect names should error") { + val input = + """circuit Top: + | module Leaf: + | input i: UInt<1> + | output o: UInt<1> + | o <= i + | node a = i + | module Middle: + | input i: UInt<1> + | output o: UInt<1> + | o <= i + | module Top: + | input i: UInt<1> + | output o: UInt<1> + | inst m1 of Middle + | inst m2 of Middle + | m1.i <= i + | m2.i <= m1.o + | o <= m2.o + """.stripMargin + intercept[NoSuchTargetException] { + val Top_m1 = Top.instOf("m1", "MiddleX") + val inputState = CircuitState(parse(input), ChirrtlForm, Seq(DummyAnnotation(Top_m1))) + new LowFirrtlCompiler().compile(inputState, customTransforms) + } + intercept[NoSuchTargetException] { + val Top_m2 = Top.instOf("x2", "Middle") + val inputState = CircuitState(parse(input), ChirrtlForm, Seq(DummyAnnotation(Top_m2))) + new LowFirrtlCompiler().compile(inputState, customTransforms) + } + } + + property("No name conflicts between two new modules") { + val input = + """circuit Top: + | module Top: + | input i: UInt<1> + | output o: UInt<1> + | inst m1 of Middle_ + | inst m2 of Middle + | m1.i <= i + | m2.i <= m1.o + | o <= m2.o + | module Middle: + | input i: UInt<1> + | output o: UInt<1> + | inst _l of Leaf + | _l.i <= i + | o <= _l.o + | module Middle_: + | input i: UInt<1> + | output o: UInt<1> + | inst l of Leaf + | l.i <= i + | node x = i + | o <= l.o + | module Leaf: + | input i: UInt<1> + | output o: UInt<1> + | o <= i + """.stripMargin + val checks = + """circuit Top : + | module Middle : + | module Top : + | module Leaf___Middle__l : + | module Leaf____Middle__l :""".stripMargin.split("\n") + val Middle_l1 = CircuitTarget("Top").module("Middle").instOf("_l", "Leaf") + val Middle_l2 = CircuitTarget("Top").module("Middle_").instOf("l", "Leaf") + val inputState = CircuitState(parse(input), ChirrtlForm, Seq(DummyAnnotation(Middle_l1), DummyAnnotation(Middle_l2))) + val outputState = new LowFirrtlCompiler().compile(inputState, customTransforms) + val outputLines = outputState.circuit.serialize.split("\n") + checks.foreach { line => + outputLines should contain (line) + } + } + + property("Keep annotations of modules not instantiated") { + val input = + """circuit Top: + | module Top: + | input i: UInt<1> + | output o: UInt<1> + | inst m1 of Middle + | inst m2 of Middle + | m1.i <= i + | m2.i <= m1.o + | o <= m2.o + | module Middle: + | input i: UInt<1> + | output o: UInt<1> + | inst _l of Leaf + | _l.i <= i + | o <= _l.o + | module Middle_: + | input i: UInt<1> + | output o: UInt<1> + | o <= UInt(0) + | module Leaf: + | input i: UInt<1> + | output o: UInt<1> + | o <= i + """.stripMargin + val checks = + """circuit Top : + | module Middle_ :""".stripMargin.split("\n") + val Middle_ = CircuitTarget("Top").module("Middle_").ref("i") + val inputState = CircuitState(parse(input), ChirrtlForm, Seq(DontTouchAnnotation(Middle_))) + val outputState = new VerilogCompiler().compile(inputState, customTransforms) + val outputLines = outputState.circuit.serialize.split("\n") + checks.foreach { line => + outputLines should contain (line) + } + } +} diff --git a/src/test/scala/firrtlTests/annotationTests/TargetSpec.scala b/src/test/scala/firrtlTests/annotationTests/TargetSpec.scala new file mode 100644 index 00000000..4ae4e036 --- /dev/null +++ b/src/test/scala/firrtlTests/annotationTests/TargetSpec.scala @@ -0,0 +1,59 @@ +// See LICENSE for license details. + +package firrtlTests.annotationTests + +import firrtl.annotations.{CircuitTarget, GenericTarget, ModuleTarget, Target} +import firrtl.annotations.TargetToken._ +import firrtlTests.FirrtlPropSpec + +class TargetSpec extends FirrtlPropSpec { + def check(comp: Target): Unit = { + val named = Target.convertTarget2Named(comp) + println(named) + val comp2 = Target.convertNamed2Target(named) + assert(comp.toGenericTarget.complete == comp2) + } + property("Serialization of Targets should work") { + val circuit = CircuitTarget("Circuit") + val top = circuit.module("Top") + val targets: Seq[(Target, String)] = + Seq( + (circuit, "~Circuit"), + (top, "~Circuit|Top"), + (top.instOf("i", "I"), "~Circuit|Top/i:I"), + (top.ref("r"), "~Circuit|Top>r"), + (top.ref("r").index(1).field("hi").clock, "~Circuit|Top>r[1].hi@clock"), + (GenericTarget(None, None, Vector(Ref("r"))), "~???|???>r") + ) + targets.foreach { case (t, str) => + assert(t.serialize == str, s"$t does not properly serialize") + } + } + property("Should convert to/from Named") { + check(Target(Some("Top"), None, Nil)) + check(Target(Some("Top"), Some("Top"), Nil)) + check(Target(Some("Top"), Some("Other"), Nil)) + val r1 = Seq(Ref("r1"), Field("I")) + val r2 = Seq(Ref("r2"), Index(0)) + check(Target(Some("Top"), Some("Top"), r1)) + check(Target(Some("Top"), Some("Top"), r2)) + } + property("Should enable creating from API") { + val top = ModuleTarget("Top","Top") + val x_reg0_data = top.instOf("x", "X").ref("reg0").field("data") + top.instOf("x", "x") + top.ref("y") + println(x_reg0_data) + } + property("Should serialize and deserialize") { + val circuit = CircuitTarget("Circuit") + val top = circuit.module("Top") + val targets: Seq[Target] = + Seq(circuit, top, top.instOf("i", "I"), top.ref("r"), + top.ref("r").index(1).field("hi").clock, GenericTarget(None, None, Vector(Ref("r")))) + targets.foreach { t => + assert(Target.deserialize(t.serialize) == t, s"$t does not properly serialize/deserialize") + } + } +} + diff --git a/src/test/scala/firrtlTests/transforms/DedupTests.scala b/src/test/scala/firrtlTests/transforms/DedupTests.scala index b66f7f9d..5ee2b927 100644 --- a/src/test/scala/firrtlTests/transforms/DedupTests.scala +++ b/src/test/scala/firrtlTests/transforms/DedupTests.scala @@ -3,14 +3,24 @@ package firrtlTests package transforms +import firrtl.RenameMap import firrtl.annotations._ -import firrtl.transforms.{DedupModules} +import firrtl.transforms.DedupModules /** * Tests inline instances transformation */ class DedupModuleTests extends HighTransformSpec { + case class MultiTargetDummyAnnotation(targets: Seq[Target], tag: Int) extends Annotation { + override def update(renames: RenameMap): Seq[Annotation] = { + val newTargets = targets.flatMap(renames(_)) + Seq(MultiTargetDummyAnnotation(newTargets, tag)) + } + } + case class SingleTargetDummyAnnotation(target: ComponentName) extends SingleTargetAnnotation[ComponentName] { + override def duplicate(n: ComponentName): Annotation = SingleTargetDummyAnnotation(n) + } def transform = new DedupModules "The module A" should "be deduped" in { val input = @@ -135,7 +145,7 @@ class DedupModuleTests extends HighTransformSpec { """.stripMargin execute(input, check, Seq(dontDedup("A"))) } - "The module A and A_" should "be deduped even with different port names and info, and annotations should remap" in { + "The module A and A_" should "be deduped even with different port names and info, and annotations should remapped" in { val input = """circuit Top : | module Top : @@ -161,16 +171,12 @@ class DedupModuleTests extends HighTransformSpec { | output x: UInt<1> @[yy 2:2] | x <= UInt(1) """.stripMargin - case class DummyAnnotation(target: ComponentName) extends SingleTargetAnnotation[ComponentName] { - override def duplicate(n: ComponentName): Annotation = DummyAnnotation(n) - } val mname = ModuleName("Top", CircuitName("Top")) - val finalState = execute(input, check, Seq(DummyAnnotation(ComponentName("a2.y", mname)))) - - finalState.annotations.collect({ case d: DummyAnnotation => d }).head should be(DummyAnnotation(ComponentName("a2.x", mname))) - + val finalState = execute(input, check, Seq(SingleTargetDummyAnnotation(ComponentName("a2.y", mname)))) + finalState.annotations.collect({ case d: SingleTargetDummyAnnotation => d }).head should be(SingleTargetDummyAnnotation(ComponentName("a2.x", mname))) } + "Extmodules" should "with the same defname and parameters should dedup" in { val input = """circuit Top : @@ -215,6 +221,7 @@ class DedupModuleTests extends HighTransformSpec { execute(input, check, Seq.empty) } + "Extmodules" should "with the different defname or parameters should NOT dedup" in { def mkfir(defnames: (String, String), params: (String, String)) = s"""circuit Top : @@ -253,12 +260,188 @@ class DedupModuleTests extends HighTransformSpec { | inst a2 of A_ | module A : | output x: UInt<1> + | inst b of B + | x <= b.x + | module A_ : + | output x: UInt<1> + | inst b of B_ + | x <= b.x + | module B : + | output x: UInt<1> + | x <= UInt(1) + | module B_ : + | output x: UInt<1> + | x <= UInt(1) + """.stripMargin + val check = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A + | module A : + | output x: UInt<1> + | inst b of B + | x <= b.x + | module B : + | output x: UInt<1> + | x <= UInt(1) + """.stripMargin + execute(input, check, Seq.empty) + } + + "The module A and A_" should "be deduped with fields that sort of match" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | wire b: {c: UInt<1>} + | x <= b.c + | module A_ : + | output x: UInt<1> + | wire b: {b: UInt<1>} + | x <= b.b + """.stripMargin + val check = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A + | module A : + | output x: UInt<1> + | wire b: {c: UInt<1>} + | x <= b.c + """.stripMargin + execute(input, check, Seq.empty) + } + + "The module A and A_" should "not be deduped with different annotation targets" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | wire b: UInt<1> + | x <= b + | module A_ : + | output x: UInt<1> + | wire b: UInt<1> + | x <= b + """.stripMargin + val check = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | wire b: UInt<1> + | x <= b + | module A_ : + | output x: UInt<1> + | wire b: UInt<1> + | x <= b + """.stripMargin + execute(input, check, Seq(dontTouch("A.b"))) + } + + "The module A and A_" should "be deduped with same annotation targets" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | wire b: UInt<1> + | x <= b + | module A_ : + | output x: UInt<1> + | wire b: UInt<1> + | x <= b + """.stripMargin + val check = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A + | module A : + | output x: UInt<1> + | wire b: UInt<1> + | x <= b + """.stripMargin + execute(input, check, Seq(dontTouch("A.b"), dontTouch("A_.b"))) + } + "The module A and A_" should "not be deduped with same annotations with same multi-targets, but which have different root modules" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | inst b of B + | x <= b.x + | module A_ : + | output x: UInt<1> | inst b of B_ | x <= b.x + | module B : + | output x: UInt<1> + | x <= UInt(1) + | module B_ : + | output x: UInt<1> + | x <= UInt(1) + """.stripMargin + val check = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | inst b of B + | x <= b.x | module A_ : | output x: UInt<1> + | inst b of B_ + | x <= b.x + | module B : + | output x: UInt<1> + | x <= UInt(1) + | module B_ : + | output x: UInt<1> + | x <= UInt(1) + """.stripMargin + val Top = CircuitTarget("Top") + val A = Top.module("A") + val B = Top.module("B") + val A_ = Top.module("A_") + val B_ = Top.module("B_") + val annoAB = MultiTargetDummyAnnotation(Seq(A, B), 0) + val annoA_B_ = MultiTargetDummyAnnotation(Seq(A_, B_), 0) + val cs = execute(input, check, Seq(annoAB, annoA_B_)) + cs.annotations.toSeq should contain (annoAB) + cs.annotations.toSeq should contain (annoA_B_) + } + "The module A and A_" should "be deduped with same annotations with same multi-targets, that share roots" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> | inst b of B | x <= b.x + | module A_ : + | output x: UInt<1> + | inst b of B_ + | x <= b.x | module B : | output x: UInt<1> | x <= UInt(1) @@ -279,7 +462,47 @@ class DedupModuleTests extends HighTransformSpec { | output x: UInt<1> | x <= UInt(1) """.stripMargin - execute(input, check, Seq.empty) + val Top = CircuitTarget("Top") + val A = Top.module("A") + val A_ = Top.module("A_") + val annoA = MultiTargetDummyAnnotation(Seq(A, A.instOf("b", "B")), 0) + val annoA_ = MultiTargetDummyAnnotation(Seq(A_, A_.instOf("b", "B_")), 0) + val cs = execute(input, check, Seq(annoA, annoA_)) + cs.annotations.toSeq should contain (annoA) + cs.annotations.toSeq should not contain (annoA_) + cs.deletedAnnotations.isEmpty should be (true) + } + "The deduping module A and A_" should "renamed internal signals that have different names" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output y: UInt<1> + | y <= UInt(1) + | module A_ : + | output x: UInt<1> + | x <= UInt(1) + """.stripMargin + val check = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A + | module A : + | output y: UInt<1> + | y <= UInt<1>("h1") + """.stripMargin + val Top = CircuitTarget("Top") + val A = Top.module("A") + val A_ = Top.module("A_") + val annoA = SingleTargetDummyAnnotation(A.ref("y")) + val annoA_ = SingleTargetDummyAnnotation(A_.ref("x")) + val cs = execute(input, check, Seq(annoA, annoA_)) + cs.annotations.toSeq should contain (annoA) + cs.annotations.toSeq should not contain (SingleTargetDummyAnnotation(A.ref("x"))) + cs.deletedAnnotations.isEmpty should be (true) } } |
