diff options
| author | azidar | 2016-09-30 11:46:01 -0700 |
|---|---|---|
| committer | Jack Koenig | 2016-11-04 13:29:09 -0700 |
| commit | 1c36656fad15f515543d89a6407b360b4b2ebb87 (patch) | |
| tree | e115e2e565f4af78aa310dae169e2df2e20a9894 /src | |
| parent | 8fa9429a6e916ab2a789f5d81fa803b022805b52 (diff) | |
Add a pass to deduplicate modules
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Compiler.scala | 5 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/Dedup.scala | 89 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/AttachSpec.scala | 28 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/PassTests.scala | 1 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/transforms/DedupTests.scala | 112 |
5 files changed, 219 insertions, 16 deletions
diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala index 9781972e..949807db 100644 --- a/src/main/scala/firrtl/Compiler.scala +++ b/src/main/scala/firrtl/Compiler.scala @@ -176,8 +176,9 @@ object CompilerUtils { } else { inputForm match { case ChirrtlForm => Seq(new ChirrtlToHighFirrtl) ++ getLoweringTransforms(HighForm, outputForm) - case HighForm => Seq(new IRToWorkingIR, new ResolveAndCheck, new HighFirrtlToMiddleFirrtl) ++ - getLoweringTransforms(MidForm, outputForm) + case HighForm => + Seq(new IRToWorkingIR, new ResolveAndCheck, new transforms.DedupModules, + new HighFirrtlToMiddleFirrtl) ++ getLoweringTransforms(MidForm, outputForm) case MidForm => Seq(new MiddleFirrtlToLowFirrtl) ++ getLoweringTransforms(LowForm, outputForm) case LowForm => error("Internal Error! This shouldn't be possible") // should be caught by if above } diff --git a/src/main/scala/firrtl/transforms/Dedup.scala b/src/main/scala/firrtl/transforms/Dedup.scala new file mode 100644 index 00000000..5d953e73 --- /dev/null +++ b/src/main/scala/firrtl/transforms/Dedup.scala @@ -0,0 +1,89 @@ +// See LICENSE for license details. + +package firrtl +package transforms + +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.Annotations._ +import firrtl.passes.PassException + +// Datastructures +import scala.collection.mutable + +// Tags an annotation to be consumed by this pass +case class DedupAnnotation(target: Named) extends Annotation with Loose with Unstable { + def duplicate(n: Named) = this.copy(target=n) + def transform = classOf[DedupModules] +} + +// Only use on legal Firrtl. Specifically, the restriction of +// instance loops must have been checked, or else this pass can +// infinitely recurse +class DedupModules extends Transform { + def inputForm = HighForm + def outputForm = HighForm + def execute(state: CircuitState): CircuitState = state.copy(circuit = run(state.circuit)) + def run(c: Circuit): Circuit = { + val moduleOrder = mutable.ArrayBuffer.empty[String] + val moduleMap = c.modules.map(m => m.name -> m).toMap + def hasInstance(b: Statement): Boolean = { + var has = false + def onStmt(s: Statement): Statement = s map onStmt match { + case DefInstance(i, n, m) => + if(!(moduleOrder contains m)) has = true + s + case WDefInstance(i, n, m, t) => + if(!(moduleOrder contains m)) has = true + s + case _ => s + } + onStmt(b) + has + } + def addModule(m: DefModule): DefModule = m match { + case Module(info, n, ps, b) => + if(!hasInstance(b)) moduleOrder += m.name + m + case e: ExtModule => + moduleOrder += m.name + m + case _ => m + } + + while((moduleOrder.size < c.modules.size)) { + c.modules.foreach(m => if(!moduleOrder.contains(m.name)) addModule(m)) + } + + // Module body -> Module name + val dedupModules = mutable.HashMap.empty[String, String] + // Old module name -> dup module name + val dedupMap = mutable.HashMap.empty[String, String] + def onModule(m: DefModule): Option[DefModule] = { + def fixInstance(s: Statement): Statement = s map fixInstance match { + case DefInstance(i, n, m) => DefInstance(i, n, dedupMap.getOrElse(m, m)) + case WDefInstance(i, n, m, t) => WDefInstance(i, n, dedupMap.getOrElse(m, m), t) + case x => x + } + + val mx = m map fixInstance + val string = mx match { + case Module(i, n, ps, b) => + ps.map(_.serialize).mkString + b.serialize + case ExtModule(i, n, ps, dn, p) => + ps.map(_.serialize).mkString + dn + p.map(_.serialize).mkString + } + dedupModules.get(string) match { + case Some(dupname) => + dedupMap(mx.name) = dupname + None + case None => + dedupModules(string) = mx.name + Some(mx) + } + } + val modulesx = moduleOrder.flatMap(n => onModule(moduleMap(n))) + val modulesxMap = modulesx.map(m => m.name -> m).toMap + c.copy(modules = c.modules.flatMap(m => modulesxMap.get(m.name))) + } +} diff --git a/src/test/scala/firrtlTests/AttachSpec.scala b/src/test/scala/firrtlTests/AttachSpec.scala index 3a67bf04..337763c9 100644 --- a/src/test/scala/firrtlTests/AttachSpec.scala +++ b/src/test/scala/firrtlTests/AttachSpec.scala @@ -53,28 +53,28 @@ class InoutVerilog extends FirrtlFlatSpec { | input an: Analog<3> | inst a of A | inst b of B - | attach an to (a.an, b.an) + | attach an to (a.an1, b.an2) | module A: - | input an: Analog<3> + | input an1: Analog<3> | module B: - | input an: Analog<3> """.stripMargin + | input an2: Analog<3> """.stripMargin val check = """module Attaching( | inout [2:0] an |); | A a ( - | .an(an) + | .an1(an) | ); | B b ( - | .an(an) + | .an2(an) | ); |endmodule |module A( - | inout [2:0] an + | inout [2:0] an1 |); |endmodule |module B( - | inout [2:0] an + | inout [2:0] an2 |); |endmodule |""".stripMargin.split("\n") map normalized @@ -89,28 +89,28 @@ class InoutVerilog extends FirrtlFlatSpec { | output an: Analog<3> | inst a of A | inst b of B - | attach an to (a.an, b.an) + | attach an to (a.an1, b.an2) | module A: - | input an: Analog<3> + | input an1: Analog<3> | module B: - | input an: Analog<3> """.stripMargin + | input an2: Analog<3> """.stripMargin val check = """module Attaching( | inout [2:0] an |); | A a ( - | .an(an) + | .an1(an) | ); | B b ( - | .an(an) + | .an2(an) | ); |endmodule |module A( - | inout [2:0] an + | inout [2:0] an1 |); |endmodule |module B( - | inout [2:0] an + | inout [2:0] an2 |); |endmodule |""".stripMargin.split("\n") map normalized diff --git a/src/test/scala/firrtlTests/PassTests.scala b/src/test/scala/firrtlTests/PassTests.scala index e574d31f..1aaf77b6 100644 --- a/src/test/scala/firrtlTests/PassTests.scala +++ b/src/test/scala/firrtlTests/PassTests.scala @@ -36,6 +36,7 @@ abstract class SimpleTransformSpec extends FlatSpec with Matchers with Compiler // Utility function def parse(s: String): Circuit = Parser.parse(s.split("\n").toIterator, infoMode = IgnoreInfo) + def squash(c: Circuit): Circuit = RemoveEmpty.run(c) // Executes the test. Call in tests. def execute(writer: Writer, annotations: AnnotationMap, input: String, check: String) = { diff --git a/src/test/scala/firrtlTests/transforms/DedupTests.scala b/src/test/scala/firrtlTests/transforms/DedupTests.scala new file mode 100644 index 00000000..fac6a1fb --- /dev/null +++ b/src/test/scala/firrtlTests/transforms/DedupTests.scala @@ -0,0 +1,112 @@ +package firrtlTests +package transform + +import java.io.StringWriter + +import org.scalatest.FlatSpec +import org.scalatest.Matchers +import org.scalatest.junit.JUnitRunner + +import firrtl.ir.Circuit +import firrtl.Parser +import firrtl.passes.PassExceptions +import firrtl.Annotations.{ + Named, + CircuitName, + Annotation, + AnnotationMap +} +import firrtl.transforms.{DedupModules, DedupAnnotation} + + +/** + * Tests inline instances transformation + */ +class DedupModuleTests extends HighTransformSpec { + def transform = new DedupModules + "The module A" should "be deduped" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | x <= UInt(1) + | module A_ : + | output x: UInt<1> + | x <= UInt(1) + """.stripMargin + val check = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A + | module A : + | output x: UInt<1> + | x <= UInt(1) + """.stripMargin + val writer = new StringWriter() + val aMap = new AnnotationMap(Nil) + execute(writer, aMap, input, check) + } + "The module A and B" should "be deduped" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | inst b of B + | x <= b.x + | module A_ : + | output x: UInt<1> + | inst b of B_ + | x <= b.x + | module B : + | output x: UInt<1> + | x <= UInt(1) + | module B_ : + | output x: UInt<1> + | x <= UInt(1) + """.stripMargin + val check = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A + | module A : + | output x: UInt<1> + | inst b of B + | x <= b.x + | module B : + | output x: UInt<1> + | x <= UInt(1) + """.stripMargin + val writer = new StringWriter() + val aMap = new AnnotationMap(Nil) + execute(writer, aMap, input, check) + } +} + +// Execution driven tests for inlining modules +// TODO(izraelevitz) fix this test +//class InlineInstancesIntegrationSpec extends FirrtlPropSpec { +// // Shorthand for creating annotations to inline modules +// def inlineModules(names: Seq[String]): Seq[CircuitAnnotation] = +// Seq(StickyCircuitAnnotation(InlineCAKind, names.map(n => ModuleName(n) -> TagAnnotation).toMap)) +// +// case class Test(name: String, dir: String, ann: Seq[CircuitAnnotation]) +// +// val runTests = Seq( +// Test("GCDTester", "/integration", inlineModules(Seq("DecoupledGCD"))) +// ) +// +// runTests foreach { test => +// property(s"${test.name} should execute correctly with inlining") { +// println(s"Got annotations ${test.ann}") +// runFirrtlTest(test.name, test.dir, test.ann) +// } +// } +//} |
