aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJack Koenig2017-06-28 17:52:56 -0700
committerJack Koenig2017-06-28 17:52:56 -0700
commit39665e1f74cfe8243067442cccf4e7eab66ade68 (patch)
tree8ba403e298c39bc6104f32a93754079dc458752a /src
parent818cfde4ad42ffa9ee30d0f9ae72533ede80e4ce (diff)
Promote ConstProp to a transform
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala6
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala (renamed from src/main/scala/firrtl/passes/ConstProp.scala)12
-rw-r--r--src/test/scala/firrtlTests/CInferMDirSpec.scala3
-rw-r--r--src/test/scala/firrtlTests/ChirrtlMemSpec.scala3
-rw-r--r--src/test/scala/firrtlTests/ConstantPropagationTests.scala39
-rw-r--r--src/test/scala/firrtlTests/LowerTypesSpec.scala3
-rw-r--r--src/test/scala/firrtlTests/ReplSeqMemTests.scala2
-rw-r--r--src/test/scala/firrtlTests/UnitTests.scala12
8 files changed, 58 insertions, 22 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index 66ae1673..8dd9b180 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -98,12 +98,12 @@ class LowFirrtlOptimization extends CoreTransform {
def outputForm = LowForm
def transforms = Seq(
passes.RemoveValidIf,
- passes.ConstProp,
+ new firrtl.transforms.ConstantPropagation,
passes.PadWidths,
- passes.ConstProp,
+ new firrtl.transforms.ConstantPropagation,
passes.Legalize,
passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter
- passes.ConstProp,
+ new firrtl.transforms.ConstantPropagation,
passes.SplitExpressions,
passes.CommonSubexpressionElimination,
new firrtl.transforms.DeadCodeElimination)
diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index f2aa1a03..930fe45a 100644
--- a/src/main/scala/firrtl/passes/ConstProp.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -1,6 +1,7 @@
// See LICENSE for license details.
-package firrtl.passes
+package firrtl
+package transforms
import firrtl._
import firrtl.ir._
@@ -10,7 +11,10 @@ import firrtl.PrimOps._
import annotation.tailrec
-object ConstProp extends Pass {
+class ConstantPropagation extends Transform {
+ def inputForm = LowForm
+ def outputForm = LowForm
+
private def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match {
case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t)
case (we, wt) if we == wt => e
@@ -292,4 +296,8 @@ object ConstProp extends Pass {
}
Circuit(c.info, modulesx, c.main)
}
+
+ def execute(state: CircuitState): CircuitState = {
+ state.copy(circuit = run(state.circuit))
+ }
}
diff --git a/src/test/scala/firrtlTests/CInferMDirSpec.scala b/src/test/scala/firrtlTests/CInferMDirSpec.scala
index 0d31038a..299142d9 100644
--- a/src/test/scala/firrtlTests/CInferMDirSpec.scala
+++ b/src/test/scala/firrtlTests/CInferMDirSpec.scala
@@ -5,6 +5,7 @@ package firrtlTests
import firrtl._
import firrtl.ir._
import firrtl.passes._
+import firrtl.transforms._
import firrtl.Mappers._
import annotations._
@@ -39,7 +40,7 @@ class CInferMDir extends LowTransformSpec {
def transform = new SeqTransform {
def inputForm = LowForm
def outputForm = LowForm
- def transforms = Seq(ConstProp, CInferMDirCheckPass)
+ def transforms = Seq(new ConstantPropagation, CInferMDirCheckPass)
}
"Memory" should "have correct mem port directions" in {
diff --git a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala
index c963c8ae..6fac5047 100644
--- a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala
+++ b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala
@@ -5,6 +5,7 @@ package firrtlTests
import firrtl._
import firrtl.ir._
import firrtl.passes._
+import firrtl.transforms._
import firrtl.Mappers._
import annotations._
@@ -53,7 +54,7 @@ class ChirrtlMemSpec extends LowTransformSpec {
def transform = new SeqTransform {
def inputForm = LowForm
def outputForm = LowForm
- def transforms = Seq(ConstProp, MemEnableCheckPass)
+ def transforms = Seq(new ConstantPropagation, MemEnableCheckPass)
}
"Sequential Memory" should "have correct enable signals" in {
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
index 95785717..c94adbf6 100644
--- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala
+++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
@@ -2,11 +2,11 @@
package firrtlTests
-import org.scalatest.Matchers
+import firrtl._
import firrtl.ir.Circuit
import firrtl.Parser.IgnoreInfo
-import firrtl.Parser
import firrtl.passes._
+import firrtl.transforms._
// Tests the following cases for constant propagation:
// 1) Unsigned integers are always greater than or
@@ -16,17 +16,17 @@ import firrtl.passes._
// 3) Values are always greater than a number smaller
// than their minimum value
class ConstantPropagationSpec extends FirrtlFlatSpec {
- val passes = Seq(
+ val transforms = Seq(
ToWorkingIR,
ResolveKinds,
InferTypes,
ResolveGenders,
InferWidths,
- ConstProp)
- private def exec (input: String) = {
- passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
- }.serialize
+ new ConstantPropagation)
+ private def exec(input: String) = {
+ transforms.foldLeft(CircuitState(parse(input), UnknownForm)) {
+ (c: CircuitState, t: Transform) => t.runTransform(c)
+ }.circuit.serialize
}
// =============================
"The rule x >= 0 " should " always be true if x is a UInt" in {
@@ -349,4 +349,27 @@ class ConstantPropagationSpec extends FirrtlFlatSpec {
"""
(parse(exec(input))) should be (parse(check))
}
+
+ // =============================
+ "ConstProp" should "work across wires" in {
+ val input =
+"""circuit Top :
+ module Top :
+ input x : UInt<1>
+ output y : UInt<1>
+ wire z : UInt<1>
+ y <= z
+ z <= mux(x, UInt<1>(0), UInt<1>(0))
+"""
+ val check =
+"""circuit Top :
+ module Top :
+ input x : UInt<1>
+ output y : UInt<1>
+ wire z : UInt<1>
+ y <= UInt<1>(0)
+ z <= UInt<1>(0)
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
}
diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala
index b43df713..ab367554 100644
--- a/src/test/scala/firrtlTests/LowerTypesSpec.scala
+++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala
@@ -8,6 +8,7 @@ import org.scalatest.prop._
import firrtl.Parser
import firrtl.ir.Circuit
import firrtl.passes._
+import firrtl.transforms._
import firrtl._
class LowerTypesSpec extends FirrtlFlatSpec {
@@ -27,7 +28,7 @@ class LowerTypesSpec extends FirrtlFlatSpec {
ExpandWhens,
CheckInitialization,
Legalize,
- ConstProp,
+ new ConstantPropagation,
ResolveKinds,
InferTypes,
ResolveGenders,
diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
index 25f845bc..7cbfeafe 100644
--- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala
+++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
@@ -22,7 +22,7 @@ class ReplSeqMemSpec extends SimpleTransformSpec {
new SeqTransform {
def inputForm = LowForm
def outputForm = LowForm
- def transforms = Seq(ConstProp, CommonSubexpressionElimination, new DeadCodeElimination, RemoveEmpty)
+ def transforms = Seq(new ConstantPropagation, CommonSubexpressionElimination, new DeadCodeElimination, RemoveEmpty)
}
)
diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala
index 0d5d098c..f717fc18 100644
--- a/src/test/scala/firrtlTests/UnitTests.scala
+++ b/src/test/scala/firrtlTests/UnitTests.scala
@@ -8,13 +8,15 @@ import org.scalatest.prop._
import firrtl._
import firrtl.ir.Circuit
import firrtl.passes._
+import firrtl.transforms._
import firrtl.Parser.IgnoreInfo
class UnitTests extends FirrtlFlatSpec {
- private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = {
- val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
+ private def executeTest(input: String, expected: Seq[String], transforms: Seq[Transform]) = {
+ val c = transforms.foldLeft(CircuitState(parse(input), UnknownForm)) {
+ (c: CircuitState, t: Transform) => t.runTransform(c)
+ }.circuit
+
val lines = c.serialize.split("\n") map normalized
expected foreach { e =>
@@ -199,7 +201,7 @@ class UnitTests extends FirrtlFlatSpec {
PullMuxes,
ExpandConnects,
RemoveAccesses,
- ConstProp
+ new ConstantPropagation
)
val input =
"""circuit AssignViaDeref :