aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJack Koenig2017-06-29 14:20:09 -0700
committerGitHub2017-06-29 14:20:09 -0700
commit905cac96053caf4b6c87ac0b9c8addf313d1085c (patch)
tree7d0bcf384f63e0176acdd70f9524369bb5bb4ce0 /src
parent8eb69dd91e58915f8dad5e42da0a3fe686c628d8 (diff)
parenta0aeafa3d591f9bcc14eca6d8a41eb2155f1b5b0 (diff)
Merge pull request #617 from freechipsproject/const-prop-regs
Improvements to Constant Propagation and Testing
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala6
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala (renamed from src/main/scala/firrtl/passes/ConstProp.scala)35
-rw-r--r--src/test/scala/firrtlTests/AnnotationTests.scala17
-rw-r--r--src/test/scala/firrtlTests/CInferMDirSpec.scala3
-rw-r--r--src/test/scala/firrtlTests/ChirrtlMemSpec.scala3
-rw-r--r--src/test/scala/firrtlTests/ConstantPropagationTests.scala82
-rw-r--r--src/test/scala/firrtlTests/FirrtlSpec.scala4
-rw-r--r--src/test/scala/firrtlTests/InlineInstancesTests.scala38
-rw-r--r--src/test/scala/firrtlTests/LowerTypesSpec.scala3
-rw-r--r--src/test/scala/firrtlTests/PassTests.scala14
-rw-r--r--src/test/scala/firrtlTests/ReplSeqMemTests.scala2
-rw-r--r--src/test/scala/firrtlTests/UnitTests.scala12
-rw-r--r--src/test/scala/firrtlTests/transforms/BlacklBoxSourceHelperSpec.scala6
-rw-r--r--src/test/scala/firrtlTests/transforms/DedupTests.scala12
14 files changed, 158 insertions, 79 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index 66ae1673..8dd9b180 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -98,12 +98,12 @@ class LowFirrtlOptimization extends CoreTransform {
def outputForm = LowForm
def transforms = Seq(
passes.RemoveValidIf,
- passes.ConstProp,
+ new firrtl.transforms.ConstantPropagation,
passes.PadWidths,
- passes.ConstProp,
+ new firrtl.transforms.ConstantPropagation,
passes.Legalize,
passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter
- passes.ConstProp,
+ new firrtl.transforms.ConstantPropagation,
passes.SplitExpressions,
passes.CommonSubexpressionElimination,
new firrtl.transforms.DeadCodeElimination)
diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index f2aa1a03..efe06e9b 100644
--- a/src/main/scala/firrtl/passes/ConstProp.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -1,8 +1,10 @@
// See LICENSE for license details.
-package firrtl.passes
+package firrtl
+package transforms
import firrtl._
+import firrtl.annotations._
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
@@ -10,7 +12,10 @@ import firrtl.PrimOps._
import annotation.tailrec
-object ConstProp extends Pass {
+class ConstantPropagation extends Transform {
+ def inputForm = LowForm
+ def outputForm = LowForm
+
private def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match {
case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t)
case (we, wt) if we == wt => e
@@ -239,7 +244,7 @@ object ConstProp extends Pass {
// 2. Propagate references again for backwards reference (Wires)
// TODO Replacing all wires with nodes makes the second pass unnecessary
@tailrec
- private def constPropModule(m: Module): Module = {
+ private def constPropModule(m: Module, dontTouches: Set[String]): Module = {
var nPropagated = 0L
val nodeMap = collection.mutable.HashMap[String, Expression]()
@@ -272,8 +277,8 @@ object ConstProp extends Pass {
def constPropStmt(s: Statement): Statement = {
val stmtx = s map constPropStmt map constPropExpression
stmtx match {
- case x: DefNode => nodeMap(x.name) = x.value
- case Connect(_, WRef(wname, wtpe, WireKind, _), expr) =>
+ case x: DefNode if !dontTouches.contains(x.name) => nodeMap(x.name) = x.value
+ case Connect(_, WRef(wname, wtpe, WireKind, _), expr) if !dontTouches.contains(wname) =>
val exprx = constPropExpression(pad(expr, wtpe))
nodeMap(wname) = exprx
case _ =>
@@ -282,14 +287,28 @@ object ConstProp extends Pass {
}
val res = Module(m.info, m.name, m.ports, backPropStmt(constPropStmt(m.body)))
- if (nPropagated > 0) constPropModule(res) else res
+ if (nPropagated > 0) constPropModule(res, dontTouches) else res
}
- def run(c: Circuit): Circuit = {
+ private def run(c: Circuit, dontTouchMap: Map[String, Set[String]]): Circuit = {
val modulesx = c.modules.map {
case m: ExtModule => m
- case m: Module => constPropModule(m)
+ case m: Module => constPropModule(m, dontTouchMap.getOrElse(m.name, Set.empty))
}
Circuit(c.info, modulesx, c.main)
}
+
+ def execute(state: CircuitState): CircuitState = {
+ val dontTouches: Seq[(String, String)] = state.annotations match {
+ case Some(aMap) => aMap.annotations.collect {
+ case DontTouchAnnotation(ComponentName(c, ModuleName(m, _))) => m -> c
+ }
+ case None => Seq.empty
+ }
+ // Map from module name to component names
+ val dontTouchMap: Map[String, Set[String]] =
+ dontTouches.groupBy(_._1).mapValues(_.map(_._2).toSet)
+
+ state.copy(circuit = run(state.circuit, dontTouchMap))
+ }
}
diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala
index 3e93081e..aeefbbe3 100644
--- a/src/test/scala/firrtlTests/AnnotationTests.scala
+++ b/src/test/scala/firrtlTests/AnnotationTests.scala
@@ -23,13 +23,13 @@ trait AnnotationSpec extends LowTransformSpec {
def transform = new ResolveAndCheck
// Check if Annotation Exception is thrown
- override def failingexecute(annotations: AnnotationMap, input: String): Exception = {
+ override def failingexecute(input: String, annotations: Seq[Annotation]): Exception = {
intercept[AnnotationException] {
- compile(CircuitState(parse(input), ChirrtlForm, Some(annotations)), Seq.empty)
+ compile(CircuitState(parse(input), ChirrtlForm, Some(AnnotationMap(annotations))), Seq.empty)
}
}
- def execute(aMap: Option[AnnotationMap], input: String, check: Annotation): Unit = {
- val cr = compile(CircuitState(parse(input), ChirrtlForm, aMap), Seq.empty)
+ def execute(input: String, check: Annotation, annotations: Seq[Annotation]): Unit = {
+ val cr = compile(CircuitState(parse(input), ChirrtlForm, Some(AnnotationMap(annotations))), Seq.empty)
cr.annotations.get.annotations should contain (check)
}
}
@@ -49,13 +49,6 @@ class AnnotationTests extends AnnotationSpec with Matchers {
Annotation(ComponentName(s, ModuleName(mod, CircuitName("Top"))), classOf[Transform], value)
def manno(mod: String): Annotation =
Annotation(ModuleName(mod, CircuitName("Top")), classOf[Transform], "some value")
- // TODO unify with FirrtlMatchers, problems with multiple definitions of parse
- def dontTouch(path: String): Annotation = {
- val parts = path.split('.')
- require(parts.size >= 2, "Must specify both module and component!")
- val name = ComponentName(parts.tail.mkString("."), ModuleName(parts.head, CircuitName("Top")))
- DontTouchAnnotation(name)
- }
"Loose and Sticky annotation on a node" should "pass through" in {
val input: String =
@@ -65,7 +58,7 @@ class AnnotationTests extends AnnotationSpec with Matchers {
| input b : UInt<1>
| node c = b""".stripMargin
val ta = anno("c", "")
- execute(getAMap(ta), input, ta)
+ execute(input, ta, Seq(ta))
}
"Annotations" should "be readable from file" in {
diff --git a/src/test/scala/firrtlTests/CInferMDirSpec.scala b/src/test/scala/firrtlTests/CInferMDirSpec.scala
index 0d31038a..299142d9 100644
--- a/src/test/scala/firrtlTests/CInferMDirSpec.scala
+++ b/src/test/scala/firrtlTests/CInferMDirSpec.scala
@@ -5,6 +5,7 @@ package firrtlTests
import firrtl._
import firrtl.ir._
import firrtl.passes._
+import firrtl.transforms._
import firrtl.Mappers._
import annotations._
@@ -39,7 +40,7 @@ class CInferMDir extends LowTransformSpec {
def transform = new SeqTransform {
def inputForm = LowForm
def outputForm = LowForm
- def transforms = Seq(ConstProp, CInferMDirCheckPass)
+ def transforms = Seq(new ConstantPropagation, CInferMDirCheckPass)
}
"Memory" should "have correct mem port directions" in {
diff --git a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala
index c963c8ae..6fac5047 100644
--- a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala
+++ b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala
@@ -5,6 +5,7 @@ package firrtlTests
import firrtl._
import firrtl.ir._
import firrtl.passes._
+import firrtl.transforms._
import firrtl.Mappers._
import annotations._
@@ -53,7 +54,7 @@ class ChirrtlMemSpec extends LowTransformSpec {
def transform = new SeqTransform {
def inputForm = LowForm
def outputForm = LowForm
- def transforms = Seq(ConstProp, MemEnableCheckPass)
+ def transforms = Seq(new ConstantPropagation, MemEnableCheckPass)
}
"Sequential Memory" should "have correct enable signals" in {
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
index 95785717..f818f9c0 100644
--- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala
+++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
@@ -2,11 +2,11 @@
package firrtlTests
-import org.scalatest.Matchers
+import firrtl._
import firrtl.ir.Circuit
import firrtl.Parser.IgnoreInfo
-import firrtl.Parser
import firrtl.passes._
+import firrtl.transforms._
// Tests the following cases for constant propagation:
// 1) Unsigned integers are always greater than or
@@ -16,17 +16,17 @@ import firrtl.passes._
// 3) Values are always greater than a number smaller
// than their minimum value
class ConstantPropagationSpec extends FirrtlFlatSpec {
- val passes = Seq(
+ val transforms = Seq(
ToWorkingIR,
ResolveKinds,
InferTypes,
ResolveGenders,
InferWidths,
- ConstProp)
- private def exec (input: String) = {
- passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
- }.serialize
+ new ConstantPropagation)
+ private def exec(input: String) = {
+ transforms.foldLeft(CircuitState(parse(input), UnknownForm)) {
+ (c: CircuitState, t: Transform) => t.runTransform(c)
+ }.circuit.serialize
}
// =============================
"The rule x >= 0 " should " always be true if x is a UInt" in {
@@ -349,4 +349,70 @@ class ConstantPropagationSpec extends FirrtlFlatSpec {
"""
(parse(exec(input))) should be (parse(check))
}
+
+ // =============================
+ "ConstProp" should "work across wires" in {
+ val input =
+"""circuit Top :
+ module Top :
+ input x : UInt<1>
+ output y : UInt<1>
+ wire z : UInt<1>
+ y <= z
+ z <= mux(x, UInt<1>(0), UInt<1>(0))
+"""
+ val check =
+"""circuit Top :
+ module Top :
+ input x : UInt<1>
+ output y : UInt<1>
+ wire z : UInt<1>
+ y <= UInt<1>(0)
+ z <= UInt<1>(0)
+"""
+ (parse(exec(input))) should be (parse(check))
+ }
+}
+
+// More sophisticated tests of the full compiler
+class ConstantPropagationIntegrationSpec extends LowTransformSpec {
+ def transform = new LowFirrtlOptimization
+
+ "ConstProp" should "should not optimize across dontTouch on nodes" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output y : UInt<1>
+ | node z = x
+ | y <= z""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output y : UInt<1>
+ | node z = x
+ | y <= z""".stripMargin
+ execute(input, check, Seq(dontTouch("Top.z")))
+ }
+
+ it should "should not optimize across dontTouch on wires" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output y : UInt<1>
+ | wire z : UInt<1>
+ | y <= z
+ | z <= x""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output y : UInt<1>
+ | wire z : UInt<1>
+ | y <= z
+ | z <= x""".stripMargin
+ execute(input, check, Seq(dontTouch("Top.z")))
+ }
}
diff --git a/src/test/scala/firrtlTests/FirrtlSpec.scala b/src/test/scala/firrtlTests/FirrtlSpec.scala
index a45af8c7..07f83142 100644
--- a/src/test/scala/firrtlTests/FirrtlSpec.scala
+++ b/src/test/scala/firrtlTests/FirrtlSpec.scala
@@ -11,7 +11,7 @@ import org.scalatest.prop._
import scala.io.Source
import firrtl._
-import firrtl.Parser.IgnoreInfo
+import firrtl.Parser.UseInfo
import firrtl.annotations._
import firrtl.transforms.{DontTouchAnnotation, NoDedupAnnotation}
import firrtl.util.BackendCompilationUtilities
@@ -100,7 +100,7 @@ trait FirrtlMatchers extends Matchers {
require(!s.contains("\n"))
s.replaceAll("\\s+", " ").trim
}
- def parse(str: String) = Parser.parse(str.split("\n").toIterator, IgnoreInfo)
+ def parse(str: String) = Parser.parse(str.split("\n").toIterator, UseInfo)
/** Helper for executing tests
* compiler will be run on input then emitted result will each be split into
* lines and normalized.
diff --git a/src/test/scala/firrtlTests/InlineInstancesTests.scala b/src/test/scala/firrtlTests/InlineInstancesTests.scala
index 9e8f8054..4398df48 100644
--- a/src/test/scala/firrtlTests/InlineInstancesTests.scala
+++ b/src/test/scala/firrtlTests/InlineInstancesTests.scala
@@ -6,7 +6,7 @@ import org.scalatest.FlatSpec
import org.scalatest.Matchers
import org.scalatest.junit.JUnitRunner
import firrtl.ir.Circuit
-import firrtl.{AnnotationMap, Parser}
+import firrtl.Parser
import firrtl.passes.PassExceptions
import firrtl.annotations.{Annotation, CircuitName, ComponentName, ModuleName, Named}
import firrtl.passes.{InlineAnnotation, InlineInstances}
@@ -18,7 +18,14 @@ import logger.LogLevel.Debug
* Tests inline instances transformation
*/
class InlineInstancesTests extends LowTransformSpec {
- def transform = new InlineInstances
+ def transform = new InlineInstances
+ def inline(mod: String): Annotation = {
+ val parts = mod.split('.')
+ val modName = ModuleName(parts.head, CircuitName("Top")) // If this fails, bad input
+ val name = if (parts.size == 1) modName
+ else ComponentName(parts.tail.mkString("."), modName)
+ InlineAnnotation(name)
+ }
// Set this to debug, this will apply to all tests
// Logger.setLevel(this.getClass, Debug)
"The module Inline" should "be inlined" in {
@@ -44,8 +51,7 @@ class InlineInstancesTests extends LowTransformSpec {
| i$b <= i$a
| b <= i$b
| i$a <= a""".stripMargin
- val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("Inline", CircuitName("Top")))))
- execute(aMap, input, check)
+ execute(input, check, Seq(inline("Inline")))
}
"The all instances of Simple" should "be inlined" in {
@@ -77,8 +83,7 @@ class InlineInstancesTests extends LowTransformSpec {
| b <= i1$b
| i0$a <= a
| i1$a <= i0$b""".stripMargin
- val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("Simple", CircuitName("Top")))))
- execute(aMap, input, check)
+ execute(input, check, Seq(inline("Simple")))
}
"Only one instance of Simple" should "be inlined" in {
@@ -112,8 +117,7 @@ class InlineInstancesTests extends LowTransformSpec {
| input a : UInt<32>
| output b : UInt<32>
| b <= a""".stripMargin
- val aMap = new AnnotationMap(Seq(InlineAnnotation(ComponentName("i0",ModuleName("Top", CircuitName("Top"))))))
- execute(aMap, input, check)
+ execute(input, check, Seq(inline("Top.i0")))
}
"All instances of A" should "be inlined" in {
@@ -157,8 +161,7 @@ class InlineInstancesTests extends LowTransformSpec {
| i$b <= i$a
| b <= i$b
| i$a <= a""".stripMargin
- val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")))))
- execute(aMap, input, check)
+ execute(input, check, Seq(inline("A")))
}
"Non-inlined instances" should "still prepend prefix" in {
@@ -196,8 +199,7 @@ class InlineInstancesTests extends LowTransformSpec {
| input a : UInt<32>
| output b : UInt<32>
| b <= a""".stripMargin
- val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")))))
- execute(aMap, input, check)
+ execute(input, check, Seq(inline("A")))
}
// ---- Errors ----
@@ -214,8 +216,7 @@ class InlineInstancesTests extends LowTransformSpec {
| extmodule A :
| input a : UInt<32>
| output b : UInt<32>""".stripMargin
- val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")))))
- failingexecute(aMap, input)
+ failingexecute(input, Seq(inline("A")))
}
// 2) ext instance
"External instance" should "not be inlined" in {
@@ -230,8 +231,7 @@ class InlineInstancesTests extends LowTransformSpec {
| extmodule A :
| input a : UInt<32>
| output b : UInt<32>""".stripMargin
- val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")))))
- failingexecute(aMap, input)
+ failingexecute(input, Seq(inline("A")))
}
// 3) no module
"Inlined module" should "exist" in {
@@ -241,8 +241,7 @@ class InlineInstancesTests extends LowTransformSpec {
| input a : UInt<32>
| output b : UInt<32>
| b <= a""".stripMargin
- val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")))))
- failingexecute(aMap, input)
+ failingexecute(input, Seq(inline("A")))
}
// 4) no inst
"Inlined instance" should "exist" in {
@@ -252,8 +251,7 @@ class InlineInstancesTests extends LowTransformSpec {
| input a : UInt<32>
| output b : UInt<32>
| b <= a""".stripMargin
- val aMap = new AnnotationMap(Seq(InlineAnnotation(ModuleName("A", CircuitName("Top")))))
- failingexecute(aMap, input)
+ failingexecute(input, Seq(inline("A")))
}
}
diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala
index b43df713..ab367554 100644
--- a/src/test/scala/firrtlTests/LowerTypesSpec.scala
+++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala
@@ -8,6 +8,7 @@ import org.scalatest.prop._
import firrtl.Parser
import firrtl.ir.Circuit
import firrtl.passes._
+import firrtl.transforms._
import firrtl._
class LowerTypesSpec extends FirrtlFlatSpec {
@@ -27,7 +28,7 @@ class LowerTypesSpec extends FirrtlFlatSpec {
ExpandWhens,
CheckInitialization,
Legalize,
- ConstProp,
+ new ConstantPropagation,
ResolveKinds,
InferTypes,
ResolveGenders,
diff --git a/src/test/scala/firrtlTests/PassTests.scala b/src/test/scala/firrtlTests/PassTests.scala
index e22fd513..6727533e 100644
--- a/src/test/scala/firrtlTests/PassTests.scala
+++ b/src/test/scala/firrtlTests/PassTests.scala
@@ -9,18 +9,19 @@ import firrtl.ir.Circuit
import firrtl.Parser.UseInfo
import firrtl.passes.{Pass, PassExceptions, RemoveEmpty}
import firrtl._
+import firrtl.annotations._
import logger._
// An example methodology for testing Firrtl Passes
// Spec class should extend this class
-abstract class SimpleTransformSpec extends FlatSpec with Matchers with Compiler with LazyLogging {
+abstract class SimpleTransformSpec extends FlatSpec with FirrtlMatchers with Compiler with LazyLogging {
// Utility function
- def parse(s: String): Circuit = Parser.parse(s.split("\n").toIterator, infoMode = UseInfo)
def squash(c: Circuit): Circuit = RemoveEmpty.run(c)
// Executes the test. Call in tests.
- def execute(annotations: AnnotationMap, input: String, check: String): Unit = {
- val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, Some(annotations)))
+ // annotations cannot have default value because scalatest trait Suite has a default value
+ def execute(input: String, check: String, annotations: Seq[Annotation]): Unit = {
+ val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, Some(AnnotationMap(annotations))))
val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize
val expected = parse(check).serialize
logger.debug(actual)
@@ -28,9 +29,10 @@ abstract class SimpleTransformSpec extends FlatSpec with Matchers with Compiler
(actual) should be (expected)
}
// Executes the test, should throw an error
- def failingexecute(annotations: AnnotationMap, input: String): Exception = {
+ // No default to be consistent with execute
+ def failingexecute(input: String, annotations: Seq[Annotation]): Exception = {
intercept[PassExceptions] {
- compile(CircuitState(parse(input), ChirrtlForm, Some(annotations)), Seq.empty)
+ compile(CircuitState(parse(input), ChirrtlForm, Some(AnnotationMap(annotations))), Seq.empty)
}
}
}
diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
index 25f845bc..7cbfeafe 100644
--- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala
+++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
@@ -22,7 +22,7 @@ class ReplSeqMemSpec extends SimpleTransformSpec {
new SeqTransform {
def inputForm = LowForm
def outputForm = LowForm
- def transforms = Seq(ConstProp, CommonSubexpressionElimination, new DeadCodeElimination, RemoveEmpty)
+ def transforms = Seq(new ConstantPropagation, CommonSubexpressionElimination, new DeadCodeElimination, RemoveEmpty)
}
)
diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala
index 0d5d098c..f717fc18 100644
--- a/src/test/scala/firrtlTests/UnitTests.scala
+++ b/src/test/scala/firrtlTests/UnitTests.scala
@@ -8,13 +8,15 @@ import org.scalatest.prop._
import firrtl._
import firrtl.ir.Circuit
import firrtl.passes._
+import firrtl.transforms._
import firrtl.Parser.IgnoreInfo
class UnitTests extends FirrtlFlatSpec {
- private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = {
- val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
- (c: Circuit, p: Pass) => p.run(c)
- }
+ private def executeTest(input: String, expected: Seq[String], transforms: Seq[Transform]) = {
+ val c = transforms.foldLeft(CircuitState(parse(input), UnknownForm)) {
+ (c: CircuitState, t: Transform) => t.runTransform(c)
+ }.circuit
+
val lines = c.serialize.split("\n") map normalized
expected foreach { e =>
@@ -199,7 +201,7 @@ class UnitTests extends FirrtlFlatSpec {
PullMuxes,
ExpandConnects,
RemoveAccesses,
- ConstProp
+ new ConstantPropagation
)
val input =
"""circuit AssignViaDeref :
diff --git a/src/test/scala/firrtlTests/transforms/BlacklBoxSourceHelperSpec.scala b/src/test/scala/firrtlTests/transforms/BlacklBoxSourceHelperSpec.scala
index 8cd51b2a..bf294fe9 100644
--- a/src/test/scala/firrtlTests/transforms/BlacklBoxSourceHelperSpec.scala
+++ b/src/test/scala/firrtlTests/transforms/BlacklBoxSourceHelperSpec.scala
@@ -78,12 +78,12 @@ class BlacklBoxSourceHelperTransformSpec extends LowTransformSpec {
"annotated external modules" should "appear in output directory" in {
- val aMap = AnnotationMap(Seq(
+ val annos = Seq(
Annotation(moduleName, classOf[BlackBoxSourceHelper], BlackBoxTargetDir("test_run_dir").serialize),
Annotation(moduleName, classOf[BlackBoxSourceHelper], BlackBoxResource("/blackboxes/AdderExtModule.v").serialize)
- ))
+ )
- execute(aMap, input, output)
+ execute(input, output, annos)
new java.io.File("test_run_dir/AdderExtModule.v").exists should be (true)
new java.io.File(s"test_run_dir/${BlackBoxSourceHelper.FileListName}").exists should be (true)
diff --git a/src/test/scala/firrtlTests/transforms/DedupTests.scala b/src/test/scala/firrtlTests/transforms/DedupTests.scala
index 7148dd11..74c4b4e7 100644
--- a/src/test/scala/firrtlTests/transforms/DedupTests.scala
+++ b/src/test/scala/firrtlTests/transforms/DedupTests.scala
@@ -46,8 +46,7 @@ class DedupModuleTests extends HighTransformSpec {
| output x: UInt<1>
| x <= UInt(1)
""".stripMargin
- val aMap = new AnnotationMap(Nil)
- execute(aMap, input, check)
+ execute(input, check, Seq.empty)
}
"The module A and B" should "be deduped" in {
val input =
@@ -83,8 +82,7 @@ class DedupModuleTests extends HighTransformSpec {
| output x: UInt<1>
| x <= UInt(1)
""".stripMargin
- val aMap = new AnnotationMap(Nil)
- execute(aMap, input, check)
+ execute(input, check, Seq.empty)
}
"The module A and B with comments" should "be deduped" in {
val input =
@@ -120,8 +118,7 @@ class DedupModuleTests extends HighTransformSpec {
| output x: UInt<1>
| x <= UInt(1)
""".stripMargin
- val aMap = new AnnotationMap(Nil)
- execute(aMap, input, check)
+ execute(input, check, Seq.empty)
}
"The module B, but not A, with comments" should "be deduped if not annotated" in {
val input =
@@ -148,8 +145,7 @@ class DedupModuleTests extends HighTransformSpec {
| output x: UInt<1> @[xx 1:1]
| x <= UInt(1)
""".stripMargin
- val aMap = new AnnotationMap(Seq(NoDedupAnnotation(ModuleName("A", CircuitName("Top")))))
- execute(aMap, input, check)
+ execute(input, check, Seq(dontDedup("A")))
}
}