aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorazidar2016-09-30 11:46:01 -0700
committerJack Koenig2016-11-04 13:29:09 -0700
commit1c36656fad15f515543d89a6407b360b4b2ebb87 (patch)
treee115e2e565f4af78aa310dae169e2df2e20a9894 /src
parent8fa9429a6e916ab2a789f5d81fa803b022805b52 (diff)
Add a pass to deduplicate modules
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/Compiler.scala5
-rw-r--r--src/main/scala/firrtl/transforms/Dedup.scala89
-rw-r--r--src/test/scala/firrtlTests/AttachSpec.scala28
-rw-r--r--src/test/scala/firrtlTests/PassTests.scala1
-rw-r--r--src/test/scala/firrtlTests/transforms/DedupTests.scala112
5 files changed, 219 insertions, 16 deletions
diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala
index 9781972e..949807db 100644
--- a/src/main/scala/firrtl/Compiler.scala
+++ b/src/main/scala/firrtl/Compiler.scala
@@ -176,8 +176,9 @@ object CompilerUtils {
} else {
inputForm match {
case ChirrtlForm => Seq(new ChirrtlToHighFirrtl) ++ getLoweringTransforms(HighForm, outputForm)
- case HighForm => Seq(new IRToWorkingIR, new ResolveAndCheck, new HighFirrtlToMiddleFirrtl) ++
- getLoweringTransforms(MidForm, outputForm)
+ case HighForm =>
+ Seq(new IRToWorkingIR, new ResolveAndCheck, new transforms.DedupModules,
+ new HighFirrtlToMiddleFirrtl) ++ getLoweringTransforms(MidForm, outputForm)
case MidForm => Seq(new MiddleFirrtlToLowFirrtl) ++ getLoweringTransforms(LowForm, outputForm)
case LowForm => error("Internal Error! This shouldn't be possible") // should be caught by if above
}
diff --git a/src/main/scala/firrtl/transforms/Dedup.scala b/src/main/scala/firrtl/transforms/Dedup.scala
new file mode 100644
index 00000000..5d953e73
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/Dedup.scala
@@ -0,0 +1,89 @@
+// See LICENSE for license details.
+
+package firrtl
+package transforms
+
+import firrtl.ir._
+import firrtl.Mappers._
+import firrtl.Annotations._
+import firrtl.passes.PassException
+
+// Datastructures
+import scala.collection.mutable
+
+// Tags an annotation to be consumed by this pass
+case class DedupAnnotation(target: Named) extends Annotation with Loose with Unstable {
+ def duplicate(n: Named) = this.copy(target=n)
+ def transform = classOf[DedupModules]
+}
+
+// Only use on legal Firrtl. Specifically, the restriction of
+// instance loops must have been checked, or else this pass can
+// infinitely recurse
+class DedupModules extends Transform {
+ def inputForm = HighForm
+ def outputForm = HighForm
+ def execute(state: CircuitState): CircuitState = state.copy(circuit = run(state.circuit))
+ def run(c: Circuit): Circuit = {
+ val moduleOrder = mutable.ArrayBuffer.empty[String]
+ val moduleMap = c.modules.map(m => m.name -> m).toMap
+ def hasInstance(b: Statement): Boolean = {
+ var has = false
+ def onStmt(s: Statement): Statement = s map onStmt match {
+ case DefInstance(i, n, m) =>
+ if(!(moduleOrder contains m)) has = true
+ s
+ case WDefInstance(i, n, m, t) =>
+ if(!(moduleOrder contains m)) has = true
+ s
+ case _ => s
+ }
+ onStmt(b)
+ has
+ }
+ def addModule(m: DefModule): DefModule = m match {
+ case Module(info, n, ps, b) =>
+ if(!hasInstance(b)) moduleOrder += m.name
+ m
+ case e: ExtModule =>
+ moduleOrder += m.name
+ m
+ case _ => m
+ }
+
+ while((moduleOrder.size < c.modules.size)) {
+ c.modules.foreach(m => if(!moduleOrder.contains(m.name)) addModule(m))
+ }
+
+ // Module body -> Module name
+ val dedupModules = mutable.HashMap.empty[String, String]
+ // Old module name -> dup module name
+ val dedupMap = mutable.HashMap.empty[String, String]
+ def onModule(m: DefModule): Option[DefModule] = {
+ def fixInstance(s: Statement): Statement = s map fixInstance match {
+ case DefInstance(i, n, m) => DefInstance(i, n, dedupMap.getOrElse(m, m))
+ case WDefInstance(i, n, m, t) => WDefInstance(i, n, dedupMap.getOrElse(m, m), t)
+ case x => x
+ }
+
+ val mx = m map fixInstance
+ val string = mx match {
+ case Module(i, n, ps, b) =>
+ ps.map(_.serialize).mkString + b.serialize
+ case ExtModule(i, n, ps, dn, p) =>
+ ps.map(_.serialize).mkString + dn + p.map(_.serialize).mkString
+ }
+ dedupModules.get(string) match {
+ case Some(dupname) =>
+ dedupMap(mx.name) = dupname
+ None
+ case None =>
+ dedupModules(string) = mx.name
+ Some(mx)
+ }
+ }
+ val modulesx = moduleOrder.flatMap(n => onModule(moduleMap(n)))
+ val modulesxMap = modulesx.map(m => m.name -> m).toMap
+ c.copy(modules = c.modules.flatMap(m => modulesxMap.get(m.name)))
+ }
+}
diff --git a/src/test/scala/firrtlTests/AttachSpec.scala b/src/test/scala/firrtlTests/AttachSpec.scala
index 3a67bf04..337763c9 100644
--- a/src/test/scala/firrtlTests/AttachSpec.scala
+++ b/src/test/scala/firrtlTests/AttachSpec.scala
@@ -53,28 +53,28 @@ class InoutVerilog extends FirrtlFlatSpec {
| input an: Analog<3>
| inst a of A
| inst b of B
- | attach an to (a.an, b.an)
+ | attach an to (a.an1, b.an2)
| module A:
- | input an: Analog<3>
+ | input an1: Analog<3>
| module B:
- | input an: Analog<3> """.stripMargin
+ | input an2: Analog<3> """.stripMargin
val check =
"""module Attaching(
| inout [2:0] an
|);
| A a (
- | .an(an)
+ | .an1(an)
| );
| B b (
- | .an(an)
+ | .an2(an)
| );
|endmodule
|module A(
- | inout [2:0] an
+ | inout [2:0] an1
|);
|endmodule
|module B(
- | inout [2:0] an
+ | inout [2:0] an2
|);
|endmodule
|""".stripMargin.split("\n") map normalized
@@ -89,28 +89,28 @@ class InoutVerilog extends FirrtlFlatSpec {
| output an: Analog<3>
| inst a of A
| inst b of B
- | attach an to (a.an, b.an)
+ | attach an to (a.an1, b.an2)
| module A:
- | input an: Analog<3>
+ | input an1: Analog<3>
| module B:
- | input an: Analog<3> """.stripMargin
+ | input an2: Analog<3> """.stripMargin
val check =
"""module Attaching(
| inout [2:0] an
|);
| A a (
- | .an(an)
+ | .an1(an)
| );
| B b (
- | .an(an)
+ | .an2(an)
| );
|endmodule
|module A(
- | inout [2:0] an
+ | inout [2:0] an1
|);
|endmodule
|module B(
- | inout [2:0] an
+ | inout [2:0] an2
|);
|endmodule
|""".stripMargin.split("\n") map normalized
diff --git a/src/test/scala/firrtlTests/PassTests.scala b/src/test/scala/firrtlTests/PassTests.scala
index e574d31f..1aaf77b6 100644
--- a/src/test/scala/firrtlTests/PassTests.scala
+++ b/src/test/scala/firrtlTests/PassTests.scala
@@ -36,6 +36,7 @@ abstract class SimpleTransformSpec extends FlatSpec with Matchers with Compiler
// Utility function
def parse(s: String): Circuit = Parser.parse(s.split("\n").toIterator, infoMode = IgnoreInfo)
+ def squash(c: Circuit): Circuit = RemoveEmpty.run(c)
// Executes the test. Call in tests.
def execute(writer: Writer, annotations: AnnotationMap, input: String, check: String) = {
diff --git a/src/test/scala/firrtlTests/transforms/DedupTests.scala b/src/test/scala/firrtlTests/transforms/DedupTests.scala
new file mode 100644
index 00000000..fac6a1fb
--- /dev/null
+++ b/src/test/scala/firrtlTests/transforms/DedupTests.scala
@@ -0,0 +1,112 @@
+package firrtlTests
+package transform
+
+import java.io.StringWriter
+
+import org.scalatest.FlatSpec
+import org.scalatest.Matchers
+import org.scalatest.junit.JUnitRunner
+
+import firrtl.ir.Circuit
+import firrtl.Parser
+import firrtl.passes.PassExceptions
+import firrtl.Annotations.{
+ Named,
+ CircuitName,
+ Annotation,
+ AnnotationMap
+}
+import firrtl.transforms.{DedupModules, DedupAnnotation}
+
+
+/**
+ * Tests inline instances transformation
+ */
+class DedupModuleTests extends HighTransformSpec {
+ def transform = new DedupModules
+ "The module A" should "be deduped" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | inst a1 of A
+ | inst a2 of A_
+ | module A :
+ | output x: UInt<1>
+ | x <= UInt(1)
+ | module A_ :
+ | output x: UInt<1>
+ | x <= UInt(1)
+ """.stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | inst a1 of A
+ | inst a2 of A
+ | module A :
+ | output x: UInt<1>
+ | x <= UInt(1)
+ """.stripMargin
+ val writer = new StringWriter()
+ val aMap = new AnnotationMap(Nil)
+ execute(writer, aMap, input, check)
+ }
+ "The module A and B" should "be deduped" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | inst a1 of A
+ | inst a2 of A_
+ | module A :
+ | output x: UInt<1>
+ | inst b of B
+ | x <= b.x
+ | module A_ :
+ | output x: UInt<1>
+ | inst b of B_
+ | x <= b.x
+ | module B :
+ | output x: UInt<1>
+ | x <= UInt(1)
+ | module B_ :
+ | output x: UInt<1>
+ | x <= UInt(1)
+ """.stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | inst a1 of A
+ | inst a2 of A
+ | module A :
+ | output x: UInt<1>
+ | inst b of B
+ | x <= b.x
+ | module B :
+ | output x: UInt<1>
+ | x <= UInt(1)
+ """.stripMargin
+ val writer = new StringWriter()
+ val aMap = new AnnotationMap(Nil)
+ execute(writer, aMap, input, check)
+ }
+}
+
+// Execution driven tests for inlining modules
+// TODO(izraelevitz) fix this test
+//class InlineInstancesIntegrationSpec extends FirrtlPropSpec {
+// // Shorthand for creating annotations to inline modules
+// def inlineModules(names: Seq[String]): Seq[CircuitAnnotation] =
+// Seq(StickyCircuitAnnotation(InlineCAKind, names.map(n => ModuleName(n) -> TagAnnotation).toMap))
+//
+// case class Test(name: String, dir: String, ann: Seq[CircuitAnnotation])
+//
+// val runTests = Seq(
+// Test("GCDTester", "/integration", inlineModules(Seq("DecoupledGCD")))
+// )
+//
+// runTests foreach { test =>
+// property(s"${test.name} should execute correctly with inlining") {
+// println(s"Got annotations ${test.ann}")
+// runFirrtlTest(test.name, test.dir, test.ann)
+// }
+// }
+//}