diff options
| author | chick | 2020-08-14 19:47:53 -0700 |
|---|---|---|
| committer | Jack Koenig | 2020-08-14 19:47:53 -0700 |
| commit | 6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch) | |
| tree | 2ed103ee80b0fba613c88a66af854ae9952610ce /src/test/scala/firrtl | |
| parent | b516293f703c4de86397862fee1897aded2ae140 (diff) | |
All of src/ formatted with scalafmt
Diffstat (limited to 'src/test/scala/firrtl')
16 files changed, 958 insertions, 724 deletions
diff --git a/src/test/scala/firrtl/JsonProtocolSpec.scala b/src/test/scala/firrtl/JsonProtocolSpec.scala index 7d04e9fc..cc7591cb 100644 --- a/src/test/scala/firrtl/JsonProtocolSpec.scala +++ b/src/test/scala/firrtl/JsonProtocolSpec.scala @@ -4,7 +4,13 @@ package firrtlTests import org.json4s._ -import firrtl.annotations.{NoTargetAnnotation, JsonProtocol, InvalidAnnotationJSONException, HasSerializationHints, Annotation} +import firrtl.annotations.{ + Annotation, + HasSerializationHints, + InvalidAnnotationJSONException, + JsonProtocol, + NoTargetAnnotation +} import org.scalatest.flatspec.AnyFlatSpec object JsonProtocolTestClasses { @@ -13,12 +19,16 @@ object JsonProtocolTestClasses { case class ChildA(foo: Int) extends Parent case class ChildB(bar: String) extends Parent case class PolymorphicParameterAnnotation(param: Parent) extends NoTargetAnnotation - case class PolymorphicParameterAnnotationWithTypeHints(param: Parent) extends NoTargetAnnotation with HasSerializationHints { + case class PolymorphicParameterAnnotationWithTypeHints(param: Parent) + extends NoTargetAnnotation + with HasSerializationHints { def typeHints = Seq(param.getClass) } case class TypeParameterizedAnnotation[T](param: T) extends NoTargetAnnotation - case class TypeParameterizedAnnotationWithTypeHints[T](param: T) extends NoTargetAnnotation with HasSerializationHints { + case class TypeParameterizedAnnotationWithTypeHints[T](param: T) + extends NoTargetAnnotation + with HasSerializationHints { def typeHints = Seq(param.getClass) } } @@ -51,11 +61,11 @@ class JsonProtocolSpec extends AnyFlatSpec { "Annotations with non-primitive type parameters" should "not serialize and deserialize without type hints" in { val anno = TypeParameterizedAnnotation(ChildA(1)) val deserAnno = serializeAndDeserialize(anno) - assert (anno != deserAnno) + assert(anno != deserAnno) } it should "serialize and deserialize with type hints" in { val anno = TypeParameterizedAnnotationWithTypeHints(ChildA(1)) val deserAnno = serializeAndDeserialize(anno) - assert (anno == deserAnno) + assert(anno == deserAnno) } } diff --git a/src/test/scala/firrtl/analysis/SymbolTableSpec.scala b/src/test/scala/firrtl/analysis/SymbolTableSpec.scala index 599b4e52..ca30b60b 100644 --- a/src/test/scala/firrtl/analysis/SymbolTableSpec.scala +++ b/src/test/scala/firrtl/analysis/SymbolTableSpec.scala @@ -8,7 +8,7 @@ import firrtl.options.Dependency import org.scalatest.flatspec.AnyFlatSpec class SymbolTableSpec extends AnyFlatSpec { - behavior of "SymbolTable" + behavior.of("SymbolTable") private val src = """circuit m: @@ -50,9 +50,20 @@ class SymbolTableSpec extends AnyFlatSpec { assert(syms("r").tpe == ir.SIntType(ir.IntWidth(4)) && syms("r").kind == firrtl.RegKind) val mType = firrtl.passes.MemPortUtils.memType( // only dataType, depth and reader, writer, readwriter properties affect the data type - ir.DefMemory(ir.NoInfo, "???", ir.UIntType(ir.IntWidth(8)), 32, 10, 10, Seq("r"), Seq(), Seq(), ir.ReadUnderWrite.New) + ir.DefMemory( + ir.NoInfo, + "???", + ir.UIntType(ir.IntWidth(8)), + 32, + 10, + 10, + Seq("r"), + Seq(), + Seq(), + ir.ReadUnderWrite.New + ) ) - assert(syms("m") .tpe == mType && syms("m").kind == firrtl.MemKind) + assert(syms("m").tpe == mType && syms("m").kind == firrtl.MemKind) } it should "find all declarations in module m after InferTypes" in { @@ -69,7 +80,7 @@ class SymbolTableSpec extends AnyFlatSpec { assert(syms("i").tpe == iType && syms("i").kind == firrtl.InstanceKind) } - behavior of "WithSeq" + behavior.of("WithSeq") it should "preserve declaration order" in { val c = firrtl.Parser.parse(src) @@ -79,7 +90,7 @@ class SymbolTableSpec extends AnyFlatSpec { assert(syms.getSymbols.map(_.name) == Seq("clk", "x", "y", "z", "a", "i", "r", "m")) } - behavior of "ModuleTypesSymbolTable" + behavior.of("ModuleTypesSymbolTable") it should "derive the module type from the module types map" in { val c = firrtl.Parser.parse(src) diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala index 015ac4a9..f7ce9914 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala @@ -20,10 +20,16 @@ private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec { sys.signals.head.e.toString } - def primop(signed: Boolean, op: String, resWidth: Int, inWidth: Seq[Int], consts: Seq[Int] = List(), - resAlwaysUnsigned: Boolean = false): String = { - val tpe = if(signed) "SInt" else "UInt" - val resTpe = if(resAlwaysUnsigned) "UInt" else tpe + def primop( + signed: Boolean, + op: String, + resWidth: Int, + inWidth: Seq[Int], + consts: Seq[Int] = List(), + resAlwaysUnsigned: Boolean = false + ): String = { + val tpe = if (signed) "SInt" else "UInt" + val resTpe = if (resAlwaysUnsigned) "UInt" else tpe val inTpes = inWidth.map(w => s"$tpe<$w>") primop(op, s"$resTpe<$resWidth>", inTpes, consts) } @@ -52,16 +58,24 @@ private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec { it should "correctly translate the `div` primitive operation" in { // division is a little bit more complicated because the result of division by zero is undefined - assert(primop(false, "div", 8, List(8, 8)) == - "ite(eq(i1, 8'b0), RANDOM.res, udiv(i0, i1))") - assert(primop(false, "div", 8, List(8, 4)) == - "ite(eq(i1, 4'b0), RANDOM.res, udiv(i0, zext(i1, 4)))") + assert( + primop(false, "div", 8, List(8, 8)) == + "ite(eq(i1, 8'b0), RANDOM.res, udiv(i0, i1))" + ) + assert( + primop(false, "div", 8, List(8, 4)) == + "ite(eq(i1, 4'b0), RANDOM.res, udiv(i0, zext(i1, 4)))" + ) // signed division increases result width by 1 - assert(primop(true, "div", 8, List(7, 7)) == - "ite(eq(i1, 7'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 1)))") - assert(primop(true, "div", 8, List(7, 4)) - == "ite(eq(i1, 4'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 4)))") + assert( + primop(true, "div", 8, List(7, 7)) == + "ite(eq(i1, 7'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 1)))" + ) + assert( + primop(true, "div", 8, List(7, 4)) + == "ite(eq(i1, 4'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 4)))" + ) } it should "correctly translate the `rem` primitive operation" in { @@ -134,15 +148,19 @@ private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec { it should "correctly translate the `dshl` primitive operation" in { assert(primop(false, "dshl", 31, List(16, 4)) == "logical_shift_left(zext(i0, 15), zext(i1, 27))") assert(primop(false, "dshl", 19, List(16, 2)) == "logical_shift_left(zext(i0, 3), zext(i1, 17))") - assert(primop("dshl", "SInt<19>", List("SInt<16>", "UInt<2>"), List()) == - "logical_shift_left(sext(i0, 3), zext(i1, 17))") + assert( + primop("dshl", "SInt<19>", List("SInt<16>", "UInt<2>"), List()) == + "logical_shift_left(sext(i0, 3), zext(i1, 17))" + ) } it should "correctly translate the `dshr` primitive operation" in { assert(primop(false, "dshr", 16, List(16, 4)) == "logical_shift_right(i0, zext(i1, 12))") assert(primop(false, "dshr", 16, List(16, 2)) == "logical_shift_right(i0, zext(i1, 14))") - assert(primop("dshr", "SInt<16>", List("SInt<16>", "UInt<2>"), List()) == - "arithmetic_shift_right(i0, zext(i1, 14))") + assert( + primop("dshr", "SInt<16>", List("SInt<16>", "UInt<2>"), List()) == + "arithmetic_shift_right(i0, zext(i1, 14))" + ) } it should "correctly translate the `cvt` primitive operation" in { @@ -197,15 +215,15 @@ private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec { } it should "correctly translate the `bits` primitive operation" in { - assert(primop(false, "bits", 1, List(4), List(2,2)) == "i0[2]") - assert(primop(false, "bits", 2, List(4), List(2,1)) == "i0[2:1]") - assert(primop(false, "bits", 1, List(4), List(2,1)) == "i0[2:1][0]") - assert(primop(false, "bits", 3, List(4), List(2,1)) == "zext(i0[2:1], 1)") + assert(primop(false, "bits", 1, List(4), List(2, 2)) == "i0[2]") + assert(primop(false, "bits", 2, List(4), List(2, 1)) == "i0[2:1]") + assert(primop(false, "bits", 1, List(4), List(2, 1)) == "i0[2:1][0]") + assert(primop(false, "bits", 3, List(4), List(2, 1)) == "zext(i0[2:1], 1)") - assert(primop(true, "bits", 1, List(4), List(2,2), resAlwaysUnsigned = true) == "i0[2]") - assert(primop(true, "bits", 2, List(4), List(2,1), resAlwaysUnsigned = true) == "i0[2:1]") - assert(primop(true, "bits", 1, List(4), List(2,1), resAlwaysUnsigned = true) == "i0[2:1][0]") - assert(primop(true, "bits", 3, List(4), List(2,1), resAlwaysUnsigned = true) == "zext(i0[2:1], 1)") + assert(primop(true, "bits", 1, List(4), List(2, 2), resAlwaysUnsigned = true) == "i0[2]") + assert(primop(true, "bits", 2, List(4), List(2, 1), resAlwaysUnsigned = true) == "i0[2:1]") + assert(primop(true, "bits", 1, List(4), List(2, 1), resAlwaysUnsigned = true) == "i0[2:1][0]") + assert(primop(true, "bits", 3, List(4), List(2, 1), resAlwaysUnsigned = true) == "zext(i0[2:1], 1)") } it should "correctly translate the `head` primitive operation" in { @@ -221,4 +239,4 @@ private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec { assert(primop(false, "tail", 4, List(5), List(1)) == "i0[3:0]") assert(primop(false, "tail", 2, List(5), List(3)) == "i0[1:0]") } -}
\ No newline at end of file +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala index ca7974c5..b41313e3 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala @@ -5,8 +5,7 @@ package firrtl.backends.experimental.smt import firrtl.{MemoryArrayInit, MemoryScalarInit, Utils} private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec { - behavior of "ModuleToTransitionSystem.run" - + behavior.of("ModuleToTransitionSystem.run") it should "model registers as state" in { // if a signal is invalid, it could take on an arbitrary value in that cycle @@ -42,39 +41,39 @@ private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec { private def memCircuit(depth: Int = 32) = s"""circuit m: - | module m: - | input reset : UInt<1> - | input clock : Clock - | input addr : UInt<${Utils.getUIntWidth(depth)}> - | input in : UInt<8> - | output out : UInt<8> - | - | mem m: - | data-type => UInt<8> - | depth => $depth - | reader => r - | writer => w - | read-latency => 0 - | write-latency => 1 - | read-under-write => new - | - | m.w.clk <= clock - | m.w.mask <= UInt(1) - | m.w.en <= UInt(1) - | m.w.data <= in - | m.w.addr <= addr - | - | m.r.clk <= clock - | m.r.en <= UInt(1) - | out <= m.r.data - | m.r.addr <= addr - | - |""".stripMargin + | module m: + | input reset : UInt<1> + | input clock : Clock + | input addr : UInt<${Utils.getUIntWidth(depth)}> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => $depth + | reader => r + | writer => w + | read-latency => 0 + | write-latency => 1 + | read-under-write => new + | + | m.w.clk <= clock + | m.w.mask <= UInt(1) + | m.w.en <= UInt(1) + | m.w.data <= in + | m.w.addr <= addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= addr + | + |""".stripMargin it should "model memories as state" in { val sys = toSys(memCircuit()) - assert(sys.signals.length == 9-2+1, "9 connects - 2 clock connects + 1 combinatorial read port") + assert(sys.signals.length == 9 - 2 + 1, "9 connects - 2 clock connects + 1 combinatorial read port") val sig = sys.signals.map(s => s.name -> s.e).toMap @@ -140,40 +139,39 @@ private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec { it should "support memories with registered read port" in { def src(readUnderWrite: String) = s"""circuit m: - | module m: - | input reset : UInt<1> - | input clock : Clock - | input addr : UInt<5> - | input in : UInt<8> - | output out : UInt<8> - | - | mem m: - | data-type => UInt<8> - | depth => 32 - | reader => r - | writer => w1, w2 - | read-latency => 1 - | write-latency => 1 - | read-under-write => $readUnderWrite - | - | m.w1.clk <= clock - | m.w1.mask <= UInt(1) - | m.w1.en <= UInt(1) - | m.w1.data <= in - | m.w1.addr <= addr - | m.w2.clk <= clock - | m.w2.mask <= UInt(1) - | m.w2.en <= UInt(1) - | m.w2.data <= in - | m.w2.addr <= addr - | - | m.r.clk <= clock - | m.r.en <= UInt(1) - | out <= m.r.data - | m.r.addr <= addr - | - |""".stripMargin - + | module m: + | input reset : UInt<1> + | input clock : Clock + | input addr : UInt<5> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => 32 + | reader => r + | writer => w1, w2 + | read-latency => 1 + | write-latency => 1 + | read-under-write => $readUnderWrite + | + | m.w1.clk <= clock + | m.w1.mask <= UInt(1) + | m.w1.en <= UInt(1) + | m.w1.data <= in + | m.w1.addr <= addr + | m.w2.clk <= clock + | m.w2.mask <= UInt(1) + | m.w2.en <= UInt(1) + | m.w2.data <= in + | m.w2.addr <= addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= addr + | + |""".stripMargin val oldValue = toSys(src("old")) val oldMData = oldValue.states.find(_.sym.name == "m.r.data").get @@ -186,9 +184,11 @@ private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec { val undefinedMData = undefinedValue.states.find(_.sym.name == "m.r.data").get assert(undefinedMData.sym.toString == "m.r.data") val undefined = "RANDOM.m_r_read_under_write_undefined" - assert(undefinedMData.next.get.toString == - s"ite(or(eq(m.r.addr, m.w1.addr), eq(m.r.addr, m.w2.addr)), $undefined, m[m.r.addr])", - "randomize result if there is a write") + assert( + undefinedMData.next.get.toString == + s"ite(or(eq(m.r.addr, m.w1.addr), eq(m.r.addr, m.w2.addr)), $undefined, m[m.r.addr])", + "randomize result if there is a write" + ) } it should "support memories with potential write-write conflicts" in { @@ -228,7 +228,6 @@ private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec { | |""".stripMargin - val sys = toSys(src) val m = sys.states.find(_.sym.name == "m").get diff --git a/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala index 6bfb5437..209279fd 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala @@ -3,7 +3,7 @@ package firrtl.backends.experimental.smt import firrtl.annotations.Annotation -import firrtl.{MemoryInitValue, ir} +import firrtl.{ir, MemoryInitValue} import firrtl.stage.{Forms, TransformManager} import org.scalatest.flatspec.AnyFlatSpec @@ -16,8 +16,12 @@ private abstract class SMTBackendBaseSpec extends AnyFlatSpec { compiler.runTransform(firrtl.CircuitState(c, annos)).circuit } - protected def toSys(src: String, mod: String = "m", presetRegs: Set[String] = Set(), - memInit: Map[String, MemoryInitValue] = Map()): TransitionSystem = { + protected def toSys( + src: String, + mod: String = "m", + presetRegs: Set[String] = Set(), + memInit: Map[String, MemoryInitValue] = Map() + ): TransitionSystem = { val circuit = compile(src) val module = circuit.modules.find(_.name == mod).get.asInstanceOf[ir.Module] // println(module.serialize) @@ -35,4 +39,4 @@ private abstract class SMTBackendBaseSpec extends AnyFlatSpec { protected def toSMTLibStr(src: String, mod: String = "m"): String = toSMTLib(src, mod).mkString("\n") + "\n" -}
\ No newline at end of file +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala index 4c6901ea..e7c8d534 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala @@ -10,9 +10,10 @@ import firrtl.stage.RunFirrtlTransformAnnotation class AsyncResetSpec extends EndToEndSMTBaseSpec { def annos(name: String) = Seq( RunFirrtlTransformAnnotation(Dependency[StutteringClockTransform]), - GlobalClockAnnotation(CircuitTarget(name).module(name).ref("global_clock"))) + GlobalClockAnnotation(CircuitTarget(name).module(name).ref("global_clock")) + ) - "a module with asynchronous reset" should "allow a register to change between clock edges" taggedAs(RequiresZ3) in { + "a module with asynchronous reset" should "allow a register to change between clock edges" taggedAs (RequiresZ3) in { def in(resetType: String) = s"""circuit AsyncReset00: | module AsyncReset00: @@ -39,8 +40,8 @@ class AsyncResetSpec extends EndToEndSMTBaseSpec { | ; can the value of r change without the count changing? | assert(global_clock, or(not(eq(count, past_count)), eq(r, past_r)), past_valid, "count = past(count) |-> r = past(r)") |""".stripMargin - test(in("AsyncReset"), MCFail(1), kmax=2, annos=annos("AsyncReset00")) - test(in("UInt<1>"), MCSuccess, kmax=2, annos=annos("AsyncReset00")) + test(in("AsyncReset"), MCFail(1), kmax = 2, annos = annos("AsyncReset00")) + test(in("UInt<1>"), MCSuccess, kmax = 2, annos = annos("AsyncReset00")) } } diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala index 2227719b..974d2e81 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala @@ -16,24 +16,23 @@ import org.scalatest.matchers.must.Matchers import scala.sys.process._ class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging { - "we" should "check if Z3 is available" taggedAs(RequiresZ3) in { + "we" should "check if Z3 is available" taggedAs (RequiresZ3) in { val log = ProcessLogger(_ => (), logger.warn(_)) val ret = Process(Seq("which", "z3")).run(log).exitValue() - if(ret != 0) { - logger.error( - """The z3 SMT-Solver seems not to be installed. - |You can exclude the end-to-end smt backend tests which rely on z3 like this: - |sbt testOnly -- -l RequiresZ3 - |""".stripMargin) + if (ret != 0) { + logger.error("""The z3 SMT-Solver seems not to be installed. + |You can exclude the end-to-end smt backend tests which rely on z3 like this: + |sbt testOnly -- -l RequiresZ3 + |""".stripMargin) } assert(ret == 0) } - "Z3" should "be available in version 4" taggedAs(RequiresZ3) in { + "Z3" should "be available in version 4" taggedAs (RequiresZ3) in { assert(Z3ModelChecker.getZ3Version.startsWith("4.")) } - "a simple combinatorial check" should "pass" taggedAs(RequiresZ3) in { + "a simple combinatorial check" should "pass" taggedAs (RequiresZ3) in { val in = """circuit CC00: | module CC00: @@ -45,7 +44,7 @@ class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging { test(in, MCSuccess) } - "a simple combinatorial check" should "fail immediately" taggedAs(RequiresZ3) in { + "a simple combinatorial check" should "fail immediately" taggedAs (RequiresZ3) in { val in = """circuit CC01: | module CC01: @@ -57,7 +56,7 @@ class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging { test(in, MCFail(0)) } - "adding the right assumption" should "make a test pass" taggedAs(RequiresZ3) in { + "adding the right assumption" should "make a test pass" taggedAs (RequiresZ3) in { val in0 = """circuit CC01: | module CC01: @@ -75,8 +74,8 @@ class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging { | assert(c, neq(add(a, b), UInt(0)), UInt(1), "a + b != 0") | assume(c, neq(a, UInt(0)), UInt(1), "a != 0") |""".stripMargin - test(in0, MCFail(0)) - test(in1, MCSuccess) + test(in0, MCFail(0)) + test(in1, MCSuccess) val in2 = """circuit CC01: @@ -91,20 +90,20 @@ class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging { test(in2, MCFail(0)) } - "a register connected to preset reset" should "be initialized with the reset value" taggedAs(RequiresZ3) in { + "a register connected to preset reset" should "be initialized with the reset value" taggedAs (RequiresZ3) in { def in(rEq: Int) = s"""circuit Preset00: - | module Preset00: - | input c: Clock - | input preset: AsyncReset - | reg r: UInt<4>, c with: (reset => (preset, UInt(3))) - | assert(c, eq(r, UInt($rEq)), UInt(1), "r = $rEq") - |""".stripMargin + | module Preset00: + | input c: Clock + | input preset: AsyncReset + | reg r: UInt<4>, c with: (reset => (preset, UInt(3))) + | assert(c, eq(r, UInt($rEq)), UInt(1), "r = $rEq") + |""".stripMargin test(in(3), MCSuccess, kmax = 1) test(in(2), MCFail(0)) } - "a register's initial value" should "should not change" taggedAs(RequiresZ3) in { + "a register's initial value" should "should not change" taggedAs (RequiresZ3) in { val in = """circuit Preset00: | module Preset00: @@ -127,24 +126,29 @@ class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging { abstract class EndToEndSMTBaseSpec extends AnyFlatSpec with Matchers { def test(src: String, expected: MCResult, kmax: Int = 0, clue: String = "", annos: Seq[Annotation] = Seq()): Unit = { expected match { - case MCFail(k) => assert(kmax >= k, s"Please set a kmax that includes the expected failing step! ($kmax < $expected)") + case MCFail(k) => + assert(kmax >= k, s"Please set a kmax that includes the expected failing step! ($kmax < $expected)") case _ => } val fir = firrtl.Parser.parse(src) val name = fir.main val testDir = BackendCompilationUtilities.createTestDirectory("EndToEndSMT." + name) // we automagically add a preset annotation if an input called preset exists - val presetAnno = if(!src.contains("input preset")) { None } else { + val presetAnno = if (!src.contains("input preset")) { None } + else { Some(PresetAnnotation(CircuitTarget(name).module(name).ref("preset"))) } - val res = (new FirrtlStage).execute(Array(), Seq( - LogLevelAnnotation(LogLevel.Error), // silence warnings for tests - RunFirrtlTransformAnnotation(new SMTLibEmitter), - RunFirrtlTransformAnnotation(new Btor2Emitter), - FirrtlCircuitAnnotation(fir), - TargetDirAnnotation(testDir.getAbsolutePath) - ) ++ presetAnno ++ annos) - assert(res.collectFirst{ case _: OutputFileAnnotation => true }.isDefined) + val res = (new FirrtlStage).execute( + Array(), + Seq( + LogLevelAnnotation(LogLevel.Error), // silence warnings for tests + RunFirrtlTransformAnnotation(new SMTLibEmitter), + RunFirrtlTransformAnnotation(new Btor2Emitter), + FirrtlCircuitAnnotation(fir), + TargetDirAnnotation(testDir.getAbsolutePath) + ) ++ presetAnno ++ annos + ) + assert(res.collectFirst { case _: OutputFileAnnotation => true }.isDefined) val r = Z3ModelChecker.bmc(testDir, name, kmax) assert(r == expected, clue + "\n" + s"$testDir") } @@ -153,7 +157,7 @@ abstract class EndToEndSMTBaseSpec extends AnyFlatSpec with Matchers { /** Minimal implementation of a Z3 based bounded model checker. * A more complete version of this with better use feedback should eventually be provided by a * chisel3 formal verification library. Do not use this implementation outside of the firrtl test suite! - * */ + */ private object Z3ModelChecker extends LazyLogging { def getZ3Version: String = { val (out, ret) = executeCmd("-version") @@ -164,14 +168,15 @@ private object Z3ModelChecker extends LazyLogging { } def bmc(testDir: File, main: String, kmax: Int): MCResult = { - assert(kmax >=0 && kmax < 50, "Trying to keep kmax in a reasonable range.") + assert(kmax >= 0 && kmax < 50, "Trying to keep kmax in a reasonable range.") val smtFile = new File(testDir, main + ".smt2") val header = read(smtFile) val steps = (0 to kmax).map(k => new File(testDir, main + s"_step$k.smt2")).zipWithIndex - steps.foreach { case (f,k) => - writeStep(f, main, header, k) - val success = executeStep(f.getAbsolutePath) - if(!success) return MCFail(k) + steps.foreach { + case (f, k) => + writeStep(f, main, header, k) + val success = executeStep(f.getAbsolutePath) + if (!success) return MCFail(k) } MCSuccess } @@ -200,21 +205,22 @@ private object Z3ModelChecker extends LazyLogging { private def step(main: String, k: Int): Iterable[String] = { // define all states (0 to k).map(ii => s"(declare-fun s$ii () $main$StateTpe)") ++ - // assert that init holds in state 0 - List(s"(assert ($main$Init s0))") ++ - // assert transition relation - (0 until k).map(ii => s"(assert ($main$Transition s$ii s${ii+1}))") ++ - // assert that assumptions hold in all states - (0 to k).map(ii => s"(assert ($main$Assumes s$ii))") ++ - // assert that assertions hold for all but last state - (0 until k).map(ii => s"(assert ($main$Asserts s$ii))") ++ - // check to see if we can violate the assertions in the last state - List(s"(assert (not ($main$Asserts s$k)))") + // assert that init holds in state 0 + List(s"(assert ($main$Init s0))") ++ + // assert transition relation + (0 until k).map(ii => s"(assert ($main$Transition s$ii s${ii + 1}))") ++ + // assert that assumptions hold in all states + (0 to k).map(ii => s"(assert ($main$Assumes s$ii))") ++ + // assert that assertions hold for all but last state + (0 until k).map(ii => s"(assert ($main$Asserts s$ii))") ++ + // check to see if we can violate the assertions in the last state + List(s"(assert (not ($main$Asserts s$k)))") } private def read(f: File): Iterable[String] = { val source = scala.io.Source.fromFile(f) - try source.getLines().toVector finally source.close() + try source.getLines().toVector + finally source.close() } // the following suffixes have to match the ones in [[SMTTransitionSystemEncoder]] diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala index 10de9cda..61e1f0f8 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala @@ -9,43 +9,43 @@ class MemorySpec extends EndToEndSMTBaseSpec { registeredTestMem(name, cmds.split("\n"), readUnderWrite) private def registeredTestMem(name: String, cmds: Iterable[String], readUnderWrite: String): String = s"""circuit $name: - | module $name: - | input reset : UInt<1> - | input clock : Clock - | input preset: AsyncReset - | input write_addr : UInt<5> - | input read_addr : UInt<5> - | input in : UInt<8> - | output out : UInt<8> - | - | mem m: - | data-type => UInt<8> - | depth => 32 - | reader => r - | writer => w - | read-latency => 1 - | write-latency => 1 - | read-under-write => $readUnderWrite - | - | m.w.clk <= clock - | m.w.mask <= UInt(1) - | m.w.en <= UInt(1) - | m.w.data <= in - | m.w.addr <= write_addr - | - | m.r.clk <= clock - | m.r.en <= UInt(1) - | out <= m.r.data - | m.r.addr <= read_addr - | - | reg cycle: UInt<8>, clock with: (reset => (preset, UInt(0))) - | cycle <= add(cycle, UInt(1)) - | node past_valid = geq(cycle, UInt(1)) - | - | ${cmds.mkString("\n ")} - |""".stripMargin + | module $name: + | input reset : UInt<1> + | input clock : Clock + | input preset: AsyncReset + | input write_addr : UInt<5> + | input read_addr : UInt<5> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => 32 + | reader => r + | writer => w + | read-latency => 1 + | write-latency => 1 + | read-under-write => $readUnderWrite + | + | m.w.clk <= clock + | m.w.mask <= UInt(1) + | m.w.en <= UInt(1) + | m.w.data <= in + | m.w.addr <= write_addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= read_addr + | + | reg cycle: UInt<8>, clock with: (reset => (preset, UInt(0))) + | cycle <= add(cycle, UInt(1)) + | node past_valid = geq(cycle, UInt(1)) + | + | ${cmds.mkString("\n ")} + |""".stripMargin - "Registered test memory" should "return written data after two cycles" taggedAs(RequiresZ3) in { + "Registered test memory" should "return written data after two cycles" taggedAs (RequiresZ3) in { val cmds = """node past_past_valid = geq(cycle, UInt(2)) |reg past_in: UInt<8>, clock @@ -85,23 +85,29 @@ class MemorySpec extends EndToEndSMTBaseSpec { |""".stripMargin private def m(num: Int) = CircuitTarget(s"Mem0$num").module(s"Mem0$num").ref("m") - "read-only memory" should "always return 0" taggedAs(RequiresZ3) in { - test(readOnlyMem("eq(out, UInt(0))", 1), MCSuccess, kmax=2, - annos=Seq(MemoryScalarInitAnnotation(m(1), 0))) + "read-only memory" should "always return 0" taggedAs (RequiresZ3) in { + test(readOnlyMem("eq(out, UInt(0))", 1), MCSuccess, kmax = 2, annos = Seq(MemoryScalarInitAnnotation(m(1), 0))) } - "read-only memory" should "not always return 1" taggedAs(RequiresZ3) in { - test(readOnlyMem("eq(out, UInt(1))", 2), MCFail(0), kmax=2, - annos=Seq(MemoryScalarInitAnnotation(m(2), 0))) + "read-only memory" should "not always return 1" taggedAs (RequiresZ3) in { + test(readOnlyMem("eq(out, UInt(1))", 2), MCFail(0), kmax = 2, annos = Seq(MemoryScalarInitAnnotation(m(2), 0))) } - "read-only memory" should "always return 1 or 2" taggedAs(RequiresZ3) in { - test(readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 3), MCSuccess, kmax=2, - annos=Seq(MemoryArrayInitAnnotation(m(3), Seq(1, 2, 2, 1)))) + "read-only memory" should "always return 1 or 2" taggedAs (RequiresZ3) in { + test( + readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 3), + MCSuccess, + kmax = 2, + annos = Seq(MemoryArrayInitAnnotation(m(3), Seq(1, 2, 2, 1))) + ) } - "read-only memory" should "not always return 1 or 2 or 3" taggedAs(RequiresZ3) in { - test(readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 4), MCFail(0), kmax=2, - annos=Seq(MemoryArrayInitAnnotation(m(4), Seq(1, 2, 2, 3)))) + "read-only memory" should "not always return 1 or 2 or 3" taggedAs (RequiresZ3) in { + test( + readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 4), + MCFail(0), + kmax = 2, + annos = Seq(MemoryArrayInitAnnotation(m(4), Seq(1, 2, 2, 3))) + ) } } diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala index cbf194dd..8ece0e23 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala @@ -13,15 +13,17 @@ import scala.sys.process.{Process, ProcessLogger} /** compiles the regression tests to SMTLib and parses the result with z3 */ class SMTCompilationTest extends AnyFlatSpec with LazyLogging { - it should "generate valid SMTLib for AddNot" taggedAs(RequiresZ3) in { compileAndParse("AddNot") } - it should "generate valid SMTLib for FPU" taggedAs(RequiresZ3) in { compileAndParse("FPU") } + it should "generate valid SMTLib for AddNot" taggedAs (RequiresZ3) in { compileAndParse("AddNot") } + it should "generate valid SMTLib for FPU" taggedAs (RequiresZ3) in { compileAndParse("FPU") } // we get a stack overflow in Scala 2.11 because of a deeply nested and(...) expression in the sequencer - it should "generate valid SMTLib for HwachaSequencer" taggedAs(RequiresZ3) ignore { compileAndParse("HwachaSequencer") } - it should "generate valid SMTLib for ICache" taggedAs(RequiresZ3) in { compileAndParse("ICache") } - it should "generate valid SMTLib for Ops" taggedAs(RequiresZ3) in { compileAndParse("Ops") } + it should "generate valid SMTLib for HwachaSequencer" taggedAs (RequiresZ3) ignore { + compileAndParse("HwachaSequencer") + } + it should "generate valid SMTLib for ICache" taggedAs (RequiresZ3) in { compileAndParse("ICache") } + it should "generate valid SMTLib for Ops" taggedAs (RequiresZ3) in { compileAndParse("Ops") } // TODO: enable Rob test once we support more than 2 write ports on a memory - it should "generate valid SMTLib for Rob" taggedAs(RequiresZ3) ignore { compileAndParse("Rob") } - it should "generate valid SMTLib for RocketCore" taggedAs(RequiresZ3) in { compileAndParse("RocketCore") } + it should "generate valid SMTLib for Rob" taggedAs (RequiresZ3) ignore { compileAndParse("Rob") } + it should "generate valid SMTLib for RocketCore" taggedAs (RequiresZ3) in { compileAndParse("RocketCore") } private def compileAndParse(name: String): Unit = { val testDir = BackendCompilationUtilities.createTestDirectory(name + "-smt") @@ -29,14 +31,18 @@ class SMTCompilationTest extends AnyFlatSpec with LazyLogging { BackendCompilationUtilities.copyResourceToFile(s"/regress/${name}.fir", inputFile) val args = Array( - "-ll", "error", // surpress warnings to keep test output clean - "--target-dir", testDir.toString, - "-i", inputFile.toString, - "-E", "experimental-smt2" + "-ll", + "error", // surpress warnings to keep test output clean + "--target-dir", + testDir.toString, + "-i", + inputFile.toString, + "-E", + "experimental-smt2" // "-fct", "firrtl.backends.experimental.smt.StutteringClockTransform" ) val res = (new FirrtlStage).execute(args, Seq()) - val fileName = res.collectFirst{ case OutputFileAnnotation(file) => file }.get + val fileName = res.collectFirst { case OutputFileAnnotation(file) => file }.get val smtFile = testDir.toString + "/" + fileName + ".smt2" val log = ProcessLogger(_ => (), logger.error(_)) diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala index 8fa80b4c..8682c2ce 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala @@ -5,27 +5,26 @@ package firrtl.backends.experimental.smt.end2end /** undefined values in firrtl are modelled as fresh auxiliary variables (inputs) */ class UndefinedFirrtlSpec extends EndToEndSMTBaseSpec { - "division by zero" should "result in an arbitrary value" taggedAs(RequiresZ3) in { + "division by zero" should "result in an arbitrary value" taggedAs (RequiresZ3) in { // the SMTLib spec defines the result of division by zero to be all 1s // https://cs.nyu.edu/pipermail/smt-lib/2015/000977.html def in(dEq: Int) = - s"""circuit CC00: - | module CC00: - | input c: Clock - | input a: UInt<2> - | input b: UInt<2> - | assume(c, eq(b, UInt(0)), UInt(1), "b = 0") - | node d = div(a, b) - | assert(c, eq(d, UInt($dEq)), UInt(1), "d = $dEq") - |""".stripMargin + s"""circuit CC00: + | module CC00: + | input c: Clock + | input a: UInt<2> + | input b: UInt<2> + | assume(c, eq(b, UInt(0)), UInt(1), "b = 0") + | node d = div(a, b) + | assert(c, eq(d, UInt($dEq)), UInt(1), "d = $dEq") + |""".stripMargin // we try to assert that (d = a / 0) is any fixed value which should be false (0 until 4).foreach { ii => test(in(ii), MCFail(0), 0, s"d = a / 0 = $ii") } } // TODO: rem should probably also be undefined, but the spec isn't 100% clear here - - "invalid signals" should "have an arbitrary values" taggedAs(RequiresZ3) in { + "invalid signals" should "have an arbitrary values" taggedAs (RequiresZ3) in { def in(aEq: Int) = s"""circuit CC00: | module CC00: diff --git a/src/test/scala/firrtl/ir/StructuralHashSpec.scala b/src/test/scala/firrtl/ir/StructuralHashSpec.scala index 17fe0b84..c4622939 100644 --- a/src/test/scala/firrtl/ir/StructuralHashSpec.scala +++ b/src/test/scala/firrtl/ir/StructuralHashSpec.scala @@ -6,11 +6,11 @@ import firrtl.PrimOps._ import org.scalatest.flatspec.AnyFlatSpec class StructuralHashSpec extends AnyFlatSpec { - private def hash(n: DefModule): HashCode = StructuralHash.sha256(n, n => n) - private def hash(c: Circuit): HashCode = StructuralHash.sha256Node(c) + private def hash(n: DefModule): HashCode = StructuralHash.sha256(n, n => n) + private def hash(c: Circuit): HashCode = StructuralHash.sha256Node(c) private def hash(e: Expression): HashCode = StructuralHash.sha256Node(e) - private def hash(t: Type): HashCode = StructuralHash.sha256Node(t) - private def hash(s: Statement): HashCode = StructuralHash.sha256Node(s) + private def hash(t: Type): HashCode = StructuralHash.sha256Node(t) + private def hash(s: Statement): HashCode = StructuralHash.sha256Node(s) private val highFirrtlCompiler = new firrtl.stage.transforms.Compiler( targets = firrtl.stage.Forms.HighForm ) @@ -24,18 +24,18 @@ class StructuralHashSpec extends AnyFlatSpec { highFirrtlCompiler.transform(firrtl.CircuitState(rawFirrtl, Seq())).circuit } - private val b0 = UIntLiteral(0,IntWidth(1)) - private val b1 = UIntLiteral(1,IntWidth(1)) + private val b0 = UIntLiteral(0, IntWidth(1)) + private val b1 = UIntLiteral(1, IntWidth(1)) private val add = DoPrim(Add, Seq(b0, b1), Seq(), UnknownType) it should "generate the same hash if the objects are structurally the same" in { - assert(hash(b0) == hash(UIntLiteral(0,IntWidth(1)))) - assert(hash(b0) != hash(UIntLiteral(1,IntWidth(1)))) - assert(hash(b0) != hash(UIntLiteral(1,IntWidth(2)))) + assert(hash(b0) == hash(UIntLiteral(0, IntWidth(1)))) + assert(hash(b0) != hash(UIntLiteral(1, IntWidth(1)))) + assert(hash(b0) != hash(UIntLiteral(1, IntWidth(2)))) - assert(hash(b1) == hash(UIntLiteral(1,IntWidth(1)))) - assert(hash(b1) != hash(UIntLiteral(0,IntWidth(1)))) - assert(hash(b1) != hash(UIntLiteral(1,IntWidth(2)))) + assert(hash(b1) == hash(UIntLiteral(1, IntWidth(1)))) + assert(hash(b1) != hash(UIntLiteral(0, IntWidth(1)))) + assert(hash(b1) != hash(UIntLiteral(1, IntWidth(2)))) } it should "ignore expression types" in { @@ -84,16 +84,19 @@ class StructuralHashSpec extends AnyFlatSpec { |""".stripMargin assert(hash(parse(a)) != hash(parse(d)), "circuits with different names are always different") - assert(hash(parse(a).modules.head) == hash(parse(d).modules.head), - "modules with different names can be structurally different") + assert( + hash(parse(a).modules.head) == hash(parse(d).modules.head), + "modules with different names can be structurally different" + ) // for the Dedup pass we do need a way to take the port names into account - assert(StructuralHash.sha256WithSignificantPortNames(parse(a).modules.head) != - StructuralHash.sha256WithSignificantPortNames(parse(b).modules.head), - "renaming ports does affect the hash if we ask to") + assert( + StructuralHash.sha256WithSignificantPortNames(parse(a).modules.head) != + StructuralHash.sha256WithSignificantPortNames(parse(b).modules.head), + "renaming ports does affect the hash if we ask to" + ) } - it should "not ignore port names if asked to" in { val e = """circuit a: @@ -119,14 +122,20 @@ class StructuralHashSpec extends AnyFlatSpec { | z <= x |""".stripMargin - assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) != - StructuralHash.sha256WithSignificantPortNames(parse(f).modules.head), - "renaming ports does affect the hash if we ask to") - assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) == - StructuralHash.sha256WithSignificantPortNames(parse(g).modules.head), - "renaming internal wires should never affect the hash") - assert(hash(parse(e).modules.head) == hash(parse(g).modules.head), - "renaming internal wires should never affect the hash") + assert( + StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) != + StructuralHash.sha256WithSignificantPortNames(parse(f).modules.head), + "renaming ports does affect the hash if we ask to" + ) + assert( + StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) == + StructuralHash.sha256WithSignificantPortNames(parse(g).modules.head), + "renaming internal wires should never affect the hash" + ) + assert( + hash(parse(e).modules.head) == hash(parse(g).modules.head), + "renaming internal wires should never affect the hash" + ) } it should "not ignore port bundle names if asked to" in { @@ -154,19 +163,26 @@ class StructuralHashSpec extends AnyFlatSpec { | y.z <= x.x |""".stripMargin - assert(hash(parse(e).modules.head) == hash(parse(f).modules.head), - "renaming port bundles does normally not affect the hash") - assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) != - StructuralHash.sha256WithSignificantPortNames(parse(f).modules.head), - "renaming port bundles does affect the hash if we ask to") - assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) == - StructuralHash.sha256WithSignificantPortNames(parse(g).modules.head), - "renaming internal wire bundles should never affect the hash") - assert(hash(parse(e).modules.head) == hash(parse(g).modules.head), - "renaming internal wire bundles should never affect the hash") + assert( + hash(parse(e).modules.head) == hash(parse(f).modules.head), + "renaming port bundles does normally not affect the hash" + ) + assert( + StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) != + StructuralHash.sha256WithSignificantPortNames(parse(f).modules.head), + "renaming port bundles does affect the hash if we ask to" + ) + assert( + StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) == + StructuralHash.sha256WithSignificantPortNames(parse(g).modules.head), + "renaming internal wire bundles should never affect the hash" + ) + assert( + hash(parse(e).modules.head) == hash(parse(g).modules.head), + "renaming internal wire bundles should never affect the hash" + ) } - it should "fail on Info" in { // it does not make sense to hash Info nodes assertThrows[RuntimeException] { @@ -178,9 +194,9 @@ class StructuralHashSpec extends AnyFlatSpec { def parse(str: String): BundleType = { val src = s"""circuit c: - | module c: - | input z: $str - |""".stripMargin + | module c: + | input z: $str + |""".stripMargin val c = firrtl.Parser.parse(src) val tpe = c.modules.head.ports.head.tpe tpe.asInstanceOf[BundleType] @@ -219,11 +235,15 @@ class StructuralHashSpec extends AnyFlatSpec { // Q: should extmodule portnames always be significant since they map to the verilog pins? // A: It would be a bug for two exmodules in the same circuit to have the same defname but different // port names. This should be detected by an earlier pass and thus we do not have to deal with that situation. - assert(hash(parse(a).modules.head) == hash(parse(b).modules.head), - "two ext modules with the same defname and the same type and number of ports") - assert(StructuralHash.sha256WithSignificantPortNames(parse(a).modules.head) != - StructuralHash.sha256WithSignificantPortNames(parse(b).modules.head), - "two ext modules with significant port names") + assert( + hash(parse(a).modules.head) == hash(parse(b).modules.head), + "two ext modules with the same defname and the same type and number of ports" + ) + assert( + StructuralHash.sha256WithSignificantPortNames(parse(a).modules.head) != + StructuralHash.sha256WithSignificantPortNames(parse(b).modules.head), + "two ext modules with significant port names" + ) } "Blocks and empty statements" should "not affect structural equivalence" in { @@ -269,9 +289,9 @@ class StructuralHashSpec extends AnyFlatSpec { } private case object DebugHasher extends Hasher { - override def update(b: Byte): Unit = println(s"b(${b.toInt & 0xff})") - override def update(i: Int): Unit = println(s"i(${i})") - override def update(l: Long): Unit = println(s"l(${l})") - override def update(s: String): Unit = println(s"s(${s})") + override def update(b: Byte): Unit = println(s"b(${b.toInt & 0xff})") + override def update(i: Int): Unit = println(s"i(${i})") + override def update(l: Long): Unit = println(s"l(${l})") + override def update(s: String): Unit = println(s"s(${s})") override def update(b: Array[Byte]): Unit = println(s"bytes(${b.map(x => x.toInt & 0xff).mkString(", ")})") -}
\ No newline at end of file +} diff --git a/src/test/scala/firrtl/passes/LowerTypesSpec.scala b/src/test/scala/firrtl/passes/LowerTypesSpec.scala index 884e51b8..0575c5da 100644 --- a/src/test/scala/firrtl/passes/LowerTypesSpec.scala +++ b/src/test/scala/firrtl/passes/LowerTypesSpec.scala @@ -8,7 +8,6 @@ import firrtl.stage.TransformManager import firrtl.stage.TransformManager.TransformDependency import org.scalatest.flatspec.AnyFlatSpec - /** Unit test style tests for [[LowerTypes]]. * You can find additional integration style tests in [[firrtlTests.LowerTypesSpec]] */ @@ -31,11 +30,12 @@ class LowerTypesEndToEndSpec extends LowerTypesBaseSpec { | $n is invalid |""".stripMargin val c = CircuitState(firrtl.Parser.parse(src), Seq()) - val c2 = lowerTypesCompiler.execute(c) + val c2 = lowerTypesCompiler.execute(c) val ps = c2.circuit.modules.head.ports.filterNot(p => namespace.contains(p.name)) - ps.map{p => + ps.map { p => val orientation = Utils.to_flip(p.direction) - s"${orientation.serialize}${p.name} : ${p.tpe.serialize}"} + s"${orientation.serialize}${p.name} : ${p.tpe.serialize}" + } } override protected def lower(n: String, tpe: String, namespace: Set[String]): Seq[String] = @@ -50,8 +50,10 @@ abstract class LowerTypesBaseSpec extends AnyFlatSpec { assert(lower("a", "{ a : UInt<1>, b : UInt<1>}") == Seq("a_a : UInt<1>", "a_b : UInt<1>")) assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}") == Seq("a_a : UInt<1>", "a_b_c : UInt<1>")) assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2]}") == Seq("a_a : UInt<1>", "a_b_0 : UInt<1>", "a_b_1 : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]") == - Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>")) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>}[2]") == + Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>") + ) // with conflicts assert(lower("a", "{ a : UInt<1>, b : UInt<1>}", Set("a_a")) == Seq("a__a : UInt<1>", "a__b : UInt<1>")) @@ -63,40 +65,71 @@ abstract class LowerTypesBaseSpec extends AnyFlatSpec { assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}", Set("a_b")) == Seq("a__a : UInt<1>", "a__b_c : UInt<1>")) assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}", Set("a_b_c")) == Seq("a__a : UInt<1>", "a__b_c : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_a")) == - Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_a", "a_b_0")) == - Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_b_0")) == - Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>")) - - assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0")) == - Seq("a__0_a : UInt<1>", "a__0_b : UInt<1>", "a__1_a : UInt<1>", "a__1_b : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_3")) == - Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0_a")) == - Seq("a__0_a : UInt<1>", "a__0_b : UInt<1>", "a__1_a : UInt<1>", "a__1_b : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0_c")) == - Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>")) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_a")) == + Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>") + ) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_a", "a_b_0")) == + Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>") + ) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_b_0")) == + Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>") + ) + + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0")) == + Seq("a__0_a : UInt<1>", "a__0_b : UInt<1>", "a__1_a : UInt<1>", "a__1_b : UInt<1>") + ) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_3")) == + Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>") + ) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0_a")) == + Seq("a__0_a : UInt<1>", "a__0_b : UInt<1>", "a__1_a : UInt<1>", "a__1_b : UInt<1>") + ) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0_c")) == + Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>") + ) // collisions inside the bundle - assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_c : UInt<1>}") == - Seq("a_a : UInt<1>", "a_b__c : UInt<1>", "a_b_c : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_b : UInt<1>}") == - Seq("a_a : UInt<1>", "a_b_c : UInt<1>", "a_b_b : UInt<1>")) - - assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2], b_0 : UInt<1>}") == - Seq("a_a : UInt<1>", "a_b__0 : UInt<1>", "a_b__1 : UInt<1>", "a_b_0 : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2], b_c : UInt<1>}") == - Seq("a_a : UInt<1>", "a_b_0 : UInt<1>", "a_b_1 : UInt<1>", "a_b_c : UInt<1>")) + assert( + lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_c : UInt<1>}") == + Seq("a_a : UInt<1>", "a_b__c : UInt<1>", "a_b_c : UInt<1>") + ) + assert( + lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_b : UInt<1>}") == + Seq("a_a : UInt<1>", "a_b_c : UInt<1>", "a_b_b : UInt<1>") + ) + + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>[2], b_0 : UInt<1>}") == + Seq("a_a : UInt<1>", "a_b__0 : UInt<1>", "a_b__1 : UInt<1>", "a_b_0 : UInt<1>") + ) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>[2], b_c : UInt<1>}") == + Seq("a_a : UInt<1>", "a_b_0 : UInt<1>", "a_b_1 : UInt<1>", "a_b_c : UInt<1>") + ) } it should "correctly lower the orientation" in { assert(lower("a", "{ flip a : UInt<1>, b : UInt<1>}") == Seq("flip a_a : UInt<1>", "a_b : UInt<1>")) - assert(lower("a", "{ flip a : UInt<1>[2], b : UInt<1>}") == - Seq("flip a_a_0 : UInt<1>", "flip a_a_1 : UInt<1>", "a_b : UInt<1>")) - assert(lower("a", "{ a : { flip c : UInt<1>, d : UInt<1>}[2], b : UInt<1>}") == - Seq("flip a_a_0_c : UInt<1>", "a_a_0_d : UInt<1>", "flip a_a_1_c : UInt<1>", "a_a_1_d : UInt<1>", "a_b : UInt<1>") + assert( + lower("a", "{ flip a : UInt<1>[2], b : UInt<1>}") == + Seq("flip a_a_0 : UInt<1>", "flip a_a_1 : UInt<1>", "a_b : UInt<1>") + ) + assert( + lower("a", "{ a : { flip c : UInt<1>, d : UInt<1>}[2], b : UInt<1>}") == + Seq( + "flip a_a_0_c : UInt<1>", + "a_a_0_d : UInt<1>", + "flip a_a_1_c : UInt<1>", + "a_a_1_d : UInt<1>", + "a_b : UInt<1>" + ) ) } } @@ -121,43 +154,45 @@ class LowerTypesRenamingSpec extends AnyFlatSpec { def one(namespace: Set[String], prefix: String): Unit = { val r = lower("a", "{ a : UInt<1>, b : UInt<1>}", namespace) - assert(get(r,a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b"))) - assert(get(r,a.field("a")) == Set(m.ref(prefix + "a"))) - assert(get(r,a.field("b")) == Set(m.ref(prefix + "b"))) + assert(get(r, a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b"))) + assert(get(r, a.field("a")) == Set(m.ref(prefix + "a"))) + assert(get(r, a.field("b")) == Set(m.ref(prefix + "b"))) } one(Set(), "a_") one(Set("a_a"), "a__") def two(namespace: Set[String], prefix: String): Unit = { - val r = lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}", namespace) - assert(get(r,a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b_c"))) - assert(get(r,a.field("a")) == Set(m.ref(prefix + "a"))) - assert(get(r,a.field("b")) == Set(m.ref(prefix + "b_c"))) - assert(get(r,a.field("b").field("c")) == Set(m.ref(prefix + "b_c"))) + val r = lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}", namespace) + assert(get(r, a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b_c"))) + assert(get(r, a.field("a")) == Set(m.ref(prefix + "a"))) + assert(get(r, a.field("b")) == Set(m.ref(prefix + "b_c"))) + assert(get(r, a.field("b").field("c")) == Set(m.ref(prefix + "b_c"))) } two(Set(), "a_") two(Set("a_a"), "a__") def three(namespace: Set[String], prefix: String): Unit = { val r = lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", namespace) - assert(get(r,a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b_0"), m.ref(prefix + "b_1"))) - assert(get(r,a.field("a")) == Set(m.ref(prefix + "a"))) - assert(get(r,a.field("b")) == Set( m.ref(prefix + "b_0"), m.ref(prefix + "b_1"))) - assert(get(r,a.field("b").index(0)) == Set(m.ref(prefix + "b_0"))) - assert(get(r,a.field("b").index(1)) == Set(m.ref(prefix + "b_1"))) + assert(get(r, a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b_0"), m.ref(prefix + "b_1"))) + assert(get(r, a.field("a")) == Set(m.ref(prefix + "a"))) + assert(get(r, a.field("b")) == Set(m.ref(prefix + "b_0"), m.ref(prefix + "b_1"))) + assert(get(r, a.field("b").index(0)) == Set(m.ref(prefix + "b_0"))) + assert(get(r, a.field("b").index(1)) == Set(m.ref(prefix + "b_1"))) } three(Set(), "a_") three(Set("a_b_0"), "a__") def four(namespace: Set[String], prefix: String): Unit = { val r = lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", namespace) - assert(get(r,a) == Set(m.ref(prefix + "0_a"), m.ref(prefix + "1_a"), m.ref(prefix + "0_b"), m.ref(prefix + "1_b"))) - assert(get(r,a.index(0)) == Set(m.ref(prefix + "0_a"), m.ref(prefix + "0_b"))) - assert(get(r,a.index(1)) == Set(m.ref(prefix + "1_a"), m.ref(prefix + "1_b"))) - assert(get(r,a.index(0).field("a")) == Set(m.ref(prefix + "0_a"))) - assert(get(r,a.index(0).field("b")) == Set(m.ref(prefix + "0_b"))) - assert(get(r,a.index(1).field("a")) == Set(m.ref(prefix + "1_a"))) - assert(get(r,a.index(1).field("b")) == Set(m.ref(prefix + "1_b"))) + assert( + get(r, a) == Set(m.ref(prefix + "0_a"), m.ref(prefix + "1_a"), m.ref(prefix + "0_b"), m.ref(prefix + "1_b")) + ) + assert(get(r, a.index(0)) == Set(m.ref(prefix + "0_a"), m.ref(prefix + "0_b"))) + assert(get(r, a.index(1)) == Set(m.ref(prefix + "1_a"), m.ref(prefix + "1_b"))) + assert(get(r, a.index(0).field("a")) == Set(m.ref(prefix + "0_a"))) + assert(get(r, a.index(0).field("b")) == Set(m.ref(prefix + "0_b"))) + assert(get(r, a.index(1).field("a")) == Set(m.ref(prefix + "1_a"))) + assert(get(r, a.index(1).field("b")) == Set(m.ref(prefix + "1_b"))) } four(Set(), "a_") four(Set("a_0"), "a__") @@ -166,28 +201,28 @@ class LowerTypesRenamingSpec extends AnyFlatSpec { // collisions inside the bundle { val r = lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_c : UInt<1>}") - assert(get(r,a) == Set(m.ref("a_a"), m.ref("a_b__c"), m.ref("a_b_c"))) - assert(get(r,a.field("a")) == Set(m.ref("a_a"))) - assert(get(r,a.field("b")) == Set(m.ref("a_b__c"))) - assert(get(r,a.field("b").field("c")) == Set(m.ref("a_b__c"))) - assert(get(r,a.field("b_c")) == Set(m.ref("a_b_c"))) + assert(get(r, a) == Set(m.ref("a_a"), m.ref("a_b__c"), m.ref("a_b_c"))) + assert(get(r, a.field("a")) == Set(m.ref("a_a"))) + assert(get(r, a.field("b")) == Set(m.ref("a_b__c"))) + assert(get(r, a.field("b").field("c")) == Set(m.ref("a_b__c"))) + assert(get(r, a.field("b_c")) == Set(m.ref("a_b_c"))) } { val r = lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_b : UInt<1>}") - assert(get(r,a) == Set(m.ref("a_a"), m.ref("a_b_c"), m.ref("a_b_b"))) - assert(get(r,a.field("a")) == Set(m.ref("a_a"))) - assert(get(r,a.field("b")) == Set(m.ref("a_b_c"))) - assert(get(r,a.field("b").field("c")) == Set(m.ref("a_b_c"))) - assert(get(r,a.field("b_b")) == Set(m.ref("a_b_b"))) + assert(get(r, a) == Set(m.ref("a_a"), m.ref("a_b_c"), m.ref("a_b_b"))) + assert(get(r, a.field("a")) == Set(m.ref("a_a"))) + assert(get(r, a.field("b")) == Set(m.ref("a_b_c"))) + assert(get(r, a.field("b").field("c")) == Set(m.ref("a_b_c"))) + assert(get(r, a.field("b_b")) == Set(m.ref("a_b_b"))) } { val r = lower("a", "{ a : UInt<1>, b : UInt<1>[2], b_0 : UInt<1>}") - assert(get(r,a) == Set(m.ref("a_a"), m.ref("a_b__0"), m.ref("a_b__1"), m.ref("a_b_0"))) - assert(get(r,a.field("a")) == Set(m.ref("a_a"))) - assert(get(r,a.field("b")) == Set(m.ref("a_b__0"), m.ref("a_b__1"))) - assert(get(r,a.field("b").index(0)) == Set(m.ref("a_b__0"))) - assert(get(r,a.field("b").index(1)) == Set(m.ref("a_b__1"))) - assert(get(r,a.field("b_0")) == Set(m.ref("a_b_0"))) + assert(get(r, a) == Set(m.ref("a_a"), m.ref("a_b__0"), m.ref("a_b__1"), m.ref("a_b_0"))) + assert(get(r, a.field("a")) == Set(m.ref("a_a"))) + assert(get(r, a.field("b")) == Set(m.ref("a_b__0"), m.ref("a_b__1"))) + assert(get(r, a.field("b").index(0)) == Set(m.ref("a_b__0"))) + assert(get(r, a.field("b").index(1)) == Set(m.ref("a_b__1"))) + assert(get(r, a.field("b_0")) == Set(m.ref("a_b_0"))) } } } @@ -199,8 +234,13 @@ class LowerTypesOfInstancesSpec extends AnyFlatSpec { private val m = CircuitTarget("m").module("m") def resultToFieldSeq(res: Seq[(String, firrtl.ir.SubField)]): Seq[String] = res.map(_._2).map(r => s"${r.name} : ${r.tpe.serialize}") - private def lower(n: String, tpe: String, module: String, namespace: Set[String], renames: RenameMap = RenameMap()): - Lower = { + private def lower( + n: String, + tpe: String, + module: String, + namespace: Set[String], + renames: RenameMap = RenameMap() + ): Lower = { val ref = firrtl.ir.DefInstance(firrtl.ir.NoInfo, n, module, parseType(tpe)) val mutableSet = scala.collection.mutable.HashSet[String]() ++ namespace val (newInstance, res) = DestructTypes.destructInstance(m, ref, mutableSet, renames, Set()) @@ -269,7 +309,7 @@ class LowerTypesOfInstancesSpec extends AnyFlatSpec { assert(r.get(i.ref("b")).get == Seq(i_.ref("b__c"))) assert(r.get(i.ref("b").field("c")).get == Seq(i_.ref("b__c"))) assert(r.get(i.ref("b_c")).get == Seq(i_.ref("b_c"))) - } + } } } @@ -278,101 +318,139 @@ class LowerTypesOfInstancesSpec extends AnyFlatSpec { */ class LowerTypesOfMemorySpec extends AnyFlatSpec { import LowerTypesSpecUtils._ - private case class Lower(mems: Seq[firrtl.ir.DefMemory], refs: Seq[(String, firrtl.ir.SubField)], - renameMap: RenameMap) + private case class Lower( + mems: Seq[firrtl.ir.DefMemory], + refs: Seq[(String, firrtl.ir.SubField)], + renameMap: RenameMap) private val m = CircuitTarget("m").module("m") private val mem = m.ref("mem") - private def lower(name: String, tpe: String, namespace: Set[String], - r: Seq[String] = List("r"), w: Seq[String] = List("w"), rw: Seq[String] = List(), depth: Int = 2): Lower = { + private def lower( + name: String, + tpe: String, + namespace: Set[String], + r: Seq[String] = List("r"), + w: Seq[String] = List("w"), + rw: Seq[String] = List(), + depth: Int = 2 + ): Lower = { val dataType = parseType(tpe) - val mem = firrtl.ir.DefMemory(firrtl.ir.NoInfo, name, dataType, depth = depth, writeLatency = 1, readLatency = 1, - readUnderWrite = firrtl.ir.ReadUnderWrite.Undefined, readers = r, writers = w, readwriters = rw) + val mem = firrtl.ir.DefMemory( + firrtl.ir.NoInfo, + name, + dataType, + depth = depth, + writeLatency = 1, + readLatency = 1, + readUnderWrite = firrtl.ir.ReadUnderWrite.Undefined, + readers = r, + writers = w, + readwriters = rw + ) val renames = RenameMap() val mutableSet = scala.collection.mutable.HashSet[String]() ++ namespace - val(mems, refs) = DestructTypes.destructMemory(m, mem, mutableSet, renames, Set()) + val (mems, refs) = DestructTypes.destructMemory(m, mem, mutableSet, renames, Set()) Lower(mems, refs, renames) } private val UInt1 = firrtl.ir.UIntType(firrtl.ir.IntWidth(1)) it should "not rename anything for a ground type memory if there was no conflict" in { - val l = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w=Seq("w")) + val l = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w = Seq("w")) assert(l.renameMap.underlying.isEmpty) } it should "still produce reference lookups, even for a ground type memory with no conflicts" in { - val nameToRef = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w=Seq("w")).refs - .map{case (n,r) => n -> r.serialize}.toSet - - assert(nameToRef == Set( - "mem.r.clk" -> "mem.r.clk", - "mem.r.en" -> "mem.r.en", - "mem.r.addr" -> "mem.r.addr", - "mem.r.data" -> "mem.r.data", - "mem.w.clk" -> "mem.w.clk", - "mem.w.en" -> "mem.w.en", - "mem.w.addr" -> "mem.w.addr", - "mem.w.data" -> "mem.w.data", - "mem.w.mask" -> "mem.w.mask" - )) + val nameToRef = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w = Seq("w")).refs.map { + case (n, r) => n -> r.serialize + }.toSet + + assert( + nameToRef == Set( + "mem.r.clk" -> "mem.r.clk", + "mem.r.en" -> "mem.r.en", + "mem.r.addr" -> "mem.r.addr", + "mem.r.data" -> "mem.r.data", + "mem.w.clk" -> "mem.w.clk", + "mem.w.en" -> "mem.w.en", + "mem.w.addr" -> "mem.w.addr", + "mem.w.data" -> "mem.w.data", + "mem.w.mask" -> "mem.w.mask" + ) + ) } it should "produce references of correct type" in { - val nameToType = lower("mem", "UInt<4>", Set("mem_r", "mem_r_data"), w=Seq("w"), depth = 3).refs - .map{case (n,r) => n -> r.tpe.serialize}.toSet - - assert(nameToType == Set( - "mem.r.clk" -> "Clock", - "mem.r.en" -> "UInt<1>", - "mem.r.addr" -> "UInt<2>", // depth = 3 - "mem.r.data" -> "UInt<4>", - "mem.w.clk" -> "Clock", - "mem.w.en" -> "UInt<1>", - "mem.w.addr" -> "UInt<2>", - "mem.w.data" -> "UInt<4>", - "mem.w.mask" -> "UInt<1>" - )) + val nameToType = lower("mem", "UInt<4>", Set("mem_r", "mem_r_data"), w = Seq("w"), depth = 3).refs.map { + case (n, r) => n -> r.tpe.serialize + }.toSet + + assert( + nameToType == Set( + "mem.r.clk" -> "Clock", + "mem.r.en" -> "UInt<1>", + "mem.r.addr" -> "UInt<2>", // depth = 3 + "mem.r.data" -> "UInt<4>", + "mem.w.clk" -> "Clock", + "mem.w.en" -> "UInt<1>", + "mem.w.addr" -> "UInt<2>", + "mem.w.data" -> "UInt<4>", + "mem.w.mask" -> "UInt<1>" + ) + ) } it should "not rename ground type memories even if there are conflicts on the ports" in { // There actually isn't such a thing as conflicting ports, because they do not get flattened by LowerTypes. - val r = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w=Seq("r_data")).renameMap + val r = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w = Seq("r_data")).renameMap assert(r.underlying.isEmpty) } it should "rename references to lowered ports" in { - val r = lower("mem", "{ a : UInt<1>, b : UInt<1>}", Set("mem_a"), r=Seq("r", "r_data")).renameMap + val r = lower("mem", "{ a : UInt<1>, b : UInt<1>}", Set("mem_a"), r = Seq("r", "r_data")).renameMap // complete memory assert(get(r, mem) == Set(m.ref("mem__a"), m.ref("mem__b"))) // read ports - assert(get(r, mem.field("r")) == - Set(m.ref("mem__a").field("r"), m.ref("mem__b").field("r"))) - assert(get(r, mem.field("r_data")) == - Set(m.ref("mem__a").field("r_data"), m.ref("mem__b").field("r_data"))) + assert( + get(r, mem.field("r")) == + Set(m.ref("mem__a").field("r"), m.ref("mem__b").field("r")) + ) + assert( + get(r, mem.field("r_data")) == + Set(m.ref("mem__a").field("r_data"), m.ref("mem__b").field("r_data")) + ) // port fields - assert(get(r, mem.field("r").field("data")) == - Set(m.ref("mem__a").field("r").field("data"), - m.ref("mem__b").field("r").field("data"))) - assert(get(r, mem.field("r").field("addr")) == - Set(m.ref("mem__a").field("r").field("addr"), - m.ref("mem__b").field("r").field("addr"))) - assert(get(r, mem.field("r").field("en")) == - Set(m.ref("mem__a").field("r").field("en"), - m.ref("mem__b").field("r").field("en"))) - assert(get(r, mem.field("r").field("clk")) == - Set(m.ref("mem__a").field("r").field("clk"), - m.ref("mem__b").field("r").field("clk"))) - assert(get(r, mem.field("w").field("mask")) == - Set(m.ref("mem__a").field("w").field("mask"), - m.ref("mem__b").field("w").field("mask"))) + assert( + get(r, mem.field("r").field("data")) == + Set(m.ref("mem__a").field("r").field("data"), m.ref("mem__b").field("r").field("data")) + ) + assert( + get(r, mem.field("r").field("addr")) == + Set(m.ref("mem__a").field("r").field("addr"), m.ref("mem__b").field("r").field("addr")) + ) + assert( + get(r, mem.field("r").field("en")) == + Set(m.ref("mem__a").field("r").field("en"), m.ref("mem__b").field("r").field("en")) + ) + assert( + get(r, mem.field("r").field("clk")) == + Set(m.ref("mem__a").field("r").field("clk"), m.ref("mem__b").field("r").field("clk")) + ) + assert( + get(r, mem.field("w").field("mask")) == + Set(m.ref("mem__a").field("w").field("mask"), m.ref("mem__b").field("w").field("mask")) + ) // port sub-fields - assert(get(r, mem.field("r").field("data").field("a")) == - Set(m.ref("mem__a").field("r").field("data"))) - assert(get(r, mem.field("r").field("data").field("b")) == - Set(m.ref("mem__b").field("r").field("data"))) + assert( + get(r, mem.field("r").field("data").field("a")) == + Set(m.ref("mem__a").field("r").field("data")) + ) + assert( + get(r, mem.field("r").field("data").field("b")) == + Set(m.ref("mem__b").field("r").field("data")) + ) // need to rename the following: // mem -> mem__a, mem__b @@ -395,24 +473,38 @@ class LowerTypesOfMemorySpec extends AnyFlatSpec { assert(get(r, mem) == Set(m.ref("mem__a"), m.ref("mem__b_c"))) // read port - assert(get(r, mem.field("r")) == - Set(m.ref("mem__a").field("r"), m.ref("mem__b_c").field("r"))) + assert( + get(r, mem.field("r")) == + Set(m.ref("mem__a").field("r"), m.ref("mem__b_c").field("r")) + ) // port sub-fields - assert(get(r, mem.field("r").field("data").field("a")) == - Set(m.ref("mem__a").field("r").field("data"))) - assert(get(r, mem.field("r").field("data").field("b")) == - Set(m.ref("mem__b_c").field("r").field("data"))) - assert(get(r, mem.field("r").field("data").field("b").field("c")) == - Set(m.ref("mem__b_c").field("r").field("data"))) + assert( + get(r, mem.field("r").field("data").field("a")) == + Set(m.ref("mem__a").field("r").field("data")) + ) + assert( + get(r, mem.field("r").field("data").field("b")) == + Set(m.ref("mem__b_c").field("r").field("data")) + ) + assert( + get(r, mem.field("r").field("data").field("b").field("c")) == + Set(m.ref("mem__b_c").field("r").field("data")) + ) // the mask field needs to be lowered just like the data field - assert(get(r, mem.field("w").field("mask").field("a")) == - Set(m.ref("mem__a").field("w").field("mask"))) - assert(get(r, mem.field("w").field("mask").field("b")) == - Set(m.ref("mem__b_c").field("w").field("mask"))) - assert(get(r, mem.field("w").field("mask").field("b").field("c")) == - Set(m.ref("mem__b_c").field("w").field("mask"))) + assert( + get(r, mem.field("w").field("mask").field("a")) == + Set(m.ref("mem__a").field("w").field("mask")) + ) + assert( + get(r, mem.field("w").field("mask").field("b")) == + Set(m.ref("mem__b_c").field("w").field("mask")) + ) + assert( + get(r, mem.field("w").field("mask").field("b").field("c")) == + Set(m.ref("mem__b_c").field("w").field("mask")) + ) val renameCount = r.underlying.map(_._2.size).sum assert(renameCount == 11, "it is enough to rename *to* 11 different signals") @@ -420,66 +512,89 @@ class LowerTypesOfMemorySpec extends AnyFlatSpec { } it should "return a name to RefLikeExpression map for a memory with a nested data type" in { - val nameToRef = lower("mem", "{ a : UInt<1>, b : { c : UInt<1>} }", Set("mem_a")).refs - .map{case (n,r) => n -> r.serialize}.toSet - - assert(nameToRef == Set( - // The non "data" or "mask" fields of read and write ports are already of ground type but still do get duplicated. - // They will all carry the exact same value, so for a RHS use of the old signal, any of the expanded ones will do. - "mem.r.clk" -> "mem__a.r.clk", "mem.r.clk" -> "mem__b_c.r.clk", - "mem.r.en" -> "mem__a.r.en", "mem.r.en" -> "mem__b_c.r.en", - "mem.r.addr" -> "mem__a.r.addr", "mem.r.addr" -> "mem__b_c.r.addr", - "mem.w.clk" -> "mem__a.w.clk", "mem.w.clk" -> "mem__b_c.w.clk", - "mem.w.en" -> "mem__a.w.en", "mem.w.en" -> "mem__b_c.w.en", - "mem.w.addr" -> "mem__a.w.addr", "mem.w.addr" -> "mem__b_c.w.addr", - // Ground type references to the data or mask field are unique. - "mem.r.data.a" -> "mem__a.r.data", - "mem.w.data.a" -> "mem__a.w.data", - "mem.w.mask.a" -> "mem__a.w.mask", - "mem.r.data.b.c" -> "mem__b_c.r.data", - "mem.w.data.b.c" -> "mem__b_c.w.data", - "mem.w.mask.b.c" -> "mem__b_c.w.mask" - )) + val nameToRef = lower("mem", "{ a : UInt<1>, b : { c : UInt<1>} }", Set("mem_a")).refs.map { + case (n, r) => n -> r.serialize + }.toSet + + assert( + nameToRef == Set( + // The non "data" or "mask" fields of read and write ports are already of ground type but still do get duplicated. + // They will all carry the exact same value, so for a RHS use of the old signal, any of the expanded ones will do. + "mem.r.clk" -> "mem__a.r.clk", + "mem.r.clk" -> "mem__b_c.r.clk", + "mem.r.en" -> "mem__a.r.en", + "mem.r.en" -> "mem__b_c.r.en", + "mem.r.addr" -> "mem__a.r.addr", + "mem.r.addr" -> "mem__b_c.r.addr", + "mem.w.clk" -> "mem__a.w.clk", + "mem.w.clk" -> "mem__b_c.w.clk", + "mem.w.en" -> "mem__a.w.en", + "mem.w.en" -> "mem__b_c.w.en", + "mem.w.addr" -> "mem__a.w.addr", + "mem.w.addr" -> "mem__b_c.w.addr", + // Ground type references to the data or mask field are unique. + "mem.r.data.a" -> "mem__a.r.data", + "mem.w.data.a" -> "mem__a.w.data", + "mem.w.mask.a" -> "mem__a.w.mask", + "mem.r.data.b.c" -> "mem__b_c.r.data", + "mem.w.data.b.c" -> "mem__b_c.w.data", + "mem.w.mask.b.c" -> "mem__b_c.w.mask" + ) + ) } it should "produce references of correct type for memories with a read/write port" in { - val refs = lower("mem", "{ a : UInt<3>, b : { c : UInt<4>} }", Set("mem_a"), - r=Seq(), w=Seq(), rw=Seq("rw"), depth = 3).refs - val nameToRef = refs.map{case (n,r) => n -> r.serialize}.toSet - val nameToType = refs.map{case (n,r) => n -> r.tpe.serialize}.toSet - - assert(nameToRef == Set( - // The non "data" or "mask" fields of read and write ports are already of ground type but still do get duplicated. - // They will all carry the exact same value, so for a RHS use of the old signal, any of the expanded ones will do. - "mem.rw.clk" -> "mem__a.rw.clk", "mem.rw.clk" -> "mem__b_c.rw.clk", - "mem.rw.en" -> "mem__a.rw.en", "mem.rw.en" -> "mem__b_c.rw.en", - "mem.rw.addr" -> "mem__a.rw.addr", "mem.rw.addr" -> "mem__b_c.rw.addr", - "mem.rw.wmode" -> "mem__a.rw.wmode", "mem.rw.wmode" -> "mem__b_c.rw.wmode", - // Ground type references to the data or mask field are unique. - "mem.rw.rdata.a" -> "mem__a.rw.rdata", - "mem.rw.wdata.a" -> "mem__a.rw.wdata", - "mem.rw.wmask.a" -> "mem__a.rw.wmask", - "mem.rw.rdata.b.c" -> "mem__b_c.rw.rdata", - "mem.rw.wdata.b.c" -> "mem__b_c.rw.wdata", - "mem.rw.wmask.b.c" -> "mem__b_c.rw.wmask" - )) - - assert(nameToType == Set( - // - "mem.rw.clk" -> "Clock", - "mem.rw.en" -> "UInt<1>", - "mem.rw.addr" -> "UInt<2>", - "mem.rw.wmode" -> "UInt<1>", - // Ground type references to the data or mask field are unique. - "mem.rw.rdata.a" -> "UInt<3>", - "mem.rw.wdata.a" -> "UInt<3>", - "mem.rw.wmask.a" -> "UInt<1>", - "mem.rw.rdata.b.c" -> "UInt<4>", - "mem.rw.wdata.b.c" -> "UInt<4>", - "mem.rw.wmask.b.c" -> "UInt<1>" - )) - } + val refs = lower( + "mem", + "{ a : UInt<3>, b : { c : UInt<4>} }", + Set("mem_a"), + r = Seq(), + w = Seq(), + rw = Seq("rw"), + depth = 3 + ).refs + val nameToRef = refs.map { case (n, r) => n -> r.serialize }.toSet + val nameToType = refs.map { case (n, r) => n -> r.tpe.serialize }.toSet + + assert( + nameToRef == Set( + // The non "data" or "mask" fields of read and write ports are already of ground type but still do get duplicated. + // They will all carry the exact same value, so for a RHS use of the old signal, any of the expanded ones will do. + "mem.rw.clk" -> "mem__a.rw.clk", + "mem.rw.clk" -> "mem__b_c.rw.clk", + "mem.rw.en" -> "mem__a.rw.en", + "mem.rw.en" -> "mem__b_c.rw.en", + "mem.rw.addr" -> "mem__a.rw.addr", + "mem.rw.addr" -> "mem__b_c.rw.addr", + "mem.rw.wmode" -> "mem__a.rw.wmode", + "mem.rw.wmode" -> "mem__b_c.rw.wmode", + // Ground type references to the data or mask field are unique. + "mem.rw.rdata.a" -> "mem__a.rw.rdata", + "mem.rw.wdata.a" -> "mem__a.rw.wdata", + "mem.rw.wmask.a" -> "mem__a.rw.wmask", + "mem.rw.rdata.b.c" -> "mem__b_c.rw.rdata", + "mem.rw.wdata.b.c" -> "mem__b_c.rw.wdata", + "mem.rw.wmask.b.c" -> "mem__b_c.rw.wmask" + ) + ) + assert( + nameToType == Set( + // + "mem.rw.clk" -> "Clock", + "mem.rw.en" -> "UInt<1>", + "mem.rw.addr" -> "UInt<2>", + "mem.rw.wmode" -> "UInt<1>", + // Ground type references to the data or mask field are unique. + "mem.rw.rdata.a" -> "UInt<3>", + "mem.rw.wdata.a" -> "UInt<3>", + "mem.rw.wmask.a" -> "UInt<1>", + "mem.rw.rdata.b.c" -> "UInt<4>", + "mem.rw.wdata.b.c" -> "UInt<4>", + "mem.rw.wmask.b.c" -> "UInt<1>" + ) + ) + } it should "rename references for vector type memories" in { val l = lower("mem", "UInt<1>[2]", Set("mem_0")) @@ -491,14 +606,20 @@ class LowerTypesOfMemorySpec extends AnyFlatSpec { assert(get(r, mem) == Set(m.ref("mem__0"), m.ref("mem__1"))) // read port - assert(get(r, mem.field("r")) == - Set(m.ref("mem__0").field("r"), m.ref("mem__1").field("r"))) + assert( + get(r, mem.field("r")) == + Set(m.ref("mem__0").field("r"), m.ref("mem__1").field("r")) + ) // port sub-fields - assert(get(r, mem.field("r").field("data").index(0)) == - Set(m.ref("mem__0").field("r").field("data"))) - assert(get(r, mem.field("r").field("data").index(1)) == - Set(m.ref("mem__1").field("r").field("data"))) + assert( + get(r, mem.field("r").field("data").index(0)) == + Set(m.ref("mem__0").field("r").field("data")) + ) + assert( + get(r, mem.field("r").field("data").index(1)) == + Set(m.ref("mem__1").field("r").field("data")) + ) val renameCount = r.underlying.map(_._2.size).sum assert(renameCount == 8, "it is enough to rename *to* 8 different signals") diff --git a/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala b/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala index 007608ca..0b9b830c 100644 --- a/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala +++ b/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala @@ -8,7 +8,14 @@ import java.io.File import firrtl._ import firrtl.stage.phases.DriverCompatibility._ import firrtl.options.{InputAnnotationFileAnnotation, Phase, TargetDirAnnotation} -import firrtl.stage.{CompilerAnnotation, FirrtlCircuitAnnotation, FirrtlFileAnnotation, FirrtlSourceAnnotation, OutputFileAnnotation, RunFirrtlTransformAnnotation} +import firrtl.stage.{ + CompilerAnnotation, + FirrtlCircuitAnnotation, + FirrtlFileAnnotation, + FirrtlSourceAnnotation, + OutputFileAnnotation, + RunFirrtlTransformAnnotation +} import firrtl.stage.phases.DriverCompatibility import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers @@ -20,7 +27,7 @@ class DriverCompatibilitySpec extends AnyFlatSpec with Matchers with PrivateMeth /* This method wraps some magic that lets you use the private method DriverCompatibility.topName */ def topName(annotations: AnnotationSeq): Option[String] = { val topName = PrivateMethod[Option[String]]('topName) - DriverCompatibility invokePrivate topName(annotations) + DriverCompatibility.invokePrivate(topName(annotations)) } def simpleCircuit(main: String): String = s"""|circuit $main: @@ -41,22 +48,22 @@ class DriverCompatibilitySpec extends AnyFlatSpec with Matchers with PrivateMeth (FirrtlFileAnnotation("src/test/resources/integration/GCDTester.pb"), "GCDTester") ) - behavior of s"${DriverCompatibility.getClass.getName}.topName (private method)" + behavior.of(s"${DriverCompatibility.getClass.getName}.topName (private method)") /* This iterates over the tails of annosWithTops. Using the ordering of annosWithTops, if this AnnotationSeq is fed to * DriverCompatibility.topName, the head annotation will be used to determine the top name. This test ensures that * topName behaves as expected. */ - for ( t <- annosWithTops.tails ) t match { + for (t <- annosWithTops.tails) t match { case Nil => it should "return None on an empty AnnotationSeq" in { - topName(Seq.empty) should be (None) + topName(Seq.empty) should be(None) } case x => val annotations = x.map(_._1) val top = x.head._2 it should s"determine a top name ('$top') from a ${annotations.head.getClass.getName}" in { - topName(annotations).get should be (top) + topName(annotations).get should be(top) } } @@ -66,152 +73,148 @@ class DriverCompatibilitySpec extends AnyFlatSpec with Matchers with PrivateMeth file.createNewFile() } - behavior of classOf[AddImplicitAnnotationFile].toString + behavior.of(classOf[AddImplicitAnnotationFile].toString) val testDir = "test_run_dir/DriverCompatibilitySpec" it should "not modify the annotations if an InputAnnotationFile already exists" in - new PhaseFixture(new AddImplicitAnnotationFile) { + new PhaseFixture(new AddImplicitAnnotationFile) { - createFile(testDir + "/foo.anno") - val annotations = Seq( - InputAnnotationFileAnnotation("bar.anno"), - TargetDirAnnotation(testDir), - TopNameAnnotation("foo") ) + createFile(testDir + "/foo.anno") + val annotations = + Seq(InputAnnotationFileAnnotation("bar.anno"), TargetDirAnnotation(testDir), TopNameAnnotation("foo")) - phase.transform(annotations).toSeq should be (annotations) - } + phase.transform(annotations).toSeq should be(annotations) + } it should "add an InputAnnotationFile based on a derived topName" in - new PhaseFixture(new AddImplicitAnnotationFile) { - createFile(testDir + "/bar.anno") - val annotations = Seq( - TargetDirAnnotation(testDir), - TopNameAnnotation("bar") ) + new PhaseFixture(new AddImplicitAnnotationFile) { + createFile(testDir + "/bar.anno") + val annotations = Seq(TargetDirAnnotation(testDir), TopNameAnnotation("bar")) - val expected = annotations.toSet + - InputAnnotationFileAnnotation(testDir + "/bar.anno") + val expected = annotations.toSet + + InputAnnotationFileAnnotation(testDir + "/bar.anno") - phase.transform(annotations).toSet should be (expected) - } + phase.transform(annotations).toSet should be(expected) + } it should "not add an InputAnnotationFile for .anno.json annotations" in - new PhaseFixture(new AddImplicitAnnotationFile) { - createFile(testDir + "/baz.anno.json") - val annotations = Seq( - TargetDirAnnotation(testDir), - TopNameAnnotation("baz") ) + new PhaseFixture(new AddImplicitAnnotationFile) { + createFile(testDir + "/baz.anno.json") + val annotations = Seq(TargetDirAnnotation(testDir), TopNameAnnotation("baz")) - phase.transform(annotations).toSeq should be (annotations) - } + phase.transform(annotations).toSeq should be(annotations) + } it should "not add an InputAnnotationFile if it cannot determine the topName" in - new PhaseFixture(new AddImplicitAnnotationFile) { - val annotations = Seq( TargetDirAnnotation(testDir) ) + new PhaseFixture(new AddImplicitAnnotationFile) { + val annotations = Seq(TargetDirAnnotation(testDir)) - phase.transform(annotations).toSeq should be (annotations) - } + phase.transform(annotations).toSeq should be(annotations) + } - behavior of classOf[AddImplicitFirrtlFile].toString + behavior.of(classOf[AddImplicitFirrtlFile].toString) it should "not modify the annotations if a CircuitOption is present" in - new PhaseFixture(new AddImplicitFirrtlFile) { - val annotations = Seq( FirrtlFileAnnotation("foo"), TopNameAnnotation("bar") ) + new PhaseFixture(new AddImplicitFirrtlFile) { + val annotations = Seq(FirrtlFileAnnotation("foo"), TopNameAnnotation("bar")) - phase.transform(annotations).toSeq should be (annotations) - } + phase.transform(annotations).toSeq should be(annotations) + } it should "add an FirrtlFileAnnotation if a TopNameAnnotation is present" in - new PhaseFixture(new AddImplicitFirrtlFile) { - val annotations = Seq( TopNameAnnotation("foo") ) - val expected = annotations.toSet + - FirrtlFileAnnotation(new File("foo.fir").getPath()) + new PhaseFixture(new AddImplicitFirrtlFile) { + val annotations = Seq(TopNameAnnotation("foo")) + val expected = annotations.toSet + + FirrtlFileAnnotation(new File("foo.fir").getPath()) - phase.transform(annotations).toSet should be (expected) - } + phase.transform(annotations).toSet should be(expected) + } it should "do nothing if no TopNameAnnotation is present" in - new PhaseFixture(new AddImplicitFirrtlFile) { - val annotations = Seq( TargetDirAnnotation("foo") ) + new PhaseFixture(new AddImplicitFirrtlFile) { + val annotations = Seq(TargetDirAnnotation("foo")) - phase.transform(annotations).toSeq should be (annotations) - } + phase.transform(annotations).toSeq should be(annotations) + } - behavior of classOf[AddImplicitEmitter].toString + behavior.of(classOf[AddImplicitEmitter].toString) - val (nc, hfc, mfc, lfc, vc, svc) = ( new NoneCompiler, - new HighFirrtlCompiler, - new MiddleFirrtlCompiler, - new LowFirrtlCompiler, - new VerilogCompiler, - new SystemVerilogCompiler ) + val (nc, hfc, mfc, lfc, vc, svc) = ( + new NoneCompiler, + new HighFirrtlCompiler, + new MiddleFirrtlCompiler, + new LowFirrtlCompiler, + new VerilogCompiler, + new SystemVerilogCompiler + ) it should "convert CompilerAnnotations into EmitCircuitAnnotations without EmitOneFilePerModuleAnnotation" in - new PhaseFixture(new AddImplicitEmitter) { - val annotations = Seq( - CompilerAnnotation(nc), - CompilerAnnotation(hfc), - CompilerAnnotation(mfc), - CompilerAnnotation(lfc), - CompilerAnnotation(vc), - CompilerAnnotation(svc) - ) - val expected = annotations - .flatMap( a => Seq(a, - RunFirrtlTransformAnnotation(a.compiler.emitter), - EmitCircuitAnnotation(a.compiler.emitter.getClass)) ) - - phase.transform(annotations).toSeq should be (expected) - } + new PhaseFixture(new AddImplicitEmitter) { + val annotations = Seq( + CompilerAnnotation(nc), + CompilerAnnotation(hfc), + CompilerAnnotation(mfc), + CompilerAnnotation(lfc), + CompilerAnnotation(vc), + CompilerAnnotation(svc) + ) + val expected = annotations + .flatMap(a => + Seq(a, RunFirrtlTransformAnnotation(a.compiler.emitter), EmitCircuitAnnotation(a.compiler.emitter.getClass)) + ) + + phase.transform(annotations).toSeq should be(expected) + } it should "convert CompilerAnnotations into EmitAllodulesAnnotation with EmitOneFilePerModuleAnnotation" in - new PhaseFixture(new AddImplicitEmitter) { - val annotations = Seq( - EmitOneFilePerModuleAnnotation, - CompilerAnnotation(nc), - CompilerAnnotation(hfc), - CompilerAnnotation(mfc), - CompilerAnnotation(lfc), - CompilerAnnotation(vc), - CompilerAnnotation(svc) - ) - val expected = annotations - .flatMap{ - case a: CompilerAnnotation => Seq(a, - RunFirrtlTransformAnnotation(a.compiler.emitter), - EmitAllModulesAnnotation(a.compiler.emitter.getClass)) + new PhaseFixture(new AddImplicitEmitter) { + val annotations = Seq( + EmitOneFilePerModuleAnnotation, + CompilerAnnotation(nc), + CompilerAnnotation(hfc), + CompilerAnnotation(mfc), + CompilerAnnotation(lfc), + CompilerAnnotation(vc), + CompilerAnnotation(svc) + ) + val expected = annotations.flatMap { + case a: CompilerAnnotation => + Seq( + a, + RunFirrtlTransformAnnotation(a.compiler.emitter), + EmitAllModulesAnnotation(a.compiler.emitter.getClass) + ) case a => Seq(a) } - phase.transform(annotations).toSeq should be (expected) - } + phase.transform(annotations).toSeq should be(expected) + } - behavior of classOf[AddImplicitOutputFile].toString + behavior.of(classOf[AddImplicitOutputFile].toString) it should "add an OutputFileAnnotation derived from a TopNameAnnotation if no OutputFileAnnotation exists" in - new PhaseFixture(new AddImplicitOutputFile) { - val annotations = Seq( TopNameAnnotation("foo") ) - val expected = Seq( - OutputFileAnnotation("foo"), - TopNameAnnotation("foo") - ) - phase.transform(annotations).toSeq should be (expected) - } + new PhaseFixture(new AddImplicitOutputFile) { + val annotations = Seq(TopNameAnnotation("foo")) + val expected = Seq( + OutputFileAnnotation("foo"), + TopNameAnnotation("foo") + ) + phase.transform(annotations).toSeq should be(expected) + } it should "do nothing if an OutputFileannotation already exists" in - new PhaseFixture(new AddImplicitOutputFile) { - val annotations = Seq( - TopNameAnnotation("foo"), - OutputFileAnnotation("bar") ) - val expected = annotations - phase.transform(annotations).toSeq should be (expected) - } + new PhaseFixture(new AddImplicitOutputFile) { + val annotations = Seq(TopNameAnnotation("foo"), OutputFileAnnotation("bar")) + val expected = annotations + phase.transform(annotations).toSeq should be(expected) + } it should "do nothing if no TopNameAnnotation exists" in - new PhaseFixture(new AddImplicitOutputFile) { - val annotations = Seq.empty - val expected = annotations - phase.transform(annotations).toSeq should be (expected) - } + new PhaseFixture(new AddImplicitOutputFile) { + val annotations = Seq.empty + val expected = annotations + phase.transform(annotations).toSeq should be(expected) + } } diff --git a/src/test/scala/firrtl/testutils/FirrtlSpec.scala b/src/test/scala/firrtl/testutils/FirrtlSpec.scala index dfc20352..a0c41085 100644 --- a/src/test/scala/firrtl/testutils/FirrtlSpec.scala +++ b/src/test/scala/firrtl/testutils/FirrtlSpec.scala @@ -46,11 +46,13 @@ object RenameTop extends Transform { val c = state.circuit val ns = Namespace(c) - val newTopName = state.annotations.collectFirst({ - case RenameTopAnnotation(name) => - require(ns.tryName(name)) - name - }).getOrElse(c.main) + val newTopName = state.annotations + .collectFirst({ + case RenameTopAnnotation(name) => + require(ns.tryName(name)) + name + }) + .getOrElse(c.main) state.annotations.collect { case ModuleNamespaceAnnotation(mustNotCollideNS) => require(mustNotCollideNS.tryName(newTopName)) @@ -70,6 +72,7 @@ object RenameTop extends Transform { trait FirrtlRunners extends BackendCompilationUtilities { val cppHarnessResourceName: String = "/firrtl/testTop.cpp" + /** Extra transforms to run by default */ val extraCheckTransforms = Seq(new CheckLowForm) @@ -80,10 +83,12 @@ trait FirrtlRunners extends BackendCompilationUtilities { * @param customAnnotations Optional Firrtl annotations * @param timesteps the maximum number of timesteps to consider */ - def firrtlEquivalenceTest(input: String, - customTransforms: Seq[Transform] = Seq.empty, - customAnnotations: AnnotationSeq = Seq.empty, - timesteps: Int = 1): Unit = { + def firrtlEquivalenceTest( + input: String, + customTransforms: Seq[Transform] = Seq.empty, + customAnnotations: AnnotationSeq = Seq.empty, + timesteps: Int = 1 + ): Unit = { val circuit = Parser.parse(input.split("\n").toIterator) val prefix = circuit.main val testDir = createTestDirectory(prefix + "_equivalence_test") @@ -93,12 +98,12 @@ trait FirrtlRunners extends BackendCompilationUtilities { def getBaseAnnos(topName: String) = { val baseTransforms = RenameTop +: extraCheckTransforms TargetDirAnnotation(testDir.toString) +: - InfoModeAnnotation("ignore") +: - RenameTopAnnotation(topName) +: - stage.FirrtlCircuitAnnotation(circuit) +: - stage.CompilerAnnotation("mverilog") +: - stage.OutputFileAnnotation(topName) +: - toAnnos(baseTransforms) + InfoModeAnnotation("ignore") +: + RenameTopAnnotation(topName) +: + stage.FirrtlCircuitAnnotation(circuit) +: + stage.CompilerAnnotation("mverilog") +: + stage.OutputFileAnnotation(topName) +: + toAnnos(baseTransforms) } val customName = s"${prefix}_custom" @@ -111,7 +116,8 @@ trait FirrtlRunners extends BackendCompilationUtilities { val refAnnos = getBaseAnnos(refSuggestedName) ++: Seq(RunFirrtlTransformAnnotation(new RenameModules), nsAnno) val refResult = (new firrtl.stage.FirrtlStage).execute(Array.empty, refAnnos) - val refName = refResult.collectFirst({ case stage.FirrtlCircuitAnnotation(c) => c.main }).getOrElse(refSuggestedName) + val refName = + refResult.collectFirst({ case stage.FirrtlCircuitAnnotation(c) => c.main }).getOrElse(refSuggestedName) assert(BackendCompilationUtilities.yosysExpectSuccess(customName, refName, testDir, timesteps)) } @@ -123,6 +129,7 @@ trait FirrtlRunners extends BackendCompilationUtilities { val res = compiler.compileAndEmit(CircuitState(circuit, HighForm, annotations), extraCheckTransforms) res.getEmittedCircuit.value } + /** Compile a Firrtl file * * @param prefix is the name of the Firrtl file without path or file extension @@ -130,25 +137,27 @@ trait FirrtlRunners extends BackendCompilationUtilities { * @param annotations Optional Firrtl annotations */ def compileFirrtlTest( - prefix: String, - srcDir: String, - customTransforms: Seq[Transform] = Seq.empty, - annotations: AnnotationSeq = Seq.empty): File = { + prefix: String, + srcDir: String, + customTransforms: Seq[Transform] = Seq.empty, + annotations: AnnotationSeq = Seq.empty + ): File = { val testDir = createTestDirectory(prefix) val inputFile = new File(testDir, s"${prefix}.fir") copyResourceToFile(s"${srcDir}/${prefix}.fir", inputFile) val annos = FirrtlFileAnnotation(inputFile.toString) +: - TargetDirAnnotation(testDir.toString) +: - InfoModeAnnotation("ignore") +: - annotations ++: - (customTransforms ++ extraCheckTransforms).map(RunFirrtlTransformAnnotation(_)) + TargetDirAnnotation(testDir.toString) +: + InfoModeAnnotation("ignore") +: + annotations ++: + (customTransforms ++ extraCheckTransforms).map(RunFirrtlTransformAnnotation(_)) (new firrtl.stage.FirrtlStage).execute(Array.empty, annos) testDir } + /** Execute a Firrtl Test * * @param prefix is the name of the Firrtl file without path or file extension @@ -157,25 +166,26 @@ trait FirrtlRunners extends BackendCompilationUtilities { * @param annotations Optional Firrtl annotations */ def runFirrtlTest( - prefix: String, - srcDir: String, - verilogPrefixes: Seq[String] = Seq.empty, - customTransforms: Seq[Transform] = Seq.empty, - annotations: AnnotationSeq = Seq.empty) = { + prefix: String, + srcDir: String, + verilogPrefixes: Seq[String] = Seq.empty, + customTransforms: Seq[Transform] = Seq.empty, + annotations: AnnotationSeq = Seq.empty + ) = { val testDir = compileFirrtlTest(prefix, srcDir, customTransforms, annotations) val harness = new File(testDir, s"top.cpp") copyResourceToFile(cppHarnessResourceName, harness) // Note file copying side effect - val verilogFiles = verilogPrefixes map { vprefix => + val verilogFiles = verilogPrefixes.map { vprefix => val file = new File(testDir, s"$vprefix.v") copyResourceToFile(s"$srcDir/$vprefix.v", file) file } verilogToCpp(prefix, testDir, verilogFiles, harness) #&& - cppToExe(prefix, testDir) ! - loggingProcessLogger + cppToExe(prefix, testDir) ! + loggingProcessLogger assert(executeExpectingSuccess(prefix, testDir)) } } @@ -201,6 +211,7 @@ trait FirrtlMatchers extends Matchers { require(!s.contains("\n")) s.replaceAll("\\s+", " ").trim } + /** Helper to make circuits that are the same appear the same */ def canonicalize(circuit: Circuit): Circuit = { import firrtl.Mappers._ @@ -208,19 +219,21 @@ trait FirrtlMatchers extends Matchers { circuit.map(onModule) } def parse(str: String) = Parser.parse(str.split("\n").toIterator, UseInfo) + /** Helper for executing tests * compiler will be run on input then emitted result will each be split into * lines and normalized. */ def executeTest( - input: String, - expected: Seq[String], - compiler: Compiler, - annotations: Seq[Annotation] = Seq.empty) = { + input: String, + expected: Seq[String], + compiler: Compiler, + annotations: Seq[Annotation] = Seq.empty + ) = { val finalState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations)) - val lines = finalState.getEmittedCircuit.value split "\n" map normalized + val lines = finalState.getEmittedCircuit.value.split("\n").map(normalized) for (e <- expected) { - lines should contain (e) + lines should contain(e) } } } @@ -239,10 +252,12 @@ object FirrtlCheckers extends FirrtlMatchers { case Some(res) => res // Otherwise keep digging case None => - require(node.isInstanceOf[Product] || !node.isInstanceOf[FirrtlNode], - "Error! Unexpected FirrtlNode that does not implement Product!") + require( + node.isInstanceOf[Product] || !node.isInstanceOf[FirrtlNode], + "Error! Unexpected FirrtlNode that does not implement Product!" + ) val iter = node match { - case p: Product => p.productIterator + case p: Product => p.productIterator case i: Iterable[Any] => i.iterator case _ => Iterator.empty } @@ -296,57 +311,63 @@ class TestFirrtlFlatSpec extends FirrtlFlatSpec { import FirrtlCheckers._ val c = parse(""" - |circuit Test: - | module Test : - | input in : UInt<8> - | output out : UInt<8> - | out <= in - |""".stripMargin) + |circuit Test: + | module Test : + | input in : UInt<8> + | output out : UInt<8> + | out <= in + |""".stripMargin) val state = CircuitState(c, ChirrtlForm) val compiled = (new LowFirrtlCompiler).compileAndEmit(state, List.empty) // While useful, ScalaTest helpers should be used over search - behavior of "Search" + behavior.of("Search") it should "be supported on Circuit" in { - assert(c search { - case Connect(_, Reference("out",_, _, _), Reference("in", _, _, _)) => true + assert(c.search { + case Connect(_, Reference("out", _, _, _), Reference("in", _, _, _)) => true }) } it should "be supported on CircuitStates" in { - assert(state search { - case Connect(_, Reference("out", _, _, _), Reference("in",_, _, _)) => true + assert(state.search { + case Connect(_, Reference("out", _, _, _), Reference("in", _, _, _)) => true }) } it should "be supported on the results of compilers" in { - assert(compiled search { - case Connect(_, WRef("out",_,_,_), WRef("in",_,_,_)) => true + assert(compiled.search { + case Connect(_, WRef("out", _, _, _), WRef("in", _, _, _)) => true }) } // Use these!!! - behavior of "ScalaTest helpers" + behavior.of("ScalaTest helpers") they should "work for lines of emitted text" in { - compiled should containLine (s"input in : UInt<8>") - compiled should containLine (s"output out : UInt<8>") - compiled should containLine (s"out <= in") + compiled should containLine(s"input in : UInt<8>") + compiled should containLine(s"output out : UInt<8>") + compiled should containLine(s"out <= in") } they should "work for partial functions matching on subtrees" in { val UInt8 = UIntType(IntWidth(8)) // BigInt unapply is weird compiled should containTree { case Port(_, "in", Input, UInt8) => true } compiled should containTree { case Port(_, "out", Output, UInt8) => true } - compiled should containTree { case Connect(_, WRef("out",_,_,_), WRef("in",_,_,_)) => true } + compiled should containTree { case Connect(_, WRef("out", _, _, _), WRef("in", _, _, _)) => true } } } /** Super class for execution driven Firrtl tests */ -abstract class ExecutionTest(name: String, dir: String, vFiles: Seq[String] = Seq.empty, annotations: AnnotationSeq = Seq.empty) extends FirrtlPropSpec { +abstract class ExecutionTest( + name: String, + dir: String, + vFiles: Seq[String] = Seq.empty, + annotations: AnnotationSeq = Seq.empty) + extends FirrtlPropSpec { property(s"$name should execute correctly") { runFirrtlTest(name, dir, vFiles, annotations = annotations) } } + /** Super class for compilation driven Firrtl tests */ abstract class CompilationTest(name: String, dir: String) extends FirrtlPropSpec { property(s"$name should compile correctly") { @@ -444,7 +465,9 @@ abstract class EquivalenceTest(transforms: Seq[Transform], name: String, dir: St throw new FileNotFoundException(s"Resource '$fileName'") } val source = scala.io.Source.fromInputStream(in) - val input = try source.mkString finally source.close() + val input = + try source.mkString + finally source.close() s"$name with ${transforms.map(_.name).mkString(", ")}" should s"be equivalent to $name without ${transforms.map(_.name).mkString(", ")}" in { diff --git a/src/test/scala/firrtl/testutils/LeanTransformSpec.scala b/src/test/scala/firrtl/testutils/LeanTransformSpec.scala index c1f0943a..4ae6a7be 100644 --- a/src/test/scala/firrtl/testutils/LeanTransformSpec.scala +++ b/src/test/scala/firrtl/testutils/LeanTransformSpec.scala @@ -1,6 +1,6 @@ package firrtl.testutils -import firrtl.{AnnotationSeq, CircuitState, EmitCircuitAnnotation, ir} +import firrtl.{ir, AnnotationSeq, CircuitState, EmitCircuitAnnotation} import firrtl.options.Dependency import firrtl.passes.RemoveEmpty import firrtl.stage.TransformManager.TransformDependency @@ -11,30 +11,33 @@ class VerilogTransformSpec extends LeanTransformSpec(Seq(Dependency[firrtl.Veril class LowFirrtlTransformSpec extends LeanTransformSpec(Seq(Dependency[firrtl.LowFirrtlEmitter])) /** The new cool kid on the block, creates a custom compiler for your transform. */ -class LeanTransformSpec(protected val transforms: Seq[TransformDependency]) extends AnyFlatSpec with FirrtlMatchers with LazyLogging { +class LeanTransformSpec(protected val transforms: Seq[TransformDependency]) + extends AnyFlatSpec + with FirrtlMatchers + with LazyLogging { private val compiler = new firrtl.stage.transforms.Compiler(transforms) private val emitterAnnos = LeanTransformSpec.deriveEmitCircuitAnnotations(transforms) protected def compile(src: String): CircuitState = compile(src, Seq()) protected def compile(src: String, annos: AnnotationSeq): CircuitState = compile(firrtl.Parser.parse(src), annos) - protected def compile(c: ir.Circuit): CircuitState = compile(c, Seq()) - protected def compile(c: ir.Circuit, annos: AnnotationSeq): CircuitState = + protected def compile(c: ir.Circuit): CircuitState = compile(c, Seq()) + protected def compile(c: ir.Circuit, annos: AnnotationSeq): CircuitState = compiler.transform(CircuitState(c, emitterAnnos ++ annos)) - protected def execute(input: String, check: String): CircuitState = execute(input, check ,Seq()) + protected def execute(input: String, check: String): CircuitState = execute(input, check, Seq()) protected def execute(input: String, check: String, inAnnos: AnnotationSeq): CircuitState = { val finalState = compiler.transform(CircuitState(parse(input), inAnnos)) val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize val expected = parse(check).serialize logger.debug(actual) logger.debug(expected) - actual should be (expected) + actual should be(expected) finalState } } private object LeanTransformSpec { private def deriveEmitCircuitAnnotations(transforms: Iterable[TransformDependency]): AnnotationSeq = { - val emitters = transforms.map(_.getObject()).collect{ case e: firrtl.Emitter => e } + val emitters = transforms.map(_.getObject()).collect { case e: firrtl.Emitter => e } emitters.map(e => EmitCircuitAnnotation(e.getClass)).toSeq } } @@ -47,4 +50,4 @@ trait MakeCompiler { new firrtl.stage.transforms.Compiler(Seq(Dependency[firrtl.MinimumVerilogEmitter]) ++ transforms) protected def makeLowFirrtlCompiler(transforms: Seq[TransformDependency] = Seq()) = new firrtl.stage.transforms.Compiler(Seq(Dependency[firrtl.LowFirrtlEmitter]) ++ transforms) -}
\ No newline at end of file +} diff --git a/src/test/scala/firrtl/testutils/PassTests.scala b/src/test/scala/firrtl/testutils/PassTests.scala index 49dea199..7a5dc306 100644 --- a/src/test/scala/firrtl/testutils/PassTests.scala +++ b/src/test/scala/firrtl/testutils/PassTests.scala @@ -15,49 +15,53 @@ import org.scalatest.flatspec.AnyFlatSpec // An example methodology for testing Firrtl Passes // Spec class should extend this class abstract class SimpleTransformSpec extends AnyFlatSpec with FirrtlMatchers with Compiler with LazyLogging { - // Utility function - def squash(c: Circuit): Circuit = RemoveEmpty.run(c) - - // Executes the test. Call in tests. - // annotations cannot have default value because scalatest trait Suite has a default value - def execute(input: String, check: String, annotations: Seq[Annotation]): CircuitState = { - val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations)) - val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize - val expected = parse(check).serialize - logger.debug(actual) - logger.debug(expected) - (actual) should be (expected) - finalState - } - - def executeWithAnnos(input: String, check: String, annotations: Seq[Annotation], - checkAnnotations: Seq[Annotation]): CircuitState = { - val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations)) - val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize - val expected = parse(check).serialize - logger.debug(actual) - logger.debug(expected) - (actual) should be (expected) - - annotations.foreach { anno => - logger.debug(anno.serialize) - } - - finalState.annotations.toSeq.foreach { anno => - logger.debug(anno.serialize) - } - checkAnnotations.foreach { check => - (finalState.annotations.toSeq) should contain (check) - } - finalState - } - // Executes the test, should throw an error - // No default to be consistent with execute - def failingexecute(input: String, annotations: Seq[Annotation]): Exception = { - intercept[PassExceptions] { - compile(CircuitState(parse(input), ChirrtlForm, annotations), Seq.empty) - } - } + // Utility function + def squash(c: Circuit): Circuit = RemoveEmpty.run(c) + + // Executes the test. Call in tests. + // annotations cannot have default value because scalatest trait Suite has a default value + def execute(input: String, check: String, annotations: Seq[Annotation]): CircuitState = { + val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations)) + val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize + val expected = parse(check).serialize + logger.debug(actual) + logger.debug(expected) + (actual) should be(expected) + finalState + } + + def executeWithAnnos( + input: String, + check: String, + annotations: Seq[Annotation], + checkAnnotations: Seq[Annotation] + ): CircuitState = { + val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations)) + val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize + val expected = parse(check).serialize + logger.debug(actual) + logger.debug(expected) + (actual) should be(expected) + + annotations.foreach { anno => + logger.debug(anno.serialize) + } + + finalState.annotations.toSeq.foreach { anno => + logger.debug(anno.serialize) + } + checkAnnotations.foreach { check => + (finalState.annotations.toSeq) should contain(check) + } + finalState + } + // Executes the test, should throw an error + // No default to be consistent with execute + def failingexecute(input: String, annotations: Seq[Annotation]): Exception = { + intercept[PassExceptions] { + compile(CircuitState(parse(input), ChirrtlForm, annotations), Seq.empty) + } + } } @deprecated( @@ -86,19 +90,19 @@ object ReRunResolveAndCheck extends Transform with DependencyAPIMigration with I } trait LowTransformSpec extends SimpleTransformSpec { - def emitter = new LowFirrtlEmitter - def transform: Transform - def transforms: Seq[Transform] = transform +: ReRunResolveAndCheck +: Forms.LowForm.map(_.getObject) + def emitter = new LowFirrtlEmitter + def transform: Transform + def transforms: Seq[Transform] = transform +: ReRunResolveAndCheck +: Forms.LowForm.map(_.getObject) } trait MiddleTransformSpec extends SimpleTransformSpec { - def emitter = new MiddleFirrtlEmitter - def transform: Transform - def transforms: Seq[Transform] = transform +: ReRunResolveAndCheck +: Forms.MidForm.map(_.getObject) + def emitter = new MiddleFirrtlEmitter + def transform: Transform + def transforms: Seq[Transform] = transform +: ReRunResolveAndCheck +: Forms.MidForm.map(_.getObject) } trait HighTransformSpec extends SimpleTransformSpec { - def emitter = new HighFirrtlEmitter - def transform: Transform - def transforms = transform +: ReRunResolveAndCheck +: Forms.HighForm.map(_.getObject) + def emitter = new HighFirrtlEmitter + def transform: Transform + def transforms = transform +: ReRunResolveAndCheck +: Forms.HighForm.map(_.getObject) } |
