aboutsummaryrefslogtreecommitdiff
path: root/src/test/scala/firrtl
diff options
context:
space:
mode:
authorchick2020-08-14 19:47:53 -0700
committerJack Koenig2020-08-14 19:47:53 -0700
commit6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch)
tree2ed103ee80b0fba613c88a66af854ae9952610ce /src/test/scala/firrtl
parentb516293f703c4de86397862fee1897aded2ae140 (diff)
All of src/ formatted with scalafmt
Diffstat (limited to 'src/test/scala/firrtl')
-rw-r--r--src/test/scala/firrtl/JsonProtocolSpec.scala20
-rw-r--r--src/test/scala/firrtl/analysis/SymbolTableSpec.scala21
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala68
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala137
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala12
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala9
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala102
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala102
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala30
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala23
-rw-r--r--src/test/scala/firrtl/ir/StructuralHashSpec.scala120
-rw-r--r--src/test/scala/firrtl/passes/LowerTypesSpec.scala541
-rw-r--r--src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala225
-rw-r--r--src/test/scala/firrtl/testutils/FirrtlSpec.scala145
-rw-r--r--src/test/scala/firrtl/testutils/LeanTransformSpec.scala19
-rw-r--r--src/test/scala/firrtl/testutils/PassTests.scala108
16 files changed, 958 insertions, 724 deletions
diff --git a/src/test/scala/firrtl/JsonProtocolSpec.scala b/src/test/scala/firrtl/JsonProtocolSpec.scala
index 7d04e9fc..cc7591cb 100644
--- a/src/test/scala/firrtl/JsonProtocolSpec.scala
+++ b/src/test/scala/firrtl/JsonProtocolSpec.scala
@@ -4,7 +4,13 @@ package firrtlTests
import org.json4s._
-import firrtl.annotations.{NoTargetAnnotation, JsonProtocol, InvalidAnnotationJSONException, HasSerializationHints, Annotation}
+import firrtl.annotations.{
+ Annotation,
+ HasSerializationHints,
+ InvalidAnnotationJSONException,
+ JsonProtocol,
+ NoTargetAnnotation
+}
import org.scalatest.flatspec.AnyFlatSpec
object JsonProtocolTestClasses {
@@ -13,12 +19,16 @@ object JsonProtocolTestClasses {
case class ChildA(foo: Int) extends Parent
case class ChildB(bar: String) extends Parent
case class PolymorphicParameterAnnotation(param: Parent) extends NoTargetAnnotation
- case class PolymorphicParameterAnnotationWithTypeHints(param: Parent) extends NoTargetAnnotation with HasSerializationHints {
+ case class PolymorphicParameterAnnotationWithTypeHints(param: Parent)
+ extends NoTargetAnnotation
+ with HasSerializationHints {
def typeHints = Seq(param.getClass)
}
case class TypeParameterizedAnnotation[T](param: T) extends NoTargetAnnotation
- case class TypeParameterizedAnnotationWithTypeHints[T](param: T) extends NoTargetAnnotation with HasSerializationHints {
+ case class TypeParameterizedAnnotationWithTypeHints[T](param: T)
+ extends NoTargetAnnotation
+ with HasSerializationHints {
def typeHints = Seq(param.getClass)
}
}
@@ -51,11 +61,11 @@ class JsonProtocolSpec extends AnyFlatSpec {
"Annotations with non-primitive type parameters" should "not serialize and deserialize without type hints" in {
val anno = TypeParameterizedAnnotation(ChildA(1))
val deserAnno = serializeAndDeserialize(anno)
- assert (anno != deserAnno)
+ assert(anno != deserAnno)
}
it should "serialize and deserialize with type hints" in {
val anno = TypeParameterizedAnnotationWithTypeHints(ChildA(1))
val deserAnno = serializeAndDeserialize(anno)
- assert (anno == deserAnno)
+ assert(anno == deserAnno)
}
}
diff --git a/src/test/scala/firrtl/analysis/SymbolTableSpec.scala b/src/test/scala/firrtl/analysis/SymbolTableSpec.scala
index 599b4e52..ca30b60b 100644
--- a/src/test/scala/firrtl/analysis/SymbolTableSpec.scala
+++ b/src/test/scala/firrtl/analysis/SymbolTableSpec.scala
@@ -8,7 +8,7 @@ import firrtl.options.Dependency
import org.scalatest.flatspec.AnyFlatSpec
class SymbolTableSpec extends AnyFlatSpec {
- behavior of "SymbolTable"
+ behavior.of("SymbolTable")
private val src =
"""circuit m:
@@ -50,9 +50,20 @@ class SymbolTableSpec extends AnyFlatSpec {
assert(syms("r").tpe == ir.SIntType(ir.IntWidth(4)) && syms("r").kind == firrtl.RegKind)
val mType = firrtl.passes.MemPortUtils.memType(
// only dataType, depth and reader, writer, readwriter properties affect the data type
- ir.DefMemory(ir.NoInfo, "???", ir.UIntType(ir.IntWidth(8)), 32, 10, 10, Seq("r"), Seq(), Seq(), ir.ReadUnderWrite.New)
+ ir.DefMemory(
+ ir.NoInfo,
+ "???",
+ ir.UIntType(ir.IntWidth(8)),
+ 32,
+ 10,
+ 10,
+ Seq("r"),
+ Seq(),
+ Seq(),
+ ir.ReadUnderWrite.New
+ )
)
- assert(syms("m") .tpe == mType && syms("m").kind == firrtl.MemKind)
+ assert(syms("m").tpe == mType && syms("m").kind == firrtl.MemKind)
}
it should "find all declarations in module m after InferTypes" in {
@@ -69,7 +80,7 @@ class SymbolTableSpec extends AnyFlatSpec {
assert(syms("i").tpe == iType && syms("i").kind == firrtl.InstanceKind)
}
- behavior of "WithSeq"
+ behavior.of("WithSeq")
it should "preserve declaration order" in {
val c = firrtl.Parser.parse(src)
@@ -79,7 +90,7 @@ class SymbolTableSpec extends AnyFlatSpec {
assert(syms.getSymbols.map(_.name) == Seq("clk", "x", "y", "z", "a", "i", "r", "m"))
}
- behavior of "ModuleTypesSymbolTable"
+ behavior.of("ModuleTypesSymbolTable")
it should "derive the module type from the module types map" in {
val c = firrtl.Parser.parse(src)
diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala
index 015ac4a9..f7ce9914 100644
--- a/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala
+++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala
@@ -20,10 +20,16 @@ private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec {
sys.signals.head.e.toString
}
- def primop(signed: Boolean, op: String, resWidth: Int, inWidth: Seq[Int], consts: Seq[Int] = List(),
- resAlwaysUnsigned: Boolean = false): String = {
- val tpe = if(signed) "SInt" else "UInt"
- val resTpe = if(resAlwaysUnsigned) "UInt" else tpe
+ def primop(
+ signed: Boolean,
+ op: String,
+ resWidth: Int,
+ inWidth: Seq[Int],
+ consts: Seq[Int] = List(),
+ resAlwaysUnsigned: Boolean = false
+ ): String = {
+ val tpe = if (signed) "SInt" else "UInt"
+ val resTpe = if (resAlwaysUnsigned) "UInt" else tpe
val inTpes = inWidth.map(w => s"$tpe<$w>")
primop(op, s"$resTpe<$resWidth>", inTpes, consts)
}
@@ -52,16 +58,24 @@ private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec {
it should "correctly translate the `div` primitive operation" in {
// division is a little bit more complicated because the result of division by zero is undefined
- assert(primop(false, "div", 8, List(8, 8)) ==
- "ite(eq(i1, 8'b0), RANDOM.res, udiv(i0, i1))")
- assert(primop(false, "div", 8, List(8, 4)) ==
- "ite(eq(i1, 4'b0), RANDOM.res, udiv(i0, zext(i1, 4)))")
+ assert(
+ primop(false, "div", 8, List(8, 8)) ==
+ "ite(eq(i1, 8'b0), RANDOM.res, udiv(i0, i1))"
+ )
+ assert(
+ primop(false, "div", 8, List(8, 4)) ==
+ "ite(eq(i1, 4'b0), RANDOM.res, udiv(i0, zext(i1, 4)))"
+ )
// signed division increases result width by 1
- assert(primop(true, "div", 8, List(7, 7)) ==
- "ite(eq(i1, 7'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 1)))")
- assert(primop(true, "div", 8, List(7, 4))
- == "ite(eq(i1, 4'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 4)))")
+ assert(
+ primop(true, "div", 8, List(7, 7)) ==
+ "ite(eq(i1, 7'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 1)))"
+ )
+ assert(
+ primop(true, "div", 8, List(7, 4))
+ == "ite(eq(i1, 4'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 4)))"
+ )
}
it should "correctly translate the `rem` primitive operation" in {
@@ -134,15 +148,19 @@ private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec {
it should "correctly translate the `dshl` primitive operation" in {
assert(primop(false, "dshl", 31, List(16, 4)) == "logical_shift_left(zext(i0, 15), zext(i1, 27))")
assert(primop(false, "dshl", 19, List(16, 2)) == "logical_shift_left(zext(i0, 3), zext(i1, 17))")
- assert(primop("dshl", "SInt<19>", List("SInt<16>", "UInt<2>"), List()) ==
- "logical_shift_left(sext(i0, 3), zext(i1, 17))")
+ assert(
+ primop("dshl", "SInt<19>", List("SInt<16>", "UInt<2>"), List()) ==
+ "logical_shift_left(sext(i0, 3), zext(i1, 17))"
+ )
}
it should "correctly translate the `dshr` primitive operation" in {
assert(primop(false, "dshr", 16, List(16, 4)) == "logical_shift_right(i0, zext(i1, 12))")
assert(primop(false, "dshr", 16, List(16, 2)) == "logical_shift_right(i0, zext(i1, 14))")
- assert(primop("dshr", "SInt<16>", List("SInt<16>", "UInt<2>"), List()) ==
- "arithmetic_shift_right(i0, zext(i1, 14))")
+ assert(
+ primop("dshr", "SInt<16>", List("SInt<16>", "UInt<2>"), List()) ==
+ "arithmetic_shift_right(i0, zext(i1, 14))"
+ )
}
it should "correctly translate the `cvt` primitive operation" in {
@@ -197,15 +215,15 @@ private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec {
}
it should "correctly translate the `bits` primitive operation" in {
- assert(primop(false, "bits", 1, List(4), List(2,2)) == "i0[2]")
- assert(primop(false, "bits", 2, List(4), List(2,1)) == "i0[2:1]")
- assert(primop(false, "bits", 1, List(4), List(2,1)) == "i0[2:1][0]")
- assert(primop(false, "bits", 3, List(4), List(2,1)) == "zext(i0[2:1], 1)")
+ assert(primop(false, "bits", 1, List(4), List(2, 2)) == "i0[2]")
+ assert(primop(false, "bits", 2, List(4), List(2, 1)) == "i0[2:1]")
+ assert(primop(false, "bits", 1, List(4), List(2, 1)) == "i0[2:1][0]")
+ assert(primop(false, "bits", 3, List(4), List(2, 1)) == "zext(i0[2:1], 1)")
- assert(primop(true, "bits", 1, List(4), List(2,2), resAlwaysUnsigned = true) == "i0[2]")
- assert(primop(true, "bits", 2, List(4), List(2,1), resAlwaysUnsigned = true) == "i0[2:1]")
- assert(primop(true, "bits", 1, List(4), List(2,1), resAlwaysUnsigned = true) == "i0[2:1][0]")
- assert(primop(true, "bits", 3, List(4), List(2,1), resAlwaysUnsigned = true) == "zext(i0[2:1], 1)")
+ assert(primop(true, "bits", 1, List(4), List(2, 2), resAlwaysUnsigned = true) == "i0[2]")
+ assert(primop(true, "bits", 2, List(4), List(2, 1), resAlwaysUnsigned = true) == "i0[2:1]")
+ assert(primop(true, "bits", 1, List(4), List(2, 1), resAlwaysUnsigned = true) == "i0[2:1][0]")
+ assert(primop(true, "bits", 3, List(4), List(2, 1), resAlwaysUnsigned = true) == "zext(i0[2:1], 1)")
}
it should "correctly translate the `head` primitive operation" in {
@@ -221,4 +239,4 @@ private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec {
assert(primop(false, "tail", 4, List(5), List(1)) == "i0[3:0]")
assert(primop(false, "tail", 2, List(5), List(3)) == "i0[1:0]")
}
-} \ No newline at end of file
+}
diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala
index ca7974c5..b41313e3 100644
--- a/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala
+++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala
@@ -5,8 +5,7 @@ package firrtl.backends.experimental.smt
import firrtl.{MemoryArrayInit, MemoryScalarInit, Utils}
private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec {
- behavior of "ModuleToTransitionSystem.run"
-
+ behavior.of("ModuleToTransitionSystem.run")
it should "model registers as state" in {
// if a signal is invalid, it could take on an arbitrary value in that cycle
@@ -42,39 +41,39 @@ private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec {
private def memCircuit(depth: Int = 32) =
s"""circuit m:
- | module m:
- | input reset : UInt<1>
- | input clock : Clock
- | input addr : UInt<${Utils.getUIntWidth(depth)}>
- | input in : UInt<8>
- | output out : UInt<8>
- |
- | mem m:
- | data-type => UInt<8>
- | depth => $depth
- | reader => r
- | writer => w
- | read-latency => 0
- | write-latency => 1
- | read-under-write => new
- |
- | m.w.clk <= clock
- | m.w.mask <= UInt(1)
- | m.w.en <= UInt(1)
- | m.w.data <= in
- | m.w.addr <= addr
- |
- | m.r.clk <= clock
- | m.r.en <= UInt(1)
- | out <= m.r.data
- | m.r.addr <= addr
- |
- |""".stripMargin
+ | module m:
+ | input reset : UInt<1>
+ | input clock : Clock
+ | input addr : UInt<${Utils.getUIntWidth(depth)}>
+ | input in : UInt<8>
+ | output out : UInt<8>
+ |
+ | mem m:
+ | data-type => UInt<8>
+ | depth => $depth
+ | reader => r
+ | writer => w
+ | read-latency => 0
+ | write-latency => 1
+ | read-under-write => new
+ |
+ | m.w.clk <= clock
+ | m.w.mask <= UInt(1)
+ | m.w.en <= UInt(1)
+ | m.w.data <= in
+ | m.w.addr <= addr
+ |
+ | m.r.clk <= clock
+ | m.r.en <= UInt(1)
+ | out <= m.r.data
+ | m.r.addr <= addr
+ |
+ |""".stripMargin
it should "model memories as state" in {
val sys = toSys(memCircuit())
- assert(sys.signals.length == 9-2+1, "9 connects - 2 clock connects + 1 combinatorial read port")
+ assert(sys.signals.length == 9 - 2 + 1, "9 connects - 2 clock connects + 1 combinatorial read port")
val sig = sys.signals.map(s => s.name -> s.e).toMap
@@ -140,40 +139,39 @@ private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec {
it should "support memories with registered read port" in {
def src(readUnderWrite: String) =
s"""circuit m:
- | module m:
- | input reset : UInt<1>
- | input clock : Clock
- | input addr : UInt<5>
- | input in : UInt<8>
- | output out : UInt<8>
- |
- | mem m:
- | data-type => UInt<8>
- | depth => 32
- | reader => r
- | writer => w1, w2
- | read-latency => 1
- | write-latency => 1
- | read-under-write => $readUnderWrite
- |
- | m.w1.clk <= clock
- | m.w1.mask <= UInt(1)
- | m.w1.en <= UInt(1)
- | m.w1.data <= in
- | m.w1.addr <= addr
- | m.w2.clk <= clock
- | m.w2.mask <= UInt(1)
- | m.w2.en <= UInt(1)
- | m.w2.data <= in
- | m.w2.addr <= addr
- |
- | m.r.clk <= clock
- | m.r.en <= UInt(1)
- | out <= m.r.data
- | m.r.addr <= addr
- |
- |""".stripMargin
-
+ | module m:
+ | input reset : UInt<1>
+ | input clock : Clock
+ | input addr : UInt<5>
+ | input in : UInt<8>
+ | output out : UInt<8>
+ |
+ | mem m:
+ | data-type => UInt<8>
+ | depth => 32
+ | reader => r
+ | writer => w1, w2
+ | read-latency => 1
+ | write-latency => 1
+ | read-under-write => $readUnderWrite
+ |
+ | m.w1.clk <= clock
+ | m.w1.mask <= UInt(1)
+ | m.w1.en <= UInt(1)
+ | m.w1.data <= in
+ | m.w1.addr <= addr
+ | m.w2.clk <= clock
+ | m.w2.mask <= UInt(1)
+ | m.w2.en <= UInt(1)
+ | m.w2.data <= in
+ | m.w2.addr <= addr
+ |
+ | m.r.clk <= clock
+ | m.r.en <= UInt(1)
+ | out <= m.r.data
+ | m.r.addr <= addr
+ |
+ |""".stripMargin
val oldValue = toSys(src("old"))
val oldMData = oldValue.states.find(_.sym.name == "m.r.data").get
@@ -186,9 +184,11 @@ private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec {
val undefinedMData = undefinedValue.states.find(_.sym.name == "m.r.data").get
assert(undefinedMData.sym.toString == "m.r.data")
val undefined = "RANDOM.m_r_read_under_write_undefined"
- assert(undefinedMData.next.get.toString ==
- s"ite(or(eq(m.r.addr, m.w1.addr), eq(m.r.addr, m.w2.addr)), $undefined, m[m.r.addr])",
- "randomize result if there is a write")
+ assert(
+ undefinedMData.next.get.toString ==
+ s"ite(or(eq(m.r.addr, m.w1.addr), eq(m.r.addr, m.w2.addr)), $undefined, m[m.r.addr])",
+ "randomize result if there is a write"
+ )
}
it should "support memories with potential write-write conflicts" in {
@@ -228,7 +228,6 @@ private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec {
|
|""".stripMargin
-
val sys = toSys(src)
val m = sys.states.find(_.sym.name == "m").get
diff --git a/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala
index 6bfb5437..209279fd 100644
--- a/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala
+++ b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala
@@ -3,7 +3,7 @@
package firrtl.backends.experimental.smt
import firrtl.annotations.Annotation
-import firrtl.{MemoryInitValue, ir}
+import firrtl.{ir, MemoryInitValue}
import firrtl.stage.{Forms, TransformManager}
import org.scalatest.flatspec.AnyFlatSpec
@@ -16,8 +16,12 @@ private abstract class SMTBackendBaseSpec extends AnyFlatSpec {
compiler.runTransform(firrtl.CircuitState(c, annos)).circuit
}
- protected def toSys(src: String, mod: String = "m", presetRegs: Set[String] = Set(),
- memInit: Map[String, MemoryInitValue] = Map()): TransitionSystem = {
+ protected def toSys(
+ src: String,
+ mod: String = "m",
+ presetRegs: Set[String] = Set(),
+ memInit: Map[String, MemoryInitValue] = Map()
+ ): TransitionSystem = {
val circuit = compile(src)
val module = circuit.modules.find(_.name == mod).get.asInstanceOf[ir.Module]
// println(module.serialize)
@@ -35,4 +39,4 @@ private abstract class SMTBackendBaseSpec extends AnyFlatSpec {
protected def toSMTLibStr(src: String, mod: String = "m"): String =
toSMTLib(src, mod).mkString("\n") + "\n"
-} \ No newline at end of file
+}
diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala
index 4c6901ea..e7c8d534 100644
--- a/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala
+++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala
@@ -10,9 +10,10 @@ import firrtl.stage.RunFirrtlTransformAnnotation
class AsyncResetSpec extends EndToEndSMTBaseSpec {
def annos(name: String) = Seq(
RunFirrtlTransformAnnotation(Dependency[StutteringClockTransform]),
- GlobalClockAnnotation(CircuitTarget(name).module(name).ref("global_clock")))
+ GlobalClockAnnotation(CircuitTarget(name).module(name).ref("global_clock"))
+ )
- "a module with asynchronous reset" should "allow a register to change between clock edges" taggedAs(RequiresZ3) in {
+ "a module with asynchronous reset" should "allow a register to change between clock edges" taggedAs (RequiresZ3) in {
def in(resetType: String) =
s"""circuit AsyncReset00:
| module AsyncReset00:
@@ -39,8 +40,8 @@ class AsyncResetSpec extends EndToEndSMTBaseSpec {
| ; can the value of r change without the count changing?
| assert(global_clock, or(not(eq(count, past_count)), eq(r, past_r)), past_valid, "count = past(count) |-> r = past(r)")
|""".stripMargin
- test(in("AsyncReset"), MCFail(1), kmax=2, annos=annos("AsyncReset00"))
- test(in("UInt<1>"), MCSuccess, kmax=2, annos=annos("AsyncReset00"))
+ test(in("AsyncReset"), MCFail(1), kmax = 2, annos = annos("AsyncReset00"))
+ test(in("UInt<1>"), MCSuccess, kmax = 2, annos = annos("AsyncReset00"))
}
}
diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala
index 2227719b..974d2e81 100644
--- a/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala
+++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala
@@ -16,24 +16,23 @@ import org.scalatest.matchers.must.Matchers
import scala.sys.process._
class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging {
- "we" should "check if Z3 is available" taggedAs(RequiresZ3) in {
+ "we" should "check if Z3 is available" taggedAs (RequiresZ3) in {
val log = ProcessLogger(_ => (), logger.warn(_))
val ret = Process(Seq("which", "z3")).run(log).exitValue()
- if(ret != 0) {
- logger.error(
- """The z3 SMT-Solver seems not to be installed.
- |You can exclude the end-to-end smt backend tests which rely on z3 like this:
- |sbt testOnly -- -l RequiresZ3
- |""".stripMargin)
+ if (ret != 0) {
+ logger.error("""The z3 SMT-Solver seems not to be installed.
+ |You can exclude the end-to-end smt backend tests which rely on z3 like this:
+ |sbt testOnly -- -l RequiresZ3
+ |""".stripMargin)
}
assert(ret == 0)
}
- "Z3" should "be available in version 4" taggedAs(RequiresZ3) in {
+ "Z3" should "be available in version 4" taggedAs (RequiresZ3) in {
assert(Z3ModelChecker.getZ3Version.startsWith("4."))
}
- "a simple combinatorial check" should "pass" taggedAs(RequiresZ3) in {
+ "a simple combinatorial check" should "pass" taggedAs (RequiresZ3) in {
val in =
"""circuit CC00:
| module CC00:
@@ -45,7 +44,7 @@ class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging {
test(in, MCSuccess)
}
- "a simple combinatorial check" should "fail immediately" taggedAs(RequiresZ3) in {
+ "a simple combinatorial check" should "fail immediately" taggedAs (RequiresZ3) in {
val in =
"""circuit CC01:
| module CC01:
@@ -57,7 +56,7 @@ class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging {
test(in, MCFail(0))
}
- "adding the right assumption" should "make a test pass" taggedAs(RequiresZ3) in {
+ "adding the right assumption" should "make a test pass" taggedAs (RequiresZ3) in {
val in0 =
"""circuit CC01:
| module CC01:
@@ -75,8 +74,8 @@ class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging {
| assert(c, neq(add(a, b), UInt(0)), UInt(1), "a + b != 0")
| assume(c, neq(a, UInt(0)), UInt(1), "a != 0")
|""".stripMargin
- test(in0, MCFail(0))
- test(in1, MCSuccess)
+ test(in0, MCFail(0))
+ test(in1, MCSuccess)
val in2 =
"""circuit CC01:
@@ -91,20 +90,20 @@ class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging {
test(in2, MCFail(0))
}
- "a register connected to preset reset" should "be initialized with the reset value" taggedAs(RequiresZ3) in {
+ "a register connected to preset reset" should "be initialized with the reset value" taggedAs (RequiresZ3) in {
def in(rEq: Int) =
s"""circuit Preset00:
- | module Preset00:
- | input c: Clock
- | input preset: AsyncReset
- | reg r: UInt<4>, c with: (reset => (preset, UInt(3)))
- | assert(c, eq(r, UInt($rEq)), UInt(1), "r = $rEq")
- |""".stripMargin
+ | module Preset00:
+ | input c: Clock
+ | input preset: AsyncReset
+ | reg r: UInt<4>, c with: (reset => (preset, UInt(3)))
+ | assert(c, eq(r, UInt($rEq)), UInt(1), "r = $rEq")
+ |""".stripMargin
test(in(3), MCSuccess, kmax = 1)
test(in(2), MCFail(0))
}
- "a register's initial value" should "should not change" taggedAs(RequiresZ3) in {
+ "a register's initial value" should "should not change" taggedAs (RequiresZ3) in {
val in =
"""circuit Preset00:
| module Preset00:
@@ -127,24 +126,29 @@ class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging {
abstract class EndToEndSMTBaseSpec extends AnyFlatSpec with Matchers {
def test(src: String, expected: MCResult, kmax: Int = 0, clue: String = "", annos: Seq[Annotation] = Seq()): Unit = {
expected match {
- case MCFail(k) => assert(kmax >= k, s"Please set a kmax that includes the expected failing step! ($kmax < $expected)")
+ case MCFail(k) =>
+ assert(kmax >= k, s"Please set a kmax that includes the expected failing step! ($kmax < $expected)")
case _ =>
}
val fir = firrtl.Parser.parse(src)
val name = fir.main
val testDir = BackendCompilationUtilities.createTestDirectory("EndToEndSMT." + name)
// we automagically add a preset annotation if an input called preset exists
- val presetAnno = if(!src.contains("input preset")) { None } else {
+ val presetAnno = if (!src.contains("input preset")) { None }
+ else {
Some(PresetAnnotation(CircuitTarget(name).module(name).ref("preset")))
}
- val res = (new FirrtlStage).execute(Array(), Seq(
- LogLevelAnnotation(LogLevel.Error), // silence warnings for tests
- RunFirrtlTransformAnnotation(new SMTLibEmitter),
- RunFirrtlTransformAnnotation(new Btor2Emitter),
- FirrtlCircuitAnnotation(fir),
- TargetDirAnnotation(testDir.getAbsolutePath)
- ) ++ presetAnno ++ annos)
- assert(res.collectFirst{ case _: OutputFileAnnotation => true }.isDefined)
+ val res = (new FirrtlStage).execute(
+ Array(),
+ Seq(
+ LogLevelAnnotation(LogLevel.Error), // silence warnings for tests
+ RunFirrtlTransformAnnotation(new SMTLibEmitter),
+ RunFirrtlTransformAnnotation(new Btor2Emitter),
+ FirrtlCircuitAnnotation(fir),
+ TargetDirAnnotation(testDir.getAbsolutePath)
+ ) ++ presetAnno ++ annos
+ )
+ assert(res.collectFirst { case _: OutputFileAnnotation => true }.isDefined)
val r = Z3ModelChecker.bmc(testDir, name, kmax)
assert(r == expected, clue + "\n" + s"$testDir")
}
@@ -153,7 +157,7 @@ abstract class EndToEndSMTBaseSpec extends AnyFlatSpec with Matchers {
/** Minimal implementation of a Z3 based bounded model checker.
* A more complete version of this with better use feedback should eventually be provided by a
* chisel3 formal verification library. Do not use this implementation outside of the firrtl test suite!
- * */
+ */
private object Z3ModelChecker extends LazyLogging {
def getZ3Version: String = {
val (out, ret) = executeCmd("-version")
@@ -164,14 +168,15 @@ private object Z3ModelChecker extends LazyLogging {
}
def bmc(testDir: File, main: String, kmax: Int): MCResult = {
- assert(kmax >=0 && kmax < 50, "Trying to keep kmax in a reasonable range.")
+ assert(kmax >= 0 && kmax < 50, "Trying to keep kmax in a reasonable range.")
val smtFile = new File(testDir, main + ".smt2")
val header = read(smtFile)
val steps = (0 to kmax).map(k => new File(testDir, main + s"_step$k.smt2")).zipWithIndex
- steps.foreach { case (f,k) =>
- writeStep(f, main, header, k)
- val success = executeStep(f.getAbsolutePath)
- if(!success) return MCFail(k)
+ steps.foreach {
+ case (f, k) =>
+ writeStep(f, main, header, k)
+ val success = executeStep(f.getAbsolutePath)
+ if (!success) return MCFail(k)
}
MCSuccess
}
@@ -200,21 +205,22 @@ private object Z3ModelChecker extends LazyLogging {
private def step(main: String, k: Int): Iterable[String] = {
// define all states
(0 to k).map(ii => s"(declare-fun s$ii () $main$StateTpe)") ++
- // assert that init holds in state 0
- List(s"(assert ($main$Init s0))") ++
- // assert transition relation
- (0 until k).map(ii => s"(assert ($main$Transition s$ii s${ii+1}))") ++
- // assert that assumptions hold in all states
- (0 to k).map(ii => s"(assert ($main$Assumes s$ii))") ++
- // assert that assertions hold for all but last state
- (0 until k).map(ii => s"(assert ($main$Asserts s$ii))") ++
- // check to see if we can violate the assertions in the last state
- List(s"(assert (not ($main$Asserts s$k)))")
+ // assert that init holds in state 0
+ List(s"(assert ($main$Init s0))") ++
+ // assert transition relation
+ (0 until k).map(ii => s"(assert ($main$Transition s$ii s${ii + 1}))") ++
+ // assert that assumptions hold in all states
+ (0 to k).map(ii => s"(assert ($main$Assumes s$ii))") ++
+ // assert that assertions hold for all but last state
+ (0 until k).map(ii => s"(assert ($main$Asserts s$ii))") ++
+ // check to see if we can violate the assertions in the last state
+ List(s"(assert (not ($main$Asserts s$k)))")
}
private def read(f: File): Iterable[String] = {
val source = scala.io.Source.fromFile(f)
- try source.getLines().toVector finally source.close()
+ try source.getLines().toVector
+ finally source.close()
}
// the following suffixes have to match the ones in [[SMTTransitionSystemEncoder]]
diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala
index 10de9cda..61e1f0f8 100644
--- a/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala
+++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala
@@ -9,43 +9,43 @@ class MemorySpec extends EndToEndSMTBaseSpec {
registeredTestMem(name, cmds.split("\n"), readUnderWrite)
private def registeredTestMem(name: String, cmds: Iterable[String], readUnderWrite: String): String =
s"""circuit $name:
- | module $name:
- | input reset : UInt<1>
- | input clock : Clock
- | input preset: AsyncReset
- | input write_addr : UInt<5>
- | input read_addr : UInt<5>
- | input in : UInt<8>
- | output out : UInt<8>
- |
- | mem m:
- | data-type => UInt<8>
- | depth => 32
- | reader => r
- | writer => w
- | read-latency => 1
- | write-latency => 1
- | read-under-write => $readUnderWrite
- |
- | m.w.clk <= clock
- | m.w.mask <= UInt(1)
- | m.w.en <= UInt(1)
- | m.w.data <= in
- | m.w.addr <= write_addr
- |
- | m.r.clk <= clock
- | m.r.en <= UInt(1)
- | out <= m.r.data
- | m.r.addr <= read_addr
- |
- | reg cycle: UInt<8>, clock with: (reset => (preset, UInt(0)))
- | cycle <= add(cycle, UInt(1))
- | node past_valid = geq(cycle, UInt(1))
- |
- | ${cmds.mkString("\n ")}
- |""".stripMargin
+ | module $name:
+ | input reset : UInt<1>
+ | input clock : Clock
+ | input preset: AsyncReset
+ | input write_addr : UInt<5>
+ | input read_addr : UInt<5>
+ | input in : UInt<8>
+ | output out : UInt<8>
+ |
+ | mem m:
+ | data-type => UInt<8>
+ | depth => 32
+ | reader => r
+ | writer => w
+ | read-latency => 1
+ | write-latency => 1
+ | read-under-write => $readUnderWrite
+ |
+ | m.w.clk <= clock
+ | m.w.mask <= UInt(1)
+ | m.w.en <= UInt(1)
+ | m.w.data <= in
+ | m.w.addr <= write_addr
+ |
+ | m.r.clk <= clock
+ | m.r.en <= UInt(1)
+ | out <= m.r.data
+ | m.r.addr <= read_addr
+ |
+ | reg cycle: UInt<8>, clock with: (reset => (preset, UInt(0)))
+ | cycle <= add(cycle, UInt(1))
+ | node past_valid = geq(cycle, UInt(1))
+ |
+ | ${cmds.mkString("\n ")}
+ |""".stripMargin
- "Registered test memory" should "return written data after two cycles" taggedAs(RequiresZ3) in {
+ "Registered test memory" should "return written data after two cycles" taggedAs (RequiresZ3) in {
val cmds =
"""node past_past_valid = geq(cycle, UInt(2))
|reg past_in: UInt<8>, clock
@@ -85,23 +85,29 @@ class MemorySpec extends EndToEndSMTBaseSpec {
|""".stripMargin
private def m(num: Int) = CircuitTarget(s"Mem0$num").module(s"Mem0$num").ref("m")
- "read-only memory" should "always return 0" taggedAs(RequiresZ3) in {
- test(readOnlyMem("eq(out, UInt(0))", 1), MCSuccess, kmax=2,
- annos=Seq(MemoryScalarInitAnnotation(m(1), 0)))
+ "read-only memory" should "always return 0" taggedAs (RequiresZ3) in {
+ test(readOnlyMem("eq(out, UInt(0))", 1), MCSuccess, kmax = 2, annos = Seq(MemoryScalarInitAnnotation(m(1), 0)))
}
- "read-only memory" should "not always return 1" taggedAs(RequiresZ3) in {
- test(readOnlyMem("eq(out, UInt(1))", 2), MCFail(0), kmax=2,
- annos=Seq(MemoryScalarInitAnnotation(m(2), 0)))
+ "read-only memory" should "not always return 1" taggedAs (RequiresZ3) in {
+ test(readOnlyMem("eq(out, UInt(1))", 2), MCFail(0), kmax = 2, annos = Seq(MemoryScalarInitAnnotation(m(2), 0)))
}
- "read-only memory" should "always return 1 or 2" taggedAs(RequiresZ3) in {
- test(readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 3), MCSuccess, kmax=2,
- annos=Seq(MemoryArrayInitAnnotation(m(3), Seq(1, 2, 2, 1))))
+ "read-only memory" should "always return 1 or 2" taggedAs (RequiresZ3) in {
+ test(
+ readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 3),
+ MCSuccess,
+ kmax = 2,
+ annos = Seq(MemoryArrayInitAnnotation(m(3), Seq(1, 2, 2, 1)))
+ )
}
- "read-only memory" should "not always return 1 or 2 or 3" taggedAs(RequiresZ3) in {
- test(readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 4), MCFail(0), kmax=2,
- annos=Seq(MemoryArrayInitAnnotation(m(4), Seq(1, 2, 2, 3))))
+ "read-only memory" should "not always return 1 or 2 or 3" taggedAs (RequiresZ3) in {
+ test(
+ readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 4),
+ MCFail(0),
+ kmax = 2,
+ annos = Seq(MemoryArrayInitAnnotation(m(4), Seq(1, 2, 2, 3)))
+ )
}
}
diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala
index cbf194dd..8ece0e23 100644
--- a/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala
+++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala
@@ -13,15 +13,17 @@ import scala.sys.process.{Process, ProcessLogger}
/** compiles the regression tests to SMTLib and parses the result with z3 */
class SMTCompilationTest extends AnyFlatSpec with LazyLogging {
- it should "generate valid SMTLib for AddNot" taggedAs(RequiresZ3) in { compileAndParse("AddNot") }
- it should "generate valid SMTLib for FPU" taggedAs(RequiresZ3) in { compileAndParse("FPU") }
+ it should "generate valid SMTLib for AddNot" taggedAs (RequiresZ3) in { compileAndParse("AddNot") }
+ it should "generate valid SMTLib for FPU" taggedAs (RequiresZ3) in { compileAndParse("FPU") }
// we get a stack overflow in Scala 2.11 because of a deeply nested and(...) expression in the sequencer
- it should "generate valid SMTLib for HwachaSequencer" taggedAs(RequiresZ3) ignore { compileAndParse("HwachaSequencer") }
- it should "generate valid SMTLib for ICache" taggedAs(RequiresZ3) in { compileAndParse("ICache") }
- it should "generate valid SMTLib for Ops" taggedAs(RequiresZ3) in { compileAndParse("Ops") }
+ it should "generate valid SMTLib for HwachaSequencer" taggedAs (RequiresZ3) ignore {
+ compileAndParse("HwachaSequencer")
+ }
+ it should "generate valid SMTLib for ICache" taggedAs (RequiresZ3) in { compileAndParse("ICache") }
+ it should "generate valid SMTLib for Ops" taggedAs (RequiresZ3) in { compileAndParse("Ops") }
// TODO: enable Rob test once we support more than 2 write ports on a memory
- it should "generate valid SMTLib for Rob" taggedAs(RequiresZ3) ignore { compileAndParse("Rob") }
- it should "generate valid SMTLib for RocketCore" taggedAs(RequiresZ3) in { compileAndParse("RocketCore") }
+ it should "generate valid SMTLib for Rob" taggedAs (RequiresZ3) ignore { compileAndParse("Rob") }
+ it should "generate valid SMTLib for RocketCore" taggedAs (RequiresZ3) in { compileAndParse("RocketCore") }
private def compileAndParse(name: String): Unit = {
val testDir = BackendCompilationUtilities.createTestDirectory(name + "-smt")
@@ -29,14 +31,18 @@ class SMTCompilationTest extends AnyFlatSpec with LazyLogging {
BackendCompilationUtilities.copyResourceToFile(s"/regress/${name}.fir", inputFile)
val args = Array(
- "-ll", "error", // surpress warnings to keep test output clean
- "--target-dir", testDir.toString,
- "-i", inputFile.toString,
- "-E", "experimental-smt2"
+ "-ll",
+ "error", // surpress warnings to keep test output clean
+ "--target-dir",
+ testDir.toString,
+ "-i",
+ inputFile.toString,
+ "-E",
+ "experimental-smt2"
// "-fct", "firrtl.backends.experimental.smt.StutteringClockTransform"
)
val res = (new FirrtlStage).execute(args, Seq())
- val fileName = res.collectFirst{ case OutputFileAnnotation(file) => file }.get
+ val fileName = res.collectFirst { case OutputFileAnnotation(file) => file }.get
val smtFile = testDir.toString + "/" + fileName + ".smt2"
val log = ProcessLogger(_ => (), logger.error(_))
diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala
index 8fa80b4c..8682c2ce 100644
--- a/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala
+++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala
@@ -5,27 +5,26 @@ package firrtl.backends.experimental.smt.end2end
/** undefined values in firrtl are modelled as fresh auxiliary variables (inputs) */
class UndefinedFirrtlSpec extends EndToEndSMTBaseSpec {
- "division by zero" should "result in an arbitrary value" taggedAs(RequiresZ3) in {
+ "division by zero" should "result in an arbitrary value" taggedAs (RequiresZ3) in {
// the SMTLib spec defines the result of division by zero to be all 1s
// https://cs.nyu.edu/pipermail/smt-lib/2015/000977.html
def in(dEq: Int) =
- s"""circuit CC00:
- | module CC00:
- | input c: Clock
- | input a: UInt<2>
- | input b: UInt<2>
- | assume(c, eq(b, UInt(0)), UInt(1), "b = 0")
- | node d = div(a, b)
- | assert(c, eq(d, UInt($dEq)), UInt(1), "d = $dEq")
- |""".stripMargin
+ s"""circuit CC00:
+ | module CC00:
+ | input c: Clock
+ | input a: UInt<2>
+ | input b: UInt<2>
+ | assume(c, eq(b, UInt(0)), UInt(1), "b = 0")
+ | node d = div(a, b)
+ | assert(c, eq(d, UInt($dEq)), UInt(1), "d = $dEq")
+ |""".stripMargin
// we try to assert that (d = a / 0) is any fixed value which should be false
(0 until 4).foreach { ii => test(in(ii), MCFail(0), 0, s"d = a / 0 = $ii") }
}
// TODO: rem should probably also be undefined, but the spec isn't 100% clear here
-
- "invalid signals" should "have an arbitrary values" taggedAs(RequiresZ3) in {
+ "invalid signals" should "have an arbitrary values" taggedAs (RequiresZ3) in {
def in(aEq: Int) =
s"""circuit CC00:
| module CC00:
diff --git a/src/test/scala/firrtl/ir/StructuralHashSpec.scala b/src/test/scala/firrtl/ir/StructuralHashSpec.scala
index 17fe0b84..c4622939 100644
--- a/src/test/scala/firrtl/ir/StructuralHashSpec.scala
+++ b/src/test/scala/firrtl/ir/StructuralHashSpec.scala
@@ -6,11 +6,11 @@ import firrtl.PrimOps._
import org.scalatest.flatspec.AnyFlatSpec
class StructuralHashSpec extends AnyFlatSpec {
- private def hash(n: DefModule): HashCode = StructuralHash.sha256(n, n => n)
- private def hash(c: Circuit): HashCode = StructuralHash.sha256Node(c)
+ private def hash(n: DefModule): HashCode = StructuralHash.sha256(n, n => n)
+ private def hash(c: Circuit): HashCode = StructuralHash.sha256Node(c)
private def hash(e: Expression): HashCode = StructuralHash.sha256Node(e)
- private def hash(t: Type): HashCode = StructuralHash.sha256Node(t)
- private def hash(s: Statement): HashCode = StructuralHash.sha256Node(s)
+ private def hash(t: Type): HashCode = StructuralHash.sha256Node(t)
+ private def hash(s: Statement): HashCode = StructuralHash.sha256Node(s)
private val highFirrtlCompiler = new firrtl.stage.transforms.Compiler(
targets = firrtl.stage.Forms.HighForm
)
@@ -24,18 +24,18 @@ class StructuralHashSpec extends AnyFlatSpec {
highFirrtlCompiler.transform(firrtl.CircuitState(rawFirrtl, Seq())).circuit
}
- private val b0 = UIntLiteral(0,IntWidth(1))
- private val b1 = UIntLiteral(1,IntWidth(1))
+ private val b0 = UIntLiteral(0, IntWidth(1))
+ private val b1 = UIntLiteral(1, IntWidth(1))
private val add = DoPrim(Add, Seq(b0, b1), Seq(), UnknownType)
it should "generate the same hash if the objects are structurally the same" in {
- assert(hash(b0) == hash(UIntLiteral(0,IntWidth(1))))
- assert(hash(b0) != hash(UIntLiteral(1,IntWidth(1))))
- assert(hash(b0) != hash(UIntLiteral(1,IntWidth(2))))
+ assert(hash(b0) == hash(UIntLiteral(0, IntWidth(1))))
+ assert(hash(b0) != hash(UIntLiteral(1, IntWidth(1))))
+ assert(hash(b0) != hash(UIntLiteral(1, IntWidth(2))))
- assert(hash(b1) == hash(UIntLiteral(1,IntWidth(1))))
- assert(hash(b1) != hash(UIntLiteral(0,IntWidth(1))))
- assert(hash(b1) != hash(UIntLiteral(1,IntWidth(2))))
+ assert(hash(b1) == hash(UIntLiteral(1, IntWidth(1))))
+ assert(hash(b1) != hash(UIntLiteral(0, IntWidth(1))))
+ assert(hash(b1) != hash(UIntLiteral(1, IntWidth(2))))
}
it should "ignore expression types" in {
@@ -84,16 +84,19 @@ class StructuralHashSpec extends AnyFlatSpec {
|""".stripMargin
assert(hash(parse(a)) != hash(parse(d)), "circuits with different names are always different")
- assert(hash(parse(a).modules.head) == hash(parse(d).modules.head),
- "modules with different names can be structurally different")
+ assert(
+ hash(parse(a).modules.head) == hash(parse(d).modules.head),
+ "modules with different names can be structurally different"
+ )
// for the Dedup pass we do need a way to take the port names into account
- assert(StructuralHash.sha256WithSignificantPortNames(parse(a).modules.head) !=
- StructuralHash.sha256WithSignificantPortNames(parse(b).modules.head),
- "renaming ports does affect the hash if we ask to")
+ assert(
+ StructuralHash.sha256WithSignificantPortNames(parse(a).modules.head) !=
+ StructuralHash.sha256WithSignificantPortNames(parse(b).modules.head),
+ "renaming ports does affect the hash if we ask to"
+ )
}
-
it should "not ignore port names if asked to" in {
val e =
"""circuit a:
@@ -119,14 +122,20 @@ class StructuralHashSpec extends AnyFlatSpec {
| z <= x
|""".stripMargin
- assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) !=
- StructuralHash.sha256WithSignificantPortNames(parse(f).modules.head),
- "renaming ports does affect the hash if we ask to")
- assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) ==
- StructuralHash.sha256WithSignificantPortNames(parse(g).modules.head),
- "renaming internal wires should never affect the hash")
- assert(hash(parse(e).modules.head) == hash(parse(g).modules.head),
- "renaming internal wires should never affect the hash")
+ assert(
+ StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) !=
+ StructuralHash.sha256WithSignificantPortNames(parse(f).modules.head),
+ "renaming ports does affect the hash if we ask to"
+ )
+ assert(
+ StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) ==
+ StructuralHash.sha256WithSignificantPortNames(parse(g).modules.head),
+ "renaming internal wires should never affect the hash"
+ )
+ assert(
+ hash(parse(e).modules.head) == hash(parse(g).modules.head),
+ "renaming internal wires should never affect the hash"
+ )
}
it should "not ignore port bundle names if asked to" in {
@@ -154,19 +163,26 @@ class StructuralHashSpec extends AnyFlatSpec {
| y.z <= x.x
|""".stripMargin
- assert(hash(parse(e).modules.head) == hash(parse(f).modules.head),
- "renaming port bundles does normally not affect the hash")
- assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) !=
- StructuralHash.sha256WithSignificantPortNames(parse(f).modules.head),
- "renaming port bundles does affect the hash if we ask to")
- assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) ==
- StructuralHash.sha256WithSignificantPortNames(parse(g).modules.head),
- "renaming internal wire bundles should never affect the hash")
- assert(hash(parse(e).modules.head) == hash(parse(g).modules.head),
- "renaming internal wire bundles should never affect the hash")
+ assert(
+ hash(parse(e).modules.head) == hash(parse(f).modules.head),
+ "renaming port bundles does normally not affect the hash"
+ )
+ assert(
+ StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) !=
+ StructuralHash.sha256WithSignificantPortNames(parse(f).modules.head),
+ "renaming port bundles does affect the hash if we ask to"
+ )
+ assert(
+ StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) ==
+ StructuralHash.sha256WithSignificantPortNames(parse(g).modules.head),
+ "renaming internal wire bundles should never affect the hash"
+ )
+ assert(
+ hash(parse(e).modules.head) == hash(parse(g).modules.head),
+ "renaming internal wire bundles should never affect the hash"
+ )
}
-
it should "fail on Info" in {
// it does not make sense to hash Info nodes
assertThrows[RuntimeException] {
@@ -178,9 +194,9 @@ class StructuralHashSpec extends AnyFlatSpec {
def parse(str: String): BundleType = {
val src =
s"""circuit c:
- | module c:
- | input z: $str
- |""".stripMargin
+ | module c:
+ | input z: $str
+ |""".stripMargin
val c = firrtl.Parser.parse(src)
val tpe = c.modules.head.ports.head.tpe
tpe.asInstanceOf[BundleType]
@@ -219,11 +235,15 @@ class StructuralHashSpec extends AnyFlatSpec {
// Q: should extmodule portnames always be significant since they map to the verilog pins?
// A: It would be a bug for two exmodules in the same circuit to have the same defname but different
// port names. This should be detected by an earlier pass and thus we do not have to deal with that situation.
- assert(hash(parse(a).modules.head) == hash(parse(b).modules.head),
- "two ext modules with the same defname and the same type and number of ports")
- assert(StructuralHash.sha256WithSignificantPortNames(parse(a).modules.head) !=
- StructuralHash.sha256WithSignificantPortNames(parse(b).modules.head),
- "two ext modules with significant port names")
+ assert(
+ hash(parse(a).modules.head) == hash(parse(b).modules.head),
+ "two ext modules with the same defname and the same type and number of ports"
+ )
+ assert(
+ StructuralHash.sha256WithSignificantPortNames(parse(a).modules.head) !=
+ StructuralHash.sha256WithSignificantPortNames(parse(b).modules.head),
+ "two ext modules with significant port names"
+ )
}
"Blocks and empty statements" should "not affect structural equivalence" in {
@@ -269,9 +289,9 @@ class StructuralHashSpec extends AnyFlatSpec {
}
private case object DebugHasher extends Hasher {
- override def update(b: Byte): Unit = println(s"b(${b.toInt & 0xff})")
- override def update(i: Int): Unit = println(s"i(${i})")
- override def update(l: Long): Unit = println(s"l(${l})")
- override def update(s: String): Unit = println(s"s(${s})")
+ override def update(b: Byte): Unit = println(s"b(${b.toInt & 0xff})")
+ override def update(i: Int): Unit = println(s"i(${i})")
+ override def update(l: Long): Unit = println(s"l(${l})")
+ override def update(s: String): Unit = println(s"s(${s})")
override def update(b: Array[Byte]): Unit = println(s"bytes(${b.map(x => x.toInt & 0xff).mkString(", ")})")
-} \ No newline at end of file
+}
diff --git a/src/test/scala/firrtl/passes/LowerTypesSpec.scala b/src/test/scala/firrtl/passes/LowerTypesSpec.scala
index 884e51b8..0575c5da 100644
--- a/src/test/scala/firrtl/passes/LowerTypesSpec.scala
+++ b/src/test/scala/firrtl/passes/LowerTypesSpec.scala
@@ -8,7 +8,6 @@ import firrtl.stage.TransformManager
import firrtl.stage.TransformManager.TransformDependency
import org.scalatest.flatspec.AnyFlatSpec
-
/** Unit test style tests for [[LowerTypes]].
* You can find additional integration style tests in [[firrtlTests.LowerTypesSpec]]
*/
@@ -31,11 +30,12 @@ class LowerTypesEndToEndSpec extends LowerTypesBaseSpec {
| $n is invalid
|""".stripMargin
val c = CircuitState(firrtl.Parser.parse(src), Seq())
- val c2 = lowerTypesCompiler.execute(c)
+ val c2 = lowerTypesCompiler.execute(c)
val ps = c2.circuit.modules.head.ports.filterNot(p => namespace.contains(p.name))
- ps.map{p =>
+ ps.map { p =>
val orientation = Utils.to_flip(p.direction)
- s"${orientation.serialize}${p.name} : ${p.tpe.serialize}"}
+ s"${orientation.serialize}${p.name} : ${p.tpe.serialize}"
+ }
}
override protected def lower(n: String, tpe: String, namespace: Set[String]): Seq[String] =
@@ -50,8 +50,10 @@ abstract class LowerTypesBaseSpec extends AnyFlatSpec {
assert(lower("a", "{ a : UInt<1>, b : UInt<1>}") == Seq("a_a : UInt<1>", "a_b : UInt<1>"))
assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}") == Seq("a_a : UInt<1>", "a_b_c : UInt<1>"))
assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2]}") == Seq("a_a : UInt<1>", "a_b_0 : UInt<1>", "a_b_1 : UInt<1>"))
- assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]") ==
- Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>"))
+ assert(
+ lower("a", "{ a : UInt<1>, b : UInt<1>}[2]") ==
+ Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>")
+ )
// with conflicts
assert(lower("a", "{ a : UInt<1>, b : UInt<1>}", Set("a_a")) == Seq("a__a : UInt<1>", "a__b : UInt<1>"))
@@ -63,40 +65,71 @@ abstract class LowerTypesBaseSpec extends AnyFlatSpec {
assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}", Set("a_b")) == Seq("a__a : UInt<1>", "a__b_c : UInt<1>"))
assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}", Set("a_b_c")) == Seq("a__a : UInt<1>", "a__b_c : UInt<1>"))
- assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_a")) ==
- Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>"))
- assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_a", "a_b_0")) ==
- Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>"))
- assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_b_0")) ==
- Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>"))
-
- assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0")) ==
- Seq("a__0_a : UInt<1>", "a__0_b : UInt<1>", "a__1_a : UInt<1>", "a__1_b : UInt<1>"))
- assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_3")) ==
- Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>"))
- assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0_a")) ==
- Seq("a__0_a : UInt<1>", "a__0_b : UInt<1>", "a__1_a : UInt<1>", "a__1_b : UInt<1>"))
- assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0_c")) ==
- Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>"))
+ assert(
+ lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_a")) ==
+ Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>")
+ )
+ assert(
+ lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_a", "a_b_0")) ==
+ Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>")
+ )
+ assert(
+ lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_b_0")) ==
+ Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>")
+ )
+
+ assert(
+ lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0")) ==
+ Seq("a__0_a : UInt<1>", "a__0_b : UInt<1>", "a__1_a : UInt<1>", "a__1_b : UInt<1>")
+ )
+ assert(
+ lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_3")) ==
+ Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>")
+ )
+ assert(
+ lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0_a")) ==
+ Seq("a__0_a : UInt<1>", "a__0_b : UInt<1>", "a__1_a : UInt<1>", "a__1_b : UInt<1>")
+ )
+ assert(
+ lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0_c")) ==
+ Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>")
+ )
// collisions inside the bundle
- assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_c : UInt<1>}") ==
- Seq("a_a : UInt<1>", "a_b__c : UInt<1>", "a_b_c : UInt<1>"))
- assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_b : UInt<1>}") ==
- Seq("a_a : UInt<1>", "a_b_c : UInt<1>", "a_b_b : UInt<1>"))
-
- assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2], b_0 : UInt<1>}") ==
- Seq("a_a : UInt<1>", "a_b__0 : UInt<1>", "a_b__1 : UInt<1>", "a_b_0 : UInt<1>"))
- assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2], b_c : UInt<1>}") ==
- Seq("a_a : UInt<1>", "a_b_0 : UInt<1>", "a_b_1 : UInt<1>", "a_b_c : UInt<1>"))
+ assert(
+ lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_c : UInt<1>}") ==
+ Seq("a_a : UInt<1>", "a_b__c : UInt<1>", "a_b_c : UInt<1>")
+ )
+ assert(
+ lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_b : UInt<1>}") ==
+ Seq("a_a : UInt<1>", "a_b_c : UInt<1>", "a_b_b : UInt<1>")
+ )
+
+ assert(
+ lower("a", "{ a : UInt<1>, b : UInt<1>[2], b_0 : UInt<1>}") ==
+ Seq("a_a : UInt<1>", "a_b__0 : UInt<1>", "a_b__1 : UInt<1>", "a_b_0 : UInt<1>")
+ )
+ assert(
+ lower("a", "{ a : UInt<1>, b : UInt<1>[2], b_c : UInt<1>}") ==
+ Seq("a_a : UInt<1>", "a_b_0 : UInt<1>", "a_b_1 : UInt<1>", "a_b_c : UInt<1>")
+ )
}
it should "correctly lower the orientation" in {
assert(lower("a", "{ flip a : UInt<1>, b : UInt<1>}") == Seq("flip a_a : UInt<1>", "a_b : UInt<1>"))
- assert(lower("a", "{ flip a : UInt<1>[2], b : UInt<1>}") ==
- Seq("flip a_a_0 : UInt<1>", "flip a_a_1 : UInt<1>", "a_b : UInt<1>"))
- assert(lower("a", "{ a : { flip c : UInt<1>, d : UInt<1>}[2], b : UInt<1>}") ==
- Seq("flip a_a_0_c : UInt<1>", "a_a_0_d : UInt<1>", "flip a_a_1_c : UInt<1>", "a_a_1_d : UInt<1>", "a_b : UInt<1>")
+ assert(
+ lower("a", "{ flip a : UInt<1>[2], b : UInt<1>}") ==
+ Seq("flip a_a_0 : UInt<1>", "flip a_a_1 : UInt<1>", "a_b : UInt<1>")
+ )
+ assert(
+ lower("a", "{ a : { flip c : UInt<1>, d : UInt<1>}[2], b : UInt<1>}") ==
+ Seq(
+ "flip a_a_0_c : UInt<1>",
+ "a_a_0_d : UInt<1>",
+ "flip a_a_1_c : UInt<1>",
+ "a_a_1_d : UInt<1>",
+ "a_b : UInt<1>"
+ )
)
}
}
@@ -121,43 +154,45 @@ class LowerTypesRenamingSpec extends AnyFlatSpec {
def one(namespace: Set[String], prefix: String): Unit = {
val r = lower("a", "{ a : UInt<1>, b : UInt<1>}", namespace)
- assert(get(r,a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b")))
- assert(get(r,a.field("a")) == Set(m.ref(prefix + "a")))
- assert(get(r,a.field("b")) == Set(m.ref(prefix + "b")))
+ assert(get(r, a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b")))
+ assert(get(r, a.field("a")) == Set(m.ref(prefix + "a")))
+ assert(get(r, a.field("b")) == Set(m.ref(prefix + "b")))
}
one(Set(), "a_")
one(Set("a_a"), "a__")
def two(namespace: Set[String], prefix: String): Unit = {
- val r = lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}", namespace)
- assert(get(r,a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b_c")))
- assert(get(r,a.field("a")) == Set(m.ref(prefix + "a")))
- assert(get(r,a.field("b")) == Set(m.ref(prefix + "b_c")))
- assert(get(r,a.field("b").field("c")) == Set(m.ref(prefix + "b_c")))
+ val r = lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}", namespace)
+ assert(get(r, a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b_c")))
+ assert(get(r, a.field("a")) == Set(m.ref(prefix + "a")))
+ assert(get(r, a.field("b")) == Set(m.ref(prefix + "b_c")))
+ assert(get(r, a.field("b").field("c")) == Set(m.ref(prefix + "b_c")))
}
two(Set(), "a_")
two(Set("a_a"), "a__")
def three(namespace: Set[String], prefix: String): Unit = {
val r = lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", namespace)
- assert(get(r,a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b_0"), m.ref(prefix + "b_1")))
- assert(get(r,a.field("a")) == Set(m.ref(prefix + "a")))
- assert(get(r,a.field("b")) == Set( m.ref(prefix + "b_0"), m.ref(prefix + "b_1")))
- assert(get(r,a.field("b").index(0)) == Set(m.ref(prefix + "b_0")))
- assert(get(r,a.field("b").index(1)) == Set(m.ref(prefix + "b_1")))
+ assert(get(r, a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b_0"), m.ref(prefix + "b_1")))
+ assert(get(r, a.field("a")) == Set(m.ref(prefix + "a")))
+ assert(get(r, a.field("b")) == Set(m.ref(prefix + "b_0"), m.ref(prefix + "b_1")))
+ assert(get(r, a.field("b").index(0)) == Set(m.ref(prefix + "b_0")))
+ assert(get(r, a.field("b").index(1)) == Set(m.ref(prefix + "b_1")))
}
three(Set(), "a_")
three(Set("a_b_0"), "a__")
def four(namespace: Set[String], prefix: String): Unit = {
val r = lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", namespace)
- assert(get(r,a) == Set(m.ref(prefix + "0_a"), m.ref(prefix + "1_a"), m.ref(prefix + "0_b"), m.ref(prefix + "1_b")))
- assert(get(r,a.index(0)) == Set(m.ref(prefix + "0_a"), m.ref(prefix + "0_b")))
- assert(get(r,a.index(1)) == Set(m.ref(prefix + "1_a"), m.ref(prefix + "1_b")))
- assert(get(r,a.index(0).field("a")) == Set(m.ref(prefix + "0_a")))
- assert(get(r,a.index(0).field("b")) == Set(m.ref(prefix + "0_b")))
- assert(get(r,a.index(1).field("a")) == Set(m.ref(prefix + "1_a")))
- assert(get(r,a.index(1).field("b")) == Set(m.ref(prefix + "1_b")))
+ assert(
+ get(r, a) == Set(m.ref(prefix + "0_a"), m.ref(prefix + "1_a"), m.ref(prefix + "0_b"), m.ref(prefix + "1_b"))
+ )
+ assert(get(r, a.index(0)) == Set(m.ref(prefix + "0_a"), m.ref(prefix + "0_b")))
+ assert(get(r, a.index(1)) == Set(m.ref(prefix + "1_a"), m.ref(prefix + "1_b")))
+ assert(get(r, a.index(0).field("a")) == Set(m.ref(prefix + "0_a")))
+ assert(get(r, a.index(0).field("b")) == Set(m.ref(prefix + "0_b")))
+ assert(get(r, a.index(1).field("a")) == Set(m.ref(prefix + "1_a")))
+ assert(get(r, a.index(1).field("b")) == Set(m.ref(prefix + "1_b")))
}
four(Set(), "a_")
four(Set("a_0"), "a__")
@@ -166,28 +201,28 @@ class LowerTypesRenamingSpec extends AnyFlatSpec {
// collisions inside the bundle
{
val r = lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_c : UInt<1>}")
- assert(get(r,a) == Set(m.ref("a_a"), m.ref("a_b__c"), m.ref("a_b_c")))
- assert(get(r,a.field("a")) == Set(m.ref("a_a")))
- assert(get(r,a.field("b")) == Set(m.ref("a_b__c")))
- assert(get(r,a.field("b").field("c")) == Set(m.ref("a_b__c")))
- assert(get(r,a.field("b_c")) == Set(m.ref("a_b_c")))
+ assert(get(r, a) == Set(m.ref("a_a"), m.ref("a_b__c"), m.ref("a_b_c")))
+ assert(get(r, a.field("a")) == Set(m.ref("a_a")))
+ assert(get(r, a.field("b")) == Set(m.ref("a_b__c")))
+ assert(get(r, a.field("b").field("c")) == Set(m.ref("a_b__c")))
+ assert(get(r, a.field("b_c")) == Set(m.ref("a_b_c")))
}
{
val r = lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_b : UInt<1>}")
- assert(get(r,a) == Set(m.ref("a_a"), m.ref("a_b_c"), m.ref("a_b_b")))
- assert(get(r,a.field("a")) == Set(m.ref("a_a")))
- assert(get(r,a.field("b")) == Set(m.ref("a_b_c")))
- assert(get(r,a.field("b").field("c")) == Set(m.ref("a_b_c")))
- assert(get(r,a.field("b_b")) == Set(m.ref("a_b_b")))
+ assert(get(r, a) == Set(m.ref("a_a"), m.ref("a_b_c"), m.ref("a_b_b")))
+ assert(get(r, a.field("a")) == Set(m.ref("a_a")))
+ assert(get(r, a.field("b")) == Set(m.ref("a_b_c")))
+ assert(get(r, a.field("b").field("c")) == Set(m.ref("a_b_c")))
+ assert(get(r, a.field("b_b")) == Set(m.ref("a_b_b")))
}
{
val r = lower("a", "{ a : UInt<1>, b : UInt<1>[2], b_0 : UInt<1>}")
- assert(get(r,a) == Set(m.ref("a_a"), m.ref("a_b__0"), m.ref("a_b__1"), m.ref("a_b_0")))
- assert(get(r,a.field("a")) == Set(m.ref("a_a")))
- assert(get(r,a.field("b")) == Set(m.ref("a_b__0"), m.ref("a_b__1")))
- assert(get(r,a.field("b").index(0)) == Set(m.ref("a_b__0")))
- assert(get(r,a.field("b").index(1)) == Set(m.ref("a_b__1")))
- assert(get(r,a.field("b_0")) == Set(m.ref("a_b_0")))
+ assert(get(r, a) == Set(m.ref("a_a"), m.ref("a_b__0"), m.ref("a_b__1"), m.ref("a_b_0")))
+ assert(get(r, a.field("a")) == Set(m.ref("a_a")))
+ assert(get(r, a.field("b")) == Set(m.ref("a_b__0"), m.ref("a_b__1")))
+ assert(get(r, a.field("b").index(0)) == Set(m.ref("a_b__0")))
+ assert(get(r, a.field("b").index(1)) == Set(m.ref("a_b__1")))
+ assert(get(r, a.field("b_0")) == Set(m.ref("a_b_0")))
}
}
}
@@ -199,8 +234,13 @@ class LowerTypesOfInstancesSpec extends AnyFlatSpec {
private val m = CircuitTarget("m").module("m")
def resultToFieldSeq(res: Seq[(String, firrtl.ir.SubField)]): Seq[String] =
res.map(_._2).map(r => s"${r.name} : ${r.tpe.serialize}")
- private def lower(n: String, tpe: String, module: String, namespace: Set[String], renames: RenameMap = RenameMap()):
- Lower = {
+ private def lower(
+ n: String,
+ tpe: String,
+ module: String,
+ namespace: Set[String],
+ renames: RenameMap = RenameMap()
+ ): Lower = {
val ref = firrtl.ir.DefInstance(firrtl.ir.NoInfo, n, module, parseType(tpe))
val mutableSet = scala.collection.mutable.HashSet[String]() ++ namespace
val (newInstance, res) = DestructTypes.destructInstance(m, ref, mutableSet, renames, Set())
@@ -269,7 +309,7 @@ class LowerTypesOfInstancesSpec extends AnyFlatSpec {
assert(r.get(i.ref("b")).get == Seq(i_.ref("b__c")))
assert(r.get(i.ref("b").field("c")).get == Seq(i_.ref("b__c")))
assert(r.get(i.ref("b_c")).get == Seq(i_.ref("b_c")))
- }
+ }
}
}
@@ -278,101 +318,139 @@ class LowerTypesOfInstancesSpec extends AnyFlatSpec {
*/
class LowerTypesOfMemorySpec extends AnyFlatSpec {
import LowerTypesSpecUtils._
- private case class Lower(mems: Seq[firrtl.ir.DefMemory], refs: Seq[(String, firrtl.ir.SubField)],
- renameMap: RenameMap)
+ private case class Lower(
+ mems: Seq[firrtl.ir.DefMemory],
+ refs: Seq[(String, firrtl.ir.SubField)],
+ renameMap: RenameMap)
private val m = CircuitTarget("m").module("m")
private val mem = m.ref("mem")
- private def lower(name: String, tpe: String, namespace: Set[String],
- r: Seq[String] = List("r"), w: Seq[String] = List("w"), rw: Seq[String] = List(), depth: Int = 2): Lower = {
+ private def lower(
+ name: String,
+ tpe: String,
+ namespace: Set[String],
+ r: Seq[String] = List("r"),
+ w: Seq[String] = List("w"),
+ rw: Seq[String] = List(),
+ depth: Int = 2
+ ): Lower = {
val dataType = parseType(tpe)
- val mem = firrtl.ir.DefMemory(firrtl.ir.NoInfo, name, dataType, depth = depth, writeLatency = 1, readLatency = 1,
- readUnderWrite = firrtl.ir.ReadUnderWrite.Undefined, readers = r, writers = w, readwriters = rw)
+ val mem = firrtl.ir.DefMemory(
+ firrtl.ir.NoInfo,
+ name,
+ dataType,
+ depth = depth,
+ writeLatency = 1,
+ readLatency = 1,
+ readUnderWrite = firrtl.ir.ReadUnderWrite.Undefined,
+ readers = r,
+ writers = w,
+ readwriters = rw
+ )
val renames = RenameMap()
val mutableSet = scala.collection.mutable.HashSet[String]() ++ namespace
- val(mems, refs) = DestructTypes.destructMemory(m, mem, mutableSet, renames, Set())
+ val (mems, refs) = DestructTypes.destructMemory(m, mem, mutableSet, renames, Set())
Lower(mems, refs, renames)
}
private val UInt1 = firrtl.ir.UIntType(firrtl.ir.IntWidth(1))
it should "not rename anything for a ground type memory if there was no conflict" in {
- val l = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w=Seq("w"))
+ val l = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w = Seq("w"))
assert(l.renameMap.underlying.isEmpty)
}
it should "still produce reference lookups, even for a ground type memory with no conflicts" in {
- val nameToRef = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w=Seq("w")).refs
- .map{case (n,r) => n -> r.serialize}.toSet
-
- assert(nameToRef == Set(
- "mem.r.clk" -> "mem.r.clk",
- "mem.r.en" -> "mem.r.en",
- "mem.r.addr" -> "mem.r.addr",
- "mem.r.data" -> "mem.r.data",
- "mem.w.clk" -> "mem.w.clk",
- "mem.w.en" -> "mem.w.en",
- "mem.w.addr" -> "mem.w.addr",
- "mem.w.data" -> "mem.w.data",
- "mem.w.mask" -> "mem.w.mask"
- ))
+ val nameToRef = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w = Seq("w")).refs.map {
+ case (n, r) => n -> r.serialize
+ }.toSet
+
+ assert(
+ nameToRef == Set(
+ "mem.r.clk" -> "mem.r.clk",
+ "mem.r.en" -> "mem.r.en",
+ "mem.r.addr" -> "mem.r.addr",
+ "mem.r.data" -> "mem.r.data",
+ "mem.w.clk" -> "mem.w.clk",
+ "mem.w.en" -> "mem.w.en",
+ "mem.w.addr" -> "mem.w.addr",
+ "mem.w.data" -> "mem.w.data",
+ "mem.w.mask" -> "mem.w.mask"
+ )
+ )
}
it should "produce references of correct type" in {
- val nameToType = lower("mem", "UInt<4>", Set("mem_r", "mem_r_data"), w=Seq("w"), depth = 3).refs
- .map{case (n,r) => n -> r.tpe.serialize}.toSet
-
- assert(nameToType == Set(
- "mem.r.clk" -> "Clock",
- "mem.r.en" -> "UInt<1>",
- "mem.r.addr" -> "UInt<2>", // depth = 3
- "mem.r.data" -> "UInt<4>",
- "mem.w.clk" -> "Clock",
- "mem.w.en" -> "UInt<1>",
- "mem.w.addr" -> "UInt<2>",
- "mem.w.data" -> "UInt<4>",
- "mem.w.mask" -> "UInt<1>"
- ))
+ val nameToType = lower("mem", "UInt<4>", Set("mem_r", "mem_r_data"), w = Seq("w"), depth = 3).refs.map {
+ case (n, r) => n -> r.tpe.serialize
+ }.toSet
+
+ assert(
+ nameToType == Set(
+ "mem.r.clk" -> "Clock",
+ "mem.r.en" -> "UInt<1>",
+ "mem.r.addr" -> "UInt<2>", // depth = 3
+ "mem.r.data" -> "UInt<4>",
+ "mem.w.clk" -> "Clock",
+ "mem.w.en" -> "UInt<1>",
+ "mem.w.addr" -> "UInt<2>",
+ "mem.w.data" -> "UInt<4>",
+ "mem.w.mask" -> "UInt<1>"
+ )
+ )
}
it should "not rename ground type memories even if there are conflicts on the ports" in {
// There actually isn't such a thing as conflicting ports, because they do not get flattened by LowerTypes.
- val r = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w=Seq("r_data")).renameMap
+ val r = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w = Seq("r_data")).renameMap
assert(r.underlying.isEmpty)
}
it should "rename references to lowered ports" in {
- val r = lower("mem", "{ a : UInt<1>, b : UInt<1>}", Set("mem_a"), r=Seq("r", "r_data")).renameMap
+ val r = lower("mem", "{ a : UInt<1>, b : UInt<1>}", Set("mem_a"), r = Seq("r", "r_data")).renameMap
// complete memory
assert(get(r, mem) == Set(m.ref("mem__a"), m.ref("mem__b")))
// read ports
- assert(get(r, mem.field("r")) ==
- Set(m.ref("mem__a").field("r"), m.ref("mem__b").field("r")))
- assert(get(r, mem.field("r_data")) ==
- Set(m.ref("mem__a").field("r_data"), m.ref("mem__b").field("r_data")))
+ assert(
+ get(r, mem.field("r")) ==
+ Set(m.ref("mem__a").field("r"), m.ref("mem__b").field("r"))
+ )
+ assert(
+ get(r, mem.field("r_data")) ==
+ Set(m.ref("mem__a").field("r_data"), m.ref("mem__b").field("r_data"))
+ )
// port fields
- assert(get(r, mem.field("r").field("data")) ==
- Set(m.ref("mem__a").field("r").field("data"),
- m.ref("mem__b").field("r").field("data")))
- assert(get(r, mem.field("r").field("addr")) ==
- Set(m.ref("mem__a").field("r").field("addr"),
- m.ref("mem__b").field("r").field("addr")))
- assert(get(r, mem.field("r").field("en")) ==
- Set(m.ref("mem__a").field("r").field("en"),
- m.ref("mem__b").field("r").field("en")))
- assert(get(r, mem.field("r").field("clk")) ==
- Set(m.ref("mem__a").field("r").field("clk"),
- m.ref("mem__b").field("r").field("clk")))
- assert(get(r, mem.field("w").field("mask")) ==
- Set(m.ref("mem__a").field("w").field("mask"),
- m.ref("mem__b").field("w").field("mask")))
+ assert(
+ get(r, mem.field("r").field("data")) ==
+ Set(m.ref("mem__a").field("r").field("data"), m.ref("mem__b").field("r").field("data"))
+ )
+ assert(
+ get(r, mem.field("r").field("addr")) ==
+ Set(m.ref("mem__a").field("r").field("addr"), m.ref("mem__b").field("r").field("addr"))
+ )
+ assert(
+ get(r, mem.field("r").field("en")) ==
+ Set(m.ref("mem__a").field("r").field("en"), m.ref("mem__b").field("r").field("en"))
+ )
+ assert(
+ get(r, mem.field("r").field("clk")) ==
+ Set(m.ref("mem__a").field("r").field("clk"), m.ref("mem__b").field("r").field("clk"))
+ )
+ assert(
+ get(r, mem.field("w").field("mask")) ==
+ Set(m.ref("mem__a").field("w").field("mask"), m.ref("mem__b").field("w").field("mask"))
+ )
// port sub-fields
- assert(get(r, mem.field("r").field("data").field("a")) ==
- Set(m.ref("mem__a").field("r").field("data")))
- assert(get(r, mem.field("r").field("data").field("b")) ==
- Set(m.ref("mem__b").field("r").field("data")))
+ assert(
+ get(r, mem.field("r").field("data").field("a")) ==
+ Set(m.ref("mem__a").field("r").field("data"))
+ )
+ assert(
+ get(r, mem.field("r").field("data").field("b")) ==
+ Set(m.ref("mem__b").field("r").field("data"))
+ )
// need to rename the following:
// mem -> mem__a, mem__b
@@ -395,24 +473,38 @@ class LowerTypesOfMemorySpec extends AnyFlatSpec {
assert(get(r, mem) == Set(m.ref("mem__a"), m.ref("mem__b_c")))
// read port
- assert(get(r, mem.field("r")) ==
- Set(m.ref("mem__a").field("r"), m.ref("mem__b_c").field("r")))
+ assert(
+ get(r, mem.field("r")) ==
+ Set(m.ref("mem__a").field("r"), m.ref("mem__b_c").field("r"))
+ )
// port sub-fields
- assert(get(r, mem.field("r").field("data").field("a")) ==
- Set(m.ref("mem__a").field("r").field("data")))
- assert(get(r, mem.field("r").field("data").field("b")) ==
- Set(m.ref("mem__b_c").field("r").field("data")))
- assert(get(r, mem.field("r").field("data").field("b").field("c")) ==
- Set(m.ref("mem__b_c").field("r").field("data")))
+ assert(
+ get(r, mem.field("r").field("data").field("a")) ==
+ Set(m.ref("mem__a").field("r").field("data"))
+ )
+ assert(
+ get(r, mem.field("r").field("data").field("b")) ==
+ Set(m.ref("mem__b_c").field("r").field("data"))
+ )
+ assert(
+ get(r, mem.field("r").field("data").field("b").field("c")) ==
+ Set(m.ref("mem__b_c").field("r").field("data"))
+ )
// the mask field needs to be lowered just like the data field
- assert(get(r, mem.field("w").field("mask").field("a")) ==
- Set(m.ref("mem__a").field("w").field("mask")))
- assert(get(r, mem.field("w").field("mask").field("b")) ==
- Set(m.ref("mem__b_c").field("w").field("mask")))
- assert(get(r, mem.field("w").field("mask").field("b").field("c")) ==
- Set(m.ref("mem__b_c").field("w").field("mask")))
+ assert(
+ get(r, mem.field("w").field("mask").field("a")) ==
+ Set(m.ref("mem__a").field("w").field("mask"))
+ )
+ assert(
+ get(r, mem.field("w").field("mask").field("b")) ==
+ Set(m.ref("mem__b_c").field("w").field("mask"))
+ )
+ assert(
+ get(r, mem.field("w").field("mask").field("b").field("c")) ==
+ Set(m.ref("mem__b_c").field("w").field("mask"))
+ )
val renameCount = r.underlying.map(_._2.size).sum
assert(renameCount == 11, "it is enough to rename *to* 11 different signals")
@@ -420,66 +512,89 @@ class LowerTypesOfMemorySpec extends AnyFlatSpec {
}
it should "return a name to RefLikeExpression map for a memory with a nested data type" in {
- val nameToRef = lower("mem", "{ a : UInt<1>, b : { c : UInt<1>} }", Set("mem_a")).refs
- .map{case (n,r) => n -> r.serialize}.toSet
-
- assert(nameToRef == Set(
- // The non "data" or "mask" fields of read and write ports are already of ground type but still do get duplicated.
- // They will all carry the exact same value, so for a RHS use of the old signal, any of the expanded ones will do.
- "mem.r.clk" -> "mem__a.r.clk", "mem.r.clk" -> "mem__b_c.r.clk",
- "mem.r.en" -> "mem__a.r.en", "mem.r.en" -> "mem__b_c.r.en",
- "mem.r.addr" -> "mem__a.r.addr", "mem.r.addr" -> "mem__b_c.r.addr",
- "mem.w.clk" -> "mem__a.w.clk", "mem.w.clk" -> "mem__b_c.w.clk",
- "mem.w.en" -> "mem__a.w.en", "mem.w.en" -> "mem__b_c.w.en",
- "mem.w.addr" -> "mem__a.w.addr", "mem.w.addr" -> "mem__b_c.w.addr",
- // Ground type references to the data or mask field are unique.
- "mem.r.data.a" -> "mem__a.r.data",
- "mem.w.data.a" -> "mem__a.w.data",
- "mem.w.mask.a" -> "mem__a.w.mask",
- "mem.r.data.b.c" -> "mem__b_c.r.data",
- "mem.w.data.b.c" -> "mem__b_c.w.data",
- "mem.w.mask.b.c" -> "mem__b_c.w.mask"
- ))
+ val nameToRef = lower("mem", "{ a : UInt<1>, b : { c : UInt<1>} }", Set("mem_a")).refs.map {
+ case (n, r) => n -> r.serialize
+ }.toSet
+
+ assert(
+ nameToRef == Set(
+ // The non "data" or "mask" fields of read and write ports are already of ground type but still do get duplicated.
+ // They will all carry the exact same value, so for a RHS use of the old signal, any of the expanded ones will do.
+ "mem.r.clk" -> "mem__a.r.clk",
+ "mem.r.clk" -> "mem__b_c.r.clk",
+ "mem.r.en" -> "mem__a.r.en",
+ "mem.r.en" -> "mem__b_c.r.en",
+ "mem.r.addr" -> "mem__a.r.addr",
+ "mem.r.addr" -> "mem__b_c.r.addr",
+ "mem.w.clk" -> "mem__a.w.clk",
+ "mem.w.clk" -> "mem__b_c.w.clk",
+ "mem.w.en" -> "mem__a.w.en",
+ "mem.w.en" -> "mem__b_c.w.en",
+ "mem.w.addr" -> "mem__a.w.addr",
+ "mem.w.addr" -> "mem__b_c.w.addr",
+ // Ground type references to the data or mask field are unique.
+ "mem.r.data.a" -> "mem__a.r.data",
+ "mem.w.data.a" -> "mem__a.w.data",
+ "mem.w.mask.a" -> "mem__a.w.mask",
+ "mem.r.data.b.c" -> "mem__b_c.r.data",
+ "mem.w.data.b.c" -> "mem__b_c.w.data",
+ "mem.w.mask.b.c" -> "mem__b_c.w.mask"
+ )
+ )
}
it should "produce references of correct type for memories with a read/write port" in {
- val refs = lower("mem", "{ a : UInt<3>, b : { c : UInt<4>} }", Set("mem_a"),
- r=Seq(), w=Seq(), rw=Seq("rw"), depth = 3).refs
- val nameToRef = refs.map{case (n,r) => n -> r.serialize}.toSet
- val nameToType = refs.map{case (n,r) => n -> r.tpe.serialize}.toSet
-
- assert(nameToRef == Set(
- // The non "data" or "mask" fields of read and write ports are already of ground type but still do get duplicated.
- // They will all carry the exact same value, so for a RHS use of the old signal, any of the expanded ones will do.
- "mem.rw.clk" -> "mem__a.rw.clk", "mem.rw.clk" -> "mem__b_c.rw.clk",
- "mem.rw.en" -> "mem__a.rw.en", "mem.rw.en" -> "mem__b_c.rw.en",
- "mem.rw.addr" -> "mem__a.rw.addr", "mem.rw.addr" -> "mem__b_c.rw.addr",
- "mem.rw.wmode" -> "mem__a.rw.wmode", "mem.rw.wmode" -> "mem__b_c.rw.wmode",
- // Ground type references to the data or mask field are unique.
- "mem.rw.rdata.a" -> "mem__a.rw.rdata",
- "mem.rw.wdata.a" -> "mem__a.rw.wdata",
- "mem.rw.wmask.a" -> "mem__a.rw.wmask",
- "mem.rw.rdata.b.c" -> "mem__b_c.rw.rdata",
- "mem.rw.wdata.b.c" -> "mem__b_c.rw.wdata",
- "mem.rw.wmask.b.c" -> "mem__b_c.rw.wmask"
- ))
-
- assert(nameToType == Set(
- //
- "mem.rw.clk" -> "Clock",
- "mem.rw.en" -> "UInt<1>",
- "mem.rw.addr" -> "UInt<2>",
- "mem.rw.wmode" -> "UInt<1>",
- // Ground type references to the data or mask field are unique.
- "mem.rw.rdata.a" -> "UInt<3>",
- "mem.rw.wdata.a" -> "UInt<3>",
- "mem.rw.wmask.a" -> "UInt<1>",
- "mem.rw.rdata.b.c" -> "UInt<4>",
- "mem.rw.wdata.b.c" -> "UInt<4>",
- "mem.rw.wmask.b.c" -> "UInt<1>"
- ))
- }
+ val refs = lower(
+ "mem",
+ "{ a : UInt<3>, b : { c : UInt<4>} }",
+ Set("mem_a"),
+ r = Seq(),
+ w = Seq(),
+ rw = Seq("rw"),
+ depth = 3
+ ).refs
+ val nameToRef = refs.map { case (n, r) => n -> r.serialize }.toSet
+ val nameToType = refs.map { case (n, r) => n -> r.tpe.serialize }.toSet
+
+ assert(
+ nameToRef == Set(
+ // The non "data" or "mask" fields of read and write ports are already of ground type but still do get duplicated.
+ // They will all carry the exact same value, so for a RHS use of the old signal, any of the expanded ones will do.
+ "mem.rw.clk" -> "mem__a.rw.clk",
+ "mem.rw.clk" -> "mem__b_c.rw.clk",
+ "mem.rw.en" -> "mem__a.rw.en",
+ "mem.rw.en" -> "mem__b_c.rw.en",
+ "mem.rw.addr" -> "mem__a.rw.addr",
+ "mem.rw.addr" -> "mem__b_c.rw.addr",
+ "mem.rw.wmode" -> "mem__a.rw.wmode",
+ "mem.rw.wmode" -> "mem__b_c.rw.wmode",
+ // Ground type references to the data or mask field are unique.
+ "mem.rw.rdata.a" -> "mem__a.rw.rdata",
+ "mem.rw.wdata.a" -> "mem__a.rw.wdata",
+ "mem.rw.wmask.a" -> "mem__a.rw.wmask",
+ "mem.rw.rdata.b.c" -> "mem__b_c.rw.rdata",
+ "mem.rw.wdata.b.c" -> "mem__b_c.rw.wdata",
+ "mem.rw.wmask.b.c" -> "mem__b_c.rw.wmask"
+ )
+ )
+ assert(
+ nameToType == Set(
+ //
+ "mem.rw.clk" -> "Clock",
+ "mem.rw.en" -> "UInt<1>",
+ "mem.rw.addr" -> "UInt<2>",
+ "mem.rw.wmode" -> "UInt<1>",
+ // Ground type references to the data or mask field are unique.
+ "mem.rw.rdata.a" -> "UInt<3>",
+ "mem.rw.wdata.a" -> "UInt<3>",
+ "mem.rw.wmask.a" -> "UInt<1>",
+ "mem.rw.rdata.b.c" -> "UInt<4>",
+ "mem.rw.wdata.b.c" -> "UInt<4>",
+ "mem.rw.wmask.b.c" -> "UInt<1>"
+ )
+ )
+ }
it should "rename references for vector type memories" in {
val l = lower("mem", "UInt<1>[2]", Set("mem_0"))
@@ -491,14 +606,20 @@ class LowerTypesOfMemorySpec extends AnyFlatSpec {
assert(get(r, mem) == Set(m.ref("mem__0"), m.ref("mem__1")))
// read port
- assert(get(r, mem.field("r")) ==
- Set(m.ref("mem__0").field("r"), m.ref("mem__1").field("r")))
+ assert(
+ get(r, mem.field("r")) ==
+ Set(m.ref("mem__0").field("r"), m.ref("mem__1").field("r"))
+ )
// port sub-fields
- assert(get(r, mem.field("r").field("data").index(0)) ==
- Set(m.ref("mem__0").field("r").field("data")))
- assert(get(r, mem.field("r").field("data").index(1)) ==
- Set(m.ref("mem__1").field("r").field("data")))
+ assert(
+ get(r, mem.field("r").field("data").index(0)) ==
+ Set(m.ref("mem__0").field("r").field("data"))
+ )
+ assert(
+ get(r, mem.field("r").field("data").index(1)) ==
+ Set(m.ref("mem__1").field("r").field("data"))
+ )
val renameCount = r.underlying.map(_._2.size).sum
assert(renameCount == 8, "it is enough to rename *to* 8 different signals")
diff --git a/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala b/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala
index 007608ca..0b9b830c 100644
--- a/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala
+++ b/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala
@@ -8,7 +8,14 @@ import java.io.File
import firrtl._
import firrtl.stage.phases.DriverCompatibility._
import firrtl.options.{InputAnnotationFileAnnotation, Phase, TargetDirAnnotation}
-import firrtl.stage.{CompilerAnnotation, FirrtlCircuitAnnotation, FirrtlFileAnnotation, FirrtlSourceAnnotation, OutputFileAnnotation, RunFirrtlTransformAnnotation}
+import firrtl.stage.{
+ CompilerAnnotation,
+ FirrtlCircuitAnnotation,
+ FirrtlFileAnnotation,
+ FirrtlSourceAnnotation,
+ OutputFileAnnotation,
+ RunFirrtlTransformAnnotation
+}
import firrtl.stage.phases.DriverCompatibility
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
@@ -20,7 +27,7 @@ class DriverCompatibilitySpec extends AnyFlatSpec with Matchers with PrivateMeth
/* This method wraps some magic that lets you use the private method DriverCompatibility.topName */
def topName(annotations: AnnotationSeq): Option[String] = {
val topName = PrivateMethod[Option[String]]('topName)
- DriverCompatibility invokePrivate topName(annotations)
+ DriverCompatibility.invokePrivate(topName(annotations))
}
def simpleCircuit(main: String): String = s"""|circuit $main:
@@ -41,22 +48,22 @@ class DriverCompatibilitySpec extends AnyFlatSpec with Matchers with PrivateMeth
(FirrtlFileAnnotation("src/test/resources/integration/GCDTester.pb"), "GCDTester")
)
- behavior of s"${DriverCompatibility.getClass.getName}.topName (private method)"
+ behavior.of(s"${DriverCompatibility.getClass.getName}.topName (private method)")
/* This iterates over the tails of annosWithTops. Using the ordering of annosWithTops, if this AnnotationSeq is fed to
* DriverCompatibility.topName, the head annotation will be used to determine the top name. This test ensures that
* topName behaves as expected.
*/
- for ( t <- annosWithTops.tails ) t match {
+ for (t <- annosWithTops.tails) t match {
case Nil =>
it should "return None on an empty AnnotationSeq" in {
- topName(Seq.empty) should be (None)
+ topName(Seq.empty) should be(None)
}
case x =>
val annotations = x.map(_._1)
val top = x.head._2
it should s"determine a top name ('$top') from a ${annotations.head.getClass.getName}" in {
- topName(annotations).get should be (top)
+ topName(annotations).get should be(top)
}
}
@@ -66,152 +73,148 @@ class DriverCompatibilitySpec extends AnyFlatSpec with Matchers with PrivateMeth
file.createNewFile()
}
- behavior of classOf[AddImplicitAnnotationFile].toString
+ behavior.of(classOf[AddImplicitAnnotationFile].toString)
val testDir = "test_run_dir/DriverCompatibilitySpec"
it should "not modify the annotations if an InputAnnotationFile already exists" in
- new PhaseFixture(new AddImplicitAnnotationFile) {
+ new PhaseFixture(new AddImplicitAnnotationFile) {
- createFile(testDir + "/foo.anno")
- val annotations = Seq(
- InputAnnotationFileAnnotation("bar.anno"),
- TargetDirAnnotation(testDir),
- TopNameAnnotation("foo") )
+ createFile(testDir + "/foo.anno")
+ val annotations =
+ Seq(InputAnnotationFileAnnotation("bar.anno"), TargetDirAnnotation(testDir), TopNameAnnotation("foo"))
- phase.transform(annotations).toSeq should be (annotations)
- }
+ phase.transform(annotations).toSeq should be(annotations)
+ }
it should "add an InputAnnotationFile based on a derived topName" in
- new PhaseFixture(new AddImplicitAnnotationFile) {
- createFile(testDir + "/bar.anno")
- val annotations = Seq(
- TargetDirAnnotation(testDir),
- TopNameAnnotation("bar") )
+ new PhaseFixture(new AddImplicitAnnotationFile) {
+ createFile(testDir + "/bar.anno")
+ val annotations = Seq(TargetDirAnnotation(testDir), TopNameAnnotation("bar"))
- val expected = annotations.toSet +
- InputAnnotationFileAnnotation(testDir + "/bar.anno")
+ val expected = annotations.toSet +
+ InputAnnotationFileAnnotation(testDir + "/bar.anno")
- phase.transform(annotations).toSet should be (expected)
- }
+ phase.transform(annotations).toSet should be(expected)
+ }
it should "not add an InputAnnotationFile for .anno.json annotations" in
- new PhaseFixture(new AddImplicitAnnotationFile) {
- createFile(testDir + "/baz.anno.json")
- val annotations = Seq(
- TargetDirAnnotation(testDir),
- TopNameAnnotation("baz") )
+ new PhaseFixture(new AddImplicitAnnotationFile) {
+ createFile(testDir + "/baz.anno.json")
+ val annotations = Seq(TargetDirAnnotation(testDir), TopNameAnnotation("baz"))
- phase.transform(annotations).toSeq should be (annotations)
- }
+ phase.transform(annotations).toSeq should be(annotations)
+ }
it should "not add an InputAnnotationFile if it cannot determine the topName" in
- new PhaseFixture(new AddImplicitAnnotationFile) {
- val annotations = Seq( TargetDirAnnotation(testDir) )
+ new PhaseFixture(new AddImplicitAnnotationFile) {
+ val annotations = Seq(TargetDirAnnotation(testDir))
- phase.transform(annotations).toSeq should be (annotations)
- }
+ phase.transform(annotations).toSeq should be(annotations)
+ }
- behavior of classOf[AddImplicitFirrtlFile].toString
+ behavior.of(classOf[AddImplicitFirrtlFile].toString)
it should "not modify the annotations if a CircuitOption is present" in
- new PhaseFixture(new AddImplicitFirrtlFile) {
- val annotations = Seq( FirrtlFileAnnotation("foo"), TopNameAnnotation("bar") )
+ new PhaseFixture(new AddImplicitFirrtlFile) {
+ val annotations = Seq(FirrtlFileAnnotation("foo"), TopNameAnnotation("bar"))
- phase.transform(annotations).toSeq should be (annotations)
- }
+ phase.transform(annotations).toSeq should be(annotations)
+ }
it should "add an FirrtlFileAnnotation if a TopNameAnnotation is present" in
- new PhaseFixture(new AddImplicitFirrtlFile) {
- val annotations = Seq( TopNameAnnotation("foo") )
- val expected = annotations.toSet +
- FirrtlFileAnnotation(new File("foo.fir").getPath())
+ new PhaseFixture(new AddImplicitFirrtlFile) {
+ val annotations = Seq(TopNameAnnotation("foo"))
+ val expected = annotations.toSet +
+ FirrtlFileAnnotation(new File("foo.fir").getPath())
- phase.transform(annotations).toSet should be (expected)
- }
+ phase.transform(annotations).toSet should be(expected)
+ }
it should "do nothing if no TopNameAnnotation is present" in
- new PhaseFixture(new AddImplicitFirrtlFile) {
- val annotations = Seq( TargetDirAnnotation("foo") )
+ new PhaseFixture(new AddImplicitFirrtlFile) {
+ val annotations = Seq(TargetDirAnnotation("foo"))
- phase.transform(annotations).toSeq should be (annotations)
- }
+ phase.transform(annotations).toSeq should be(annotations)
+ }
- behavior of classOf[AddImplicitEmitter].toString
+ behavior.of(classOf[AddImplicitEmitter].toString)
- val (nc, hfc, mfc, lfc, vc, svc) = ( new NoneCompiler,
- new HighFirrtlCompiler,
- new MiddleFirrtlCompiler,
- new LowFirrtlCompiler,
- new VerilogCompiler,
- new SystemVerilogCompiler )
+ val (nc, hfc, mfc, lfc, vc, svc) = (
+ new NoneCompiler,
+ new HighFirrtlCompiler,
+ new MiddleFirrtlCompiler,
+ new LowFirrtlCompiler,
+ new VerilogCompiler,
+ new SystemVerilogCompiler
+ )
it should "convert CompilerAnnotations into EmitCircuitAnnotations without EmitOneFilePerModuleAnnotation" in
- new PhaseFixture(new AddImplicitEmitter) {
- val annotations = Seq(
- CompilerAnnotation(nc),
- CompilerAnnotation(hfc),
- CompilerAnnotation(mfc),
- CompilerAnnotation(lfc),
- CompilerAnnotation(vc),
- CompilerAnnotation(svc)
- )
- val expected = annotations
- .flatMap( a => Seq(a,
- RunFirrtlTransformAnnotation(a.compiler.emitter),
- EmitCircuitAnnotation(a.compiler.emitter.getClass)) )
-
- phase.transform(annotations).toSeq should be (expected)
- }
+ new PhaseFixture(new AddImplicitEmitter) {
+ val annotations = Seq(
+ CompilerAnnotation(nc),
+ CompilerAnnotation(hfc),
+ CompilerAnnotation(mfc),
+ CompilerAnnotation(lfc),
+ CompilerAnnotation(vc),
+ CompilerAnnotation(svc)
+ )
+ val expected = annotations
+ .flatMap(a =>
+ Seq(a, RunFirrtlTransformAnnotation(a.compiler.emitter), EmitCircuitAnnotation(a.compiler.emitter.getClass))
+ )
+
+ phase.transform(annotations).toSeq should be(expected)
+ }
it should "convert CompilerAnnotations into EmitAllodulesAnnotation with EmitOneFilePerModuleAnnotation" in
- new PhaseFixture(new AddImplicitEmitter) {
- val annotations = Seq(
- EmitOneFilePerModuleAnnotation,
- CompilerAnnotation(nc),
- CompilerAnnotation(hfc),
- CompilerAnnotation(mfc),
- CompilerAnnotation(lfc),
- CompilerAnnotation(vc),
- CompilerAnnotation(svc)
- )
- val expected = annotations
- .flatMap{
- case a: CompilerAnnotation => Seq(a,
- RunFirrtlTransformAnnotation(a.compiler.emitter),
- EmitAllModulesAnnotation(a.compiler.emitter.getClass))
+ new PhaseFixture(new AddImplicitEmitter) {
+ val annotations = Seq(
+ EmitOneFilePerModuleAnnotation,
+ CompilerAnnotation(nc),
+ CompilerAnnotation(hfc),
+ CompilerAnnotation(mfc),
+ CompilerAnnotation(lfc),
+ CompilerAnnotation(vc),
+ CompilerAnnotation(svc)
+ )
+ val expected = annotations.flatMap {
+ case a: CompilerAnnotation =>
+ Seq(
+ a,
+ RunFirrtlTransformAnnotation(a.compiler.emitter),
+ EmitAllModulesAnnotation(a.compiler.emitter.getClass)
+ )
case a => Seq(a)
}
- phase.transform(annotations).toSeq should be (expected)
- }
+ phase.transform(annotations).toSeq should be(expected)
+ }
- behavior of classOf[AddImplicitOutputFile].toString
+ behavior.of(classOf[AddImplicitOutputFile].toString)
it should "add an OutputFileAnnotation derived from a TopNameAnnotation if no OutputFileAnnotation exists" in
- new PhaseFixture(new AddImplicitOutputFile) {
- val annotations = Seq( TopNameAnnotation("foo") )
- val expected = Seq(
- OutputFileAnnotation("foo"),
- TopNameAnnotation("foo")
- )
- phase.transform(annotations).toSeq should be (expected)
- }
+ new PhaseFixture(new AddImplicitOutputFile) {
+ val annotations = Seq(TopNameAnnotation("foo"))
+ val expected = Seq(
+ OutputFileAnnotation("foo"),
+ TopNameAnnotation("foo")
+ )
+ phase.transform(annotations).toSeq should be(expected)
+ }
it should "do nothing if an OutputFileannotation already exists" in
- new PhaseFixture(new AddImplicitOutputFile) {
- val annotations = Seq(
- TopNameAnnotation("foo"),
- OutputFileAnnotation("bar") )
- val expected = annotations
- phase.transform(annotations).toSeq should be (expected)
- }
+ new PhaseFixture(new AddImplicitOutputFile) {
+ val annotations = Seq(TopNameAnnotation("foo"), OutputFileAnnotation("bar"))
+ val expected = annotations
+ phase.transform(annotations).toSeq should be(expected)
+ }
it should "do nothing if no TopNameAnnotation exists" in
- new PhaseFixture(new AddImplicitOutputFile) {
- val annotations = Seq.empty
- val expected = annotations
- phase.transform(annotations).toSeq should be (expected)
- }
+ new PhaseFixture(new AddImplicitOutputFile) {
+ val annotations = Seq.empty
+ val expected = annotations
+ phase.transform(annotations).toSeq should be(expected)
+ }
}
diff --git a/src/test/scala/firrtl/testutils/FirrtlSpec.scala b/src/test/scala/firrtl/testutils/FirrtlSpec.scala
index dfc20352..a0c41085 100644
--- a/src/test/scala/firrtl/testutils/FirrtlSpec.scala
+++ b/src/test/scala/firrtl/testutils/FirrtlSpec.scala
@@ -46,11 +46,13 @@ object RenameTop extends Transform {
val c = state.circuit
val ns = Namespace(c)
- val newTopName = state.annotations.collectFirst({
- case RenameTopAnnotation(name) =>
- require(ns.tryName(name))
- name
- }).getOrElse(c.main)
+ val newTopName = state.annotations
+ .collectFirst({
+ case RenameTopAnnotation(name) =>
+ require(ns.tryName(name))
+ name
+ })
+ .getOrElse(c.main)
state.annotations.collect {
case ModuleNamespaceAnnotation(mustNotCollideNS) => require(mustNotCollideNS.tryName(newTopName))
@@ -70,6 +72,7 @@ object RenameTop extends Transform {
trait FirrtlRunners extends BackendCompilationUtilities {
val cppHarnessResourceName: String = "/firrtl/testTop.cpp"
+
/** Extra transforms to run by default */
val extraCheckTransforms = Seq(new CheckLowForm)
@@ -80,10 +83,12 @@ trait FirrtlRunners extends BackendCompilationUtilities {
* @param customAnnotations Optional Firrtl annotations
* @param timesteps the maximum number of timesteps to consider
*/
- def firrtlEquivalenceTest(input: String,
- customTransforms: Seq[Transform] = Seq.empty,
- customAnnotations: AnnotationSeq = Seq.empty,
- timesteps: Int = 1): Unit = {
+ def firrtlEquivalenceTest(
+ input: String,
+ customTransforms: Seq[Transform] = Seq.empty,
+ customAnnotations: AnnotationSeq = Seq.empty,
+ timesteps: Int = 1
+ ): Unit = {
val circuit = Parser.parse(input.split("\n").toIterator)
val prefix = circuit.main
val testDir = createTestDirectory(prefix + "_equivalence_test")
@@ -93,12 +98,12 @@ trait FirrtlRunners extends BackendCompilationUtilities {
def getBaseAnnos(topName: String) = {
val baseTransforms = RenameTop +: extraCheckTransforms
TargetDirAnnotation(testDir.toString) +:
- InfoModeAnnotation("ignore") +:
- RenameTopAnnotation(topName) +:
- stage.FirrtlCircuitAnnotation(circuit) +:
- stage.CompilerAnnotation("mverilog") +:
- stage.OutputFileAnnotation(topName) +:
- toAnnos(baseTransforms)
+ InfoModeAnnotation("ignore") +:
+ RenameTopAnnotation(topName) +:
+ stage.FirrtlCircuitAnnotation(circuit) +:
+ stage.CompilerAnnotation("mverilog") +:
+ stage.OutputFileAnnotation(topName) +:
+ toAnnos(baseTransforms)
}
val customName = s"${prefix}_custom"
@@ -111,7 +116,8 @@ trait FirrtlRunners extends BackendCompilationUtilities {
val refAnnos = getBaseAnnos(refSuggestedName) ++: Seq(RunFirrtlTransformAnnotation(new RenameModules), nsAnno)
val refResult = (new firrtl.stage.FirrtlStage).execute(Array.empty, refAnnos)
- val refName = refResult.collectFirst({ case stage.FirrtlCircuitAnnotation(c) => c.main }).getOrElse(refSuggestedName)
+ val refName =
+ refResult.collectFirst({ case stage.FirrtlCircuitAnnotation(c) => c.main }).getOrElse(refSuggestedName)
assert(BackendCompilationUtilities.yosysExpectSuccess(customName, refName, testDir, timesteps))
}
@@ -123,6 +129,7 @@ trait FirrtlRunners extends BackendCompilationUtilities {
val res = compiler.compileAndEmit(CircuitState(circuit, HighForm, annotations), extraCheckTransforms)
res.getEmittedCircuit.value
}
+
/** Compile a Firrtl file
*
* @param prefix is the name of the Firrtl file without path or file extension
@@ -130,25 +137,27 @@ trait FirrtlRunners extends BackendCompilationUtilities {
* @param annotations Optional Firrtl annotations
*/
def compileFirrtlTest(
- prefix: String,
- srcDir: String,
- customTransforms: Seq[Transform] = Seq.empty,
- annotations: AnnotationSeq = Seq.empty): File = {
+ prefix: String,
+ srcDir: String,
+ customTransforms: Seq[Transform] = Seq.empty,
+ annotations: AnnotationSeq = Seq.empty
+ ): File = {
val testDir = createTestDirectory(prefix)
val inputFile = new File(testDir, s"${prefix}.fir")
copyResourceToFile(s"${srcDir}/${prefix}.fir", inputFile)
val annos =
FirrtlFileAnnotation(inputFile.toString) +:
- TargetDirAnnotation(testDir.toString) +:
- InfoModeAnnotation("ignore") +:
- annotations ++:
- (customTransforms ++ extraCheckTransforms).map(RunFirrtlTransformAnnotation(_))
+ TargetDirAnnotation(testDir.toString) +:
+ InfoModeAnnotation("ignore") +:
+ annotations ++:
+ (customTransforms ++ extraCheckTransforms).map(RunFirrtlTransformAnnotation(_))
(new firrtl.stage.FirrtlStage).execute(Array.empty, annos)
testDir
}
+
/** Execute a Firrtl Test
*
* @param prefix is the name of the Firrtl file without path or file extension
@@ -157,25 +166,26 @@ trait FirrtlRunners extends BackendCompilationUtilities {
* @param annotations Optional Firrtl annotations
*/
def runFirrtlTest(
- prefix: String,
- srcDir: String,
- verilogPrefixes: Seq[String] = Seq.empty,
- customTransforms: Seq[Transform] = Seq.empty,
- annotations: AnnotationSeq = Seq.empty) = {
+ prefix: String,
+ srcDir: String,
+ verilogPrefixes: Seq[String] = Seq.empty,
+ customTransforms: Seq[Transform] = Seq.empty,
+ annotations: AnnotationSeq = Seq.empty
+ ) = {
val testDir = compileFirrtlTest(prefix, srcDir, customTransforms, annotations)
val harness = new File(testDir, s"top.cpp")
copyResourceToFile(cppHarnessResourceName, harness)
// Note file copying side effect
- val verilogFiles = verilogPrefixes map { vprefix =>
+ val verilogFiles = verilogPrefixes.map { vprefix =>
val file = new File(testDir, s"$vprefix.v")
copyResourceToFile(s"$srcDir/$vprefix.v", file)
file
}
verilogToCpp(prefix, testDir, verilogFiles, harness) #&&
- cppToExe(prefix, testDir) !
- loggingProcessLogger
+ cppToExe(prefix, testDir) !
+ loggingProcessLogger
assert(executeExpectingSuccess(prefix, testDir))
}
}
@@ -201,6 +211,7 @@ trait FirrtlMatchers extends Matchers {
require(!s.contains("\n"))
s.replaceAll("\\s+", " ").trim
}
+
/** Helper to make circuits that are the same appear the same */
def canonicalize(circuit: Circuit): Circuit = {
import firrtl.Mappers._
@@ -208,19 +219,21 @@ trait FirrtlMatchers extends Matchers {
circuit.map(onModule)
}
def parse(str: String) = Parser.parse(str.split("\n").toIterator, UseInfo)
+
/** Helper for executing tests
* compiler will be run on input then emitted result will each be split into
* lines and normalized.
*/
def executeTest(
- input: String,
- expected: Seq[String],
- compiler: Compiler,
- annotations: Seq[Annotation] = Seq.empty) = {
+ input: String,
+ expected: Seq[String],
+ compiler: Compiler,
+ annotations: Seq[Annotation] = Seq.empty
+ ) = {
val finalState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations))
- val lines = finalState.getEmittedCircuit.value split "\n" map normalized
+ val lines = finalState.getEmittedCircuit.value.split("\n").map(normalized)
for (e <- expected) {
- lines should contain (e)
+ lines should contain(e)
}
}
}
@@ -239,10 +252,12 @@ object FirrtlCheckers extends FirrtlMatchers {
case Some(res) => res
// Otherwise keep digging
case None =>
- require(node.isInstanceOf[Product] || !node.isInstanceOf[FirrtlNode],
- "Error! Unexpected FirrtlNode that does not implement Product!")
+ require(
+ node.isInstanceOf[Product] || !node.isInstanceOf[FirrtlNode],
+ "Error! Unexpected FirrtlNode that does not implement Product!"
+ )
val iter = node match {
- case p: Product => p.productIterator
+ case p: Product => p.productIterator
case i: Iterable[Any] => i.iterator
case _ => Iterator.empty
}
@@ -296,57 +311,63 @@ class TestFirrtlFlatSpec extends FirrtlFlatSpec {
import FirrtlCheckers._
val c = parse("""
- |circuit Test:
- | module Test :
- | input in : UInt<8>
- | output out : UInt<8>
- | out <= in
- |""".stripMargin)
+ |circuit Test:
+ | module Test :
+ | input in : UInt<8>
+ | output out : UInt<8>
+ | out <= in
+ |""".stripMargin)
val state = CircuitState(c, ChirrtlForm)
val compiled = (new LowFirrtlCompiler).compileAndEmit(state, List.empty)
// While useful, ScalaTest helpers should be used over search
- behavior of "Search"
+ behavior.of("Search")
it should "be supported on Circuit" in {
- assert(c search {
- case Connect(_, Reference("out",_, _, _), Reference("in", _, _, _)) => true
+ assert(c.search {
+ case Connect(_, Reference("out", _, _, _), Reference("in", _, _, _)) => true
})
}
it should "be supported on CircuitStates" in {
- assert(state search {
- case Connect(_, Reference("out", _, _, _), Reference("in",_, _, _)) => true
+ assert(state.search {
+ case Connect(_, Reference("out", _, _, _), Reference("in", _, _, _)) => true
})
}
it should "be supported on the results of compilers" in {
- assert(compiled search {
- case Connect(_, WRef("out",_,_,_), WRef("in",_,_,_)) => true
+ assert(compiled.search {
+ case Connect(_, WRef("out", _, _, _), WRef("in", _, _, _)) => true
})
}
// Use these!!!
- behavior of "ScalaTest helpers"
+ behavior.of("ScalaTest helpers")
they should "work for lines of emitted text" in {
- compiled should containLine (s"input in : UInt<8>")
- compiled should containLine (s"output out : UInt<8>")
- compiled should containLine (s"out <= in")
+ compiled should containLine(s"input in : UInt<8>")
+ compiled should containLine(s"output out : UInt<8>")
+ compiled should containLine(s"out <= in")
}
they should "work for partial functions matching on subtrees" in {
val UInt8 = UIntType(IntWidth(8)) // BigInt unapply is weird
compiled should containTree { case Port(_, "in", Input, UInt8) => true }
compiled should containTree { case Port(_, "out", Output, UInt8) => true }
- compiled should containTree { case Connect(_, WRef("out",_,_,_), WRef("in",_,_,_)) => true }
+ compiled should containTree { case Connect(_, WRef("out", _, _, _), WRef("in", _, _, _)) => true }
}
}
/** Super class for execution driven Firrtl tests */
-abstract class ExecutionTest(name: String, dir: String, vFiles: Seq[String] = Seq.empty, annotations: AnnotationSeq = Seq.empty) extends FirrtlPropSpec {
+abstract class ExecutionTest(
+ name: String,
+ dir: String,
+ vFiles: Seq[String] = Seq.empty,
+ annotations: AnnotationSeq = Seq.empty)
+ extends FirrtlPropSpec {
property(s"$name should execute correctly") {
runFirrtlTest(name, dir, vFiles, annotations = annotations)
}
}
+
/** Super class for compilation driven Firrtl tests */
abstract class CompilationTest(name: String, dir: String) extends FirrtlPropSpec {
property(s"$name should compile correctly") {
@@ -444,7 +465,9 @@ abstract class EquivalenceTest(transforms: Seq[Transform], name: String, dir: St
throw new FileNotFoundException(s"Resource '$fileName'")
}
val source = scala.io.Source.fromInputStream(in)
- val input = try source.mkString finally source.close()
+ val input =
+ try source.mkString
+ finally source.close()
s"$name with ${transforms.map(_.name).mkString(", ")}" should
s"be equivalent to $name without ${transforms.map(_.name).mkString(", ")}" in {
diff --git a/src/test/scala/firrtl/testutils/LeanTransformSpec.scala b/src/test/scala/firrtl/testutils/LeanTransformSpec.scala
index c1f0943a..4ae6a7be 100644
--- a/src/test/scala/firrtl/testutils/LeanTransformSpec.scala
+++ b/src/test/scala/firrtl/testutils/LeanTransformSpec.scala
@@ -1,6 +1,6 @@
package firrtl.testutils
-import firrtl.{AnnotationSeq, CircuitState, EmitCircuitAnnotation, ir}
+import firrtl.{ir, AnnotationSeq, CircuitState, EmitCircuitAnnotation}
import firrtl.options.Dependency
import firrtl.passes.RemoveEmpty
import firrtl.stage.TransformManager.TransformDependency
@@ -11,30 +11,33 @@ class VerilogTransformSpec extends LeanTransformSpec(Seq(Dependency[firrtl.Veril
class LowFirrtlTransformSpec extends LeanTransformSpec(Seq(Dependency[firrtl.LowFirrtlEmitter]))
/** The new cool kid on the block, creates a custom compiler for your transform. */
-class LeanTransformSpec(protected val transforms: Seq[TransformDependency]) extends AnyFlatSpec with FirrtlMatchers with LazyLogging {
+class LeanTransformSpec(protected val transforms: Seq[TransformDependency])
+ extends AnyFlatSpec
+ with FirrtlMatchers
+ with LazyLogging {
private val compiler = new firrtl.stage.transforms.Compiler(transforms)
private val emitterAnnos = LeanTransformSpec.deriveEmitCircuitAnnotations(transforms)
protected def compile(src: String): CircuitState = compile(src, Seq())
protected def compile(src: String, annos: AnnotationSeq): CircuitState = compile(firrtl.Parser.parse(src), annos)
- protected def compile(c: ir.Circuit): CircuitState = compile(c, Seq())
- protected def compile(c: ir.Circuit, annos: AnnotationSeq): CircuitState =
+ protected def compile(c: ir.Circuit): CircuitState = compile(c, Seq())
+ protected def compile(c: ir.Circuit, annos: AnnotationSeq): CircuitState =
compiler.transform(CircuitState(c, emitterAnnos ++ annos))
- protected def execute(input: String, check: String): CircuitState = execute(input, check ,Seq())
+ protected def execute(input: String, check: String): CircuitState = execute(input, check, Seq())
protected def execute(input: String, check: String, inAnnos: AnnotationSeq): CircuitState = {
val finalState = compiler.transform(CircuitState(parse(input), inAnnos))
val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize
val expected = parse(check).serialize
logger.debug(actual)
logger.debug(expected)
- actual should be (expected)
+ actual should be(expected)
finalState
}
}
private object LeanTransformSpec {
private def deriveEmitCircuitAnnotations(transforms: Iterable[TransformDependency]): AnnotationSeq = {
- val emitters = transforms.map(_.getObject()).collect{ case e: firrtl.Emitter => e }
+ val emitters = transforms.map(_.getObject()).collect { case e: firrtl.Emitter => e }
emitters.map(e => EmitCircuitAnnotation(e.getClass)).toSeq
}
}
@@ -47,4 +50,4 @@ trait MakeCompiler {
new firrtl.stage.transforms.Compiler(Seq(Dependency[firrtl.MinimumVerilogEmitter]) ++ transforms)
protected def makeLowFirrtlCompiler(transforms: Seq[TransformDependency] = Seq()) =
new firrtl.stage.transforms.Compiler(Seq(Dependency[firrtl.LowFirrtlEmitter]) ++ transforms)
-} \ No newline at end of file
+}
diff --git a/src/test/scala/firrtl/testutils/PassTests.scala b/src/test/scala/firrtl/testutils/PassTests.scala
index 49dea199..7a5dc306 100644
--- a/src/test/scala/firrtl/testutils/PassTests.scala
+++ b/src/test/scala/firrtl/testutils/PassTests.scala
@@ -15,49 +15,53 @@ import org.scalatest.flatspec.AnyFlatSpec
// An example methodology for testing Firrtl Passes
// Spec class should extend this class
abstract class SimpleTransformSpec extends AnyFlatSpec with FirrtlMatchers with Compiler with LazyLogging {
- // Utility function
- def squash(c: Circuit): Circuit = RemoveEmpty.run(c)
-
- // Executes the test. Call in tests.
- // annotations cannot have default value because scalatest trait Suite has a default value
- def execute(input: String, check: String, annotations: Seq[Annotation]): CircuitState = {
- val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations))
- val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize
- val expected = parse(check).serialize
- logger.debug(actual)
- logger.debug(expected)
- (actual) should be (expected)
- finalState
- }
-
- def executeWithAnnos(input: String, check: String, annotations: Seq[Annotation],
- checkAnnotations: Seq[Annotation]): CircuitState = {
- val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations))
- val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize
- val expected = parse(check).serialize
- logger.debug(actual)
- logger.debug(expected)
- (actual) should be (expected)
-
- annotations.foreach { anno =>
- logger.debug(anno.serialize)
- }
-
- finalState.annotations.toSeq.foreach { anno =>
- logger.debug(anno.serialize)
- }
- checkAnnotations.foreach { check =>
- (finalState.annotations.toSeq) should contain (check)
- }
- finalState
- }
- // Executes the test, should throw an error
- // No default to be consistent with execute
- def failingexecute(input: String, annotations: Seq[Annotation]): Exception = {
- intercept[PassExceptions] {
- compile(CircuitState(parse(input), ChirrtlForm, annotations), Seq.empty)
- }
- }
+ // Utility function
+ def squash(c: Circuit): Circuit = RemoveEmpty.run(c)
+
+ // Executes the test. Call in tests.
+ // annotations cannot have default value because scalatest trait Suite has a default value
+ def execute(input: String, check: String, annotations: Seq[Annotation]): CircuitState = {
+ val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations))
+ val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize
+ val expected = parse(check).serialize
+ logger.debug(actual)
+ logger.debug(expected)
+ (actual) should be(expected)
+ finalState
+ }
+
+ def executeWithAnnos(
+ input: String,
+ check: String,
+ annotations: Seq[Annotation],
+ checkAnnotations: Seq[Annotation]
+ ): CircuitState = {
+ val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations))
+ val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize
+ val expected = parse(check).serialize
+ logger.debug(actual)
+ logger.debug(expected)
+ (actual) should be(expected)
+
+ annotations.foreach { anno =>
+ logger.debug(anno.serialize)
+ }
+
+ finalState.annotations.toSeq.foreach { anno =>
+ logger.debug(anno.serialize)
+ }
+ checkAnnotations.foreach { check =>
+ (finalState.annotations.toSeq) should contain(check)
+ }
+ finalState
+ }
+ // Executes the test, should throw an error
+ // No default to be consistent with execute
+ def failingexecute(input: String, annotations: Seq[Annotation]): Exception = {
+ intercept[PassExceptions] {
+ compile(CircuitState(parse(input), ChirrtlForm, annotations), Seq.empty)
+ }
+ }
}
@deprecated(
@@ -86,19 +90,19 @@ object ReRunResolveAndCheck extends Transform with DependencyAPIMigration with I
}
trait LowTransformSpec extends SimpleTransformSpec {
- def emitter = new LowFirrtlEmitter
- def transform: Transform
- def transforms: Seq[Transform] = transform +: ReRunResolveAndCheck +: Forms.LowForm.map(_.getObject)
+ def emitter = new LowFirrtlEmitter
+ def transform: Transform
+ def transforms: Seq[Transform] = transform +: ReRunResolveAndCheck +: Forms.LowForm.map(_.getObject)
}
trait MiddleTransformSpec extends SimpleTransformSpec {
- def emitter = new MiddleFirrtlEmitter
- def transform: Transform
- def transforms: Seq[Transform] = transform +: ReRunResolveAndCheck +: Forms.MidForm.map(_.getObject)
+ def emitter = new MiddleFirrtlEmitter
+ def transform: Transform
+ def transforms: Seq[Transform] = transform +: ReRunResolveAndCheck +: Forms.MidForm.map(_.getObject)
}
trait HighTransformSpec extends SimpleTransformSpec {
- def emitter = new HighFirrtlEmitter
- def transform: Transform
- def transforms = transform +: ReRunResolveAndCheck +: Forms.HighForm.map(_.getObject)
+ def emitter = new HighFirrtlEmitter
+ def transform: Transform
+ def transforms = transform +: ReRunResolveAndCheck +: Forms.HighForm.map(_.getObject)
}