diff options
Diffstat (limited to 'src/test')
| -rw-r--r-- | src/test/scala/firrtlTests/AnnotationTests.scala | 81 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/AttachSpec.scala | 11 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/DCETests.scala | 366 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/FirrtlSpec.scala | 22 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ReplSeqMemTests.scala | 11 |
5 files changed, 467 insertions, 24 deletions
diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala index 81c394c1..44964eba 100644 --- a/src/test/scala/firrtlTests/AnnotationTests.scala +++ b/src/test/scala/firrtlTests/AnnotationTests.scala @@ -7,8 +7,10 @@ import java.io.{File, FileWriter, Writer} import firrtl.annotations.AnnotationYamlProtocol._ import firrtl.annotations._ import firrtl._ +import firrtl.transforms.OptimizableExtModuleAnnotation import firrtl.passes.InlineAnnotation import firrtl.passes.memlib.PinAnnotation +import firrtl.transforms.DontTouchAnnotation import net.jcazevedo.moultingyaml._ import org.scalatest.Matchers import logger._ @@ -43,8 +45,17 @@ trait AnnotationSpec extends LowTransformSpec { class AnnotationTests extends AnnotationSpec with Matchers { def getAMap(a: Annotation): Option[AnnotationMap] = Some(AnnotationMap(Seq(a))) def getAMap(as: Seq[Annotation]): Option[AnnotationMap] = Some(AnnotationMap(as)) - def anno(s: String, value: String ="this is a value"): Annotation = - Annotation(ComponentName(s, ModuleName("Top", CircuitName("Top"))), classOf[Transform], value) + def anno(s: String, value: String ="this is a value", mod: String = "Top"): Annotation = + Annotation(ComponentName(s, ModuleName(mod, CircuitName("Top"))), classOf[Transform], value) + def manno(mod: String): Annotation = + Annotation(ModuleName(mod, CircuitName("Top")), classOf[Transform], "some value") + // TODO unify with FirrtlMatchers, problems with multiple definitions of parse + def dontTouch(path: String): Annotation = { + val parts = path.split('.') + require(parts.size >= 2, "Must specify both module and component!") + val name = ComponentName(parts.tail.mkString("."), ModuleName(parts.head, CircuitName("Top"))) + DontTouchAnnotation(name) + } "Loose and Sticky annotation on a node" should "pass through" in { val input: String = @@ -145,7 +156,6 @@ class AnnotationTests extends AnnotationSpec with Matchers { val deleted = result.deletedAnnotations exception.str should be (s"No EmittedCircuit found! Did you delete any annotations?\n$deleted") } - "Renaming" should "propagate in Lowering of memories" in { val compiler = new VerilogCompiler // Uncomment to help debugging failing tests @@ -165,7 +175,8 @@ class AnnotationTests extends AnnotationSpec with Matchers { | m.r.en <= UInt(1) | m.r.addr <= in |""".stripMargin - val annos = Seq(anno("m.r.data.b", "sub"), anno("m.r.data", "all"), anno("m", "mem")) + val annos = Seq(anno("m.r.data.b", "sub"), anno("m.r.data", "all"), anno("m", "mem"), + dontTouch("Top.m")) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) val resultAnno = result.annotations.get.annotations resultAnno should contain (anno("m_a", "mem")) @@ -179,7 +190,6 @@ class AnnotationTests extends AnnotationSpec with Matchers { resultAnno should not contain (anno("m")) resultAnno should not contain (anno("r")) } - "Renaming" should "propagate in RemoveChirrtl and Lowering of memories" in { val compiler = new VerilogCompiler Logger.setClassLogLevels(Map(compiler.getClass.getName -> LogLevel.Debug)) @@ -191,7 +201,7 @@ class AnnotationTests extends AnnotationSpec with Matchers { | cmem m: {a: UInt<4>, b: UInt<4>[2]}[8] | read mport r = m[in], clk |""".stripMargin - val annos = Seq(anno("r.b", "sub"), anno("r", "all"), anno("m", "mem")) + val annos = Seq(anno("r.b", "sub"), anno("r", "all"), anno("m", "mem"), dontTouch("Top.m")) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) val resultAnno = result.annotations.get.annotations resultAnno should contain (anno("m_a", "mem")) @@ -220,7 +230,8 @@ class AnnotationTests extends AnnotationSpec with Matchers { | x.a <= zero | x.b <= zero |""".stripMargin - val annos = Seq(anno("zero"), anno("x.a"), anno("x.b"), anno("y[0]"), anno("y[1]"), anno("y[2]")) + val annos = Seq(anno("zero"), anno("x.a"), anno("x.b"), anno("y[0]"), anno("y[1]"), + anno("y[2]"), dontTouch("Top.x")) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) val resultAnno = result.annotations.get.annotations resultAnno should contain (anno("x_a")) @@ -260,7 +271,8 @@ class AnnotationTests extends AnnotationSpec with Matchers { anno("w.a"), anno("w.b[0]"), anno("w.b[1]"), anno("n.a"), anno("n.b[0]"), anno("n.b[1]"), anno("r.a"), anno("r.b[0]"), anno("r.b[1]"), - anno("write.a"), anno("write.b[0]"), anno("write.b[1]") + anno("write.a"), anno("write.b[0]"), anno("write.b[1]"), + dontTouch("Top.r") ) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) val resultAnno = result.annotations.get.annotations @@ -314,7 +326,7 @@ class AnnotationTests extends AnnotationSpec with Matchers { | out <= n | reg r: {a: UInt<3>, b: UInt<3>[2]}, clk |""".stripMargin - val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r")) + val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r"), dontTouch("Top.r")) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) val resultAnno = result.annotations.get.annotations resultAnno should contain (anno("in_a")) @@ -349,7 +361,8 @@ class AnnotationTests extends AnnotationSpec with Matchers { | out <= n | reg r: {a: UInt<3>, b: UInt<3>[2]}, clk |""".stripMargin - val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("n.b"), anno("r.b")) + val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("n.b"), anno("r.b"), + dontTouch("Top.r")) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) val resultAnno = result.annotations.get.annotations resultAnno should contain (anno("in_b_0")) @@ -364,8 +377,7 @@ class AnnotationTests extends AnnotationSpec with Matchers { resultAnno should contain (anno("r_b_1")) } - - "Renaming" should "track dce" in { + "Renaming" should "track constprop + dce" in { val compiler = new VerilogCompiler val input = """circuit Top : @@ -403,4 +415,49 @@ class AnnotationTests extends AnnotationSpec with Matchers { resultAnno should contain (anno("out_b_0")) resultAnno should contain (anno("out_b_1")) } + + ignore should "track deleted modules AND instances in dce" in { + val compiler = new VerilogCompiler + val input = + """circuit Top : + | module Dead : + | input foo : UInt<8> + | output bar : UInt<8> + | bar <= foo + | extmodule DeadExt : + | input foo : UInt<8> + | output bar : UInt<8> + | module Top : + | input foo : UInt<8> + | output bar : UInt<8> + | inst d of Dead + | d.foo <= foo + | inst d2 of DeadExt + | d2.foo <= foo + | bar <= foo + |""".stripMargin + val annos = Seq( + OptimizableExtModuleAnnotation(ModuleName("DeadExt", CircuitName("Top"))), + manno("Dead"), manno("DeadExt"), manno("Top"), + anno("d"), anno("d2"), + anno("foo", mod = "Top"), anno("bar", mod = "Top"), + anno("foo", mod = "Dead"), anno("bar", mod = "Dead"), + anno("foo", mod = "DeadExt"), anno("bar", mod = "DeadExt") + ) + val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) + val resultAnno = result.annotations.get.annotations + + resultAnno should contain (manno("Top")) + resultAnno should contain (anno("foo", mod = "Top")) + resultAnno should contain (anno("bar", mod = "Top")) + + resultAnno should not contain (manno("Dead")) + resultAnno should not contain (manno("DeadExt")) + resultAnno should not contain (anno("d")) + resultAnno should not contain (anno("d2")) + resultAnno should not contain (anno("foo", mod = "Dead")) + resultAnno should not contain (anno("bar", mod = "Dead")) + resultAnno should not contain (anno("foo", mod = "DeadExt")) + resultAnno should not contain (anno("bar", mod = "DeadExt")) + } } diff --git a/src/test/scala/firrtlTests/AttachSpec.scala b/src/test/scala/firrtlTests/AttachSpec.scala index c29a7e43..93a36f70 100644 --- a/src/test/scala/firrtlTests/AttachSpec.scala +++ b/src/test/scala/firrtlTests/AttachSpec.scala @@ -62,7 +62,7 @@ class InoutVerilogSpec extends FirrtlFlatSpec { | module A: | input an: Analog<3> | module B: - | input an: Analog<3> """.stripMargin + | input an: Analog<3>""".stripMargin val check = """module Attaching( |); @@ -70,16 +70,19 @@ class InoutVerilogSpec extends FirrtlFlatSpec { | A a ( | .an(_GEN_0) | ); - | A b ( + | B b ( | .an(_GEN_0) | ); |endmodule |module A( | inout [2:0] an |); + |module B( + | inout [2:0] an + |); |endmodule |""".stripMargin.split("\n") map normalized - executeTest(input, check, compiler) + executeTest(input, check, compiler, Seq(dontTouch("A.an"), dontDedup("A"))) } it should "attach a wire source" in { @@ -101,7 +104,7 @@ class InoutVerilogSpec extends FirrtlFlatSpec { | ); |endmodule |""".stripMargin.split("\n") map normalized - executeTest(input, check, compiler) + executeTest(input, check, compiler, Seq(dontTouch("Attaching.x"))) } it should "attach multiple sources" in { diff --git a/src/test/scala/firrtlTests/DCETests.scala b/src/test/scala/firrtlTests/DCETests.scala new file mode 100644 index 00000000..deb73b3b --- /dev/null +++ b/src/test/scala/firrtlTests/DCETests.scala @@ -0,0 +1,366 @@ +// See LICENSE for license details. + +package firrtlTests + +import firrtl.ir.Circuit +import firrtl._ +import firrtl.passes._ +import firrtl.transforms._ +import firrtl.annotations._ +import firrtl.passes.memlib.SimpleTransform + +class DCETests extends FirrtlFlatSpec { + // Not using executeTest because it is for positive testing, we need to check that stuff got + // deleted + private val customTransforms = Seq( + new LowFirrtlOptimization, + new SimpleTransform(RemoveEmpty, LowForm) + ) + private def exec(input: String, check: String, annos: Seq[Annotation] = List.empty): Unit = { + val state = CircuitState(parse(input), ChirrtlForm, Some(AnnotationMap(annos))) + val finalState = (new LowFirrtlCompiler).compileAndEmit(state, customTransforms) + val res = finalState.getEmittedCircuit.value + // Convert to sets for comparison + val resSet = Set(parse(res).serialize.split("\n"):_*) + val checkSet = Set(parse(check).serialize.split("\n"):_*) + resSet should be (checkSet) + } + + "Unread wire" should "be deleted" in { + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | wire a : UInt<1> + | z <= x + | a <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | z <= x""".stripMargin + exec(input, check) + } + "Unread wire marked dont touch" should "NOT be deleted" in { + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | wire a : UInt<1> + | z <= x + | a <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | wire a : UInt<1> + | z <= x + | a <= x""".stripMargin + exec(input, check, Seq(dontTouch("Top.a"))) + } + "Unread register" should "be deleted" in { + val input = + """circuit Top : + | module Top : + | input clk : Clock + | input x : UInt<1> + | output z : UInt<1> + | reg a : UInt<1>, clk + | a <= x + | node y = asUInt(clk) + | z <= or(x, y)""".stripMargin + val check = + """circuit Top : + | module Top : + | input clk : Clock + | input x : UInt<1> + | output z : UInt<1> + | node y = asUInt(clk) + | z <= or(x, y)""".stripMargin + exec(input, check) + } + "Unread node" should "be deleted" in { + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | node a = not(x) + | z <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | z <= x""".stripMargin + exec(input, check) + } + "Unused ports" should "be deleted" in { + val input = + """circuit Top : + | module Sub : + | input x : UInt<1> + | input y : UInt<1> + | output z : UInt<1> + | z <= x + | module Top : + | input x : UInt<1> + | input y : UInt<1> + | output z : UInt<1> + | inst sub of Sub + | sub.x <= x + | z <= sub.z""".stripMargin + val check = + """circuit Top : + | module Sub : + | input x : UInt<1> + | output z : UInt<1> + | z <= x + | module Top : + | input x : UInt<1> + | input y : UInt<1> + | output z : UInt<1> + | inst sub of Sub + | sub.x <= x + | z <= sub.z""".stripMargin + exec(input, check) + } + "Chain of unread nodes" should "be deleted" in { + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | node a = not(x) + | node b = or(a, a) + | node c = add(b, x) + | z <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | z <= x""".stripMargin + exec(input, check) + } + "Chain of unread wires and their connections" should "be deleted" in { + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | wire a : UInt<1> + | a <= x + | wire b : UInt<1> + | b <= a + | z <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | z <= x""".stripMargin + exec(input, check) + } + "Read register" should "not be deleted" in { + val input = + """circuit Top : + | module Top : + | input clk : Clock + | input x : UInt<1> + | output z : UInt<1> + | reg r : UInt<1>, clk + | r <= x + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clk : Clock + | input x : UInt<1> + | output z : UInt<1> + | reg r : UInt<1>, clk with : (reset => (UInt<1>("h0"), r)) + | r <= x + | z <= r""".stripMargin + exec(input, check) + } + "Logic that feeds into simulation constructs" should "not be deleted" in { + val input = + """circuit Top : + | module Top : + | input clk : Clock + | input x : UInt<1> + | output z : UInt<1> + | node a = not(x) + | stop(clk, a, 0) + | z <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input clk : Clock + | input x : UInt<1> + | output z : UInt<1> + | node a = not(x) + | z <= x + | stop(clk, a, 0)""".stripMargin + exec(input, check) + } + "Globally dead module" should "should be deleted" in { + val input = + """circuit Top : + | module Dead : + | input x : UInt<1> + | output z : UInt<1> + | z <= x + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst dead of Dead + | dead.x <= x + | z <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | z <= x""".stripMargin + exec(input, check) + } + "Globally dead extmodule" should "NOT be deleted by default" in { + val input = + """circuit Top : + | extmodule Dead : + | input x : UInt<1> + | output z : UInt<1> + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst dead of Dead + | dead.x <= x + | z <= x""".stripMargin + val check = + """circuit Top : + | extmodule Dead : + | input x : UInt<1> + | output z : UInt<1> + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst dead of Dead + | dead.x <= x + | z <= x""".stripMargin + exec(input, check) + } + "Globally dead extmodule marked optimizable" should "be deleted" in { + val input = + """circuit Top : + | extmodule Dead : + | input x : UInt<1> + | output z : UInt<1> + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst dead of Dead + | dead.x <= x + | z <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | z <= x""".stripMargin + val doTouchAnno = OptimizableExtModuleAnnotation(ModuleName("Dead", CircuitName("Top"))) + exec(input, check, Seq(doTouchAnno)) + } + "Analog ports of extmodules" should "count as both inputs and outputs" in { + val input = + """circuit Top : + | extmodule BB1 : + | output bus : Analog<1> + | extmodule BB2 : + | output bus : Analog<1> + | output out : UInt<1> + | module Top : + | output out : UInt<1> + | inst bb1 of BB1 + | inst bb2 of BB2 + | attach (bb1.bus, bb2.bus) + | out <= bb2.out + """.stripMargin + exec(input, input) + } + // bar.z is not used and thus is dead code, but foo.z is used so this code isn't eliminated + "Module deduplication" should "should be preserved despite unused output of ONE instance" in { + val input = + """circuit Top : + | module Child : + | input x : UInt<1> + | output y : UInt<1> + | output z : UInt<1> + | y <= not(x) + | z <= x + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst foo of Child + | inst bar of Child + | foo.x <= x + | bar.x <= x + | node t0 = or(foo.y, foo.z) + | z <= or(t0, bar.y)""".stripMargin + val check = + """circuit Top : + | module Child : + | input x : UInt<1> + | output y : UInt<1> + | output z : UInt<1> + | y <= not(x) + | z <= x + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst foo of Child + | inst bar of Child + | foo.x <= x + | bar.x <= x + | node t0 = or(foo.y, foo.z) + | z <= or(t0, bar.y)""".stripMargin + exec(input, check) + } + // This currently does NOT work + behavior of "Single dead instances" + ignore should "should be deleted" in { + val input = + """circuit Top : + | module Child : + | input x : UInt<1> + | output z : UInt<1> + | z <= x + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst foo of Child + | inst bar of Child + | foo.x <= x + | bar.x <= x + | z <= foo.z""".stripMargin + val check = + """circuit Top : + | module Child : + | input x : UInt<1> + | output z : UInt<1> + | z <= x + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst foo of Child + | skip + | foo.x <= x + | skip + | z <= foo.z""".stripMargin + exec(input, check) + } +} diff --git a/src/test/scala/firrtlTests/FirrtlSpec.scala b/src/test/scala/firrtlTests/FirrtlSpec.scala index f77b47f3..a45af8c7 100644 --- a/src/test/scala/firrtlTests/FirrtlSpec.scala +++ b/src/test/scala/firrtlTests/FirrtlSpec.scala @@ -12,7 +12,8 @@ import scala.io.Source import firrtl._ import firrtl.Parser.IgnoreInfo -import firrtl.annotations +import firrtl.annotations._ +import firrtl.transforms.{DontTouchAnnotation, NoDedupAnnotation} import firrtl.util.BackendCompilationUtilities trait FirrtlRunners extends BackendCompilationUtilities { @@ -82,6 +83,16 @@ trait FirrtlRunners extends BackendCompilationUtilities { } trait FirrtlMatchers extends Matchers { + def dontTouch(path: String): Annotation = { + val parts = path.split('.') + require(parts.size >= 2, "Must specify both module and component!") + val name = ComponentName(parts.tail.mkString("."), ModuleName(parts.head, CircuitName("Top"))) + DontTouchAnnotation(name) + } + def dontDedup(mod: String): Annotation = { + require(mod.split('.').size == 1, "Can only specify a Module, not a component or instance") + NoDedupAnnotation(ModuleName(mod, CircuitName("Top"))) + } // Replace all whitespace with a single space and remove leading and // trailing whitespace // Note this is intended for single-line strings, no newlines @@ -94,8 +105,13 @@ trait FirrtlMatchers extends Matchers { * compiler will be run on input then emitted result will each be split into * lines and normalized. */ - def executeTest(input: String, expected: Seq[String], compiler: Compiler) = { - val finalState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm)) + def executeTest( + input: String, + expected: Seq[String], + compiler: Compiler, + annotations: Seq[Annotation] = Seq.empty) = { + val annoMap = AnnotationMap(annotations) + val finalState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm, Some(annoMap))) val lines = finalState.getEmittedCircuit.value split "\n" map normalized for (e <- expected) { lines should contain (e) diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index 8367f152..25f845bc 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -5,6 +5,7 @@ package firrtlTests import firrtl._ import firrtl.ir._ import firrtl.passes._ +import firrtl.transforms._ import firrtl.passes.memlib._ import annotations._ @@ -21,7 +22,7 @@ class ReplSeqMemSpec extends SimpleTransformSpec { new SeqTransform { def inputForm = LowForm def outputForm = LowForm - def transforms = Seq(ConstProp, CommonSubexpressionElimination, DeadCodeElimination, RemoveEmpty) + def transforms = Seq(ConstProp, CommonSubexpressionElimination, new DeadCodeElimination, RemoveEmpty) } ) @@ -199,7 +200,7 @@ circuit CustomMemory : smem mem_1 : UInt<16>[7] read mport _T_17 = mem_0[io.rAddr], clock read mport _T_19 = mem_1[io.rAddr], clock - io.dO <= _T_17 + io.dO <= and(_T_17, _T_19) when io.wEn : write mport _T_18 = mem_0[io.wAddr], clock write mport _T_20 = mem_1[io.wAddr], clock @@ -218,7 +219,7 @@ circuit CustomMemory : case e: ExtModule => true case _ => false } - require(numExtMods == 2) + numExtMods should be (2) (new java.io.File(confLoc)).delete() } @@ -237,7 +238,7 @@ circuit CustomMemory : read mport _T_17 = mem_0[io.rAddr], clock read mport _T_19 = mem_1[io.rAddr], clock read mport _T_21 = mem_2[io.rAddr], clock - io.dO <= _T_17 + io.dO <= and(_T_17, and(_T_19, _T_21)) when io.wEn : write mport _T_18 = mem_0[io.wAddr], clock write mport _T_20 = mem_1[io.wAddr], clock @@ -258,7 +259,7 @@ circuit CustomMemory : case e: ExtModule => true case _ => false } - require(numExtMods == 2) + numExtMods should be (2) (new java.io.File(confLoc)).delete() } |
