diff options
| author | Jack Koenig | 2017-06-28 17:52:56 -0700 |
|---|---|---|
| committer | Jack Koenig | 2017-06-28 17:52:56 -0700 |
| commit | 39665e1f74cfe8243067442cccf4e7eab66ade68 (patch) | |
| tree | 8ba403e298c39bc6104f32a93754079dc458752a /src | |
| parent | 818cfde4ad42ffa9ee30d0f9ae72533ede80e4ce (diff) | |
Promote ConstProp to a transform
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/LoweringCompilers.scala | 6 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala (renamed from src/main/scala/firrtl/passes/ConstProp.scala) | 12 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/CInferMDirSpec.scala | 3 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ChirrtlMemSpec.scala | 3 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ConstantPropagationTests.scala | 39 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/LowerTypesSpec.scala | 3 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ReplSeqMemTests.scala | 2 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/UnitTests.scala | 12 |
8 files changed, 58 insertions, 22 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index 66ae1673..8dd9b180 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -98,12 +98,12 @@ class LowFirrtlOptimization extends CoreTransform { def outputForm = LowForm def transforms = Seq( passes.RemoveValidIf, - passes.ConstProp, + new firrtl.transforms.ConstantPropagation, passes.PadWidths, - passes.ConstProp, + new firrtl.transforms.ConstantPropagation, passes.Legalize, passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter - passes.ConstProp, + new firrtl.transforms.ConstantPropagation, passes.SplitExpressions, passes.CommonSubexpressionElimination, new firrtl.transforms.DeadCodeElimination) diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index f2aa1a03..930fe45a 100644 --- a/src/main/scala/firrtl/passes/ConstProp.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -1,6 +1,7 @@ // See LICENSE for license details. -package firrtl.passes +package firrtl +package transforms import firrtl._ import firrtl.ir._ @@ -10,7 +11,10 @@ import firrtl.PrimOps._ import annotation.tailrec -object ConstProp extends Pass { +class ConstantPropagation extends Transform { + def inputForm = LowForm + def outputForm = LowForm + private def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match { case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) case (we, wt) if we == wt => e @@ -292,4 +296,8 @@ object ConstProp extends Pass { } Circuit(c.info, modulesx, c.main) } + + def execute(state: CircuitState): CircuitState = { + state.copy(circuit = run(state.circuit)) + } } diff --git a/src/test/scala/firrtlTests/CInferMDirSpec.scala b/src/test/scala/firrtlTests/CInferMDirSpec.scala index 0d31038a..299142d9 100644 --- a/src/test/scala/firrtlTests/CInferMDirSpec.scala +++ b/src/test/scala/firrtlTests/CInferMDirSpec.scala @@ -5,6 +5,7 @@ package firrtlTests import firrtl._ import firrtl.ir._ import firrtl.passes._ +import firrtl.transforms._ import firrtl.Mappers._ import annotations._ @@ -39,7 +40,7 @@ class CInferMDir extends LowTransformSpec { def transform = new SeqTransform { def inputForm = LowForm def outputForm = LowForm - def transforms = Seq(ConstProp, CInferMDirCheckPass) + def transforms = Seq(new ConstantPropagation, CInferMDirCheckPass) } "Memory" should "have correct mem port directions" in { diff --git a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala index c963c8ae..6fac5047 100644 --- a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala +++ b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala @@ -5,6 +5,7 @@ package firrtlTests import firrtl._ import firrtl.ir._ import firrtl.passes._ +import firrtl.transforms._ import firrtl.Mappers._ import annotations._ @@ -53,7 +54,7 @@ class ChirrtlMemSpec extends LowTransformSpec { def transform = new SeqTransform { def inputForm = LowForm def outputForm = LowForm - def transforms = Seq(ConstProp, MemEnableCheckPass) + def transforms = Seq(new ConstantPropagation, MemEnableCheckPass) } "Sequential Memory" should "have correct enable signals" in { diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index 95785717..c94adbf6 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -2,11 +2,11 @@ package firrtlTests -import org.scalatest.Matchers +import firrtl._ import firrtl.ir.Circuit import firrtl.Parser.IgnoreInfo -import firrtl.Parser import firrtl.passes._ +import firrtl.transforms._ // Tests the following cases for constant propagation: // 1) Unsigned integers are always greater than or @@ -16,17 +16,17 @@ import firrtl.passes._ // 3) Values are always greater than a number smaller // than their minimum value class ConstantPropagationSpec extends FirrtlFlatSpec { - val passes = Seq( + val transforms = Seq( ToWorkingIR, ResolveKinds, InferTypes, ResolveGenders, InferWidths, - ConstProp) - private def exec (input: String) = { - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - }.serialize + new ConstantPropagation) + private def exec(input: String) = { + transforms.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, t: Transform) => t.runTransform(c) + }.circuit.serialize } // ============================= "The rule x >= 0 " should " always be true if x is a UInt" in { @@ -349,4 +349,27 @@ class ConstantPropagationSpec extends FirrtlFlatSpec { """ (parse(exec(input))) should be (parse(check)) } + + // ============================= + "ConstProp" should "work across wires" in { + val input = +"""circuit Top : + module Top : + input x : UInt<1> + output y : UInt<1> + wire z : UInt<1> + y <= z + z <= mux(x, UInt<1>(0), UInt<1>(0)) +""" + val check = +"""circuit Top : + module Top : + input x : UInt<1> + output y : UInt<1> + wire z : UInt<1> + y <= UInt<1>(0) + z <= UInt<1>(0) +""" + (parse(exec(input))) should be (parse(check)) + } } diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala index b43df713..ab367554 100644 --- a/src/test/scala/firrtlTests/LowerTypesSpec.scala +++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala @@ -8,6 +8,7 @@ import org.scalatest.prop._ import firrtl.Parser import firrtl.ir.Circuit import firrtl.passes._ +import firrtl.transforms._ import firrtl._ class LowerTypesSpec extends FirrtlFlatSpec { @@ -27,7 +28,7 @@ class LowerTypesSpec extends FirrtlFlatSpec { ExpandWhens, CheckInitialization, Legalize, - ConstProp, + new ConstantPropagation, ResolveKinds, InferTypes, ResolveGenders, diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index 25f845bc..7cbfeafe 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -22,7 +22,7 @@ class ReplSeqMemSpec extends SimpleTransformSpec { new SeqTransform { def inputForm = LowForm def outputForm = LowForm - def transforms = Seq(ConstProp, CommonSubexpressionElimination, new DeadCodeElimination, RemoveEmpty) + def transforms = Seq(new ConstantPropagation, CommonSubexpressionElimination, new DeadCodeElimination, RemoveEmpty) } ) diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index 0d5d098c..f717fc18 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -8,13 +8,15 @@ import org.scalatest.prop._ import firrtl._ import firrtl.ir.Circuit import firrtl.passes._ +import firrtl.transforms._ import firrtl.Parser.IgnoreInfo class UnitTests extends FirrtlFlatSpec { - private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { - val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) - } + private def executeTest(input: String, expected: Seq[String], transforms: Seq[Transform]) = { + val c = transforms.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, t: Transform) => t.runTransform(c) + }.circuit + val lines = c.serialize.split("\n") map normalized expected foreach { e => @@ -199,7 +201,7 @@ class UnitTests extends FirrtlFlatSpec { PullMuxes, ExpandConnects, RemoveAccesses, - ConstProp + new ConstantPropagation ) val input = """circuit AssignViaDeref : |
