From 396ee7ca63eb8a9e201dcdea965cbfc3e9d36783 Mon Sep 17 00:00:00 2001 From: Jack Koenig Date: Wed, 28 Mar 2018 12:52:11 -0700 Subject: Enhance RenameMap to support circuit renaming (#775) Also delete CircuitTopName. It will not work with updated RenameMap --- src/main/scala/firrtl/Compiler.scala | 74 ++++++++++++++++---------- src/main/scala/firrtl/annotations/Named.scala | 6 --- src/test/scala/firrtlTests/RenameMapSpec.scala | 42 +++++++++++++++ 3 files changed, 87 insertions(+), 35 deletions(-) (limited to 'src') diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala index ca74e5e1..603c05f6 100644 --- a/src/main/scala/firrtl/Compiler.scala +++ b/src/main/scala/firrtl/Compiler.scala @@ -27,42 +27,58 @@ object RenameMap { // TODO This should probably be refactored into immutable and mutable versions final class RenameMap private () { private val underlying = mutable.HashMap[Named, Seq[Named]]() - /** Get new names for an old name - * - * This is analogous to get on standard Scala collection Maps - * None indicates the key was not renamed - * Empty indicates the name was deleted + + /** Get renames of a [[CircuitName]] + * @note A [[CircuitName]] can only be renamed to a single [[CircuitName]] */ - // TODO Is there a better way to express this? - def get(key: Named): Option[Seq[Named]] = { + def get(key: CircuitName): Option[CircuitName] = underlying.get(key).map { + case Seq(c: CircuitName) => c + case other => error(s"Unsupported Circuit rename to $other!") + } + /** Get renames of a [[ModuleName]] + * @note A [[ModuleName]] can only be renamed to one-or-more [[ModuleName]]s + */ + def get(key: ModuleName): Option[Seq[ModuleName]] = { + def nestedRename(m: ModuleName): Option[Seq[ModuleName]] = + this.get(m.circuit).map(cname => Seq(ModuleName(m.name, cname))) underlying.get(key) match { - // If the key was renamed, check if anything it renamed to is a component - // If so, check if nested modules were renamed case Some(names) => Some(names.flatMap { - case comp @ ComponentName(cname, mod) => - underlying.get(mod) match { - case Some(mods) => mods.map { - case modx: ModuleName => - ComponentName(cname, modx) - case _ => error("Unexpected rename of Module to non-Module!") - } - case None => List(comp) - } - case other => List(other) + case m: ModuleName => + nestedRename(m).getOrElse(Seq(m)) + case other => error(s"Unsupported Module rename of $key to $other") }) - // If key wans't renamed, still check if it's a component - // If so, check if nexted modules were renamed - case None => key match { - case ComponentName(cname, mod) => - underlying.get(mod).map(_.map { - case modx: ModuleName => - ComponentName(cname, modx) - case _ => error("Unexpected rename of Module to non-Module!") - }) - case other => None + case None => nestedRename(key) + } + } + /** Get renames of a [[ComponentName]] + * @note A [[ComponentName]] can only be renamed to one-or-more [[ComponentName]]s + */ + def get(key: ComponentName): Option[Seq[ComponentName]] = { + def nestedRename(c: ComponentName): Option[Seq[ComponentName]] = + this.get(c.module).map { modules => + modules.map(mname => ComponentName(c.name, mname)) } + underlying.get(key) match { + case Some(names) => Some(names.flatMap { + case c: ComponentName => + nestedRename(c).getOrElse(Seq(c)) + case other => error(s"Unsupported Component rename of $key to $other") + }) + case None => nestedRename(key) } } + /** Get new names for an old name + * + * This is analogous to get on standard Scala collection Maps + * None indicates the key was not renamed + * Empty indicates the name was deleted + */ + def get(key: Named): Option[Seq[Named]] = key match { + case c: ComponentName => this.get(c) + case m: ModuleName => this.get(m) + // The CircuitName version returns Option[CircuitName] + case c: CircuitName => this.get(c).map(Seq(_)) + } // Mutable helpers private var circuitName: String = "" diff --git a/src/main/scala/firrtl/annotations/Named.scala b/src/main/scala/firrtl/annotations/Named.scala index d3a70643..3da75884 100644 --- a/src/main/scala/firrtl/annotations/Named.scala +++ b/src/main/scala/firrtl/annotations/Named.scala @@ -13,12 +13,6 @@ sealed trait Named { def serialize: String } -/** Name referring to the top of the circuit */ -final case object CircuitTopName extends Named { - def name: String = "CircuitTop" - def serialize: String = name -} - final case class CircuitName(name: String) extends Named { if(!validModuleName(name)) throw AnnotationException(s"Illegal circuit name: $name") def serialize: String = name diff --git a/src/test/scala/firrtlTests/RenameMapSpec.scala b/src/test/scala/firrtlTests/RenameMapSpec.scala index 9d19bb72..9e305b70 100644 --- a/src/test/scala/firrtlTests/RenameMapSpec.scala +++ b/src/test/scala/firrtlTests/RenameMapSpec.scala @@ -3,7 +3,9 @@ package firrtlTests import firrtl.RenameMap +import firrtl.FIRRTLException import firrtl.annotations.{ + Named, CircuitName, ModuleName, ComponentName @@ -11,9 +13,13 @@ import firrtl.annotations.{ class RenameMapSpec extends FirrtlFlatSpec { val cir = CircuitName("Top") + val cir2 = CircuitName("Pot") + val cir3 = CircuitName("Cir3") val modA = ModuleName("A", cir) + val modA2 = ModuleName("A", cir2) val modB = ModuleName("B", cir) val foo = ComponentName("foo", modA) + val foo2 = ComponentName("foo", modA2) val bar = ComponentName("bar", modA) val fizz = ComponentName("fizz", modA) val fooB = ComponentName("foo", modB) @@ -71,4 +77,40 @@ class RenameMapSpec extends FirrtlFlatSpec { renames.rename(foo, bar) renames.get(foo) should be (Some(Seq(barB))) } + + it should "rename modules if their circuit is renamed" in { + val renames = RenameMap() + renames.rename(cir, cir2) + renames.get(modA) should be (Some(Seq(modA2))) + } + + it should "rename components if their circuit is renamed" in { + val renames = RenameMap() + renames.rename(cir, cir2) + renames.get(foo) should be (Some(Seq(foo2))) + } + + // Renaming `from` to each of the `tos` at the same time should error + case class BadRename(from: Named, tos: Seq[Named]) + val badRenames = + Seq(BadRename(foo, Seq(cir)), + BadRename(foo, Seq(modA)), + BadRename(modA, Seq(foo)), + BadRename(modA, Seq(cir)), + BadRename(cir, Seq(foo)), + BadRename(cir, Seq(modA)), + BadRename(cir, Seq(cir2, cir3)) + ) + // Run all BadRename tests + for (BadRename(from, tos) <- badRenames) { + val fromN = from.getClass.getSimpleName + val tosN = tos.map(_.getClass.getSimpleName).mkString(", ") + it should s"error if a $fromN is renamed to $tosN" in { + val renames = RenameMap() + for (to <- tos) { renames.rename(from, to) } + a [FIRRTLException] shouldBe thrownBy { + renames.get(foo) + } + } + } } -- cgit v1.2.3