diff options
| author | Abert Chen | 2019-07-19 09:01:57 -0700 |
|---|---|---|
| committer | Schuyler Eldridge | 2019-09-12 12:41:58 -0400 |
| commit | 750ee776978fa1fdcfa64aa04f218b0c70c3e85e (patch) | |
| tree | 0f2f13ead2a12d919809c71be891043fcb34d2bd /src | |
| parent | 7929768a99eb93eea1c1ff0f71ab7d16a59abaa0 (diff) | |
update inline transform and testcases
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/analyses/InstanceGraph.scala | 8 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Inline.scala | 139 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/InlineInstancesTests.scala | 79 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/PassTests.scala | 14 |
4 files changed, 195 insertions, 45 deletions
diff --git a/src/main/scala/firrtl/analyses/InstanceGraph.scala b/src/main/scala/firrtl/analyses/InstanceGraph.scala index 1c6ada4e..22f40359 100644 --- a/src/main/scala/firrtl/analyses/InstanceGraph.scala +++ b/src/main/scala/firrtl/analyses/InstanceGraph.scala @@ -70,8 +70,12 @@ class InstanceGraph(c: Circuit) { * @return a Seq[ Seq[WDefInstance] ] of absolute instance paths */ def findInstancesInHierarchy(module: String): Seq[Seq[WDefInstance]] = { - val instances = graph.getVertices.filter(_.module == module).toSeq - instances flatMap { i => fullHierarchy(i) } + if (instantiated(module)) { + val instances = graph.getVertices.filter(_.module == module).toSeq + instances flatMap { i => fullHierarchy(i) } + } else { + Nil + } } /** An [[firrtl.graph.EulerTour EulerTour]] representation of the [[firrtl.graph.DiGraph DiGraph]] */ diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala index 7f8913c6..86d7bb22 100644 --- a/src/main/scala/firrtl/passes/Inline.scala +++ b/src/main/scala/firrtl/passes/Inline.scala @@ -132,44 +132,96 @@ class InlineInstances extends Transform with RegisteredTransform { val iGraph = new InstanceGraph(c) val namespaceMap = collection.mutable.Map[String, Namespace]() + def getInstMap(mod: DefModule): Map[String, String] = mod match { + case m: Module => getInstMapBody(m.body) + case _ => Map.empty[String, String] + } + def getInstMapBody(stmt: Statement): Map[String, String] = { + val instMap = mutable.Map.empty[String, String] + def onStmt(s: Statement): Statement = s.map(onStmt) match { + case wDef@ WDefInstance(_, instName, modName, _) => + instMap += (instName -> modName) + wDef + case other => other + } + onStmt(stmt) + instMap.toMap + } + /** Add a prefix to all declarations updating a [[Namespace]] and appending to a [[RenameMap]] */ - def appendNamePrefix(prefix: String, ns: Namespace, renames: RenameMap)(name:String): String = { - if (prefix.nonEmpty && !ns.tryName(prefix + name)) - throw new Exception(s"Inlining failed. Inlined name '${prefix + name}' already exists") - renames.rename(name, prefix + name) - prefix + name + def appendNamePrefix( + currentModule: IsModule, + prefix: String, + ns: Namespace, + renames: mutable.HashMap[String, String], + renameMap: RenameMap)(s: Statement): Statement = { + def onName(ofModuleOpt: Option[String])(name: String) = { + if (prefix.nonEmpty && !ns.tryName(prefix + name)) { + throw new Exception(s"Inlining failed. Inlined name '${prefix + name}' already exists") + } + ofModuleOpt match { + case None => + renameMap.record(currentModule.ref(name), currentModule.ref(prefix + name)) + case Some(ofModule) => + renameMap.record(currentModule.instOf(name, ofModule), currentModule.instOf(prefix + name, ofModule)) + } + renames(name) = prefix + name + prefix + name + } + + s match { + case s: WDefInstance => s.map(onName(Some(s.module))).map(appendNamePrefix(currentModule, prefix, ns, renames, renameMap)) + case other => s.map(onName(None)).map(appendNamePrefix(currentModule, prefix, ns, renames, renameMap)) + } } /** Modify all references */ - def appendRefPrefix(currentModule: ModuleName, renames: RenameMap) - (e: Expression): Expression = e match { + def appendRefPrefix( + currentModule: IsModule, + renames: mutable.HashMap[String, String])(s: Statement): Statement = { + def onExpr(e: Expression): Expression = e match { + case wr@ WRef(name, _, _, _) => + renames.get(name) match { + case Some(prefixedName) => wr.copy(name = prefixedName) + case None => wr + } + case ex => ex.map(onExpr) + } + s.map(onExpr).map(appendRefPrefix(currentModule, renames)) + } + + def fixupRefs( + instMap: Map[String, String], + currentModule: IsModule, + renames: RenameMap)(e: Expression): Expression = e match { case wsf@ WSubField(wr@ WRef(ref, _, InstanceKind, _), field, tpe, gen) => - val port = ComponentName(s"$ref.$field", currentModule) - val inst = ComponentName(s"$ref", currentModule) - (renames.get(port), renames.get(inst)) match { - case (Some(Seq(p)), _) => - p.toTarget match { - case ReferenceTarget(_, _, Seq(), r, Seq(TargetToken.Field(f))) => wsf.copy(expr = wr.copy(name = r), name = f) - case ReferenceTarget(_, _, Seq(), r, Seq()) => WRef(r, tpe, WireKind, gen) + val inst = currentModule.instOf(ref, instMap(ref)) + val port = inst.ref(field) + renames.get(port) match { + case Some(Seq(p)) => + p match { + case ReferenceTarget(_, _, Seq((TargetToken.Instance(r), TargetToken.OfModule(_))), f, Nil) => + wsf.copy(expr = wr.copy(name = r), name = f) + case ReferenceTarget(_, _, Nil, r, Nil) => WRef(r, tpe, WireKind, gen) } - case (None, Some(Seq(i))) => wsf.map(appendRefPrefix(currentModule, renames)) - case (None, None) => wsf + case None => wsf } case wr@ WRef(name, _, _, _) => - val comp = ComponentName(name, currentModule) - renames.get(comp).orElse(Some(Seq(comp))) match { - case Some(Seq(car)) => wr.copy(name=car.name) - case c@ Some(_) => throw new PassException( - s"Inlining found mlutiple renames for ref $comp -> $c. This should be impossible...") + val comp = currentModule.ref(name) //ComponentName(name, currentModule) + renames.get(comp).getOrElse(Seq(comp)) match { + case Seq(car: ReferenceTarget) => wr.copy(name=car.ref) } - case ex => ex.map(appendRefPrefix(currentModule, renames)) + case ex => ex.map(fixupRefs(instMap, currentModule, renames)) } - def onStmt(currentModule: ModuleName, renames: RenameMap)(s: Statement): Statement = { - val ns = namespaceMap.getOrElseUpdate(currentModule.name, Namespace(iGraph.moduleMap(currentModule.name))) - renames.setModule(currentModule.name) + var renames = RenameMap() + + def onStmt(currentModule: ModuleTarget)(s: Statement): Statement = { + val currentModuleName = currentModule.module + val ns = namespaceMap.getOrElseUpdate(currentModuleName, Namespace(iGraph.moduleMap(currentModuleName))) + val instMap = getInstMap(iGraph.moduleMap(currentModuleName)) s match { - case wDef@ WDefInstance(_, instName, modName, _) if flatInstances.contains(s"${currentModule.name}.$instName") => + case wDef@ WDefInstance(_, instName, modName, _) if flatInstances.contains(s"${currentModuleName}.$instName") => val toInline = iGraph.moduleMap(modName) match { case m: ExtModule => throw new PassException(s"Cannot inline external module ${m.name}") case m: Module => m @@ -177,9 +229,7 @@ class InlineInstances extends Transform with RegisteredTransform { val ports = toInline.ports.map(p => DefWire(p.info, p.name, p.tpe)) - val subRenames = RenameMap() - subRenames.setCircuit(currentModule.circuit.name) - val bodyx = Block(ports :+ toInline.body) map onStmt(currentModule.copy(name=modName), subRenames) + val bodyx = Block(ports :+ toInline.body) map onStmt(currentModule.copy(module = modName)) val names = "" +: Uniquify .enumerateNames(Uniquify.stmtToType(bodyx)(NoInfo, "")) @@ -192,26 +242,31 @@ class InlineInstances extends Transform with RegisteredTransform { */ val safePrefix = Uniquify.findValidPrefix(instName + inlineDelim, names, ns.cloneUnderlying - instName) - ports.foreach( p => renames.rename(s"$instName.${p.name}", safePrefix + p.name) ) + val prefixRenames = RenameMap() + val prefixMap = mutable.HashMap.empty[String, String] + val inlineTarget = currentModule.instOf(instName, modName) + val renamedBody = bodyx + .map(appendNamePrefix(inlineTarget, safePrefix, ns, prefixMap, prefixRenames)) + .map(appendRefPrefix(inlineTarget, prefixMap)) - def recName(s: Statement): Statement = s.map(recName).map(appendNamePrefix(safePrefix, ns, subRenames)) - def recRef(s: Statement): Statement = s.map(recRef).map(appendRefPrefix(currentModule.copy(name=modName), subRenames)) + val inlineRenames = RenameMap() + inlineRenames.record(inlineTarget, currentModule) - bodyx - .map(recName) - .map(recRef) - case sx => sx - .map(appendRefPrefix(currentModule, renames)) - .map(onStmt(currentModule, renames)) + renames = renames.andThen(prefixRenames).andThen(inlineRenames) + + renamedBody + case sx => + sx + .map(fixupRefs(instMap, currentModule, renames)) + .map(onStmt(currentModule)) } } - val renames = RenameMap() - renames.setCircuit(c.main) val flatCircuit = c.copy(modules = c.modules.flatMap { case m if flatModules.contains(m.name) => None - case m => Some(m.map(onStmt(ModuleName(m.name, CircuitName(c.main)), renames))) + case m => Some(m.map(onStmt(ModuleName(m.name, CircuitName(c.main))))) }) - CircuitState(flatCircuit, LowForm, annos, None) + + CircuitState(flatCircuit, LowForm, annos, Some(renames)) } } diff --git a/src/test/scala/firrtlTests/InlineInstancesTests.scala b/src/test/scala/firrtlTests/InlineInstancesTests.scala index 7132c0f3..5f48c883 100644 --- a/src/test/scala/firrtlTests/InlineInstancesTests.scala +++ b/src/test/scala/firrtlTests/InlineInstancesTests.scala @@ -8,7 +8,7 @@ import org.scalatest.junit.JUnitRunner import firrtl.ir.Circuit import firrtl.Parser import firrtl.passes.PassExceptions -import firrtl.annotations.{Annotation, CircuitName, ComponentName, ModuleName, Named} +import firrtl.annotations._ import firrtl.passes.{InlineAnnotation, InlineInstances} import logger.{LogLevel, Logger} import logger.LogLevel.Debug @@ -382,6 +382,83 @@ class InlineInstancesTests extends LowTransformSpec { | b <= a""".stripMargin execute(input, check, Seq(inline("Inline"))) } + + case class DummyAnno(target: ReferenceTarget) extends SingleTargetAnnotation[ReferenceTarget] { + override def duplicate(n: ReferenceTarget): Annotation = DummyAnno(n) + } + "annotations" should "be renamed" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline + | i.a <= a + | b <= i.b + | module Inline : + | input a : UInt<32> + | output b : UInt<32> + | inst foo of NestedInline + | inst bar of NestedNoInline + | foo.a <= a + | bar.a <= foo.b + | b <= bar.b + | module NestedInline : + | input a : UInt<32> + | output b : UInt<32> + | b <= a + | module NestedNoInline : + | input a : UInt<32> + | output b : UInt<32> + | b <= a + |""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i_a : UInt<32> + | wire i_b : UInt<32> + | wire i_foo_a : UInt<32> + | wire i_foo_b : UInt<32> + | i_foo_b <= i_foo_a + | inst i_bar of NestedNoInline + | i_b <= i_bar.b + | i_foo_a <= i_a + | i_bar.a <= i_foo_b + | b <= i_b + | i_a <= a + | module NestedNoInline : + | input a : UInt<32> + | output b : UInt<32> + | b <= a + |""".stripMargin + val top = CircuitTarget("Top").module("Top") + val inlined = top.instOf("i", "Inline") + val nestedInlined = top.instOf("i", "Inline").instOf("foo", "NestedInline") + val nestedNotInlined = top.instOf("i", "Inline").instOf("bar", "NestedNoInline") + + executeWithAnnos(input, check, + Seq( + inline("Inline"), + inline("NestedInline"), + DummyAnno(inlined.ref("a")), + DummyAnno(inlined.ref("b")), + DummyAnno(nestedInlined.ref("a")), + DummyAnno(nestedInlined.ref("b")), + DummyAnno(nestedNotInlined.ref("a")), + DummyAnno(nestedNotInlined.ref("b")) + ), + Seq( + DummyAnno(top.ref("i_a")), + DummyAnno(top.ref("i_b")), + DummyAnno(top.ref("i_foo_a")), + DummyAnno(top.ref("i_foo_b")), + DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("a")), + DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("b")) + ) + ) + } } // Execution driven tests for inlining modules diff --git a/src/test/scala/firrtlTests/PassTests.scala b/src/test/scala/firrtlTests/PassTests.scala index 6f94275e..7fe154ec 100644 --- a/src/test/scala/firrtlTests/PassTests.scala +++ b/src/test/scala/firrtlTests/PassTests.scala @@ -29,6 +29,20 @@ abstract class SimpleTransformSpec extends FlatSpec with FirrtlMatchers with Com (actual) should be (expected) finalState } + + def executeWithAnnos(input: String, check: String, annotations: Seq[Annotation], + checkAnnotations: Seq[Annotation]): CircuitState = { + val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations)) + val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize + val expected = parse(check).serialize + logger.debug(actual) + logger.debug(expected) + (actual) should be (expected) + checkAnnotations.foreach { check => + (finalState.annotations.toSeq) should contain (check) + } + finalState + } // Executes the test, should throw an error // No default to be consistent with execute def failingexecute(input: String, annotations: Seq[Annotation]): Exception = { |
