From dbb4de2a4e6c2122e7c0def7d1c4ff38a79e1558 Mon Sep 17 00:00:00 2001 From: Jack Koenig Date: Wed, 28 Jun 2017 17:00:13 -0700 Subject: [Testing] Have SimpleTransformSpec mix in FirrtlMatchers Gives all transform specs access to useful utilities (like dontTouch). Deletes some duplicate code. Parsing mode UseInfo is fine for everything, only matters if the test actually uses info. --- src/test/scala/firrtlTests/AnnotationTests.scala | 7 ------- src/test/scala/firrtlTests/FirrtlSpec.scala | 4 ++-- src/test/scala/firrtlTests/PassTests.scala | 3 +-- 3 files changed, 3 insertions(+), 11 deletions(-) (limited to 'src') diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala index 3e93081e..c8b83bd2 100644 --- a/src/test/scala/firrtlTests/AnnotationTests.scala +++ b/src/test/scala/firrtlTests/AnnotationTests.scala @@ -49,13 +49,6 @@ class AnnotationTests extends AnnotationSpec with Matchers { Annotation(ComponentName(s, ModuleName(mod, CircuitName("Top"))), classOf[Transform], value) def manno(mod: String): Annotation = Annotation(ModuleName(mod, CircuitName("Top")), classOf[Transform], "some value") - // TODO unify with FirrtlMatchers, problems with multiple definitions of parse - def dontTouch(path: String): Annotation = { - val parts = path.split('.') - require(parts.size >= 2, "Must specify both module and component!") - val name = ComponentName(parts.tail.mkString("."), ModuleName(parts.head, CircuitName("Top"))) - DontTouchAnnotation(name) - } "Loose and Sticky annotation on a node" should "pass through" in { val input: String = diff --git a/src/test/scala/firrtlTests/FirrtlSpec.scala b/src/test/scala/firrtlTests/FirrtlSpec.scala index a45af8c7..07f83142 100644 --- a/src/test/scala/firrtlTests/FirrtlSpec.scala +++ b/src/test/scala/firrtlTests/FirrtlSpec.scala @@ -11,7 +11,7 @@ import org.scalatest.prop._ import scala.io.Source import firrtl._ -import firrtl.Parser.IgnoreInfo +import firrtl.Parser.UseInfo import firrtl.annotations._ import firrtl.transforms.{DontTouchAnnotation, NoDedupAnnotation} import firrtl.util.BackendCompilationUtilities @@ -100,7 +100,7 @@ trait FirrtlMatchers extends Matchers { require(!s.contains("\n")) s.replaceAll("\\s+", " ").trim } - def parse(str: String) = Parser.parse(str.split("\n").toIterator, IgnoreInfo) + def parse(str: String) = Parser.parse(str.split("\n").toIterator, UseInfo) /** Helper for executing tests * compiler will be run on input then emitted result will each be split into * lines and normalized. diff --git a/src/test/scala/firrtlTests/PassTests.scala b/src/test/scala/firrtlTests/PassTests.scala index e22fd513..7fa7e8ef 100644 --- a/src/test/scala/firrtlTests/PassTests.scala +++ b/src/test/scala/firrtlTests/PassTests.scala @@ -13,9 +13,8 @@ import logger._ // An example methodology for testing Firrtl Passes // Spec class should extend this class -abstract class SimpleTransformSpec extends FlatSpec with Matchers with Compiler with LazyLogging { +abstract class SimpleTransformSpec extends FlatSpec with FirrtlMatchers with Compiler with LazyLogging { // Utility function - def parse(s: String): Circuit = Parser.parse(s.split("\n").toIterator, infoMode = UseInfo) def squash(c: Circuit): Circuit = RemoveEmpty.run(c) // Executes the test. Call in tests. -- cgit v1.2.3 From 818cfde4ad42ffa9ee30d0f9ae72533ede80e4ce Mon Sep 17 00:00:00 2001 From: Jack Koenig Date: Wed, 28 Jun 2017 17:31:00 -0700 Subject: [Testing] Clean up SimpleTransformSpec execute methods This makes it more concise to write tests --- src/test/scala/firrtlTests/AnnotationTests.scala | 10 +++--- .../scala/firrtlTests/InlineInstancesTests.scala | 38 ++++++++++------------ src/test/scala/firrtlTests/PassTests.scala | 11 ++++--- .../transforms/BlacklBoxSourceHelperSpec.scala | 6 ++-- .../scala/firrtlTests/transforms/DedupTests.scala | 12 +++---- 5 files changed, 37 insertions(+), 40 deletions(-) (limited to 'src') diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala index c8b83bd2..aeefbbe3 100644 --- a/src/test/scala/firrtlTests/AnnotationTests.scala +++ b/src/test/scala/firrtlTests/AnnotationTests.scala @@ -23,13 +23,13 @@ trait AnnotationSpec extends LowTransformSpec { def transform = new ResolveAndCheck // Check if Annotation Exception is thrown - override def failingexecute(annotations: AnnotationMap, input: String): Exception = { + override def failingexecute(input: String, annotations: Seq[Annotation]): Exception = { intercept[AnnotationException] { - compile(CircuitState(parse(input), ChirrtlForm, Some(annotations)), Seq.empty) + compile(CircuitState(parse(input), ChirrtlForm, Some(AnnotationMap(annotations))), Seq.empty) } } - def execute(aMap: Option[AnnotationMap], input: String, check: Annotation): Unit = { - val cr = compile(CircuitState(parse(input), ChirrtlForm, aMap), Seq.empty) + def execute(input: String, check: Annotation, annotations: Seq[Annotation]): Unit = { + val cr = compile(CircuitState(parse(input), ChirrtlForm, Some(AnnotationMap(annotations))), Seq.empty) cr.annotations.get.annotations should contain (check) } } @@ -58,7 +58,7 @@ class AnnotationTests extends AnnotationSpec with Matchers { | input b : UInt<1> | node c = b""".stripMargin val ta = anno("c", "") - execute(getAMap(ta), input, ta) + execute(input, ta, Seq(ta)) } "Annotations" should "be readable from file" in { diff --git a/src/test/scala/firrtlTests/InlineInstancesTests.scala b/src/test/scala/firrtlTests/InlineInstancesTests.scala index 9e8f8054..4398df48 100644 --- a/src/test/scala/firrtlTests/InlineInstancesTests.scala +++ b/src/test/scala/firrtlTests/InlineInstancesTests.scala @@ -6,7 +6,7 @@ import org.scalatest.FlatSpec import org.scalatest.Matchers import org.scalatest.junit.JUnitRunner import firrtl.ir.Circuit -import firrtl.{AnnotationMap, Parser} +import firrtl.Parser import firrtl.passes.PassExceptions import firrtl.annotations.{Annotation, CircuitName, ComponentName, ModuleName, Named} import firrtl.passes.{InlineAnnotation, InlineInstances} @@ -18,7 +18,14 @@ import logger.LogLevel.Debug * Tests inline instances transformation */ class InlineInstancesTests extends LowTransformSpec { - def transform = new InlineInstances + def transform = new InlineInstances + def inline(mod: String): Annotation = { + val parts = mod.split('.') + val modName = ModuleName(parts.head, CircuitName("Top")) // If this fails, bad input + val name = if (parts.size == 1) modName + else ComponentName(parts.tail.mkString("."), modName) + InlineAnnotation(name) + } // Set this to debug, this will apply to all tests // Logger.setLevel(this.getClass, Debug) "The module Inline" should "be inlined" in { @@ -44,8 +51,7 @@ class InlineInstancesTests extends LowTransformSpec { | i$b <= i$a | b <= i$b | i$a <= a""".stripMargin - val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("Inline", CircuitName("Top"))))) - execute(aMap, input, check) + execute(input, check, Seq(inline("Inline"))) } "The all instances of Simple" should "be inlined" in { @@ -77,8 +83,7 @@ class InlineInstancesTests extends LowTransformSpec { | b <= i1$b | i0$a <= a | i1$a <= i0$b""".stripMargin - val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("Simple", CircuitName("Top"))))) - execute(aMap, input, check) + execute(input, check, Seq(inline("Simple"))) } "Only one instance of Simple" should "be inlined" in { @@ -112,8 +117,7 @@ class InlineInstancesTests extends LowTransformSpec { | input a : UInt<32> | output b : UInt<32> | b <= a""".stripMargin - val aMap = new AnnotationMap(Seq(InlineAnnotation(ComponentName("i0",ModuleName("Top", CircuitName("Top")))))) - execute(aMap, input, check) + execute(input, check, Seq(inline("Top.i0"))) } "All instances of A" should "be inlined" in { @@ -157,8 +161,7 @@ class InlineInstancesTests extends LowTransformSpec { | i$b <= i$a | b <= i$b | i$a <= a""".stripMargin - val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top"))))) - execute(aMap, input, check) + execute(input, check, Seq(inline("A"))) } "Non-inlined instances" should "still prepend prefix" in { @@ -196,8 +199,7 @@ class InlineInstancesTests extends LowTransformSpec { | input a : UInt<32> | output b : UInt<32> | b <= a""".stripMargin - val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top"))))) - execute(aMap, input, check) + execute(input, check, Seq(inline("A"))) } // ---- Errors ---- @@ -214,8 +216,7 @@ class InlineInstancesTests extends LowTransformSpec { | extmodule A : | input a : UInt<32> | output b : UInt<32>""".stripMargin - val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top"))))) - failingexecute(aMap, input) + failingexecute(input, Seq(inline("A"))) } // 2) ext instance "External instance" should "not be inlined" in { @@ -230,8 +231,7 @@ class InlineInstancesTests extends LowTransformSpec { | extmodule A : | input a : UInt<32> | output b : UInt<32>""".stripMargin - val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top"))))) - failingexecute(aMap, input) + failingexecute(input, Seq(inline("A"))) } // 3) no module "Inlined module" should "exist" in { @@ -241,8 +241,7 @@ class InlineInstancesTests extends LowTransformSpec { | input a : UInt<32> | output b : UInt<32> | b <= a""".stripMargin - val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top"))))) - failingexecute(aMap, input) + failingexecute(input, Seq(inline("A"))) } // 4) no inst "Inlined instance" should "exist" in { @@ -252,8 +251,7 @@ class InlineInstancesTests extends LowTransformSpec { | input a : UInt<32> | output b : UInt<32> | b <= a""".stripMargin - val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top"))))) - failingexecute(aMap, input) + failingexecute(input, Seq(inline("A"))) } } diff --git a/src/test/scala/firrtlTests/PassTests.scala b/src/test/scala/firrtlTests/PassTests.scala index 7fa7e8ef..6727533e 100644 --- a/src/test/scala/firrtlTests/PassTests.scala +++ b/src/test/scala/firrtlTests/PassTests.scala @@ -9,6 +9,7 @@ import firrtl.ir.Circuit import firrtl.Parser.UseInfo import firrtl.passes.{Pass, PassExceptions, RemoveEmpty} import firrtl._ +import firrtl.annotations._ import logger._ // An example methodology for testing Firrtl Passes @@ -18,8 +19,9 @@ abstract class SimpleTransformSpec extends FlatSpec with FirrtlMatchers with Com def squash(c: Circuit): Circuit = RemoveEmpty.run(c) // Executes the test. Call in tests. - def execute(annotations: AnnotationMap, input: String, check: String): Unit = { - val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, Some(annotations))) + // annotations cannot have default value because scalatest trait Suite has a default value + def execute(input: String, check: String, annotations: Seq[Annotation]): Unit = { + val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, Some(AnnotationMap(annotations)))) val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize val expected = parse(check).serialize logger.debug(actual) @@ -27,9 +29,10 @@ abstract class SimpleTransformSpec extends FlatSpec with FirrtlMatchers with Com (actual) should be (expected) } // Executes the test, should throw an error - def failingexecute(annotations: AnnotationMap, input: String): Exception = { + // No default to be consistent with execute + def failingexecute(input: String, annotations: Seq[Annotation]): Exception = { intercept[PassExceptions] { - compile(CircuitState(parse(input), ChirrtlForm, Some(annotations)), Seq.empty) + compile(CircuitState(parse(input), ChirrtlForm, Some(AnnotationMap(annotations))), Seq.empty) } } } diff --git a/src/test/scala/firrtlTests/transforms/BlacklBoxSourceHelperSpec.scala b/src/test/scala/firrtlTests/transforms/BlacklBoxSourceHelperSpec.scala index 8cd51b2a..bf294fe9 100644 --- a/src/test/scala/firrtlTests/transforms/BlacklBoxSourceHelperSpec.scala +++ b/src/test/scala/firrtlTests/transforms/BlacklBoxSourceHelperSpec.scala @@ -78,12 +78,12 @@ class BlacklBoxSourceHelperTransformSpec extends LowTransformSpec { "annotated external modules" should "appear in output directory" in { - val aMap = AnnotationMap(Seq( + val annos = Seq( Annotation(moduleName, classOf[BlackBoxSourceHelper], BlackBoxTargetDir("test_run_dir").serialize), Annotation(moduleName, classOf[BlackBoxSourceHelper], BlackBoxResource("/blackboxes/AdderExtModule.v").serialize) - )) + ) - execute(aMap, input, output) + execute(input, output, annos) new java.io.File("test_run_dir/AdderExtModule.v").exists should be (true) new java.io.File(s"test_run_dir/${BlackBoxSourceHelper.FileListName}").exists should be (true) diff --git a/src/test/scala/firrtlTests/transforms/DedupTests.scala b/src/test/scala/firrtlTests/transforms/DedupTests.scala index 7148dd11..74c4b4e7 100644 --- a/src/test/scala/firrtlTests/transforms/DedupTests.scala +++ b/src/test/scala/firrtlTests/transforms/DedupTests.scala @@ -46,8 +46,7 @@ class DedupModuleTests extends HighTransformSpec { | output x: UInt<1> | x <= UInt(1) """.stripMargin - val aMap = new AnnotationMap(Nil) - execute(aMap, input, check) + execute(input, check, Seq.empty) } "The module A and B" should "be deduped" in { val input = @@ -83,8 +82,7 @@ class DedupModuleTests extends HighTransformSpec { | output x: UInt<1> | x <= UInt(1) """.stripMargin - val aMap = new AnnotationMap(Nil) - execute(aMap, input, check) + execute(input, check, Seq.empty) } "The module A and B with comments" should "be deduped" in { val input = @@ -120,8 +118,7 @@ class DedupModuleTests extends HighTransformSpec { | output x: UInt<1> | x <= UInt(1) """.stripMargin - val aMap = new AnnotationMap(Nil) - execute(aMap, input, check) + execute(input, check, Seq.empty) } "The module B, but not A, with comments" should "be deduped if not annotated" in { val input = @@ -148,8 +145,7 @@ class DedupModuleTests extends HighTransformSpec { | output x: UInt<1> @[xx 1:1] | x <= UInt(1) """.stripMargin - val aMap = new AnnotationMap(Seq(NoDedupAnnotation(ModuleName("A", CircuitName("Top"))))) - execute(aMap, input, check) + execute(input, check, Seq(dontDedup("A"))) } } -- cgit v1.2.3 From 39665e1f74cfe8243067442cccf4e7eab66ade68 Mon Sep 17 00:00:00 2001 From: Jack Koenig Date: Wed, 28 Jun 2017 17:52:56 -0700 Subject: Promote ConstProp to a transform --- src/main/scala/firrtl/LoweringCompilers.scala | 6 +- src/main/scala/firrtl/passes/ConstProp.scala | 295 -------------------- .../firrtl/transforms/ConstantPropagation.scala | 303 +++++++++++++++++++++ src/test/scala/firrtlTests/CInferMDirSpec.scala | 3 +- src/test/scala/firrtlTests/ChirrtlMemSpec.scala | 3 +- .../firrtlTests/ConstantPropagationTests.scala | 39 ++- src/test/scala/firrtlTests/LowerTypesSpec.scala | 3 +- src/test/scala/firrtlTests/ReplSeqMemTests.scala | 2 +- src/test/scala/firrtlTests/UnitTests.scala | 12 +- 9 files changed, 351 insertions(+), 315 deletions(-) delete mode 100644 src/main/scala/firrtl/passes/ConstProp.scala create mode 100644 src/main/scala/firrtl/transforms/ConstantPropagation.scala (limited to 'src') diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index 66ae1673..8dd9b180 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -98,12 +98,12 @@ class LowFirrtlOptimization extends CoreTransform { def outputForm = LowForm def transforms = Seq( passes.RemoveValidIf, - passes.ConstProp, + new firrtl.transforms.ConstantPropagation, passes.PadWidths, - passes.ConstProp, + new firrtl.transforms.ConstantPropagation, passes.Legalize, passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter - passes.ConstProp, + new firrtl.transforms.ConstantPropagation, passes.SplitExpressions, passes.CommonSubexpressionElimination, new firrtl.transforms.DeadCodeElimination) diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala deleted file mode 100644 index f2aa1a03..00000000 --- a/src/main/scala/firrtl/passes/ConstProp.scala +++ /dev/null @@ -1,295 +0,0 @@ -// See LICENSE for license details. - -package firrtl.passes - -import firrtl._ -import firrtl.ir._ -import firrtl.Utils._ -import firrtl.Mappers._ -import firrtl.PrimOps._ - -import annotation.tailrec - -object ConstProp extends Pass { - private def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match { - case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) - case (we, wt) if we == wt => e - } - - private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t) - - trait FoldLogicalOp { - def fold(c1: Literal, c2: Literal): Expression - def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression - - def apply(e: DoPrim): Expression = (e.args.head, e.args(1)) match { - case (lhs: Literal, rhs: Literal) => fold(lhs, rhs) - case (lhs: Literal, rhs) => pad(simplify(e, lhs, rhs), e.tpe) - case (lhs, rhs: Literal) => pad(simplify(e, rhs, lhs), e.tpe) - case _ => e - } - } - - object FoldAND extends FoldLogicalOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value & c2.value, c1.width max c2.width) - def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) - case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) - case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => rhs - case _ => e - } - } - - object FoldOR extends FoldLogicalOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value | c2.value, c1.width max c2.width) - def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, _) if v == BigInt(0) => rhs - case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) - case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => lhs - case _ => e - } - } - - object FoldXOR extends FoldLogicalOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width) - def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, _) if v == BigInt(0) => rhs - case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) - case _ => e - } - } - - object FoldEqual extends FoldLogicalOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1)) - def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs - case _ => e - } - } - - object FoldNotEqual extends FoldLogicalOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1)) - def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs - case _ => e - } - } - - private def foldConcat(e: DoPrim) = (e.args.head, e.args(1)) match { - case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) => UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw)) - case _ => e - } - - private def foldShiftLeft(e: DoPrim) = e.consts.head.toInt match { - case 0 => e.args.head - case x => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x)) - case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x)) - case _ => e - } - } - - private def foldShiftRight(e: DoPrim) = e.consts.head.toInt match { - case 0 => e.args.head - case x => e.args.head match { - // TODO when amount >= x.width, return a zero-width wire - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1)) - // take sign bit if shift amount is larger than arg width - case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1)) - case _ => e - } - } - - private def foldComparison(e: DoPrim) = { - def foldIfZeroedArg(x: Expression): Expression = { - def isUInt(e: Expression): Boolean = e.tpe match { - case UIntType(_) => true - case _ => false - } - def isZero(e: Expression) = e match { - case UIntLiteral(value, _) => value == BigInt(0) - case SIntLiteral(value, _) => value == BigInt(0) - case _ => false - } - x match { - case DoPrim(Lt, Seq(a,b),_,_) if isUInt(a) && isZero(b) => zero - case DoPrim(Leq, Seq(a,b),_,_) if isZero(a) && isUInt(b) => one - case DoPrim(Gt, Seq(a,b),_,_) if isZero(a) && isUInt(b) => zero - case DoPrim(Geq, Seq(a,b),_,_) if isUInt(a) && isZero(b) => one - case ex => ex - } - } - - def foldIfOutsideRange(x: Expression): Expression = { - //Note, only abides by a partial ordering - case class Range(min: BigInt, max: BigInt) { - def === (that: Range) = - Seq(this.min, this.max, that.min, that.max) - .sliding(2,1) - .map(x => x.head == x(1)) - .reduce(_ && _) - def > (that: Range) = this.min > that.max - def >= (that: Range) = this.min >= that.max - def < (that: Range) = this.max < that.min - def <= (that: Range) = this.max <= that.min - } - def range(e: Expression): Range = e match { - case UIntLiteral(value, _) => Range(value, value) - case SIntLiteral(value, _) => Range(value, value) - case _ => e.tpe match { - case SIntType(IntWidth(width)) => Range( - min = BigInt(0) - BigInt(2).pow(width.toInt - 1), - max = BigInt(2).pow(width.toInt - 1) - BigInt(1) - ) - case UIntType(IntWidth(width)) => Range( - min = BigInt(0), - max = BigInt(2).pow(width.toInt) - BigInt(1) - ) - } - } - // Calculates an expression's range of values - x match { - case ex: DoPrim => - def r0 = range(ex.args.head) - def r1 = range(ex.args(1)) - ex.op match { - // Always true - case Lt if r0 < r1 => one - case Leq if r0 <= r1 => one - case Gt if r0 > r1 => one - case Geq if r0 >= r1 => one - // Always false - case Lt if r0 >= r1 => zero - case Leq if r0 > r1 => zero - case Gt if r0 <= r1 => zero - case Geq if r0 < r1 => zero - case _ => ex - } - case ex => ex - } - } - foldIfZeroedArg(foldIfOutsideRange(e)) - } - - private def constPropPrim(e: DoPrim): Expression = e.op match { - case Shl => foldShiftLeft(e) - case Shr => foldShiftRight(e) - case Cat => foldConcat(e) - case And => FoldAND(e) - case Or => FoldOR(e) - case Xor => FoldXOR(e) - case Eq => FoldEqual(e) - case Neq => FoldNotEqual(e) - case (Lt | Leq | Gt | Geq) => foldComparison(e) - case Not => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) - case _ => e - } - case AsUInt => e.args.head match { - case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w)) - case u: UIntLiteral => u - case _ => e - } - case AsSInt => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt-1)) << w.toInt), IntWidth(w)) - case s: SIntLiteral => s - case _ => e - } - case Pad => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head max w)) - case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head max w)) - case _ if bitWidth(e.args.head.tpe) == e.consts.head => e.args.head - case _ => e - } - case Bits => e.args.head match { - case lit: Literal => - val hi = e.consts.head.toInt - val lo = e.consts(1).toInt - require(hi >= lo) - UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe)) - case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match { - case t: UIntType => x - case _ => asUInt(x, e.tpe) - } - case _ => e - } - case _ => e - } - - private def constPropMuxCond(m: Mux) = m.cond match { - case UIntLiteral(c, _) => pad(if (c == BigInt(1)) m.tval else m.fval, m.tpe) - case _ => m - } - - private def constPropMux(m: Mux): Expression = (m.tval, m.fval) match { - case _ if m.tval == m.fval => m.tval - case (t: UIntLiteral, f: UIntLiteral) => - if (t.value == BigInt(1) && f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1)) m.cond - else constPropMuxCond(m) - case _ => constPropMuxCond(m) - } - - private def constPropNodeRef(r: WRef, e: Expression) = e match { - case _: UIntLiteral | _: SIntLiteral | _: WRef => e - case _ => r - } - - // Two pass process - // 1. Propagate constants in expressions and forward propagate references - // 2. Propagate references again for backwards reference (Wires) - // TODO Replacing all wires with nodes makes the second pass unnecessary - @tailrec - private def constPropModule(m: Module): Module = { - var nPropagated = 0L - val nodeMap = collection.mutable.HashMap[String, Expression]() - - def backPropExpr(expr: Expression): Expression = { - val old = expr map backPropExpr - val propagated = old match { - case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => - constPropNodeRef(ref, nodeMap(rname)) - case x => x - } - if (old ne propagated) { - nPropagated += 1 - } - propagated - } - def backPropStmt(stmt: Statement): Statement = stmt map backPropStmt map backPropExpr - - def constPropExpression(e: Expression): Expression = { - val old = e map constPropExpression - val propagated = old match { - case p: DoPrim => constPropPrim(p) - case m: Mux => constPropMux(m) - case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => - constPropNodeRef(ref, nodeMap(rname)) - case x => x - } - propagated - } - - def constPropStmt(s: Statement): Statement = { - val stmtx = s map constPropStmt map constPropExpression - stmtx match { - case x: DefNode => nodeMap(x.name) = x.value - case Connect(_, WRef(wname, wtpe, WireKind, _), expr) => - val exprx = constPropExpression(pad(expr, wtpe)) - nodeMap(wname) = exprx - case _ => - } - stmtx - } - - val res = Module(m.info, m.name, m.ports, backPropStmt(constPropStmt(m.body))) - if (nPropagated > 0) constPropModule(res) else res - } - - def run(c: Circuit): Circuit = { - val modulesx = c.modules.map { - case m: ExtModule => m - case m: Module => constPropModule(m) - } - Circuit(c.info, modulesx, c.main) - } -} diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala new file mode 100644 index 00000000..930fe45a --- /dev/null +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -0,0 +1,303 @@ +// See LICENSE for license details. + +package firrtl +package transforms + +import firrtl._ +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.Mappers._ +import firrtl.PrimOps._ + +import annotation.tailrec + +class ConstantPropagation extends Transform { + def inputForm = LowForm + def outputForm = LowForm + + private def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match { + case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) + case (we, wt) if we == wt => e + } + + private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t) + + trait FoldLogicalOp { + def fold(c1: Literal, c2: Literal): Expression + def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression + + def apply(e: DoPrim): Expression = (e.args.head, e.args(1)) match { + case (lhs: Literal, rhs: Literal) => fold(lhs, rhs) + case (lhs: Literal, rhs) => pad(simplify(e, lhs, rhs), e.tpe) + case (lhs, rhs: Literal) => pad(simplify(e, rhs, lhs), e.tpe) + case _ => e + } + } + + object FoldAND extends FoldLogicalOp { + def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value & c2.value, c1.width max c2.width) + def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { + case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) + case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) + case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => rhs + case _ => e + } + } + + object FoldOR extends FoldLogicalOp { + def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value | c2.value, c1.width max c2.width) + def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { + case UIntLiteral(v, _) if v == BigInt(0) => rhs + case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) + case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => lhs + case _ => e + } + } + + object FoldXOR extends FoldLogicalOp { + def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width) + def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { + case UIntLiteral(v, _) if v == BigInt(0) => rhs + case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) + case _ => e + } + } + + object FoldEqual extends FoldLogicalOp { + def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1)) + def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { + case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs + case _ => e + } + } + + object FoldNotEqual extends FoldLogicalOp { + def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1)) + def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { + case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs + case _ => e + } + } + + private def foldConcat(e: DoPrim) = (e.args.head, e.args(1)) match { + case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) => UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw)) + case _ => e + } + + private def foldShiftLeft(e: DoPrim) = e.consts.head.toInt match { + case 0 => e.args.head + case x => e.args.head match { + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x)) + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x)) + case _ => e + } + } + + private def foldShiftRight(e: DoPrim) = e.consts.head.toInt match { + case 0 => e.args.head + case x => e.args.head match { + // TODO when amount >= x.width, return a zero-width wire + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1)) + // take sign bit if shift amount is larger than arg width + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1)) + case _ => e + } + } + + private def foldComparison(e: DoPrim) = { + def foldIfZeroedArg(x: Expression): Expression = { + def isUInt(e: Expression): Boolean = e.tpe match { + case UIntType(_) => true + case _ => false + } + def isZero(e: Expression) = e match { + case UIntLiteral(value, _) => value == BigInt(0) + case SIntLiteral(value, _) => value == BigInt(0) + case _ => false + } + x match { + case DoPrim(Lt, Seq(a,b),_,_) if isUInt(a) && isZero(b) => zero + case DoPrim(Leq, Seq(a,b),_,_) if isZero(a) && isUInt(b) => one + case DoPrim(Gt, Seq(a,b),_,_) if isZero(a) && isUInt(b) => zero + case DoPrim(Geq, Seq(a,b),_,_) if isUInt(a) && isZero(b) => one + case ex => ex + } + } + + def foldIfOutsideRange(x: Expression): Expression = { + //Note, only abides by a partial ordering + case class Range(min: BigInt, max: BigInt) { + def === (that: Range) = + Seq(this.min, this.max, that.min, that.max) + .sliding(2,1) + .map(x => x.head == x(1)) + .reduce(_ && _) + def > (that: Range) = this.min > that.max + def >= (that: Range) = this.min >= that.max + def < (that: Range) = this.max < that.min + def <= (that: Range) = this.max <= that.min + } + def range(e: Expression): Range = e match { + case UIntLiteral(value, _) => Range(value, value) + case SIntLiteral(value, _) => Range(value, value) + case _ => e.tpe match { + case SIntType(IntWidth(width)) => Range( + min = BigInt(0) - BigInt(2).pow(width.toInt - 1), + max = BigInt(2).pow(width.toInt - 1) - BigInt(1) + ) + case UIntType(IntWidth(width)) => Range( + min = BigInt(0), + max = BigInt(2).pow(width.toInt) - BigInt(1) + ) + } + } + // Calculates an expression's range of values + x match { + case ex: DoPrim => + def r0 = range(ex.args.head) + def r1 = range(ex.args(1)) + ex.op match { + // Always true + case Lt if r0 < r1 => one + case Leq if r0 <= r1 => one + case Gt if r0 > r1 => one + case Geq if r0 >= r1 => one + // Always false + case Lt if r0 >= r1 => zero + case Leq if r0 > r1 => zero + case Gt if r0 <= r1 => zero + case Geq if r0 < r1 => zero + case _ => ex + } + case ex => ex + } + } + foldIfZeroedArg(foldIfOutsideRange(e)) + } + + private def constPropPrim(e: DoPrim): Expression = e.op match { + case Shl => foldShiftLeft(e) + case Shr => foldShiftRight(e) + case Cat => foldConcat(e) + case And => FoldAND(e) + case Or => FoldOR(e) + case Xor => FoldXOR(e) + case Eq => FoldEqual(e) + case Neq => FoldNotEqual(e) + case (Lt | Leq | Gt | Geq) => foldComparison(e) + case Not => e.args.head match { + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) + case _ => e + } + case AsUInt => e.args.head match { + case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w)) + case u: UIntLiteral => u + case _ => e + } + case AsSInt => e.args.head match { + case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt-1)) << w.toInt), IntWidth(w)) + case s: SIntLiteral => s + case _ => e + } + case Pad => e.args.head match { + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head max w)) + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head max w)) + case _ if bitWidth(e.args.head.tpe) == e.consts.head => e.args.head + case _ => e + } + case Bits => e.args.head match { + case lit: Literal => + val hi = e.consts.head.toInt + val lo = e.consts(1).toInt + require(hi >= lo) + UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe)) + case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match { + case t: UIntType => x + case _ => asUInt(x, e.tpe) + } + case _ => e + } + case _ => e + } + + private def constPropMuxCond(m: Mux) = m.cond match { + case UIntLiteral(c, _) => pad(if (c == BigInt(1)) m.tval else m.fval, m.tpe) + case _ => m + } + + private def constPropMux(m: Mux): Expression = (m.tval, m.fval) match { + case _ if m.tval == m.fval => m.tval + case (t: UIntLiteral, f: UIntLiteral) => + if (t.value == BigInt(1) && f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1)) m.cond + else constPropMuxCond(m) + case _ => constPropMuxCond(m) + } + + private def constPropNodeRef(r: WRef, e: Expression) = e match { + case _: UIntLiteral | _: SIntLiteral | _: WRef => e + case _ => r + } + + // Two pass process + // 1. Propagate constants in expressions and forward propagate references + // 2. Propagate references again for backwards reference (Wires) + // TODO Replacing all wires with nodes makes the second pass unnecessary + @tailrec + private def constPropModule(m: Module): Module = { + var nPropagated = 0L + val nodeMap = collection.mutable.HashMap[String, Expression]() + + def backPropExpr(expr: Expression): Expression = { + val old = expr map backPropExpr + val propagated = old match { + case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => + constPropNodeRef(ref, nodeMap(rname)) + case x => x + } + if (old ne propagated) { + nPropagated += 1 + } + propagated + } + def backPropStmt(stmt: Statement): Statement = stmt map backPropStmt map backPropExpr + + def constPropExpression(e: Expression): Expression = { + val old = e map constPropExpression + val propagated = old match { + case p: DoPrim => constPropPrim(p) + case m: Mux => constPropMux(m) + case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => + constPropNodeRef(ref, nodeMap(rname)) + case x => x + } + propagated + } + + def constPropStmt(s: Statement): Statement = { + val stmtx = s map constPropStmt map constPropExpression + stmtx match { + case x: DefNode => nodeMap(x.name) = x.value + case Connect(_, WRef(wname, wtpe, WireKind, _), expr) => + val exprx = constPropExpression(pad(expr, wtpe)) + nodeMap(wname) = exprx + case _ => + } + stmtx + } + + val res = Module(m.info, m.name, m.ports, backPropStmt(constPropStmt(m.body))) + if (nPropagated > 0) constPropModule(res) else res + } + + def run(c: Circuit): Circuit = { + val modulesx = c.modules.map { + case m: ExtModule => m + case m: Module => constPropModule(m) + } + Circuit(c.info, modulesx, c.main) + } + + def execute(state: CircuitState): CircuitState = { + state.copy(circuit = run(state.circuit)) + } +} diff --git a/src/test/scala/firrtlTests/CInferMDirSpec.scala b/src/test/scala/firrtlTests/CInferMDirSpec.scala index 0d31038a..299142d9 100644 --- a/src/test/scala/firrtlTests/CInferMDirSpec.scala +++ b/src/test/scala/firrtlTests/CInferMDirSpec.scala @@ -5,6 +5,7 @@ package firrtlTests import firrtl._ import firrtl.ir._ import firrtl.passes._ +import firrtl.transforms._ import firrtl.Mappers._ import annotations._ @@ -39,7 +40,7 @@ class CInferMDir extends LowTransformSpec { def transform = new SeqTransform { def inputForm = LowForm def outputForm = LowForm - def transforms = Seq(ConstProp, CInferMDirCheckPass) + def transforms = Seq(new ConstantPropagation, CInferMDirCheckPass) } "Memory" should "have correct mem port directions" in { diff --git a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala index c963c8ae..6fac5047 100644 --- a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala +++ b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala @@ -5,6 +5,7 @@ package firrtlTests import firrtl._ import firrtl.ir._ import firrtl.passes._ +import firrtl.transforms._ import firrtl.Mappers._ import annotations._ @@ -53,7 +54,7 @@ class ChirrtlMemSpec extends LowTransformSpec { def transform = new SeqTransform { def inputForm = LowForm def outputForm = LowForm - def transforms = Seq(ConstProp, MemEnableCheckPass) + def transforms = Seq(new ConstantPropagation, MemEnableCheckPass) } "Sequential Memory" should "have correct enable signals" in { diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index 95785717..c94adbf6 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -2,11 +2,11 @@ package firrtlTests -import org.scalatest.Matchers +import firrtl._ import firrtl.ir.Circuit import firrtl.Parser.IgnoreInfo -import firrtl.Parser import firrtl.passes._ +import firrtl.transforms._ // Tests the following cases for constant propagation: // 1) Unsigned integers are always greater than or @@ -16,17 +16,17 @@ import firrtl.passes._ // 3) Values are always greater than a number smaller // than their minimum value class ConstantPropagationSpec extends FirrtlFlatSpec { - val passes = Seq( + val transforms = Seq( ToWorkingIR, ResolveKinds, InferTypes, ResolveGenders, InferWidths, - ConstProp) - private def exec (input: String) = { - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - }.serialize + new ConstantPropagation) + private def exec(input: String) = { + transforms.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, t: Transform) => t.runTransform(c) + }.circuit.serialize } // ============================= "The rule x >= 0 " should " always be true if x is a UInt" in { @@ -346,6 +346,29 @@ class ConstantPropagationSpec extends FirrtlFlatSpec { input x : UInt<3> output y : UInt<1> y <= UInt<1>(0) +""" + (parse(exec(input))) should be (parse(check)) + } + + // ============================= + "ConstProp" should "work across wires" in { + val input = +"""circuit Top : + module Top : + input x : UInt<1> + output y : UInt<1> + wire z : UInt<1> + y <= z + z <= mux(x, UInt<1>(0), UInt<1>(0)) +""" + val check = +"""circuit Top : + module Top : + input x : UInt<1> + output y : UInt<1> + wire z : UInt<1> + y <= UInt<1>(0) + z <= UInt<1>(0) """ (parse(exec(input))) should be (parse(check)) } diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala index b43df713..ab367554 100644 --- a/src/test/scala/firrtlTests/LowerTypesSpec.scala +++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala @@ -8,6 +8,7 @@ import org.scalatest.prop._ import firrtl.Parser import firrtl.ir.Circuit import firrtl.passes._ +import firrtl.transforms._ import firrtl._ class LowerTypesSpec extends FirrtlFlatSpec { @@ -27,7 +28,7 @@ class LowerTypesSpec extends FirrtlFlatSpec { ExpandWhens, CheckInitialization, Legalize, - ConstProp, + new ConstantPropagation, ResolveKinds, InferTypes, ResolveGenders, diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index 25f845bc..7cbfeafe 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -22,7 +22,7 @@ class ReplSeqMemSpec extends SimpleTransformSpec { new SeqTransform { def inputForm = LowForm def outputForm = LowForm - def transforms = Seq(ConstProp, CommonSubexpressionElimination, new DeadCodeElimination, RemoveEmpty) + def transforms = Seq(new ConstantPropagation, CommonSubexpressionElimination, new DeadCodeElimination, RemoveEmpty) } ) diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index 0d5d098c..f717fc18 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -8,13 +8,15 @@ import org.scalatest.prop._ import firrtl._ import firrtl.ir.Circuit import firrtl.passes._ +import firrtl.transforms._ import firrtl.Parser.IgnoreInfo class UnitTests extends FirrtlFlatSpec { - private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { - val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) - } + private def executeTest(input: String, expected: Seq[String], transforms: Seq[Transform]) = { + val c = transforms.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, t: Transform) => t.runTransform(c) + }.circuit + val lines = c.serialize.split("\n") map normalized expected foreach { e => @@ -199,7 +201,7 @@ class UnitTests extends FirrtlFlatSpec { PullMuxes, ExpandConnects, RemoveAccesses, - ConstProp + new ConstantPropagation ) val input = """circuit AssignViaDeref : -- cgit v1.2.3 From a0aeafa3d591f9bcc14eca6d8a41eb2155f1b5b0 Mon Sep 17 00:00:00 2001 From: Jack Koenig Date: Wed, 28 Jun 2017 18:49:32 -0700 Subject: Make Constant Propagation respect dontTouch Constant Propagation will not optimize across components marked dontTouch --- .../firrtl/transforms/ConstantPropagation.scala | 25 +++++++++---- .../firrtlTests/ConstantPropagationTests.scala | 43 ++++++++++++++++++++++ 2 files changed, 61 insertions(+), 7 deletions(-) (limited to 'src') diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 930fe45a..efe06e9b 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -4,6 +4,7 @@ package firrtl package transforms import firrtl._ +import firrtl.annotations._ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ @@ -243,7 +244,7 @@ class ConstantPropagation extends Transform { // 2. Propagate references again for backwards reference (Wires) // TODO Replacing all wires with nodes makes the second pass unnecessary @tailrec - private def constPropModule(m: Module): Module = { + private def constPropModule(m: Module, dontTouches: Set[String]): Module = { var nPropagated = 0L val nodeMap = collection.mutable.HashMap[String, Expression]() @@ -276,8 +277,8 @@ class ConstantPropagation extends Transform { def constPropStmt(s: Statement): Statement = { val stmtx = s map constPropStmt map constPropExpression stmtx match { - case x: DefNode => nodeMap(x.name) = x.value - case Connect(_, WRef(wname, wtpe, WireKind, _), expr) => + case x: DefNode if !dontTouches.contains(x.name) => nodeMap(x.name) = x.value + case Connect(_, WRef(wname, wtpe, WireKind, _), expr) if !dontTouches.contains(wname) => val exprx = constPropExpression(pad(expr, wtpe)) nodeMap(wname) = exprx case _ => @@ -286,18 +287,28 @@ class ConstantPropagation extends Transform { } val res = Module(m.info, m.name, m.ports, backPropStmt(constPropStmt(m.body))) - if (nPropagated > 0) constPropModule(res) else res + if (nPropagated > 0) constPropModule(res, dontTouches) else res } - def run(c: Circuit): Circuit = { + private def run(c: Circuit, dontTouchMap: Map[String, Set[String]]): Circuit = { val modulesx = c.modules.map { case m: ExtModule => m - case m: Module => constPropModule(m) + case m: Module => constPropModule(m, dontTouchMap.getOrElse(m.name, Set.empty)) } Circuit(c.info, modulesx, c.main) } def execute(state: CircuitState): CircuitState = { - state.copy(circuit = run(state.circuit)) + val dontTouches: Seq[(String, String)] = state.annotations match { + case Some(aMap) => aMap.annotations.collect { + case DontTouchAnnotation(ComponentName(c, ModuleName(m, _))) => m -> c + } + case None => Seq.empty + } + // Map from module name to component names + val dontTouchMap: Map[String, Set[String]] = + dontTouches.groupBy(_._1).mapValues(_.map(_._2).toSet) + + state.copy(circuit = run(state.circuit, dontTouchMap)) } } diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index c94adbf6..f818f9c0 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -373,3 +373,46 @@ class ConstantPropagationSpec extends FirrtlFlatSpec { (parse(exec(input))) should be (parse(check)) } } + +// More sophisticated tests of the full compiler +class ConstantPropagationIntegrationSpec extends LowTransformSpec { + def transform = new LowFirrtlOptimization + + "ConstProp" should "should not optimize across dontTouch on nodes" in { + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output y : UInt<1> + | node z = x + | y <= z""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output y : UInt<1> + | node z = x + | y <= z""".stripMargin + execute(input, check, Seq(dontTouch("Top.z"))) + } + + it should "should not optimize across dontTouch on wires" in { + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output y : UInt<1> + | wire z : UInt<1> + | y <= z + | z <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output y : UInt<1> + | wire z : UInt<1> + | y <= z + | z <= x""".stripMargin + execute(input, check, Seq(dontTouch("Top.z"))) + } +} -- cgit v1.2.3