diff options
| author | chick | 2020-08-14 19:47:53 -0700 |
|---|---|---|
| committer | Jack Koenig | 2020-08-14 19:47:53 -0700 |
| commit | 6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch) | |
| tree | 2ed103ee80b0fba613c88a66af854ae9952610ce /src/test/scala | |
| parent | b516293f703c4de86397862fee1897aded2ae140 (diff) | |
All of src/ formatted with scalafmt
Diffstat (limited to 'src/test/scala')
131 files changed, 8356 insertions, 7801 deletions
diff --git a/src/test/scala/firrtl/JsonProtocolSpec.scala b/src/test/scala/firrtl/JsonProtocolSpec.scala index 7d04e9fc..cc7591cb 100644 --- a/src/test/scala/firrtl/JsonProtocolSpec.scala +++ b/src/test/scala/firrtl/JsonProtocolSpec.scala @@ -4,7 +4,13 @@ package firrtlTests import org.json4s._ -import firrtl.annotations.{NoTargetAnnotation, JsonProtocol, InvalidAnnotationJSONException, HasSerializationHints, Annotation} +import firrtl.annotations.{ + Annotation, + HasSerializationHints, + InvalidAnnotationJSONException, + JsonProtocol, + NoTargetAnnotation +} import org.scalatest.flatspec.AnyFlatSpec object JsonProtocolTestClasses { @@ -13,12 +19,16 @@ object JsonProtocolTestClasses { case class ChildA(foo: Int) extends Parent case class ChildB(bar: String) extends Parent case class PolymorphicParameterAnnotation(param: Parent) extends NoTargetAnnotation - case class PolymorphicParameterAnnotationWithTypeHints(param: Parent) extends NoTargetAnnotation with HasSerializationHints { + case class PolymorphicParameterAnnotationWithTypeHints(param: Parent) + extends NoTargetAnnotation + with HasSerializationHints { def typeHints = Seq(param.getClass) } case class TypeParameterizedAnnotation[T](param: T) extends NoTargetAnnotation - case class TypeParameterizedAnnotationWithTypeHints[T](param: T) extends NoTargetAnnotation with HasSerializationHints { + case class TypeParameterizedAnnotationWithTypeHints[T](param: T) + extends NoTargetAnnotation + with HasSerializationHints { def typeHints = Seq(param.getClass) } } @@ -51,11 +61,11 @@ class JsonProtocolSpec extends AnyFlatSpec { "Annotations with non-primitive type parameters" should "not serialize and deserialize without type hints" in { val anno = TypeParameterizedAnnotation(ChildA(1)) val deserAnno = serializeAndDeserialize(anno) - assert (anno != deserAnno) + assert(anno != deserAnno) } it should "serialize and deserialize with type hints" in { val anno = TypeParameterizedAnnotationWithTypeHints(ChildA(1)) val deserAnno = serializeAndDeserialize(anno) - assert (anno == deserAnno) + assert(anno == deserAnno) } } diff --git a/src/test/scala/firrtl/analysis/SymbolTableSpec.scala b/src/test/scala/firrtl/analysis/SymbolTableSpec.scala index 599b4e52..ca30b60b 100644 --- a/src/test/scala/firrtl/analysis/SymbolTableSpec.scala +++ b/src/test/scala/firrtl/analysis/SymbolTableSpec.scala @@ -8,7 +8,7 @@ import firrtl.options.Dependency import org.scalatest.flatspec.AnyFlatSpec class SymbolTableSpec extends AnyFlatSpec { - behavior of "SymbolTable" + behavior.of("SymbolTable") private val src = """circuit m: @@ -50,9 +50,20 @@ class SymbolTableSpec extends AnyFlatSpec { assert(syms("r").tpe == ir.SIntType(ir.IntWidth(4)) && syms("r").kind == firrtl.RegKind) val mType = firrtl.passes.MemPortUtils.memType( // only dataType, depth and reader, writer, readwriter properties affect the data type - ir.DefMemory(ir.NoInfo, "???", ir.UIntType(ir.IntWidth(8)), 32, 10, 10, Seq("r"), Seq(), Seq(), ir.ReadUnderWrite.New) + ir.DefMemory( + ir.NoInfo, + "???", + ir.UIntType(ir.IntWidth(8)), + 32, + 10, + 10, + Seq("r"), + Seq(), + Seq(), + ir.ReadUnderWrite.New + ) ) - assert(syms("m") .tpe == mType && syms("m").kind == firrtl.MemKind) + assert(syms("m").tpe == mType && syms("m").kind == firrtl.MemKind) } it should "find all declarations in module m after InferTypes" in { @@ -69,7 +80,7 @@ class SymbolTableSpec extends AnyFlatSpec { assert(syms("i").tpe == iType && syms("i").kind == firrtl.InstanceKind) } - behavior of "WithSeq" + behavior.of("WithSeq") it should "preserve declaration order" in { val c = firrtl.Parser.parse(src) @@ -79,7 +90,7 @@ class SymbolTableSpec extends AnyFlatSpec { assert(syms.getSymbols.map(_.name) == Seq("clk", "x", "y", "z", "a", "i", "r", "m")) } - behavior of "ModuleTypesSymbolTable" + behavior.of("ModuleTypesSymbolTable") it should "derive the module type from the module types map" in { val c = firrtl.Parser.parse(src) diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala index 015ac4a9..f7ce9914 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala @@ -20,10 +20,16 @@ private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec { sys.signals.head.e.toString } - def primop(signed: Boolean, op: String, resWidth: Int, inWidth: Seq[Int], consts: Seq[Int] = List(), - resAlwaysUnsigned: Boolean = false): String = { - val tpe = if(signed) "SInt" else "UInt" - val resTpe = if(resAlwaysUnsigned) "UInt" else tpe + def primop( + signed: Boolean, + op: String, + resWidth: Int, + inWidth: Seq[Int], + consts: Seq[Int] = List(), + resAlwaysUnsigned: Boolean = false + ): String = { + val tpe = if (signed) "SInt" else "UInt" + val resTpe = if (resAlwaysUnsigned) "UInt" else tpe val inTpes = inWidth.map(w => s"$tpe<$w>") primop(op, s"$resTpe<$resWidth>", inTpes, consts) } @@ -52,16 +58,24 @@ private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec { it should "correctly translate the `div` primitive operation" in { // division is a little bit more complicated because the result of division by zero is undefined - assert(primop(false, "div", 8, List(8, 8)) == - "ite(eq(i1, 8'b0), RANDOM.res, udiv(i0, i1))") - assert(primop(false, "div", 8, List(8, 4)) == - "ite(eq(i1, 4'b0), RANDOM.res, udiv(i0, zext(i1, 4)))") + assert( + primop(false, "div", 8, List(8, 8)) == + "ite(eq(i1, 8'b0), RANDOM.res, udiv(i0, i1))" + ) + assert( + primop(false, "div", 8, List(8, 4)) == + "ite(eq(i1, 4'b0), RANDOM.res, udiv(i0, zext(i1, 4)))" + ) // signed division increases result width by 1 - assert(primop(true, "div", 8, List(7, 7)) == - "ite(eq(i1, 7'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 1)))") - assert(primop(true, "div", 8, List(7, 4)) - == "ite(eq(i1, 4'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 4)))") + assert( + primop(true, "div", 8, List(7, 7)) == + "ite(eq(i1, 7'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 1)))" + ) + assert( + primop(true, "div", 8, List(7, 4)) + == "ite(eq(i1, 4'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 4)))" + ) } it should "correctly translate the `rem` primitive operation" in { @@ -134,15 +148,19 @@ private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec { it should "correctly translate the `dshl` primitive operation" in { assert(primop(false, "dshl", 31, List(16, 4)) == "logical_shift_left(zext(i0, 15), zext(i1, 27))") assert(primop(false, "dshl", 19, List(16, 2)) == "logical_shift_left(zext(i0, 3), zext(i1, 17))") - assert(primop("dshl", "SInt<19>", List("SInt<16>", "UInt<2>"), List()) == - "logical_shift_left(sext(i0, 3), zext(i1, 17))") + assert( + primop("dshl", "SInt<19>", List("SInt<16>", "UInt<2>"), List()) == + "logical_shift_left(sext(i0, 3), zext(i1, 17))" + ) } it should "correctly translate the `dshr` primitive operation" in { assert(primop(false, "dshr", 16, List(16, 4)) == "logical_shift_right(i0, zext(i1, 12))") assert(primop(false, "dshr", 16, List(16, 2)) == "logical_shift_right(i0, zext(i1, 14))") - assert(primop("dshr", "SInt<16>", List("SInt<16>", "UInt<2>"), List()) == - "arithmetic_shift_right(i0, zext(i1, 14))") + assert( + primop("dshr", "SInt<16>", List("SInt<16>", "UInt<2>"), List()) == + "arithmetic_shift_right(i0, zext(i1, 14))" + ) } it should "correctly translate the `cvt` primitive operation" in { @@ -197,15 +215,15 @@ private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec { } it should "correctly translate the `bits` primitive operation" in { - assert(primop(false, "bits", 1, List(4), List(2,2)) == "i0[2]") - assert(primop(false, "bits", 2, List(4), List(2,1)) == "i0[2:1]") - assert(primop(false, "bits", 1, List(4), List(2,1)) == "i0[2:1][0]") - assert(primop(false, "bits", 3, List(4), List(2,1)) == "zext(i0[2:1], 1)") + assert(primop(false, "bits", 1, List(4), List(2, 2)) == "i0[2]") + assert(primop(false, "bits", 2, List(4), List(2, 1)) == "i0[2:1]") + assert(primop(false, "bits", 1, List(4), List(2, 1)) == "i0[2:1][0]") + assert(primop(false, "bits", 3, List(4), List(2, 1)) == "zext(i0[2:1], 1)") - assert(primop(true, "bits", 1, List(4), List(2,2), resAlwaysUnsigned = true) == "i0[2]") - assert(primop(true, "bits", 2, List(4), List(2,1), resAlwaysUnsigned = true) == "i0[2:1]") - assert(primop(true, "bits", 1, List(4), List(2,1), resAlwaysUnsigned = true) == "i0[2:1][0]") - assert(primop(true, "bits", 3, List(4), List(2,1), resAlwaysUnsigned = true) == "zext(i0[2:1], 1)") + assert(primop(true, "bits", 1, List(4), List(2, 2), resAlwaysUnsigned = true) == "i0[2]") + assert(primop(true, "bits", 2, List(4), List(2, 1), resAlwaysUnsigned = true) == "i0[2:1]") + assert(primop(true, "bits", 1, List(4), List(2, 1), resAlwaysUnsigned = true) == "i0[2:1][0]") + assert(primop(true, "bits", 3, List(4), List(2, 1), resAlwaysUnsigned = true) == "zext(i0[2:1], 1)") } it should "correctly translate the `head` primitive operation" in { @@ -221,4 +239,4 @@ private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec { assert(primop(false, "tail", 4, List(5), List(1)) == "i0[3:0]") assert(primop(false, "tail", 2, List(5), List(3)) == "i0[1:0]") } -}
\ No newline at end of file +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala index ca7974c5..b41313e3 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala @@ -5,8 +5,7 @@ package firrtl.backends.experimental.smt import firrtl.{MemoryArrayInit, MemoryScalarInit, Utils} private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec { - behavior of "ModuleToTransitionSystem.run" - + behavior.of("ModuleToTransitionSystem.run") it should "model registers as state" in { // if a signal is invalid, it could take on an arbitrary value in that cycle @@ -42,39 +41,39 @@ private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec { private def memCircuit(depth: Int = 32) = s"""circuit m: - | module m: - | input reset : UInt<1> - | input clock : Clock - | input addr : UInt<${Utils.getUIntWidth(depth)}> - | input in : UInt<8> - | output out : UInt<8> - | - | mem m: - | data-type => UInt<8> - | depth => $depth - | reader => r - | writer => w - | read-latency => 0 - | write-latency => 1 - | read-under-write => new - | - | m.w.clk <= clock - | m.w.mask <= UInt(1) - | m.w.en <= UInt(1) - | m.w.data <= in - | m.w.addr <= addr - | - | m.r.clk <= clock - | m.r.en <= UInt(1) - | out <= m.r.data - | m.r.addr <= addr - | - |""".stripMargin + | module m: + | input reset : UInt<1> + | input clock : Clock + | input addr : UInt<${Utils.getUIntWidth(depth)}> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => $depth + | reader => r + | writer => w + | read-latency => 0 + | write-latency => 1 + | read-under-write => new + | + | m.w.clk <= clock + | m.w.mask <= UInt(1) + | m.w.en <= UInt(1) + | m.w.data <= in + | m.w.addr <= addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= addr + | + |""".stripMargin it should "model memories as state" in { val sys = toSys(memCircuit()) - assert(sys.signals.length == 9-2+1, "9 connects - 2 clock connects + 1 combinatorial read port") + assert(sys.signals.length == 9 - 2 + 1, "9 connects - 2 clock connects + 1 combinatorial read port") val sig = sys.signals.map(s => s.name -> s.e).toMap @@ -140,40 +139,39 @@ private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec { it should "support memories with registered read port" in { def src(readUnderWrite: String) = s"""circuit m: - | module m: - | input reset : UInt<1> - | input clock : Clock - | input addr : UInt<5> - | input in : UInt<8> - | output out : UInt<8> - | - | mem m: - | data-type => UInt<8> - | depth => 32 - | reader => r - | writer => w1, w2 - | read-latency => 1 - | write-latency => 1 - | read-under-write => $readUnderWrite - | - | m.w1.clk <= clock - | m.w1.mask <= UInt(1) - | m.w1.en <= UInt(1) - | m.w1.data <= in - | m.w1.addr <= addr - | m.w2.clk <= clock - | m.w2.mask <= UInt(1) - | m.w2.en <= UInt(1) - | m.w2.data <= in - | m.w2.addr <= addr - | - | m.r.clk <= clock - | m.r.en <= UInt(1) - | out <= m.r.data - | m.r.addr <= addr - | - |""".stripMargin - + | module m: + | input reset : UInt<1> + | input clock : Clock + | input addr : UInt<5> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => 32 + | reader => r + | writer => w1, w2 + | read-latency => 1 + | write-latency => 1 + | read-under-write => $readUnderWrite + | + | m.w1.clk <= clock + | m.w1.mask <= UInt(1) + | m.w1.en <= UInt(1) + | m.w1.data <= in + | m.w1.addr <= addr + | m.w2.clk <= clock + | m.w2.mask <= UInt(1) + | m.w2.en <= UInt(1) + | m.w2.data <= in + | m.w2.addr <= addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= addr + | + |""".stripMargin val oldValue = toSys(src("old")) val oldMData = oldValue.states.find(_.sym.name == "m.r.data").get @@ -186,9 +184,11 @@ private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec { val undefinedMData = undefinedValue.states.find(_.sym.name == "m.r.data").get assert(undefinedMData.sym.toString == "m.r.data") val undefined = "RANDOM.m_r_read_under_write_undefined" - assert(undefinedMData.next.get.toString == - s"ite(or(eq(m.r.addr, m.w1.addr), eq(m.r.addr, m.w2.addr)), $undefined, m[m.r.addr])", - "randomize result if there is a write") + assert( + undefinedMData.next.get.toString == + s"ite(or(eq(m.r.addr, m.w1.addr), eq(m.r.addr, m.w2.addr)), $undefined, m[m.r.addr])", + "randomize result if there is a write" + ) } it should "support memories with potential write-write conflicts" in { @@ -228,7 +228,6 @@ private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec { | |""".stripMargin - val sys = toSys(src) val m = sys.states.find(_.sym.name == "m").get diff --git a/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala index 6bfb5437..209279fd 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala @@ -3,7 +3,7 @@ package firrtl.backends.experimental.smt import firrtl.annotations.Annotation -import firrtl.{MemoryInitValue, ir} +import firrtl.{ir, MemoryInitValue} import firrtl.stage.{Forms, TransformManager} import org.scalatest.flatspec.AnyFlatSpec @@ -16,8 +16,12 @@ private abstract class SMTBackendBaseSpec extends AnyFlatSpec { compiler.runTransform(firrtl.CircuitState(c, annos)).circuit } - protected def toSys(src: String, mod: String = "m", presetRegs: Set[String] = Set(), - memInit: Map[String, MemoryInitValue] = Map()): TransitionSystem = { + protected def toSys( + src: String, + mod: String = "m", + presetRegs: Set[String] = Set(), + memInit: Map[String, MemoryInitValue] = Map() + ): TransitionSystem = { val circuit = compile(src) val module = circuit.modules.find(_.name == mod).get.asInstanceOf[ir.Module] // println(module.serialize) @@ -35,4 +39,4 @@ private abstract class SMTBackendBaseSpec extends AnyFlatSpec { protected def toSMTLibStr(src: String, mod: String = "m"): String = toSMTLib(src, mod).mkString("\n") + "\n" -}
\ No newline at end of file +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala index 4c6901ea..e7c8d534 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala @@ -10,9 +10,10 @@ import firrtl.stage.RunFirrtlTransformAnnotation class AsyncResetSpec extends EndToEndSMTBaseSpec { def annos(name: String) = Seq( RunFirrtlTransformAnnotation(Dependency[StutteringClockTransform]), - GlobalClockAnnotation(CircuitTarget(name).module(name).ref("global_clock"))) + GlobalClockAnnotation(CircuitTarget(name).module(name).ref("global_clock")) + ) - "a module with asynchronous reset" should "allow a register to change between clock edges" taggedAs(RequiresZ3) in { + "a module with asynchronous reset" should "allow a register to change between clock edges" taggedAs (RequiresZ3) in { def in(resetType: String) = s"""circuit AsyncReset00: | module AsyncReset00: @@ -39,8 +40,8 @@ class AsyncResetSpec extends EndToEndSMTBaseSpec { | ; can the value of r change without the count changing? | assert(global_clock, or(not(eq(count, past_count)), eq(r, past_r)), past_valid, "count = past(count) |-> r = past(r)") |""".stripMargin - test(in("AsyncReset"), MCFail(1), kmax=2, annos=annos("AsyncReset00")) - test(in("UInt<1>"), MCSuccess, kmax=2, annos=annos("AsyncReset00")) + test(in("AsyncReset"), MCFail(1), kmax = 2, annos = annos("AsyncReset00")) + test(in("UInt<1>"), MCSuccess, kmax = 2, annos = annos("AsyncReset00")) } } diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala index 2227719b..974d2e81 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala @@ -16,24 +16,23 @@ import org.scalatest.matchers.must.Matchers import scala.sys.process._ class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging { - "we" should "check if Z3 is available" taggedAs(RequiresZ3) in { + "we" should "check if Z3 is available" taggedAs (RequiresZ3) in { val log = ProcessLogger(_ => (), logger.warn(_)) val ret = Process(Seq("which", "z3")).run(log).exitValue() - if(ret != 0) { - logger.error( - """The z3 SMT-Solver seems not to be installed. - |You can exclude the end-to-end smt backend tests which rely on z3 like this: - |sbt testOnly -- -l RequiresZ3 - |""".stripMargin) + if (ret != 0) { + logger.error("""The z3 SMT-Solver seems not to be installed. + |You can exclude the end-to-end smt backend tests which rely on z3 like this: + |sbt testOnly -- -l RequiresZ3 + |""".stripMargin) } assert(ret == 0) } - "Z3" should "be available in version 4" taggedAs(RequiresZ3) in { + "Z3" should "be available in version 4" taggedAs (RequiresZ3) in { assert(Z3ModelChecker.getZ3Version.startsWith("4.")) } - "a simple combinatorial check" should "pass" taggedAs(RequiresZ3) in { + "a simple combinatorial check" should "pass" taggedAs (RequiresZ3) in { val in = """circuit CC00: | module CC00: @@ -45,7 +44,7 @@ class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging { test(in, MCSuccess) } - "a simple combinatorial check" should "fail immediately" taggedAs(RequiresZ3) in { + "a simple combinatorial check" should "fail immediately" taggedAs (RequiresZ3) in { val in = """circuit CC01: | module CC01: @@ -57,7 +56,7 @@ class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging { test(in, MCFail(0)) } - "adding the right assumption" should "make a test pass" taggedAs(RequiresZ3) in { + "adding the right assumption" should "make a test pass" taggedAs (RequiresZ3) in { val in0 = """circuit CC01: | module CC01: @@ -75,8 +74,8 @@ class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging { | assert(c, neq(add(a, b), UInt(0)), UInt(1), "a + b != 0") | assume(c, neq(a, UInt(0)), UInt(1), "a != 0") |""".stripMargin - test(in0, MCFail(0)) - test(in1, MCSuccess) + test(in0, MCFail(0)) + test(in1, MCSuccess) val in2 = """circuit CC01: @@ -91,20 +90,20 @@ class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging { test(in2, MCFail(0)) } - "a register connected to preset reset" should "be initialized with the reset value" taggedAs(RequiresZ3) in { + "a register connected to preset reset" should "be initialized with the reset value" taggedAs (RequiresZ3) in { def in(rEq: Int) = s"""circuit Preset00: - | module Preset00: - | input c: Clock - | input preset: AsyncReset - | reg r: UInt<4>, c with: (reset => (preset, UInt(3))) - | assert(c, eq(r, UInt($rEq)), UInt(1), "r = $rEq") - |""".stripMargin + | module Preset00: + | input c: Clock + | input preset: AsyncReset + | reg r: UInt<4>, c with: (reset => (preset, UInt(3))) + | assert(c, eq(r, UInt($rEq)), UInt(1), "r = $rEq") + |""".stripMargin test(in(3), MCSuccess, kmax = 1) test(in(2), MCFail(0)) } - "a register's initial value" should "should not change" taggedAs(RequiresZ3) in { + "a register's initial value" should "should not change" taggedAs (RequiresZ3) in { val in = """circuit Preset00: | module Preset00: @@ -127,24 +126,29 @@ class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging { abstract class EndToEndSMTBaseSpec extends AnyFlatSpec with Matchers { def test(src: String, expected: MCResult, kmax: Int = 0, clue: String = "", annos: Seq[Annotation] = Seq()): Unit = { expected match { - case MCFail(k) => assert(kmax >= k, s"Please set a kmax that includes the expected failing step! ($kmax < $expected)") + case MCFail(k) => + assert(kmax >= k, s"Please set a kmax that includes the expected failing step! ($kmax < $expected)") case _ => } val fir = firrtl.Parser.parse(src) val name = fir.main val testDir = BackendCompilationUtilities.createTestDirectory("EndToEndSMT." + name) // we automagically add a preset annotation if an input called preset exists - val presetAnno = if(!src.contains("input preset")) { None } else { + val presetAnno = if (!src.contains("input preset")) { None } + else { Some(PresetAnnotation(CircuitTarget(name).module(name).ref("preset"))) } - val res = (new FirrtlStage).execute(Array(), Seq( - LogLevelAnnotation(LogLevel.Error), // silence warnings for tests - RunFirrtlTransformAnnotation(new SMTLibEmitter), - RunFirrtlTransformAnnotation(new Btor2Emitter), - FirrtlCircuitAnnotation(fir), - TargetDirAnnotation(testDir.getAbsolutePath) - ) ++ presetAnno ++ annos) - assert(res.collectFirst{ case _: OutputFileAnnotation => true }.isDefined) + val res = (new FirrtlStage).execute( + Array(), + Seq( + LogLevelAnnotation(LogLevel.Error), // silence warnings for tests + RunFirrtlTransformAnnotation(new SMTLibEmitter), + RunFirrtlTransformAnnotation(new Btor2Emitter), + FirrtlCircuitAnnotation(fir), + TargetDirAnnotation(testDir.getAbsolutePath) + ) ++ presetAnno ++ annos + ) + assert(res.collectFirst { case _: OutputFileAnnotation => true }.isDefined) val r = Z3ModelChecker.bmc(testDir, name, kmax) assert(r == expected, clue + "\n" + s"$testDir") } @@ -153,7 +157,7 @@ abstract class EndToEndSMTBaseSpec extends AnyFlatSpec with Matchers { /** Minimal implementation of a Z3 based bounded model checker. * A more complete version of this with better use feedback should eventually be provided by a * chisel3 formal verification library. Do not use this implementation outside of the firrtl test suite! - * */ + */ private object Z3ModelChecker extends LazyLogging { def getZ3Version: String = { val (out, ret) = executeCmd("-version") @@ -164,14 +168,15 @@ private object Z3ModelChecker extends LazyLogging { } def bmc(testDir: File, main: String, kmax: Int): MCResult = { - assert(kmax >=0 && kmax < 50, "Trying to keep kmax in a reasonable range.") + assert(kmax >= 0 && kmax < 50, "Trying to keep kmax in a reasonable range.") val smtFile = new File(testDir, main + ".smt2") val header = read(smtFile) val steps = (0 to kmax).map(k => new File(testDir, main + s"_step$k.smt2")).zipWithIndex - steps.foreach { case (f,k) => - writeStep(f, main, header, k) - val success = executeStep(f.getAbsolutePath) - if(!success) return MCFail(k) + steps.foreach { + case (f, k) => + writeStep(f, main, header, k) + val success = executeStep(f.getAbsolutePath) + if (!success) return MCFail(k) } MCSuccess } @@ -200,21 +205,22 @@ private object Z3ModelChecker extends LazyLogging { private def step(main: String, k: Int): Iterable[String] = { // define all states (0 to k).map(ii => s"(declare-fun s$ii () $main$StateTpe)") ++ - // assert that init holds in state 0 - List(s"(assert ($main$Init s0))") ++ - // assert transition relation - (0 until k).map(ii => s"(assert ($main$Transition s$ii s${ii+1}))") ++ - // assert that assumptions hold in all states - (0 to k).map(ii => s"(assert ($main$Assumes s$ii))") ++ - // assert that assertions hold for all but last state - (0 until k).map(ii => s"(assert ($main$Asserts s$ii))") ++ - // check to see if we can violate the assertions in the last state - List(s"(assert (not ($main$Asserts s$k)))") + // assert that init holds in state 0 + List(s"(assert ($main$Init s0))") ++ + // assert transition relation + (0 until k).map(ii => s"(assert ($main$Transition s$ii s${ii + 1}))") ++ + // assert that assumptions hold in all states + (0 to k).map(ii => s"(assert ($main$Assumes s$ii))") ++ + // assert that assertions hold for all but last state + (0 until k).map(ii => s"(assert ($main$Asserts s$ii))") ++ + // check to see if we can violate the assertions in the last state + List(s"(assert (not ($main$Asserts s$k)))") } private def read(f: File): Iterable[String] = { val source = scala.io.Source.fromFile(f) - try source.getLines().toVector finally source.close() + try source.getLines().toVector + finally source.close() } // the following suffixes have to match the ones in [[SMTTransitionSystemEncoder]] diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala index 10de9cda..61e1f0f8 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala @@ -9,43 +9,43 @@ class MemorySpec extends EndToEndSMTBaseSpec { registeredTestMem(name, cmds.split("\n"), readUnderWrite) private def registeredTestMem(name: String, cmds: Iterable[String], readUnderWrite: String): String = s"""circuit $name: - | module $name: - | input reset : UInt<1> - | input clock : Clock - | input preset: AsyncReset - | input write_addr : UInt<5> - | input read_addr : UInt<5> - | input in : UInt<8> - | output out : UInt<8> - | - | mem m: - | data-type => UInt<8> - | depth => 32 - | reader => r - | writer => w - | read-latency => 1 - | write-latency => 1 - | read-under-write => $readUnderWrite - | - | m.w.clk <= clock - | m.w.mask <= UInt(1) - | m.w.en <= UInt(1) - | m.w.data <= in - | m.w.addr <= write_addr - | - | m.r.clk <= clock - | m.r.en <= UInt(1) - | out <= m.r.data - | m.r.addr <= read_addr - | - | reg cycle: UInt<8>, clock with: (reset => (preset, UInt(0))) - | cycle <= add(cycle, UInt(1)) - | node past_valid = geq(cycle, UInt(1)) - | - | ${cmds.mkString("\n ")} - |""".stripMargin + | module $name: + | input reset : UInt<1> + | input clock : Clock + | input preset: AsyncReset + | input write_addr : UInt<5> + | input read_addr : UInt<5> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => 32 + | reader => r + | writer => w + | read-latency => 1 + | write-latency => 1 + | read-under-write => $readUnderWrite + | + | m.w.clk <= clock + | m.w.mask <= UInt(1) + | m.w.en <= UInt(1) + | m.w.data <= in + | m.w.addr <= write_addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= read_addr + | + | reg cycle: UInt<8>, clock with: (reset => (preset, UInt(0))) + | cycle <= add(cycle, UInt(1)) + | node past_valid = geq(cycle, UInt(1)) + | + | ${cmds.mkString("\n ")} + |""".stripMargin - "Registered test memory" should "return written data after two cycles" taggedAs(RequiresZ3) in { + "Registered test memory" should "return written data after two cycles" taggedAs (RequiresZ3) in { val cmds = """node past_past_valid = geq(cycle, UInt(2)) |reg past_in: UInt<8>, clock @@ -85,23 +85,29 @@ class MemorySpec extends EndToEndSMTBaseSpec { |""".stripMargin private def m(num: Int) = CircuitTarget(s"Mem0$num").module(s"Mem0$num").ref("m") - "read-only memory" should "always return 0" taggedAs(RequiresZ3) in { - test(readOnlyMem("eq(out, UInt(0))", 1), MCSuccess, kmax=2, - annos=Seq(MemoryScalarInitAnnotation(m(1), 0))) + "read-only memory" should "always return 0" taggedAs (RequiresZ3) in { + test(readOnlyMem("eq(out, UInt(0))", 1), MCSuccess, kmax = 2, annos = Seq(MemoryScalarInitAnnotation(m(1), 0))) } - "read-only memory" should "not always return 1" taggedAs(RequiresZ3) in { - test(readOnlyMem("eq(out, UInt(1))", 2), MCFail(0), kmax=2, - annos=Seq(MemoryScalarInitAnnotation(m(2), 0))) + "read-only memory" should "not always return 1" taggedAs (RequiresZ3) in { + test(readOnlyMem("eq(out, UInt(1))", 2), MCFail(0), kmax = 2, annos = Seq(MemoryScalarInitAnnotation(m(2), 0))) } - "read-only memory" should "always return 1 or 2" taggedAs(RequiresZ3) in { - test(readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 3), MCSuccess, kmax=2, - annos=Seq(MemoryArrayInitAnnotation(m(3), Seq(1, 2, 2, 1)))) + "read-only memory" should "always return 1 or 2" taggedAs (RequiresZ3) in { + test( + readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 3), + MCSuccess, + kmax = 2, + annos = Seq(MemoryArrayInitAnnotation(m(3), Seq(1, 2, 2, 1))) + ) } - "read-only memory" should "not always return 1 or 2 or 3" taggedAs(RequiresZ3) in { - test(readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 4), MCFail(0), kmax=2, - annos=Seq(MemoryArrayInitAnnotation(m(4), Seq(1, 2, 2, 3)))) + "read-only memory" should "not always return 1 or 2 or 3" taggedAs (RequiresZ3) in { + test( + readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 4), + MCFail(0), + kmax = 2, + annos = Seq(MemoryArrayInitAnnotation(m(4), Seq(1, 2, 2, 3))) + ) } } diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala index cbf194dd..8ece0e23 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala @@ -13,15 +13,17 @@ import scala.sys.process.{Process, ProcessLogger} /** compiles the regression tests to SMTLib and parses the result with z3 */ class SMTCompilationTest extends AnyFlatSpec with LazyLogging { - it should "generate valid SMTLib for AddNot" taggedAs(RequiresZ3) in { compileAndParse("AddNot") } - it should "generate valid SMTLib for FPU" taggedAs(RequiresZ3) in { compileAndParse("FPU") } + it should "generate valid SMTLib for AddNot" taggedAs (RequiresZ3) in { compileAndParse("AddNot") } + it should "generate valid SMTLib for FPU" taggedAs (RequiresZ3) in { compileAndParse("FPU") } // we get a stack overflow in Scala 2.11 because of a deeply nested and(...) expression in the sequencer - it should "generate valid SMTLib for HwachaSequencer" taggedAs(RequiresZ3) ignore { compileAndParse("HwachaSequencer") } - it should "generate valid SMTLib for ICache" taggedAs(RequiresZ3) in { compileAndParse("ICache") } - it should "generate valid SMTLib for Ops" taggedAs(RequiresZ3) in { compileAndParse("Ops") } + it should "generate valid SMTLib for HwachaSequencer" taggedAs (RequiresZ3) ignore { + compileAndParse("HwachaSequencer") + } + it should "generate valid SMTLib for ICache" taggedAs (RequiresZ3) in { compileAndParse("ICache") } + it should "generate valid SMTLib for Ops" taggedAs (RequiresZ3) in { compileAndParse("Ops") } // TODO: enable Rob test once we support more than 2 write ports on a memory - it should "generate valid SMTLib for Rob" taggedAs(RequiresZ3) ignore { compileAndParse("Rob") } - it should "generate valid SMTLib for RocketCore" taggedAs(RequiresZ3) in { compileAndParse("RocketCore") } + it should "generate valid SMTLib for Rob" taggedAs (RequiresZ3) ignore { compileAndParse("Rob") } + it should "generate valid SMTLib for RocketCore" taggedAs (RequiresZ3) in { compileAndParse("RocketCore") } private def compileAndParse(name: String): Unit = { val testDir = BackendCompilationUtilities.createTestDirectory(name + "-smt") @@ -29,14 +31,18 @@ class SMTCompilationTest extends AnyFlatSpec with LazyLogging { BackendCompilationUtilities.copyResourceToFile(s"/regress/${name}.fir", inputFile) val args = Array( - "-ll", "error", // surpress warnings to keep test output clean - "--target-dir", testDir.toString, - "-i", inputFile.toString, - "-E", "experimental-smt2" + "-ll", + "error", // surpress warnings to keep test output clean + "--target-dir", + testDir.toString, + "-i", + inputFile.toString, + "-E", + "experimental-smt2" // "-fct", "firrtl.backends.experimental.smt.StutteringClockTransform" ) val res = (new FirrtlStage).execute(args, Seq()) - val fileName = res.collectFirst{ case OutputFileAnnotation(file) => file }.get + val fileName = res.collectFirst { case OutputFileAnnotation(file) => file }.get val smtFile = testDir.toString + "/" + fileName + ".smt2" val log = ProcessLogger(_ => (), logger.error(_)) diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala index 8fa80b4c..8682c2ce 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala @@ -5,27 +5,26 @@ package firrtl.backends.experimental.smt.end2end /** undefined values in firrtl are modelled as fresh auxiliary variables (inputs) */ class UndefinedFirrtlSpec extends EndToEndSMTBaseSpec { - "division by zero" should "result in an arbitrary value" taggedAs(RequiresZ3) in { + "division by zero" should "result in an arbitrary value" taggedAs (RequiresZ3) in { // the SMTLib spec defines the result of division by zero to be all 1s // https://cs.nyu.edu/pipermail/smt-lib/2015/000977.html def in(dEq: Int) = - s"""circuit CC00: - | module CC00: - | input c: Clock - | input a: UInt<2> - | input b: UInt<2> - | assume(c, eq(b, UInt(0)), UInt(1), "b = 0") - | node d = div(a, b) - | assert(c, eq(d, UInt($dEq)), UInt(1), "d = $dEq") - |""".stripMargin + s"""circuit CC00: + | module CC00: + | input c: Clock + | input a: UInt<2> + | input b: UInt<2> + | assume(c, eq(b, UInt(0)), UInt(1), "b = 0") + | node d = div(a, b) + | assert(c, eq(d, UInt($dEq)), UInt(1), "d = $dEq") + |""".stripMargin // we try to assert that (d = a / 0) is any fixed value which should be false (0 until 4).foreach { ii => test(in(ii), MCFail(0), 0, s"d = a / 0 = $ii") } } // TODO: rem should probably also be undefined, but the spec isn't 100% clear here - - "invalid signals" should "have an arbitrary values" taggedAs(RequiresZ3) in { + "invalid signals" should "have an arbitrary values" taggedAs (RequiresZ3) in { def in(aEq: Int) = s"""circuit CC00: | module CC00: diff --git a/src/test/scala/firrtl/ir/StructuralHashSpec.scala b/src/test/scala/firrtl/ir/StructuralHashSpec.scala index 17fe0b84..c4622939 100644 --- a/src/test/scala/firrtl/ir/StructuralHashSpec.scala +++ b/src/test/scala/firrtl/ir/StructuralHashSpec.scala @@ -6,11 +6,11 @@ import firrtl.PrimOps._ import org.scalatest.flatspec.AnyFlatSpec class StructuralHashSpec extends AnyFlatSpec { - private def hash(n: DefModule): HashCode = StructuralHash.sha256(n, n => n) - private def hash(c: Circuit): HashCode = StructuralHash.sha256Node(c) + private def hash(n: DefModule): HashCode = StructuralHash.sha256(n, n => n) + private def hash(c: Circuit): HashCode = StructuralHash.sha256Node(c) private def hash(e: Expression): HashCode = StructuralHash.sha256Node(e) - private def hash(t: Type): HashCode = StructuralHash.sha256Node(t) - private def hash(s: Statement): HashCode = StructuralHash.sha256Node(s) + private def hash(t: Type): HashCode = StructuralHash.sha256Node(t) + private def hash(s: Statement): HashCode = StructuralHash.sha256Node(s) private val highFirrtlCompiler = new firrtl.stage.transforms.Compiler( targets = firrtl.stage.Forms.HighForm ) @@ -24,18 +24,18 @@ class StructuralHashSpec extends AnyFlatSpec { highFirrtlCompiler.transform(firrtl.CircuitState(rawFirrtl, Seq())).circuit } - private val b0 = UIntLiteral(0,IntWidth(1)) - private val b1 = UIntLiteral(1,IntWidth(1)) + private val b0 = UIntLiteral(0, IntWidth(1)) + private val b1 = UIntLiteral(1, IntWidth(1)) private val add = DoPrim(Add, Seq(b0, b1), Seq(), UnknownType) it should "generate the same hash if the objects are structurally the same" in { - assert(hash(b0) == hash(UIntLiteral(0,IntWidth(1)))) - assert(hash(b0) != hash(UIntLiteral(1,IntWidth(1)))) - assert(hash(b0) != hash(UIntLiteral(1,IntWidth(2)))) + assert(hash(b0) == hash(UIntLiteral(0, IntWidth(1)))) + assert(hash(b0) != hash(UIntLiteral(1, IntWidth(1)))) + assert(hash(b0) != hash(UIntLiteral(1, IntWidth(2)))) - assert(hash(b1) == hash(UIntLiteral(1,IntWidth(1)))) - assert(hash(b1) != hash(UIntLiteral(0,IntWidth(1)))) - assert(hash(b1) != hash(UIntLiteral(1,IntWidth(2)))) + assert(hash(b1) == hash(UIntLiteral(1, IntWidth(1)))) + assert(hash(b1) != hash(UIntLiteral(0, IntWidth(1)))) + assert(hash(b1) != hash(UIntLiteral(1, IntWidth(2)))) } it should "ignore expression types" in { @@ -84,16 +84,19 @@ class StructuralHashSpec extends AnyFlatSpec { |""".stripMargin assert(hash(parse(a)) != hash(parse(d)), "circuits with different names are always different") - assert(hash(parse(a).modules.head) == hash(parse(d).modules.head), - "modules with different names can be structurally different") + assert( + hash(parse(a).modules.head) == hash(parse(d).modules.head), + "modules with different names can be structurally different" + ) // for the Dedup pass we do need a way to take the port names into account - assert(StructuralHash.sha256WithSignificantPortNames(parse(a).modules.head) != - StructuralHash.sha256WithSignificantPortNames(parse(b).modules.head), - "renaming ports does affect the hash if we ask to") + assert( + StructuralHash.sha256WithSignificantPortNames(parse(a).modules.head) != + StructuralHash.sha256WithSignificantPortNames(parse(b).modules.head), + "renaming ports does affect the hash if we ask to" + ) } - it should "not ignore port names if asked to" in { val e = """circuit a: @@ -119,14 +122,20 @@ class StructuralHashSpec extends AnyFlatSpec { | z <= x |""".stripMargin - assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) != - StructuralHash.sha256WithSignificantPortNames(parse(f).modules.head), - "renaming ports does affect the hash if we ask to") - assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) == - StructuralHash.sha256WithSignificantPortNames(parse(g).modules.head), - "renaming internal wires should never affect the hash") - assert(hash(parse(e).modules.head) == hash(parse(g).modules.head), - "renaming internal wires should never affect the hash") + assert( + StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) != + StructuralHash.sha256WithSignificantPortNames(parse(f).modules.head), + "renaming ports does affect the hash if we ask to" + ) + assert( + StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) == + StructuralHash.sha256WithSignificantPortNames(parse(g).modules.head), + "renaming internal wires should never affect the hash" + ) + assert( + hash(parse(e).modules.head) == hash(parse(g).modules.head), + "renaming internal wires should never affect the hash" + ) } it should "not ignore port bundle names if asked to" in { @@ -154,19 +163,26 @@ class StructuralHashSpec extends AnyFlatSpec { | y.z <= x.x |""".stripMargin - assert(hash(parse(e).modules.head) == hash(parse(f).modules.head), - "renaming port bundles does normally not affect the hash") - assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) != - StructuralHash.sha256WithSignificantPortNames(parse(f).modules.head), - "renaming port bundles does affect the hash if we ask to") - assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) == - StructuralHash.sha256WithSignificantPortNames(parse(g).modules.head), - "renaming internal wire bundles should never affect the hash") - assert(hash(parse(e).modules.head) == hash(parse(g).modules.head), - "renaming internal wire bundles should never affect the hash") + assert( + hash(parse(e).modules.head) == hash(parse(f).modules.head), + "renaming port bundles does normally not affect the hash" + ) + assert( + StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) != + StructuralHash.sha256WithSignificantPortNames(parse(f).modules.head), + "renaming port bundles does affect the hash if we ask to" + ) + assert( + StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) == + StructuralHash.sha256WithSignificantPortNames(parse(g).modules.head), + "renaming internal wire bundles should never affect the hash" + ) + assert( + hash(parse(e).modules.head) == hash(parse(g).modules.head), + "renaming internal wire bundles should never affect the hash" + ) } - it should "fail on Info" in { // it does not make sense to hash Info nodes assertThrows[RuntimeException] { @@ -178,9 +194,9 @@ class StructuralHashSpec extends AnyFlatSpec { def parse(str: String): BundleType = { val src = s"""circuit c: - | module c: - | input z: $str - |""".stripMargin + | module c: + | input z: $str + |""".stripMargin val c = firrtl.Parser.parse(src) val tpe = c.modules.head.ports.head.tpe tpe.asInstanceOf[BundleType] @@ -219,11 +235,15 @@ class StructuralHashSpec extends AnyFlatSpec { // Q: should extmodule portnames always be significant since they map to the verilog pins? // A: It would be a bug for two exmodules in the same circuit to have the same defname but different // port names. This should be detected by an earlier pass and thus we do not have to deal with that situation. - assert(hash(parse(a).modules.head) == hash(parse(b).modules.head), - "two ext modules with the same defname and the same type and number of ports") - assert(StructuralHash.sha256WithSignificantPortNames(parse(a).modules.head) != - StructuralHash.sha256WithSignificantPortNames(parse(b).modules.head), - "two ext modules with significant port names") + assert( + hash(parse(a).modules.head) == hash(parse(b).modules.head), + "two ext modules with the same defname and the same type and number of ports" + ) + assert( + StructuralHash.sha256WithSignificantPortNames(parse(a).modules.head) != + StructuralHash.sha256WithSignificantPortNames(parse(b).modules.head), + "two ext modules with significant port names" + ) } "Blocks and empty statements" should "not affect structural equivalence" in { @@ -269,9 +289,9 @@ class StructuralHashSpec extends AnyFlatSpec { } private case object DebugHasher extends Hasher { - override def update(b: Byte): Unit = println(s"b(${b.toInt & 0xff})") - override def update(i: Int): Unit = println(s"i(${i})") - override def update(l: Long): Unit = println(s"l(${l})") - override def update(s: String): Unit = println(s"s(${s})") + override def update(b: Byte): Unit = println(s"b(${b.toInt & 0xff})") + override def update(i: Int): Unit = println(s"i(${i})") + override def update(l: Long): Unit = println(s"l(${l})") + override def update(s: String): Unit = println(s"s(${s})") override def update(b: Array[Byte]): Unit = println(s"bytes(${b.map(x => x.toInt & 0xff).mkString(", ")})") -}
\ No newline at end of file +} diff --git a/src/test/scala/firrtl/passes/LowerTypesSpec.scala b/src/test/scala/firrtl/passes/LowerTypesSpec.scala index 884e51b8..0575c5da 100644 --- a/src/test/scala/firrtl/passes/LowerTypesSpec.scala +++ b/src/test/scala/firrtl/passes/LowerTypesSpec.scala @@ -8,7 +8,6 @@ import firrtl.stage.TransformManager import firrtl.stage.TransformManager.TransformDependency import org.scalatest.flatspec.AnyFlatSpec - /** Unit test style tests for [[LowerTypes]]. * You can find additional integration style tests in [[firrtlTests.LowerTypesSpec]] */ @@ -31,11 +30,12 @@ class LowerTypesEndToEndSpec extends LowerTypesBaseSpec { | $n is invalid |""".stripMargin val c = CircuitState(firrtl.Parser.parse(src), Seq()) - val c2 = lowerTypesCompiler.execute(c) + val c2 = lowerTypesCompiler.execute(c) val ps = c2.circuit.modules.head.ports.filterNot(p => namespace.contains(p.name)) - ps.map{p => + ps.map { p => val orientation = Utils.to_flip(p.direction) - s"${orientation.serialize}${p.name} : ${p.tpe.serialize}"} + s"${orientation.serialize}${p.name} : ${p.tpe.serialize}" + } } override protected def lower(n: String, tpe: String, namespace: Set[String]): Seq[String] = @@ -50,8 +50,10 @@ abstract class LowerTypesBaseSpec extends AnyFlatSpec { assert(lower("a", "{ a : UInt<1>, b : UInt<1>}") == Seq("a_a : UInt<1>", "a_b : UInt<1>")) assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}") == Seq("a_a : UInt<1>", "a_b_c : UInt<1>")) assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2]}") == Seq("a_a : UInt<1>", "a_b_0 : UInt<1>", "a_b_1 : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]") == - Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>")) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>}[2]") == + Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>") + ) // with conflicts assert(lower("a", "{ a : UInt<1>, b : UInt<1>}", Set("a_a")) == Seq("a__a : UInt<1>", "a__b : UInt<1>")) @@ -63,40 +65,71 @@ abstract class LowerTypesBaseSpec extends AnyFlatSpec { assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}", Set("a_b")) == Seq("a__a : UInt<1>", "a__b_c : UInt<1>")) assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}", Set("a_b_c")) == Seq("a__a : UInt<1>", "a__b_c : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_a")) == - Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_a", "a_b_0")) == - Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_b_0")) == - Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>")) - - assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0")) == - Seq("a__0_a : UInt<1>", "a__0_b : UInt<1>", "a__1_a : UInt<1>", "a__1_b : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_3")) == - Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0_a")) == - Seq("a__0_a : UInt<1>", "a__0_b : UInt<1>", "a__1_a : UInt<1>", "a__1_b : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0_c")) == - Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>")) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_a")) == + Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>") + ) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_a", "a_b_0")) == + Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>") + ) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_b_0")) == + Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>") + ) + + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0")) == + Seq("a__0_a : UInt<1>", "a__0_b : UInt<1>", "a__1_a : UInt<1>", "a__1_b : UInt<1>") + ) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_3")) == + Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>") + ) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0_a")) == + Seq("a__0_a : UInt<1>", "a__0_b : UInt<1>", "a__1_a : UInt<1>", "a__1_b : UInt<1>") + ) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0_c")) == + Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>") + ) // collisions inside the bundle - assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_c : UInt<1>}") == - Seq("a_a : UInt<1>", "a_b__c : UInt<1>", "a_b_c : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_b : UInt<1>}") == - Seq("a_a : UInt<1>", "a_b_c : UInt<1>", "a_b_b : UInt<1>")) - - assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2], b_0 : UInt<1>}") == - Seq("a_a : UInt<1>", "a_b__0 : UInt<1>", "a_b__1 : UInt<1>", "a_b_0 : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2], b_c : UInt<1>}") == - Seq("a_a : UInt<1>", "a_b_0 : UInt<1>", "a_b_1 : UInt<1>", "a_b_c : UInt<1>")) + assert( + lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_c : UInt<1>}") == + Seq("a_a : UInt<1>", "a_b__c : UInt<1>", "a_b_c : UInt<1>") + ) + assert( + lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_b : UInt<1>}") == + Seq("a_a : UInt<1>", "a_b_c : UInt<1>", "a_b_b : UInt<1>") + ) + + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>[2], b_0 : UInt<1>}") == + Seq("a_a : UInt<1>", "a_b__0 : UInt<1>", "a_b__1 : UInt<1>", "a_b_0 : UInt<1>") + ) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>[2], b_c : UInt<1>}") == + Seq("a_a : UInt<1>", "a_b_0 : UInt<1>", "a_b_1 : UInt<1>", "a_b_c : UInt<1>") + ) } it should "correctly lower the orientation" in { assert(lower("a", "{ flip a : UInt<1>, b : UInt<1>}") == Seq("flip a_a : UInt<1>", "a_b : UInt<1>")) - assert(lower("a", "{ flip a : UInt<1>[2], b : UInt<1>}") == - Seq("flip a_a_0 : UInt<1>", "flip a_a_1 : UInt<1>", "a_b : UInt<1>")) - assert(lower("a", "{ a : { flip c : UInt<1>, d : UInt<1>}[2], b : UInt<1>}") == - Seq("flip a_a_0_c : UInt<1>", "a_a_0_d : UInt<1>", "flip a_a_1_c : UInt<1>", "a_a_1_d : UInt<1>", "a_b : UInt<1>") + assert( + lower("a", "{ flip a : UInt<1>[2], b : UInt<1>}") == + Seq("flip a_a_0 : UInt<1>", "flip a_a_1 : UInt<1>", "a_b : UInt<1>") + ) + assert( + lower("a", "{ a : { flip c : UInt<1>, d : UInt<1>}[2], b : UInt<1>}") == + Seq( + "flip a_a_0_c : UInt<1>", + "a_a_0_d : UInt<1>", + "flip a_a_1_c : UInt<1>", + "a_a_1_d : UInt<1>", + "a_b : UInt<1>" + ) ) } } @@ -121,43 +154,45 @@ class LowerTypesRenamingSpec extends AnyFlatSpec { def one(namespace: Set[String], prefix: String): Unit = { val r = lower("a", "{ a : UInt<1>, b : UInt<1>}", namespace) - assert(get(r,a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b"))) - assert(get(r,a.field("a")) == Set(m.ref(prefix + "a"))) - assert(get(r,a.field("b")) == Set(m.ref(prefix + "b"))) + assert(get(r, a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b"))) + assert(get(r, a.field("a")) == Set(m.ref(prefix + "a"))) + assert(get(r, a.field("b")) == Set(m.ref(prefix + "b"))) } one(Set(), "a_") one(Set("a_a"), "a__") def two(namespace: Set[String], prefix: String): Unit = { - val r = lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}", namespace) - assert(get(r,a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b_c"))) - assert(get(r,a.field("a")) == Set(m.ref(prefix + "a"))) - assert(get(r,a.field("b")) == Set(m.ref(prefix + "b_c"))) - assert(get(r,a.field("b").field("c")) == Set(m.ref(prefix + "b_c"))) + val r = lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}", namespace) + assert(get(r, a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b_c"))) + assert(get(r, a.field("a")) == Set(m.ref(prefix + "a"))) + assert(get(r, a.field("b")) == Set(m.ref(prefix + "b_c"))) + assert(get(r, a.field("b").field("c")) == Set(m.ref(prefix + "b_c"))) } two(Set(), "a_") two(Set("a_a"), "a__") def three(namespace: Set[String], prefix: String): Unit = { val r = lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", namespace) - assert(get(r,a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b_0"), m.ref(prefix + "b_1"))) - assert(get(r,a.field("a")) == Set(m.ref(prefix + "a"))) - assert(get(r,a.field("b")) == Set( m.ref(prefix + "b_0"), m.ref(prefix + "b_1"))) - assert(get(r,a.field("b").index(0)) == Set(m.ref(prefix + "b_0"))) - assert(get(r,a.field("b").index(1)) == Set(m.ref(prefix + "b_1"))) + assert(get(r, a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b_0"), m.ref(prefix + "b_1"))) + assert(get(r, a.field("a")) == Set(m.ref(prefix + "a"))) + assert(get(r, a.field("b")) == Set(m.ref(prefix + "b_0"), m.ref(prefix + "b_1"))) + assert(get(r, a.field("b").index(0)) == Set(m.ref(prefix + "b_0"))) + assert(get(r, a.field("b").index(1)) == Set(m.ref(prefix + "b_1"))) } three(Set(), "a_") three(Set("a_b_0"), "a__") def four(namespace: Set[String], prefix: String): Unit = { val r = lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", namespace) - assert(get(r,a) == Set(m.ref(prefix + "0_a"), m.ref(prefix + "1_a"), m.ref(prefix + "0_b"), m.ref(prefix + "1_b"))) - assert(get(r,a.index(0)) == Set(m.ref(prefix + "0_a"), m.ref(prefix + "0_b"))) - assert(get(r,a.index(1)) == Set(m.ref(prefix + "1_a"), m.ref(prefix + "1_b"))) - assert(get(r,a.index(0).field("a")) == Set(m.ref(prefix + "0_a"))) - assert(get(r,a.index(0).field("b")) == Set(m.ref(prefix + "0_b"))) - assert(get(r,a.index(1).field("a")) == Set(m.ref(prefix + "1_a"))) - assert(get(r,a.index(1).field("b")) == Set(m.ref(prefix + "1_b"))) + assert( + get(r, a) == Set(m.ref(prefix + "0_a"), m.ref(prefix + "1_a"), m.ref(prefix + "0_b"), m.ref(prefix + "1_b")) + ) + assert(get(r, a.index(0)) == Set(m.ref(prefix + "0_a"), m.ref(prefix + "0_b"))) + assert(get(r, a.index(1)) == Set(m.ref(prefix + "1_a"), m.ref(prefix + "1_b"))) + assert(get(r, a.index(0).field("a")) == Set(m.ref(prefix + "0_a"))) + assert(get(r, a.index(0).field("b")) == Set(m.ref(prefix + "0_b"))) + assert(get(r, a.index(1).field("a")) == Set(m.ref(prefix + "1_a"))) + assert(get(r, a.index(1).field("b")) == Set(m.ref(prefix + "1_b"))) } four(Set(), "a_") four(Set("a_0"), "a__") @@ -166,28 +201,28 @@ class LowerTypesRenamingSpec extends AnyFlatSpec { // collisions inside the bundle { val r = lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_c : UInt<1>}") - assert(get(r,a) == Set(m.ref("a_a"), m.ref("a_b__c"), m.ref("a_b_c"))) - assert(get(r,a.field("a")) == Set(m.ref("a_a"))) - assert(get(r,a.field("b")) == Set(m.ref("a_b__c"))) - assert(get(r,a.field("b").field("c")) == Set(m.ref("a_b__c"))) - assert(get(r,a.field("b_c")) == Set(m.ref("a_b_c"))) + assert(get(r, a) == Set(m.ref("a_a"), m.ref("a_b__c"), m.ref("a_b_c"))) + assert(get(r, a.field("a")) == Set(m.ref("a_a"))) + assert(get(r, a.field("b")) == Set(m.ref("a_b__c"))) + assert(get(r, a.field("b").field("c")) == Set(m.ref("a_b__c"))) + assert(get(r, a.field("b_c")) == Set(m.ref("a_b_c"))) } { val r = lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_b : UInt<1>}") - assert(get(r,a) == Set(m.ref("a_a"), m.ref("a_b_c"), m.ref("a_b_b"))) - assert(get(r,a.field("a")) == Set(m.ref("a_a"))) - assert(get(r,a.field("b")) == Set(m.ref("a_b_c"))) - assert(get(r,a.field("b").field("c")) == Set(m.ref("a_b_c"))) - assert(get(r,a.field("b_b")) == Set(m.ref("a_b_b"))) + assert(get(r, a) == Set(m.ref("a_a"), m.ref("a_b_c"), m.ref("a_b_b"))) + assert(get(r, a.field("a")) == Set(m.ref("a_a"))) + assert(get(r, a.field("b")) == Set(m.ref("a_b_c"))) + assert(get(r, a.field("b").field("c")) == Set(m.ref("a_b_c"))) + assert(get(r, a.field("b_b")) == Set(m.ref("a_b_b"))) } { val r = lower("a", "{ a : UInt<1>, b : UInt<1>[2], b_0 : UInt<1>}") - assert(get(r,a) == Set(m.ref("a_a"), m.ref("a_b__0"), m.ref("a_b__1"), m.ref("a_b_0"))) - assert(get(r,a.field("a")) == Set(m.ref("a_a"))) - assert(get(r,a.field("b")) == Set(m.ref("a_b__0"), m.ref("a_b__1"))) - assert(get(r,a.field("b").index(0)) == Set(m.ref("a_b__0"))) - assert(get(r,a.field("b").index(1)) == Set(m.ref("a_b__1"))) - assert(get(r,a.field("b_0")) == Set(m.ref("a_b_0"))) + assert(get(r, a) == Set(m.ref("a_a"), m.ref("a_b__0"), m.ref("a_b__1"), m.ref("a_b_0"))) + assert(get(r, a.field("a")) == Set(m.ref("a_a"))) + assert(get(r, a.field("b")) == Set(m.ref("a_b__0"), m.ref("a_b__1"))) + assert(get(r, a.field("b").index(0)) == Set(m.ref("a_b__0"))) + assert(get(r, a.field("b").index(1)) == Set(m.ref("a_b__1"))) + assert(get(r, a.field("b_0")) == Set(m.ref("a_b_0"))) } } } @@ -199,8 +234,13 @@ class LowerTypesOfInstancesSpec extends AnyFlatSpec { private val m = CircuitTarget("m").module("m") def resultToFieldSeq(res: Seq[(String, firrtl.ir.SubField)]): Seq[String] = res.map(_._2).map(r => s"${r.name} : ${r.tpe.serialize}") - private def lower(n: String, tpe: String, module: String, namespace: Set[String], renames: RenameMap = RenameMap()): - Lower = { + private def lower( + n: String, + tpe: String, + module: String, + namespace: Set[String], + renames: RenameMap = RenameMap() + ): Lower = { val ref = firrtl.ir.DefInstance(firrtl.ir.NoInfo, n, module, parseType(tpe)) val mutableSet = scala.collection.mutable.HashSet[String]() ++ namespace val (newInstance, res) = DestructTypes.destructInstance(m, ref, mutableSet, renames, Set()) @@ -269,7 +309,7 @@ class LowerTypesOfInstancesSpec extends AnyFlatSpec { assert(r.get(i.ref("b")).get == Seq(i_.ref("b__c"))) assert(r.get(i.ref("b").field("c")).get == Seq(i_.ref("b__c"))) assert(r.get(i.ref("b_c")).get == Seq(i_.ref("b_c"))) - } + } } } @@ -278,101 +318,139 @@ class LowerTypesOfInstancesSpec extends AnyFlatSpec { */ class LowerTypesOfMemorySpec extends AnyFlatSpec { import LowerTypesSpecUtils._ - private case class Lower(mems: Seq[firrtl.ir.DefMemory], refs: Seq[(String, firrtl.ir.SubField)], - renameMap: RenameMap) + private case class Lower( + mems: Seq[firrtl.ir.DefMemory], + refs: Seq[(String, firrtl.ir.SubField)], + renameMap: RenameMap) private val m = CircuitTarget("m").module("m") private val mem = m.ref("mem") - private def lower(name: String, tpe: String, namespace: Set[String], - r: Seq[String] = List("r"), w: Seq[String] = List("w"), rw: Seq[String] = List(), depth: Int = 2): Lower = { + private def lower( + name: String, + tpe: String, + namespace: Set[String], + r: Seq[String] = List("r"), + w: Seq[String] = List("w"), + rw: Seq[String] = List(), + depth: Int = 2 + ): Lower = { val dataType = parseType(tpe) - val mem = firrtl.ir.DefMemory(firrtl.ir.NoInfo, name, dataType, depth = depth, writeLatency = 1, readLatency = 1, - readUnderWrite = firrtl.ir.ReadUnderWrite.Undefined, readers = r, writers = w, readwriters = rw) + val mem = firrtl.ir.DefMemory( + firrtl.ir.NoInfo, + name, + dataType, + depth = depth, + writeLatency = 1, + readLatency = 1, + readUnderWrite = firrtl.ir.ReadUnderWrite.Undefined, + readers = r, + writers = w, + readwriters = rw + ) val renames = RenameMap() val mutableSet = scala.collection.mutable.HashSet[String]() ++ namespace - val(mems, refs) = DestructTypes.destructMemory(m, mem, mutableSet, renames, Set()) + val (mems, refs) = DestructTypes.destructMemory(m, mem, mutableSet, renames, Set()) Lower(mems, refs, renames) } private val UInt1 = firrtl.ir.UIntType(firrtl.ir.IntWidth(1)) it should "not rename anything for a ground type memory if there was no conflict" in { - val l = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w=Seq("w")) + val l = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w = Seq("w")) assert(l.renameMap.underlying.isEmpty) } it should "still produce reference lookups, even for a ground type memory with no conflicts" in { - val nameToRef = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w=Seq("w")).refs - .map{case (n,r) => n -> r.serialize}.toSet - - assert(nameToRef == Set( - "mem.r.clk" -> "mem.r.clk", - "mem.r.en" -> "mem.r.en", - "mem.r.addr" -> "mem.r.addr", - "mem.r.data" -> "mem.r.data", - "mem.w.clk" -> "mem.w.clk", - "mem.w.en" -> "mem.w.en", - "mem.w.addr" -> "mem.w.addr", - "mem.w.data" -> "mem.w.data", - "mem.w.mask" -> "mem.w.mask" - )) + val nameToRef = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w = Seq("w")).refs.map { + case (n, r) => n -> r.serialize + }.toSet + + assert( + nameToRef == Set( + "mem.r.clk" -> "mem.r.clk", + "mem.r.en" -> "mem.r.en", + "mem.r.addr" -> "mem.r.addr", + "mem.r.data" -> "mem.r.data", + "mem.w.clk" -> "mem.w.clk", + "mem.w.en" -> "mem.w.en", + "mem.w.addr" -> "mem.w.addr", + "mem.w.data" -> "mem.w.data", + "mem.w.mask" -> "mem.w.mask" + ) + ) } it should "produce references of correct type" in { - val nameToType = lower("mem", "UInt<4>", Set("mem_r", "mem_r_data"), w=Seq("w"), depth = 3).refs - .map{case (n,r) => n -> r.tpe.serialize}.toSet - - assert(nameToType == Set( - "mem.r.clk" -> "Clock", - "mem.r.en" -> "UInt<1>", - "mem.r.addr" -> "UInt<2>", // depth = 3 - "mem.r.data" -> "UInt<4>", - "mem.w.clk" -> "Clock", - "mem.w.en" -> "UInt<1>", - "mem.w.addr" -> "UInt<2>", - "mem.w.data" -> "UInt<4>", - "mem.w.mask" -> "UInt<1>" - )) + val nameToType = lower("mem", "UInt<4>", Set("mem_r", "mem_r_data"), w = Seq("w"), depth = 3).refs.map { + case (n, r) => n -> r.tpe.serialize + }.toSet + + assert( + nameToType == Set( + "mem.r.clk" -> "Clock", + "mem.r.en" -> "UInt<1>", + "mem.r.addr" -> "UInt<2>", // depth = 3 + "mem.r.data" -> "UInt<4>", + "mem.w.clk" -> "Clock", + "mem.w.en" -> "UInt<1>", + "mem.w.addr" -> "UInt<2>", + "mem.w.data" -> "UInt<4>", + "mem.w.mask" -> "UInt<1>" + ) + ) } it should "not rename ground type memories even if there are conflicts on the ports" in { // There actually isn't such a thing as conflicting ports, because they do not get flattened by LowerTypes. - val r = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w=Seq("r_data")).renameMap + val r = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w = Seq("r_data")).renameMap assert(r.underlying.isEmpty) } it should "rename references to lowered ports" in { - val r = lower("mem", "{ a : UInt<1>, b : UInt<1>}", Set("mem_a"), r=Seq("r", "r_data")).renameMap + val r = lower("mem", "{ a : UInt<1>, b : UInt<1>}", Set("mem_a"), r = Seq("r", "r_data")).renameMap // complete memory assert(get(r, mem) == Set(m.ref("mem__a"), m.ref("mem__b"))) // read ports - assert(get(r, mem.field("r")) == - Set(m.ref("mem__a").field("r"), m.ref("mem__b").field("r"))) - assert(get(r, mem.field("r_data")) == - Set(m.ref("mem__a").field("r_data"), m.ref("mem__b").field("r_data"))) + assert( + get(r, mem.field("r")) == + Set(m.ref("mem__a").field("r"), m.ref("mem__b").field("r")) + ) + assert( + get(r, mem.field("r_data")) == + Set(m.ref("mem__a").field("r_data"), m.ref("mem__b").field("r_data")) + ) // port fields - assert(get(r, mem.field("r").field("data")) == - Set(m.ref("mem__a").field("r").field("data"), - m.ref("mem__b").field("r").field("data"))) - assert(get(r, mem.field("r").field("addr")) == - Set(m.ref("mem__a").field("r").field("addr"), - m.ref("mem__b").field("r").field("addr"))) - assert(get(r, mem.field("r").field("en")) == - Set(m.ref("mem__a").field("r").field("en"), - m.ref("mem__b").field("r").field("en"))) - assert(get(r, mem.field("r").field("clk")) == - Set(m.ref("mem__a").field("r").field("clk"), - m.ref("mem__b").field("r").field("clk"))) - assert(get(r, mem.field("w").field("mask")) == - Set(m.ref("mem__a").field("w").field("mask"), - m.ref("mem__b").field("w").field("mask"))) + assert( + get(r, mem.field("r").field("data")) == + Set(m.ref("mem__a").field("r").field("data"), m.ref("mem__b").field("r").field("data")) + ) + assert( + get(r, mem.field("r").field("addr")) == + Set(m.ref("mem__a").field("r").field("addr"), m.ref("mem__b").field("r").field("addr")) + ) + assert( + get(r, mem.field("r").field("en")) == + Set(m.ref("mem__a").field("r").field("en"), m.ref("mem__b").field("r").field("en")) + ) + assert( + get(r, mem.field("r").field("clk")) == + Set(m.ref("mem__a").field("r").field("clk"), m.ref("mem__b").field("r").field("clk")) + ) + assert( + get(r, mem.field("w").field("mask")) == + Set(m.ref("mem__a").field("w").field("mask"), m.ref("mem__b").field("w").field("mask")) + ) // port sub-fields - assert(get(r, mem.field("r").field("data").field("a")) == - Set(m.ref("mem__a").field("r").field("data"))) - assert(get(r, mem.field("r").field("data").field("b")) == - Set(m.ref("mem__b").field("r").field("data"))) + assert( + get(r, mem.field("r").field("data").field("a")) == + Set(m.ref("mem__a").field("r").field("data")) + ) + assert( + get(r, mem.field("r").field("data").field("b")) == + Set(m.ref("mem__b").field("r").field("data")) + ) // need to rename the following: // mem -> mem__a, mem__b @@ -395,24 +473,38 @@ class LowerTypesOfMemorySpec extends AnyFlatSpec { assert(get(r, mem) == Set(m.ref("mem__a"), m.ref("mem__b_c"))) // read port - assert(get(r, mem.field("r")) == - Set(m.ref("mem__a").field("r"), m.ref("mem__b_c").field("r"))) + assert( + get(r, mem.field("r")) == + Set(m.ref("mem__a").field("r"), m.ref("mem__b_c").field("r")) + ) // port sub-fields - assert(get(r, mem.field("r").field("data").field("a")) == - Set(m.ref("mem__a").field("r").field("data"))) - assert(get(r, mem.field("r").field("data").field("b")) == - Set(m.ref("mem__b_c").field("r").field("data"))) - assert(get(r, mem.field("r").field("data").field("b").field("c")) == - Set(m.ref("mem__b_c").field("r").field("data"))) + assert( + get(r, mem.field("r").field("data").field("a")) == + Set(m.ref("mem__a").field("r").field("data")) + ) + assert( + get(r, mem.field("r").field("data").field("b")) == + Set(m.ref("mem__b_c").field("r").field("data")) + ) + assert( + get(r, mem.field("r").field("data").field("b").field("c")) == + Set(m.ref("mem__b_c").field("r").field("data")) + ) // the mask field needs to be lowered just like the data field - assert(get(r, mem.field("w").field("mask").field("a")) == - Set(m.ref("mem__a").field("w").field("mask"))) - assert(get(r, mem.field("w").field("mask").field("b")) == - Set(m.ref("mem__b_c").field("w").field("mask"))) - assert(get(r, mem.field("w").field("mask").field("b").field("c")) == - Set(m.ref("mem__b_c").field("w").field("mask"))) + assert( + get(r, mem.field("w").field("mask").field("a")) == + Set(m.ref("mem__a").field("w").field("mask")) + ) + assert( + get(r, mem.field("w").field("mask").field("b")) == + Set(m.ref("mem__b_c").field("w").field("mask")) + ) + assert( + get(r, mem.field("w").field("mask").field("b").field("c")) == + Set(m.ref("mem__b_c").field("w").field("mask")) + ) val renameCount = r.underlying.map(_._2.size).sum assert(renameCount == 11, "it is enough to rename *to* 11 different signals") @@ -420,66 +512,89 @@ class LowerTypesOfMemorySpec extends AnyFlatSpec { } it should "return a name to RefLikeExpression map for a memory with a nested data type" in { - val nameToRef = lower("mem", "{ a : UInt<1>, b : { c : UInt<1>} }", Set("mem_a")).refs - .map{case (n,r) => n -> r.serialize}.toSet - - assert(nameToRef == Set( - // The non "data" or "mask" fields of read and write ports are already of ground type but still do get duplicated. - // They will all carry the exact same value, so for a RHS use of the old signal, any of the expanded ones will do. - "mem.r.clk" -> "mem__a.r.clk", "mem.r.clk" -> "mem__b_c.r.clk", - "mem.r.en" -> "mem__a.r.en", "mem.r.en" -> "mem__b_c.r.en", - "mem.r.addr" -> "mem__a.r.addr", "mem.r.addr" -> "mem__b_c.r.addr", - "mem.w.clk" -> "mem__a.w.clk", "mem.w.clk" -> "mem__b_c.w.clk", - "mem.w.en" -> "mem__a.w.en", "mem.w.en" -> "mem__b_c.w.en", - "mem.w.addr" -> "mem__a.w.addr", "mem.w.addr" -> "mem__b_c.w.addr", - // Ground type references to the data or mask field are unique. - "mem.r.data.a" -> "mem__a.r.data", - "mem.w.data.a" -> "mem__a.w.data", - "mem.w.mask.a" -> "mem__a.w.mask", - "mem.r.data.b.c" -> "mem__b_c.r.data", - "mem.w.data.b.c" -> "mem__b_c.w.data", - "mem.w.mask.b.c" -> "mem__b_c.w.mask" - )) + val nameToRef = lower("mem", "{ a : UInt<1>, b : { c : UInt<1>} }", Set("mem_a")).refs.map { + case (n, r) => n -> r.serialize + }.toSet + + assert( + nameToRef == Set( + // The non "data" or "mask" fields of read and write ports are already of ground type but still do get duplicated. + // They will all carry the exact same value, so for a RHS use of the old signal, any of the expanded ones will do. + "mem.r.clk" -> "mem__a.r.clk", + "mem.r.clk" -> "mem__b_c.r.clk", + "mem.r.en" -> "mem__a.r.en", + "mem.r.en" -> "mem__b_c.r.en", + "mem.r.addr" -> "mem__a.r.addr", + "mem.r.addr" -> "mem__b_c.r.addr", + "mem.w.clk" -> "mem__a.w.clk", + "mem.w.clk" -> "mem__b_c.w.clk", + "mem.w.en" -> "mem__a.w.en", + "mem.w.en" -> "mem__b_c.w.en", + "mem.w.addr" -> "mem__a.w.addr", + "mem.w.addr" -> "mem__b_c.w.addr", + // Ground type references to the data or mask field are unique. + "mem.r.data.a" -> "mem__a.r.data", + "mem.w.data.a" -> "mem__a.w.data", + "mem.w.mask.a" -> "mem__a.w.mask", + "mem.r.data.b.c" -> "mem__b_c.r.data", + "mem.w.data.b.c" -> "mem__b_c.w.data", + "mem.w.mask.b.c" -> "mem__b_c.w.mask" + ) + ) } it should "produce references of correct type for memories with a read/write port" in { - val refs = lower("mem", "{ a : UInt<3>, b : { c : UInt<4>} }", Set("mem_a"), - r=Seq(), w=Seq(), rw=Seq("rw"), depth = 3).refs - val nameToRef = refs.map{case (n,r) => n -> r.serialize}.toSet - val nameToType = refs.map{case (n,r) => n -> r.tpe.serialize}.toSet - - assert(nameToRef == Set( - // The non "data" or "mask" fields of read and write ports are already of ground type but still do get duplicated. - // They will all carry the exact same value, so for a RHS use of the old signal, any of the expanded ones will do. - "mem.rw.clk" -> "mem__a.rw.clk", "mem.rw.clk" -> "mem__b_c.rw.clk", - "mem.rw.en" -> "mem__a.rw.en", "mem.rw.en" -> "mem__b_c.rw.en", - "mem.rw.addr" -> "mem__a.rw.addr", "mem.rw.addr" -> "mem__b_c.rw.addr", - "mem.rw.wmode" -> "mem__a.rw.wmode", "mem.rw.wmode" -> "mem__b_c.rw.wmode", - // Ground type references to the data or mask field are unique. - "mem.rw.rdata.a" -> "mem__a.rw.rdata", - "mem.rw.wdata.a" -> "mem__a.rw.wdata", - "mem.rw.wmask.a" -> "mem__a.rw.wmask", - "mem.rw.rdata.b.c" -> "mem__b_c.rw.rdata", - "mem.rw.wdata.b.c" -> "mem__b_c.rw.wdata", - "mem.rw.wmask.b.c" -> "mem__b_c.rw.wmask" - )) - - assert(nameToType == Set( - // - "mem.rw.clk" -> "Clock", - "mem.rw.en" -> "UInt<1>", - "mem.rw.addr" -> "UInt<2>", - "mem.rw.wmode" -> "UInt<1>", - // Ground type references to the data or mask field are unique. - "mem.rw.rdata.a" -> "UInt<3>", - "mem.rw.wdata.a" -> "UInt<3>", - "mem.rw.wmask.a" -> "UInt<1>", - "mem.rw.rdata.b.c" -> "UInt<4>", - "mem.rw.wdata.b.c" -> "UInt<4>", - "mem.rw.wmask.b.c" -> "UInt<1>" - )) - } + val refs = lower( + "mem", + "{ a : UInt<3>, b : { c : UInt<4>} }", + Set("mem_a"), + r = Seq(), + w = Seq(), + rw = Seq("rw"), + depth = 3 + ).refs + val nameToRef = refs.map { case (n, r) => n -> r.serialize }.toSet + val nameToType = refs.map { case (n, r) => n -> r.tpe.serialize }.toSet + + assert( + nameToRef == Set( + // The non "data" or "mask" fields of read and write ports are already of ground type but still do get duplicated. + // They will all carry the exact same value, so for a RHS use of the old signal, any of the expanded ones will do. + "mem.rw.clk" -> "mem__a.rw.clk", + "mem.rw.clk" -> "mem__b_c.rw.clk", + "mem.rw.en" -> "mem__a.rw.en", + "mem.rw.en" -> "mem__b_c.rw.en", + "mem.rw.addr" -> "mem__a.rw.addr", + "mem.rw.addr" -> "mem__b_c.rw.addr", + "mem.rw.wmode" -> "mem__a.rw.wmode", + "mem.rw.wmode" -> "mem__b_c.rw.wmode", + // Ground type references to the data or mask field are unique. + "mem.rw.rdata.a" -> "mem__a.rw.rdata", + "mem.rw.wdata.a" -> "mem__a.rw.wdata", + "mem.rw.wmask.a" -> "mem__a.rw.wmask", + "mem.rw.rdata.b.c" -> "mem__b_c.rw.rdata", + "mem.rw.wdata.b.c" -> "mem__b_c.rw.wdata", + "mem.rw.wmask.b.c" -> "mem__b_c.rw.wmask" + ) + ) + assert( + nameToType == Set( + // + "mem.rw.clk" -> "Clock", + "mem.rw.en" -> "UInt<1>", + "mem.rw.addr" -> "UInt<2>", + "mem.rw.wmode" -> "UInt<1>", + // Ground type references to the data or mask field are unique. + "mem.rw.rdata.a" -> "UInt<3>", + "mem.rw.wdata.a" -> "UInt<3>", + "mem.rw.wmask.a" -> "UInt<1>", + "mem.rw.rdata.b.c" -> "UInt<4>", + "mem.rw.wdata.b.c" -> "UInt<4>", + "mem.rw.wmask.b.c" -> "UInt<1>" + ) + ) + } it should "rename references for vector type memories" in { val l = lower("mem", "UInt<1>[2]", Set("mem_0")) @@ -491,14 +606,20 @@ class LowerTypesOfMemorySpec extends AnyFlatSpec { assert(get(r, mem) == Set(m.ref("mem__0"), m.ref("mem__1"))) // read port - assert(get(r, mem.field("r")) == - Set(m.ref("mem__0").field("r"), m.ref("mem__1").field("r"))) + assert( + get(r, mem.field("r")) == + Set(m.ref("mem__0").field("r"), m.ref("mem__1").field("r")) + ) // port sub-fields - assert(get(r, mem.field("r").field("data").index(0)) == - Set(m.ref("mem__0").field("r").field("data"))) - assert(get(r, mem.field("r").field("data").index(1)) == - Set(m.ref("mem__1").field("r").field("data"))) + assert( + get(r, mem.field("r").field("data").index(0)) == + Set(m.ref("mem__0").field("r").field("data")) + ) + assert( + get(r, mem.field("r").field("data").index(1)) == + Set(m.ref("mem__1").field("r").field("data")) + ) val renameCount = r.underlying.map(_._2.size).sum assert(renameCount == 8, "it is enough to rename *to* 8 different signals") diff --git a/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala b/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala index 007608ca..0b9b830c 100644 --- a/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala +++ b/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala @@ -8,7 +8,14 @@ import java.io.File import firrtl._ import firrtl.stage.phases.DriverCompatibility._ import firrtl.options.{InputAnnotationFileAnnotation, Phase, TargetDirAnnotation} -import firrtl.stage.{CompilerAnnotation, FirrtlCircuitAnnotation, FirrtlFileAnnotation, FirrtlSourceAnnotation, OutputFileAnnotation, RunFirrtlTransformAnnotation} +import firrtl.stage.{ + CompilerAnnotation, + FirrtlCircuitAnnotation, + FirrtlFileAnnotation, + FirrtlSourceAnnotation, + OutputFileAnnotation, + RunFirrtlTransformAnnotation +} import firrtl.stage.phases.DriverCompatibility import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers @@ -20,7 +27,7 @@ class DriverCompatibilitySpec extends AnyFlatSpec with Matchers with PrivateMeth /* This method wraps some magic that lets you use the private method DriverCompatibility.topName */ def topName(annotations: AnnotationSeq): Option[String] = { val topName = PrivateMethod[Option[String]]('topName) - DriverCompatibility invokePrivate topName(annotations) + DriverCompatibility.invokePrivate(topName(annotations)) } def simpleCircuit(main: String): String = s"""|circuit $main: @@ -41,22 +48,22 @@ class DriverCompatibilitySpec extends AnyFlatSpec with Matchers with PrivateMeth (FirrtlFileAnnotation("src/test/resources/integration/GCDTester.pb"), "GCDTester") ) - behavior of s"${DriverCompatibility.getClass.getName}.topName (private method)" + behavior.of(s"${DriverCompatibility.getClass.getName}.topName (private method)") /* This iterates over the tails of annosWithTops. Using the ordering of annosWithTops, if this AnnotationSeq is fed to * DriverCompatibility.topName, the head annotation will be used to determine the top name. This test ensures that * topName behaves as expected. */ - for ( t <- annosWithTops.tails ) t match { + for (t <- annosWithTops.tails) t match { case Nil => it should "return None on an empty AnnotationSeq" in { - topName(Seq.empty) should be (None) + topName(Seq.empty) should be(None) } case x => val annotations = x.map(_._1) val top = x.head._2 it should s"determine a top name ('$top') from a ${annotations.head.getClass.getName}" in { - topName(annotations).get should be (top) + topName(annotations).get should be(top) } } @@ -66,152 +73,148 @@ class DriverCompatibilitySpec extends AnyFlatSpec with Matchers with PrivateMeth file.createNewFile() } - behavior of classOf[AddImplicitAnnotationFile].toString + behavior.of(classOf[AddImplicitAnnotationFile].toString) val testDir = "test_run_dir/DriverCompatibilitySpec" it should "not modify the annotations if an InputAnnotationFile already exists" in - new PhaseFixture(new AddImplicitAnnotationFile) { + new PhaseFixture(new AddImplicitAnnotationFile) { - createFile(testDir + "/foo.anno") - val annotations = Seq( - InputAnnotationFileAnnotation("bar.anno"), - TargetDirAnnotation(testDir), - TopNameAnnotation("foo") ) + createFile(testDir + "/foo.anno") + val annotations = + Seq(InputAnnotationFileAnnotation("bar.anno"), TargetDirAnnotation(testDir), TopNameAnnotation("foo")) - phase.transform(annotations).toSeq should be (annotations) - } + phase.transform(annotations).toSeq should be(annotations) + } it should "add an InputAnnotationFile based on a derived topName" in - new PhaseFixture(new AddImplicitAnnotationFile) { - createFile(testDir + "/bar.anno") - val annotations = Seq( - TargetDirAnnotation(testDir), - TopNameAnnotation("bar") ) + new PhaseFixture(new AddImplicitAnnotationFile) { + createFile(testDir + "/bar.anno") + val annotations = Seq(TargetDirAnnotation(testDir), TopNameAnnotation("bar")) - val expected = annotations.toSet + - InputAnnotationFileAnnotation(testDir + "/bar.anno") + val expected = annotations.toSet + + InputAnnotationFileAnnotation(testDir + "/bar.anno") - phase.transform(annotations).toSet should be (expected) - } + phase.transform(annotations).toSet should be(expected) + } it should "not add an InputAnnotationFile for .anno.json annotations" in - new PhaseFixture(new AddImplicitAnnotationFile) { - createFile(testDir + "/baz.anno.json") - val annotations = Seq( - TargetDirAnnotation(testDir), - TopNameAnnotation("baz") ) + new PhaseFixture(new AddImplicitAnnotationFile) { + createFile(testDir + "/baz.anno.json") + val annotations = Seq(TargetDirAnnotation(testDir), TopNameAnnotation("baz")) - phase.transform(annotations).toSeq should be (annotations) - } + phase.transform(annotations).toSeq should be(annotations) + } it should "not add an InputAnnotationFile if it cannot determine the topName" in - new PhaseFixture(new AddImplicitAnnotationFile) { - val annotations = Seq( TargetDirAnnotation(testDir) ) + new PhaseFixture(new AddImplicitAnnotationFile) { + val annotations = Seq(TargetDirAnnotation(testDir)) - phase.transform(annotations).toSeq should be (annotations) - } + phase.transform(annotations).toSeq should be(annotations) + } - behavior of classOf[AddImplicitFirrtlFile].toString + behavior.of(classOf[AddImplicitFirrtlFile].toString) it should "not modify the annotations if a CircuitOption is present" in - new PhaseFixture(new AddImplicitFirrtlFile) { - val annotations = Seq( FirrtlFileAnnotation("foo"), TopNameAnnotation("bar") ) + new PhaseFixture(new AddImplicitFirrtlFile) { + val annotations = Seq(FirrtlFileAnnotation("foo"), TopNameAnnotation("bar")) - phase.transform(annotations).toSeq should be (annotations) - } + phase.transform(annotations).toSeq should be(annotations) + } it should "add an FirrtlFileAnnotation if a TopNameAnnotation is present" in - new PhaseFixture(new AddImplicitFirrtlFile) { - val annotations = Seq( TopNameAnnotation("foo") ) - val expected = annotations.toSet + - FirrtlFileAnnotation(new File("foo.fir").getPath()) + new PhaseFixture(new AddImplicitFirrtlFile) { + val annotations = Seq(TopNameAnnotation("foo")) + val expected = annotations.toSet + + FirrtlFileAnnotation(new File("foo.fir").getPath()) - phase.transform(annotations).toSet should be (expected) - } + phase.transform(annotations).toSet should be(expected) + } it should "do nothing if no TopNameAnnotation is present" in - new PhaseFixture(new AddImplicitFirrtlFile) { - val annotations = Seq( TargetDirAnnotation("foo") ) + new PhaseFixture(new AddImplicitFirrtlFile) { + val annotations = Seq(TargetDirAnnotation("foo")) - phase.transform(annotations).toSeq should be (annotations) - } + phase.transform(annotations).toSeq should be(annotations) + } - behavior of classOf[AddImplicitEmitter].toString + behavior.of(classOf[AddImplicitEmitter].toString) - val (nc, hfc, mfc, lfc, vc, svc) = ( new NoneCompiler, - new HighFirrtlCompiler, - new MiddleFirrtlCompiler, - new LowFirrtlCompiler, - new VerilogCompiler, - new SystemVerilogCompiler ) + val (nc, hfc, mfc, lfc, vc, svc) = ( + new NoneCompiler, + new HighFirrtlCompiler, + new MiddleFirrtlCompiler, + new LowFirrtlCompiler, + new VerilogCompiler, + new SystemVerilogCompiler + ) it should "convert CompilerAnnotations into EmitCircuitAnnotations without EmitOneFilePerModuleAnnotation" in - new PhaseFixture(new AddImplicitEmitter) { - val annotations = Seq( - CompilerAnnotation(nc), - CompilerAnnotation(hfc), - CompilerAnnotation(mfc), - CompilerAnnotation(lfc), - CompilerAnnotation(vc), - CompilerAnnotation(svc) - ) - val expected = annotations - .flatMap( a => Seq(a, - RunFirrtlTransformAnnotation(a.compiler.emitter), - EmitCircuitAnnotation(a.compiler.emitter.getClass)) ) - - phase.transform(annotations).toSeq should be (expected) - } + new PhaseFixture(new AddImplicitEmitter) { + val annotations = Seq( + CompilerAnnotation(nc), + CompilerAnnotation(hfc), + CompilerAnnotation(mfc), + CompilerAnnotation(lfc), + CompilerAnnotation(vc), + CompilerAnnotation(svc) + ) + val expected = annotations + .flatMap(a => + Seq(a, RunFirrtlTransformAnnotation(a.compiler.emitter), EmitCircuitAnnotation(a.compiler.emitter.getClass)) + ) + + phase.transform(annotations).toSeq should be(expected) + } it should "convert CompilerAnnotations into EmitAllodulesAnnotation with EmitOneFilePerModuleAnnotation" in - new PhaseFixture(new AddImplicitEmitter) { - val annotations = Seq( - EmitOneFilePerModuleAnnotation, - CompilerAnnotation(nc), - CompilerAnnotation(hfc), - CompilerAnnotation(mfc), - CompilerAnnotation(lfc), - CompilerAnnotation(vc), - CompilerAnnotation(svc) - ) - val expected = annotations - .flatMap{ - case a: CompilerAnnotation => Seq(a, - RunFirrtlTransformAnnotation(a.compiler.emitter), - EmitAllModulesAnnotation(a.compiler.emitter.getClass)) + new PhaseFixture(new AddImplicitEmitter) { + val annotations = Seq( + EmitOneFilePerModuleAnnotation, + CompilerAnnotation(nc), + CompilerAnnotation(hfc), + CompilerAnnotation(mfc), + CompilerAnnotation(lfc), + CompilerAnnotation(vc), + CompilerAnnotation(svc) + ) + val expected = annotations.flatMap { + case a: CompilerAnnotation => + Seq( + a, + RunFirrtlTransformAnnotation(a.compiler.emitter), + EmitAllModulesAnnotation(a.compiler.emitter.getClass) + ) case a => Seq(a) } - phase.transform(annotations).toSeq should be (expected) - } + phase.transform(annotations).toSeq should be(expected) + } - behavior of classOf[AddImplicitOutputFile].toString + behavior.of(classOf[AddImplicitOutputFile].toString) it should "add an OutputFileAnnotation derived from a TopNameAnnotation if no OutputFileAnnotation exists" in - new PhaseFixture(new AddImplicitOutputFile) { - val annotations = Seq( TopNameAnnotation("foo") ) - val expected = Seq( - OutputFileAnnotation("foo"), - TopNameAnnotation("foo") - ) - phase.transform(annotations).toSeq should be (expected) - } + new PhaseFixture(new AddImplicitOutputFile) { + val annotations = Seq(TopNameAnnotation("foo")) + val expected = Seq( + OutputFileAnnotation("foo"), + TopNameAnnotation("foo") + ) + phase.transform(annotations).toSeq should be(expected) + } it should "do nothing if an OutputFileannotation already exists" in - new PhaseFixture(new AddImplicitOutputFile) { - val annotations = Seq( - TopNameAnnotation("foo"), - OutputFileAnnotation("bar") ) - val expected = annotations - phase.transform(annotations).toSeq should be (expected) - } + new PhaseFixture(new AddImplicitOutputFile) { + val annotations = Seq(TopNameAnnotation("foo"), OutputFileAnnotation("bar")) + val expected = annotations + phase.transform(annotations).toSeq should be(expected) + } it should "do nothing if no TopNameAnnotation exists" in - new PhaseFixture(new AddImplicitOutputFile) { - val annotations = Seq.empty - val expected = annotations - phase.transform(annotations).toSeq should be (expected) - } + new PhaseFixture(new AddImplicitOutputFile) { + val annotations = Seq.empty + val expected = annotations + phase.transform(annotations).toSeq should be(expected) + } } diff --git a/src/test/scala/firrtl/testutils/FirrtlSpec.scala b/src/test/scala/firrtl/testutils/FirrtlSpec.scala index dfc20352..a0c41085 100644 --- a/src/test/scala/firrtl/testutils/FirrtlSpec.scala +++ b/src/test/scala/firrtl/testutils/FirrtlSpec.scala @@ -46,11 +46,13 @@ object RenameTop extends Transform { val c = state.circuit val ns = Namespace(c) - val newTopName = state.annotations.collectFirst({ - case RenameTopAnnotation(name) => - require(ns.tryName(name)) - name - }).getOrElse(c.main) + val newTopName = state.annotations + .collectFirst({ + case RenameTopAnnotation(name) => + require(ns.tryName(name)) + name + }) + .getOrElse(c.main) state.annotations.collect { case ModuleNamespaceAnnotation(mustNotCollideNS) => require(mustNotCollideNS.tryName(newTopName)) @@ -70,6 +72,7 @@ object RenameTop extends Transform { trait FirrtlRunners extends BackendCompilationUtilities { val cppHarnessResourceName: String = "/firrtl/testTop.cpp" + /** Extra transforms to run by default */ val extraCheckTransforms = Seq(new CheckLowForm) @@ -80,10 +83,12 @@ trait FirrtlRunners extends BackendCompilationUtilities { * @param customAnnotations Optional Firrtl annotations * @param timesteps the maximum number of timesteps to consider */ - def firrtlEquivalenceTest(input: String, - customTransforms: Seq[Transform] = Seq.empty, - customAnnotations: AnnotationSeq = Seq.empty, - timesteps: Int = 1): Unit = { + def firrtlEquivalenceTest( + input: String, + customTransforms: Seq[Transform] = Seq.empty, + customAnnotations: AnnotationSeq = Seq.empty, + timesteps: Int = 1 + ): Unit = { val circuit = Parser.parse(input.split("\n").toIterator) val prefix = circuit.main val testDir = createTestDirectory(prefix + "_equivalence_test") @@ -93,12 +98,12 @@ trait FirrtlRunners extends BackendCompilationUtilities { def getBaseAnnos(topName: String) = { val baseTransforms = RenameTop +: extraCheckTransforms TargetDirAnnotation(testDir.toString) +: - InfoModeAnnotation("ignore") +: - RenameTopAnnotation(topName) +: - stage.FirrtlCircuitAnnotation(circuit) +: - stage.CompilerAnnotation("mverilog") +: - stage.OutputFileAnnotation(topName) +: - toAnnos(baseTransforms) + InfoModeAnnotation("ignore") +: + RenameTopAnnotation(topName) +: + stage.FirrtlCircuitAnnotation(circuit) +: + stage.CompilerAnnotation("mverilog") +: + stage.OutputFileAnnotation(topName) +: + toAnnos(baseTransforms) } val customName = s"${prefix}_custom" @@ -111,7 +116,8 @@ trait FirrtlRunners extends BackendCompilationUtilities { val refAnnos = getBaseAnnos(refSuggestedName) ++: Seq(RunFirrtlTransformAnnotation(new RenameModules), nsAnno) val refResult = (new firrtl.stage.FirrtlStage).execute(Array.empty, refAnnos) - val refName = refResult.collectFirst({ case stage.FirrtlCircuitAnnotation(c) => c.main }).getOrElse(refSuggestedName) + val refName = + refResult.collectFirst({ case stage.FirrtlCircuitAnnotation(c) => c.main }).getOrElse(refSuggestedName) assert(BackendCompilationUtilities.yosysExpectSuccess(customName, refName, testDir, timesteps)) } @@ -123,6 +129,7 @@ trait FirrtlRunners extends BackendCompilationUtilities { val res = compiler.compileAndEmit(CircuitState(circuit, HighForm, annotations), extraCheckTransforms) res.getEmittedCircuit.value } + /** Compile a Firrtl file * * @param prefix is the name of the Firrtl file without path or file extension @@ -130,25 +137,27 @@ trait FirrtlRunners extends BackendCompilationUtilities { * @param annotations Optional Firrtl annotations */ def compileFirrtlTest( - prefix: String, - srcDir: String, - customTransforms: Seq[Transform] = Seq.empty, - annotations: AnnotationSeq = Seq.empty): File = { + prefix: String, + srcDir: String, + customTransforms: Seq[Transform] = Seq.empty, + annotations: AnnotationSeq = Seq.empty + ): File = { val testDir = createTestDirectory(prefix) val inputFile = new File(testDir, s"${prefix}.fir") copyResourceToFile(s"${srcDir}/${prefix}.fir", inputFile) val annos = FirrtlFileAnnotation(inputFile.toString) +: - TargetDirAnnotation(testDir.toString) +: - InfoModeAnnotation("ignore") +: - annotations ++: - (customTransforms ++ extraCheckTransforms).map(RunFirrtlTransformAnnotation(_)) + TargetDirAnnotation(testDir.toString) +: + InfoModeAnnotation("ignore") +: + annotations ++: + (customTransforms ++ extraCheckTransforms).map(RunFirrtlTransformAnnotation(_)) (new firrtl.stage.FirrtlStage).execute(Array.empty, annos) testDir } + /** Execute a Firrtl Test * * @param prefix is the name of the Firrtl file without path or file extension @@ -157,25 +166,26 @@ trait FirrtlRunners extends BackendCompilationUtilities { * @param annotations Optional Firrtl annotations */ def runFirrtlTest( - prefix: String, - srcDir: String, - verilogPrefixes: Seq[String] = Seq.empty, - customTransforms: Seq[Transform] = Seq.empty, - annotations: AnnotationSeq = Seq.empty) = { + prefix: String, + srcDir: String, + verilogPrefixes: Seq[String] = Seq.empty, + customTransforms: Seq[Transform] = Seq.empty, + annotations: AnnotationSeq = Seq.empty + ) = { val testDir = compileFirrtlTest(prefix, srcDir, customTransforms, annotations) val harness = new File(testDir, s"top.cpp") copyResourceToFile(cppHarnessResourceName, harness) // Note file copying side effect - val verilogFiles = verilogPrefixes map { vprefix => + val verilogFiles = verilogPrefixes.map { vprefix => val file = new File(testDir, s"$vprefix.v") copyResourceToFile(s"$srcDir/$vprefix.v", file) file } verilogToCpp(prefix, testDir, verilogFiles, harness) #&& - cppToExe(prefix, testDir) ! - loggingProcessLogger + cppToExe(prefix, testDir) ! + loggingProcessLogger assert(executeExpectingSuccess(prefix, testDir)) } } @@ -201,6 +211,7 @@ trait FirrtlMatchers extends Matchers { require(!s.contains("\n")) s.replaceAll("\\s+", " ").trim } + /** Helper to make circuits that are the same appear the same */ def canonicalize(circuit: Circuit): Circuit = { import firrtl.Mappers._ @@ -208,19 +219,21 @@ trait FirrtlMatchers extends Matchers { circuit.map(onModule) } def parse(str: String) = Parser.parse(str.split("\n").toIterator, UseInfo) + /** Helper for executing tests * compiler will be run on input then emitted result will each be split into * lines and normalized. */ def executeTest( - input: String, - expected: Seq[String], - compiler: Compiler, - annotations: Seq[Annotation] = Seq.empty) = { + input: String, + expected: Seq[String], + compiler: Compiler, + annotations: Seq[Annotation] = Seq.empty + ) = { val finalState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations)) - val lines = finalState.getEmittedCircuit.value split "\n" map normalized + val lines = finalState.getEmittedCircuit.value.split("\n").map(normalized) for (e <- expected) { - lines should contain (e) + lines should contain(e) } } } @@ -239,10 +252,12 @@ object FirrtlCheckers extends FirrtlMatchers { case Some(res) => res // Otherwise keep digging case None => - require(node.isInstanceOf[Product] || !node.isInstanceOf[FirrtlNode], - "Error! Unexpected FirrtlNode that does not implement Product!") + require( + node.isInstanceOf[Product] || !node.isInstanceOf[FirrtlNode], + "Error! Unexpected FirrtlNode that does not implement Product!" + ) val iter = node match { - case p: Product => p.productIterator + case p: Product => p.productIterator case i: Iterable[Any] => i.iterator case _ => Iterator.empty } @@ -296,57 +311,63 @@ class TestFirrtlFlatSpec extends FirrtlFlatSpec { import FirrtlCheckers._ val c = parse(""" - |circuit Test: - | module Test : - | input in : UInt<8> - | output out : UInt<8> - | out <= in - |""".stripMargin) + |circuit Test: + | module Test : + | input in : UInt<8> + | output out : UInt<8> + | out <= in + |""".stripMargin) val state = CircuitState(c, ChirrtlForm) val compiled = (new LowFirrtlCompiler).compileAndEmit(state, List.empty) // While useful, ScalaTest helpers should be used over search - behavior of "Search" + behavior.of("Search") it should "be supported on Circuit" in { - assert(c search { - case Connect(_, Reference("out",_, _, _), Reference("in", _, _, _)) => true + assert(c.search { + case Connect(_, Reference("out", _, _, _), Reference("in", _, _, _)) => true }) } it should "be supported on CircuitStates" in { - assert(state search { - case Connect(_, Reference("out", _, _, _), Reference("in",_, _, _)) => true + assert(state.search { + case Connect(_, Reference("out", _, _, _), Reference("in", _, _, _)) => true }) } it should "be supported on the results of compilers" in { - assert(compiled search { - case Connect(_, WRef("out",_,_,_), WRef("in",_,_,_)) => true + assert(compiled.search { + case Connect(_, WRef("out", _, _, _), WRef("in", _, _, _)) => true }) } // Use these!!! - behavior of "ScalaTest helpers" + behavior.of("ScalaTest helpers") they should "work for lines of emitted text" in { - compiled should containLine (s"input in : UInt<8>") - compiled should containLine (s"output out : UInt<8>") - compiled should containLine (s"out <= in") + compiled should containLine(s"input in : UInt<8>") + compiled should containLine(s"output out : UInt<8>") + compiled should containLine(s"out <= in") } they should "work for partial functions matching on subtrees" in { val UInt8 = UIntType(IntWidth(8)) // BigInt unapply is weird compiled should containTree { case Port(_, "in", Input, UInt8) => true } compiled should containTree { case Port(_, "out", Output, UInt8) => true } - compiled should containTree { case Connect(_, WRef("out",_,_,_), WRef("in",_,_,_)) => true } + compiled should containTree { case Connect(_, WRef("out", _, _, _), WRef("in", _, _, _)) => true } } } /** Super class for execution driven Firrtl tests */ -abstract class ExecutionTest(name: String, dir: String, vFiles: Seq[String] = Seq.empty, annotations: AnnotationSeq = Seq.empty) extends FirrtlPropSpec { +abstract class ExecutionTest( + name: String, + dir: String, + vFiles: Seq[String] = Seq.empty, + annotations: AnnotationSeq = Seq.empty) + extends FirrtlPropSpec { property(s"$name should execute correctly") { runFirrtlTest(name, dir, vFiles, annotations = annotations) } } + /** Super class for compilation driven Firrtl tests */ abstract class CompilationTest(name: String, dir: String) extends FirrtlPropSpec { property(s"$name should compile correctly") { @@ -444,7 +465,9 @@ abstract class EquivalenceTest(transforms: Seq[Transform], name: String, dir: St throw new FileNotFoundException(s"Resource '$fileName'") } val source = scala.io.Source.fromInputStream(in) - val input = try source.mkString finally source.close() + val input = + try source.mkString + finally source.close() s"$name with ${transforms.map(_.name).mkString(", ")}" should s"be equivalent to $name without ${transforms.map(_.name).mkString(", ")}" in { diff --git a/src/test/scala/firrtl/testutils/LeanTransformSpec.scala b/src/test/scala/firrtl/testutils/LeanTransformSpec.scala index c1f0943a..4ae6a7be 100644 --- a/src/test/scala/firrtl/testutils/LeanTransformSpec.scala +++ b/src/test/scala/firrtl/testutils/LeanTransformSpec.scala @@ -1,6 +1,6 @@ package firrtl.testutils -import firrtl.{AnnotationSeq, CircuitState, EmitCircuitAnnotation, ir} +import firrtl.{ir, AnnotationSeq, CircuitState, EmitCircuitAnnotation} import firrtl.options.Dependency import firrtl.passes.RemoveEmpty import firrtl.stage.TransformManager.TransformDependency @@ -11,30 +11,33 @@ class VerilogTransformSpec extends LeanTransformSpec(Seq(Dependency[firrtl.Veril class LowFirrtlTransformSpec extends LeanTransformSpec(Seq(Dependency[firrtl.LowFirrtlEmitter])) /** The new cool kid on the block, creates a custom compiler for your transform. */ -class LeanTransformSpec(protected val transforms: Seq[TransformDependency]) extends AnyFlatSpec with FirrtlMatchers with LazyLogging { +class LeanTransformSpec(protected val transforms: Seq[TransformDependency]) + extends AnyFlatSpec + with FirrtlMatchers + with LazyLogging { private val compiler = new firrtl.stage.transforms.Compiler(transforms) private val emitterAnnos = LeanTransformSpec.deriveEmitCircuitAnnotations(transforms) protected def compile(src: String): CircuitState = compile(src, Seq()) protected def compile(src: String, annos: AnnotationSeq): CircuitState = compile(firrtl.Parser.parse(src), annos) - protected def compile(c: ir.Circuit): CircuitState = compile(c, Seq()) - protected def compile(c: ir.Circuit, annos: AnnotationSeq): CircuitState = + protected def compile(c: ir.Circuit): CircuitState = compile(c, Seq()) + protected def compile(c: ir.Circuit, annos: AnnotationSeq): CircuitState = compiler.transform(CircuitState(c, emitterAnnos ++ annos)) - protected def execute(input: String, check: String): CircuitState = execute(input, check ,Seq()) + protected def execute(input: String, check: String): CircuitState = execute(input, check, Seq()) protected def execute(input: String, check: String, inAnnos: AnnotationSeq): CircuitState = { val finalState = compiler.transform(CircuitState(parse(input), inAnnos)) val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize val expected = parse(check).serialize logger.debug(actual) logger.debug(expected) - actual should be (expected) + actual should be(expected) finalState } } private object LeanTransformSpec { private def deriveEmitCircuitAnnotations(transforms: Iterable[TransformDependency]): AnnotationSeq = { - val emitters = transforms.map(_.getObject()).collect{ case e: firrtl.Emitter => e } + val emitters = transforms.map(_.getObject()).collect { case e: firrtl.Emitter => e } emitters.map(e => EmitCircuitAnnotation(e.getClass)).toSeq } } @@ -47,4 +50,4 @@ trait MakeCompiler { new firrtl.stage.transforms.Compiler(Seq(Dependency[firrtl.MinimumVerilogEmitter]) ++ transforms) protected def makeLowFirrtlCompiler(transforms: Seq[TransformDependency] = Seq()) = new firrtl.stage.transforms.Compiler(Seq(Dependency[firrtl.LowFirrtlEmitter]) ++ transforms) -}
\ No newline at end of file +} diff --git a/src/test/scala/firrtl/testutils/PassTests.scala b/src/test/scala/firrtl/testutils/PassTests.scala index 49dea199..7a5dc306 100644 --- a/src/test/scala/firrtl/testutils/PassTests.scala +++ b/src/test/scala/firrtl/testutils/PassTests.scala @@ -15,49 +15,53 @@ import org.scalatest.flatspec.AnyFlatSpec // An example methodology for testing Firrtl Passes // Spec class should extend this class abstract class SimpleTransformSpec extends AnyFlatSpec with FirrtlMatchers with Compiler with LazyLogging { - // Utility function - def squash(c: Circuit): Circuit = RemoveEmpty.run(c) - - // Executes the test. Call in tests. - // annotations cannot have default value because scalatest trait Suite has a default value - def execute(input: String, check: String, annotations: Seq[Annotation]): CircuitState = { - val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations)) - val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize - val expected = parse(check).serialize - logger.debug(actual) - logger.debug(expected) - (actual) should be (expected) - finalState - } - - def executeWithAnnos(input: String, check: String, annotations: Seq[Annotation], - checkAnnotations: Seq[Annotation]): CircuitState = { - val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations)) - val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize - val expected = parse(check).serialize - logger.debug(actual) - logger.debug(expected) - (actual) should be (expected) - - annotations.foreach { anno => - logger.debug(anno.serialize) - } - - finalState.annotations.toSeq.foreach { anno => - logger.debug(anno.serialize) - } - checkAnnotations.foreach { check => - (finalState.annotations.toSeq) should contain (check) - } - finalState - } - // Executes the test, should throw an error - // No default to be consistent with execute - def failingexecute(input: String, annotations: Seq[Annotation]): Exception = { - intercept[PassExceptions] { - compile(CircuitState(parse(input), ChirrtlForm, annotations), Seq.empty) - } - } + // Utility function + def squash(c: Circuit): Circuit = RemoveEmpty.run(c) + + // Executes the test. Call in tests. + // annotations cannot have default value because scalatest trait Suite has a default value + def execute(input: String, check: String, annotations: Seq[Annotation]): CircuitState = { + val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations)) + val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize + val expected = parse(check).serialize + logger.debug(actual) + logger.debug(expected) + (actual) should be(expected) + finalState + } + + def executeWithAnnos( + input: String, + check: String, + annotations: Seq[Annotation], + checkAnnotations: Seq[Annotation] + ): CircuitState = { + val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations)) + val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize + val expected = parse(check).serialize + logger.debug(actual) + logger.debug(expected) + (actual) should be(expected) + + annotations.foreach { anno => + logger.debug(anno.serialize) + } + + finalState.annotations.toSeq.foreach { anno => + logger.debug(anno.serialize) + } + checkAnnotations.foreach { check => + (finalState.annotations.toSeq) should contain(check) + } + finalState + } + // Executes the test, should throw an error + // No default to be consistent with execute + def failingexecute(input: String, annotations: Seq[Annotation]): Exception = { + intercept[PassExceptions] { + compile(CircuitState(parse(input), ChirrtlForm, annotations), Seq.empty) + } + } } @deprecated( @@ -86,19 +90,19 @@ object ReRunResolveAndCheck extends Transform with DependencyAPIMigration with I } trait LowTransformSpec extends SimpleTransformSpec { - def emitter = new LowFirrtlEmitter - def transform: Transform - def transforms: Seq[Transform] = transform +: ReRunResolveAndCheck +: Forms.LowForm.map(_.getObject) + def emitter = new LowFirrtlEmitter + def transform: Transform + def transforms: Seq[Transform] = transform +: ReRunResolveAndCheck +: Forms.LowForm.map(_.getObject) } trait MiddleTransformSpec extends SimpleTransformSpec { - def emitter = new MiddleFirrtlEmitter - def transform: Transform - def transforms: Seq[Transform] = transform +: ReRunResolveAndCheck +: Forms.MidForm.map(_.getObject) + def emitter = new MiddleFirrtlEmitter + def transform: Transform + def transforms: Seq[Transform] = transform +: ReRunResolveAndCheck +: Forms.MidForm.map(_.getObject) } trait HighTransformSpec extends SimpleTransformSpec { - def emitter = new HighFirrtlEmitter - def transform: Transform - def transforms = transform +: ReRunResolveAndCheck +: Forms.HighForm.map(_.getObject) + def emitter = new HighFirrtlEmitter + def transform: Transform + def transforms = transform +: ReRunResolveAndCheck +: Forms.HighForm.map(_.getObject) } diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala index 4017503e..6f8dd574 100644 --- a/src/test/scala/firrtlTests/AnnotationTests.scala +++ b/src/test/scala/firrtlTests/AnnotationTests.scala @@ -15,7 +15,6 @@ import firrtl.util.BackendCompilationUtilities import firrtl.testutils._ import org.scalatest.matchers.should.Matchers - object AnnotationTests { class DeletingTransform extends Transform { @@ -31,26 +30,26 @@ object AnnotationTests { abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with MakeCompiler { import AnnotationTests._ - def anno(s: String, value: String ="this is a value", mod: String = "Top"): Annotation + def anno(s: String, value: String = "this is a value", mod: String = "Top"): Annotation def manno(mod: String): Annotation "Annotation on a node" should "pass through" in { val input: String = """circuit Top : - | module Top : - | input a : UInt<1>[2] - | input b : UInt<1> - | node c = b""".stripMargin + | module Top : + | input a : UInt<1>[2] + | input b : UInt<1> + | node c = b""".stripMargin val ta = anno("c", "") val r = compile(input, Seq(ta)) - r.annotations.toSeq should contain (ta) + r.annotations.toSeq should contain(ta) } "Deleting annotations" should "create a DeletedAnnotation" in { val transform = Dependency[DeletingTransform] val compiler = makeVerilogCompiler(Seq(transform)) val input = - """circuit Top : + """circuit Top : | module Top : | input in: UInt<3> |""".stripMargin @@ -65,7 +64,7 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with result.getEmittedCircuit }) val deleted = result.deletedAnnotations - exception.getMessage should be (s"No EmittedCircuit found! Did you delete any annotations?\n$deleted") + exception.getMessage should be(s"No EmittedCircuit found! Did you delete any annotations?\n$deleted") } "Renaming" should "propagate in Lowering of memories" in { @@ -73,7 +72,7 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with // Uncomment to help debugging failing tests // Logger.setClassLogLevels(Map(compiler.getClass.getName -> LogLevel.Debug)) val input = - """circuit Top : + """circuit Top : | module Top : | input clk: Clock | input in: UInt<3> @@ -87,25 +86,24 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with | m.r.en <= UInt(1) | m.r.addr <= in |""".stripMargin - val annos = Seq(anno("m.r.data.b", "sub"), anno("m.r.data", "all"), anno("m", "mem"), - dontTouch("Top.m")) + val annos = Seq(anno("m.r.data.b", "sub"), anno("m.r.data", "all"), anno("m", "mem"), dontTouch("Top.m")) val result = compiler.transform(CircuitState(parse(input), annos)) val resultAnno = result.annotations.toSeq - resultAnno should contain (anno("m_a", "mem")) - resultAnno should contain (anno("m_b_0", "mem")) - resultAnno should contain (anno("m_b_1", "mem")) - resultAnno should contain (anno("m_a.r.data", "all")) - resultAnno should contain (anno("m_b_0.r.data", "all")) - resultAnno should contain (anno("m_b_1.r.data", "all")) - resultAnno should contain (anno("m_b_0.r.data", "sub")) - resultAnno should contain (anno("m_b_1.r.data", "sub")) + resultAnno should contain(anno("m_a", "mem")) + resultAnno should contain(anno("m_b_0", "mem")) + resultAnno should contain(anno("m_b_1", "mem")) + resultAnno should contain(anno("m_a.r.data", "all")) + resultAnno should contain(anno("m_b_0.r.data", "all")) + resultAnno should contain(anno("m_b_1.r.data", "all")) + resultAnno should contain(anno("m_b_0.r.data", "sub")) + resultAnno should contain(anno("m_b_1.r.data", "sub")) resultAnno should not contain (anno("m")) resultAnno should not contain (anno("r")) } "Renaming" should "propagate in RemoveChirrtl and Lowering of memories" in { val compiler = makeVerilogCompiler() val input = - """circuit Top : + """circuit Top : | module Top : | input clk: Clock | input in: UInt<3> @@ -115,14 +113,14 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with val annos = Seq(anno("r.b", "sub"), anno("r", "all"), anno("m", "mem"), dontTouch("Top.m")) val result = compiler.transform(CircuitState(parse(input), annos)) val resultAnno = result.annotations.toSeq - resultAnno should contain (anno("m_a", "mem")) - resultAnno should contain (anno("m_b_0", "mem")) - resultAnno should contain (anno("m_b_1", "mem")) - resultAnno should contain (anno("m_a.r.data", "all")) - resultAnno should contain (anno("m_b_0.r.data", "all")) - resultAnno should contain (anno("m_b_1.r.data", "all")) - resultAnno should contain (anno("m_b_0.r.data", "sub")) - resultAnno should contain (anno("m_b_1.r.data", "sub")) + resultAnno should contain(anno("m_a", "mem")) + resultAnno should contain(anno("m_b_0", "mem")) + resultAnno should contain(anno("m_b_1", "mem")) + resultAnno should contain(anno("m_a.r.data", "all")) + resultAnno should contain(anno("m_b_0.r.data", "all")) + resultAnno should contain(anno("m_b_1.r.data", "all")) + resultAnno should contain(anno("m_b_0.r.data", "sub")) + resultAnno should contain(anno("m_b_1.r.data", "sub")) resultAnno should not contain (anno("m")) resultAnno should not contain (anno("r")) } @@ -130,7 +128,7 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with "Renaming" should "propagate in ZeroWidth" in { val compiler = makeVerilogCompiler() val input = - """circuit Top : + """circuit Top : | module Top : | input zero: UInt<0> | wire x: {a: UInt<3>, b: UInt<0>} @@ -141,11 +139,11 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with | x.a <= zero | x.b <= zero |""".stripMargin - val annos = Seq(anno("zero"), anno("x.a"), anno("x.b"), anno("y[0]"), anno("y[1]"), - anno("y[2]"), dontTouch("Top.x")) + val annos = + Seq(anno("zero"), anno("x.a"), anno("x.b"), anno("y[0]"), anno("y[1]"), anno("y[2]"), dontTouch("Top.x")) val result = compiler.transform(CircuitState(parse(input), annos)) val resultAnno = result.annotations.toSeq - resultAnno should contain (anno("x_a")) + resultAnno should contain(anno("x_a")) resultAnno should not contain (anno("zero")) resultAnno should not contain (anno("x.a")) resultAnno should not contain (anno("x.b")) @@ -161,7 +159,7 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with "Renaming subcomponents" should "propagate in Lowering" in { val compiler = makeVerilogCompiler() val input = - """circuit Top : + """circuit Top : | module Top : | input clk: Clock | input pred: UInt<1> @@ -176,12 +174,24 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with | write <= in |""".stripMargin val annos = Seq( - anno("in.a"), anno("in.b[0]"), anno("in.b[1]"), - anno("out.a"), anno("out.b[0]"), anno("out.b[1]"), - anno("w.a"), anno("w.b[0]"), anno("w.b[1]"), - anno("r.a"), anno("r.b[0]"), anno("r.b[1]"), - anno("write.a"), anno("write.b[0]"), anno("write.b[1]"), - dontTouch("Top.r"), dontTouch("Top.w"), dontTouch("Top.mem") + anno("in.a"), + anno("in.b[0]"), + anno("in.b[1]"), + anno("out.a"), + anno("out.b[0]"), + anno("out.b[1]"), + anno("w.a"), + anno("w.b[0]"), + anno("w.b[1]"), + anno("r.a"), + anno("r.b[0]"), + anno("r.b[1]"), + anno("write.a"), + anno("write.b[0]"), + anno("write.b[1]"), + dontTouch("Top.r"), + dontTouch("Top.w"), + dontTouch("Top.mem") ) val result = compiler.transform(CircuitState(parse(input), annos)) val resultAnno = result.annotations.toSeq @@ -200,27 +210,27 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with resultAnno should not contain (anno("r.a")) resultAnno should not contain (anno("r.b[0]")) resultAnno should not contain (anno("r.b[1]")) - resultAnno should contain (anno("in_a")) - resultAnno should contain (anno("in_b_0")) - resultAnno should contain (anno("in_b_1")) - resultAnno should contain (anno("out_a")) - resultAnno should contain (anno("out_b_0")) - resultAnno should contain (anno("out_b_1")) - resultAnno should contain (anno("w_a")) - resultAnno should contain (anno("w_b_0")) - resultAnno should contain (anno("w_b_1")) - resultAnno should contain (anno("r_a")) - resultAnno should contain (anno("r_b_0")) - resultAnno should contain (anno("r_b_1")) - resultAnno should contain (anno("mem_a.write.data")) - resultAnno should contain (anno("mem_b_0.write.data")) - resultAnno should contain (anno("mem_b_1.write.data")) + resultAnno should contain(anno("in_a")) + resultAnno should contain(anno("in_b_0")) + resultAnno should contain(anno("in_b_1")) + resultAnno should contain(anno("out_a")) + resultAnno should contain(anno("out_b_0")) + resultAnno should contain(anno("out_b_1")) + resultAnno should contain(anno("w_a")) + resultAnno should contain(anno("w_b_0")) + resultAnno should contain(anno("w_b_1")) + resultAnno should contain(anno("r_a")) + resultAnno should contain(anno("r_b_0")) + resultAnno should contain(anno("r_b_1")) + resultAnno should contain(anno("mem_a.write.data")) + resultAnno should contain(anno("mem_b_0.write.data")) + resultAnno should contain(anno("mem_b_1.write.data")) } "Renaming components" should "expand in Lowering" in { val compiler = makeVerilogCompiler() val input = - """circuit Top : + """circuit Top : | module Top : | input clk: Clock | input pred: UInt<1> @@ -231,28 +241,27 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with | out <= mux(pred, in, w) | reg r: {a: UInt<3>, b: UInt<3>[2]}, clk |""".stripMargin - val annos = Seq(anno("in"), anno("out"), anno("w"), anno("r"), dontTouch("Top.r"), - dontTouch("Top.w")) + val annos = Seq(anno("in"), anno("out"), anno("w"), anno("r"), dontTouch("Top.r"), dontTouch("Top.w")) val result = compiler.transform(CircuitState(parse(input), annos)) val resultAnno = result.annotations.toSeq - resultAnno should contain (anno("in_a")) - resultAnno should contain (anno("in_b_0")) - resultAnno should contain (anno("in_b_1")) - resultAnno should contain (anno("out_a")) - resultAnno should contain (anno("out_b_0")) - resultAnno should contain (anno("out_b_1")) - resultAnno should contain (anno("w_a")) - resultAnno should contain (anno("w_b_0")) - resultAnno should contain (anno("w_b_1")) - resultAnno should contain (anno("r_a")) - resultAnno should contain (anno("r_b_0")) - resultAnno should contain (anno("r_b_1")) + resultAnno should contain(anno("in_a")) + resultAnno should contain(anno("in_b_0")) + resultAnno should contain(anno("in_b_1")) + resultAnno should contain(anno("out_a")) + resultAnno should contain(anno("out_b_0")) + resultAnno should contain(anno("out_b_1")) + resultAnno should contain(anno("w_a")) + resultAnno should contain(anno("w_b_0")) + resultAnno should contain(anno("w_b_1")) + resultAnno should contain(anno("r_a")) + resultAnno should contain(anno("r_b_0")) + resultAnno should contain(anno("r_b_1")) } "Renaming subcomponents that aren't leaves" should "expand in Lowering" in { val compiler = makeVerilogCompiler() val input = - """circuit Top : + """circuit Top : | module Top : | input clk: Clock | input pred: UInt<1> @@ -264,24 +273,23 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with | out <= n | reg r: {a: UInt<3>, b: UInt<3>[2]}, clk |""".stripMargin - val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("r.b"), - dontTouch("Top.r"), dontTouch("Top.w")) + val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("r.b"), dontTouch("Top.r"), dontTouch("Top.w")) val result = compiler.transform(CircuitState(parse(input), annos)) val resultAnno = result.annotations.toSeq - resultAnno should contain (anno("in_b_0")) - resultAnno should contain (anno("in_b_1")) - resultAnno should contain (anno("out_b_0")) - resultAnno should contain (anno("out_b_1")) - resultAnno should contain (anno("w_b_0")) - resultAnno should contain (anno("w_b_1")) - resultAnno should contain (anno("r_b_0")) - resultAnno should contain (anno("r_b_1")) + resultAnno should contain(anno("in_b_0")) + resultAnno should contain(anno("in_b_1")) + resultAnno should contain(anno("out_b_0")) + resultAnno should contain(anno("out_b_1")) + resultAnno should contain(anno("w_b_0")) + resultAnno should contain(anno("w_b_1")) + resultAnno should contain(anno("r_b_0")) + resultAnno should contain(anno("r_b_1")) } "Renaming" should "track constprop + dce" in { val compiler = makeVerilogCompiler() val input = - """circuit Top : + """circuit Top : | module Top : | input clk: Clock | input pred: UInt<1> @@ -291,9 +299,15 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with | out <= n |""".stripMargin val annos = Seq( - anno("in.a"), anno("in.b[0]"), anno("in.b[1]"), - anno("out.a"), anno("out.b[0]"), anno("out.b[1]"), - anno("n.a"), anno("n.b[0]"), anno("n.b[1]") + anno("in.a"), + anno("in.b[0]"), + anno("in.b[1]"), + anno("out.a"), + anno("out.b[0]"), + anno("out.b[1]"), + anno("n.a"), + anno("n.b[0]"), + anno("n.b[1]") ) val result = compiler.transform(CircuitState(parse(input), annos)) val resultAnno = result.annotations.toSeq @@ -309,18 +323,18 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with resultAnno should not contain (anno("n_a")) resultAnno should not contain (anno("n_b_0")) resultAnno should not contain (anno("n_b_1")) - resultAnno should contain (anno("in_a")) - resultAnno should contain (anno("in_b_0")) - resultAnno should contain (anno("in_b_1")) - resultAnno should contain (anno("out_a")) - resultAnno should contain (anno("out_b_0")) - resultAnno should contain (anno("out_b_1")) + resultAnno should contain(anno("in_a")) + resultAnno should contain(anno("in_b_0")) + resultAnno should contain(anno("in_b_1")) + resultAnno should contain(anno("out_a")) + resultAnno should contain(anno("out_b_0")) + resultAnno should contain(anno("out_b_1")) } "Renaming" should "track deleted modules AND instances in dce" in { val compiler = makeVerilogCompiler() val input = - """circuit Top : + """circuit Top : | module Dead : | input foo : UInt<8> | output bar : UInt<8> @@ -339,11 +353,17 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with |""".stripMargin val annos = Seq( OptimizableExtModuleAnnotation(ModuleName("DeadExt", CircuitName("Top"))), - manno("Dead"), manno("DeadExt"), manno("Top"), - anno("d"), anno("d2"), - anno("foo", mod = "Top"), anno("bar", mod = "Top"), - anno("foo", mod = "Dead"), anno("bar", mod = "Dead"), - anno("foo", mod = "DeadExt"), anno("bar", mod = "DeadExt") + manno("Dead"), + manno("DeadExt"), + manno("Top"), + anno("d"), + anno("d2"), + anno("foo", mod = "Top"), + anno("bar", mod = "Top"), + anno("foo", mod = "Dead"), + anno("bar", mod = "Dead"), + anno("foo", mod = "DeadExt"), + anno("bar", mod = "DeadExt") ) val result = compiler.transform(CircuitState(parse(input), annos)) /* Uncomment to help debug @@ -354,12 +374,12 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with case Annotation(target, _, _) => println(s"not deleted: $target") } } - */ + */ val resultAnno = result.annotations.toSeq - resultAnno should contain (manno("Top")) - resultAnno should contain (anno("foo", mod = "Top")) - resultAnno should contain (anno("bar", mod = "Top")) + resultAnno should contain(manno("Top")) + resultAnno should contain(anno("foo", mod = "Top")) + resultAnno should contain(anno("bar", mod = "Top")) resultAnno should not contain (manno("Dead")) resultAnno should not contain (manno("DeadExt")) @@ -373,7 +393,7 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with "Renaming" should "track deduplication" in { val input = - """circuit Top : + """circuit Top : | module Child : | input x : UInt<32> | output y : UInt<32> @@ -392,13 +412,16 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with | out <= tail(add(a.y, b.y), 1) |""".stripMargin val annos = Seq( - anno("x", mod = "Child"), anno("y", mod = "Child_1"), manno("Child"), manno("Child_1") + anno("x", mod = "Child"), + anno("y", mod = "Child_1"), + manno("Child"), + manno("Child_1") ) val result = compile(input, annos) val resultAnno = result.annotations.toSeq - resultAnno should contain (anno("x", mod = "Child")) - resultAnno should contain (anno("y", mod = "Child")) - resultAnno should contain (manno("Child")) + resultAnno should contain(anno("x", mod = "Child")) + resultAnno should contain(anno("y", mod = "Child")) + resultAnno should contain(manno("Child")) resultAnno should not contain (anno("y", mod = "Child_1")) resultAnno should not contain (manno("Child_1")) } @@ -412,7 +435,7 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with "Annotations on empty aggregates" should "be deleted" in { val compiler = makeVerilogCompiler() val input = - """circuit Top : + """circuit Top : | module Top : | input x : { foo : UInt<8>, bar : {}, fizz : UInt<8>[0], buzz : UInt<0> } | output y : { foo : UInt<8>, bar : {}, fizz : UInt<8>[0], buzz : UInt<0> } @@ -423,12 +446,19 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with | y <= x |""".stripMargin val annos = Seq( - anno("x"), anno("y.bar"), anno("y.fizz"), anno("y.buzz"), anno("a"), anno("b"), anno("c"), - anno("c[0].d"), anno("c[1].d") + anno("x"), + anno("y.bar"), + anno("y.fizz"), + anno("y.buzz"), + anno("a"), + anno("b"), + anno("c"), + anno("c[0].d"), + anno("c[1].d") ) val result = compiler.transform(CircuitState(parse(input), annos)) val resultAnno = result.annotations.toSeq - resultAnno should contain (anno("x_foo")) + resultAnno should contain(anno("x_foo")) resultAnno should not contain (anno("a")) resultAnno should not contain (anno("b")) // Check both with and without dots because both are wrong @@ -445,8 +475,8 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with resultAnno should not contain (anno("x_fizz")) resultAnno should not contain (anno("x_buzz")) resultAnno should not contain (anno("c")) - resultAnno should contain (anno("c_0_e")) - resultAnno should contain (anno("c_1_e")) + resultAnno should contain(anno("c_0_e")) + resultAnno should contain(anno("c_1_e")) resultAnno should not contain (anno("c[0].d")) resultAnno should not contain (anno("c[1].d")) resultAnno should not contain (anno("c_0_d")) @@ -456,15 +486,14 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with class JsonAnnotationTests extends AnnotationTests { // Helper annotations - case class SimpleAnno(target: ComponentName, value: String) extends - SingleTargetAnnotation[ComponentName] { + case class SimpleAnno(target: ComponentName, value: String) extends SingleTargetAnnotation[ComponentName] { def duplicate(n: ComponentName) = this.copy(target = n) } case class ModuleAnno(target: ModuleName) extends SingleTargetAnnotation[ModuleName] { def duplicate(n: ModuleName) = this.copy(target = n) } - def anno(s: String, value: String ="this is a value", mod: String = "Top"): SimpleAnno = + def anno(s: String, value: String = "this is a value", mod: String = "Top"): SimpleAnno = SimpleAnno(ComponentName(s, ModuleName(mod, CircuitName("Top"))), value) def manno(mod: String): Annotation = ModuleAnno(ModuleName(mod, CircuitName("Top"))) @@ -487,17 +516,17 @@ class JsonAnnotationTests extends AnnotationTests { val readAnnos = JsonProtocol.deserializeTry(text).get - annos should be (readAnnos) + annos should be(readAnnos) } private def setupManager(annoFileText: Option[String]) = { val source = """ - |circuit test : - | module test : - | input x : UInt<1> - | output z : UInt<1> - | z <= x - | node y = x""".stripMargin + |circuit test : + | module test : + | input x : UInt<1> + | output z : UInt<1> + | z <= x + | node y = x""".stripMargin val testDir = BackendCompilationUtilities.createTestDirectory(this.getClass.getSimpleName) val annoFile = new File(testDir, "anno.json") @@ -519,58 +548,57 @@ class JsonAnnotationTests extends AnnotationTests { "Annotation file not found" should "give a reasonable error message" in { val manager = setupManager(None) - an [AnnotationFileNotFoundException] shouldBe thrownBy { + an[AnnotationFileNotFoundException] shouldBe thrownBy { Driver.execute(manager) } } "Annotation class not found" should "give a reasonable error message" in { val anno = """ - |[ - | { - | "class":"ThisClassDoesNotExist", - | "target":"test.test.y" - | } - |] """.stripMargin + |[ + | { + | "class":"ThisClassDoesNotExist", + | "target":"test.test.y" + | } + |] """.stripMargin val manager = setupManager(Some(anno)) - the [Exception] thrownBy Driver.execute(manager) should matchPattern { + the[Exception] thrownBy Driver.execute(manager) should matchPattern { case InvalidAnnotationFileException(_, _: AnnotationClassNotFoundException) => } } "Malformed annotation file" should "give a reasonable error message" in { val anno = """ - |[ - | { - | "class": - | "target":"test.test.y" - | } - |] """.stripMargin + |[ + | { + | "class": + | "target":"test.test.y" + | } + |] """.stripMargin val manager = setupManager(Some(anno)) - the [Exception] thrownBy Driver.execute(manager) should matchPattern { + the[Exception] thrownBy Driver.execute(manager) should matchPattern { case InvalidAnnotationFileException(_, _: InvalidAnnotationJSONException) => } } "Non-array annotation file" should "give a reasonable error message" in { val anno = """ - |{ - | "class":"firrtl.transforms.DontTouchAnnotation", - | "target":"test.test.y" - |} - |""".stripMargin + |{ + | "class":"firrtl.transforms.DontTouchAnnotation", + | "target":"test.test.y" + |} + |""".stripMargin val manager = setupManager(Some(anno)) - the [Exception] thrownBy Driver.execute(manager) should matchPattern { - case InvalidAnnotationFileException(_, InvalidAnnotationJSONException(msg)) - if msg.contains("JObject") => + the[Exception] thrownBy Driver.execute(manager) should matchPattern { + case InvalidAnnotationFileException(_, InvalidAnnotationJSONException(msg)) if msg.contains("JObject") => } } object DoNothingTransform extends Transform { - override def inputForm: CircuitForm = UnknownForm + override def inputForm: CircuitForm = UnknownForm override def outputForm: CircuitForm = UnknownForm def execute(state: CircuitState): CircuitState = state @@ -580,9 +608,9 @@ class JsonAnnotationTests extends AnnotationTests { val annos = Seq(anno("a"), anno("b"), anno("c"), anno("d"), anno("e")) val input: String = """circuit Top : - | module Top : - | input a : UInt<1> - | node b = c""".stripMargin + | module Top : + | input a : UInt<1> + | node b = c""".stripMargin val cr = DoNothingTransform.runTransform(CircuitState(parse(input), ChirrtlForm, annos)) cr.annotations.toSeq shouldEqual annos } diff --git a/src/test/scala/firrtlTests/AsyncResetSpec.scala b/src/test/scala/firrtlTests/AsyncResetSpec.scala index 70b28585..04b558e9 100644 --- a/src/test/scala/firrtlTests/AsyncResetSpec.scala +++ b/src/test/scala/firrtlTests/AsyncResetSpec.scala @@ -9,330 +9,313 @@ import FirrtlCheckers._ class AsyncResetSpec extends VerilogTransformSpec { def compileBody(body: String) = { val str = """ - |circuit Test : - | module Test : - |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") + |circuit Test : + | module Test : + |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") compile(str) } "AsyncReset" should "generate async-reset always blocks" in { val result = compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<8> - |output z : UInt<8> - |reg r : UInt<8>, clock with : (reset => (reset, UInt(123))) - |r <= x - |z <= r""".stripMargin - ) - result should containLine ("always @(posedge clock or posedge reset) begin") + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<8> + |output z : UInt<8> + |reg r : UInt<8>, clock with : (reset => (reset, UInt(123))) + |r <= x + |z <= r""".stripMargin) + result should containLine("always @(posedge clock or posedge reset) begin") } it should "work in nested and flipped aggregates with regular and partial connect" in { val result = compileBody(s""" - |output fizz : { flip foo : { a : AsyncReset, flip b: AsyncReset }[2], bar : { a : AsyncReset, flip b: AsyncReset }[2] } - |output buzz : { flip foo : { a : AsyncReset, flip b: AsyncReset }[2], bar : { a : AsyncReset, flip b: AsyncReset }[2] } - |fizz.bar <= fizz.foo - |buzz.bar <- buzz.foo - |""".stripMargin - ) + |output fizz : { flip foo : { a : AsyncReset, flip b: AsyncReset }[2], bar : { a : AsyncReset, flip b: AsyncReset }[2] } + |output buzz : { flip foo : { a : AsyncReset, flip b: AsyncReset }[2], bar : { a : AsyncReset, flip b: AsyncReset }[2] } + |fizz.bar <= fizz.foo + |buzz.bar <- buzz.foo + |""".stripMargin) - result should containLine ("assign fizz_foo_0_b = fizz_bar_0_b;") - result should containLine ("assign fizz_foo_1_b = fizz_bar_1_b;") - result should containLine ("assign fizz_bar_0_a = fizz_foo_0_a;") - result should containLine ("assign fizz_bar_1_a = fizz_foo_1_a;") - result should containLine ("assign buzz_foo_0_b = buzz_bar_0_b;") - result should containLine ("assign buzz_foo_1_b = buzz_bar_1_b;") - result should containLine ("assign buzz_bar_0_a = buzz_foo_0_a;") - result should containLine ("assign buzz_bar_1_a = buzz_foo_1_a;") + result should containLine("assign fizz_foo_0_b = fizz_bar_0_b;") + result should containLine("assign fizz_foo_1_b = fizz_bar_1_b;") + result should containLine("assign fizz_bar_0_a = fizz_foo_0_a;") + result should containLine("assign fizz_bar_1_a = fizz_foo_1_a;") + result should containLine("assign buzz_foo_0_b = buzz_bar_0_b;") + result should containLine("assign buzz_foo_1_b = buzz_bar_1_b;") + result should containLine("assign buzz_bar_0_a = buzz_foo_0_a;") + result should containLine("assign buzz_bar_1_a = buzz_foo_1_a;") } it should "support casting to other types" in { val result = compileBody(s""" - |input a : AsyncReset - |output u : Interval[0, 1].0 - |output v : UInt<1> - |output w : SInt<1> - |output x : Clock - |output y : Fixed<1><<0>> - |output z : AsyncReset - |u <= asInterval(a, 0, 1, 0) - |v <= asUInt(a) - |w <= asSInt(a) - |x <= asClock(a) - |y <= asFixedPoint(a, 0) - |z <= asAsyncReset(a) - |""".stripMargin - ) - result should containLine ("assign v = a;") - result should containLine ("assign w = a;") - result should containLine ("assign x = a;") - result should containLine ("assign y = a;") - result should containLine ("assign z = a;") + |input a : AsyncReset + |output u : Interval[0, 1].0 + |output v : UInt<1> + |output w : SInt<1> + |output x : Clock + |output y : Fixed<1><<0>> + |output z : AsyncReset + |u <= asInterval(a, 0, 1, 0) + |v <= asUInt(a) + |w <= asSInt(a) + |x <= asClock(a) + |y <= asFixedPoint(a, 0) + |z <= asAsyncReset(a) + |""".stripMargin) + result should containLine("assign v = a;") + result should containLine("assign w = a;") + result should containLine("assign x = a;") + result should containLine("assign y = a;") + result should containLine("assign z = a;") } "Other types" should "support casting to AsyncReset" in { val result = compileBody(s""" - |input a : UInt<1> - |input b : SInt<1> - |input c : Clock - |input d : Fixed<1><<0>> - |input e : AsyncReset - |input f : Interval[0, 0].0 - |output u : AsyncReset - |output v : AsyncReset - |output w : AsyncReset - |output x : AsyncReset - |output y : AsyncReset - |output z : AsyncReset - |u <= asAsyncReset(a) - |v <= asAsyncReset(b) - |w <= asAsyncReset(c) - |x <= asAsyncReset(d) - |y <= asAsyncReset(e) - |z <= asAsyncReset(f)""".stripMargin - ) - result should containLine ("assign u = a;") - result should containLine ("assign v = b;") - result should containLine ("assign w = c;") - result should containLine ("assign x = d;") - result should containLine ("assign y = e;") - result should containLine ("assign z = f;") + |input a : UInt<1> + |input b : SInt<1> + |input c : Clock + |input d : Fixed<1><<0>> + |input e : AsyncReset + |input f : Interval[0, 0].0 + |output u : AsyncReset + |output v : AsyncReset + |output w : AsyncReset + |output x : AsyncReset + |output y : AsyncReset + |output z : AsyncReset + |u <= asAsyncReset(a) + |v <= asAsyncReset(b) + |w <= asAsyncReset(c) + |x <= asAsyncReset(d) + |y <= asAsyncReset(e) + |z <= asAsyncReset(f)""".stripMargin) + result should containLine("assign u = a;") + result should containLine("assign v = b;") + result should containLine("assign w = c;") + result should containLine("assign x = d;") + result should containLine("assign y = e;") + result should containLine("assign z = f;") } "Non-literals" should "NOT be allowed as reset values for AsyncReset" in { - an [checks.CheckResets.NonLiteralAsyncResetValueException] shouldBe thrownBy { + an[checks.CheckResets.NonLiteralAsyncResetValueException] shouldBe thrownBy { compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<8> - |input y : UInt<8> - |output z : UInt<8> - |reg r : UInt<8>, clock with : (reset => (reset, y)) - |r <= x - |z <= r""".stripMargin - ) + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<8> + |input y : UInt<8> + |output z : UInt<8> + |reg r : UInt<8>, clock with : (reset => (reset, y)) + |r <= x + |z <= r""".stripMargin) } } "Self-inits" should "NOT cause infinite loops in CheckResets" in { val result = compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input in : UInt<12> - |output out : UInt<10> - | - |reg a : UInt<10>, clock with : - | reset => (reset, a) - |out <= UInt<5>("h15")""".stripMargin - ) + |input clock : Clock + |input reset : AsyncReset + |input in : UInt<12> + |output out : UInt<10> + | + |reg a : UInt<10>, clock with : + | reset => (reset, a) + |out <= UInt<5>("h15")""".stripMargin) result should containLine("assign out = 10'h15;") } "Late non-literals connections" should "NOT be allowed as reset values for AsyncReset" in { - an [checks.CheckResets.NonLiteralAsyncResetValueException] shouldBe thrownBy { + an[checks.CheckResets.NonLiteralAsyncResetValueException] shouldBe thrownBy { compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<8> - |input y : UInt<8> - |output z : UInt<8> - |wire a : UInt<8> - |reg r : UInt<8>, clock with : (reset => (reset, a)) - |a <= y - |r <= x - |z <= r""".stripMargin - ) + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<8> + |input y : UInt<8> + |output z : UInt<8> + |wire a : UInt<8> + |reg r : UInt<8>, clock with : (reset => (reset, a)) + |a <= y + |r <= x + |z <= r""".stripMargin) } } "Hidden Non-literals" should "NOT be allowed as reset values for AsyncReset" in { - an [checks.CheckResets.NonLiteralAsyncResetValueException] shouldBe thrownBy { + an[checks.CheckResets.NonLiteralAsyncResetValueException] shouldBe thrownBy { compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<1>[4] - |input y : UInt<1> - |output z : UInt<1>[4] - |wire literal : UInt<1>[4] - |literal[0] <= UInt<1>("h00") - |literal[1] <= y - |literal[2] <= UInt<1>("h00") - |literal[3] <= UInt<1>("h00") - |reg r : UInt<1>[4], clock with : (reset => (reset, literal)) - |r <= x - |z <= r""".stripMargin - ) + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1>[4] + |input y : UInt<1> + |output z : UInt<1>[4] + |wire literal : UInt<1>[4] + |literal[0] <= UInt<1>("h00") + |literal[1] <= y + |literal[2] <= UInt<1>("h00") + |literal[3] <= UInt<1>("h00") + |reg r : UInt<1>[4], clock with : (reset => (reset, literal)) + |r <= x + |z <= r""".stripMargin) } } "Wire connected to non-literal" should "NOT be allowed as reset values for AsyncReset" in { - an [checks.CheckResets.NonLiteralAsyncResetValueException] shouldBe thrownBy { + an[checks.CheckResets.NonLiteralAsyncResetValueException] shouldBe thrownBy { compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<1> - |input y : UInt<1> - |input cond : UInt<1> - |output z : UInt<1> - |wire w : UInt<1> - |w <= UInt(1) - |when cond : - | w <= y - |reg r : UInt<1>, clock with : (reset => (reset, w)) - |r <= x - |z <= r""".stripMargin - ) + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1> + |input y : UInt<1> + |input cond : UInt<1> + |output z : UInt<1> + |wire w : UInt<1> + |w <= UInt(1) + |when cond : + | w <= y + |reg r : UInt<1>, clock with : (reset => (reset, w)) + |r <= x + |z <= r""".stripMargin) } } "Complex literals" should "be allowed as reset values for AsyncReset" in { val result = compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<1>[4] - |output z : UInt<1>[4] - |wire literal : UInt<1>[4] - |literal[0] <= UInt<1>("h00") - |literal[1] <= UInt<1>("h00") - |literal[2] <= UInt<1>("h00") - |literal[3] <= UInt<1>("h00") - |reg r : UInt<1>[4], clock with : (reset => (reset, literal)) - |r <= x - |z <= r""".stripMargin - ) - result should containLine ("always @(posedge clock or posedge reset) begin") + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1>[4] + |output z : UInt<1>[4] + |wire literal : UInt<1>[4] + |literal[0] <= UInt<1>("h00") + |literal[1] <= UInt<1>("h00") + |literal[2] <= UInt<1>("h00") + |literal[3] <= UInt<1>("h00") + |reg r : UInt<1>[4], clock with : (reset => (reset, literal)) + |r <= x + |z <= r""".stripMargin) + result should containLine("always @(posedge clock or posedge reset) begin") } "Complex literals of complex literals" should "be allowed as reset values for AsyncReset" in { val result = compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<1>[4] - |output z : UInt<1>[4] - |wire literal : UInt<1>[2] - |literal[0] <= UInt<1>("h01") - |literal[1] <= UInt<1>("h01") - |wire complex_literal : UInt<1>[4] - |complex_literal[0] <= literal[0] - |complex_literal[1] <= literal[1] - |complex_literal[2] <= UInt<1>("h00") - |complex_literal[3] <= UInt<1>("h00") - |reg r : UInt<1>[4], clock with : (reset => (reset, complex_literal)) - |r <= x - |z <= r""".stripMargin - ) - result should containLine ("always @(posedge clock or posedge reset) begin") + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1>[4] + |output z : UInt<1>[4] + |wire literal : UInt<1>[2] + |literal[0] <= UInt<1>("h01") + |literal[1] <= UInt<1>("h01") + |wire complex_literal : UInt<1>[4] + |complex_literal[0] <= literal[0] + |complex_literal[1] <= literal[1] + |complex_literal[2] <= UInt<1>("h00") + |complex_literal[3] <= UInt<1>("h00") + |reg r : UInt<1>[4], clock with : (reset => (reset, complex_literal)) + |r <= x + |z <= r""".stripMargin) + result should containLine("always @(posedge clock or posedge reset) begin") } "Literals of bundle literals" should "be allowed as reset values for AsyncReset" in { val result = compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<1>[4] - |output z : UInt<1>[4] - |wire bundle : {a: UInt<1>, b: UInt<1>} - |bundle.a <= UInt<1>("h01") - |bundle.b <= UInt<1>("h01") - |wire complex_literal : UInt<1>[4] - |complex_literal[0] <= bundle.a - |complex_literal[1] <= bundle.b - |complex_literal[2] <= UInt<1>("h00") - |complex_literal[3] <= UInt<1>("h00") - |reg r : UInt<1>[4], clock with : (reset => (reset, complex_literal)) - |r <= x - |z <= r""".stripMargin - ) - result should containLine ("always @(posedge clock or posedge reset) begin") + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1>[4] + |output z : UInt<1>[4] + |wire bundle : {a: UInt<1>, b: UInt<1>} + |bundle.a <= UInt<1>("h01") + |bundle.b <= UInt<1>("h01") + |wire complex_literal : UInt<1>[4] + |complex_literal[0] <= bundle.a + |complex_literal[1] <= bundle.b + |complex_literal[2] <= UInt<1>("h00") + |complex_literal[3] <= UInt<1>("h00") + |reg r : UInt<1>[4], clock with : (reset => (reset, complex_literal)) + |r <= x + |z <= r""".stripMargin) + result should containLine("always @(posedge clock or posedge reset) begin") } "Cast literals" should "be allowed as reset values for AsyncReset" in { // This also checks that casts can be across wires and nodes val sintResult = compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input x : SInt<4> - |output y : SInt<4> - |output z : SInt<4> - |reg r : SInt<4>, clock with : (reset => (reset, asSInt(UInt(0)))) - |r <= x - |wire w : SInt<4> - |reg r2 : SInt<4>, clock with : (reset => (reset, w)) - |r2 <= x - |node n = UInt("hf") - |w <= asSInt(n) - |y <= r2 - |z <= r""".stripMargin - ) - sintResult should containLine ("always @(posedge clock or posedge reset) begin") - sintResult should containLine ("r <= 4'sh0;") - sintResult should containLine ("r2 <= -4'sh1;") + |input clock : Clock + |input reset : AsyncReset + |input x : SInt<4> + |output y : SInt<4> + |output z : SInt<4> + |reg r : SInt<4>, clock with : (reset => (reset, asSInt(UInt(0)))) + |r <= x + |wire w : SInt<4> + |reg r2 : SInt<4>, clock with : (reset => (reset, w)) + |r2 <= x + |node n = UInt("hf") + |w <= asSInt(n) + |y <= r2 + |z <= r""".stripMargin) + sintResult should containLine("always @(posedge clock or posedge reset) begin") + sintResult should containLine("r <= 4'sh0;") + sintResult should containLine("r2 <= -4'sh1;") val fixedResult = compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input x : Fixed<2><<0>> - |output z : Fixed<2><<0>> - |reg r : Fixed<2><<0>>, clock with : (reset => (reset, asFixedPoint(UInt(2), 0))) - |r <= x - |z <= r""".stripMargin - ) - fixedResult should containLine ("always @(posedge clock or posedge reset) begin") - fixedResult should containLine ("r <= 2'sh2;") + |input clock : Clock + |input reset : AsyncReset + |input x : Fixed<2><<0>> + |output z : Fixed<2><<0>> + |reg r : Fixed<2><<0>>, clock with : (reset => (reset, asFixedPoint(UInt(2), 0))) + |r <= x + |z <= r""".stripMargin) + fixedResult should containLine("always @(posedge clock or posedge reset) begin") + fixedResult should containLine("r <= 2'sh2;") val intervalResult = compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input x : Interval[0, 4].0 - |output z : Interval[0, 4].0 - |reg r : Interval[0, 4].0, clock with : (reset => (reset, asInterval(UInt(0), 0, 0, 0))) - |r <= x - |z <= r""".stripMargin - ) - intervalResult should containLine ("always @(posedge clock or posedge reset) begin") - intervalResult should containLine ("r <= 4'sh0;") + |input clock : Clock + |input reset : AsyncReset + |input x : Interval[0, 4].0 + |output z : Interval[0, 4].0 + |reg r : Interval[0, 4].0, clock with : (reset => (reset, asInterval(UInt(0), 0, 0, 0))) + |r <= x + |z <= r""".stripMargin) + intervalResult should containLine("always @(posedge clock or posedge reset) begin") + intervalResult should containLine("r <= 4'sh0;") } "CheckResets" should "NOT raise StackOverflow Exception on Combinational Loops (should be caught by firrtl.transforms.CheckCombLoops)" in { - an [firrtl.transforms.CheckCombLoops.CombLoopException] shouldBe thrownBy { + an[firrtl.transforms.CheckCombLoops.CombLoopException] shouldBe thrownBy { compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |wire x : UInt<1> - |wire y : UInt<2> - |x <= UInt<1>("h01") - |node ad = add(x, y) - |node adt = tail(ad, 1) - |y <= adt - |reg r : UInt, clock with : (reset => (reset, y)) - |""".stripMargin - ) + |input clock : Clock + |input reset : AsyncReset + |wire x : UInt<1> + |wire y : UInt<2> + |x <= UInt<1>("h01") + |node ad = add(x, y) + |node adt = tail(ad, 1) + |y <= adt + |reg r : UInt, clock with : (reset => (reset, y)) + |""".stripMargin) } } "Every async reset reg" should "generate its own always block" in { val result = compileBody(s""" - |input clock0 : Clock - |input clock1 : Clock - |input syncReset : UInt<1> - |input asyncReset : AsyncReset - |input x : UInt<8>[5] - |output z : UInt<8>[5] - |reg r0 : UInt<8>, clock0 with : (reset => (syncReset, UInt(123))) - |reg r1 : UInt<8>, clock1 with : (reset => (syncReset, UInt(123))) - |reg r2 : UInt<8>, clock0 with : (reset => (asyncReset, UInt(123))) - |reg r3 : UInt<8>, clock0 with : (reset => (asyncReset, UInt(123))) - |reg r4 : UInt<8>, clock1 with : (reset => (asyncReset, UInt(123))) - |r0 <= x[0] - |r1 <= x[1] - |r2 <= x[2] - |r3 <= x[3] - |r4 <= x[4] - |z[0] <= r0 - |z[1] <= r1 - |z[2] <= r2 - |z[3] <= r3 - |z[4] <= r4""".stripMargin - ) - result should containLines ( + |input clock0 : Clock + |input clock1 : Clock + |input syncReset : UInt<1> + |input asyncReset : AsyncReset + |input x : UInt<8>[5] + |output z : UInt<8>[5] + |reg r0 : UInt<8>, clock0 with : (reset => (syncReset, UInt(123))) + |reg r1 : UInt<8>, clock1 with : (reset => (syncReset, UInt(123))) + |reg r2 : UInt<8>, clock0 with : (reset => (asyncReset, UInt(123))) + |reg r3 : UInt<8>, clock0 with : (reset => (asyncReset, UInt(123))) + |reg r4 : UInt<8>, clock1 with : (reset => (asyncReset, UInt(123))) + |r0 <= x[0] + |r1 <= x[1] + |r2 <= x[2] + |r3 <= x[3] + |r4 <= x[4] + |z[0] <= r0 + |z[1] <= r1 + |z[2] <= r2 + |z[3] <= r3 + |z[4] <= r4""".stripMargin) + result should containLines( "always @(posedge clock0) begin", "if (syncReset) begin", "r0 <= 8'h7b;", @@ -341,7 +324,7 @@ class AsyncResetSpec extends VerilogTransformSpec { "end", "end" ) - result should containLines ( + result should containLines( "always @(posedge clock1) begin", "if (syncReset) begin", "r1 <= 8'h7b;", @@ -350,7 +333,7 @@ class AsyncResetSpec extends VerilogTransformSpec { "end", "end" ) - result should containLines ( + result should containLines( "always @(posedge clock0 or posedge asyncReset) begin", "if (asyncReset) begin", "r2 <= 8'h7b;", @@ -359,7 +342,7 @@ class AsyncResetSpec extends VerilogTransformSpec { "end", "end" ) - result should containLines ( + result should containLines( "always @(posedge clock0 or posedge asyncReset) begin", "if (asyncReset) begin", "r3 <= 8'h7b;", @@ -368,7 +351,7 @@ class AsyncResetSpec extends VerilogTransformSpec { "end", "end" ) - result should containLines ( + result should containLines( "always @(posedge clock1 or posedge asyncReset) begin", "if (asyncReset) begin", "r4 <= 8'h7b;", @@ -427,27 +410,26 @@ class AsyncResetSpec extends VerilogTransformSpec { "AsyncReset registers" should "emit 'else' case for reset even for trivial valued registers" in { val withDontTouch = s""" - |circuit m : - | module m : - | input clock : Clock - | input reset : AsyncReset - | input x : UInt<8> - | reg r : UInt<8>, clock with : (reset => (reset, UInt(123))) - |""".stripMargin + |circuit m : + | module m : + | input clock : Clock + | input reset : AsyncReset + | input x : UInt<8> + | reg r : UInt<8>, clock with : (reset => (reset, UInt(123))) + |""".stripMargin val annos = Seq(dontTouch("m.r")) // dontTouch prevents ConstantPropagation from fixing this problem val result = (new VerilogCompiler).compileAndEmit(CircuitState(parse(withDontTouch), ChirrtlForm, annos)) - result should containLines ( - "always @(posedge clock or posedge reset) begin", - "if (reset) begin", - "r <= 8'h7b;", - "end else begin", - "r <= 8'h7b;", - "end", - "end" - ) + result should containLines( + "always @(posedge clock or posedge reset) begin", + "if (reset) begin", + "r <= 8'h7b;", + "end else begin", + "r <= 8'h7b;", + "end", + "end" + ) } } class AsyncResetExecutionTest extends ExecutionTest("AsyncResetTester", "/features") - diff --git a/src/test/scala/firrtlTests/AttachSpec.scala b/src/test/scala/firrtlTests/AttachSpec.scala index e4acc735..709e3692 100644 --- a/src/test/scala/firrtlTests/AttachSpec.scala +++ b/src/test/scala/firrtlTests/AttachSpec.scala @@ -9,12 +9,12 @@ import firrtl.testutils._ class InoutVerilogSpec extends FirrtlFlatSpec { - behavior of "Analog" + behavior.of("Analog") it should "attach a module input source directly" in { val compiler = new VerilogCompiler val input = - """circuit Attaching : + """circuit Attaching : | module Attaching : | input an: Analog<3> | inst a of A @@ -25,32 +25,32 @@ class InoutVerilogSpec extends FirrtlFlatSpec { | module B: | input an2: Analog<3> """.stripMargin val check = - """module Attaching( - | inout [2:0] an - |); - | A a ( - | .an1(an) - | ); - | B b ( - | .an2(an) - | ); - |endmodule - |module A( - | inout [2:0] an1 - |); - |endmodule - |module B( - | inout [2:0] an2 - |); - |endmodule - |""".stripMargin.split("\n") map normalized + """module Attaching( + | inout [2:0] an + |); + | A a ( + | .an1(an) + | ); + | B b ( + | .an2(an) + | ); + |endmodule + |module A( + | inout [2:0] an1 + |); + |endmodule + |module B( + | inout [2:0] an2 + |); + |endmodule + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler, Seq(dontDedup("A"), dontDedup("B"))) } it should "attach two instances" in { val compiler = new VerilogCompiler val input = - """circuit Attaching : + """circuit Attaching : | module Attaching : | inst a of A | inst b of B @@ -60,24 +60,24 @@ class InoutVerilogSpec extends FirrtlFlatSpec { | module B: | input an: Analog<3>""".stripMargin val check = - """module Attaching( - |); - | wire [2:0] _GEN_0; - | A a ( - | .an(_GEN_0) - | ); - | B b ( - | .an(_GEN_0) - | ); - |endmodule - |module A( - | inout [2:0] an - |); - |module B( - | inout [2:0] an - |); - |endmodule - |""".stripMargin.split("\n") map normalized + """module Attaching( + |); + | wire [2:0] _GEN_0; + | A a ( + | .an(_GEN_0) + | ); + | B b ( + | .an(_GEN_0) + | ); + |endmodule + |module A( + | inout [2:0] an + |); + |module B( + | inout [2:0] an + |); + |endmodule + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler, Seq(dontTouch("A.an"), dontDedup("A"))) } @@ -85,12 +85,12 @@ class InoutVerilogSpec extends FirrtlFlatSpec { val compiler = new VerilogCompiler val input = """circuit Attaching : - | module Attaching : - | wire x: Analog - | inst a of A - | attach (x, a.an) - | module A: - | input an: Analog<3> """.stripMargin + | module Attaching : + | wire x: Analog + | inst a of A + | attach (x, a.an) + | module A: + | input an: Analog<3> """.stripMargin val check = """module Attaching( |); @@ -99,7 +99,7 @@ class InoutVerilogSpec extends FirrtlFlatSpec { | .an(x) | ); |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler, Seq(dontTouch("Attaching.x"))) } @@ -107,14 +107,14 @@ class InoutVerilogSpec extends FirrtlFlatSpec { val compiler = new VerilogCompiler val input = """circuit Attaching : - | module Attaching : - | input an: Analog<3> - | wire x: Analog - | inst a of A - | attach (x, a.an) - | attach (x, an) - | module A: - | input an: Analog<3> """.stripMargin + | module Attaching : + | input an: Analog<3> + | wire x: Analog + | inst a of A + | attach (x, a.an) + | attach (x, an) + | module A: + | input an: Analog<3> """.stripMargin val check = """module Attaching( | inout [2:0] an @@ -123,20 +123,19 @@ class InoutVerilogSpec extends FirrtlFlatSpec { | .an(an) | ); |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler, Seq(dontTouch("Attaching.x"))) } - it should "attach multiple sources" in { val compiler = new VerilogCompiler val input = """circuit Attaching : - | module Attaching : - | input a1 : Analog<3> - | input a2 : Analog<3> - | wire x: Analog<3> - | attach (x, a1, a2)""".stripMargin + | module Attaching : + | input a1 : Analog<3> + | input a2 : Analog<3> + | wire x: Analog<3> + | attach (x, a1, a2)""".stripMargin val check = """module Attaching( | inout [2:0] a1, @@ -151,7 +150,7 @@ class InoutVerilogSpec extends FirrtlFlatSpec { | alias a1 = a2; | `endif |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } @@ -159,10 +158,10 @@ class InoutVerilogSpec extends FirrtlFlatSpec { val compiler = new VerilogCompiler val input = """circuit Attaching : - | module Attaching : - | input foo : { b : UInt<3>, a : Analog<3> } - | output bar : { b : UInt<3>, a : Analog<3> } - | bar <- foo""".stripMargin + | module Attaching : + | input foo : { b : UInt<3>, a : Analog<3> } + | output bar : { b : UInt<3>, a : Analog<3> } + | bar <- foo""".stripMargin // Omitting `ifdef SYNTHESIS and `elsif verilator since it's tested above val check = """module Attaching( @@ -174,7 +173,7 @@ class InoutVerilogSpec extends FirrtlFlatSpec { | assign bar_b = foo_b; | alias bar_a = foo_a; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } @@ -182,14 +181,14 @@ class InoutVerilogSpec extends FirrtlFlatSpec { val compiler = new VerilogCompiler val input = """circuit Attaching : - | module Attaching : - | input a : Analog<32> - | input b : Analog<32> - | input c : Analog<32> - | input d : Analog<32> - | attach (a, b) - | attach (c, b) - | attach (a, d)""".stripMargin + | module Attaching : + | input a : Analog<32> + | input b : Analog<32> + | input c : Analog<32> + | input d : Analog<32> + | attach (a, b) + | attach (c, b) + | attach (a, d)""".stripMargin val check = """module Attaching( | inout [31:0] a, @@ -199,19 +198,19 @@ class InoutVerilogSpec extends FirrtlFlatSpec { |); | alias a = b = c = d; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) val input2 = """circuit Attaching : - | module Attaching : - | input a : Analog<32> - | input b : Analog<32> - | input c : Analog<32> - | input d : Analog<32> - | attach (a, b) - | attach (c, d) - | attach (d, a)""".stripMargin + | module Attaching : + | input a : Analog<32> + | input b : Analog<32> + | input c : Analog<32> + | input d : Analog<32> + | attach (a, b) + | attach (c, d) + | attach (d, a)""".stripMargin val check2 = """module Attaching( | inout [31:0] a, @@ -221,14 +220,14 @@ class InoutVerilogSpec extends FirrtlFlatSpec { |); | alias a = b = c = d; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input2, check2, compiler) } it should "infer widths" in { val compiler = new VerilogCompiler val input = - """circuit Attaching : + """circuit Attaching : | module Attaching : | input an: Analog | inst a of A @@ -236,70 +235,65 @@ class InoutVerilogSpec extends FirrtlFlatSpec { | module A: | input an1: Analog<3>""".stripMargin val check = - """module Attaching( - | inout [2:0] an - |); - | A a ( - | .an1(an) - | ); - |endmodule - |module A( - | inout [2:0] an1 - |); - |endmodule""".stripMargin.split("\n") map normalized + """module Attaching( + | inout [2:0] an + |); + | A a ( + | .an1(an) + | ); + |endmodule + |module A( + | inout [2:0] an1 + |); + |endmodule""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } it should "not error if not isinvalid" in { val compiler = new VerilogCompiler val input = - """circuit Attaching : + """circuit Attaching : | module Attaching : | output an: Analog<3> |""".stripMargin val check = - """module Attaching( - | inout [2:0] an - |); - |endmodule""".stripMargin.split("\n") map normalized + """module Attaching( + | inout [2:0] an + |); + |endmodule""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } it should "not error if isinvalid" in { val compiler = new VerilogCompiler val input = - """circuit Attaching : + """circuit Attaching : | module Attaching : | output an: Analog<3> | an is invalid |""".stripMargin val check = - """module Attaching( - | inout [2:0] an - |); - |endmodule""".stripMargin.split("\n") map normalized + """module Attaching( + | inout [2:0] an + |); + |endmodule""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } } class AttachAnalogSpec extends FirrtlFlatSpec { private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { - val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) + val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => + p.run(c) } - val lines = c.serialize.split("\n") map normalized + val lines = c.serialize.split("\n").map(normalized) - expected foreach { e => + expected.foreach { e => lines should contain(e) } } "Connecting analog types" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes) val input = """circuit Unit : | module Unit : @@ -307,38 +301,28 @@ class AttachAnalogSpec extends FirrtlFlatSpec { | output x: Analog<1> | x <= y""".stripMargin intercept[CheckTypes.InvalidConnect] { - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => + p.run(c) } } } "Declaring register with analog types" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes) val input = """circuit Unit : | module Unit : | input clk: Clock | reg r: Analog<2>, clk""".stripMargin intercept[CheckTypes.IllegalAnalogDeclaration] { - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => + p.run(c) } } } "Declaring memory with analog types" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes) val input = """circuit Unit : | module Unit : @@ -350,38 +334,28 @@ class AttachAnalogSpec extends FirrtlFlatSpec { | write-latency => 1 | read-under-write => undefined""".stripMargin intercept[CheckTypes.IllegalAnalogDeclaration] { - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => + p.run(c) } } } "Declaring node with analog types" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes) val input = """circuit Unit : | module Unit : | input in: Analog<2> | node n = in """.stripMargin intercept[CheckTypes.IllegalAnalogDeclaration] { - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => + p.run(c) } } } "Attaching a non-analog expression" should "not be ok" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes) val input = """circuit Unit : | module Unit : @@ -394,21 +368,14 @@ class AttachAnalogSpec extends FirrtlFlatSpec { | extmodule B: | input o: Analog<2>""".stripMargin intercept[CheckTypes.OpNotAnalog] { - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => + p.run(c) } } } "Inequal attach widths" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes, - new InferWidths, - CheckWidths) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes, new InferWidths, CheckWidths) val input = """circuit Unit : | module Unit : @@ -418,8 +385,8 @@ class AttachAnalogSpec extends FirrtlFlatSpec { | extmodule A : | output o: Analog<2> """.stripMargin intercept[CheckWidths.AttachWidthsNotEqual] { - passes.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) + passes.foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) } } } diff --git a/src/test/scala/firrtlTests/CInferMDirSpec.scala b/src/test/scala/firrtlTests/CInferMDirSpec.scala index 6c9d4047..ce0c0a74 100644 --- a/src/test/scala/firrtlTests/CInferMDirSpec.scala +++ b/src/test/scala/firrtlTests/CInferMDirSpec.scala @@ -14,22 +14,21 @@ class CInferMDirSpec extends LowTransformSpec { def checkStmt(s: Statement): Boolean = s match { case s: DefMemory if s.name == "indices" => (s.readers contains "index") && - (s.writers contains "bar") && - s.readwriters.isEmpty + (s.writers contains "bar") && + s.readwriters.isEmpty case s: Block => - s.stmts exists checkStmt + s.stmts.exists(checkStmt) case _ => false } - def run (c: Circuit) = { + def run(c: Circuit) = { val errors = new Errors - val check = c.modules exists { - case m: Module => checkStmt(m.body) + val check = c.modules.exists { + case m: Module => checkStmt(m.body) case m: ExtModule => false } if (!check) { - errors append new PassException( - "Memory has incorrect port directions!") + errors.append(new PassException("Memory has incorrect port directions!")) errors.trigger } c diff --git a/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala b/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala index 2016e160..d8151142 100644 --- a/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala +++ b/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala @@ -12,46 +12,46 @@ import java.nio.file.Paths import firrtl.options.Dependency import firrtl.stage.FirrtlStage -class CheckCombLoopsSpec extends LeanTransformSpec(Seq(Dependency[CheckCombLoops]) ){ +class CheckCombLoopsSpec extends LeanTransformSpec(Seq(Dependency[CheckCombLoops])) { "Loop-free circuit" should "not throw an exception" in { val input = """circuit hasnoloops : - | module thru : - | input in1 : UInt<1> - | input in2 : UInt<1> - | output out1 : UInt<1> - | output out2 : UInt<1> - | out1 <= in1 - | out2 <= in2 - | module hasnoloops : - | input clk : Clock - | input a : UInt<1> - | output b : UInt<1> - | wire x : UInt<1> - | inst inner of thru - | inner.in1 <= a - | x <= inner.out1 - | inner.in2 <= x - | b <= inner.out2 - |""".stripMargin + | module thru : + | input in1 : UInt<1> + | input in2 : UInt<1> + | output out1 : UInt<1> + | output out2 : UInt<1> + | out1 <= in1 + | out2 <= in2 + | module hasnoloops : + | input clk : Clock + | input a : UInt<1> + | output b : UInt<1> + | wire x : UInt<1> + | inst inner of thru + | inner.in1 <= a + | x <= inner.out1 + | inner.in2 <= x + | b <= inner.out2 + |""".stripMargin compile(parse(input)) } "Simple combinational loop" should "throw an exception" in { val input = """circuit hasloops : - | module hasloops : - | input clk : Clock - | input a : UInt<1> - | input b : UInt<1> - | output c : UInt<1> - | output d : UInt<1> - | wire y : UInt<1> - | wire z : UInt<1> - | c <= b - | z <= y - | y <= z - | d <= z - |""".stripMargin + | module hasloops : + | input clk : Clock + | input a : UInt<1> + | input b : UInt<1> + | output c : UInt<1> + | output d : UInt<1> + | wire y : UInt<1> + | wire z : UInt<1> + | c <= b + | z <= y + | y <= z + | d <= z + |""".stripMargin intercept[CheckCombLoops.CombLoopException] { compile(parse(input)) @@ -60,12 +60,12 @@ class CheckCombLoopsSpec extends LeanTransformSpec(Seq(Dependency[CheckCombLoops "Single-element combinational loop" should "throw an exception" in { val input = """circuit loop : - | module loop : - | output y : UInt<8> - | wire w : UInt<8> - | w <= w - | y <= w - |""".stripMargin + | module loop : + | output y : UInt<8> + | wire w : UInt<8> + | w <= w + | y <= w + |""".stripMargin intercept[CheckCombLoops.CombLoopException] { compile(parse(input)) @@ -74,18 +74,18 @@ class CheckCombLoopsSpec extends LeanTransformSpec(Seq(Dependency[CheckCombLoops "Node combinational loop" should "throw an exception" in { val input = """circuit hasloops : - | module hasloops : - | input clk : Clock - | input a : UInt<1> - | input b : UInt<1> - | output c : UInt<1> - | output d : UInt<1> - | wire y : UInt<1> - | c <= b - | node z = and(c,y) - | y <= z - | d <= z - |""".stripMargin + | module hasloops : + | input clk : Clock + | input a : UInt<1> + | input b : UInt<1> + | output c : UInt<1> + | output d : UInt<1> + | wire y : UInt<1> + | c <= b + | node z = and(c,y) + | y <= z + | d <= z + |""".stripMargin intercept[CheckCombLoops.CombLoopException] { compile(parse(input)) @@ -94,29 +94,29 @@ class CheckCombLoopsSpec extends LeanTransformSpec(Seq(Dependency[CheckCombLoops "Combinational loop through a combinational memory read port" should "throw an exception" in { val input = """circuit hasloops : - | module hasloops : - | input clk : Clock - | input a : UInt<1> - | input b : UInt<1> - | output c : UInt<1> - | output d : UInt<1> - | wire y : UInt<1> - | wire z : UInt<1> - | c <= b - | mem m : - | data-type => UInt<1> - | depth => 2 - | read-latency => 0 - | write-latency => 1 - | reader => r - | read-under-write => undefined - | m.r.clk <= clk - | m.r.addr <= y - | m.r.en <= UInt(1) - | z <= m.r.data - | y <= z - | d <= z - |""".stripMargin + | module hasloops : + | input clk : Clock + | input a : UInt<1> + | input b : UInt<1> + | output c : UInt<1> + | output d : UInt<1> + | wire y : UInt<1> + | wire z : UInt<1> + | c <= b + | mem m : + | data-type => UInt<1> + | depth => 2 + | read-latency => 0 + | write-latency => 1 + | reader => r + | read-under-write => undefined + | m.r.clk <= clk + | m.r.addr <= y + | m.r.en <= UInt(1) + | z <= m.r.data + | y <= z + | d <= z + |""".stripMargin intercept[CheckCombLoops.CombLoopException] { compile(parse(input)) @@ -125,25 +125,25 @@ class CheckCombLoopsSpec extends LeanTransformSpec(Seq(Dependency[CheckCombLoops "Combination loop through an instance" should "throw an exception" in { val input = """circuit hasloops : - | module thru : - | input in : UInt<1> - | output out : UInt<1> - | out <= in - | module hasloops : - | input clk : Clock - | input a : UInt<1> - | input b : UInt<1> - | output c : UInt<1> - | output d : UInt<1> - | wire y : UInt<1> - | wire z : UInt<1> - | c <= b - | inst inner of thru - | inner.in <= y - | z <= inner.out - | y <= z - | d <= z - |""".stripMargin + | module thru : + | input in : UInt<1> + | output out : UInt<1> + | out <= in + | module hasloops : + | input clk : Clock + | input a : UInt<1> + | input b : UInt<1> + | output c : UInt<1> + | output d : UInt<1> + | wire y : UInt<1> + | wire z : UInt<1> + | c <= b + | inst inner of thru + | inner.in <= y + | z <= inner.out + | y <= z + | d <= z + |""".stripMargin intercept[CheckCombLoops.CombLoopException] { compile(parse(input)) @@ -152,24 +152,24 @@ class CheckCombLoopsSpec extends LeanTransformSpec(Seq(Dependency[CheckCombLoops "Combinational loop through an annotated ExtModule" should "throw an exception" in { val input = """circuit hasloops : - | extmodule blackbox : - | input in : UInt<1> - | output out : UInt<1> - | module hasloops : - | input clk : Clock - | input a : UInt<1> - | input b : UInt<1> - | output c : UInt<1> - | output d : UInt<1> - | wire y : UInt<1> - | wire z : UInt<1> - | c <= b - | inst inner of blackbox - | inner.in <= y - | z <= inner.out - | y <= z - | d <= z - |""".stripMargin + | extmodule blackbox : + | input in : UInt<1> + | output out : UInt<1> + | module hasloops : + | input clk : Clock + | input a : UInt<1> + | input b : UInt<1> + | output c : UInt<1> + | output d : UInt<1> + | wire y : UInt<1> + | wire z : UInt<1> + | c <= b + | inst inner of blackbox + | inner.in <= y + | z <= inner.out + | y <= z + | d <= z + |""".stripMargin val mt = ModuleTarget("hasloops", "blackbox") val annos = AnnotationSeq(Seq(ExtModulePathAnnotation(mt.ref("in"), mt.ref("out")))) @@ -180,53 +180,56 @@ class CheckCombLoopsSpec extends LeanTransformSpec(Seq(Dependency[CheckCombLoops "Loop-free circuit with ExtModulePathAnnotations" should "not throw an exception" in { val input = """circuit hasnoloops : - | extmodule blackbox : - | input in1 : UInt<1> - | input in2 : UInt<1> - | output out1 : UInt<1> - | output out2 : UInt<1> - | module hasnoloops : - | input clk : Clock - | input a : UInt<1> - | output b : UInt<1> - | wire x : UInt<1> - | inst inner of blackbox - | inner.in1 <= a - | x <= inner.out1 - | inner.in2 <= x - | b <= inner.out2 - |""".stripMargin + | extmodule blackbox : + | input in1 : UInt<1> + | input in2 : UInt<1> + | output out1 : UInt<1> + | output out2 : UInt<1> + | module hasnoloops : + | input clk : Clock + | input a : UInt<1> + | output b : UInt<1> + | wire x : UInt<1> + | inst inner of blackbox + | inner.in1 <= a + | x <= inner.out1 + | inner.in2 <= x + | b <= inner.out2 + |""".stripMargin val mt = ModuleTarget("hasnoloops", "blackbox") - val annos = AnnotationSeq(Seq( - ExtModulePathAnnotation(mt.ref("in1"), mt.ref("out1")), - ExtModulePathAnnotation(mt.ref("in2"), mt.ref("out2")))) + val annos = AnnotationSeq( + Seq( + ExtModulePathAnnotation(mt.ref("in1"), mt.ref("out1")), + ExtModulePathAnnotation(mt.ref("in2"), mt.ref("out2")) + ) + ) compile(parse(input), annos) } "Combinational loop through an output RHS reference" should "throw an exception" in { val input = """circuit hasloops : - | module thru : - | input in : UInt<1> - | output tmp : UInt<1> - | output out : UInt<1> - | tmp <= in - | out <= tmp - | module hasloops : - | input clk : Clock - | input a : UInt<1> - | input b : UInt<1> - | output c : UInt<1> - | output d : UInt<1> - | wire y : UInt<1> - | wire z : UInt<1> - | c <= b - | inst inner of thru - | inner.in <= y - | z <= inner.out - | y <= z - | d <= z - |""".stripMargin + | module thru : + | input in : UInt<1> + | output tmp : UInt<1> + | output out : UInt<1> + | tmp <= in + | out <= tmp + | module hasloops : + | input clk : Clock + | input a : UInt<1> + | input b : UInt<1> + | output c : UInt<1> + | output d : UInt<1> + | wire y : UInt<1> + | wire z : UInt<1> + | c <= b + | inst inner of thru + | inner.in <= y + | z <= inner.out + | y <= z + | d <= z + |""".stripMargin intercept[CheckCombLoops.CombLoopException] { compile(parse(input)) @@ -235,21 +238,21 @@ class CheckCombLoopsSpec extends LeanTransformSpec(Seq(Dependency[CheckCombLoops "Multiple simple loops in one SCC" should "throw an exception" in { val input = """circuit hasloops : - | module hasloops : - | input i : UInt<1> - | output o : UInt<1> - | wire a : UInt<1> - | wire b : UInt<1> - | wire c : UInt<1> - | wire d : UInt<1> - | wire e : UInt<1> - | a <= and(c,i) - | b <= and(a,d) - | c <= b - | d <= and(c,e) - | e <= b - | o <= e - |""".stripMargin + | module hasloops : + | input i : UInt<1> + | output o : UInt<1> + | wire a : UInt<1> + | wire b : UInt<1> + | wire c : UInt<1> + | wire d : UInt<1> + | wire e : UInt<1> + | a <= and(c,i) + | b <= and(a,d) + | c <= b + | d <= and(c,e) + | e <= b + | o <= e + |""".stripMargin intercept[CheckCombLoops.CombLoopException] { compile(parse(input)) @@ -280,7 +283,7 @@ class CheckCombLoopsSpec extends LeanTransformSpec(Seq(Dependency[CheckCombLoops val cs = compile(parse(input)) val mt = ModuleTarget("hasnoloops", "hasnoloops") val anno = CombinationalPath(mt.ref("b"), Seq(mt.ref("a"))) - cs.annotations.contains(anno) should be (true) + cs.annotations.contains(anno) should be(true) } } @@ -292,7 +295,7 @@ class CheckCombLoopsCommandLineSpec extends FirrtlFlatSpec { val args = Array("-i", inputFile.getAbsolutePath, "-o", outFile.getAbsolutePath, "-X", "verilog") "Combinational loops detection" should "run by default" in { - a [CheckCombLoops.CombLoopException] should be thrownBy { + a[CheckCombLoops.CombLoopException] should be thrownBy { (new FirrtlStage).execute(args, Seq()) } } diff --git a/src/test/scala/firrtlTests/CheckInitializationSpec.scala b/src/test/scala/firrtlTests/CheckInitializationSpec.scala index 34e0da03..5fd9543e 100644 --- a/src/test/scala/firrtlTests/CheckInitializationSpec.scala +++ b/src/test/scala/firrtlTests/CheckInitializationSpec.scala @@ -2,27 +2,27 @@ package firrtlTests -import firrtl.{CircuitState, UnknownForm, Transform} +import firrtl.{CircuitState, Transform, UnknownForm} import firrtl.passes._ import firrtl.testutils._ class CheckInitializationSpec extends FirrtlFlatSpec { private val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes, - ResolveFlows, - CheckFlows, - new InferWidths, - CheckWidths, - PullMuxes, - ExpandConnects, - RemoveAccesses, - ExpandWhens, - CheckInitialization, - InferTypes + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + ResolveFlows, + CheckFlows, + new InferWidths, + CheckWidths, + PullMuxes, + ExpandConnects, + RemoveAccesses, + ExpandWhens, + CheckInitialization, + InferTypes ) "Missing assignment in consequence branch" should "trigger a PassException" in { val input = @@ -33,8 +33,8 @@ class CheckInitializationSpec extends FirrtlFlatSpec { | when p : | x <= UInt(1)""".stripMargin intercept[CheckInitialization.RefNotInitializedException] { - passes.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) + passes.foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) } } } @@ -48,8 +48,8 @@ class CheckInitializationSpec extends FirrtlFlatSpec { | else : | x <= UInt(1)""".stripMargin intercept[CheckInitialization.RefNotInitializedException] { - passes.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) + passes.foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) } } } @@ -64,8 +64,8 @@ class CheckInitializationSpec extends FirrtlFlatSpec { | x <= UInt(1) | x <= UInt(1) | """.stripMargin - passes.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) + passes.foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) } } @@ -84,8 +84,8 @@ class CheckInitializationSpec extends FirrtlFlatSpec { | x <= UInt(2) | x <= UInt(1) | """.stripMargin - passes.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) + passes.foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) } } @@ -100,8 +100,8 @@ class CheckInitializationSpec extends FirrtlFlatSpec { | when p : | c.in <= UInt(1)""".stripMargin intercept[CheckInitialization.RefNotInitializedException] { - passes.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) + passes.foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) } } } diff --git a/src/test/scala/firrtlTests/CheckSpec.scala b/src/test/scala/firrtlTests/CheckSpec.scala index 5c38bf30..20a5f969 100644 --- a/src/test/scala/firrtlTests/CheckSpec.scala +++ b/src/test/scala/firrtlTests/CheckSpec.scala @@ -3,17 +3,29 @@ package firrtlTests import org.scalatest._ -import firrtl.{Parser, CircuitState, UnknownForm, Transform} +import firrtl.{CircuitState, Parser, Transform, UnknownForm} import firrtl.ir.Circuit -import firrtl.passes.{Pass,ToWorkingIR,CheckHighForm,ResolveKinds,InferTypes,CheckTypes,PassException,InferWidths,CheckWidths,ResolveFlows,CheckFlows} +import firrtl.passes.{ + CheckFlows, + CheckHighForm, + CheckTypes, + CheckWidths, + InferTypes, + InferWidths, + Pass, + PassException, + ResolveFlows, + ResolveKinds, + ToWorkingIR +} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers class CheckSpec extends AnyFlatSpec with Matchers { val defaultPasses = Seq(ToWorkingIR, CheckHighForm) def checkHighInput(input: String) = { - defaultPasses.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) + defaultPasses.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => + p.run(c) } } @@ -44,9 +56,7 @@ class CheckSpec extends AnyFlatSpec with Matchers { } "Memories with zero write latency" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm) + val passes = Seq(ToWorkingIR, CheckHighForm) val input = """circuit Unit : | module Unit : @@ -56,8 +66,8 @@ class CheckSpec extends AnyFlatSpec with Matchers { | read-latency => 0 | write-latency => 0""".stripMargin intercept[CheckHighForm.IllegalMemLatencyException] { - passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => + p.run(c) } } } @@ -181,90 +191,81 @@ class CheckSpec extends AnyFlatSpec with Matchers { ResolveFlows, CheckFlows, new InferWidths, - CheckWidths) + CheckWidths + ) val input = """ - |circuit TheRealTop : - | - | module Top : - | output io : {flip debug_clk : Clock} - | - | extmodule BlackBoxTop : - | input jtag : {TCK : Clock} - | - | module TheRealTop : - | input clock : Clock - | input reset : UInt<1> - | output io : {flip jtag : {TCK : Clock}} - | - | io is invalid - | inst sub of Top - | sub.io is invalid - | inst bb of BlackBoxTop - | bb.jtag is invalid - | bb.jtag <- io.jtag - | - | sub.io.debug_clk <= io.jtag.TCK - | - |""".stripMargin + |circuit TheRealTop : + | + | module Top : + | output io : {flip debug_clk : Clock} + | + | extmodule BlackBoxTop : + | input jtag : {TCK : Clock} + | + | module TheRealTop : + | input clock : Clock + | input reset : UInt<1> + | output io : {flip jtag : {TCK : Clock}} + | + | io is invalid + | inst sub of Top + | sub.io is invalid + | inst bb of BlackBoxTop + | bb.jtag is invalid + | bb.jtag <- io.jtag + | + | sub.io.debug_clk <= io.jtag.TCK + | + |""".stripMargin passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) { (c: CircuitState, p: Transform) => p.runTransform(c) } } "Clocks with types other than ClockType" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes) val input = """ - |circuit Top : - | - | module Top : - | input clk : UInt<1> - | input i : UInt<1> - | output o : UInt<1> - | - | reg r : UInt<1>, clk - | r <= i - | o <= r - | - |""".stripMargin + |circuit Top : + | + | module Top : + | input clk : UInt<1> + | input i : UInt<1> + | output o : UInt<1> + | + | reg r : UInt<1>, clk + | r <= i + | o <= r + | + |""".stripMargin intercept[CheckTypes.RegReqClk] { - passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => + p.run(c) } } } "Illegal reset type" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes) val input = """ - |circuit Top : - | - | module Top : - | input clk : Clock - | input reset : UInt<2> - | input i : UInt<1> - | output o : UInt<1> - | - | reg r : UInt<1>, clk with : (reset => (reset, UInt<1>("h00"))) - | r <= i - | o <= r - | - |""".stripMargin + |circuit Top : + | + | module Top : + | input clk : Clock + | input reset : UInt<2> + | input i : UInt<1> + | output o : UInt<1> + | + | reg r : UInt<1>, clk with : (reset => (reset, UInt<1>("h00"))) + | r <= i + | o <= r + | + |""".stripMargin intercept[CheckTypes.IllegalResetType] { - passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => + p.run(c) } } } @@ -281,7 +282,7 @@ class CheckSpec extends AnyFlatSpec with Matchers { val exception = intercept[PassException] { checkHighInput(input) } - exception.getMessage should include (s"Primop $op argument $amount < 0") + exception.getMessage should include(s"Primop $op argument $amount < 0") } } @@ -301,11 +302,11 @@ class CheckSpec extends AnyFlatSpec with Matchers { } } - behavior of "Uniqueness" + behavior.of("Uniqueness") for ((description, input) <- CheckSpec.nonUniqueExamples) { it should s"be asserted for $description" in { assertThrows[CheckHighForm.NotUniqueException] { - Seq(ToWorkingIR, CheckHighForm).foldLeft(Parser.parse(input)){ case (c, tx) => tx.run(c) } + Seq(ToWorkingIR, CheckHighForm).foldLeft(Parser.parse(input)) { case (c, tx) => tx.run(c) } } } } @@ -400,7 +401,7 @@ class CheckSpec extends AnyFlatSpec with Matchers { } } - behavior of "CheckHighForm running on circuits containing ExtModules" + behavior.of("CheckHighForm running on circuits containing ExtModules") it should "throw an exception if parameterless ExtModules have the same ports, but different widths" in { val input = @@ -539,19 +540,17 @@ class CheckSpec extends AnyFlatSpec with Matchers { object CheckSpec { val nonUniqueExamples = List( - ("two ports with the same name", - """|circuit Top: - | module Top: - | input a: UInt<1> - | input a: UInt<1>""".stripMargin), - ("two nodes with the same name", - """|circuit Top: - | module Top: - | node a = UInt<1>("h0") - | node a = UInt<1>("h0")""".stripMargin), - ("a port and a node with the same name", - """|circuit Top: - | module Top: - | input a: UInt<1> - | node a = UInt<1>("h0") """.stripMargin) ) - } + ("two ports with the same name", """|circuit Top: + | module Top: + | input a: UInt<1> + | input a: UInt<1>""".stripMargin), + ("two nodes with the same name", """|circuit Top: + | module Top: + | node a = UInt<1>("h0") + | node a = UInt<1>("h0")""".stripMargin), + ("a port and a node with the same name", """|circuit Top: + | module Top: + | input a: UInt<1> + | node a = UInt<1>("h0") """.stripMargin) + ) +} diff --git a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala index 372ba53b..11a27d65 100644 --- a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala +++ b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala @@ -16,37 +16,37 @@ class ChirrtlMemSpec extends LowFirrtlTransformSpec { type Netlist = collection.mutable.HashMap[String, Expression] def buildNetlist(netlist: Netlist)(s: Statement): Statement = { s match { - case s: Connect => Utils.kind(s.loc) match { - case MemKind => netlist(s.loc.serialize) = s.expr - case _ => - } + case s: Connect => + Utils.kind(s.loc) match { + case MemKind => netlist(s.loc.serialize) = s.expr + case _ => + } case _ => } - s map buildNetlist(netlist) + s.map(buildNetlist(netlist)) } // walks on memories and checks whether or not read enables are high def checkStmt(netlist: Netlist)(s: Statement): Boolean = s match { - case s: DefMemory if s.name == "mem" && s.readers.size == 1=> + case s: DefMemory if s.name == "mem" && s.readers.size == 1 => val en = MemPortUtils.memPortField(s, s.readers.head, "en") // memory read enable ?= 1 WrappedExpression.weq(netlist(en.serialize), Utils.one) case s: Block => - s.stmts exists checkStmt(netlist) + s.stmts.exists(checkStmt(netlist)) case _ => false } - def run (c: Circuit) = { + def run(c: Circuit) = { val errors = new Errors - val check = c.modules exists { + val check = c.modules.exists { case m: Module => val netlist = new Netlist checkStmt(netlist)(buildNetlist(netlist)(m.body)) case m: ExtModule => false } if (!check) { - errors append new PassException( - "Enable signal for the read port is incorrect!") + errors.append(new PassException("Enable signal for the read port is incorrect!")) errors.trigger } c @@ -105,18 +105,18 @@ circuit foo : "An mport that refers to an undefined memory" should "have a helpful error message" in { val input = """circuit testTestModule : - | module testTestModule : - | input clock : Clock - | input reset : UInt<1> - | output io : {flip in : UInt<10>, out : UInt<10>} - | - | node _T_10 = bits(io.in, 1, 0) - | read mport _T_11 = m[_T_10], clock - | io.out <= _T_11""".stripMargin + | module testTestModule : + | input clock : Clock + | input reset : UInt<1> + | output io : {flip in : UInt<10>, out : UInt<10>} + | + | node _T_10 = bits(io.in, 1, 0) + | read mport _T_11 = m[_T_10], clock + | io.out <= _T_11""".stripMargin - intercept[PassException]{ + intercept[PassException] { compile(parse(input)) - }.getMessage should startWith ("Undefined memory m referenced by mport _T_11") + }.getMessage should startWith("Undefined memory m referenced by mport _T_11") } ignore should "Memories should not have validif on port clocks when declared in a when" in { @@ -167,9 +167,19 @@ circuit foo : | io.dataOut <= out @[Stack.scala 31:14] """.stripMargin val res = compile(parse(input)) - assert(res search { - case Connect(_, WSubField(WSubField(WRef("stack_mem", _, _, _), "_T_35",_, _), "clk", _, _), WRef("clock", _, _, _)) => true - case Connect(_, WSubField(WSubField(WRef("stack_mem", _, _, _), "_T_17",_, _), "clk", _, _), WRef("clock", _, _, _)) => true + assert(res.search { + case Connect( + _, + WSubField(WSubField(WRef("stack_mem", _, _, _), "_T_35", _, _), "clk", _, _), + WRef("clock", _, _, _) + ) => + true + case Connect( + _, + WSubField(WSubField(WRef("stack_mem", _, _, _), "_T_17", _, _), "clk", _, _), + WRef("clock", _, _, _) + ) => + true }) } @@ -188,8 +198,9 @@ circuit foo : | out <= bar |""".stripMargin val res = compile(parse(input)) - assert(res search { - case Connect(_, WSubField(WSubField(WRef("mem", _, _, _), "bar",_, _), "clk", _, _), WRef("clock", _, _, _)) => true + assert(res.search { + case Connect(_, WSubField(WSubField(WRef("mem", _, _, _), "bar", _, _), "clk", _, _), WRef("clock", _, _, _)) => + true }) } @@ -209,8 +220,9 @@ circuit foo : | out <= bar |""".stripMargin val res = compile(parse(input)) - assert(res search { - case Connect(_, WSubField(WSubField(WRef("mem", _, _, _), "bar",_, _), "clk", _, _), WRef("clock", _, _, _)) => true + assert(res.search { + case Connect(_, WSubField(WSubField(WRef("mem", _, _, _), "bar", _, _), "clk", _, _), WRef("clock", _, _, _)) => + true }) } @@ -230,12 +242,16 @@ circuit foo : | out <= bar |""".stripMargin val res = new LowFirrtlCompiler().compile(CircuitState(parse(input), ChirrtlForm), Seq()).circuit - assert(res search { - case Connect(_, WSubField(WSubField(WRef("mem", _, _, _), "bar",_, _), "clk", _, _), DoPrim(AsClock, Seq(WRef("clock", _, _, _)), Nil, _)) => true + assert(res.search { + case Connect( + _, + WSubField(WSubField(WRef("mem", _, _, _), "bar", _, _), "clk", _, _), + DoPrim(AsClock, Seq(WRef("clock", _, _, _)), Nil, _) + ) => + true }) } - ignore should "Mem non-local nested clock port assignment should be ok" in { val input = """circuit foo : @@ -251,8 +267,13 @@ circuit foo : | out <= bar |""".stripMargin val res = (new HighFirrtlCompiler).compile(CircuitState(parse(input), ChirrtlForm), Seq()).circuit - assert(res search { - case Connect(_, SubField(SubField(Reference("mem", _, _, _), "bar", _, _), "clk", _, _), DoPrim(AsClock, Seq(Reference("clock", _, _, _)), _, _)) => true + assert(res.search { + case Connect( + _, + SubField(SubField(Reference("mem", _, _, _), "bar", _, _), "clk", _, _), + DoPrim(AsClock, Seq(Reference("clock", _, _, _)), _, _) + ) => + true }) } } diff --git a/src/test/scala/firrtlTests/ChirrtlSpec.scala b/src/test/scala/firrtlTests/ChirrtlSpec.scala index dcc8b872..2d13c835 100644 --- a/src/test/scala/firrtlTests/ChirrtlSpec.scala +++ b/src/test/scala/firrtlTests/ChirrtlSpec.scala @@ -30,49 +30,49 @@ class ChirrtlSpec extends FirrtlFlatSpec { "Chirrtl memories" should "allow ports with clocks defined after the memory" in { val input = - """circuit Unit : - | module Unit : - | input clock : Clock - | smem ram : UInt<32>[128] - | node newClock = clock - | infer mport x = ram[UInt(2)], newClock - | x <= UInt(3) - | when UInt(1) : - | infer mport y = ram[UInt(4)], newClock - | y <= UInt(5) + """circuit Unit : + | module Unit : + | input clock : Clock + | smem ram : UInt<32>[128] + | node newClock = clock + | infer mport x = ram[UInt(2)], newClock + | x <= UInt(3) + | when UInt(1) : + | infer mport y = ram[UInt(4)], newClock + | y <= UInt(5) """.stripMargin val circuit = Parser.parse(input.split("\n").toIterator) - transforms.foldLeft(CircuitState(circuit, UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) + transforms.foldLeft(CircuitState(circuit, UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) } } "Chirrtl" should "catch undeclared wires" in { val input = - """circuit Unit : - | module Unit : - | input clock : Clock - | smem ram : UInt<32>[128] - | node newClock = clock - | infer mport x = ram[UInt(2)], newClock - | x <= UInt(3) - | when UInt(1) : - | infer mport y = ram[UInt(4)], newClock - | y <= z + """circuit Unit : + | module Unit : + | input clock : Clock + | smem ram : UInt<32>[128] + | node newClock = clock + | infer mport x = ram[UInt(2)], newClock + | x <= UInt(3) + | when UInt(1) : + | infer mport y = ram[UInt(4)], newClock + | y <= z """.stripMargin intercept[PassException] { val circuit = Parser.parse(input.split("\n").toIterator) - transforms.foldLeft(CircuitState(circuit, UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) + transforms.foldLeft(CircuitState(circuit, UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) } } } - behavior of "Uniqueness" + behavior.of("Uniqueness") for ((description, input) <- CheckSpec.nonUniqueExamples) { it should s"be asserted for $description" in { assertThrows[CheckHighForm.NotUniqueException] { - Seq(ToWorkingIR, CheckHighForm).foldLeft(Parser.parse(input)){ case (c, tx) => tx.run(c) } + Seq(ToWorkingIR, CheckHighForm).foldLeft(Parser.parse(input)) { case (c, tx) => tx.run(c) } } } } diff --git a/src/test/scala/firrtlTests/ClockListTests.scala b/src/test/scala/firrtlTests/ClockListTests.scala index 9233d4d5..c547448b 100644 --- a/src/test/scala/firrtlTests/ClockListTests.scala +++ b/src/test/scala/firrtlTests/ClockListTests.scala @@ -11,12 +11,12 @@ import clocklist._ class ClockListTests extends FirrtlFlatSpec { private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { - val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) + val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => + p.run(c) } - val lines = c.serialize.split("\n") map normalized + val lines = c.serialize.split("\n").map(normalized) - expected foreach { e => + expected.foreach { e => lines should contain(e) } } @@ -69,19 +69,21 @@ class ClockListTests extends FirrtlFlatSpec { | output clk2: Clock | output clk3: Clock |""".stripMargin - val check = - """Sourcelist: List(h$clkGen$clk1, h$clkGen$clk2, h$clkGen$clk3, clock) - |Good Origin of clock is clock - |Good Origin of h.clock is h$clkGen.clk1 - |Good Origin of h$b.clock is h$clkGen.clk2 - |Good Origin of h$c.clock is h$clkGen.clk3 - |""".stripMargin - val c = passes.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) - }.circuit + val check = + """Sourcelist: List(h$clkGen$clk1, h$clkGen$clk2, h$clkGen$clk3, clock) + |Good Origin of clock is clock + |Good Origin of h.clock is h$clkGen.clk1 + |Good Origin of h$b.clock is h$clkGen.clk2 + |Good Origin of h$c.clock is h$clkGen.clk3 + |""".stripMargin + val c = passes + .foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) + } + .circuit val writer = new StringWriter() val retC = new ClockList("HTop", writer).run(c) - (writer.toString) should be (check) + (writer.toString) should be(check) } "A->B->C, and A.clock == C.clock" should "still emit C.clock origin" in { val input = @@ -101,18 +103,20 @@ class ClockListTests extends FirrtlFlatSpec { | input clock: Clock | reg r: UInt<5>, clock |""".stripMargin - val check = - """Sourcelist: List(clock, clkB) - |Good Origin of clock is clock - |Good Origin of b.clock is clkB - |Good Origin of b$c.clock is clock - |""".stripMargin - val c = passes.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) - }.circuit + val check = + """Sourcelist: List(clock, clkB) + |Good Origin of clock is clock + |Good Origin of b.clock is clkB + |Good Origin of b$c.clock is clock + |""".stripMargin + val c = passes + .foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) + } + .circuit val writer = new StringWriter() val retC = new ClockList("A", writer).run(c) - (writer.toString) should be (check) + (writer.toString) should be(check) } "Have not circuit main be top of clocklist pass" should "still work" in { val input = @@ -136,15 +140,17 @@ class ClockListTests extends FirrtlFlatSpec { | input clock: Clock |""".stripMargin val check = - """Sourcelist: List(clock, clkC) - |Good Origin of clock is clock - |Good Origin of c.clock is clkC - |""".stripMargin - val c = passes.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) - }.circuit + """Sourcelist: List(clock, clkC) + |Good Origin of clock is clock + |Good Origin of c.clock is clkC + |""".stripMargin + val c = passes + .foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) + } + .circuit val writer = new StringWriter() val retC = new ClockList("B", writer).run(c) - (writer.toString) should be (check) + (writer.toString) should be(check) } } diff --git a/src/test/scala/firrtlTests/CompilerTests.scala b/src/test/scala/firrtlTests/CompilerTests.scala index dfa796c4..129ff8f5 100644 --- a/src/test/scala/firrtlTests/CompilerTests.scala +++ b/src/test/scala/firrtlTests/CompilerTests.scala @@ -12,36 +12,36 @@ import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers /** - * An example methodology for testing Firrtl compilers. - * - * Given an input Firrtl circuit (expressed as a string), - * the compiler is executed. The output of the compiler - * should be compared against the check string. - */ + * An example methodology for testing Firrtl compilers. + * + * Given an input Firrtl circuit (expressed as a string), + * the compiler is executed. The output of the compiler + * should be compared against the check string. + */ abstract class CompilerSpec(emitter: Dependency[firrtl.Emitter]) extends LeanTransformSpec(Seq(emitter)) { - def input: String - def getOutput: String = compile(input).getEmittedCircuit.value + def input: String + def getOutput: String = compile(input).getEmittedCircuit.value } /** - * An example test for testing the HighFirrtlCompiler. - * - * Given an input Firrtl circuit (expressed as a string), - * the compiler is executed. The output of the compiler - * is parsed again and compared (in-memory) to the parsed - * input. - */ + * An example test for testing the HighFirrtlCompiler. + * + * Given an input Firrtl circuit (expressed as a string), + * the compiler is executed. The output of the compiler + * is parsed again and compared (in-memory) to the parsed + * input. + */ class HighFirrtlCompilerSpec extends CompilerSpec(Dependency[firrtl.HighFirrtlEmitter]) with Matchers { - val input = -"""circuit Top : + val input = + """circuit Top : module Top : input a : UInt<1>[2] node x = a """ - val check = input - "Any circuit" should "match exactly to its input" in { - (parse(getOutput)) should be (parse(check)) - } + val check = input + "Any circuit" should "match exactly to its input" in { + (parse(getOutput)) should be(parse(check)) + } } /** @@ -53,8 +53,8 @@ class HighFirrtlCompilerSpec extends CompilerSpec(Dependency[firrtl.HighFirrtlEm * string compared to the correct lowered circuit. */ class MiddleFirrtlCompilerSpec extends CompilerSpec(Dependency[firrtl.MiddleFirrtlEmitter]) with Matchers { - val input = - """ + val input = + """ circuit Top : module Top : input reset : UInt<1> @@ -64,77 +64,77 @@ circuit Top : when reset : b <= UInt(0) """ - // Verify that Vecs are retained, but widths are inferred and whens are expanded. - val check = Seq( - "circuit Top :", - " module Top :", - " input reset : UInt<1>", - " input a : UInt<1>[2]", - " wire b : UInt<1>", - " node _GEN_0 = mux(reset, UInt<1>(\"h0\"), a[0])", - " b <= _GEN_0\n\n" - ).reduce(_ + "\n" + _) - "A circuit" should "match exactly to its MidForm state" in { - (parse(getOutput)) should be (parse(check)) - } + // Verify that Vecs are retained, but widths are inferred and whens are expanded. + val check = Seq( + "circuit Top :", + " module Top :", + " input reset : UInt<1>", + " input a : UInt<1>[2]", + " wire b : UInt<1>", + " node _GEN_0 = mux(reset, UInt<1>(\"h0\"), a[0])", + " b <= _GEN_0\n\n" + ).reduce(_ + "\n" + _) + "A circuit" should "match exactly to its MidForm state" in { + (parse(getOutput)) should be(parse(check)) + } } /** - * An example test for testing the LoweringCompiler. - * - * Given an input Firrtl circuit (expressed as a string), - * the compiler is executed. The output of the compiler is - * a lowered version of the input circuit. The output is - * string compared to the correct lowered circuit. - */ + * An example test for testing the LoweringCompiler. + * + * Given an input Firrtl circuit (expressed as a string), + * the compiler is executed. The output of the compiler is + * a lowered version of the input circuit. The output is + * string compared to the correct lowered circuit. + */ class LowFirrtlCompilerSpec extends CompilerSpec(Dependency[firrtl.LowFirrtlEmitter]) with Matchers { - val input = -""" + val input = + """ circuit Top : module Top : input a : UInt<1>[2] node x = a """ - val check = Seq( - "circuit Top :", - " module Top :", - " input a_0 : UInt<1>", - " input a_1 : UInt<1>", - " node x_0 = a_0", - " node x_1 = a_1\n\n" - ).reduce(_ + "\n" + _) - "A circuit" should "match exactly to its lowered state" in { - (parse(getOutput)) should be (parse(check)) - } + val check = Seq( + "circuit Top :", + " module Top :", + " input a_0 : UInt<1>", + " input a_1 : UInt<1>", + " node x_0 = a_0", + " node x_1 = a_1\n\n" + ).reduce(_ + "\n" + _) + "A circuit" should "match exactly to its lowered state" in { + (parse(getOutput)) should be(parse(check)) + } } /** - * An example test for testing the VerilogCompiler. - * - * Given an input Firrtl circuit (expressed as a string), - * the compiler is executed. The output of the compiler is - * the corresponding Verilog. The output is string compared - * to the correct Verilog. - */ + * An example test for testing the VerilogCompiler. + * + * Given an input Firrtl circuit (expressed as a string), + * the compiler is executed. The output of the compiler is + * the corresponding Verilog. The output is string compared + * to the correct Verilog. + */ class VerilogCompilerSpec extends CompilerSpec(Dependency[firrtl.VerilogEmitter]) with Matchers { - val input = """circuit Top : - | module Top : - | input a : UInt<1>[2] - | output b : UInt<1>[2] - | b <= a""".stripMargin - val check = """module Top( - | input a_0, - | input a_1, - | output b_0, - | output b_1 - |); - | assign b_0 = a_0; - | assign b_1 = a_1; - |endmodule - |""".stripMargin - "A circuit's verilog output" should "match the given string and not have RANDOMIZE if no invalids" in { - getOutput should be (check) - } + val input = """circuit Top : + | module Top : + | input a : UInt<1>[2] + | output b : UInt<1>[2] + | b <= a""".stripMargin + val check = """module Top( + | input a_0, + | input a_1, + | output b_0, + | output b_1 + |); + | assign b_0 = a_0; + | assign b_1 = a_1; + |endmodule + |""".stripMargin + "A circuit's verilog output" should "match the given string and not have RANDOMIZE if no invalids" in { + getOutput should be(check) + } } class MinimumVerilogCompilerSpec extends CompilerSpec(Dependency[firrtl.MinimumVerilogEmitter]) with Matchers { @@ -166,6 +166,6 @@ class MinimumVerilogCompilerSpec extends CompilerSpec(Dependency[firrtl.MinimumV |endmodule |""".stripMargin "A circuit's minimum Verilog output" should "pad signed RHSes but not reflect any const-prop or DCE" in { - getOutput should be (check) + getOutput should be(check) } } diff --git a/src/test/scala/firrtlTests/CompilerUtilsSpec.scala b/src/test/scala/firrtlTests/CompilerUtilsSpec.scala index bfb53ce1..530a036a 100644 --- a/src/test/scala/firrtlTests/CompilerUtilsSpec.scala +++ b/src/test/scala/firrtlTests/CompilerUtilsSpec.scala @@ -30,39 +30,39 @@ class CompilerUtilsSpec extends FirrtlFlatSpec { val lowToLowTwo = genTransform(LowForm, LowForm) - behavior of "mergeTransforms" + behavior.of("mergeTransforms") it should "do nothing if there are no custom transforms" in { - mergeTransforms(chirrtlToLowList, List.empty) should be (chirrtlToLowList) + mergeTransforms(chirrtlToLowList, List.empty) should be(chirrtlToLowList) } it should "insert transforms at the correct place" in { mergeTransforms(chirrtlToLowList, List(chirrtlToChirrtl)) should be - (chirrtlToChirrtl +: chirrtlToLowList) + (chirrtlToChirrtl +: chirrtlToLowList) mergeTransforms(chirrtlToLowList, List(highToHigh)) should be - (List(chirrtlToHigh, highToHigh, highToMid, midToLow)) + (List(chirrtlToHigh, highToHigh, highToMid, midToLow)) mergeTransforms(chirrtlToLowList, List(midToMid)) should be - (List(chirrtlToHigh, highToMid, midToMid, midToLow)) + (List(chirrtlToHigh, highToMid, midToMid, midToLow)) mergeTransforms(chirrtlToLowList, List(lowToLow)) should be - (chirrtlToLowList :+ lowToLow) + (chirrtlToLowList :+ lowToLow) } it should "insert transforms at the last legal location" in { lowToLow should not be (lowToLowTwo) // sanity check - mergeTransforms(chirrtlToLowList :+ lowToLow, List(lowToLowTwo)).last should be (lowToLowTwo) + mergeTransforms(chirrtlToLowList :+ lowToLow, List(lowToLowTwo)).last should be(lowToLowTwo) } it should "insert multiple transforms correctly" in { mergeTransforms(chirrtlToLowList, List(highToHigh, lowToLow)) should be - (List(chirrtlToHigh, highToHigh, highToMid, midToLow, lowToLow)) + (List(chirrtlToHigh, highToHigh, highToMid, midToLow, lowToLow)) } it should "handle transforms that raise the form" in { mergeTransforms(chirrtlToLowList, List(lowToHigh)) match { case chirrtlToHigh :: highToMid :: midToLow :: lowToHigh :: remainder => // Remainder will be the actual Firrtl lowering transforms - remainder.head.inputForm should be (HighForm) - remainder.last.outputForm should be (LowForm) + remainder.head.inputForm should be(HighForm) + remainder.last.outputForm should be(LowForm) case _ => fail() } } @@ -70,8 +70,7 @@ class CompilerUtilsSpec extends FirrtlFlatSpec { // Order is not always maintained, see note on function Scaladoc it should "maintain order of custom tranforms" in { mergeTransforms(chirrtlToLowList, List(lowToLow, lowToLowTwo)) should be - (chirrtlToLowList ++ List(lowToLow, lowToLowTwo)) + (chirrtlToLowList ++ List(lowToLow, lowToLowTwo)) } } - diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index efe85e48..6ab54159 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -9,24 +9,22 @@ import firrtl.testutils._ import firrtl.annotations.Annotation class ConstantPropagationSpec extends FirrtlFlatSpec { - val transforms: Seq[Transform] = Seq( - ToWorkingIR, - ResolveKinds, - InferTypes, - ResolveFlows, - new InferWidths, - new ConstantPropagation) + val transforms: Seq[Transform] = + Seq(ToWorkingIR, ResolveKinds, InferTypes, ResolveFlows, new InferWidths, new ConstantPropagation) protected def exec(input: String, annos: Seq[Annotation] = Nil) = { - transforms.foldLeft(CircuitState(parse(input), UnknownForm, AnnotationSeq(annos))) { - (c: CircuitState, t: Transform) => t.runTransform(c) - }.circuit.serialize + transforms + .foldLeft(CircuitState(parse(input), UnknownForm, AnnotationSeq(annos))) { (c: CircuitState, t: Transform) => + t.runTransform(c) + } + .circuit + .serialize } } class ConstantPropagationMultiModule extends ConstantPropagationSpec { - "ConstProp" should "propagate constant inputs" in { - val input = -"""circuit Top : + "ConstProp" should "propagate constant inputs" in { + val input = + """circuit Top : module Child : input in0 : UInt<1> input in1 : UInt<1> @@ -40,8 +38,8 @@ class ConstantPropagationMultiModule extends ConstantPropagationSpec { c.in1 <= UInt<1>(1) z <= c.out """ - val check = -"""circuit Top : + val check = + """circuit Top : module Child : input in0 : UInt<1> input in1 : UInt<1> @@ -55,12 +53,12 @@ class ConstantPropagationMultiModule extends ConstantPropagationSpec { c.in1 <= UInt<1>(1) z <= c.out """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - "ConstProp" should "propagate constant inputs ONLY if ALL instance inputs get the same value" in { - def circuit(allSame: Boolean) = -s"""circuit Top : + "ConstProp" should "propagate constant inputs ONLY if ALL instance inputs get the same value" in { + def circuit(allSame: Boolean) = + s"""circuit Top : module Bottom : input in : UInt<1> output out : UInt<1> @@ -83,8 +81,8 @@ s"""circuit Top : z <= and(and(b0.out, b1.out), c.out) """ - val resultFromAllSame = -"""circuit Top : + val resultFromAllSame = + """circuit Top : module Bottom : input in : UInt<1> output out : UInt<1> @@ -104,14 +102,14 @@ s"""circuit Top : b1.in <= UInt(1) z <= UInt(1) """ - (parse(exec(circuit(false)))) should be (parse(circuit(false))) - (parse(exec(circuit(true)))) should be (parse(resultFromAllSame)) - } - - // ============================= - "ConstProp" should "do nothing on unrelated modules" in { - val input = -"""circuit foo : + (parse(exec(circuit(false)))) should be(parse(circuit(false))) + (parse(exec(circuit(true)))) should be(parse(resultFromAllSame)) + } + + // ============================= + "ConstProp" should "do nothing on unrelated modules" in { + val input = + """circuit foo : module foo : input dummy : UInt<1> skip @@ -120,14 +118,14 @@ s"""circuit Top : input dummy : UInt<1> skip """ - val check = input - (parse(exec(input))) should be (parse(check)) - } - - // ============================= - "ConstProp" should "propagate module chains not connected to the top" in { - val input = -"""circuit foo : + val check = input + (parse(exec(input))) should be(parse(check)) + } + + // ============================= + "ConstProp" should "propagate module chains not connected to the top" in { + val input = + """circuit foo : module foo : input dummy : UInt<1> skip @@ -151,8 +149,8 @@ s"""circuit Top : output test : UInt<1> test <= UInt<1>(0) """ - val check = -"""circuit foo : + val check = + """circuit foo : module foo : input dummy : UInt<1> skip @@ -176,8 +174,8 @@ s"""circuit Top : output test : UInt<1> test <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } } // Tests the following cases for constant propagation: @@ -188,332 +186,332 @@ s"""circuit Top : // 3) Values are always greater than a number smaller // than their minimum value class ConstantPropagationSingleModule extends ConstantPropagationSpec { - // ============================= - "The rule x >= 0 " should " always be true if x is a UInt" in { - val input = -"""circuit Top : + // ============================= + "The rule x >= 0 " should " always be true if x is a UInt" in { + val input = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= geq(x, UInt(0)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= UInt<1>("h1") """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule x < 0 " should " never be true if x is a UInt" in { - val input = -"""circuit Top : + // ============================= + "The rule x < 0 " should " never be true if x is a UInt" in { + val input = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= lt(x, UInt(0)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 0 <= x " should " always be true if x is a UInt" in { - val input = -"""circuit Top : + // ============================= + "The rule 0 <= x " should " always be true if x is a UInt" in { + val input = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= leq(UInt(0),x) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= UInt<1>(1) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 0 > x " should " never be true if x is a UInt" in { - val input = -"""circuit Top : + // ============================= + "The rule 0 > x " should " never be true if x is a UInt" in { + val input = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= gt(UInt(0),x) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 1 < 3 " should " always be true" in { - val input = -"""circuit Top : + // ============================= + "The rule 1 < 3 " should " always be true" in { + val input = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= lt(UInt(0),UInt(3)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= UInt<1>(1) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule x < 8 " should " always be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule x < 8 " should " always be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= lt(x,UInt(8)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(1) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule x <= 7 " should " always be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule x <= 7 " should " always be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= leq(x,UInt(7)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(1) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 8 > x" should " always be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule 8 > x" should " always be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= gt(UInt(8),x) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(1) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 7 >= x" should " always be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule 7 >= x" should " always be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= geq(UInt(7),x) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(1) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 10 == 10" should " always be true" in { - val input = -"""circuit Top : + // ============================= + "The rule 10 == 10" should " always be true" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= eq(UInt(10),UInt(10)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(1) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule x == z " should " not be true even if they have the same number of bits" in { - val input = -"""circuit Top : + // ============================= + "The rule x == z " should " not be true even if they have the same number of bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> input z : UInt<3> output y : UInt<1> y <= eq(x,z) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> input z : UInt<3> output y : UInt<1> y <= eq(x,z) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 10 != 10 " should " always be false" in { - val input = -"""circuit Top : + // ============================= + "The rule 10 != 10 " should " always be false" in { + val input = + """circuit Top : module Top : output y : UInt<1> y <= neq(UInt(10),UInt(10)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : output y : UInt<1> y <= UInt(0) """ - (parse(exec(input))) should be (parse(check)) - } - // ============================= - "The rule 1 >= 3 " should " always be false" in { - val input = -"""circuit Top : + (parse(exec(input))) should be(parse(check)) + } + // ============================= + "The rule 1 >= 3 " should " always be false" in { + val input = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= geq(UInt(1),UInt(3)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule x >= 8 " should " never be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule x >= 8 " should " never be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= geq(x,UInt(8)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule x > 7 " should " never be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule x > 7 " should " never be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= gt(x,UInt(7)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 8 <= x" should " never be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule 8 <= x" should " never be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= leq(UInt(8),x) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 7 < x" should " never be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule 7 < x" should " never be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= lt(UInt(7),x) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "ConstProp" should "work across wires" in { - val input = -"""circuit Top : + // ============================= + "ConstProp" should "work across wires" in { + val input = + """circuit Top : module Top : input x : UInt<1> output y : UInt<1> @@ -521,8 +519,8 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { y <= z z <= mux(x, UInt<1>(0), UInt<1>(0)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<1> output y : UInt<1> @@ -530,13 +528,13 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { y <= UInt<1>(0) z <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "ConstProp" should "swap named nodes with temporary nodes that drive them" in { - val input = -"""circuit Top : + // ============================= + "ConstProp" should "swap named nodes with temporary nodes that drive them" in { + val input = + """circuit Top : module Top : input x : UInt<1> input y : UInt<1> @@ -545,8 +543,8 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { node n = _T_1 z <= and(n, x) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<1> input y : UInt<1> @@ -555,13 +553,13 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { node _T_1 = n z <= and(n, x) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "ConstProp" should "swap named nodes with temporary wires that drive them" in { - val input = -"""circuit Top : + // ============================= + "ConstProp" should "swap named nodes with temporary wires that drive them" in { + val input = + """circuit Top : module Top : input x : UInt<1> input y : UInt<1> @@ -571,8 +569,8 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { z <= n _T_1 <= and(x, y) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<1> input y : UInt<1> @@ -582,13 +580,13 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { z <= n n <= and(x, y) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "ConstProp" should "swap named nodes with temporary registers that drive them" in { - val input = -"""circuit Top : + // ============================= + "ConstProp" should "swap named nodes with temporary registers that drive them" in { + val input = + """circuit Top : module Top : input clock : Clock input x : UInt<1> @@ -598,8 +596,8 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { z <= n _T_1 <= x """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input clock : Clock input x : UInt<1> @@ -609,13 +607,13 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { z <= n n <= x """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "ConstProp" should "only swap a given name with one other name" in { - val input = -"""circuit Top : + // ============================= + "ConstProp" should "only swap a given name with one other name" in { + val input = + """circuit Top : module Top : input x : UInt<1> input y : UInt<1> @@ -625,8 +623,8 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { node m = _T_1 z <= add(n, m) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<1> input y : UInt<1> @@ -636,12 +634,12 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { node m = n z <= add(n, n) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - "ConstProp" should "NOT swap wire names with node names" in { - val input = -"""circuit Top : + "ConstProp" should "NOT swap wire names with node names" in { + val input = + """circuit Top : module Top : input clock : Clock input x : UInt<1> @@ -653,8 +651,8 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { hit <= _T_2 z <= hit """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input clock : Clock input x : UInt<1> @@ -666,12 +664,12 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { hit <= or(x, y) z <= hit """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - "ConstProp" should "propagate constant outputs" in { - val input = -"""circuit Top : + "ConstProp" should "propagate constant outputs" in { + val input = + """circuit Top : module Child : output out : UInt<1> out <= UInt<1>(0) @@ -681,8 +679,8 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { inst c of Child z <= and(x, c.out) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Child : output out : UInt<1> out <= UInt<1>(0) @@ -692,10 +690,10 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { inst c of Child z <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - "ConstProp" should "propagate constant addition" in { + "ConstProp" should "propagate constant addition" in { val input = """circuit Top : | module Top : @@ -717,7 +715,7 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { (parse(exec(input))) should be(parse(check)) } - "ConstProp" should "propagate addition with zero" in { + "ConstProp" should "propagate addition with zero" in { val input = """circuit Top : | module Top : @@ -779,20 +777,20 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { def castCheck(tpe: String, cast: String): Unit = { val input = - s"""circuit Top : - | module Top : - | input x : $tpe - | output z : $tpe - | z <= $cast(x) + s"""circuit Top : + | module Top : + | input x : $tpe + | output z : $tpe + | z <= $cast(x) """.stripMargin val check = - s"""circuit Top : - | module Top : - | input x : $tpe - | output z : $tpe - | z <= x + s"""circuit Top : + | module Top : + | input x : $tpe + | output z : $tpe + | z <= x """.stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } it should "optimize unnecessary casts" in { castCheck("UInt<4>", "asUInt") @@ -807,218 +805,217 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { def transform = new LowFirrtlOptimization "ConstProp" should "NOT optimize across dontTouch on nodes" in { - val input = - """circuit Top : - | module Top : - | input x : UInt<1> - | output y : UInt<1> - | node z = x - | y <= z""".stripMargin - val check = input + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output y : UInt<1> + | node z = x + | y <= z""".stripMargin + val check = input execute(input, check, Seq(dontTouch("Top.z"))) } it should "NOT optimize across nodes marked dontTouch by other annotations" in { - val input = - """circuit Top : - | module Top : - | input x : UInt<1> - | output y : UInt<1> - | node z = x - | y <= z""".stripMargin - val check = input - val dontTouchRT = annotations.ModuleTarget("Top", "Top").ref("z") + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output y : UInt<1> + | node z = x + | y <= z""".stripMargin + val check = input + val dontTouchRT = annotations.ModuleTarget("Top", "Top").ref("z") execute(input, check, Seq(AnnotationWithDontTouches(dontTouchRT))) } it should "NOT optimize across dontTouch on registers" in { - val input = - """circuit Top : - | module Top : - | input clk : Clock - | input reset : UInt<1> - | output y : UInt<1> - | reg z : UInt<1>, clk - | y <= z - | z <= mux(reset, UInt<1>("h0"), z)""".stripMargin - val check = input + val input = + """circuit Top : + | module Top : + | input clk : Clock + | input reset : UInt<1> + | output y : UInt<1> + | reg z : UInt<1>, clk + | y <= z + | z <= mux(reset, UInt<1>("h0"), z)""".stripMargin + val check = input execute(input, check, Seq(dontTouch("Top.z"))) } - it should "NOT optimize across dontTouch on wires" in { - val input = - """circuit Top : - | module Top : - | input x : UInt<1> - | output y : UInt<1> - | wire z : UInt<1> - | y <= z - | z <= x""".stripMargin - val check = - """circuit Top : - | module Top : - | input x : UInt<1> - | output y : UInt<1> - | node z = x - | y <= z""".stripMargin + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output y : UInt<1> + | wire z : UInt<1> + | y <= z + | z <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output y : UInt<1> + | node z = x + | y <= z""".stripMargin execute(input, check, Seq(dontTouch("Top.z"))) } it should "NOT optimize across dontTouch on output ports" in { val input = """circuit Top : - | module Child : - | output out : UInt<1> - | out <= UInt<1>(0) - | module Top : - | input x : UInt<1> - | output z : UInt<1> - | inst c of Child - | z <= and(x, c.out)""".stripMargin - val check = input + | module Child : + | output out : UInt<1> + | out <= UInt<1>(0) + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst c of Child + | z <= and(x, c.out)""".stripMargin + val check = input execute(input, check, Seq(dontTouch("Child.out"))) } it should "NOT optimize across dontTouch on input ports" in { val input = """circuit Top : - | module Child : - | input in0 : UInt<1> - | input in1 : UInt<1> - | output out : UInt<1> - | out <= and(in0, in1) - | module Top : - | input x : UInt<1> - | output z : UInt<1> - | inst c of Child - | z <= c.out - | c.in0 <= x - | c.in1 <= UInt<1>(1)""".stripMargin - val check = input + | module Child : + | input in0 : UInt<1> + | input in1 : UInt<1> + | output out : UInt<1> + | out <= and(in0, in1) + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst c of Child + | z <= c.out + | c.in0 <= x + | c.in1 <= UInt<1>(1)""".stripMargin + val check = input execute(input, check, Seq(dontTouch("Child.in1"))) } it should "still propagate constants even when there is name swapping" in { - val input = - """circuit Top : - | module Top : - | input x : UInt<1> - | input y : UInt<1> - | output z : UInt<1> - | node _T_1 = and(and(x, y), UInt<1>(0)) - | node n = _T_1 - | z <= n""".stripMargin - val check = - """circuit Top : - | module Top : - | input x : UInt<1> - | input y : UInt<1> - | output z : UInt<1> - | z <= UInt<1>(0)""".stripMargin + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | input y : UInt<1> + | output z : UInt<1> + | node _T_1 = and(and(x, y), UInt<1>(0)) + | node n = _T_1 + | z <= n""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | input y : UInt<1> + | output z : UInt<1> + | z <= UInt<1>(0)""".stripMargin execute(input, check, Seq.empty) } it should "pad constant connections to wires when propagating" in { - val input = - """circuit Top : - | module Top : - | output z : UInt<16> - | wire w : { a : UInt<8>, b : UInt<8> } - | w.a <= UInt<2>("h3") - | w.b <= UInt<2>("h3") - | z <= cat(w.a, w.b)""".stripMargin - val check = - """circuit Top : - | module Top : - | output z : UInt<16> - | z <= UInt<16>("h303")""".stripMargin + val input = + """circuit Top : + | module Top : + | output z : UInt<16> + | wire w : { a : UInt<8>, b : UInt<8> } + | w.a <= UInt<2>("h3") + | w.b <= UInt<2>("h3") + | z <= cat(w.a, w.b)""".stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<16> + | z <= UInt<16>("h303")""".stripMargin execute(input, check, Seq.empty) } it should "pad constant connections to registers when propagating" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | output z : UInt<16> - | reg r : { a : UInt<8>, b : UInt<8> }, clock - | r.a <= UInt<2>("h3") - | r.b <= UInt<2>("h3") - | z <= cat(r.a, r.b)""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | output z : UInt<16> - | z <= UInt<16>("h303")""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | output z : UInt<16> + | reg r : { a : UInt<8>, b : UInt<8> }, clock + | r.a <= UInt<2>("h3") + | r.b <= UInt<2>("h3") + | z <= cat(r.a, r.b)""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | output z : UInt<16> + | z <= UInt<16>("h303")""".stripMargin execute(input, check, Seq.empty) } it should "pad zero when constant propping a register replaced with zero" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | output z : UInt<16> - | reg r : UInt<8>, clock - | r <= or(r, UInt(0)) - | node n = UInt("hab") - | z <= cat(n, r)""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | output z : UInt<16> - | z <= UInt<16>("hab00")""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | output z : UInt<16> + | reg r : UInt<8>, clock + | r <= or(r, UInt(0)) + | node n = UInt("hab") + | z <= cat(n, r)""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | output z : UInt<16> + | z <= UInt<16>("hab00")""".stripMargin execute(input, check, Seq.empty) } it should "pad constant connections to outputs when propagating" in { - val input = - """circuit Top : - | module Child : - | output x : UInt<8> - | x <= UInt<2>("h3") - | module Top : - | output z : UInt<16> - | inst c of Child - | z <= cat(UInt<2>("h3"), c.x)""".stripMargin - val check = - """circuit Top : - | module Top : - | output z : UInt<16> - | z <= UInt<16>("h303")""".stripMargin + val input = + """circuit Top : + | module Child : + | output x : UInt<8> + | x <= UInt<2>("h3") + | module Top : + | output z : UInt<16> + | inst c of Child + | z <= cat(UInt<2>("h3"), c.x)""".stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<16> + | z <= UInt<16>("h303")""".stripMargin execute(input, check, Seq.empty) } it should "pad constant connections to submodule inputs when propagating" in { - val input = - """circuit Top : - | module Child : - | input x : UInt<8> - | output y : UInt<16> - | y <= cat(UInt<2>("h3"), x) - | module Top : - | output z : UInt<16> - | inst c of Child - | c.x <= UInt<2>("h3") - | z <= c.y""".stripMargin - val check = - """circuit Top : - | module Top : - | output z : UInt<16> - | z <= UInt<16>("h303")""".stripMargin + val input = + """circuit Top : + | module Child : + | input x : UInt<8> + | output y : UInt<16> + | y <= cat(UInt<2>("h3"), x) + | module Top : + | output z : UInt<16> + | inst c of Child + | c.x <= UInt<2>("h3") + | z <= c.y""".stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<16> + | z <= UInt<16>("h303")""".stripMargin execute(input, check, Seq.empty) } it should "remove pads if the width is <= the width of the argument" in { def input(w: Int) = - s"""circuit Top : - | module Top : - | input x : UInt<8> - | output z : UInt<8> - | z <= pad(x, $w)""".stripMargin + s"""circuit Top : + | module Top : + | input x : UInt<8> + | output z : UInt<8> + | z <= pad(x, $w)""".stripMargin val check = """circuit Top : | module Top : @@ -1029,247 +1026,246 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { execute(input(8), check, Seq.empty) } - "Registers with no reset or connections" should "be replaced with constant zero" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | output z : UInt<8> - | reg r : UInt<8>, clock - | z <= r""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | output z : UInt<8> - | z <= UInt<8>(0)""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | output z : UInt<8> + | reg r : UInt<8>, clock + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | output z : UInt<8> + | z <= UInt<8>(0)""".stripMargin execute(input, check, Seq.empty) } "Registers with ONLY constant reset" should "be replaced with that constant" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | output z : UInt<8> - | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) - | z <= r""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | output z : UInt<8> - | z <= UInt<8>("hb")""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | output z : UInt<8> + | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | output z : UInt<8> + | z <= UInt<8>("hb")""".stripMargin execute(input, check, Seq.empty) } "Registers async reset and a constant connection" should "NOT be removed" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : AsyncReset - | input en : UInt<1> - | output z : UInt<8> - | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) - | when en : - | r <= UInt<4>("h0") - | z <= r""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : AsyncReset - | input en : UInt<1> - | output z : UInt<8> - | reg r : UInt<8>, clock with : - | reset => (reset, UInt<8>("hb")) - | z <= r - | r <= mux(en, UInt<8>("h0"), r)""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : AsyncReset + | input en : UInt<1> + | output z : UInt<8> + | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) + | when en : + | r <= UInt<4>("h0") + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : AsyncReset + | input en : UInt<1> + | output z : UInt<8> + | reg r : UInt<8>, clock with : + | reset => (reset, UInt<8>("hb")) + | z <= r + | r <= mux(en, UInt<8>("h0"), r)""".stripMargin execute(input, check, Seq.empty) } "Registers with constant reset and connection to the same constant" should "be replaced with that constant" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | input cond : UInt<1> - | output z : UInt<8> - | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) - | when cond : - | r <= UInt<4>("hb") - | z <= r""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | input cond : UInt<1> - | output z : UInt<8> - | z <= UInt<8>("hb")""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | input cond : UInt<1> + | output z : UInt<8> + | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) + | when cond : + | r <= UInt<4>("hb") + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | input cond : UInt<1> + | output z : UInt<8> + | z <= UInt<8>("hb")""".stripMargin execute(input, check, Seq.empty) } "Const prop of registers" should "do limited speculative expansion of optimized muxes to absorb bigger cones" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | input en : UInt<1> - | output out : UInt<1> - | reg r1 : UInt<1>, clock - | reg r2 : UInt<1>, clock - | when en : - | r1 <= UInt<1>(1) - | r2 <= UInt<1>(0) - | when en : - | r2 <= r2 - | out <= xor(r1, r2)""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | input en : UInt<1> - | output out : UInt<1> - | out <= UInt<1>("h1")""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input en : UInt<1> + | output out : UInt<1> + | reg r1 : UInt<1>, clock + | reg r2 : UInt<1>, clock + | when en : + | r1 <= UInt<1>(1) + | r2 <= UInt<1>(0) + | when en : + | r2 <= r2 + | out <= xor(r1, r2)""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input en : UInt<1> + | output out : UInt<1> + | out <= UInt<1>("h1")""".stripMargin execute(input, check, Seq.empty) } "A register with constant reset and all connection to either itself or the same constant" should "be replaced with that constant" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | input cmd : UInt<3> - | output z : UInt<8> - | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("h7"))) - | r <= r - | when eq(cmd, UInt<3>("h0")) : - | r <= UInt<3>("h7") - | else : - | when eq(cmd, UInt<3>("h1")) : - | r <= r - | else : - | when eq(cmd, UInt<3>("h2")) : - | r <= UInt<4>("h7") - | else : - | r <= r - | z <= r""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | input cmd : UInt<3> - | output z : UInt<8> - | z <= UInt<8>("h7")""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | input cmd : UInt<3> + | output z : UInt<8> + | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("h7"))) + | r <= r + | when eq(cmd, UInt<3>("h0")) : + | r <= UInt<3>("h7") + | else : + | when eq(cmd, UInt<3>("h1")) : + | r <= r + | else : + | when eq(cmd, UInt<3>("h2")) : + | r <= UInt<4>("h7") + | else : + | r <= r + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | input cmd : UInt<3> + | output z : UInt<8> + | z <= UInt<8>("h7")""".stripMargin execute(input, check, Seq.empty) } "Registers with ONLY constant connection" should "be replaced with that constant" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | output z : SInt<8> - | reg r : SInt<8>, clock - | r <= SInt<4>(-5) - | z <= r""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | output z : SInt<8> - | z <= SInt<8>(-5)""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | output z : SInt<8> + | reg r : SInt<8>, clock + | r <= SInt<4>(-5) + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | output z : SInt<8> + | z <= SInt<8>(-5)""".stripMargin execute(input, check, Seq.empty) } "Registers with identical constant reset and connection" should "be replaced with that constant" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | output z : UInt<8> - | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) - | r <= UInt<4>("hb") - | z <= r""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | output z : UInt<8> - | z <= UInt<8>("hb")""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | output z : UInt<8> + | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) + | r <= UInt<4>("hb") + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | output z : UInt<8> + | z <= UInt<8>("hb")""".stripMargin execute(input, check, Seq.empty) } "Connections to a node reference" should "be replaced with the rhs of that node" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<8> - | input b : UInt<8> - | input c : UInt<1> - | output z : UInt<8> - | node x = mux(c, a, b) - | z <= x""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<8> - | input b : UInt<8> - | input c : UInt<1> - | output z : UInt<8> - | z <= mux(c, a, b)""".stripMargin + val input = + """circuit Top : + | module Top : + | input a : UInt<8> + | input b : UInt<8> + | input c : UInt<1> + | output z : UInt<8> + | node x = mux(c, a, b) + | z <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<8> + | input b : UInt<8> + | input c : UInt<1> + | output z : UInt<8> + | z <= mux(c, a, b)""".stripMargin execute(input, check, Seq.empty) } "Registers connected only to themselves" should "be replaced with zero" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | output a : UInt<8> - | reg ra : UInt<8>, clock - | ra <= ra - | a <= ra - |""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | output a : UInt<8> - | a <= UInt<8>(0) - |""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | output a : UInt<8> + | reg ra : UInt<8>, clock + | ra <= ra + | a <= ra + |""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | output a : UInt<8> + | a <= UInt<8>(0) + |""".stripMargin execute(input, check, Seq.empty) } "Registers connected only to themselves from constant propagation" should "be replaced with zero" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | output a : UInt<8> - | reg ra : UInt<8>, clock - | ra <= or(ra, UInt(0)) - | a <= ra - |""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | output a : UInt<8> - | a <= UInt<8>(0) - |""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | output a : UInt<8> + | reg ra : UInt<8>, clock + | ra <= or(ra, UInt(0)) + | a <= ra + |""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | output a : UInt<8> + | a <= UInt<8>(0) + |""".stripMargin execute(input, check, Seq.empty) } @@ -1290,7 +1286,7 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { execute(input, check, Seq.empty) } - behavior of "ConstProp" + behavior.of("ConstProp") it should "optimize shl of constants" in { val input = @@ -1381,30 +1377,30 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { it should "optimize some binary operations when arguments match" in { // Signedness matters - matchingArgs("sub", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """ ) - matchingArgs("sub", "SInt<8>", "SInt<8>", """ SInt<8>("h0") """ ) - matchingArgs("div", "UInt<8>", "UInt<8>", """ UInt<8>("h1") """ ) - matchingArgs("div", "SInt<8>", "SInt<8>", """ SInt<8>("h1") """ ) - matchingArgs("rem", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """ ) - matchingArgs("rem", "SInt<8>", "SInt<8>", """ SInt<8>("h0") """ ) - matchingArgs("and", "UInt<8>", "UInt<8>", """ i """ ) - matchingArgs("and", "SInt<8>", "UInt<8>", """ asUInt(i) """ ) + matchingArgs("sub", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """) + matchingArgs("sub", "SInt<8>", "SInt<8>", """ SInt<8>("h0") """) + matchingArgs("div", "UInt<8>", "UInt<8>", """ UInt<8>("h1") """) + matchingArgs("div", "SInt<8>", "SInt<8>", """ SInt<8>("h1") """) + matchingArgs("rem", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """) + matchingArgs("rem", "SInt<8>", "SInt<8>", """ SInt<8>("h0") """) + matchingArgs("and", "UInt<8>", "UInt<8>", """ i """) + matchingArgs("and", "SInt<8>", "UInt<8>", """ asUInt(i) """) // Signedness doesn't matter - matchingArgs("or", "UInt<8>", "UInt<8>", """ i """ ) - matchingArgs("or", "SInt<8>", "UInt<8>", """ asUInt(i) """ ) - matchingArgs("xor", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """ ) - matchingArgs("xor", "SInt<8>", "UInt<8>", """ UInt<8>("h0") """ ) + matchingArgs("or", "UInt<8>", "UInt<8>", """ i """) + matchingArgs("or", "SInt<8>", "UInt<8>", """ asUInt(i) """) + matchingArgs("xor", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """) + matchingArgs("xor", "SInt<8>", "UInt<8>", """ UInt<8>("h0") """) // Always true - matchingArgs("eq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """ ) - matchingArgs("leq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """ ) - matchingArgs("geq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """ ) + matchingArgs("eq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """) + matchingArgs("leq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """) + matchingArgs("geq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """) // Never true - matchingArgs("neq", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """ ) - matchingArgs("lt", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """ ) - matchingArgs("gt", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """ ) + matchingArgs("neq", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """) + matchingArgs("lt", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """) + matchingArgs("gt", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """) } - behavior of "Reduction operators" + behavior.of("Reduction operators") it should "optimize andr of a literal" in { val input = @@ -1534,7 +1530,6 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { } - class ConstantPropagationEquivalenceSpec extends FirrtlFlatSpec { private val srcDir = "/constant_propagation_tests" private val transforms = Seq(new ConstantPropagation) @@ -1642,15 +1637,15 @@ class ConstantPropagationEquivalenceSpec extends FirrtlFlatSpec { firrtlEquivalenceTest(input, transforms) } - "addition of negative literals" should "be propagated" in { - val input = - s"""circuit AddTester : - | module AddTester : - | output ref : SInt<2> - | ref <= add(SInt<1>("h-1"), SInt<1>("h-1")) - |""".stripMargin - firrtlEquivalenceTest(input, transforms) - } + "addition of negative literals" should "be propagated" in { + val input = + s"""circuit AddTester : + | module AddTester : + | output ref : SInt<2> + | ref <= add(SInt<1>("h-1"), SInt<1>("h-1")) + |""".stripMargin + firrtlEquivalenceTest(input, transforms) + } "propagation of signed expressions" should "have the correct signs" in { val input = diff --git a/src/test/scala/firrtlTests/CustomTransformSpec.scala b/src/test/scala/firrtlTests/CustomTransformSpec.scala index 3e5fd254..6edd212d 100644 --- a/src/test/scala/firrtlTests/CustomTransformSpec.scala +++ b/src/test/scala/firrtlTests/CustomTransformSpec.scala @@ -19,28 +19,28 @@ object CustomTransformSpec { class ReplaceExtModuleTransform extends SeqTransform with FirrtlMatchers { // Simple module val delayModuleString = """ - |circuit Delay : - | module Delay : - | input clock : Clock - | input reset : UInt<1> - | input a : UInt<32> - | input en : UInt<1> - | output b : UInt<32> - | - | reg r : UInt<32>, clock - | r <= r - | when en : - | r <= a - | b <= r - |""".stripMargin + |circuit Delay : + | module Delay : + | input clock : Clock + | input reset : UInt<1> + | input a : UInt<32> + | input en : UInt<1> + | output b : UInt<32> + | + | reg r : UInt<32>, clock + | r <= r + | when en : + | r <= a + | b <= r + |""".stripMargin val delayModuleCircuit = parse(delayModuleString) val delayModule = delayModuleCircuit.modules.find(_.name == delayModuleCircuit.main).get class ReplaceExtModule extends Pass { def run(c: Circuit): Circuit = c.copy( - modules = c.modules map { + modules = c.modules.map { case ExtModule(_, "Delay", _, _, _) => delayModule - case other => other + case other => other } ) } @@ -50,10 +50,10 @@ object CustomTransformSpec { } val input = """ - |circuit test : - | module test : - | output out : UInt - | out <= UInt(123)""".stripMargin + |circuit test : + | module test : + | output out : UInt + | out <= UInt(123)""".stripMargin val errorString = "My Custom Transform failed!" class ErroringTransform extends Transform { def inputForm = HighForm @@ -122,7 +122,7 @@ class CustomTransformSpec extends FirrtlFlatSpec { import CustomTransformSpec._ - behavior of "Custom Transforms" + behavior.of("Custom Transforms") they should "be able to introduce high firrtl" in { runFirrtlTest("CustomTransform", "/features", customTransforms = List(new ReplaceExtModuleTransform)) @@ -130,22 +130,24 @@ class CustomTransformSpec extends FirrtlFlatSpec { they should "not cause \"Internal Errors\"" in { val optionsManager = new ExecutionOptionsManager("test") with HasFirrtlOptions { - firrtlOptions = FirrtlExecutionOptions( - firrtlSource = Some(input), - customTransforms = List(new ErroringTransform)) + firrtlOptions = FirrtlExecutionOptions(firrtlSource = Some(input), customTransforms = List(new ErroringTransform)) } - (the [java.lang.IllegalArgumentException] thrownBy { + (the[java.lang.IllegalArgumentException] thrownBy { Driver.execute(optionsManager) - }).getMessage should include (errorString) + }).getMessage should include(errorString) } they should "preserve the input order" in { - runFirrtlTest("CustomTransform", "/features", customTransforms = List( - new FirstTransform, - new SecondTransform, - new ThirdTransform, - new ReplaceExtModuleTransform - )) + runFirrtlTest( + "CustomTransform", + "/features", + customTransforms = List( + new FirstTransform, + new SecondTransform, + new ThirdTransform, + new ReplaceExtModuleTransform + ) + ) } they should "run right before the emitter* when inputForm=LowForm" in { @@ -159,11 +161,10 @@ class CustomTransformSpec extends FirrtlFlatSpec { val custom = Dependency[IdentityLowForm] val tm = new firrtl.stage.transforms.Compiler(custom :: emitter :: Nil) info(s"when using ${emitter.getObject.name}") - tm - .flattenedTransformOrder + tm.flattenedTransformOrder .map(Dependency.fromTransform) .sliding(2) - .toList should contain (Seq(custom, emitter)) + .toList should contain(Seq(custom, emitter)) } } diff --git a/src/test/scala/firrtlTests/DCETests.scala b/src/test/scala/firrtlTests/DCETests.scala index b309467a..a9084f0b 100644 --- a/src/test/scala/firrtlTests/DCETests.scala +++ b/src/test/scala/firrtlTests/DCETests.scala @@ -13,7 +13,8 @@ import java.io.File import java.nio.file.Paths case class AnnotationWithDontTouches(target: ReferenceTarget) - extends SingleTargetAnnotation[ReferenceTarget] with HasDontTouches { + extends SingleTargetAnnotation[ReferenceTarget] + with HasDontTouches { def targets = Seq(target) def duplicate(n: ReferenceTarget) = this.copy(n) def dontTouches: Seq[ReferenceTarget] = targets @@ -31,9 +32,9 @@ class DCETests extends FirrtlFlatSpec { val finalState = (new LowFirrtlCompiler).compileAndEmit(state, customTransforms) val res = finalState.getEmittedCircuit.value // Convert to sets for comparison - val resSet = Set(parse(res).serialize.split("\n"):_*) - val checkSet = Set(parse(check).serialize.split("\n"):_*) - resSet should be (checkSet) + val resSet = Set(parse(res).serialize.split("\n"): _*) + val checkSet = Set(parse(check).serialize.split("\n"): _*) + resSet should be(checkSet) } "Unread wire" should "be deleted" in { @@ -418,7 +419,7 @@ class DCETests extends FirrtlFlatSpec { exec(input, check) } // This currently does NOT work - behavior of "Single dead instances" + behavior.of("Single dead instances") ignore should "should be deleted" in { val input = """circuit Top : @@ -469,9 +470,9 @@ class DCETests extends FirrtlFlatSpec { val result = (new VerilogCompiler).compileAndEmit(state, List.empty) val verilog = result.getEmittedCircuit.value // Check that mux is removed! - verilog shouldNot include regex ("""a \? x : r;""") + (verilog shouldNot include).regex("""a \? x : r;""") // Check for register update - verilog should include regex ("""(?m)if \(a\) begin\n\s*r <= x;\s*end""") + (verilog should include).regex("""(?m)if \(a\) begin\n\s*r <= x;\s*end""") } "Emitted Verilog" should "not contain dead print or stop statements" in { @@ -487,8 +488,8 @@ class DCETests extends FirrtlFlatSpec { val state = CircuitState(input, ChirrtlForm) val result = (new VerilogCompiler).compileAndEmit(state, List.empty) val verilog = result.getEmittedCircuit.value - verilog shouldNot include regex ("""fwrite""") - verilog shouldNot include regex ("""fatal""") + (verilog shouldNot include).regex("""fwrite""") + (verilog shouldNot include).regex("""fatal""") } } @@ -502,7 +503,7 @@ class DCECommandLineSpec extends FirrtlFlatSpec { "Dead Code Elimination" should "run by default" in { firrtl.Driver.execute(args) match { case FirrtlExecutionSuccess(_, verilog) => - verilog should not include regex ("wire +a") + (verilog should not).include(regex("wire +a")) case _ => fail("Unexpected compilation failure") } } @@ -510,7 +511,7 @@ class DCECommandLineSpec extends FirrtlFlatSpec { it should "not run when given --no-dce option" in { firrtl.Driver.execute(args :+ "--no-dce") match { case FirrtlExecutionSuccess(_, verilog) => - verilog should include regex ("wire +a") + (verilog should include).regex("wire +a") case _ => fail("Unexpected compilation failure") } } diff --git a/src/test/scala/firrtlTests/DriverSpec.scala b/src/test/scala/firrtlTests/DriverSpec.scala index 400bf314..5352fadf 100644 --- a/src/test/scala/firrtlTests/DriverSpec.scala +++ b/src/test/scala/firrtlTests/DriverSpec.scala @@ -85,15 +85,13 @@ class DriverSpec extends AnyFreeSpec with Matchers with BackendCompilationUtilit optionsManager.commonOptions.programArgs should be("fox" :: "tardigrade" :: "stomatopod" :: Nil) optionsManager.commonOptions = CommonOptions() - optionsManager.parse( - Array("dog", "stomatopod")) should be(true) + optionsManager.parse(Array("dog", "stomatopod")) should be(true) info(s"programArgs ${optionsManager.commonOptions.programArgs}") optionsManager.commonOptions.programArgs.length should be(2) optionsManager.commonOptions.programArgs should be("dog" :: "stomatopod" :: Nil) optionsManager.commonOptions = CommonOptions() - optionsManager.parse( - Array("fox", "--top-name", "dog", "tardigrade", "stomatopod")) should be(true) + optionsManager.parse(Array("fox", "--top-name", "dog", "tardigrade", "stomatopod")) should be(true) info(s"programArgs ${optionsManager.commonOptions.programArgs}") optionsManager.commonOptions.programArgs.length should be(3) optionsManager.commonOptions.programArgs should be("fox" :: "tardigrade" :: "stomatopod" :: Nil) @@ -130,11 +128,11 @@ class DriverSpec extends AnyFreeSpec with Matchers with BackendCompilationUtilit outputFileName should be("carol.v") } val input = """ - |circuit Top : - | module Top : - | input x : UInt<8> - | output y : UInt<8> - | y <= x""".stripMargin + |circuit Top : + | module Top : + | input x : UInt<8> + | output y : UInt<8> + | y <= x""".stripMargin val circuit = Parser.parse(input.split("\n").toIterator) "firrtl source can be provided directly" in { val manager = new ExecutionOptionsManager("test") with HasFirrtlOptions { @@ -153,18 +151,15 @@ class DriverSpec extends AnyFreeSpec with Matchers with BackendCompilationUtilit "Only one of inputFileNameOverride, firrtlSource, and firrtlCircuit can be used at a time" in { val manager1 = new ExecutionOptionsManager("test") with HasFirrtlOptions { commonOptions = CommonOptions(topName = "Top") - firrtlOptions = FirrtlExecutionOptions(firrtlCircuit = Some(circuit), - firrtlSource = Some(input)) + firrtlOptions = FirrtlExecutionOptions(firrtlCircuit = Some(circuit), firrtlSource = Some(input)) } val manager2 = new ExecutionOptionsManager("test") with HasFirrtlOptions { commonOptions = CommonOptions(topName = "Top") - firrtlOptions = FirrtlExecutionOptions(inputFileNameOverride = "hi", - firrtlSource = Some(input)) + firrtlOptions = FirrtlExecutionOptions(inputFileNameOverride = "hi", firrtlSource = Some(input)) } val manager3 = new ExecutionOptionsManager("test") with HasFirrtlOptions { commonOptions = CommonOptions(topName = "Top") - firrtlOptions = FirrtlExecutionOptions(inputFileNameOverride = "hi", - firrtlCircuit = Some(circuit)) + firrtlOptions = FirrtlExecutionOptions(inputFileNameOverride = "hi", firrtlCircuit = Some(circuit)) } assert(firrtl.Driver.getCircuit(manager1).isFailure) assert(firrtl.Driver.getCircuit(manager2).isFailure) @@ -273,26 +268,25 @@ class DriverSpec extends AnyFreeSpec with Matchers with BackendCompilationUtilit "verilog" -> "./Foo.v", "mverilog" -> "./Foo.v", "sverilog" -> "./Foo.sv" - ).foreach { case (compilerName, expectedOutputFileName) => - info(s"$compilerName -> $expectedOutputFileName") - val manager = new ExecutionOptionsManager("test") with HasFirrtlOptions { - commonOptions = CommonOptions(topName = "Foo") - firrtlOptions = FirrtlExecutionOptions(firrtlSource = Some(input), compilerName = compilerName) - } - - firrtl.Driver.execute(manager) match { - case success: FirrtlExecutionSuccess => - success.emitted.size should not be (0) - success.circuitState.annotations.length should be > (0) - case a: FirrtlExecutionFailure => - fail(s"Got a FirrtlExecutionFailure! Expected FirrtlExecutionSuccess. Full message:\n${a.message}") - } - - - - val file = new File(expectedOutputFileName) - file.exists() should be(true) - file.delete() + ).foreach { + case (compilerName, expectedOutputFileName) => + info(s"$compilerName -> $expectedOutputFileName") + val manager = new ExecutionOptionsManager("test") with HasFirrtlOptions { + commonOptions = CommonOptions(topName = "Foo") + firrtlOptions = FirrtlExecutionOptions(firrtlSource = Some(input), compilerName = compilerName) + } + + firrtl.Driver.execute(manager) match { + case success: FirrtlExecutionSuccess => + success.emitted.size should not be (0) + success.circuitState.annotations.length should be > (0) + case a: FirrtlExecutionFailure => + fail(s"Got a FirrtlExecutionFailure! Expected FirrtlExecutionSuccess. Full message:\n${a.message}") + } + + val file = new File(expectedOutputFileName) + file.exists() should be(true) + file.delete() } } "To a single file per module if OneFilePerModule is specified" in { @@ -304,27 +298,30 @@ class DriverSpec extends AnyFreeSpec with Matchers with BackendCompilationUtilit "verilog" -> Seq("./Top.v", "./Child.v"), "mverilog" -> Seq("./Top.v", "./Child.v"), "sverilog" -> Seq("./Top.sv", "./Child.sv") - ).foreach { case (compilerName, expectedOutputFileNames) => - info(s"$compilerName -> $expectedOutputFileNames") - val manager = new ExecutionOptionsManager("test") with HasFirrtlOptions { - firrtlOptions = FirrtlExecutionOptions(firrtlSource = Some(input), - compilerName = compilerName, - emitOneFilePerModule = true) - } - - firrtl.Driver.execute(manager) match { - case success: FirrtlExecutionSuccess => - success.emitted.size should not be (0) - success.circuitState.annotations.length should be > (0) - case failure: FirrtlExecutionFailure => - fail(s"Got a FirrtlExecutionFailure! Expected FirrtlExecutionSuccess. Full message:\n${failure.message}") - } - - for (name <- expectedOutputFileNames) { - val file = new File(name) - file.exists() should be(true) - file.delete() - } + ).foreach { + case (compilerName, expectedOutputFileNames) => + info(s"$compilerName -> $expectedOutputFileNames") + val manager = new ExecutionOptionsManager("test") with HasFirrtlOptions { + firrtlOptions = FirrtlExecutionOptions( + firrtlSource = Some(input), + compilerName = compilerName, + emitOneFilePerModule = true + ) + } + + firrtl.Driver.execute(manager) match { + case success: FirrtlExecutionSuccess => + success.emitted.size should not be (0) + success.circuitState.annotations.length should be > (0) + case failure: FirrtlExecutionFailure => + fail(s"Got a FirrtlExecutionFailure! Expected FirrtlExecutionSuccess. Full message:\n${failure.message}") + } + + for (name <- expectedOutputFileNames) { + val file = new File(name) + file.exists() should be(true) + file.delete() + } } } } @@ -348,7 +345,7 @@ class DriverSpec extends AnyFreeSpec with Matchers with BackendCompilationUtilit "Both paths do the same thing" in { val s1 = FileUtils.getText(verilogFromFir) val s2 = FileUtils.getText(verilogFromPb) - s1 should equal (s2) + s1 should equal(s2) } } @@ -378,12 +375,12 @@ class VcdSuppressionSpec extends FirrtlFlatSpec { copyResourceToFile(cppHarnessResourceName, harness) verilogToCpp(prefix, testDir, Seq.empty, harness, suppress) #&& - cppToExe(prefix, testDir) ! loggingProcessLogger + cppToExe(prefix, testDir) ! loggingProcessLogger assert(executeExpectingSuccess(prefix, testDir)) val vcdFile = new File(s"$testDir/dump.vcd") - vcdFile.exists() should be(! suppress) + vcdFile.exists() should be(!suppress) } testIfVcdCreated(suppress = false) diff --git a/src/test/scala/firrtlTests/ExecutionOptionsManagerSpec.scala b/src/test/scala/firrtlTests/ExecutionOptionsManagerSpec.scala index 863b6900..9f756927 100644 --- a/src/test/scala/firrtlTests/ExecutionOptionsManagerSpec.scala +++ b/src/test/scala/firrtlTests/ExecutionOptionsManagerSpec.scala @@ -10,32 +10,36 @@ class ExecutionOptionsManagerSpec extends AnyFreeSpec with Matchers { "ExecutionOptionsManager is a container for one more more ComposableOptions Block" - { "It has a default CommonOptionsBlock" in { val manager = new ExecutionOptionsManager("test") - manager.topName should be ("") - manager.targetDirName should be (".") - manager.commonOptions.topName should be ("") - manager.commonOptions.targetDirName should be (".") + manager.topName should be("") + manager.targetDirName should be(".") + manager.commonOptions.topName should be("") + manager.commonOptions.targetDirName should be(".") } "But can override defaults like this" in { - val manager = new ExecutionOptionsManager("test") { commonOptions = CommonOptions(topName = "dog", targetDirName = "a/b/c") } - manager.commonOptions shouldBe a [CommonOptions] - manager.topName should be ("dog") - manager.targetDirName should be ("a/b/c") - manager.commonOptions.topName should be ("dog") - manager.commonOptions.targetDirName should be ("a/b/c") + val manager = new ExecutionOptionsManager("test") { + commonOptions = CommonOptions(topName = "dog", targetDirName = "a/b/c") + } + manager.commonOptions shouldBe a[CommonOptions] + manager.topName should be("dog") + manager.targetDirName should be("a/b/c") + manager.commonOptions.topName should be("dog") + manager.commonOptions.targetDirName should be("a/b/c") } "The add method should put a new version of a given type the manager" in { - val manager = new ExecutionOptionsManager("test") { commonOptions = CommonOptions(topName = "dog", targetDirName = "a/b/c") } + val manager = new ExecutionOptionsManager("test") { + commonOptions = CommonOptions(topName = "dog", targetDirName = "a/b/c") + } val initialCommon = manager.commonOptions - initialCommon.topName should be ("dog") - initialCommon.targetDirName should be ("a/b/c") + initialCommon.topName should be("dog") + initialCommon.targetDirName should be("a/b/c") manager.commonOptions = CommonOptions(topName = "cat", targetDirName = "d/e/f") val afterCommon = manager.commonOptions - afterCommon.topName should be ("cat") - afterCommon.targetDirName should be ("d/e/f") - initialCommon.topName should be ("dog") - initialCommon.targetDirName should be ("a/b/c") + afterCommon.topName should be("cat") + afterCommon.targetDirName should be("d/e/f") + initialCommon.topName should be("dog") + initialCommon.targetDirName should be("a/b/c") } "multiple composable blocks should be separable" in { val manager = new ExecutionOptionsManager("test") with HasFirrtlOptions { @@ -43,8 +47,8 @@ class ExecutionOptionsManagerSpec extends AnyFreeSpec with Matchers { firrtlOptions = FirrtlExecutionOptions(inputFileNameOverride = "fork") } - manager.firrtlOptions.inputFileNameOverride should be ("fork") - manager.commonOptions.topName should be ("spoon") + manager.firrtlOptions.inputFileNameOverride should be("fork") + manager.commonOptions.topName should be("spoon") } } } diff --git a/src/test/scala/firrtlTests/ExpandWhensSpec.scala b/src/test/scala/firrtlTests/ExpandWhensSpec.scala index 3616397f..6737643a 100644 --- a/src/test/scala/firrtlTests/ExpandWhensSpec.scala +++ b/src/test/scala/firrtlTests/ExpandWhensSpec.scala @@ -22,54 +22,55 @@ class ExpandWhensSpec extends FirrtlFlatSpec { PullMuxes, ExpandConnects, RemoveAccesses, - ExpandWhens) + ExpandWhens + ) private def executeTest(input: String, check: String, expected: Boolean) = { val circuit = Parser.parse(input.split("\n").toIterator) - val result = transforms.foldLeft(CircuitState(circuit, UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) + val result = transforms.foldLeft(CircuitState(circuit, UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) } val c = result.circuit - val lines = c.serialize.split("\n") map normalized + val lines = c.serialize.split("\n").map(normalized) if (expected) { - c.serialize.contains(check) should be (true) + c.serialize.contains(check) should be(true) } else { - lines.foreach(_.contains(check) should be (false)) + lines.foreach(_.contains(check) should be(false)) } } "Expand Whens" should "not emit INVALID" in { val input = - """|circuit Tester : - | module Tester : - | input p : UInt<1> - | when p : - | wire a : {b : UInt<64>, c : UInt<64>} - | a is invalid - | a.b <= UInt<64>("h04000000000000000")""".stripMargin + """|circuit Tester : + | module Tester : + | input p : UInt<1> + | when p : + | wire a : {b : UInt<64>, c : UInt<64>} + | a is invalid + | a.b <= UInt<64>("h04000000000000000")""".stripMargin val check = "INVALID" executeTest(input, check, false) } it should "void unwritten memory fields" in { val input = - """|circuit Tester : - | module Tester : - | input clk : Clock - | mem memory: - | data-type => UInt<32> - | depth => 32 - | reader => r0 - | writer => w0 - | read-latency => 0 - | write-latency => 1 - | read-under-write => undefined - | memory.r0.addr <= UInt<1>(1) - | memory.r0.en <= UInt<1>(1) - | memory.r0.clk <= clk - | memory.w0.addr <= UInt<1>(1) - | memory.w0.data <= UInt<1>(1) - | memory.w0.en <= UInt<1>(1) - | memory.w0.clk <= clk - | """.stripMargin + """|circuit Tester : + | module Tester : + | input clk : Clock + | mem memory: + | data-type => UInt<32> + | depth => 32 + | reader => r0 + | writer => w0 + | read-latency => 0 + | write-latency => 1 + | read-under-write => undefined + | memory.r0.addr <= UInt<1>(1) + | memory.r0.en <= UInt<1>(1) + | memory.r0.clk <= clk + | memory.w0.addr <= UInt<1>(1) + | memory.w0.data <= UInt<1>(1) + | memory.w0.en <= UInt<1>(1) + | memory.w0.clk <= clk + | """.stripMargin val check = "VOID" executeTest(input, check, true) } diff --git a/src/test/scala/firrtlTests/ExtModuleSpec.scala b/src/test/scala/firrtlTests/ExtModuleSpec.scala index 7379f1aa..c684e57b 100644 --- a/src/test/scala/firrtlTests/ExtModuleSpec.scala +++ b/src/test/scala/firrtlTests/ExtModuleSpec.scala @@ -4,13 +4,12 @@ package firrtlTests import firrtl.testutils._ -class SimpleExtModuleExecutionTest extends ExecutionTest("SimpleExtModuleTester", "/blackboxes", - Seq("SimpleExtModule")) -class MultiExtModuleExecutionTest extends ExecutionTest("MultiExtModuleTester", "/blackboxes", - Seq("SimpleExtModule", "AdderExtModule")) -class RenamedExtModuleExecutionTest extends ExecutionTest("RenamedExtModuleTester", "/blackboxes", - Seq("SimpleExtModule")) -class ParameterizedExtModuleExecutionTest extends ExecutionTest( - "ParameterizedExtModuleTester", "/blackboxes", Seq("ParameterizedExtModule")) +class SimpleExtModuleExecutionTest extends ExecutionTest("SimpleExtModuleTester", "/blackboxes", Seq("SimpleExtModule")) +class MultiExtModuleExecutionTest + extends ExecutionTest("MultiExtModuleTester", "/blackboxes", Seq("SimpleExtModule", "AdderExtModule")) +class RenamedExtModuleExecutionTest + extends ExecutionTest("RenamedExtModuleTester", "/blackboxes", Seq("SimpleExtModule")) +class ParameterizedExtModuleExecutionTest + extends ExecutionTest("ParameterizedExtModuleTester", "/blackboxes", Seq("ParameterizedExtModule")) class LargeParamExecutionTest extends ExecutionTest("LargeParamTester", "/blackboxes", Seq("LargeParam")) diff --git a/src/test/scala/firrtlTests/ExtModuleTests.scala b/src/test/scala/firrtlTests/ExtModuleTests.scala index 9ab3429e..5a58df2b 100644 --- a/src/test/scala/firrtlTests/ExtModuleTests.scala +++ b/src/test/scala/firrtlTests/ExtModuleTests.scala @@ -20,7 +20,6 @@ class ExtModuleTests extends FirrtlFlatSpec { | parameter TYP = 'bit' | """.stripMargin val parsed = parse(input) - (parse(parsed.serialize)) should be (parsed) + (parse(parsed.serialize)) should be(parsed) } } - diff --git a/src/test/scala/firrtlTests/FeatureSpec.scala b/src/test/scala/firrtlTests/FeatureSpec.scala index c7c8f4ac..4972eeb5 100644 --- a/src/test/scala/firrtlTests/FeatureSpec.scala +++ b/src/test/scala/firrtlTests/FeatureSpec.scala @@ -6,4 +6,3 @@ import firrtl.testutils.ExecutionTest // Miscellaneous Feature Checks class NestedSubAccessExecutionTest extends ExecutionTest("NestedSubAccessTester", "/features") - diff --git a/src/test/scala/firrtlTests/FileUtilsSpec.scala b/src/test/scala/firrtlTests/FileUtilsSpec.scala index 43d35048..5a438251 100644 --- a/src/test/scala/firrtlTests/FileUtilsSpec.scala +++ b/src/test/scala/firrtlTests/FileUtilsSpec.scala @@ -2,17 +2,16 @@ package firrtlTests - import firrtl.FileUtils import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers class FileUtilsSpec extends AnyFlatSpec with Matchers { - private val sampleAnnotations: String = "annotations/SampleAnnotations.anno.json" + private val sampleAnnotations: String = "annotations/SampleAnnotations.anno.json" private val sampleAnnotationsFileName: String = s"src/test/resources/$sampleAnnotations" - behavior of "FileUtils.getLines" + behavior.of("FileUtils.getLines") it should "read from a string filename" in { FileUtils.getLines(sampleAnnotationsFileName).size should be > 0 @@ -22,7 +21,7 @@ class FileUtilsSpec extends AnyFlatSpec with Matchers { FileUtils.getLines(new java.io.File(sampleAnnotationsFileName)).size should be > 0 } - behavior of "FileUtils.getText" + behavior.of("FileUtils.getText") it should "read from a string filename" in { FileUtils.getText(sampleAnnotationsFileName).size should be > 0 @@ -32,13 +31,13 @@ class FileUtilsSpec extends AnyFlatSpec with Matchers { FileUtils.getText(new java.io.File(sampleAnnotationsFileName)).size should be > 0 } - behavior of "FileUtils.getLinesResource" + behavior.of("FileUtils.getLinesResource") it should "read from a resource" in { FileUtils.getLinesResource(s"/$sampleAnnotations").size should be > 0 } - behavior of "FileUtils.getTextResource" + behavior.of("FileUtils.getTextResource") it should "read from a resource" in { FileUtils.getTextResource(s"/$sampleAnnotations").split("\n").size should be > 0 diff --git a/src/test/scala/firrtlTests/FlattenTests.scala b/src/test/scala/firrtlTests/FlattenTests.scala index 34edfe58..53604ee5 100644 --- a/src/test/scala/firrtlTests/FlattenTests.scala +++ b/src/test/scala/firrtlTests/FlattenTests.scala @@ -3,12 +3,12 @@ package firrtlTests import firrtl.annotations.{Annotation, CircuitName, ComponentName, ModuleName} -import firrtl.transforms.{FlattenAnnotation, Flatten, NoCircuitDedupAnnotation} +import firrtl.transforms.{Flatten, FlattenAnnotation, NoCircuitDedupAnnotation} import firrtl.testutils._ /** - * Tests deep inline transformation - */ + * Tests deep inline transformation + */ class FlattenTests extends LowTransformSpec { def transform = new Flatten def flatten(mod: String): Annotation = { @@ -19,204 +19,204 @@ class FlattenTests extends LowTransformSpec { } "The modules inside Top " should "be inlined" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline1 - | i.a <= a - | b <= i.b - | module Inline1 : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | wire i_a : UInt<32> - | wire i_b : UInt<32> - | i_b <= i_a - | b <= i_b - | i_a <= a - | module Inline1 : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - execute(input, check, Seq(flatten("Top"))) + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline1 + | i.a <= a + | b <= i.b + | module Inline1 : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i_a : UInt<32> + | wire i_b : UInt<32> + | i_b <= i_a + | b <= i_b + | i_a <= a + | module Inline1 : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + execute(input, check, Seq(flatten("Top"))) } "Two instances of the same module inside Top " should "be inlined" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i1 of Inline1 - | inst i2 of Inline1 - | i1.a <= a - | node tmp = i1.b - | i2.a <= tmp - | b <= i2.b - | module Inline1 : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | wire i1_a : UInt<32> - | wire i1_b : UInt<32> - | i1_b <= i1_a - | wire i2_a : UInt<32> - | wire i2_b : UInt<32> - | i2_b <= i2_a - | node tmp = i1_b - | b <= i2_b - | i1_a <= a - | i2_a <= tmp - | module Inline1 : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - execute(input, check, Seq(flatten("Top"))) + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i1 of Inline1 + | inst i2 of Inline1 + | i1.a <= a + | node tmp = i1.b + | i2.a <= tmp + | b <= i2.b + | module Inline1 : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i1_a : UInt<32> + | wire i1_b : UInt<32> + | i1_b <= i1_a + | wire i2_a : UInt<32> + | wire i2_b : UInt<32> + | i2_b <= i2_a + | node tmp = i1_b + | b <= i2_b + | i1_a <= a + | i2_a <= tmp + | module Inline1 : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + execute(input, check, Seq(flatten("Top"))) } "The module instance i in Top " should "be inlined" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | input na : UInt<32> - | output b : UInt<32> - | output nb : UInt<32> - | inst i of Inline1 - | inst ni of NotInline1 - | i.a <= a - | b <= i.b - | ni.a <= na - | nb <= ni.b - | module NotInline1 : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline2 - | i.a <= a - | b <= i.a - | module Inline1 : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline2 - | i.a <= a - | b <= i.a - | module Inline2 : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<32> - | input na : UInt<32> - | output b : UInt<32> - | output nb : UInt<32> - | wire i_a : UInt<32> - | wire i_b : UInt<32> - | wire i_i_a : UInt<32> - | wire i_i_b : UInt<32> - | i_i_b <= i_i_a - | i_b <= i_i_a - | i_i_a <= i_a - | inst ni of NotInline1 - | b <= i_b - | nb <= ni.b - | i_a <= a - | ni.a <= na - | module NotInline1 : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline2 - | b <= i.a - | i.a <= a - | module Inline1 : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline2 - | b <= i.a - | i.a <= a - | module Inline2 : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - execute(input, check, Seq(flatten("Top.i"), NoCircuitDedupAnnotation)) + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | input na : UInt<32> + | output b : UInt<32> + | output nb : UInt<32> + | inst i of Inline1 + | inst ni of NotInline1 + | i.a <= a + | b <= i.b + | ni.a <= na + | nb <= ni.b + | module NotInline1 : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline2 + | i.a <= a + | b <= i.a + | module Inline1 : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline2 + | i.a <= a + | b <= i.a + | module Inline2 : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | input na : UInt<32> + | output b : UInt<32> + | output nb : UInt<32> + | wire i_a : UInt<32> + | wire i_b : UInt<32> + | wire i_i_a : UInt<32> + | wire i_i_b : UInt<32> + | i_i_b <= i_i_a + | i_b <= i_i_a + | i_i_a <= i_a + | inst ni of NotInline1 + | b <= i_b + | nb <= ni.b + | i_a <= a + | ni.a <= na + | module NotInline1 : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline2 + | b <= i.a + | i.a <= a + | module Inline1 : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline2 + | b <= i.a + | i.a <= a + | module Inline2 : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + execute(input, check, Seq(flatten("Top.i"), NoCircuitDedupAnnotation)) } "The module Inline1" should "be inlined" in { val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | input na : UInt<32> - | output b : UInt<32> - | output nb : UInt<32> - | inst i of Inline1 - | inst ni of NotInline1 - | i.a <= a - | b <= i.b - | ni.a <= na - | nb <= ni.b - | module NotInline1 : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline2 - | i.a <= a - | b <= i.a - | module Inline1 : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline2 - | i.a <= a - | b <= i.a - | module Inline2 : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<32> - | input na : UInt<32> - | output b : UInt<32> - | output nb : UInt<32> - | inst i of Inline1 - | inst ni of NotInline1 - | b <= i.b - | nb <= ni.b - | i.a <= a - | ni.a <= na - | module NotInline1 : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline2 - | b <= i.a - | i.a <= a - | module Inline1 : - | input a : UInt<32> - | output b : UInt<32> - | wire i_a : UInt<32> - | wire i_b : UInt<32> - | i_b <= i_a - | b <= i_a - | i_a <= a - | module Inline2 : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - execute(input, check, Seq(flatten("Inline1"), NoCircuitDedupAnnotation)) + """circuit Top : + | module Top : + | input a : UInt<32> + | input na : UInt<32> + | output b : UInt<32> + | output nb : UInt<32> + | inst i of Inline1 + | inst ni of NotInline1 + | i.a <= a + | b <= i.b + | ni.a <= na + | nb <= ni.b + | module NotInline1 : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline2 + | i.a <= a + | b <= i.a + | module Inline1 : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline2 + | i.a <= a + | b <= i.a + | module Inline2 : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | input na : UInt<32> + | output b : UInt<32> + | output nb : UInt<32> + | inst i of Inline1 + | inst ni of NotInline1 + | b <= i.b + | nb <= ni.b + | i.a <= a + | ni.a <= na + | module NotInline1 : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline2 + | b <= i.a + | i.a <= a + | module Inline1 : + | input a : UInt<32> + | output b : UInt<32> + | wire i_a : UInt<32> + | wire i_b : UInt<32> + | i_b <= i_a + | b <= i_a + | i_a <= a + | module Inline2 : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + execute(input, check, Seq(flatten("Inline1"), NoCircuitDedupAnnotation)) } - "The Flatten transform" should "do nothing if no flatten annotations are present" in{ + "The Flatten transform" should "do nothing if no flatten annotations are present" in { val input = """|circuit Foo: | module Foo: @@ -229,46 +229,46 @@ class FlattenTests extends LowTransformSpec { "The Flatten transform" should "ignore extmodules" in { val input = """ - |circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline - | i.a <= a - | b <= i.b - | module Inline : - | input a : UInt<32> - | output b : UInt<32> - | inst i of ExternalMod - | i.a <= a - | b <= i.b - | extmodule ExternalMod : - | input a : UInt<32> - | output b : UInt<32> - | defname = ExternalMod + |circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline + | i.a <= a + | b <= i.b + | module Inline : + | input a : UInt<32> + | output b : UInt<32> + | inst i of ExternalMod + | i.a <= a + | b <= i.b + | extmodule ExternalMod : + | input a : UInt<32> + | output b : UInt<32> + | defname = ExternalMod """.stripMargin val check = """ - |circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | wire i_a : UInt<32> - | wire i_b : UInt<32> - | inst i_i of ExternalMod - | i_b <= i_i.b - | i_i.a <= i_a - | b <= i_b - | i_a <= a - | module Inline : - | input a : UInt<32> - | output b : UInt<32> - | inst i of ExternalMod - | b <= i.b - | i.a <= a - | extmodule ExternalMod : - | input a : UInt<32> - | output b : UInt<32> - | defname = ExternalMod + |circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i_a : UInt<32> + | wire i_b : UInt<32> + | inst i_i of ExternalMod + | i_b <= i_i.b + | i_i.a <= i_a + | b <= i_b + | i_a <= a + | module Inline : + | input a : UInt<32> + | output b : UInt<32> + | inst i of ExternalMod + | b <= i.b + | i.a <= a + | extmodule ExternalMod : + | input a : UInt<32> + | output b : UInt<32> + | defname = ExternalMod """.stripMargin execute(input, check, Seq(flatten("Top"))) } diff --git a/src/test/scala/firrtlTests/InferReadWriteSpec.scala b/src/test/scala/firrtlTests/InferReadWriteSpec.scala index e8be70ad..81f2df33 100644 --- a/src/test/scala/firrtlTests/InferReadWriteSpec.scala +++ b/src/test/scala/firrtlTests/InferReadWriteSpec.scala @@ -10,8 +10,7 @@ import firrtl.testutils._ import firrtl.testutils.FirrtlCheckers._ class InferReadWriteSpec extends SimpleTransformSpec { - class InferReadWriteCheckException extends PassException( - "Readwrite ports are not found!") + class InferReadWriteCheckException extends PassException("Readwrite ports are not found!") object InferReadWriteCheck extends Pass { override def prerequisites = Forms.MidForm @@ -23,18 +22,18 @@ class InferReadWriteSpec extends SimpleTransformSpec { case s: DefMemory if s.readLatency > 0 && s.readwriters.size == 1 => s.name == "mem" && s.readwriters.head == "rw" case s: Block => - s.stmts exists findReadWrite + s.stmts.exists(findReadWrite) case _ => false } - def run (c: Circuit) = { + def run(c: Circuit) = { val errors = new Errors - val foundReadWrite = c.modules exists { - case m: Module => findReadWrite(m.body) + val foundReadWrite = c.modules.exists { + case m: Module => findReadWrite(m.body) case m: ExtModule => false } if (!foundReadWrite) { - errors append new InferReadWriteCheckException + errors.append(new InferReadWriteCheckException) errors.trigger } c @@ -176,6 +175,6 @@ circuit sram6t : val annos = Seq(memlib.InferReadWriteAnnotation) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // Check correctness of firrtl - res should containLine (s"mem.rw.wmode <= wen") + res should containLine(s"mem.rw.wmode <= wen") } } diff --git a/src/test/scala/firrtlTests/InferResetsSpec.scala b/src/test/scala/firrtlTests/InferResetsSpec.scala index b607fb46..057fb3b0 100644 --- a/src/test/scala/firrtlTests/InferResetsSpec.scala +++ b/src/test/scala/firrtlTests/InferResetsSpec.scala @@ -4,7 +4,7 @@ package firrtlTests import firrtl._ import firrtl.ir._ -import firrtl.passes.{CheckHighForm, CheckTypes, CheckInitialization} +import firrtl.passes.{CheckHighForm, CheckInitialization, CheckTypes} import firrtl.transforms.{CheckCombLoops, InferResets} import firrtl.testutils._ import firrtl.testutils.FirrtlCheckers._ @@ -16,95 +16,93 @@ class InferResetsSpec extends FirrtlFlatSpec { def compile(input: String, compiler: Compiler = new MiddleFirrtlCompiler): CircuitState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty) - behavior of "ResetType" + behavior.of("ResetType") val BoolType = UIntType(IntWidth(1)) it should "support casting to other types" in { val result = compile(s""" - |circuit top: - | module top: - | input a : UInt<1> - | output v : UInt<1> - | output w : SInt<1> - | output x : Clock - | output y : Fixed<1><<0>> - | output z : AsyncReset - | wire r : Reset - | r <= a - | v <= asUInt(r) - | w <= asSInt(r) - | x <= asClock(r) - | y <= asFixedPoint(r, 0) - | z <= asAsyncReset(r)""".stripMargin - ) - result should containLine ("wire r : UInt<1>") - result should containLine ("r <= a") - result should containLine ("v <= asUInt(r)") - result should containLine ("w <= asSInt(r)") - result should containLine ("x <= asClock(r)") - result should containLine ("y <= asSInt(r)") - result should containLine ("z <= asAsyncReset(r)") + |circuit top: + | module top: + | input a : UInt<1> + | output v : UInt<1> + | output w : SInt<1> + | output x : Clock + | output y : Fixed<1><<0>> + | output z : AsyncReset + | wire r : Reset + | r <= a + | v <= asUInt(r) + | w <= asSInt(r) + | x <= asClock(r) + | y <= asFixedPoint(r, 0) + | z <= asAsyncReset(r)""".stripMargin) + result should containLine("wire r : UInt<1>") + result should containLine("r <= a") + result should containLine("v <= asUInt(r)") + result should containLine("w <= asSInt(r)") + result should containLine("x <= asClock(r)") + result should containLine("y <= asSInt(r)") + result should containLine("z <= asAsyncReset(r)") } it should "work across Module boundaries" in { val result = compile(s""" - |circuit top : - | module child : - | input clock : Clock - | input childReset : Reset - | input x : UInt<8> - | output z : UInt<8> - | reg r : UInt<8>, clock with : (reset => (childReset, UInt(123))) - | r <= x - | z <= r - | module top : - | input clock : Clock - | input reset : UInt<1> - | input x : UInt<8> - | output z : UInt<8> - | inst c of child - | c.clock <= clock - | c.childReset <= reset - | c.x <= x - | z <= c.z - |""".stripMargin - ) + |circuit top : + | module child : + | input clock : Clock + | input childReset : Reset + | input x : UInt<8> + | output z : UInt<8> + | reg r : UInt<8>, clock with : (reset => (childReset, UInt(123))) + | r <= x + | z <= r + | module top : + | input clock : Clock + | input reset : UInt<1> + | input x : UInt<8> + | output z : UInt<8> + | inst c of child + | c.clock <= clock + | c.childReset <= reset + | c.x <= x + | z <= c.z + |""".stripMargin) result should containTree { case Port(_, "childReset", Input, BoolType) => true } } it should "work across multiple Module boundaries" in { val result = compile(s""" - |circuit top : - | module child : - | input resetIn : Reset - | output resetOut : Reset - | resetOut <= resetIn - | module top : - | input clock : Clock - | input reset : UInt<1> - | input x : UInt<8> - | output z : UInt<8> - | inst c of child - | c.resetIn <= reset - | reg r : UInt<8>, clock with : (reset => (c.resetOut, UInt(123))) - | r <= x - | z <= r - |""".stripMargin - ) + |circuit top : + | module child : + | input resetIn : Reset + | output resetOut : Reset + | resetOut <= resetIn + | module top : + | input clock : Clock + | input reset : UInt<1> + | input x : UInt<8> + | output z : UInt<8> + | inst c of child + | c.resetIn <= reset + | reg r : UInt<8>, clock with : (reset => (c.resetOut, UInt(123))) + | r <= x + | z <= r + |""".stripMargin) result should containTree { case Port(_, "resetIn", Input, BoolType) => true } result should containTree { case Port(_, "resetOut", Output, BoolType) => true } } it should "work in nested and flipped aggregates with regular and partial connect" in { - val result = compile(s""" - |circuit top : - | module top : - | output fizz : { flip foo : { a : AsyncReset, flip b: Reset }[2], bar : { a : Reset, flip b: AsyncReset }[2] } - | output buzz : { flip foo : { a : AsyncReset, c: UInt<1>, flip b: Reset }[2], bar : { a : Reset, flip b: AsyncReset, c: UInt<8> }[2] } - | fizz.bar <= fizz.foo - | buzz.bar <- buzz.foo - |""".stripMargin, + val result = compile( + s""" + |circuit top : + | module top : + | output fizz : { flip foo : { a : AsyncReset, flip b: Reset }[2], bar : { a : Reset, flip b: AsyncReset }[2] } + | output buzz : { flip foo : { a : AsyncReset, c: UInt<1>, flip b: Reset }[2], bar : { a : Reset, flip b: AsyncReset, c: UInt<8> }[2] } + | fizz.bar <= fizz.foo + | buzz.bar <- buzz.foo + |""".stripMargin, new LowFirrtlCompiler ) result should containTree { case Port(_, "fizz_foo_0_a", Input, AsyncResetType) => true } @@ -126,386 +124,370 @@ class InferResetsSpec extends FirrtlFlatSpec { } it should "not crash if a ResetType has no drivers" in { - a [CheckInitialization.RefNotInitializedException] shouldBe thrownBy { + a[CheckInitialization.RefNotInitializedException] shouldBe thrownBy { compile(s""" - |circuit test : - | module test : - | output out : Reset - | wire w : Reset - | out <= w - | out <= UInt(1) - |""".stripMargin - ) + |circuit test : + | module test : + | output out : Reset + | wire w : Reset + | out <= w + | out <= UInt(1) + |""".stripMargin) } } it should "NOT allow last connect semantics to pick the right type for Reset" in { - an [InferResets.InferResetsException] shouldBe thrownBy { + an[InferResets.InferResetsException] shouldBe thrownBy { compile(s""" - |circuit top : - | module top : - | input reset0 : AsyncReset - | input reset1 : UInt<1> - | output out : Reset - | wire w0 : Reset - | wire w1 : Reset - | w0 <= reset0 - | w1 <= reset1 - | out <= w0 - | out <= w1 - |""".stripMargin - ) + |circuit top : + | module top : + | input reset0 : AsyncReset + | input reset1 : UInt<1> + | output out : Reset + | wire w0 : Reset + | wire w1 : Reset + | w0 <= reset0 + | w1 <= reset1 + | out <= w0 + | out <= w1 + |""".stripMargin) } } it should "NOT support last connect semantics across whens" in { - an [InferResets.InferResetsException] shouldBe thrownBy { + an[InferResets.InferResetsException] shouldBe thrownBy { compile(s""" - |circuit top : - | module top : - | input reset0 : AsyncReset - | input reset1 : AsyncReset - | input reset2 : UInt<1> - | input en : UInt<1> - | output out : Reset - | wire w0 : Reset - | wire w1 : Reset - | wire w2 : Reset - | w0 <= reset0 - | w1 <= reset1 - | w2 <= reset2 - | out <= w2 - | when en : - | out <= w0 - | else : - | out <= w1 - |""".stripMargin - ) + |circuit top : + | module top : + | input reset0 : AsyncReset + | input reset1 : AsyncReset + | input reset2 : UInt<1> + | input en : UInt<1> + | output out : Reset + | wire w0 : Reset + | wire w1 : Reset + | wire w2 : Reset + | w0 <= reset0 + | w1 <= reset1 + | w2 <= reset2 + | out <= w2 + | when en : + | out <= w0 + | else : + | out <= w1 + |""".stripMargin) } } it should "not allow different Reset Types to drive a single Reset" in { - an [InferResets.InferResetsException] shouldBe thrownBy { + an[InferResets.InferResetsException] shouldBe thrownBy { val result = compile(s""" - |circuit top : - | module top : - | input reset0 : AsyncReset - | input reset1 : UInt<1> - | input en : UInt<1> - | output out : Reset - | wire w1 : Reset - | wire w2 : Reset - | w1 <= reset0 - | w2 <= reset1 - | out <= w1 - | when en : - | out <= w2 - |""".stripMargin - ) + |circuit top : + | module top : + | input reset0 : AsyncReset + | input reset1 : UInt<1> + | input en : UInt<1> + | output out : Reset + | wire w1 : Reset + | wire w2 : Reset + | w1 <= reset0 + | w2 <= reset1 + | out <= w1 + | when en : + | out <= w2 + |""".stripMargin) } } it should "allow concrete reset types to overrule invalidation" in { val result = compile(s""" - |circuit test : - | module test : - | input in : AsyncReset - | output out : Reset - | out is invalid - | out <= in - |""".stripMargin) + |circuit test : + | module test : + | input in : AsyncReset + | output out : Reset + | out is invalid + | out <= in + |""".stripMargin) result should containTree { case Port(_, "out", Output, AsyncResetType) => true } } it should "default to BoolType for Resets that are only invalidated" in { val result = compile(s""" - |circuit test : - | module test : - | output out : Reset - | out is invalid - |""".stripMargin) + |circuit test : + | module test : + | output out : Reset + | out is invalid + |""".stripMargin) result should containTree { case Port(_, "out", Output, BoolType) => true } } it should "not error if component of ResetType is invalidated and connected to an AsyncResetType" in { val result = compile(s""" - |circuit test : - | module test : - | input cond : UInt<1> - | input in : AsyncReset - | output out : Reset - | out is invalid - | when cond : - | out <= in - |""".stripMargin) + |circuit test : + | module test : + | input cond : UInt<1> + | input in : AsyncReset + | output out : Reset + | out is invalid + | when cond : + | out <= in + |""".stripMargin) result should containTree { case Port(_, "out", Output, AsyncResetType) => true } } it should "allow ResetType to drive AsyncResets or UInt<1>" in { val result1 = compile(s""" - |circuit top : - | module top : - | input in : UInt<1> - | output out : UInt<1> - | wire w : Reset - | w <= in - | out <= w - |""".stripMargin - ) + |circuit top : + | module top : + | input in : UInt<1> + | output out : UInt<1> + | wire w : Reset + | w <= in + | out <= w + |""".stripMargin) result1 should containTree { case DefWire(_, "w", BoolType) => true } val result2 = compile(s""" - |circuit top : - | module top : - | output foo : { flip a : UInt<1> } - | input bar : { flip a : UInt<1> } - | wire w : { flip a : Reset } - | foo <= w - | w <= bar - |""".stripMargin - ) + |circuit top : + | module top : + | output foo : { flip a : UInt<1> } + | input bar : { flip a : UInt<1> } + | wire w : { flip a : Reset } + | foo <= w + | w <= bar + |""".stripMargin) val AggType = BundleType(Seq(Field("a", Flip, BoolType))) result2 should containTree { case DefWire(_, "w", AggType) => true } val result3 = compile(s""" - |circuit top : - | module top : - | input in : UInt<1> - | output out : UInt<1> - | wire w : Reset - | w <- in - | out <- w - |""".stripMargin - ) + |circuit top : + | module top : + | input in : UInt<1> + | output out : UInt<1> + | wire w : Reset + | w <- in + | out <- w + |""".stripMargin) result3 should containTree { case DefWire(_, "w", BoolType) => true } } it should "error if a ResetType driving UInt<1> infers to AsyncReset" in { - an [Exception] shouldBe thrownBy { + an[Exception] shouldBe thrownBy { compile(s""" - |circuit top : - | module top : - | input in : AsyncReset - | output out : UInt<1> - | wire w : Reset - | w <= in - | out <= w - |""".stripMargin - ) + |circuit top : + | module top : + | input in : AsyncReset + | output out : UInt<1> + | wire w : Reset + | w <= in + | out <= w + |""".stripMargin) } } it should "error if a ResetType driving AsyncReset infers to UInt<1>" in { - an [Exception] shouldBe thrownBy { + an[Exception] shouldBe thrownBy { compile(s""" - |circuit top : - | module top : - | input in : UInt<1> - | output out : AsyncReset - | wire w : Reset - | w <= in - | out <= w - |""".stripMargin - ) + |circuit top : + | module top : + | input in : UInt<1> + | output out : AsyncReset + | wire w : Reset + | w <= in + | out <= w + |""".stripMargin) } } it should "not allow ResetType as an Input or ExtModule output" in { // TODO what exception should be thrown here? - an [CheckHighForm.ResetInputException] shouldBe thrownBy { + an[CheckHighForm.ResetInputException] shouldBe thrownBy { val result = compile(s""" - |circuit top : - | module top : - | input in : { foo : Reset } - | output out : Reset - | out <= in.foo - |""".stripMargin - ) + |circuit top : + | module top : + | input in : { foo : Reset } + | output out : Reset + | out <= in.foo + |""".stripMargin) } - an [CheckHighForm.ResetExtModuleOutputException] shouldBe thrownBy { + an[CheckHighForm.ResetExtModuleOutputException] shouldBe thrownBy { val result = compile(s""" - |circuit top : - | extmodule ext : - | output out : { foo : Reset } - | module top : - | output out : Reset - | inst e of ext - | out <= e.out.foo - |""".stripMargin - ) + |circuit top : + | extmodule ext : + | output out : { foo : Reset } + | module top : + | output out : Reset + | inst e of ext + | out <= e.out.foo + |""".stripMargin) } } it should "not allow Vecs to infer different Reset Types" in { - an [CheckTypes.InvalidConnect] shouldBe thrownBy { + an[CheckTypes.InvalidConnect] shouldBe thrownBy { val result = compile(s""" - |circuit top : - | module top : - | input reset0 : AsyncReset - | input reset1 : UInt<1> - | output out : Reset[2] - | out[0] <= reset0 - | out[1] <= reset1 - |""".stripMargin - ) + |circuit top : + | module top : + | input reset0 : AsyncReset + | input reset1 : UInt<1> + | output out : Reset[2] + | out[0] <= reset0 + | out[1] <= reset1 + |""".stripMargin) } } // Or is this actually an error? The behavior is that out is inferred as AsyncReset[2] ignore should "not allow Vecs only be partially inferred" in { // Some exception should be thrown, TODO figure out which one - an [Exception] shouldBe thrownBy { + an[Exception] shouldBe thrownBy { val result = compile(s""" - |circuit top : - | module top : - | input reset : AsyncReset - | output out : Reset[2] - | out is invalid - | out[0] <= reset - |""".stripMargin - ) + |circuit top : + | module top : + | input reset : AsyncReset + | output out : Reset[2] + | out is invalid + | out[0] <= reset + |""".stripMargin) } } - it should "support inferring modules that would dedup differently" in { val result = compile(s""" - |circuit top : - | module child : - | input clock : Clock - | input childReset : Reset - | input x : UInt<8> - | output z : UInt<8> - | reg r : UInt<8>, clock with : (reset => (childReset, UInt(123))) - | r <= x - | z <= r - | module child_1 : - | input clock : Clock - | input childReset : Reset - | input x : UInt<8> - | output z : UInt<8> - | reg r : UInt<8>, clock with : (reset => (childReset, UInt(123))) - | r <= x - | z <= r - | module top : - | input clock : Clock - | input reset1 : UInt<1> - | input reset2 : AsyncReset - | input x : UInt<8>[2] - | output z : UInt<8>[2] - | inst c of child - | c.clock <= clock - | c.childReset <= reset1 - | c.x <= x[0] - | z[0] <= c.z - | inst c2 of child_1 - | c2.clock <= clock - | c2.childReset <= reset2 - | c2.x <= x[1] - | z[1] <= c2.z - |""".stripMargin - ) + |circuit top : + | module child : + | input clock : Clock + | input childReset : Reset + | input x : UInt<8> + | output z : UInt<8> + | reg r : UInt<8>, clock with : (reset => (childReset, UInt(123))) + | r <= x + | z <= r + | module child_1 : + | input clock : Clock + | input childReset : Reset + | input x : UInt<8> + | output z : UInt<8> + | reg r : UInt<8>, clock with : (reset => (childReset, UInt(123))) + | r <= x + | z <= r + | module top : + | input clock : Clock + | input reset1 : UInt<1> + | input reset2 : AsyncReset + | input x : UInt<8>[2] + | output z : UInt<8>[2] + | inst c of child + | c.clock <= clock + | c.childReset <= reset1 + | c.x <= x[0] + | z[0] <= c.z + | inst c2 of child_1 + | c2.clock <= clock + | c2.childReset <= reset2 + | c2.x <= x[1] + | z[1] <= c2.z + |""".stripMargin) result should containTree { case Port(_, "childReset", Input, BoolType) => true } result should containTree { case Port(_, "childReset", Input, AsyncResetType) => true } } it should "infer based on what a component *drives* not just what drives it" in { val result = compile(s""" - |circuit top : - | module top : - | input in : AsyncReset - | output out : Reset - | wire w : Reset - | w is invalid - | out <= w - | out <= in - |""".stripMargin) + |circuit top : + | module top : + | input in : AsyncReset + | output out : Reset + | wire w : Reset + | w is invalid + | out <= w + | out <= in + |""".stripMargin) result should containTree { case DefWire(_, "w", AsyncResetType) => true } } it should "infer from connections, ignoring the fact that the invalidation wins" in { val result = compile(s""" - |circuit top : - | module top : - | input in : AsyncReset - | output out : Reset - | out <= in - | out is invalid - |""".stripMargin) + |circuit top : + | module top : + | input in : AsyncReset + | output out : Reset + | out <= in + | out is invalid + |""".stripMargin) result should containTree { case Port(_, "out", Output, AsyncResetType) => true } } // The backwards type propagation constrains `w` to be the same as both `out0` and `out1` it should "not allow an invalidated Wire to drive both a UInt<1> and an AsyncReset" in { - an [InferResets.InferResetsException] shouldBe thrownBy { + an[InferResets.InferResetsException] shouldBe thrownBy { val result = compile(s""" - |circuit top : - | module top : - | input in0 : AsyncReset - | input in1 : UInt<1> - | output out0 : Reset - | output out1 : Reset - | wire w : Reset - | w is invalid - | out0 <= w - | out1 <= w - | out0 <= in0 - | out1 <= in1 - |""".stripMargin - ) + |circuit top : + | module top : + | input in0 : AsyncReset + | input in1 : UInt<1> + | output out0 : Reset + | output out1 : Reset + | wire w : Reset + | w is invalid + | out0 <= w + | out1 <= w + | out0 <= in0 + | out1 <= in1 + |""".stripMargin) } } it should "not propagate type info from downstream across a cast" in { val result = compile(s""" - |circuit top : - | module top : - | input in0 : AsyncReset - | input in1 : UInt<1> - | output out0 : Reset - | output out1 : Reset - | wire w : Reset - | w is invalid - | out0 <= asAsyncReset(w) - | out1 <= w - | out0 <= in0 - | out1 <= in1 - |""".stripMargin - ) + |circuit top : + | module top : + | input in0 : AsyncReset + | input in1 : UInt<1> + | output out0 : Reset + | output out1 : Reset + | wire w : Reset + | w is invalid + | out0 <= asAsyncReset(w) + | out1 <= w + | out0 <= in0 + | out1 <= in1 + |""".stripMargin) result should containTree { case Port(_, "out0", Output, AsyncResetType) => true } } // This tests for a bug unrelated to support or lackthereof for last connect in inference it should "take into account both internal and external constraints on Module port types" in { val result = compile(s""" - |circuit top : - | module child : - | input i : AsyncReset - | output o : Reset - | o <= i - | module top : - | input in : AsyncReset - | output out : AsyncReset - | inst c of child - | c.o is invalid - | c.i <= in - | out <= c.o - |""".stripMargin) + |circuit top : + | module child : + | input i : AsyncReset + | output o : Reset + | o <= i + | module top : + | input in : AsyncReset + | output out : AsyncReset + | inst c of child + | c.o is invalid + | c.i <= in + | out <= c.o + |""".stripMargin) result should containTree { case Port(_, "o", Output, AsyncResetType) => true } } it should "not crash on combinational loops" in { - a [CheckCombLoops.CombLoopException] shouldBe thrownBy { - val result = compile(s""" - |circuit top : - | module top : - | input in : AsyncReset - | output out : Reset - | wire w0 : Reset - | wire w1 : Reset - | w0 <= in - | w0 <= w1 - | w1 <= w0 - | out <= in - |""".stripMargin, + a[CheckCombLoops.CombLoopException] shouldBe thrownBy { + val result = compile( + s""" + |circuit top : + | module top : + | input in : AsyncReset + | output out : Reset + | wire w0 : Reset + | wire w1 : Reset + | w0 <= in + | w0 <= w1 + | w1 <= w0 + | out <= in + |""".stripMargin, compiler = new LowFirrtlCompiler ) } diff --git a/src/test/scala/firrtlTests/InfoSpec.scala b/src/test/scala/firrtlTests/InfoSpec.scala index a2410f9d..172ddfb9 100644 --- a/src/test/scala/firrtlTests/InfoSpec.scala +++ b/src/test/scala/firrtlTests/InfoSpec.scala @@ -13,9 +13,9 @@ class InfoSpec extends FirrtlFlatSpec with FirrtlMatchers { (new VerilogCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty) def compileBody(body: String) = { val str = """ - |circuit Test : - | module Test : - |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") + |circuit Test : + | module Test : + |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") compile(str) } @@ -27,156 +27,151 @@ class InfoSpec extends FirrtlFlatSpec with FirrtlMatchers { "Source locators on module ports" should "be propagated to Verilog" in { val result = compileBody(s""" - |input x : UInt<8> $Info1 - |output y : UInt<8> $Info2 - |y <= x""".stripMargin - ) + |input x : UInt<8> $Info1 + |output y : UInt<8> $Info2 + |y <= x""".stripMargin) result should containTree { case Port(Info1, "x", Input, _) => true } - result should containLine (s"input [7:0] x, //$Info1") + result should containLine(s"input [7:0] x, //$Info1") result should containTree { case Port(Info2, "y", Output, _) => true } - result should containLine (s"output [7:0] y //$Info2") + result should containLine(s"output [7:0] y //$Info2") } "Source locators on aggregates" should "be propagated to Verilog" in { val result = compileBody(s""" - |input io : { x : UInt<8>, flip y : UInt<8> } $Info1 - |io.y <= io.x""".stripMargin - ) + |input io : { x : UInt<8>, flip y : UInt<8> } $Info1 + |io.y <= io.x""".stripMargin) result should containTree { case Port(Info1, "io_x", Input, _) => true } - result should containLine (s"input [7:0] io_x, //$Info1") + result should containLine(s"input [7:0] io_x, //$Info1") result should containTree { case Port(Info1, "io_y", Output, _) => true } - result should containLine (s"output [7:0] io_y //$Info1") + result should containLine(s"output [7:0] io_y //$Info1") } "Source locators" should "be propagated on declarations" in { val result = compileBody(s""" - |input clock : Clock - |input x : UInt<8> - |output y : UInt<8> - |reg r : UInt<8>, clock $Info1 - |wire w : UInt<8> $Info2 - |node n = or(w, x) $Info3 - |w <= and(x, r) - |r <= or(n, r) - |y <= r""".stripMargin - ) - result should containTree { case DefRegister(Info1, "r", _,_,_,_) => true } - result should containLine (s"reg [7:0] r; //$Info1") + |input clock : Clock + |input x : UInt<8> + |output y : UInt<8> + |reg r : UInt<8>, clock $Info1 + |wire w : UInt<8> $Info2 + |node n = or(w, x) $Info3 + |w <= and(x, r) + |r <= or(n, r) + |y <= r""".stripMargin) + result should containTree { case DefRegister(Info1, "r", _, _, _, _) => true } + result should containLine(s"reg [7:0] r; //$Info1") result should containTree { case DefNode(Info2, "w", _) => true } - result should containLine (s"wire [7:0] w = x & r; //$Info2") // Node "w" declaration in Verilog + result should containLine(s"wire [7:0] w = x & r; //$Info2") // Node "w" declaration in Verilog result should containTree { case DefNode(Info3, "n", _) => true } - result should containLine (s"wire [7:0] n = w | x; //$Info3") + result should containLine(s"wire [7:0] n = w | x; //$Info3") } it should "be propagated on memories" in { val result = compileBody(s""" - |input clock : Clock - |input addr : UInt<5> - |output z : UInt<8> - |mem m: $Info1 - | data-type => UInt<8> - | depth => 32 - | read-latency => 0 - | write-latency => 1 - | reader => r - | writer => w - |m.r.clk <= clock - |m.r.addr <= addr - |m.r.en <= UInt(1) - |m.w.clk <= clock - |m.w.addr <= addr - |m.w.en <= UInt(0) - |m.w.data <= UInt(0) - |m.w.mask <= UInt(0) - |z <= m.r.data - |""".stripMargin - ) + |input clock : Clock + |input addr : UInt<5> + |output z : UInt<8> + |mem m: $Info1 + | data-type => UInt<8> + | depth => 32 + | read-latency => 0 + | write-latency => 1 + | reader => r + | writer => w + |m.r.clk <= clock + |m.r.addr <= addr + |m.r.en <= UInt(1) + |m.w.clk <= clock + |m.w.addr <= addr + |m.w.en <= UInt(0) + |m.w.data <= UInt(0) + |m.w.mask <= UInt(0) + |z <= m.r.data + |""".stripMargin) - result should containTree { case DefMemory(Info1, "m", _,_,_,_,_,_,_,_) => true } - result should containLine (s"reg [7:0] m [0:31]; //$Info1") - result should containLine (s"wire [7:0] m_r_data; //$Info1") - result should containLine (s"wire [4:0] m_r_addr; //$Info1") - result should containLine (s"wire [7:0] m_w_data; //$Info1") - result should containLine (s"wire [4:0] m_w_addr; //$Info1") - result should containLine (s"wire m_w_mask; //$Info1") - result should containLine (s"wire m_w_en; //$Info1") - result should containLine (s"assign m_r_data = m[m_r_addr]; //$Info1") - result should containLine (s"m[m_w_addr] <= m_w_data; //$Info1") + result should containTree { case DefMemory(Info1, "m", _, _, _, _, _, _, _, _) => true } + result should containLine(s"reg [7:0] m [0:31]; //$Info1") + result should containLine(s"wire [7:0] m_r_data; //$Info1") + result should containLine(s"wire [4:0] m_r_addr; //$Info1") + result should containLine(s"wire [7:0] m_w_data; //$Info1") + result should containLine(s"wire [4:0] m_w_addr; //$Info1") + result should containLine(s"wire m_w_mask; //$Info1") + result should containLine(s"wire m_w_en; //$Info1") + result should containLine(s"assign m_r_data = m[m_r_addr]; //$Info1") + result should containLine(s"m[m_w_addr] <= m_w_data; //$Info1") } it should "be propagated on instances" in { val result = compile(s""" - |circuit Test : - | module Child : - | output io : { flip in : UInt<8>, out : UInt<8> } - | io.out <= io.in - | module Test : - | output io : { flip in : UInt<8>, out : UInt<8> } - | inst c of Child $Info1 - | io <= c.io - |""".stripMargin - ) + |circuit Test : + | module Child : + | output io : { flip in : UInt<8>, out : UInt<8> } + | io.out <= io.in + | module Test : + | output io : { flip in : UInt<8>, out : UInt<8> } + | inst c of Child $Info1 + | io <= c.io + |""".stripMargin) result should containTree { case WDefInstance(Info1, "c", "Child", _) => true } - result should containLine (s"Child c ( //$Info1") + result should containLine(s"Child c ( //$Info1") } it should "be propagated across direct node assignments and connections" in { val result = compile(s""" - |circuit Test : - | module Test : - | input in : UInt<8> - | output out : UInt<8> - | node a = in $Info1 - | node b = a - | out <= b - |""".stripMargin - ) - result should containTree { case Connect(Info1, Reference("out", _,_,_), Reference("in", _,_,_)) => true } - result should containLine (s"assign out = in; //$Info1") + |circuit Test : + | module Test : + | input in : UInt<8> + | output out : UInt<8> + | node a = in $Info1 + | node b = a + | out <= b + |""".stripMargin) + result should containTree { case Connect(Info1, Reference("out", _, _, _), Reference("in", _, _, _)) => true } + result should containLine(s"assign out = in; //$Info1") } "source locators" should "be propagated through ExpandWhens" in { - val input = """ - |;buildInfoPackage: chisel3, version: 3.1-SNAPSHOT, scalaVersion: 2.11.7, sbtVersion: 0.13.11, builtAtString: 2016-11-26 18:48:38.030, builtAtMillis: 1480186118030 - |circuit GCD : - | module GCD : - | input clock : Clock - | input reset : UInt<1> - | output io : {flip a : UInt<32>, flip b : UInt<32>, flip e : UInt<1>, z : UInt<32>, v : UInt<1>} - | - | io is invalid - | io is invalid - | reg x : UInt<32>, clock @[GCD.scala 15:14] - | reg y : UInt<32>, clock @[GCD.scala 16:14] - | node _T_14 = gt(x, y) @[GCD.scala 17:11] - | when _T_14 : @[GCD.scala 17:18] - | node _T_15 = sub(x, y) @[GCD.scala 17:27] - | node _T_16 = tail(_T_15, 1) @[GCD.scala 17:27] - | x <= _T_16 @[GCD.scala 17:22] - | skip @[GCD.scala 17:18] - | node _T_18 = eq(_T_14, UInt<1>("h00")) @[GCD.scala 17:18] - | when _T_18 : @[GCD.scala 18:18] - | node _T_19 = sub(y, x) @[GCD.scala 18:27] - | node _T_20 = tail(_T_19, 1) @[GCD.scala 18:27] - | y <= _T_20 @[GCD.scala 18:22] - | skip @[GCD.scala 18:18] - | when io.e : @[GCD.scala 19:15] - | x <= io.a @[GCD.scala 19:19] - | y <= io.b @[GCD.scala 19:30] - | skip @[GCD.scala 19:15] - | io.z <= x @[GCD.scala 20:8] - | node _T_22 = eq(y, UInt<1>("h00")) @[GCD.scala 21:13] - | io.v <= _T_22 @[GCD.scala 21:8] - | + val input = + """ + |;buildInfoPackage: chisel3, version: 3.1-SNAPSHOT, scalaVersion: 2.11.7, sbtVersion: 0.13.11, builtAtString: 2016-11-26 18:48:38.030, builtAtMillis: 1480186118030 + |circuit GCD : + | module GCD : + | input clock : Clock + | input reset : UInt<1> + | output io : {flip a : UInt<32>, flip b : UInt<32>, flip e : UInt<1>, z : UInt<32>, v : UInt<1>} + | + | io is invalid + | io is invalid + | reg x : UInt<32>, clock @[GCD.scala 15:14] + | reg y : UInt<32>, clock @[GCD.scala 16:14] + | node _T_14 = gt(x, y) @[GCD.scala 17:11] + | when _T_14 : @[GCD.scala 17:18] + | node _T_15 = sub(x, y) @[GCD.scala 17:27] + | node _T_16 = tail(_T_15, 1) @[GCD.scala 17:27] + | x <= _T_16 @[GCD.scala 17:22] + | skip @[GCD.scala 17:18] + | node _T_18 = eq(_T_14, UInt<1>("h00")) @[GCD.scala 17:18] + | when _T_18 : @[GCD.scala 18:18] + | node _T_19 = sub(y, x) @[GCD.scala 18:27] + | node _T_20 = tail(_T_19, 1) @[GCD.scala 18:27] + | y <= _T_20 @[GCD.scala 18:22] + | skip @[GCD.scala 18:18] + | when io.e : @[GCD.scala 19:15] + | x <= io.a @[GCD.scala 19:19] + | y <= io.b @[GCD.scala 19:30] + | skip @[GCD.scala 19:15] + | io.z <= x @[GCD.scala 20:8] + | node _T_22 = eq(y, UInt<1>("h00")) @[GCD.scala 21:13] + | io.v <= _T_22 @[GCD.scala 21:8] + | """.stripMargin val result = (new LowFirrtlCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty) - result should containLine ("node _GEN_0 = mux(_T_14, _T_16, x) @[GCD.scala 17:18 GCD.scala 17:22 GCD.scala 15:14]") - result should containLine ("node _GEN_2 = mux(io_e, io_a, _GEN_0) @[GCD.scala 19:15 GCD.scala 19:19]") - result should containLine ("x <= _GEN_2") - result should containLine ("node _GEN_1 = mux(_T_18, _T_20, y) @[GCD.scala 18:18 GCD.scala 18:22 GCD.scala 16:14]") - result should containLine ("node _GEN_3 = mux(io_e, io_b, _GEN_1) @[GCD.scala 19:15 GCD.scala 19:30]") - result should containLine ("y <= _GEN_3") + result should containLine("node _GEN_0 = mux(_T_14, _T_16, x) @[GCD.scala 17:18 GCD.scala 17:22 GCD.scala 15:14]") + result should containLine("node _GEN_2 = mux(io_e, io_a, _GEN_0) @[GCD.scala 19:15 GCD.scala 19:19]") + result should containLine("x <= _GEN_2") + result should containLine("node _GEN_1 = mux(_T_18, _T_20, y) @[GCD.scala 18:18 GCD.scala 18:22 GCD.scala 16:14]") + result should containLine("node _GEN_3 = mux(io_e, io_b, _GEN_1) @[GCD.scala 19:15 GCD.scala 19:30]") + result should containLine("y <= _GEN_3") } "source locators for append option" should "use multiinfo" in { @@ -195,71 +190,68 @@ class InfoSpec extends FirrtlFlatSpec with FirrtlMatchers { "source locators for basic register updates" should "be propagated to Verilog" in { val result = compileBody(s""" - |input clock : Clock - |input reset : UInt<1> - |output io : { flip in : UInt<8>, out : UInt<8>} - |reg r : UInt<8>, clock - |r <= io.in $Info1 - |io.out <= r - |""".stripMargin - ) - result should containLine (s"r <= io_in; //$Info1") + |input clock : Clock + |input reset : UInt<1> + |output io : { flip in : UInt<8>, out : UInt<8>} + |reg r : UInt<8>, clock + |r <= io.in $Info1 + |io.out <= r + |""".stripMargin) + result should containLine(s"r <= io_in; //$Info1") } "source locators for register reset" should "be propagated to Verilog" in { val result = compileBody(s""" - |input clock : Clock - |input reset : UInt<1> - |output io : { flip in : UInt<8>, out : UInt<8>} - |reg r : UInt<8>, clock with : (reset => (reset, UInt<8>("h0"))) $Info3 - |r <= io.in $Info1 - |io.out <= r - |""".stripMargin - ) - result should containLine (s"if (reset) begin //$Info3") - result should containLine (s"r <= 8'h0; //$Info3") - result should containLine (s"r <= io_in; //$Info1") + |input clock : Clock + |input reset : UInt<1> + |output io : { flip in : UInt<8>, out : UInt<8>} + |reg r : UInt<8>, clock with : (reset => (reset, UInt<8>("h0"))) $Info3 + |r <= io.in $Info1 + |io.out <= r + |""".stripMargin) + result should containLine(s"if (reset) begin //$Info3") + result should containLine(s"r <= 8'h0; //$Info3") + result should containLine(s"r <= io_in; //$Info1") } "source locators for complex register updates" should "be propagated to Verilog" in { val result = compileBody(s""" - |input clock : Clock - |input reset : UInt<1> - |output io : { flip in : UInt<8>, flip a : UInt<1>, out : UInt<8>} - |reg r : UInt<8>, clock with : (reset => (reset, UInt<8>("h0"))) $Info1 - |r <= UInt<2>(2) $Info2 - |when io.a : $Info3 - | r <= io.in $Info4 - |io.out <= r - |""".stripMargin - ) - result should containLine (s"if (reset) begin //$Info1") - result should containLine (s"r <= 8'h0; //$Info1") - result should containLine (s"end else if (io_a) begin //$Info3") - result should containLine (s"r <= io_in; //$Info4") - result should containLine (s"r <= 8'h2; //$Info2") + |input clock : Clock + |input reset : UInt<1> + |output io : { flip in : UInt<8>, flip a : UInt<1>, out : UInt<8>} + |reg r : UInt<8>, clock with : (reset => (reset, UInt<8>("h0"))) $Info1 + |r <= UInt<2>(2) $Info2 + |when io.a : $Info3 + | r <= io.in $Info4 + |io.out <= r + |""".stripMargin) + result should containLine(s"if (reset) begin //$Info1") + result should containLine(s"r <= 8'h0; //$Info1") + result should containLine(s"end else if (io_a) begin //$Info3") + result should containLine(s"r <= io_in; //$Info4") + result should containLine(s"r <= 8'h2; //$Info2") } "FileInfo" should "be able to contain a escaped characters" in { def input(info: String): String = s"""circuit m: @[$info] - | module m: - | skip - |""".stripMargin + | module m: + | skip + |""".stripMargin def parseInfo(info: String): FileInfo = { firrtl.Parser.parse(input(info)).info.asInstanceOf[FileInfo] } - parseInfo("test\\ntest").escaped should be ("test\\ntest") - parseInfo("test\\ntest").unescaped should be ("test\ntest") - parseInfo("test\\ttest").escaped should be ("test\\ttest") - parseInfo("test\\ttest").unescaped should be ("test\ttest") - parseInfo("test\\\\test").escaped should be ("test\\\\test") - parseInfo("test\\\\test").unescaped should be ("test\\test") - parseInfo("test\\]test").escaped should be ("test\\]test") - parseInfo("test\\]test").unescaped should be ("test]test") - parseInfo("test[\\][\\]test").escaped should be ("test[\\][\\]test") - parseInfo("test[\\][\\]test").unescaped should be ("test[][]test") + parseInfo("test\\ntest").escaped should be("test\\ntest") + parseInfo("test\\ntest").unescaped should be("test\ntest") + parseInfo("test\\ttest").escaped should be("test\\ttest") + parseInfo("test\\ttest").unescaped should be("test\ttest") + parseInfo("test\\\\test").escaped should be("test\\\\test") + parseInfo("test\\\\test").unescaped should be("test\\test") + parseInfo("test\\]test").escaped should be("test\\]test") + parseInfo("test\\]test").unescaped should be("test]test") + parseInfo("test[\\][\\]test").escaped should be("test[\\][\\]test") + parseInfo("test[\\][\\]test").unescaped should be("test[][]test") } } diff --git a/src/test/scala/firrtlTests/InlineInstancesTests.scala b/src/test/scala/firrtlTests/InlineInstancesTests.scala index 27102785..e4f711ed 100644 --- a/src/test/scala/firrtlTests/InlineInstancesTests.scala +++ b/src/test/scala/firrtlTests/InlineInstancesTests.scala @@ -12,8 +12,8 @@ import firrtl.stage.TransformManager import firrtl.options.Dependency /** - * Tests inline instances transformation - */ + * Tests inline instances transformation + */ class InlineInstancesTests extends LowTransformSpec { def transform = new InlineInstances def inline(mod: String): Annotation = { @@ -22,181 +22,181 @@ class InlineInstancesTests extends LowTransformSpec { val name = if (parts.size == 1) modName else ComponentName(parts.tail.mkString("."), modName) InlineAnnotation(name) } - // Set this to debug, this will apply to all tests - // Logger.setLevel(this.getClass, Debug) - "The module Inline" should "be inlined" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline - | i.a <= a - | b <= i.b - | module Inline : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | wire i_a : UInt<32> - | wire i_b : UInt<32> - | i_b <= i_a - | b <= i_b - | i_a <= a""".stripMargin - execute(input, check, Seq(inline("Inline"))) - } + // Set this to debug, this will apply to all tests + // Logger.setLevel(this.getClass, Debug) + "The module Inline" should "be inlined" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline + | i.a <= a + | b <= i.b + | module Inline : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i_a : UInt<32> + | wire i_b : UInt<32> + | i_b <= i_a + | b <= i_b + | i_a <= a""".stripMargin + execute(input, check, Seq(inline("Inline"))) + } - "The all instances of Simple" should "be inlined" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i0 of Simple - | inst i1 of Simple - | i0.a <= a - | i1.a <= i0.b - | b <= i1.b - | module Simple : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | wire i0_a : UInt<32> - | wire i0_b : UInt<32> - | i0_b <= i0_a - | wire i1_a : UInt<32> - | wire i1_b : UInt<32> - | i1_b <= i1_a - | b <= i1_b - | i0_a <= a - | i1_a <= i0_b""".stripMargin - execute(input, check, Seq(inline("Simple"))) - } + "The all instances of Simple" should "be inlined" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i0 of Simple + | inst i1 of Simple + | i0.a <= a + | i1.a <= i0.b + | b <= i1.b + | module Simple : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i0_a : UInt<32> + | wire i0_b : UInt<32> + | i0_b <= i0_a + | wire i1_a : UInt<32> + | wire i1_b : UInt<32> + | i1_b <= i1_a + | b <= i1_b + | i0_a <= a + | i1_a <= i0_b""".stripMargin + execute(input, check, Seq(inline("Simple"))) + } - "Only one instance of Simple" should "be inlined" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i0 of Simple - | inst i1 of Simple - | i0.a <= a - | i1.a <= i0.b - | b <= i1.b - | module Simple : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | wire i0_a : UInt<32> - | wire i0_b : UInt<32> - | i0_b <= i0_a - | inst i1 of Simple - | b <= i1.b - | i0_a <= a - | i1.a <= i0_b - | module Simple : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - execute(input, check, Seq(inline("Top.i0"))) - } + "Only one instance of Simple" should "be inlined" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i0 of Simple + | inst i1 of Simple + | i0.a <= a + | i1.a <= i0.b + | b <= i1.b + | module Simple : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i0_a : UInt<32> + | wire i0_b : UInt<32> + | i0_b <= i0_a + | inst i1 of Simple + | b <= i1.b + | i0_a <= a + | i1.a <= i0_b + | module Simple : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + execute(input, check, Seq(inline("Top.i0"))) + } - "All instances of A" should "be inlined" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i0 of A - | inst i1 of B - | i0.a <= a - | i1.a <= i0.b - | b <= i1.b - | module A : - | input a : UInt<32> - | output b : UInt<32> - | b <= a - | module B : - | input a : UInt<32> - | output b : UInt<32> - | inst i of A - | i.a <= a - | b <= i.b""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | wire i0_a : UInt<32> - | wire i0_b : UInt<32> - | i0_b <= i0_a - | inst i1 of B - | b <= i1.b - | i0_a <= a - | i1.a <= i0_b - | module B : - | input a : UInt<32> - | output b : UInt<32> - | wire i_a : UInt<32> - | wire i_b : UInt<32> - | i_b <= i_a - | b <= i_b - | i_a <= a""".stripMargin - execute(input, check, Seq(inline("A"))) - } + "All instances of A" should "be inlined" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i0 of A + | inst i1 of B + | i0.a <= a + | i1.a <= i0.b + | b <= i1.b + | module A : + | input a : UInt<32> + | output b : UInt<32> + | b <= a + | module B : + | input a : UInt<32> + | output b : UInt<32> + | inst i of A + | i.a <= a + | b <= i.b""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i0_a : UInt<32> + | wire i0_b : UInt<32> + | i0_b <= i0_a + | inst i1 of B + | b <= i1.b + | i0_a <= a + | i1.a <= i0_b + | module B : + | input a : UInt<32> + | output b : UInt<32> + | wire i_a : UInt<32> + | wire i_b : UInt<32> + | i_b <= i_a + | b <= i_b + | i_a <= a""".stripMargin + execute(input, check, Seq(inline("A"))) + } - "Non-inlined instances" should "still prepend prefix" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i of A - | i.a <= a - | b <= i.b - | module A : - | input a : UInt<32> - | output b : UInt<32> - | inst i of B - | i.a <= a - | b <= i.b - | module B : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | wire i_a : UInt<32> - | wire i_b : UInt<32> - | inst i_i of B - | i_b <= i_i.b - | i_i.a <= i_a - | b <= i_b - | i_a <= a - | module B : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - execute(input, check, Seq(inline("A"))) - } + "Non-inlined instances" should "still prepend prefix" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i of A + | i.a <= a + | b <= i.b + | module A : + | input a : UInt<32> + | output b : UInt<32> + | inst i of B + | i.a <= a + | b <= i.b + | module B : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i_a : UInt<32> + | wire i_b : UInt<32> + | inst i_i of B + | i_b <= i_i.b + | i_i.a <= i_a + | b <= i_b + | i_a <= a + | module B : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + execute(input, check, Seq(inline("A"))) + } "A module with nested inlines" should "still prepend prefixes" in { val input = @@ -291,57 +291,57 @@ class InlineInstancesTests extends LowTransformSpec { execute(input, check, Seq(inline("Foo"), inline("Foo.bar"))) } - // ---- Errors ---- - // 1) ext module - "External module" should "not be inlined" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i of A - | i.a <= a - | b <= i.b - | extmodule A : - | input a : UInt<32> - | output b : UInt<32>""".stripMargin - failingexecute(input, Seq(inline("A"))) - } - // 2) ext instance - "External instance" should "not be inlined" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i of A - | i.a <= a - | b <= i.b - | extmodule A : - | input a : UInt<32> - | output b : UInt<32>""".stripMargin - failingexecute(input, Seq(inline("A"))) - } - // 3) no module - "Inlined module" should "exist" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - failingexecute(input, Seq(inline("A"))) - } - // 4) no inst - "Inlined instance" should "exist" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - failingexecute(input, Seq(inline("A"))) - } + // ---- Errors ---- + // 1) ext module + "External module" should "not be inlined" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i of A + | i.a <= a + | b <= i.b + | extmodule A : + | input a : UInt<32> + | output b : UInt<32>""".stripMargin + failingexecute(input, Seq(inline("A"))) + } + // 2) ext instance + "External instance" should "not be inlined" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i of A + | i.a <= a + | b <= i.b + | extmodule A : + | input a : UInt<32> + | output b : UInt<32>""".stripMargin + failingexecute(input, Seq(inline("A"))) + } + // 3) no module + "Inlined module" should "exist" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + failingexecute(input, Seq(inline("A"))) + } + // 4) no inst + "Inlined instance" should "exist" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + failingexecute(input, Seq(inline("A"))) + } "Jack's Bug" should "not fail" in { @@ -384,163 +384,167 @@ class InlineInstancesTests extends LowTransformSpec { override def duplicate(n: ReferenceTarget): Annotation = DummyAnno(n) } "annotations" should "be renamed" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline - | i.a <= a - | b <= i.b - | module Inline : - | input a : UInt<32> - | output b : UInt<32> - | inst foo of NestedInline - | inst bar of NestedNoInline - | foo.a <= a - | bar.a <= foo.b - | b <= bar.b - | module NestedInline : - | input a : UInt<32> - | output b : UInt<32> - | b <= a - | module NestedNoInline : - | input a : UInt<32> - | output b : UInt<32> - | b <= a - |""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | wire i_a : UInt<32> - | wire i_b : UInt<32> - | wire i_foo_a : UInt<32> - | wire i_foo_b : UInt<32> - | i_foo_b <= i_foo_a - | inst i_bar of NestedNoInline - | i_b <= i_bar.b - | i_foo_a <= i_a - | i_bar.a <= i_foo_b - | b <= i_b - | i_a <= a - | module NestedNoInline : - | input a : UInt<32> - | output b : UInt<32> - | b <= a - |""".stripMargin + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline + | i.a <= a + | b <= i.b + | module Inline : + | input a : UInt<32> + | output b : UInt<32> + | inst foo of NestedInline + | inst bar of NestedNoInline + | foo.a <= a + | bar.a <= foo.b + | b <= bar.b + | module NestedInline : + | input a : UInt<32> + | output b : UInt<32> + | b <= a + | module NestedNoInline : + | input a : UInt<32> + | output b : UInt<32> + | b <= a + |""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i_a : UInt<32> + | wire i_b : UInt<32> + | wire i_foo_a : UInt<32> + | wire i_foo_b : UInt<32> + | i_foo_b <= i_foo_a + | inst i_bar of NestedNoInline + | i_b <= i_bar.b + | i_foo_a <= i_a + | i_bar.a <= i_foo_b + | b <= i_b + | i_a <= a + | module NestedNoInline : + | input a : UInt<32> + | output b : UInt<32> + | b <= a + |""".stripMargin val top = CircuitTarget("Top").module("Top") val inlined = top.instOf("i", "Inline") val nestedInlined = top.instOf("i", "Inline").instOf("foo", "NestedInline") val nestedNotInlined = top.instOf("i", "Inline").instOf("bar", "NestedNoInline") - executeWithAnnos(input, check, - Seq( - inline("Inline"), - inline("NestedInline"), - NoCircuitDedupAnnotation, - DummyAnno(inlined.ref("a")), - DummyAnno(inlined.ref("b")), - DummyAnno(nestedInlined.ref("a")), - DummyAnno(nestedInlined.ref("b")), - DummyAnno(nestedNotInlined.ref("a")), - DummyAnno(nestedNotInlined.ref("b")) - ), - Seq( - DummyAnno(top.ref("i_a")), - DummyAnno(top.ref("i_b")), - DummyAnno(top.ref("i_foo_a")), - DummyAnno(top.ref("i_foo_b")), - DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("a")), - DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("b")) - ) - ) + executeWithAnnos( + input, + check, + Seq( + inline("Inline"), + inline("NestedInline"), + NoCircuitDedupAnnotation, + DummyAnno(inlined.ref("a")), + DummyAnno(inlined.ref("b")), + DummyAnno(nestedInlined.ref("a")), + DummyAnno(nestedInlined.ref("b")), + DummyAnno(nestedNotInlined.ref("a")), + DummyAnno(nestedNotInlined.ref("b")) + ), + Seq( + DummyAnno(top.ref("i_a")), + DummyAnno(top.ref("i_b")), + DummyAnno(top.ref("i_foo_a")), + DummyAnno(top.ref("i_foo_b")), + DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("a")), + DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("b")) + ) + ) } "inlining both grandparent and grandchild" should "should work" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline - | i.a <= a - | b <= i.b - | module Inline : - | input a : UInt<32> - | output b : UInt<32> - | inst foo of NestedInline - | inst bar of NestedNoInline - | foo.a <= a - | bar.a <= foo.b - | b <= bar.b - | module NestedInline : - | input a : UInt<32> - | output b : UInt<32> - | b <= a - | module NestedNoInline : - | input a : UInt<32> - | output b : UInt<32> - | inst foo of NestedInline - | foo.a <= a - | b <= foo.b - |""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | wire i_a : UInt<32> - | wire i_b : UInt<32> - | wire i_foo_a : UInt<32> - | wire i_foo_b : UInt<32> - | i_foo_b <= i_foo_a - | inst i_bar of NestedNoInline - | i_b <= i_bar.b - | i_foo_a <= i_a - | i_bar.a <= i_foo_b - | b <= i_b - | i_a <= a - | module NestedNoInline : - | input a : UInt<32> - | output b : UInt<32> - | wire foo_a : UInt<32> - | wire foo_b : UInt<32> - | foo_b <= foo_a - | b <= foo_b - | foo_a <= a - |""".stripMargin + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline + | i.a <= a + | b <= i.b + | module Inline : + | input a : UInt<32> + | output b : UInt<32> + | inst foo of NestedInline + | inst bar of NestedNoInline + | foo.a <= a + | bar.a <= foo.b + | b <= bar.b + | module NestedInline : + | input a : UInt<32> + | output b : UInt<32> + | b <= a + | module NestedNoInline : + | input a : UInt<32> + | output b : UInt<32> + | inst foo of NestedInline + | foo.a <= a + | b <= foo.b + |""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i_a : UInt<32> + | wire i_b : UInt<32> + | wire i_foo_a : UInt<32> + | wire i_foo_b : UInt<32> + | i_foo_b <= i_foo_a + | inst i_bar of NestedNoInline + | i_b <= i_bar.b + | i_foo_a <= i_a + | i_bar.a <= i_foo_b + | b <= i_b + | i_a <= a + | module NestedNoInline : + | input a : UInt<32> + | output b : UInt<32> + | wire foo_a : UInt<32> + | wire foo_b : UInt<32> + | foo_b <= foo_a + | b <= foo_b + | foo_a <= a + |""".stripMargin val top = CircuitTarget("Top").module("Top") val inlined = top.instOf("i", "Inline") val nestedInlined = inlined.instOf("foo", "NestedInline") val nestedNotInlined = inlined.instOf("bar", "NestedNoInline") val innerNestedInlined = nestedNotInlined.instOf("foo", "NestedInline") - executeWithAnnos(input, check, - Seq( - inline("Inline"), - inline("NestedInline"), - DummyAnno(inlined.ref("a")), - DummyAnno(inlined.ref("b")), - DummyAnno(nestedInlined.ref("a")), - DummyAnno(nestedInlined.ref("b")), - DummyAnno(nestedNotInlined.ref("a")), - DummyAnno(nestedNotInlined.ref("b")), - DummyAnno(innerNestedInlined.ref("a")), - DummyAnno(innerNestedInlined.ref("b")) - ), - Seq( - DummyAnno(top.ref("i_a")), - DummyAnno(top.ref("i_b")), - DummyAnno(top.ref("i_foo_a")), - DummyAnno(top.ref("i_foo_b")), - DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("a")), - DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("b")), - DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("foo_a")), - DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("foo_b")) - ) - ) + executeWithAnnos( + input, + check, + Seq( + inline("Inline"), + inline("NestedInline"), + DummyAnno(inlined.ref("a")), + DummyAnno(inlined.ref("b")), + DummyAnno(nestedInlined.ref("a")), + DummyAnno(nestedInlined.ref("b")), + DummyAnno(nestedNotInlined.ref("a")), + DummyAnno(nestedNotInlined.ref("b")), + DummyAnno(innerNestedInlined.ref("a")), + DummyAnno(innerNestedInlined.ref("b")) + ), + Seq( + DummyAnno(top.ref("i_a")), + DummyAnno(top.ref("i_b")), + DummyAnno(top.ref("i_foo_a")), + DummyAnno(top.ref("i_foo_b")), + DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("a")), + DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("b")), + DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("foo_a")), + DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("foo_b")) + ) + ) } "InlineInstances" should "properly invalidate ResolveKinds" in { @@ -562,7 +566,7 @@ class InlineInstancesTests extends LowTransformSpec { val result = manager.execute(state) result shouldNot containTree { case WRef("i_a", _, PortKind, _) => true } - result should containTree { case WRef("i_a", _, WireKind, _) => true } + result should containTree { case WRef("i_a", _, WireKind, _) => true } } } diff --git a/src/test/scala/firrtlTests/IntegrationSpec.scala b/src/test/scala/firrtlTests/IntegrationSpec.scala index b399923f..ff21a90b 100644 --- a/src/test/scala/firrtlTests/IntegrationSpec.scala +++ b/src/test/scala/firrtlTests/IntegrationSpec.scala @@ -23,10 +23,11 @@ class GCDSplitEmissionExecutionTest extends FirrtlFlatSpec { val optionsManager = new ExecutionOptionsManager("GCDTesterSplitEmission") with HasFirrtlOptions { commonOptions = CommonOptions(topName = top, targetDirName = testDir.getPath) firrtlOptions = FirrtlExecutionOptions( - inputFileNameOverride = sourceFile.getPath, - compilerName = "verilog", - infoModeName = "ignore", - emitOneFilePerModule = true) + inputFileNameOverride = sourceFile.getPath, + compilerName = "verilog", + infoModeName = "ignore", + emitOneFilePerModule = true + ) } firrtl.Driver.execute(optionsManager) @@ -42,7 +43,7 @@ class GCDSplitEmissionExecutionTest extends FirrtlFlatSpec { // topFile will be compiled by Verilator command by default but we need to also include dutFile verilogToCpp(top, testDir, Seq(dutFile), harness) #&& - cppToExe(top, testDir) ! loggingProcessLogger + cppToExe(top, testDir) ! loggingProcessLogger assert(executeExpectingSuccess(top, testDir)) } } @@ -53,14 +54,14 @@ class ICacheCompilationTest extends CompilationTest("ICache", "/regress") class FPUCompilationTest extends CompilationTest("FPU", "/regress") class HwachaSequencerCompilationTest extends CompilationTest("HwachaSequencer", "/regress") -abstract class CommonSubexprEliminationEquivTest(name: String, dir: String) extends - EquivalenceTest(Seq(firrtl.passes.CommonSubexpressionElimination), name, dir) -abstract class DeadCodeEliminationEquivTest(name: String, dir: String) extends - EquivalenceTest(Seq(new firrtl.transforms.DeadCodeElimination), name, dir) -abstract class ConstantPropagationEquivTest(name: String, dir: String) extends - EquivalenceTest(Seq(new firrtl.transforms.ConstantPropagation), name, dir) -abstract class LowFirrtlOptimizationEquivTest(name: String, dir: String) extends - EquivalenceTest(Seq(new LowFirrtlOptimization), name, dir) +abstract class CommonSubexprEliminationEquivTest(name: String, dir: String) + extends EquivalenceTest(Seq(firrtl.passes.CommonSubexpressionElimination), name, dir) +abstract class DeadCodeEliminationEquivTest(name: String, dir: String) + extends EquivalenceTest(Seq(new firrtl.transforms.DeadCodeElimination), name, dir) +abstract class ConstantPropagationEquivTest(name: String, dir: String) + extends EquivalenceTest(Seq(new firrtl.transforms.ConstantPropagation), name, dir) +abstract class LowFirrtlOptimizationEquivTest(name: String, dir: String) + extends EquivalenceTest(Seq(new LowFirrtlOptimization), name, dir) class OpsCommonSubexprEliminationTest extends CommonSubexprEliminationEquivTest("Ops", "/regress") class OpsDeadCodeEliminationTest extends DeadCodeEliminationEquivTest("Ops", "/regress") diff --git a/src/test/scala/firrtlTests/LegalizeSpec.scala b/src/test/scala/firrtlTests/LegalizeSpec.scala index 22fef730..aa6458ba 100644 --- a/src/test/scala/firrtlTests/LegalizeSpec.scala +++ b/src/test/scala/firrtlTests/LegalizeSpec.scala @@ -5,4 +5,3 @@ package firrtlTests import firrtl.testutils.ExecutionTest class LegalizeExecutionTest extends ExecutionTest("Legalize", "/passes/Legalize") - diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala index 648c6b36..0d020252 100644 --- a/src/test/scala/firrtlTests/LowerTypesSpec.scala +++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala @@ -21,14 +21,14 @@ class LowerTypesSpec extends FirrtlFlatSpec { private def executeTest(input: String, expected: Seq[String]) = { val fir = Parser.parse(input.split("\n").toIterator) val c = compiler.runTransform(CircuitState(fir, Seq())).circuit - val lines = c.serialize.split("\n") map normalized + val lines = c.serialize.split("\n").map(normalized) - expected foreach { e => + expected.foreach { e => lines should contain(e) } } - behavior of "Lower Types" + behavior.of("Lower Types") it should "lower ports" in { val input = @@ -39,9 +39,23 @@ class LowerTypesSpec extends FirrtlFlatSpec { | input y : UInt<1>[4] | input z : { c : { d : UInt<1>, e : UInt<1>}, f : UInt<1>[2] }[2] """.stripMargin - val expected = Seq("w", "x_a", "x_b", "y_0", "y_1", "y_2", "y_3", "z_0_c_d", - "z_0_c_e", "z_0_f_0", "z_0_f_1", "z_1_c_d", "z_1_c_e", "z_1_f_0", - "z_1_f_1") map (x => s"input $x : UInt<1>") map normalized + val expected = Seq( + "w", + "x_a", + "x_b", + "y_0", + "y_1", + "y_2", + "y_3", + "z_0_c_d", + "z_0_c_e", + "z_0_f_0", + "z_0_f_1", + "z_1_c_d", + "z_1_c_e", + "z_1_f_0", + "z_1_f_1" + ).map(x => s"input $x : UInt<1>").map(normalized) executeTest(input, expected) } @@ -56,7 +70,7 @@ class LowerTypesSpec extends FirrtlFlatSpec { val expected = Seq( "output foo_0_a : UInt<1>", "input foo_0_b : UInt<1>" - ) map normalized + ).map(normalized) executeTest(input, expected) } @@ -72,29 +86,47 @@ class LowerTypesSpec extends FirrtlFlatSpec { | reg y : UInt<1>[4], clock | reg z : { c : { d : UInt<1>, e : UInt<1>}, f : UInt<1>[2] }[2], clock """.stripMargin - val expected = Seq("w", "x_a", "x_b", "y_0", "y_1", "y_2", "y_3", "z_0_c_d", - "z_0_c_e", "z_0_f_0", "z_0_f_1", "z_1_c_d", "z_1_c_e", "z_1_f_0", - "z_1_f_1") map (x => s"reg $x : UInt<1>, clock with :") map normalized + val expected = Seq( + "w", + "x_a", + "x_b", + "y_0", + "y_1", + "y_2", + "y_3", + "z_0_c_d", + "z_0_c_e", + "z_0_f_0", + "z_0_f_1", + "z_1_c_d", + "z_1_c_e", + "z_1_f_0", + "z_1_f_1" + ).map(x => s"reg $x : UInt<1>, clock with :").map(normalized) executeTest(input, expected) } it should "lower registers with aggregate initialization" in { val input = - """circuit Test : - | module Test : - | input clock : Clock - | input reset : UInt<1> - | input init : { a : UInt<1>, b : UInt<1>}[2] - | reg x : { a : UInt<1>, b : UInt<1>}[2], clock with : - | reset => (reset, init) + """circuit Test : + | module Test : + | input clock : Clock + | input reset : UInt<1> + | input init : { a : UInt<1>, b : UInt<1>}[2] + | reg x : { a : UInt<1>, b : UInt<1>}[2], clock with : + | reset => (reset, init) """.stripMargin val expected = Seq( - "reg x_0_a : UInt<1>, clock with :", "reset => (reset, init_0_a)", - "reg x_0_b : UInt<1>, clock with :", "reset => (reset, init_0_b)", - "reg x_1_a : UInt<1>, clock with :", "reset => (reset, init_1_a)", - "reg x_1_b : UInt<1>, clock with :", "reset => (reset, init_1_b)" - ) map normalized + "reg x_0_a : UInt<1>, clock with :", + "reset => (reset, init_0_a)", + "reg x_0_b : UInt<1>, clock with :", + "reset => (reset, init_0_b)", + "reg x_1_a : UInt<1>, clock with :", + "reset => (reset, init_1_a)", + "reg x_1_b : UInt<1>, clock with :", + "reset => (reset, init_1_b)" + ).map(normalized) executeTest(input, expected) } @@ -112,77 +144,87 @@ class LowerTypesSpec extends FirrtlFlatSpec { val expected = Seq( "reg foo : UInt<4>, clock_1 with :", "reset => (reset_a, init_3_b_1_d)" - ) map normalized + ).map(normalized) executeTest(input, expected) } it should "lower DefInstances (but not too far!)" in { val input = - """circuit Test : - | module Other : - | input a : { b : UInt<1>, c : UInt<1>} - | output d : UInt<1>[2] - | d[0] <= a.b - | d[1] <= a.c - | module Test : - | input x : UInt<1> - | inst mod of Other - | mod.a.b <= x - | mod.a.c <= x - | node y = mod.d[0] + """circuit Test : + | module Other : + | input a : { b : UInt<1>, c : UInt<1>} + | output d : UInt<1>[2] + | d[0] <= a.b + | d[1] <= a.c + | module Test : + | input x : UInt<1> + | inst mod of Other + | mod.a.b <= x + | mod.a.c <= x + | node y = mod.d[0] """.stripMargin - val expected = Seq( - "mod.a_b <= x", - "mod.a_c <= x", - "node y = mod.d_0") map normalized + val expected = Seq("mod.a_b <= x", "mod.a_c <= x", "node y = mod.d_0").map(normalized) executeTest(input, expected) } it should "lower aggregate memories" in { val input = - """circuit Test : - | module Test : - | input clock : Clock - | mem m : - | data-type => { a : UInt<8>, b : UInt<8>}[2] - | depth => 32 - | read-latency => 0 - | write-latency => 1 - | reader => read - | writer => write - | m.read.clk <= clock - | m.read.en <= UInt<1>(1) - | m.read.addr is invalid - | node x = m.read.data - | node y = m.read.data[0].b - | - | m.write.clk <= clock - | m.write.en <= UInt<1>(0) - | m.write.mask is invalid - | m.write.addr is invalid - | wire w : { a : UInt<8>, b : UInt<8>}[2] - | w[0].a <= UInt<4>(2) - | w[0].b <= UInt<4>(3) - | w[1].a <= UInt<4>(4) - | w[1].b <= UInt<4>(5) - | m.write.data <= w + """circuit Test : + | module Test : + | input clock : Clock + | mem m : + | data-type => { a : UInt<8>, b : UInt<8>}[2] + | depth => 32 + | read-latency => 0 + | write-latency => 1 + | reader => read + | writer => write + | m.read.clk <= clock + | m.read.en <= UInt<1>(1) + | m.read.addr is invalid + | node x = m.read.data + | node y = m.read.data[0].b + | + | m.write.clk <= clock + | m.write.en <= UInt<1>(0) + | m.write.mask is invalid + | m.write.addr is invalid + | wire w : { a : UInt<8>, b : UInt<8>}[2] + | w[0].a <= UInt<4>(2) + | w[0].b <= UInt<4>(3) + | w[1].a <= UInt<4>(4) + | w[1].b <= UInt<4>(5) + | m.write.data <= w """.stripMargin val expected = Seq( - "mem m_0_a :", "mem m_0_b :", "mem m_1_a :", "mem m_1_b :", - "m_0_a.read.clk <= clock", "m_0_b.read.clk <= clock", - "m_1_a.read.clk <= clock", "m_1_b.read.clk <= clock", - "m_0_a.read.addr is invalid", "m_0_b.read.addr is invalid", - "m_1_a.read.addr is invalid", "m_1_b.read.addr is invalid", - "node x_0_a = m_0_a.read.data", "node x_0_b = m_0_b.read.data", - "node x_1_a = m_1_a.read.data", "node x_1_b = m_1_b.read.data", - "m_0_a.write.mask is invalid", "m_0_b.write.mask is invalid", - "m_1_a.write.mask is invalid", "m_1_b.write.mask is invalid", - "m_0_a.write.data <= w_0_a", "m_0_b.write.data <= w_0_b", - "m_1_a.write.data <= w_1_a", "m_1_b.write.data <= w_1_b" - ) map normalized + "mem m_0_a :", + "mem m_0_b :", + "mem m_1_a :", + "mem m_1_b :", + "m_0_a.read.clk <= clock", + "m_0_b.read.clk <= clock", + "m_1_a.read.clk <= clock", + "m_1_b.read.clk <= clock", + "m_0_a.read.addr is invalid", + "m_0_b.read.addr is invalid", + "m_1_a.read.addr is invalid", + "m_1_b.read.addr is invalid", + "node x_0_a = m_0_a.read.data", + "node x_0_b = m_0_b.read.data", + "node x_1_a = m_1_a.read.data", + "node x_1_b = m_1_b.read.data", + "m_0_a.write.mask is invalid", + "m_0_b.write.mask is invalid", + "m_1_a.write.mask is invalid", + "m_1_b.write.mask is invalid", + "m_0_a.write.data <= w_0_a", + "m_0_b.write.data <= w_0_b", + "m_1_a.write.data <= w_1_a", + "m_1_b.write.data <= w_1_b" + ).map(normalized) executeTest(input, expected) } @@ -192,12 +234,17 @@ class LowerTypesSpec extends FirrtlFlatSpec { class LowerTypesUniquifySpec extends FirrtlFlatSpec { private val compiler = new TransformManager(Seq(Dependency(firrtl.passes.LowerTypes))) - private def executeTest(input: String, expected: Seq[String]): Unit = executeTest(input, expected, Seq.empty, Seq.empty) - private def executeTest(input: String, expected: Seq[String], - inputAnnos: Seq[Annotation], expectedAnnos: Seq[Annotation]): Unit = { + private def executeTest(input: String, expected: Seq[String]): Unit = + executeTest(input, expected, Seq.empty, Seq.empty) + private def executeTest( + input: String, + expected: Seq[String], + inputAnnos: Seq[Annotation], + expectedAnnos: Seq[Annotation] + ): Unit = { val circuit = Parser.parse(input.split("\n").toIterator) val result = compiler.runTransform(CircuitState(circuit, inputAnnos)) - val lines = result.circuit.serialize.split("\n") map normalized + val lines = result.circuit.serialize.split("\n").map(normalized) expected.map(normalized).foreach { e => assert(lines.contains(e), f"Failed to find $e in ${lines.mkString("\n")}") @@ -206,7 +253,7 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { result.annotations.toSeq should equal(expectedAnnos) } - behavior of "LowerTypes" + behavior.of("LowerTypes") it should "rename colliding ports" in { val input = @@ -221,17 +268,16 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { "input a___0_c__0_d : UInt<2>", "output a___0_c__0_e : UInt<3>", "output a_0_c_ : UInt<5>", - "output a__0 : UInt<6>") + "output a__0 : UInt<6>" + ) val m = CircuitTarget("Test").module("Test") val inputAnnos = Seq( DontTouchAnnotation(m.ref("a").index(0).field("b")), - DontTouchAnnotation(m.ref("a").index(0).field("c").index(0).field("e"))) - - val expectedAnnos = Seq( - DontTouchAnnotation(m.ref("a___0_b")), - DontTouchAnnotation(m.ref("a___0_c__0_e"))) + DontTouchAnnotation(m.ref("a").index(0).field("c").index(0).field("e")) + ) + val expectedAnnos = Seq(DontTouchAnnotation(m.ref("a___0_b")), DontTouchAnnotation(m.ref("a___0_c__0_e"))) executeTest(input, expected, inputAnnos, expectedAnnos) } @@ -250,7 +296,8 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { "reg a___1_c__1_e : UInt<3>, clock with :", "reg a___0_c_1_e : UInt<4>, clock with :", "reg a_0_c_ : UInt<5>, clock with :", - "reg a__0 : UInt<6>, clock with :") + "reg a__0 : UInt<6>, clock with :" + ) executeTest(input, expected) } @@ -274,7 +321,6 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { executeTest(input, expected) } - it should "rename DefRegister expressions: clock, reset, and init" in { val input = """circuit Test : @@ -368,9 +414,7 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { | node foo = data.a | node bar = data.b[1] """.stripMargin - val expected = Seq( - "node foo = data___a", - "node bar = data___b_1") + val expected = Seq("node foo = data___a", "node bar = data___b_1") executeTest(input, expected) } @@ -439,7 +483,8 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { "mem mem__0_b_0 :", "node mem_0_b_0 = mem__0_b_0.read.data", "node mem_0_b_1 = mem__0_b_1.read.data", - "mem__0_b_0.read.addr is invalid") + "mem__0_b_0.read.addr is invalid" + ) executeTest(input, expected) } @@ -467,12 +512,8 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { | mem.write.en <= UInt(0) | mem.write.clk <= clock """.stripMargin - val expected = Seq( - "mem mem_a :", - "mem mem_b__0 :", - "mem mem_b__1 :", - "mem mem_b_0 :", - "node x = mem_b__0.read.data") + val expected = + Seq("mem mem_a :", "mem mem_b__0 :", "mem mem_b__1 :", "mem mem_b_0 :", "node x = mem_b__0.read.data") executeTest(input, expected) } @@ -492,11 +533,7 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { | mod.a.c <= x | node mod_a_b = mod.a_b """.stripMargin - val expected = Seq( - "inst mod_ of Other", - "mod_.a__b <= x", - "mod_.a__c <= x", - "node mod_a_b = mod_.a_b") + val expected = Seq("inst mod_ of Other", "mod_.a__b <= x", "mod_.a__c <= x", "node mod_a_b = mod_.a_b") executeTest(input, expected) } @@ -515,7 +552,7 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { // Run the "quick" test three times and choose the longest time as the basis. val nCalibrationRuns = 3 def mkType(i: Int): String = { - if(i == 0) "UInt<8>" else s"{x: ${mkType(i - 1)}}" + if (i == 0) "UInt<8>" else s"{x: ${mkType(i - 1)}}" } val timesMs = ( for (depth <- (List.fill(nCalibrationRuns)(1) :+ depth)) yield { @@ -528,12 +565,11 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { val (ms, _) = Utils.time(compileToVerilog(input)) ms } - ).toArray + ).toArray // The baseMs will be the maximum of the first calibration runs val baseMs = timesMs.slice(0, nCalibrationRuns - 1).max val renameMs = timesMs(nCalibrationRuns) if (TestOptions.accurateTiming) - renameMs shouldBe < (baseMs * threshold) + renameMs shouldBe <(baseMs * threshold) } } - diff --git a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala index f0f2042e..46416619 100644 --- a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala +++ b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala @@ -21,115 +21,127 @@ object Transforms { } import firrtl.{ChirrtlForm => C, HighForm => H, MidForm => M, LowForm => L, UnknownForm => U} class ChirrtlToChirrtl extends IdentityTransformDiff(C, C) - class HighToChirrtl extends IdentityTransformDiff(H, C) - class HighToHigh extends IdentityTransformDiff(H, H) - class MidToMid extends IdentityTransformDiff(M, M) - class MidToChirrtl extends IdentityTransformDiff(M, C) - class MidToHigh extends IdentityTransformDiff(M, H) - class LowToChirrtl extends IdentityTransformDiff(L, C) - class LowToHigh extends IdentityTransformDiff(L, H) - class LowToMid extends IdentityTransformDiff(L, M) - class LowToLow extends IdentityTransformDiff(L, L) + class HighToChirrtl extends IdentityTransformDiff(H, C) + class HighToHigh extends IdentityTransformDiff(H, H) + class MidToMid extends IdentityTransformDiff(M, M) + class MidToChirrtl extends IdentityTransformDiff(M, C) + class MidToHigh extends IdentityTransformDiff(M, H) + class LowToChirrtl extends IdentityTransformDiff(L, C) + class LowToHigh extends IdentityTransformDiff(L, H) + class LowToMid extends IdentityTransformDiff(L, M) + class LowToLow extends IdentityTransformDiff(L, L) } class LoweringCompilersSpec extends AnyFlatSpec with Matchers { def legacyTransforms(a: CoreTransform): Seq[Transform] = a match { - case _: ChirrtlToHighFirrtl => Seq( - new firrtl.stage.transforms.CheckScalaVersion, - firrtl.passes.CheckChirrtl, - firrtl.passes.CInferTypes, - firrtl.passes.CInferMDir, - firrtl.passes.RemoveCHIRRTL) + case _: ChirrtlToHighFirrtl => + Seq( + new firrtl.stage.transforms.CheckScalaVersion, + firrtl.passes.CheckChirrtl, + firrtl.passes.CInferTypes, + firrtl.passes.CInferMDir, + firrtl.passes.RemoveCHIRRTL + ) case _: IRToWorkingIR => Seq(firrtl.passes.ToWorkingIR) - case _: ResolveAndCheck => Seq( - firrtl.passes.CheckHighForm, - firrtl.passes.ResolveKinds, - firrtl.passes.InferTypes, - firrtl.passes.CheckTypes, - firrtl.passes.Uniquify, - firrtl.passes.ResolveKinds, - firrtl.passes.InferTypes, - firrtl.passes.ResolveFlows, - firrtl.passes.CheckFlows, - new firrtl.passes.InferBinaryPoints, - new firrtl.passes.TrimIntervals, - new firrtl.passes.InferWidths, - firrtl.passes.CheckWidths, - new firrtl.transforms.InferResets) - case _: HighFirrtlToMiddleFirrtl => Seq( - firrtl.passes.PullMuxes, - firrtl.passes.ReplaceAccesses, - firrtl.passes.ExpandConnects, - firrtl.passes.ZeroLengthVecs, - firrtl.passes.RemoveAccesses, - firrtl.passes.Uniquify, - firrtl.passes.ExpandWhens, - firrtl.passes.CheckInitialization, - firrtl.passes.ResolveKinds, - firrtl.passes.InferTypes, - firrtl.passes.CheckTypes, - firrtl.passes.ResolveFlows, - new firrtl.passes.InferWidths, - firrtl.passes.CheckWidths, - new firrtl.passes.RemoveIntervals, - firrtl.passes.ConvertFixedToSInt, - firrtl.passes.ZeroWidth, - firrtl.passes.InferTypes) - case _: MiddleFirrtlToLowFirrtl => Seq( - firrtl.passes.LowerTypes, - firrtl.passes.ResolveKinds, - firrtl.passes.InferTypes, - firrtl.passes.ResolveFlows, - new firrtl.passes.InferWidths, - firrtl.passes.Legalize, - firrtl.transforms.RemoveReset, - firrtl.passes.ResolveFlows, - new firrtl.transforms.CheckCombLoops, - new checks.CheckResets, - new firrtl.transforms.RemoveWires) - case _: LowFirrtlOptimization => Seq( - firrtl.passes.RemoveValidIf, - new firrtl.transforms.ConstantPropagation, - firrtl.passes.PadWidths, - new firrtl.transforms.ConstantPropagation, - firrtl.passes.Legalize, - firrtl.passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter - new firrtl.transforms.ConstantPropagation, - firrtl.passes.SplitExpressions, - new firrtl.transforms.CombineCats, - firrtl.passes.CommonSubexpressionElimination, - new firrtl.transforms.DeadCodeElimination) - case _: MinimumLowFirrtlOptimization => Seq( - firrtl.passes.RemoveValidIf, - firrtl.passes.PadWidths, - firrtl.passes.Legalize, - firrtl.passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter - firrtl.passes.SplitExpressions) + case _: ResolveAndCheck => + Seq( + firrtl.passes.CheckHighForm, + firrtl.passes.ResolveKinds, + firrtl.passes.InferTypes, + firrtl.passes.CheckTypes, + firrtl.passes.Uniquify, + firrtl.passes.ResolveKinds, + firrtl.passes.InferTypes, + firrtl.passes.ResolveFlows, + firrtl.passes.CheckFlows, + new firrtl.passes.InferBinaryPoints, + new firrtl.passes.TrimIntervals, + new firrtl.passes.InferWidths, + firrtl.passes.CheckWidths, + new firrtl.transforms.InferResets + ) + case _: HighFirrtlToMiddleFirrtl => + Seq( + firrtl.passes.PullMuxes, + firrtl.passes.ReplaceAccesses, + firrtl.passes.ExpandConnects, + firrtl.passes.ZeroLengthVecs, + firrtl.passes.RemoveAccesses, + firrtl.passes.Uniquify, + firrtl.passes.ExpandWhens, + firrtl.passes.CheckInitialization, + firrtl.passes.ResolveKinds, + firrtl.passes.InferTypes, + firrtl.passes.CheckTypes, + firrtl.passes.ResolveFlows, + new firrtl.passes.InferWidths, + firrtl.passes.CheckWidths, + new firrtl.passes.RemoveIntervals, + firrtl.passes.ConvertFixedToSInt, + firrtl.passes.ZeroWidth, + firrtl.passes.InferTypes + ) + case _: MiddleFirrtlToLowFirrtl => + Seq( + firrtl.passes.LowerTypes, + firrtl.passes.ResolveKinds, + firrtl.passes.InferTypes, + firrtl.passes.ResolveFlows, + new firrtl.passes.InferWidths, + firrtl.passes.Legalize, + firrtl.transforms.RemoveReset, + firrtl.passes.ResolveFlows, + new firrtl.transforms.CheckCombLoops, + new checks.CheckResets, + new firrtl.transforms.RemoveWires + ) + case _: LowFirrtlOptimization => + Seq( + firrtl.passes.RemoveValidIf, + new firrtl.transforms.ConstantPropagation, + firrtl.passes.PadWidths, + new firrtl.transforms.ConstantPropagation, + firrtl.passes.Legalize, + firrtl.passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter + new firrtl.transforms.ConstantPropagation, + firrtl.passes.SplitExpressions, + new firrtl.transforms.CombineCats, + firrtl.passes.CommonSubexpressionElimination, + new firrtl.transforms.DeadCodeElimination + ) + case _: MinimumLowFirrtlOptimization => + Seq( + firrtl.passes.RemoveValidIf, + firrtl.passes.PadWidths, + firrtl.passes.Legalize, + firrtl.passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter + firrtl.passes.SplitExpressions + ) } def compare(a: Seq[Transform], b: TransformManager, patches: Seq[PatchAction] = Seq.empty): Unit = { info(s"""Transform Order:\n${b.prettyPrint(" ")}""") val m = new scala.collection.mutable.HashMap[Int, Seq[Dependency[Transform]]].withDefault(_ => Seq.empty) - a.map(Dependency.fromTransform).zipWithIndex.foreach{ case (t, idx) => m(idx) = Seq(t) } + a.map(Dependency.fromTransform).zipWithIndex.foreach { case (t, idx) => m(idx) = Seq(t) } patches.foreach { case Add(line, txs) => m(line - 1) = m(line - 1) ++ txs case Del(line) => m.remove(line - 1) } - val patched = scala.collection.immutable.TreeMap(m.toArray:_*).values.flatten + val patched = scala.collection.immutable.TreeMap(m.toArray: _*).values.flatten patched .zip(b.flattenedTransformOrder.map(Dependency.fromTransform)) - .foreach{ case (aa, bb) => bb should be (aa) } + .foreach { case (aa, bb) => bb should be(aa) } info(s"found ${b.flattenedTransformOrder.size} transforms") - patched.size should be (b.flattenedTransformOrder.size) + patched.size should be(b.flattenedTransformOrder.size) } - behavior of "ChirrtlToHighFirrtl" + behavior.of("ChirrtlToHighFirrtl") it should "replicate the old order" in { val tm = new TransformManager(Forms.MinimalHighForm, Forms.ChirrtlForm) @@ -139,26 +151,28 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { compare(legacyTransforms(new firrtl.ChirrtlToHighFirrtl), tm, patches) } - behavior of "IRToWorkingIR" + behavior.of("IRToWorkingIR") it should "replicate the old order" in { val tm = new TransformManager(Forms.WorkingIR, Forms.MinimalHighForm) compare(legacyTransforms(new firrtl.IRToWorkingIR), tm) } - behavior of "ResolveAndCheck" + behavior.of("ResolveAndCheck") it should "replicate the old order" in { val tm = new TransformManager(Forms.Resolved, Forms.WorkingIR) val patches = Seq( // Uniquify is now part of [[firrtl.passes.LowerTypes]] - Del(5), Del(6), Del(7), + Del(5), + Del(6), + Del(7), Add(14, Seq(Dependency.fromTransform(firrtl.passes.CheckTypes))) ) compare(legacyTransforms(new ResolveAndCheck), tm, patches) } - behavior of "HighFirrtlToMiddleFirrtl" + behavior.of("HighFirrtlToMiddleFirrtl") it should "replicate the old order" in { val tm = new TransformManager(Forms.MidForm, Forms.Deduped) @@ -174,56 +188,54 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { Del(11), Del(12), Del(13), - Add(12, Seq(Dependency(firrtl.passes.ResolveFlows), - Dependency[firrtl.passes.InferWidths])), + Add(12, Seq(Dependency(firrtl.passes.ResolveFlows), Dependency[firrtl.passes.InferWidths])), Del(14), - Add(15, Seq(Dependency(firrtl.passes.ResolveKinds), - Dependency(firrtl.passes.InferTypes))), + Add(15, Seq(Dependency(firrtl.passes.ResolveKinds), Dependency(firrtl.passes.InferTypes))), // TODO Add(17, Seq(Dependency[firrtl.transforms.formal.AssertSubmoduleAssumptions])) ) compare(legacyTransforms(new HighFirrtlToMiddleFirrtl), tm, patches) } - behavior of "MiddleFirrtlToLowFirrtl" + behavior.of("MiddleFirrtlToLowFirrtl") it should "replicate the old order" in { val tm = new TransformManager(Forms.LowForm, Forms.MidForm) val patches = Seq( // Uniquify is now part of [[firrtl.passes.LowerTypes]] - Del(2), Del(3), Del(5), + Del(2), + Del(3), + Del(5), // RemoveWires now visibly invalidates ResolveKinds Add(11, Seq(Dependency(firrtl.passes.ResolveKinds))) ) compare(legacyTransforms(new MiddleFirrtlToLowFirrtl), tm, patches) } - behavior of "MinimumLowFirrtlOptimization" + behavior.of("MinimumLowFirrtlOptimization") it should "replicate the old order" in { val tm = new TransformManager(Forms.LowFormMinimumOptimized, Forms.LowForm) val patches = Seq( Add(4, Seq(Dependency(firrtl.passes.ResolveFlows))), - Add(6, Seq(Dependency[firrtl.transforms.LegalizeAndReductionsTransform], - Dependency(firrtl.passes.ResolveKinds))) + Add(6, Seq(Dependency[firrtl.transforms.LegalizeAndReductionsTransform], Dependency(firrtl.passes.ResolveKinds))) ) compare(legacyTransforms(new MinimumLowFirrtlOptimization), tm, patches) } - behavior of "LowFirrtlOptimization" + behavior.of("LowFirrtlOptimization") it should "replicate the old order" in { val tm = new TransformManager(Forms.LowFormOptimized, Forms.LowForm) val patches = Seq( Add(6, Seq(Dependency(firrtl.passes.ResolveFlows))), Add(7, Seq(Dependency(firrtl.passes.Legalize))), - Add(8, Seq(Dependency[firrtl.transforms.LegalizeAndReductionsTransform], - Dependency(firrtl.passes.ResolveKinds))) + Add(8, Seq(Dependency[firrtl.transforms.LegalizeAndReductionsTransform], Dependency(firrtl.passes.ResolveKinds))) ) compare(legacyTransforms(new LowFirrtlOptimization), tm, patches) } - behavior of "VerilogMinimumOptimized" + behavior.of("VerilogMinimumOptimized") it should "replicate the old order" in { val legacy = Seq( @@ -238,12 +250,13 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { firrtl.passes.VerilogModulusCleanup, new firrtl.transforms.VerilogRename, firrtl.passes.VerilogPrep, - new firrtl.AddDescriptionNodes) + new firrtl.AddDescriptionNodes + ) val tm = new TransformManager(Forms.VerilogMinimumOptimized, (new firrtl.VerilogEmitter).prerequisites) compare(legacy, tm) } - behavior of "VerilogOptimized" + behavior.of("VerilogOptimized") it should "replicate the old order" in { val legacy = Seq( @@ -259,12 +272,13 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { firrtl.passes.VerilogModulusCleanup, new firrtl.transforms.VerilogRename, firrtl.passes.VerilogPrep, - new firrtl.AddDescriptionNodes) + new firrtl.AddDescriptionNodes + ) val tm = new TransformManager(Forms.VerilogOptimized, Forms.LowFormOptimized) compare(legacy, tm) } - behavior of "Legacy Custom Transforms" + behavior.of("Legacy Custom Transforms") it should "work for Chirrtl -> Chirrtl" in { val expected = new Transforms.ChirrtlToChirrtl :: new firrtl.ChirrtlEmitter :: Nil diff --git a/src/test/scala/firrtlTests/MemLatencySpec.scala b/src/test/scala/firrtlTests/MemLatencySpec.scala index 79986cc2..8a04eeef 100644 --- a/src/test/scala/firrtlTests/MemLatencySpec.scala +++ b/src/test/scala/firrtlTests/MemLatencySpec.scala @@ -6,8 +6,8 @@ object MemLatencySpec { case class Write(addr: Int, data: Int, mask: Option[Boolean] = None) case class Read(addr: Int, expectedValue: Int) case class MemAccess(w: Option[Write], r: Option[Read]) - def writeOnly(addr: Int, data: Int) = MemAccess(Some(Write(addr, data)), None) - def readOnly(addr: Int, expectedValue: Int) = MemAccess(None, Some(Read(addr, expectedValue))) + def writeOnly(addr: Int, data: Int) = MemAccess(Some(Write(addr, data)), None) + def readOnly(addr: Int, expectedValue: Int) = MemAccess(None, Some(Read(addr, expectedValue))) } abstract class MemLatencySpec(rLatency: Int, wLatency: Int, ruw: String) @@ -36,7 +36,7 @@ abstract class MemLatencySpec(rLatency: Int, wLatency: Int, ruw: String) def mask2Poke(m: Option[Boolean]) = m match { case Some(false) => Poke("m.w.mask", 0) - case _ => Poke("m.w.mask", 1) + case _ => Poke("m.w.mask", 1) } def wPokes = memAccesses.map { @@ -47,24 +47,25 @@ abstract class MemLatencySpec(rLatency: Int, wLatency: Int, ruw: String) def rPokes = memAccesses.map { case MemAccess(_, Some(Read(a, _))) => Seq(Poke("m.r.en", 1), Poke("m.r.addr", a)) - case _ => Seq(Poke("m.r.en", 0), Invalidate("m.r.addr")) + case _ => Seq(Poke("m.r.en", 0), Invalidate("m.r.addr")) } // Need to idle for <rLatency> cycles at the end val idle = Seq(Poke("m.w.en", 0), Poke("m.r.en", 0)) - def pokes = (wPokes zip rPokes).map { case (wp, rp) => wp ++ rp } ++ Seq.fill(rLatency)(idle) + def pokes = (wPokes.zip(rPokes)).map { case (wp, rp) => wp ++ rp } ++ Seq.fill(rLatency)(idle) // Need to delay read value expects by <rLatency> def expects = Seq.fill(rLatency)(Seq(Step(1))) ++ memAccesses.map { case MemAccess(_, Some(Read(_, expected))) => Seq(Expect("m.r.data", expected), Step(1)) - case _ => Seq(Step(1)) + case _ => Seq(Step(1)) } - def commands: Seq[SimpleTestCommand] = (pokes zip expects).flatMap { case (p, e) => p ++ e } + def commands: Seq[SimpleTestCommand] = (pokes.zip(expects)).flatMap { case (p, e) => p ++ e } } trait ToggleMaskAndEnable { import MemLatencySpec._ + /** * A canonical sequence of memory accesses for sanity checking memories of different latencies. * The shortest true "RAW" hazard is reading address 14 two accesses after writing it. Since this @@ -76,19 +77,19 @@ trait ToggleMaskAndEnable { * @note Write-first mems should return expected values for (write-latency <= read-latency + 2) */ val memAccesses: Seq[MemAccess] = Seq( - MemAccess(Some(Write(6, 32)), None), - MemAccess(Some(Write(14, 87)), None), - MemAccess(None, None), - MemAccess(Some(Write(19, 63)), Some(Read(14, 87))), - MemAccess(Some(Write(22, 49)), None), - MemAccess(Some(Write(11, 99)), Some(Read(6, 32))), - MemAccess(Some(Write(42, 42)), None), - MemAccess(Some(Write(77, 81)), None), - MemAccess(Some(Write(6, 7)), Some(Read(19, 63))), - MemAccess(Some(Write(39, 5)), Some(Read(42, 42))), + MemAccess(Some(Write(6, 32)), None), + MemAccess(Some(Write(14, 87)), None), + MemAccess(None, None), + MemAccess(Some(Write(19, 63)), Some(Read(14, 87))), + MemAccess(Some(Write(22, 49)), None), + MemAccess(Some(Write(11, 99)), Some(Read(6, 32))), + MemAccess(Some(Write(42, 42)), None), + MemAccess(Some(Write(77, 81)), None), + MemAccess(Some(Write(6, 7)), Some(Read(19, 63))), + MemAccess(Some(Write(39, 5)), Some(Read(42, 42))), MemAccess(Some(Write(39, 6, Some(false))), Some(Read(77, 81))), // set mask to zero, should not write - MemAccess(None, Some(Read(6, 7))), // also read a twice-written address - MemAccess(None, Some(Read(39, 5))) // ensure masked writes didn't happen + MemAccess(None, Some(Read(6, 7))), // also read a twice-written address + MemAccess(None, Some(Read(39, 5))) // ensure masked writes didn't happen ) } @@ -111,20 +112,34 @@ class WriteFirstMemToggleSpec extends MemLatencySpec(rLatency = 1, wLatency = 1, class ReadFirstMemToggleSpec extends MemLatencySpec(rLatency = 1, wLatency = 1, ruw = "old") with ToggleMaskAndEnable // Read latency 2 -class WriteFirstMemToggleSpecRL2 extends MemLatencySpec(rLatency = 2, wLatency = 1, ruw = "new") with ToggleMaskAndEnable +class WriteFirstMemToggleSpecRL2 + extends MemLatencySpec(rLatency = 2, wLatency = 1, ruw = "new") + with ToggleMaskAndEnable class ReadFirstMemToggleSpecRL2 extends MemLatencySpec(rLatency = 2, wLatency = 1, ruw = "old") with ToggleMaskAndEnable // Write latency 2 -class WriteFirstMemToggleSpecWL2 extends MemLatencySpec(rLatency = 1, wLatency = 2, ruw = "new") with ToggleMaskAndEnable +class WriteFirstMemToggleSpecWL2 + extends MemLatencySpec(rLatency = 1, wLatency = 2, ruw = "new") + with ToggleMaskAndEnable class ReadFirstMemToggleSpecWL2 extends MemLatencySpec(rLatency = 1, wLatency = 2, ruw = "old") with ToggleMaskAndEnable // Read latency 2, write latency 2 -class WriteFirstMemToggleSpecRL2WL2 extends MemLatencySpec(rLatency = 2, wLatency = 2, ruw = "new") with ToggleMaskAndEnable -class ReadFirstMemToggleSpecRL2WL2 extends MemLatencySpec(rLatency = 2, wLatency = 2, ruw = "old") with ToggleMaskAndEnable +class WriteFirstMemToggleSpecRL2WL2 + extends MemLatencySpec(rLatency = 2, wLatency = 2, ruw = "new") + with ToggleMaskAndEnable +class ReadFirstMemToggleSpecRL2WL2 + extends MemLatencySpec(rLatency = 2, wLatency = 2, ruw = "old") + with ToggleMaskAndEnable // Read latency 3, write latency 2 -class WriteFirstMemToggleSpecRL3WL2 extends MemLatencySpec(rLatency = 3, wLatency = 2, ruw = "new") with ToggleMaskAndEnable -class ReadFirstMemToggleSpecRL3WL2 extends MemLatencySpec(rLatency = 3, wLatency = 2, ruw = "old") with ToggleMaskAndEnable +class WriteFirstMemToggleSpecRL3WL2 + extends MemLatencySpec(rLatency = 3, wLatency = 2, ruw = "new") + with ToggleMaskAndEnable +class ReadFirstMemToggleSpecRL3WL2 + extends MemLatencySpec(rLatency = 3, wLatency = 2, ruw = "old") + with ToggleMaskAndEnable // Read latency 2, write latency 4 -> ToggleSpec pattern only valid for write-first at this combo -class WriteFirstMemToggleSpecRL2WL4 extends MemLatencySpec(rLatency = 2, wLatency = 4, ruw = "new") with ToggleMaskAndEnable +class WriteFirstMemToggleSpecRL2WL4 + extends MemLatencySpec(rLatency = 2, wLatency = 4, ruw = "new") + with ToggleMaskAndEnable diff --git a/src/test/scala/firrtlTests/MemSpec.scala b/src/test/scala/firrtlTests/MemSpec.scala index c7ab8db7..e05aca86 100644 --- a/src/test/scala/firrtlTests/MemSpec.scala +++ b/src/test/scala/firrtlTests/MemSpec.scala @@ -50,7 +50,7 @@ class MemSpec extends FirrtlPropSpec with FirrtlMatchers { """.stripMargin val result = (new VerilogCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm, List.empty)) // TODO Not great that it includes the sparse comment for VCS - result should containLine (s"reg /* sparse */ [7:0] m [0:$addrWidth'd${memSize-1}];") + result should containLine(s"reg /* sparse */ [7:0] m [0:$addrWidth'd${memSize - 1}];") } property("Very large CHIRRTL memories should be supported") { @@ -76,7 +76,6 @@ class MemSpec extends FirrtlPropSpec with FirrtlMatchers { """.stripMargin val result = (new VerilogCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm, List.empty)) // TODO Not great that it includes the sparse comment for VCS - result should containLine (s"reg /* sparse */ [7:0] m [0:$addrWidth'd${memSize-1}];") + result should containLine(s"reg /* sparse */ [7:0] m [0:$addrWidth'd${memSize - 1}];") } } - diff --git a/src/test/scala/firrtlTests/MemoryInitSpec.scala b/src/test/scala/firrtlTests/MemoryInitSpec.scala index 5598e58b..984bf0b4 100644 --- a/src/test/scala/firrtlTests/MemoryInitSpec.scala +++ b/src/test/scala/firrtlTests/MemoryInitSpec.scala @@ -11,37 +11,37 @@ import firrtlTests.execution._ class MemInitSpec extends FirrtlFlatSpec { def input(tpe: String): String = s""" - |circuit MemTest: - | module MemTest: - | input clock : Clock - | input rAddr : UInt<5> - | input rEnable : UInt<1> - | input wAddr : UInt<5> - | input wData : $tpe - | input wEnable : UInt<1> - | output rData : $tpe - | - | mem m: - | data-type => $tpe - | depth => 32 - | reader => r - | writer => w - | read-latency => 1 - | write-latency => 1 - | read-under-write => new - | - | m.r.clk <= clock - | m.r.addr <= rAddr - | m.r.en <= rEnable - | rData <= m.r.data - | - | m.w.clk <= clock - | m.w.addr <= wAddr - | m.w.en <= wEnable - | m.w.data <= wData - | m.w.mask is invalid - | - |""".stripMargin + |circuit MemTest: + | module MemTest: + | input clock : Clock + | input rAddr : UInt<5> + | input rEnable : UInt<1> + | input wAddr : UInt<5> + | input wData : $tpe + | input wEnable : UInt<1> + | output rData : $tpe + | + | mem m: + | data-type => $tpe + | depth => 32 + | reader => r + | writer => w + | read-latency => 1 + | write-latency => 1 + | read-under-write => new + | + | m.r.clk <= clock + | m.r.addr <= rAddr + | m.r.en <= rEnable + | rData <= m.r.data + | + | m.w.clk <= clock + | m.w.addr <= wAddr + | m.w.en <= wEnable + | m.w.data <= wData + | m.w.mask is invalid + | + |""".stripMargin val mRef = CircuitTarget("MemTest").module("MemTest").ref("m") def compile(annos: AnnotationSeq, tpe: String = "UInt<32>"): CircuitState = { @@ -51,13 +51,13 @@ class MemInitSpec extends FirrtlFlatSpec { "NoAnnotation" should "create a randomized initialization" in { val annos = Seq() val result = compile(annos) - result should containLine (" m[initvar] = _RAND_0[31:0];") + result should containLine(" m[initvar] = _RAND_0[31:0];") } "MemoryRandomInitAnnotation" should "create a randomized initialization" in { val annos = Seq(MemoryRandomInitAnnotation(mRef)) val result = compile(annos) - result should containLine (" m[initvar] = _RAND_0[31:0];") + result should containLine(" m[initvar] = _RAND_0[31:0];") } "MemoryScalarInitAnnotation w/ 0" should "create an initialization with all zeros" in { @@ -79,8 +79,9 @@ class MemInitSpec extends FirrtlFlatSpec { val values = Seq.tabulate(32)(ii => 2 * ii + 5).map(BigInt(_)) val annos = Seq(MemoryArrayInitAnnotation(mRef, values)) val result = compile(annos) - values.zipWithIndex.foreach { case (value, addr) => - result should containLine(s" m[$addr] = $value;") + values.zipWithIndex.foreach { + case (value, addr) => + result should containLine(s" m[$addr] = $value;") } } @@ -137,7 +138,9 @@ class MemInitSpec extends FirrtlFlatSpec { val annos = Seq(MemoryScalarInitAnnotation(mRef, 0)) compile(annos, "{real: SInt<10>, imag: SInt<10>}") } - assert(caught.getMessage.endsWith("Cannot initialize memory m of non ground type { real : SInt<10>, imag : SInt<10>}")) + assert( + caught.getMessage.endsWith("Cannot initialize memory m of non ground type { real : SInt<10>, imag : SInt<10>}") + ) } private def jsonAnno(name: String, suffix: String): String = @@ -165,39 +168,46 @@ class MemInitSpec extends FirrtlFlatSpec { } abstract class MemInitExecutionSpec(values: Seq[Int], init: ReferenceTarget => Annotation) - extends SimpleExecutionTest with VerilogExecution { + extends SimpleExecutionTest + with VerilogExecution { override val body: String = s""" - |mem m: - | data-type => UInt<32> - | depth => ${values.length} - | reader => r - | read-latency => 1 - | write-latency => 1 - | read-under-write => new - |m.r.clk <= clock - |m.r.en <= UInt<1>(1) - |""".stripMargin + |mem m: + | data-type => UInt<32> + | depth => ${values.length} + | reader => r + | read-latency => 1 + | write-latency => 1 + | read-under-write => new + |m.r.clk <= clock + |m.r.en <= UInt<1>(1) + |""".stripMargin val mRef = CircuitTarget("dut").module("dut").ref("m") override val customAnnotations: AnnotationSeq = Seq(init(mRef)) - override def commands: Seq[SimpleTestCommand] = (Seq(-1) ++ values).zipWithIndex.map { case (value, addr) => - if(value == -1) { Seq(Poke("m.r.addr", addr)) } - else if(addr >= values.length) { Seq(Expect("m.r.data", value)) } - else { Seq(Poke("m.r.addr", addr), Expect("m.r.data", value)) } + override def commands: Seq[SimpleTestCommand] = (Seq(-1) ++ values).zipWithIndex.map { + case (value, addr) => + if (value == -1) { Seq(Poke("m.r.addr", addr)) } + else if (addr >= values.length) { Seq(Expect("m.r.data", value)) } + else { Seq(Poke("m.r.addr", addr), Expect("m.r.data", value)) } }.flatMap(_ ++ Seq(Step(1))) } -class MemScalarInit0ExecutionSpec extends MemInitExecutionSpec( - Seq.tabulate(31)(_ => 0), r => MemoryScalarInitAnnotation(r, 0) -) {} - -class MemScalarInit17ExecutionSpec extends MemInitExecutionSpec( - Seq.tabulate(31)(_ => 17), r => MemoryScalarInitAnnotation(r, 17) -) {} - -class MemArrayInitExecutionSpec extends MemInitExecutionSpec( - Seq.tabulate(31)(ii => ii * 5 + 7), - r => MemoryArrayInitAnnotation(r, Seq.tabulate(31)(ii => ii * 5 + 7).map(BigInt(_))) -) {} +class MemScalarInit0ExecutionSpec + extends MemInitExecutionSpec( + Seq.tabulate(31)(_ => 0), + r => MemoryScalarInitAnnotation(r, 0) + ) {} + +class MemScalarInit17ExecutionSpec + extends MemInitExecutionSpec( + Seq.tabulate(31)(_ => 17), + r => MemoryScalarInitAnnotation(r, 17) + ) {} + +class MemArrayInitExecutionSpec + extends MemInitExecutionSpec( + Seq.tabulate(31)(ii => ii * 5 + 7), + r => MemoryArrayInitAnnotation(r, Seq.tabulate(31)(ii => ii * 5 + 7).map(BigInt(_))) + ) {} diff --git a/src/test/scala/firrtlTests/MultiThreadingSpec.scala b/src/test/scala/firrtlTests/MultiThreadingSpec.scala index c7b18624..6ec1a2bd 100644 --- a/src/test/scala/firrtlTests/MultiThreadingSpec.scala +++ b/src/test/scala/firrtlTests/MultiThreadingSpec.scala @@ -24,7 +24,8 @@ class MultiThreadingSpec extends FirrtlPropSpec { new firrtl.HighFirrtlCompiler, new firrtl.MiddleFirrtlCompiler, new firrtl.LowFirrtlCompiler, - new firrtl.VerilogCompiler) + new firrtl.VerilogCompiler + ) val inputFilePath = s"/integration/GCDTester.fir" // arbitrary val numThreads = 64 // arbitrary @@ -35,20 +36,20 @@ class MultiThreadingSpec extends FirrtlPropSpec { import ExecutionContext.Implicits.global try { // Use try-catch because error can manifest in many ways // Execute for each compiler - val compilerResults = compilers map { compiler => + val compilerResults = compilers.map { compiler => // Run compiler serially once val serialResult = runCompiler(inputStrings, compiler) Future { - val threadFutures = (0 until numThreads) map { i => - Future { - runCompiler(inputStrings, compiler) == serialResult - } + val threadFutures = (0 until numThreads).map { i => + Future { + runCompiler(inputStrings, compiler) == serialResult } + } Await.result(Future.sequence(threadFutures), Duration.Inf) } } val results = Await.result(Future.sequence(compilerResults), Duration.Inf) - assert(results.flatten reduce (_ && _)) // check all true (ie. success) + assert(results.flatten.reduce(_ && _)) // check all true (ie. success) } catch { case _: Throwable => fail("The Compiler is not thread safe") } diff --git a/src/test/scala/firrtlTests/NamespaceSpec.scala b/src/test/scala/firrtlTests/NamespaceSpec.scala index a9bb844d..bf7cb019 100644 --- a/src/test/scala/firrtlTests/NamespaceSpec.scala +++ b/src/test/scala/firrtlTests/NamespaceSpec.scala @@ -9,19 +9,19 @@ class NamespaceSpec extends FirrtlFlatSpec { "A Namespace" should "not allow collisions" in { val namespace = Namespace() - namespace.newName("foo") should be ("foo") - namespace.newName("foo") should be ("foo_0") + namespace.newName("foo") should be("foo") + namespace.newName("foo") should be("foo_0") } it should "start temps with a suffix of 0" in { - Namespace().newTemp.last should be ('0') + Namespace().newTemp.last should be('0') } it should "handle multiple prefixes with independent suffixes" in { val namespace = Namespace() - namespace.newName("foo") should be ("foo") - namespace.newName("foo") should be ("foo_0") - namespace.newName("bar") should be ("bar") - namespace.newName("bar") should be ("bar_0") + namespace.newName("foo") should be("foo") + namespace.newName("foo") should be("foo_0") + namespace.newName("bar") should be("bar") + namespace.newName("bar") should be("bar_0") } } diff --git a/src/test/scala/firrtlTests/ParserSpec.scala b/src/test/scala/firrtlTests/ParserSpec.scala index 3d377901..25e52e57 100644 --- a/src/test/scala/firrtlTests/ParserSpec.scala +++ b/src/test/scala/firrtlTests/ParserSpec.scala @@ -12,16 +12,17 @@ class ParserSpec extends FirrtlFlatSpec { private object MemTests { val prelude = Seq("circuit top :", " module top :", " mem m : ") - val fields = Map("data-type" -> "UInt<32>", - "depth" -> "4", - "read-latency" -> "1", - "write-latency" -> "1", - "reader" -> "a", - "writer" -> "b", - "readwriter" -> "c" - ) + val fields = Map( + "data-type" -> "UInt<32>", + "depth" -> "4", + "read-latency" -> "1", + "write-latency" -> "1", + "reader" -> "a", + "writer" -> "b", + "readwriter" -> "c" + ) def fieldsToSeq(m: Map[String, String]): Seq[String] = - m.map { case (k,v) => s" ${k} => ${v}" }.toSeq + m.map { case (k, v) => s" ${k} => ${v}" }.toSeq } private object RegTests { @@ -36,11 +37,51 @@ class ParserSpec extends FirrtlFlatSpec { private object KeywordTests { val prelude = Seq("circuit top :", " module top :") - val keywords = Seq("circuit", "module", "extmodule", "parameter", "input", "output", "UInt", - "SInt", "Analog", "Fixed", "flip", "Clock", "wire", "reg", "reset", "with", "mem", "depth", - "reader", "writer", "readwriter", "inst", "of", "node", "is", "invalid", "when", "else", - "stop", "printf", "skip", "old", "new", "undefined", "mux", "validif", "cmem", "smem", - "mport", "infer", "read", "write", "rdwr") ++ PrimOps.listing + val keywords = Seq( + "circuit", + "module", + "extmodule", + "parameter", + "input", + "output", + "UInt", + "SInt", + "Analog", + "Fixed", + "flip", + "Clock", + "wire", + "reg", + "reset", + "with", + "mem", + "depth", + "reader", + "writer", + "readwriter", + "inst", + "of", + "node", + "is", + "invalid", + "when", + "else", + "stop", + "printf", + "skip", + "old", + "new", + "undefined", + "mux", + "validif", + "cmem", + "smem", + "mport", + "infer", + "read", + "write", + "rdwr" + ) ++ PrimOps.listing } // ********** Memories ********** @@ -48,7 +89,7 @@ class ParserSpec extends FirrtlFlatSpec { val fields = MemTests.fieldsToSeq(MemTests.fields) val golden = firrtl.Parser.parse((MemTests.prelude ++ fields)) - fields.permutations foreach { permutation => + fields.permutations.foreach { permutation => val circuit = firrtl.Parser.parse((MemTests.prelude ++ permutation)) assert(golden === circuit) } @@ -56,13 +97,13 @@ class ParserSpec extends FirrtlFlatSpec { it should "have exactly one of each: data-type, depth, read-latency, and write-latency" in { import MemTests._ - def parseWithoutField(s: String) = firrtl.Parser.parse((prelude ++ fieldsToSeq(fields - s))) + def parseWithoutField(s: String) = firrtl.Parser.parse((prelude ++ fieldsToSeq(fields - s))) def parseWithDuplicate(k: String, v: String) = firrtl.Parser.parse((prelude ++ fieldsToSeq(fields) :+ s" ${k} => ${v}")) - Seq("data-type", "depth", "read-latency", "write-latency") foreach { field => - an [ParameterNotSpecifiedException] should be thrownBy { parseWithoutField(field) } - an [ParameterRedefinedException] should be thrownBy { parseWithDuplicate(field, fields(field)) } + Seq("data-type", "depth", "read-latency", "write-latency").foreach { field => + an[ParameterNotSpecifiedException] should be thrownBy { parseWithoutField(field) } + an[ParameterRedefinedException] should be thrownBy { parseWithDuplicate(field, fields(field)) } } } @@ -86,7 +127,7 @@ class ParserSpec extends FirrtlFlatSpec { import RegTests._ val res = firrtl.Parser.parse((prelude :+ s"${reg} with : (${reset}) $finfo" :+ " wire a : UInt")) CircuitState(res, Nil) should containTree { - case DefRegister(`fileInfo`, `regName`, _,_,_,_) => true + case DefRegister(`fileInfo`, `regName`, _, _, _, _) => true } } @@ -94,7 +135,7 @@ class ParserSpec extends FirrtlFlatSpec { import RegTests._ val res = firrtl.Parser.parse((prelude :+ s"${reg} with :\n (${reset}) $finfo")) CircuitState(res, Nil) should containTree { - case DefRegister(`fileInfo`, `regName`, _,_,_,_) => true + case DefRegister(`fileInfo`, `regName`, _, _, _, _) => true } } @@ -102,35 +143,34 @@ class ParserSpec extends FirrtlFlatSpec { import RegTests._ val res = firrtl.Parser.parse((prelude :+ s"${reg} $finfo")) CircuitState(res, Nil) should containTree { - case DefRegister(`fileInfo`, `regName`, _,_,_,_) => true + case DefRegister(`fileInfo`, `regName`, _, _, _, _) => true } } // ********** Keywords ********** "Keywords" should "be allowed as Ids" in { import KeywordTests._ - keywords foreach { keyword => + keywords.foreach { keyword => firrtl.Parser.parse((prelude :+ s" wire ${keyword} : UInt")) } } it should "be allowed on lhs in connects" in { import KeywordTests._ - keywords foreach { keyword => - firrtl.Parser.parse((prelude ++ Seq(s" wire ${keyword} : UInt", - s" ${keyword} <= ${keyword}"))) + keywords.foreach { keyword => + firrtl.Parser.parse((prelude ++ Seq(s" wire ${keyword} : UInt", s" ${keyword} <= ${keyword}"))) } } // ********** Digits as Fields ********** "Digits" should "be legal fields in bundles and in subexpressions" in { val input = """ - |circuit Test : - | module Test : - | input in : { 0 : { 0 : UInt<32>, flip 1 : UInt<32> } } - | input in2 : { 4 : { 23 : { foo : UInt<32>, bar : { flip 123 : UInt<32> } } } } - | in.0.1 <= in.0.0 - | in2.4.23.bar.123 <= in2.4.23.foo + |circuit Test : + | module Test : + | input in : { 0 : { 0 : UInt<32>, flip 1 : UInt<32> } } + | input in2 : { 4 : { 23 : { foo : UInt<32>, bar : { flip 123 : UInt<32> } } } } + | in.0.1 <= in.0.0 + | in2.4.23.bar.123 <= in2.4.23.foo """.stripMargin val c = firrtl.Parser.parse(input) firrtl.Parser.parse(c.serialize) @@ -148,7 +188,7 @@ class ParserSpec extends FirrtlFlatSpec { } def check(inFormat: String, ref: Integer): Unit = { - (circuit(inFormat)) should be (circuit(ref.toString)) + (circuit(inFormat)) should be(circuit(ref.toString)) } val checks = Map( @@ -166,25 +206,25 @@ class ParserSpec extends FirrtlFlatSpec { ) checks.foreach { case (k, v) => check(k, v) } - } + } // ********** Doubles as parameters ********** "Doubles" should "be legal parameters for extmodules" in { val nums = Seq("1.0", "7.6", "3.00004", "1.0E10", "1.0023E-17") val signs = Seq("", "+", "-") - val tests = "0.0" +: (signs flatMap (s => nums map (n => s + n))) + val tests = "0.0" +: (signs.flatMap(s => nums.map(n => s + n))) for (test <- tests) { val input = s""" - |circuit Test : - | extmodule Ext : - | input foo : UInt<32> - | - | defname = MyExtModule - | parameter REAL = $test - | - | module Test : - | input foo : UInt<32> - | output bar : UInt<32> + |circuit Test : + | extmodule Ext : + | input foo : UInt<32> + | + | defname = MyExtModule + | parameter REAL = $test + | + | module Test : + | input foo : UInt<32> + | output bar : UInt<32> """.stripMargin val c = firrtl.Parser.parse(input) firrtl.Parser.parse(c.serialize) @@ -193,16 +233,16 @@ class ParserSpec extends FirrtlFlatSpec { "Strings" should "be legal parameters for extmodules" in { val input = s""" - |circuit Test : - | extmodule Ext : - | input foo : UInt<32> - | - | defname = MyExtModule - | parameter STR = "hello=%d" - | - | module Test : - | input foo : UInt<32> - | output bar : UInt<32> + |circuit Test : + | extmodule Ext : + | input foo : UInt<32> + | + | defname = MyExtModule + | parameter STR = "hello=%d" + | + | module Test : + | input foo : UInt<32> + | output bar : UInt<32> """.stripMargin val c = firrtl.Parser.parse(input) firrtl.Parser.parse(c.serialize) @@ -210,37 +250,37 @@ class ParserSpec extends FirrtlFlatSpec { "Parsing errors" should "be reported as normal exceptions" in { val input = s""" - |circuit Test - | module Test : + |circuit Test + | module Test : - |""".stripMargin + |""".stripMargin val manager = new ExecutionOptionsManager("test") with HasFirrtlOptions { firrtlOptions = FirrtlExecutionOptions(firrtlSource = Some(input)) } - a [SyntaxErrorsException] shouldBe thrownBy { + a[SyntaxErrorsException] shouldBe thrownBy { Driver.execute(manager) } } "Trailing syntax errors" should "be caught in the parser" in { val input = s""" - |circuit Foo: - | module Bar: - | input a: UInt<1> - |output b: UInt<1> - | b <- a - | - | module Foo: - | input a: UInt<1> - | output b: UInt<1> - | inst bar of Bar - | bar.a <- a - | b <- bar.b + |circuit Foo: + | module Bar: + | input a: UInt<1> + |output b: UInt<1> + | b <- a + | + | module Foo: + | input a: UInt<1> + | output b: UInt<1> + | inst bar of Bar + | bar.a <- a + | b <- bar.b """.stripMargin val manager = new ExecutionOptionsManager("test") with HasFirrtlOptions { firrtlOptions = FirrtlExecutionOptions(firrtlSource = Some(input)) } - a [SyntaxErrorsException] shouldBe thrownBy { + a[SyntaxErrorsException] shouldBe thrownBy { Driver.execute(manager) } } @@ -250,9 +290,9 @@ class ParserSpec extends FirrtlFlatSpec { val info = ir.MultiInfo(Seq(ir.MultiInfo(Seq(ir.FileInfo("a"))), ir.FileInfo("b"), ir.FileInfo("c"))) val input = s"""circuit m:${info.serialize} - | module m: - | skip - |""".stripMargin + | module m: + | skip + |""".stripMargin val c = firrtl.Parser.parse(input) assert(c.info == ir.FileInfo("a b c")) } @@ -272,14 +312,14 @@ class ParserPropSpec extends FirrtlPropSpec { } yield (x :: xs).mkString property("Identifiers should allow [A-Za-z0-9_$] but not allow starting with a digit or $") { - forAll (identifier) { id => + forAll(identifier) { id => whenever(id.nonEmpty) { val input = s""" - |circuit Test : - | module Test : - | input $id : UInt<32> - |""".stripMargin - firrtl.Parser.parse(input split "\n") + |circuit Test : + | module Test : + | input $id : UInt<32> + |""".stripMargin + firrtl.Parser.parse(input.split("\n")) } } } @@ -289,15 +329,16 @@ class ParserPropSpec extends FirrtlPropSpec { } yield xs.mkString property("Bundle fields should allow [A-Za-z0-9_] including starting with a digit or $") { - forAll (identifier, bundleField) { case (id, field) => - whenever(id.nonEmpty && field.nonEmpty) { - val input = s""" - |circuit Test : - | module Test : - | input $id : { $field : UInt<32> } - |""".stripMargin - firrtl.Parser.parse(input split "\n") - } + forAll(identifier, bundleField) { + case (id, field) => + whenever(id.nonEmpty && field.nonEmpty) { + val input = s""" + |circuit Test : + | module Test : + | input $id : { $field : UInt<32> } + |""".stripMargin + firrtl.Parser.parse(input.split("\n")) + } } } } diff --git a/src/test/scala/firrtlTests/PresetSpec.scala b/src/test/scala/firrtlTests/PresetSpec.scala index 689a910d..9fa64647 100644 --- a/src/test/scala/firrtlTests/PresetSpec.scala +++ b/src/test/scala/firrtlTests/PresetSpec.scala @@ -13,156 +13,178 @@ class PresetSpec extends FirrtlFlatSpec { def compile(input: String, annos: AnnotationSeq): CircuitState = (new VerilogCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos), List.empty) def compileBody(modules: ModuleSeq) = { - val annos = Seq(new PresetAnnotation(CircuitTarget("Test").module("Test").ref("reset")), firrtl.transforms.NoDCEAnnotation) + val annos = + Seq(new PresetAnnotation(CircuitTarget("Test").module("Test").ref("reset")), firrtl.transforms.NoDCEAnnotation) var str = """ - |circuit Test : - |""".stripMargin - modules foreach ((m: Mod) => { + |circuit Test : + |""".stripMargin + modules.foreach((m: Mod) => { val header = "|module " + m(0) + " :" - str += header.stripMargin.stripMargin.split("\n").mkString(" ", "\n ", "") + str += header.stripMargin.stripMargin.split("\n").mkString(" ", "\n ", "") str += m(1).split("\n").mkString(" ", "\n ", "") str += """ - |""".stripMargin + |""".stripMargin }) - compile(str,annos) + compile(str, annos) } "Preset" should """behave properly given a `Preset` annotated `AsyncReset` INPUT reset: - replace AsyncReset specific blocks by standard Register blocks - add inline declaration of all registers connected to reset - remove the useless input port""" in { - val result = compileBody(Seq(Seq("Test",s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<1> - |output z : UInt<1> - |reg r : UInt<1>, clock with : (reset => (reset, UInt(0))) - |r <= x - |z <= r""".stripMargin)) + val result = compileBody( + Seq( + Seq( + "Test", + s""" + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1> + |output z : UInt<1> + |reg r : UInt<1>, clock with : (reset => (reset, UInt(0))) + |r <= x + |z <= r""".stripMargin + ) + ) ) - result shouldNot containLine ("always @(posedge clock or posedge reset) begin") - result shouldNot containLines ( - "if (reset) begin", - "r = 1'h0;", - "end") - result shouldNot containLine ("input reset,") - result should containLine ("always @(posedge clock) begin") - result should containLine ("reg r = 1'h0;") + result shouldNot containLine("always @(posedge clock or posedge reset) begin") + result shouldNot containLines("if (reset) begin", "r = 1'h0;", "end") + result shouldNot containLine("input reset,") + result should containLine("always @(posedge clock) begin") + result should containLine("reg r = 1'h0;") } - + it should """behave properly given a `Preset` annotated `AsyncReset` WIRE reset: - replace AsyncReset specific blocks by standard Register blocks - add inline declaration of all registers connected to reset - remove the useless wire declaration and assignation""" in { - val result = compileBody(Seq(Seq("Test",s""" - |input clock : Clock - |input x : UInt<1> - |output z : UInt<1> - |wire reset : AsyncReset - |reset <= asAsyncReset(UInt(0)) - |reg r : UInt<1>, clock with : (reset => (reset, UInt(0))) - |r <= x - |z <= r""".stripMargin)) + val result = compileBody( + Seq( + Seq( + "Test", + s""" + |input clock : Clock + |input x : UInt<1> + |output z : UInt<1> + |wire reset : AsyncReset + |reset <= asAsyncReset(UInt(0)) + |reg r : UInt<1>, clock with : (reset => (reset, UInt(0))) + |r <= x + |z <= r""".stripMargin + ) + ) ) - result shouldNot containLine ("always @(posedge clock or posedge reset) begin") - result shouldNot containLines ( - "if (reset) begin", - "r = 1'h0;", - "end") - result should containLine ("always @(posedge clock) begin") - result should containLine ("reg r = 1'h0;") - // it should also remove useless asyncReset signal, all along the path down to registers - result shouldNot containLine ("wire reset;") - result shouldNot containLine ("assign reset = 1'h0;") + result shouldNot containLine("always @(posedge clock or posedge reset) begin") + result shouldNot containLines("if (reset) begin", "r = 1'h0;", "end") + result should containLine("always @(posedge clock) begin") + result should containLine("reg r = 1'h0;") + // it should also remove useless asyncReset signal, all along the path down to registers + result shouldNot containLine("wire reset;") + result shouldNot containLine("assign reset = 1'h0;") } it should "raise TreeCleanUpOrphantException on cast of annotated AsyncReset" in { - an [firrtl.transforms.PropagatePresetAnnotations.TreeCleanUpOrphanException] shouldBe thrownBy { - compileBody(Seq(Seq("Test",s""" - |input clock : Clock - |input x : UInt<1> - |output z : UInt<1> - |output sz : UInt<1> - |wire reset : AsyncReset - |reset <= asAsyncReset(UInt(0)) - |reg r : UInt<1>, clock with : (reset => (reset, UInt(0))) - |wire sreset : UInt<1> - |sreset <= asUInt(reset) ; this is FORBIDDEN - |reg s : UInt<1>, clock with : (reset => (sreset, UInt(0))) - |r <= x - |s <= x - |z <= r - |sz <= s""".stripMargin)) + an[firrtl.transforms.PropagatePresetAnnotations.TreeCleanUpOrphanException] shouldBe thrownBy { + compileBody( + Seq( + Seq( + "Test", + s""" + |input clock : Clock + |input x : UInt<1> + |output z : UInt<1> + |output sz : UInt<1> + |wire reset : AsyncReset + |reset <= asAsyncReset(UInt(0)) + |reg r : UInt<1>, clock with : (reset => (reset, UInt(0))) + |wire sreset : UInt<1> + |sreset <= asUInt(reset) ; this is FORBIDDEN + |reg s : UInt<1>, clock with : (reset => (sreset, UInt(0))) + |r <= x + |s <= x + |z <= r + |sz <= s""".stripMargin + ) + ) ) } } - + it should "propagate through bundles" in { - val result = compileBody(Seq(Seq("Test",s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<1> - |output z : UInt<1> - |wire bundle : {in_rst: AsyncReset, out_rst:AsyncReset} - |bundle.in_rst <= reset - |bundle.out_rst <= bundle.in_rst - |reg r : UInt<1>, clock with : (reset => (bundle.out_rst, UInt(0))) - |r <= x - |z <= r""".stripMargin)) + val result = compileBody( + Seq( + Seq( + "Test", + s""" + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1> + |output z : UInt<1> + |wire bundle : {in_rst: AsyncReset, out_rst:AsyncReset} + |bundle.in_rst <= reset + |bundle.out_rst <= bundle.in_rst + |reg r : UInt<1>, clock with : (reset => (bundle.out_rst, UInt(0))) + |r <= x + |z <= r""".stripMargin + ) + ) ) - result shouldNot containLine ("input reset,") - result shouldNot containLine ("always @(posedge clock or posedge reset) begin") - result shouldNot containLines ( - "if (reset) begin", - "r = 1'h0;", - "end") - result should containLine ("always @(posedge clock) begin") - result should containLine ("reg r = 1'h0;") + result shouldNot containLine("input reset,") + result shouldNot containLine("always @(posedge clock or posedge reset) begin") + result shouldNot containLines("if (reset) begin", "r = 1'h0;", "end") + result should containLine("always @(posedge clock) begin") + result should containLine("reg r = 1'h0;") } it should "propagate through vectors" in { - val result = compileBody(Seq(Seq("Test",s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<1> - |output z : UInt<1> - |wire vector : AsyncReset[2] - |vector[0] <= reset - |vector[1] <= vector[0] - |reg r : UInt<1>, clock with : (reset => (vector[1], UInt(0))) - |r <= x - |z <= r""".stripMargin)) + val result = compileBody( + Seq( + Seq( + "Test", + s""" + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1> + |output z : UInt<1> + |wire vector : AsyncReset[2] + |vector[0] <= reset + |vector[1] <= vector[0] + |reg r : UInt<1>, clock with : (reset => (vector[1], UInt(0))) + |r <= x + |z <= r""".stripMargin + ) + ) ) - result shouldNot containLine ("input reset,") - result shouldNot containLine ("always @(posedge clock or posedge reset) begin") - result shouldNot containLines ( - "if (reset) begin", - "r = 1'h0;", - "end") - result should containLine ("always @(posedge clock) begin") - result should containLine ("reg r = 1'h0;") + result shouldNot containLine("input reset,") + result shouldNot containLine("always @(posedge clock or posedge reset) begin") + result shouldNot containLines("if (reset) begin", "r = 1'h0;", "end") + result should containLine("always @(posedge clock) begin") + result should containLine("reg r = 1'h0;") } - + it should "propagate through bundles of vectors" in { - val result = compileBody(Seq(Seq("Test",s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<1> - |output z : UInt<1> - |wire bundle : {in_rst: AsyncReset[2], out_rst:AsyncReset} - |bundle.in_rst[0] <= reset - |bundle.in_rst[1] <= bundle.in_rst[0] - |bundle.out_rst <= bundle.in_rst[1] - |reg r : UInt<1>, clock with : (reset => (bundle.out_rst, UInt(0))) - |r <= x - |z <= r""".stripMargin)) + val result = compileBody( + Seq( + Seq( + "Test", + s""" + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1> + |output z : UInt<1> + |wire bundle : {in_rst: AsyncReset[2], out_rst:AsyncReset} + |bundle.in_rst[0] <= reset + |bundle.in_rst[1] <= bundle.in_rst[0] + |bundle.out_rst <= bundle.in_rst[1] + |reg r : UInt<1>, clock with : (reset => (bundle.out_rst, UInt(0))) + |r <= x + |z <= r""".stripMargin + ) + ) ) - result shouldNot containLine ("input reset,") - result shouldNot containLine ("always @(posedge clock or posedge reset) begin") - result shouldNot containLines ( - "if (reset) begin", - "r = 1'h0;", - "end") - result should containLine ("always @(posedge clock) begin") - result should containLine ("reg r = 1'h0;") + result shouldNot containLine("input reset,") + result shouldNot containLine("always @(posedge clock or posedge reset) begin") + result shouldNot containLines("if (reset) begin", "r = 1'h0;", "end") + result should containLine("always @(posedge clock) begin") + result should containLine("reg r = 1'h0;") } it should """propagate properly accross modules: - replace AsyncReset specific blocks by standard Register blocks @@ -171,70 +193,79 @@ class PresetSpec extends FirrtlFlatSpec { - remove the useless instance connections - remove wires and assignations used in instance connections """ in { - val result = compileBody(Seq( - Seq("TestA",s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<1> - |output z : UInt<1> - |reg r : UInt<1>, clock with : (reset => (reset, UInt(0))) - |r <= x - |z <= r - |""".stripMargin), - Seq("Test",s""" - |input clock : Clock - |input x : UInt<1> - |output z : UInt<1> - |wire reset : AsyncReset - |reset <= asAsyncReset(UInt(0)) - |inst i of TestA - |i.clock <= clock - |i.reset <= reset - |i.x <= x - |z <= i.z""".stripMargin) - )) + val result = compileBody( + Seq( + Seq( + "TestA", + s""" + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1> + |output z : UInt<1> + |reg r : UInt<1>, clock with : (reset => (reset, UInt(0))) + |r <= x + |z <= r + |""".stripMargin + ), + Seq( + "Test", + s""" + |input clock : Clock + |input x : UInt<1> + |output z : UInt<1> + |wire reset : AsyncReset + |reset <= asAsyncReset(UInt(0)) + |inst i of TestA + |i.clock <= clock + |i.reset <= reset + |i.x <= x + |z <= i.z""".stripMargin + ) + ) + ) // assess that all useless connections are not emitted - result shouldNot containLine ("wire i_reset;") - result shouldNot containLine (".reset(i_reset),") - result shouldNot containLine ("assign i_reset = reset;") - result shouldNot containLine ("input reset,") - - result shouldNot containLine ("always @(posedge clock or posedge reset) begin") - result shouldNot containLines ( - "if (reset) begin", - "r = 1'h0;", - "end") - result should containLine ("always @(posedge clock) begin") - result should containLine ("reg r = 1'h0;") + result shouldNot containLine("wire i_reset;") + result shouldNot containLine(".reset(i_reset),") + result shouldNot containLine("assign i_reset = reset;") + result shouldNot containLine("input reset,") + + result shouldNot containLine("always @(posedge clock or posedge reset) begin") + result shouldNot containLines("if (reset) begin", "r = 1'h0;", "end") + result should containLine("always @(posedge clock) begin") + result should containLine("reg r = 1'h0;") } it should "propagate even through disordonned statements" in { - val result = compileBody(Seq(Seq("Test",s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<1> - |output z : UInt<1> - |wire bundle : {in_rst: AsyncReset, out_rst:AsyncReset} - |reg r : UInt<1>, clock with : (reset => (bundle.out_rst, UInt(0))) - |bundle.out_rst <= bundle.in_rst - |bundle.in_rst <= reset - |r <= x - |z <= r""".stripMargin)) + val result = compileBody( + Seq( + Seq( + "Test", + s""" + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1> + |output z : UInt<1> + |wire bundle : {in_rst: AsyncReset, out_rst:AsyncReset} + |reg r : UInt<1>, clock with : (reset => (bundle.out_rst, UInt(0))) + |bundle.out_rst <= bundle.in_rst + |bundle.in_rst <= reset + |r <= x + |z <= r""".stripMargin + ) + ) ) - result shouldNot containLine ("input reset,") - result shouldNot containLine ("always @(posedge clock or posedge reset) begin") - result shouldNot containLines ( - "if (reset) begin", - "r = 1'h0;", - "end") - result should containLine ("always @(posedge clock) begin") - result should containLine ("reg r = 1'h0;") + result shouldNot containLine("input reset,") + result shouldNot containLine("always @(posedge clock or posedge reset) begin") + result shouldNot containLines("if (reset) begin", "r = 1'h0;", "end") + result should containLine("always @(posedge clock) begin") + result should containLine("reg r = 1'h0;") } } -class PresetExecutionTest extends ExecutionTest( - "PresetTester", - "/features", - annotations = Seq(new PresetAnnotation(CircuitTarget("PresetTester").module("PresetTester").ref("preset"))) -) +class PresetExecutionTest + extends ExecutionTest( + "PresetTester", + "/features", + annotations = Seq(new PresetAnnotation(CircuitTarget("PresetTester").module("PresetTester").ref("preset"))) + ) diff --git a/src/test/scala/firrtlTests/ProtoBufSpec.scala b/src/test/scala/firrtlTests/ProtoBufSpec.scala index 3a94ec3f..7cfdc4dc 100644 --- a/src/test/scala/firrtlTests/ProtoBufSpec.scala +++ b/src/test/scala/firrtlTests/ProtoBufSpec.scala @@ -44,50 +44,50 @@ class ProtoBufSpec extends FirrtlFlatSpec { val cistream = com.google.protobuf.CodedInputStream.newInstance(istream) cistream.setRecursionLimit(Integer.MAX_VALUE) val protobuf2 = firrtl.FirrtlProtos.Firrtl.parseFrom(cistream) - protobuf2 should equal (protobuf) + protobuf2 should equal(protobuf) // Test that our faster serialization matches generated serialization val ostream2 = new java.io.ByteArrayOutputStream proto.ToProto.writeToStream(ostream2, circuit) - ostream2.toByteArray.toList should equal (ostream.toByteArray.toList) + ostream2.toByteArray.toList should equal(ostream.toByteArray.toList) } } // ********** Focused Tests ********** // The goal is to fill coverage holes left after the above - behavior of "ProtoBuf serialization and deserialization" + behavior.of("ProtoBuf serialization and deserialization") import firrtl.proto._ it should "support UnknownWidth" in { // Note that this has to be handled in the parent object so we need to test everything that has a width val uint = ir.UIntType(ir.UnknownWidth) - FromProto.convert(ToProto.convert(uint).build) should equal (uint) + FromProto.convert(ToProto.convert(uint).build) should equal(uint) val sint = ir.SIntType(ir.UnknownWidth) - FromProto.convert(ToProto.convert(sint).build) should equal (sint) + FromProto.convert(ToProto.convert(sint).build) should equal(sint) val ftpe = ir.FixedType(ir.UnknownWidth, ir.UnknownWidth) - FromProto.convert(ToProto.convert(ftpe).build) should equal (ftpe) + FromProto.convert(ToProto.convert(ftpe).build) should equal(ftpe) val atpe = ir.AnalogType(ir.UnknownWidth) - FromProto.convert(ToProto.convert(atpe).build) should equal (atpe) + FromProto.convert(ToProto.convert(atpe).build) should equal(atpe) val ulit = ir.UIntLiteral(123, ir.UnknownWidth) - FromProto.convert(ToProto.convert(ulit).build) should equal (ulit) + FromProto.convert(ToProto.convert(ulit).build) should equal(ulit) val slit = ir.SIntLiteral(-123, ir.UnknownWidth) - FromProto.convert(ToProto.convert(slit).build) should equal (slit) + FromProto.convert(ToProto.convert(slit).build) should equal(slit) val flit = ir.FixedLiteral(-123, ir.UnknownWidth, ir.UnknownWidth) - FromProto.convert(ToProto.convert(flit).build) should equal (flit) + FromProto.convert(ToProto.convert(flit).build) should equal(flit) } it should "support all Primops" in { val builtInOps = PrimOps.listing.map(PrimOps.fromString(_)) for (op <- builtInOps) { val expr = DoPrim(op, List.empty, List.empty, ir.UnknownType) - FromProto.convert(ToProto.convert(expr).build) should equal (expr) + FromProto.convert(ToProto.convert(expr).build) should equal(expr) } } @@ -103,25 +103,25 @@ class ProtoBufSpec extends FirrtlFlatSpec { RawStringParam("param4", "get some raw strings") ) val ext = ir.ExtModule(ir.NoInfo, "MyModule", ports, "DefNameHere", params) - FromProto.convert(ToProto.convert(ext).build) should equal (ext) + FromProto.convert(ToProto.convert(ext).build) should equal(ext) } it should "support FixedType" in { val ftpe = ir.FixedType(IntWidth(8), IntWidth(4)) - FromProto.convert(ToProto.convert(ftpe).build) should equal (ftpe) + FromProto.convert(ToProto.convert(ftpe).build) should equal(ftpe) } it should "support FixedLiteral" in { val flit = ir.FixedLiteral(3, IntWidth(8), IntWidth(4)) - FromProto.convert(ToProto.convert(flit).build) should equal (flit) + FromProto.convert(ToProto.convert(flit).build) should equal(flit) } it should "support Analog and Attach" in { val analog = ir.AnalogType(IntWidth(8)) - FromProto.convert(ToProto.convert(analog).build) should equal (analog) + FromProto.convert(ToProto.convert(analog).build) should equal(analog) val attach = ir.Attach(ir.NoInfo, Seq(Reference("hi", ir.UnknownType))) - FromProto.convert(ToProto.convert(attach).head.build) should equal (attach) + FromProto.convert(ToProto.convert(attach).head.build) should equal(attach) } // Regression tests were generated before Chisel could emit else @@ -129,12 +129,12 @@ class ProtoBufSpec extends FirrtlFlatSpec { val expr = Reference("hi", ir.UnknownType) val stmt = Connect(ir.NoInfo, expr, expr) val when = ir.Conditionally(ir.NoInfo, expr, stmt, stmt) - FromProto.convert(ToProto.convert(when).head.build) should equal (when) + FromProto.convert(ToProto.convert(when).head.build) should equal(when) } it should "support SIntLiteral with a width" in { val slit = ir.SIntLiteral(-123) - FromProto.convert(ToProto.convert(slit).build) should equal (slit) + FromProto.convert(ToProto.convert(slit).build) should equal(slit) } // Backwards compatibility @@ -143,18 +143,21 @@ class ProtoBufSpec extends FirrtlFlatSpec { val mem = DefMemory(NoInfo, "m", UIntType(IntWidth(8)), size, 1, 1, List("r"), List("w"), List("rw")) val builder = ToProto.convert(mem).head val defaultProto = builder.build() - val oldProto = Firrtl.Statement.newBuilder().setMemory( - builder.getMemoryBuilder.clearDepth().setUintDepth(size) - ).build() + val oldProto = Firrtl.Statement + .newBuilder() + .setMemory( + builder.getMemoryBuilder.clearDepth().setUintDepth(size) + ) + .build() // These Proto messages are not the same - defaultProto shouldNot equal (oldProto) + defaultProto shouldNot equal(oldProto) val defaultMem = FromProto.convert(defaultProto) val oldMem = FromProto.convert(oldProto) // But they both deserialize to the original! - defaultMem should equal (mem) - oldMem should equal (mem) + defaultMem should equal(mem) + oldMem should equal(mem) } // Backwards compatibility @@ -164,43 +167,46 @@ class ProtoBufSpec extends FirrtlFlatSpec { val vtpe = ToProto.convert(VectorType(UIntType(IntWidth(8)), size)) val builder = ToProto.convert(cmem).head val defaultProto = builder.build() - val oldProto = Firrtl.Statement.newBuilder().setCmemory( - builder.getCmemoryBuilder.clearTypeAndDepth().setVectorType(vtpe) - ).build() + val oldProto = Firrtl.Statement + .newBuilder() + .setCmemory( + builder.getCmemoryBuilder.clearTypeAndDepth().setVectorType(vtpe) + ) + .build() // These Proto messages are not the same - defaultProto shouldNot equal (oldProto) + defaultProto shouldNot equal(oldProto) val defaultCMem = FromProto.convert(defaultProto) val oldCMem = FromProto.convert(oldProto) // But they both deserialize to the original! - defaultCMem should equal (cmem) - oldCMem should equal (cmem) + defaultCMem should equal(cmem) + oldCMem should equal(cmem) } // readunderwrite support it should "support readunderwrite parameters" in { val m1 = DefMemory(NoInfo, "m", UIntType(IntWidth(8)), 128, 1, 1, List("r"), List("w"), Nil, ir.ReadUnderWrite.Old) - FromProto.convert(ToProto.convert(m1).head.build) should equal (m1) + FromProto.convert(ToProto.convert(m1).head.build) should equal(m1) val m2 = m1.copy(readUnderWrite = ir.ReadUnderWrite.New) - FromProto.convert(ToProto.convert(m2).head.build) should equal (m2) + FromProto.convert(ToProto.convert(m2).head.build) should equal(m2) val cm1 = CDefMemory(NoInfo, "m", UIntType(IntWidth(8)), 128, true, ir.ReadUnderWrite.Old) - FromProto.convert(ToProto.convert(cm1).head.build) should equal (cm1) + FromProto.convert(ToProto.convert(cm1).head.build) should equal(cm1) val cm2 = cm1.copy(readUnderWrite = ir.ReadUnderWrite.New) - FromProto.convert(ToProto.convert(cm2).head.build) should equal (cm2) + FromProto.convert(ToProto.convert(cm2).head.build) should equal(cm2) } it should "support AsyncResetTypes" in { val port = ir.Port(ir.NoInfo, "reset", ir.Input, ir.AsyncResetType) - FromProto.convert(ToProto.convert(port).build) should equal (port) + FromProto.convert(ToProto.convert(port).build) should equal(port) } it should "support ResetTypes" in { val port = ir.Port(ir.NoInfo, "reset", ir.Input, ir.ResetType) - FromProto.convert(ToProto.convert(port).build) should equal (port) + FromProto.convert(ToProto.convert(port).build) should equal(port) } it should "support ValidIf" in { @@ -209,7 +215,7 @@ class ProtoBufSpec extends FirrtlFlatSpec { val vi = ir.ValidIf(en, value, value.tpe) // Deserialized has almost nothing filled in val expected = ir.ValidIf(ir.Reference("en"), ir.Reference("x"), UnknownType) - FromProto.convert(ToProto.convert(vi).build) should equal (expected) + FromProto.convert(ToProto.convert(vi).build) should equal(expected) } it should "appropriately escape and unescape FileInfo strings" in { @@ -220,10 +226,11 @@ class ProtoBufSpec extends FirrtlFlatSpec { "test\\]test" -> "test]test" ) - pairs.foreach { case (escaped, unescaped) => - val info = ir.FileInfo(escaped) - ToProto.convert(info).build().getText should equal (unescaped) - FromProto.convert(ToProto.convert(info).build) should equal (info) + pairs.foreach { + case (escaped, unescaped) => + val info = ir.FileInfo(escaped) + ToProto.convert(info).build().getText should equal(unescaped) + FromProto.convert(ToProto.convert(info).build) should equal(info) } } } diff --git a/src/test/scala/firrtlTests/RegisterUpdateSpec.scala b/src/test/scala/firrtlTests/RegisterUpdateSpec.scala index dfef5955..d335becc 100644 --- a/src/test/scala/firrtlTests/RegisterUpdateSpec.scala +++ b/src/test/scala/firrtlTests/RegisterUpdateSpec.scala @@ -22,7 +22,8 @@ object RegisterUpdateSpec { override def invalidates(a: Transform): Boolean = false def execute(state: CircuitState): CircuitState = { val emittedAnno = EmittedFirrtlCircuitAnnotation( - EmittedFirrtlCircuit(state.circuit.main, state.circuit.serialize, ".fir")) + EmittedFirrtlCircuit(state.circuit.main, state.circuit.serialize, ".fir") + ) val capturedState = state.copy(annotations = emittedAnno +: state.annotations) state.copy(annotations = CaptureStateAnno(capturedState) +: state.annotations) } @@ -37,64 +38,61 @@ class RegisterUpdateSpec extends FirrtlFlatSpec { } def compileBody(body: String) = { val str = """ - |circuit Test : - | module Test : - |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") + |circuit Test : + | module Test : + |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") compile(str) } "Register update logic" should "not duplicate common subtrees" in { val result = compileBody(s""" - |input clock : Clock - |output io : { flip in : UInt<8>, flip a : UInt<1>, flip b : UInt<1>, flip c : UInt<1>, out : UInt<8>} - |reg r : UInt<8>, clock - |when io.a : - | r <= io.in - |when io.b : - | when io.c : - | r <= UInt(2) - |io.out <= r""".stripMargin - ) + |input clock : Clock + |output io : { flip in : UInt<8>, flip a : UInt<1>, flip b : UInt<1>, flip c : UInt<1>, out : UInt<8>} + |reg r : UInt<8>, clock + |when io.a : + | r <= io.in + |when io.b : + | when io.c : + | r <= UInt(2) + |io.out <= r""".stripMargin) // Checking intermediate state between FlattenRegUpdate and Verilog emission val fstate = result.annotations.collectFirst { case CaptureStateAnno(x) => x }.get - fstate should containLine ("""r <= mux(io_b, mux(io_c, UInt<8>("h2"), _GEN_0), _GEN_0)""") + fstate should containLine("""r <= mux(io_b, mux(io_c, UInt<8>("h2"), _GEN_0), _GEN_0)""") // Checking the Verilog val verilog = result.getEmittedCircuit.value - result shouldNot containLine ("r <= io_in;") - verilog shouldNot include ("if (io_a) begin") - result should containLine ("r <= _GEN_0;") + result shouldNot containLine("r <= io_in;") + verilog shouldNot include("if (io_a) begin") + result should containLine("r <= _GEN_0;") } it should "not let duplicate subtrees on one register affect another" in { val result = compileBody(s""" - |input clock : Clock - |output io : { flip in : UInt<8>, flip a : UInt<1>, flip b : UInt<1>, flip c : UInt<1>, out : UInt<8>} + |input clock : Clock + |output io : { flip in : UInt<8>, flip a : UInt<1>, flip b : UInt<1>, flip c : UInt<1>, out : UInt<8>} - |reg r : UInt<8>, clock - |reg r2 : UInt<8>, clock - |when io.a : - | r <= io.in - | r2 <= io.in - |when io.b : - | r2 <= UInt(3) - | when io.c : - | r <= UInt(2) - |io.out <= and(r, r2)""".stripMargin - ) + |reg r : UInt<8>, clock + |reg r2 : UInt<8>, clock + |when io.a : + | r <= io.in + | r2 <= io.in + |when io.b : + | r2 <= UInt(3) + | when io.c : + | r <= UInt(2) + |io.out <= and(r, r2)""".stripMargin) // Checking intermediate state between FlattenRegUpdate and Verilog emission val fstate = result.annotations.collectFirst { case CaptureStateAnno(x) => x }.get - fstate should containLine ("""r <= mux(io_b, mux(io_c, UInt<8>("h2"), _GEN_0), _GEN_0)""") - fstate should containLine ("""r2 <= mux(io_b, UInt<8>("h3"), mux(io_a, io_in, r2))""") + fstate should containLine("""r <= mux(io_b, mux(io_c, UInt<8>("h2"), _GEN_0), _GEN_0)""") + fstate should containLine("""r2 <= mux(io_b, UInt<8>("h3"), mux(io_a, io_in, r2))""") // Checking the Verilog val verilog = result.getEmittedCircuit.value - result shouldNot containLine ("r <= io_in;") - result should containLine ("r <= _GEN_0;") - result should containLine ("r2 <= io_in;") - verilog should include ("if (io_a) begin") // For r2 + result shouldNot containLine("r <= io_in;") + result should containLine("r <= _GEN_0;") + result should containLine("r2 <= io_in;") + verilog should include("if (io_a) begin") // For r2 // 1 time for r2, old versions would have 3 occurences - Regex.quote("if (io_a) begin").r.findAllMatchIn(verilog).size should be (1) + Regex.quote("if (io_a) begin").r.findAllMatchIn(verilog).size should be(1) } } - diff --git a/src/test/scala/firrtlTests/RemoveWiresSpec.scala b/src/test/scala/firrtlTests/RemoveWiresSpec.scala index df3ceef6..3438da67 100644 --- a/src/test/scala/firrtlTests/RemoveWiresSpec.scala +++ b/src/test/scala/firrtlTests/RemoveWiresSpec.scala @@ -14,9 +14,9 @@ class RemoveWiresSpec extends FirrtlFlatSpec { (new LowFirrtlCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty) def compileBody(body: String) = { val str = """ - |circuit Test : - | module Test : - |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") + |circuit Test : + | module Test : + |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") compile(str) } @@ -26,7 +26,7 @@ class RemoveWiresSpec extends FirrtlFlatSpec { val nodes = mutable.ArrayBuffer.empty[DefNode] val wires = mutable.ArrayBuffer.empty[DefWire] def onStmt(stmt: Statement): Statement = { - stmt map onStmt match { + stmt.map(onStmt) match { case node: DefNode => nodes += node case wire: DefWire => wires += wire case _ => @@ -35,7 +35,7 @@ class RemoveWiresSpec extends FirrtlFlatSpec { } circuit.modules.head match { - case Module(_,_,_, body) => onStmt(body) + case Module(_, _, _, body) => onStmt(body) } (nodes.toSeq, wires.toSeq) } @@ -44,98 +44,90 @@ class RemoveWiresSpec extends FirrtlFlatSpec { require(circuit.modules.size == 1) val names = mutable.ArrayBuffer.empty[String] def onStmt(stmt: Statement): Statement = { - stmt map onStmt match { - case reg: DefRegister => names += reg.name - case wire: DefWire => names += wire.name - case node: DefNode => names += node.name + stmt.map(onStmt) match { + case reg: DefRegister => names += reg.name + case wire: DefWire => names += wire.name + case node: DefNode => names += node.name case _ => } stmt } circuit.modules.head match { - case Module(_,_,_, body) => onStmt(body) + case Module(_, _, _, body) => onStmt(body) } names.toSeq } "Remove Wires" should "turn wires and their single connect into nodes" in { val result = compileBody(s""" - |input a : UInt<8> - |output b : UInt<8> - |wire w : UInt<8> - |w <= a - |b <= w""".stripMargin - ) + |input a : UInt<8> + |output b : UInt<8> + |wire w : UInt<8> + |w <= a + |b <= w""".stripMargin) val (nodes, wires) = getNodesAndWires(result.circuit) - wires.size should be (0) + wires.size should be(0) - nodes.map(_.serialize) should be (Seq("node w = a")) + nodes.map(_.serialize) should be(Seq("node w = a")) } it should "order nodes in a legal, flow-forward way" in { val result = compileBody(s""" - |input a : UInt<8> - |output b : UInt<8> - |wire w : UInt<8> - |wire x : UInt<8> - |node y = x - |x <= w - |w <= a - |b <= y""".stripMargin - ) + |input a : UInt<8> + |output b : UInt<8> + |wire w : UInt<8> + |wire x : UInt<8> + |node y = x + |x <= w + |w <= a + |b <= y""".stripMargin) val (nodes, wires) = getNodesAndWires(result.circuit) - wires.size should be (0) - nodes.map(_.serialize) should be ( - Seq("node w = a", - "node x = w", - "node y = x") + wires.size should be(0) + nodes.map(_.serialize) should be( + Seq("node w = a", "node x = w", "node y = x") ) } it should "properly pad rhs of introduced nodes if necessary" in { val result = compileBody(s""" - |output b : UInt<8> - |wire w : UInt<8> - |w <= UInt(2) - |b <= w""".stripMargin - ) + |output b : UInt<8> + |wire w : UInt<8> + |w <= UInt(2) + |b <= w""".stripMargin) val (nodes, wires) = getNodesAndWires(result.circuit) - wires.size should be (0) - nodes.map(_.serialize) should be ( + wires.size should be(0) + nodes.map(_.serialize) should be( Seq("""node w = pad(UInt<2>("h2"), 8)""") ) } it should "support arbitrary expression for wire connection rhs" in { val result = compileBody(s""" - |input a : UInt<8> - |input b : UInt<8> - |output c : UInt<8> - |wire w : UInt<8> - |w <= tail(add(a, b), 1) - |c <= w""".stripMargin - ) + |input a : UInt<8> + |input b : UInt<8> + |output c : UInt<8> + |wire w : UInt<8> + |w <= tail(add(a, b), 1) + |c <= w""".stripMargin) val (nodes, wires) = getNodesAndWires(result.circuit) - wires.size should be (0) - nodes.map(_.serialize) should be ( + wires.size should be(0) + nodes.map(_.serialize) should be( Seq("""node w = tail(add(a, b), 1)""") ) } it should "do a reasonable job preserving input order for unrelatd logic" in { val result = compileBody(s""" - |input a : UInt<8> - |input b : UInt<8> - |output z : UInt<8> - |node x = not(a) - |node y = not(b) - |z <= and(x, y)""".stripMargin - ) + |input a : UInt<8> + |input b : UInt<8> + |output z : UInt<8> + |node x = not(a) + |node y = not(b) + |z <= and(x, y)""".stripMargin) val (nodes, wires) = getNodesAndWires(result.circuit) - wires.size should be (0) - nodes.map(_.serialize) should be ( - Seq("node x = not(a)", - "node y = not(b)") + wires.size should be(0) + nodes.map(_.serialize) should be( + Seq("node x = not(a)", "node y = not(b)") ) } @@ -148,52 +140,49 @@ class RemoveWiresSpec extends FirrtlFlatSpec { |""".stripMargin ) val names = orderedNames(result.circuit) - names should be (Seq("a", "clock2", "b")) + names should be(Seq("a", "clock2", "b")) } it should "order registers correctly" in { val result = compileBody(s""" - |input clock : Clock - |input a : UInt<8> - |output c : UInt<8> - |wire w : UInt<8> - |node n = tail(add(w, UInt(1)), 1) - |reg r : UInt<8>, clock - |w <= tail(add(r, a), 1) - |c <= n""".stripMargin - ) + |input clock : Clock + |input a : UInt<8> + |output c : UInt<8> + |wire w : UInt<8> + |node n = tail(add(w, UInt(1)), 1) + |reg r : UInt<8>, clock + |w <= tail(add(r, a), 1) + |c <= n""".stripMargin) // Check declaration before use is maintained firrtl.passes.CheckHighForm.execute(result) } it should "order registers with async reset correctly" in { val result = compileBody(s""" - |input clock : Clock - |input reset : UInt<1> - |input in : UInt<8> - |output out : UInt<8> - |wire areset : AsyncReset - |reg r : UInt<8>, clock with : (reset => (areset, UInt(0))) - |areset <= asAsyncReset(reset) - |r <= in - |out <= r - |""".stripMargin - ) + |input clock : Clock + |input reset : UInt<1> + |input in : UInt<8> + |output out : UInt<8> + |wire areset : AsyncReset + |reg r : UInt<8>, clock with : (reset => (areset, UInt(0))) + |areset <= asAsyncReset(reset) + |r <= in + |out <= r + |""".stripMargin) // Check declaration before use is maintained firrtl.passes.CheckHighForm.execute(result) } it should "order registers respecting initializations" in { - val result = compileBody( - s"""|input clock : Clock - |input foo : UInt<2> - |output bar : UInt<2> - |wire y_fault : UInt<2> - |reg y : UInt<2>, clock with : - | reset => (UInt<1>("h0"), y_fault) - |y_fault <= foo - |bar <= y - |""".stripMargin) + val result = compileBody(s"""|input clock : Clock + |input foo : UInt<2> + |output bar : UInt<2> + |wire y_fault : UInt<2> + |reg y : UInt<2>, clock with : + | reset => (UInt<1>("h0"), y_fault) + |y_fault <= foo + |bar <= y + |""".stripMargin) // Check declaration before use is maintained firrtl.passes.CheckHighForm.execute(result) } diff --git a/src/test/scala/firrtlTests/RenameMapSpec.scala b/src/test/scala/firrtlTests/RenameMapSpec.scala index 7931b94f..609d8eef 100644 --- a/src/test/scala/firrtlTests/RenameMapSpec.scala +++ b/src/test/scala/firrtlTests/RenameMapSpec.scala @@ -8,10 +8,10 @@ import firrtl.annotations._ import firrtl.testutils._ class RenameMapSpec extends FirrtlFlatSpec { - val cir = CircuitTarget("Top") - val cir2 = CircuitTarget("Pot") - val cir3 = CircuitTarget("Cir3") - val modA = cir.module("A") + val cir = CircuitTarget("Top") + val cir2 = CircuitTarget("Pot") + val cir3 = CircuitTarget("Cir3") + val modA = cir.module("A") val modA2 = cir2.module("A") val modB = cir.module("B") val foo = modA.ref("foo") @@ -26,69 +26,69 @@ class RenameMapSpec extends FirrtlFlatSpec { val middle = cir.module("Middle") val middle2 = cir.module("Middle2") - behavior of "RenameMap" + behavior.of("RenameMap") it should "return None if it does not rename something" in { val renames = RenameMap() - renames.get(modA) should be (None) - renames.get(foo) should be (None) + renames.get(modA) should be(None) + renames.get(foo) should be(None) } it should "return a Seq of renamed things if it does rename something" in { val renames = RenameMap() renames.record(foo, bar) - renames.get(foo) should be (Some(Seq(bar))) + renames.get(foo) should be(Some(Seq(bar))) } it should "allow something to be renamed to multiple things" in { val renames = RenameMap() renames.record(foo, bar) renames.record(foo, fizz) - renames.get(foo) should be (Some(Seq(bar, fizz))) + renames.get(foo) should be(Some(Seq(bar, fizz))) } it should "allow something to be renamed to nothing (ie. deleted)" in { val renames = RenameMap() renames.record(foo, Seq()) - renames.get(foo) should be (Some(Seq())) + renames.get(foo) should be(Some(Seq())) } it should "return None if something is renamed to itself" in { val renames = RenameMap() renames.record(foo, foo) - renames.get(foo) should be (None) + renames.get(foo) should be(None) } it should "allow targets to change module" in { val renames = RenameMap() renames.record(foo, fooB) - renames.get(foo) should be (Some(Seq(fooB))) + renames.get(foo) should be(Some(Seq(fooB))) } it should "rename targets if their module is renamed" in { val renames = RenameMap() renames.record(modA, modB) - renames.get(foo) should be (Some(Seq(fooB))) - renames.get(bar) should be (Some(Seq(barB))) + renames.get(foo) should be(Some(Seq(fooB))) + renames.get(bar) should be(Some(Seq(barB))) } it should "not rename already renamed targets if the module of the target is renamed" in { val renames = RenameMap() renames.record(modA, modB) renames.record(foo, bar) - renames.get(foo) should be (Some(Seq(bar))) + renames.get(foo) should be(Some(Seq(bar))) } it should "rename modules if their circuit is renamed" in { val renames = RenameMap() renames.record(cir, cir2) - renames.get(modA) should be (Some(Seq(modA2))) + renames.get(modA) should be(Some(Seq(modA2))) } it should "rename targets if their circuit is renamed" in { val renames = RenameMap() renames.record(cir, cir2) - renames.get(foo) should be (Some(Seq(foo2))) + renames.get(foo) should be(Some(Seq(foo2))) } val TopCircuit = cir @@ -105,44 +105,44 @@ class RenameMapSpec extends FirrtlFlatSpec { it should "rename targets if modules in the path are renamed" in { val renames = RenameMap() renames.record(Middle, Middle2) - renames.get(Top_m) should be (Some(Seq(Top.instOf("m", "Middle2")))) + renames.get(Top_m) should be(Some(Seq(Top.instOf("m", "Middle2")))) } it should "rename only the instance if instance and module in the path are renamed" in { val renames = RenameMap() renames.record(Middle, Middle2) renames.record(Top.instOf("m", "Middle"), Top.instOf("m2", "Middle")) - renames.get(Top_m) should be (Some(Seq(Top.instOf("m2", "Middle")))) + renames.get(Top_m) should be(Some(Seq(Top.instOf("m2", "Middle")))) } it should "rename targets if instance in the path are renamed" in { val renames = RenameMap() renames.record(Top.instOf("m", "Middle"), Top.instOf("m2", "Middle")) - renames.get(Top_m) should be (Some(Seq(Top.instOf("m2", "Middle")))) + renames.get(Top_m) should be(Some(Seq(Top.instOf("m2", "Middle")))) } it should "rename targets if instance and ofmodule in the path are renamed" in { val renames = RenameMap() val Top_m2 = Top.instOf("m2", "Middle2") renames.record(Top_m, Top_m2) - renames.get(Top_m) should be (Some(Seq(Top_m2))) + renames.get(Top_m) should be(Some(Seq(Top_m2))) } it should "properly do nothing if no remaps" in { val renames = RenameMap() - renames.get(Top_m_l_a) should be (None) + renames.get(Top_m_l_a) should be(None) } it should "properly rename if leaf is inlined" in { val renames = RenameMap() renames.record(Middle_l_a, Middle_la) - renames.get(Top_m_l_a) should be (Some(Seq(Top_m_la))) + renames.get(Top_m_l_a) should be(Some(Seq(Top_m_la))) } it should "properly rename if middle is inlined" in { val renames = RenameMap() renames.record(Top_m_l, Top.instOf("m_l", "Leaf")) - renames.get(Top_m_l_a) should be (Some(Seq(Top.instOf("m_l", "Leaf").ref("a")))) + renames.get(Top_m_l_a) should be(Some(Seq(Top.instOf("m_l", "Leaf").ref("a")))) } it should "properly rename if leaf and middle are inlined" in { @@ -151,18 +151,20 @@ class RenameMapSpec extends FirrtlFlatSpec { renames.record(Top_m_l_a, inlined) renames.record(Top_m_l, Nil) renames.record(Top_m, Nil) - renames.get(Top_m_l_a) should be (Some(Seq(inlined))) + renames.get(Top_m_l_a) should be(Some(Seq(inlined))) } it should "quickly rename a target with a long path" in { (0 until 50 by 10).foreach { endIdx => val renames = RenameMap() renames.record(TopCircuit.module("Y0"), TopCircuit.module("X0")) - val deepTarget = (0 until endIdx).foldLeft(Top: IsModule) { (t, idx) => - t.instOf("a", "A" + idx) - }.ref("ref") + val deepTarget = (0 until endIdx) + .foldLeft(Top: IsModule) { (t, idx) => + t.instOf("a", "A" + idx) + } + .ref("ref") val (millis, rename) = firrtl.Utils.time(renames.get(deepTarget)) - //rename should be(None) + //rename should be(None) } } @@ -171,7 +173,7 @@ class RenameMapSpec extends FirrtlFlatSpec { val Middle2 = cir.module("Middle2") renames.record(Middle, Middle2) renames.record(Middle.ref("l"), Middle.ref("lx")) - renames.get(Middle.ref("l")) should be (Some(Seq(Middle.ref("lx")))) + renames.get(Middle.ref("l")) should be(Some(Seq(Middle.ref("lx")))) } it should "rename with fields" in { @@ -181,7 +183,7 @@ class RenameMapSpec extends FirrtlFlatSpec { val Middle_i_f = Middle.ref("i").field("f") val renames = RenameMap() renames.record(Middle_o, Middle_i) - renames.get(Middle_o_f) should be (Some(Seq(Middle_i_f))) + renames.get(Middle_o_f) should be(Some(Seq(Middle_i_f))) } it should "rename instances with same ofModule" in { @@ -189,7 +191,7 @@ class RenameMapSpec extends FirrtlFlatSpec { val Middle_i = Middle.instOf("i", "O") val renames = RenameMap() renames.record(Middle_o, Middle_i) - renames.get(Middle.instOf("o", "O")) should be (Some(Seq(Middle.instOf("i", "O")))) + renames.get(Middle.instOf("o", "O")) should be(Some(Seq(Middle.instOf("i", "O")))) } it should "not treat references as instances targets" in { @@ -197,14 +199,14 @@ class RenameMapSpec extends FirrtlFlatSpec { val Middle_i = Middle.ref("i") val renames = RenameMap() renames.record(Middle_o, Middle_i) - renames.get(Middle.instOf("o", "O")) should be (None) + renames.get(Middle.instOf("o", "O")) should be(None) } it should "be able to rename weird stuff" in { // Renaming `from` to each of the `tos` at the same time should be ok case class BadRename(from: CompleteTarget, tos: Seq[CompleteTarget]) val badRenames = - Seq(//BadRename(foo, Seq(cir)), + Seq( //BadRename(foo, Seq(cir)), //BadRename(foo, Seq(modB)), //BadRename(modA, Seq(fooB)), //BadRename(modA, Seq(cir)), @@ -217,17 +219,17 @@ class RenameMapSpec extends FirrtlFlatSpec { val fromN = from val tosN = tos.mkString(", ") //it should s"error if a $fromN is renamed to $tosN" in { - val renames = RenameMap() - for (to <- tos) { - (from, to) match { - case (f: CircuitTarget, t: CircuitTarget) => renames.record(f, t) - case (f: IsMember, t: IsMember) => renames.record(f, t) - case _ => sys.error("Unexpected!") - } + val renames = RenameMap() + for (to <- tos) { + (from, to) match { + case (f: CircuitTarget, t: CircuitTarget) => renames.record(f, t) + case (f: IsMember, t: IsMember) => renames.record(f, t) + case _ => sys.error("Unexpected!") } - //a [FIRRTLException] shouldBe thrownBy { - renames.get(from) - //} + } + //a [FIRRTLException] shouldBe thrownBy { + renames.get(from) + //} //} } } @@ -247,8 +249,8 @@ class RenameMapSpec extends FirrtlFlatSpec { val top = CircuitTarget("Top") renames.record(top.module("A"), top.module("B")) renames.record(top.module("B"), top.module("A")) - renames.get(top.module("A")) should be (Some(Seq(top.module("B")))) - renames.get(top.module("B")) should be (Some(Seq(top.module("A")))) + renames.get(top.module("A")) should be(Some(Seq(top.module("B")))) + renames.get(top.module("B")) should be(Some(Seq(top.module("A")))) } it should "error if a reference is renamed to a module and vice versa" in { @@ -256,10 +258,10 @@ class RenameMapSpec extends FirrtlFlatSpec { val top = CircuitTarget("Top") renames.record(top.module("A").ref("ref"), top.module("B")) renames.record(top.module("C"), top.module("D").ref("ref")) - a [IllegalRenameException] shouldBe thrownBy { + a[IllegalRenameException] shouldBe thrownBy { renames.get(top.module("C")) } - a [IllegalRenameException] shouldBe thrownBy { + a[IllegalRenameException] shouldBe thrownBy { renames.get(top.module("A").ref("ref").field("field")) } renames.get(top.module("A").instOf("ref", "R")) should be(None) @@ -270,7 +272,7 @@ class RenameMapSpec extends FirrtlFlatSpec { val top = CircuitTarget("Top") renames.record(top.module("C"), top.module("D").ref("x")) - a [IllegalRenameException] shouldBe thrownBy { + a[IllegalRenameException] shouldBe thrownBy { renames.get(top.module("A").instOf("c", "C")) } } @@ -281,7 +283,7 @@ class RenameMapSpec extends FirrtlFlatSpec { renames.record(top.module("E").instOf("f", "F"), top.module("E").ref("g")) - a [IllegalRenameException] shouldBe thrownBy { + a[IllegalRenameException] shouldBe thrownBy { renames.get(top.module("E").instOf("f", "F").ref("g")) } } @@ -403,7 +405,7 @@ class RenameMapSpec extends FirrtlFlatSpec { .ref("ref") .field("f1") .field("f2") - val to2 = modA + val to2 = modA .instOf("b", "B") .instOf("c", "C") .ref("ref") @@ -417,7 +419,7 @@ class RenameMapSpec extends FirrtlFlatSpec { .instOf("c", "C") .ref("ref") .field("f1") - val to3 = modB + val to3 = modB .instOf("c", "C") .ref("ref") .field("f11") @@ -426,7 +428,7 @@ class RenameMapSpec extends FirrtlFlatSpec { // to: ~Top|C>refref // renamed last because it has no path val from4 = modC.ref("ref") - val to4 = modC.ref("refref") + val to4 = modC.ref("refref") val renames1 = RenameMap() renames1.record(from1, to1) @@ -435,14 +437,17 @@ class RenameMapSpec extends FirrtlFlatSpec { renames1.record(from4, to4) renames1.get(from1) should be { - Some(Seq(modA - .instOf("b", "B") - .instOf("c", "C") - .ref("ref") - .field("f1") - .field("f2") - .field("f33") - )) + Some( + Seq( + modA + .instOf("b", "B") + .instOf("c", "C") + .ref("ref") + .field("f1") + .field("f2") + .field("f33") + ) + ) } val renames2 = RenameMap() @@ -451,14 +456,17 @@ class RenameMapSpec extends FirrtlFlatSpec { renames2.record(from4, to4) renames2.get(from1) should be { - Some(Seq(modA - .instOf("b", "B") - .instOf("c", "C") - .ref("ref") - .field("f1") - .field("f22") - .field("f3") - )) + Some( + Seq( + modA + .instOf("b", "B") + .instOf("c", "C") + .ref("ref") + .field("f1") + .field("f22") + .field("f3") + ) + ) } val renames3 = RenameMap() @@ -466,14 +474,17 @@ class RenameMapSpec extends FirrtlFlatSpec { renames3.record(from4, to4) renames3.get(from1) should be { - Some(Seq(modA - .instOf("b", "B") - .instOf("c", "C") - .ref("ref") - .field("f11") - .field("f2") - .field("f3") - )) + Some( + Seq( + modA + .instOf("b", "B") + .instOf("c", "C") + .ref("ref") + .field("f11") + .field("f2") + .field("f3") + ) + ) } } @@ -498,8 +509,18 @@ class RenameMapSpec extends FirrtlFlatSpec { val to = cir.module("D").instOf("e", "E").instOf("f", "F").ref("foo").field("foo") renames.record(from, to) renames.get(cir.module("A").instOf("b", "B").instOf("c", "C").ref("foo").field("bar")) should be { - Some(Seq(cir.module("A").instOf("b", "B").instOf("c", "D") - .instOf("e", "E").instOf("f", "F").ref("foo").field("foo"))) + Some( + Seq( + cir + .module("A") + .instOf("b", "B") + .instOf("c", "D") + .instOf("e", "E") + .instOf("f", "F") + .ref("foo") + .field("foo") + ) + ) } } @@ -509,7 +530,7 @@ class RenameMapSpec extends FirrtlFlatSpec { val from = top.instOf("a", "A") val to = top.ref("b") renames.record(from, to) - a [IllegalRenameException] shouldBe thrownBy { + a[IllegalRenameException] shouldBe thrownBy { renames.get(from) } } @@ -520,7 +541,7 @@ class RenameMapSpec extends FirrtlFlatSpec { val from = top.ref("a") val to = top.ref("b") renames.record(from, to) - renames.get(top.instOf("a", "Foo")) should be (None) + renames.get(top.instOf("a", "Foo")) should be(None) } it should "correctly chain renames together" in { @@ -651,8 +672,8 @@ class RenameMapSpec extends FirrtlFlatSpec { val dupMod1 = top.module("A1") val dupMod2 = top.module("A2") - val relPath1 = dupMod1.addHierarchy("Foo", "a")//top.module("Foo").instOf("a", "A1") - val relPath2 = dupMod2.addHierarchy("Foo", "a")//top.module("Foo").instOf("a", "A2") + val relPath1 = dupMod1.addHierarchy("Foo", "a") //top.module("Foo").instOf("a", "A1") + val relPath2 = dupMod2.addHierarchy("Foo", "a") //top.module("Foo").instOf("a", "A2") val absPath1 = relPath1.addHierarchy("Top", "foo") val absPath2 = relPath2.addHierarchy("Top", "foo") @@ -766,7 +787,7 @@ class RenameMapSpec extends FirrtlFlatSpec { r.record(foo, foo) r.get(foo) should not be (empty) - r.get(foo).get should contain allOf (foo, bar) + (r.get(foo).get should contain).allOf(foo, bar) } it should "not record the same rename multiple times" in { @@ -807,7 +828,7 @@ class RenameMapSpec extends FirrtlFlatSpec { val r = RenameMap() r.delete(Mod) - r.get(foo) should be (Some(Nil)) + r.get(foo) should be(Some(Nil)) } it should "rename an instance if it has been renamed" in { @@ -818,8 +839,8 @@ class RenameMapSpec extends FirrtlFlatSpec { val i = top.instOf("i", "child") val i_ = top.instOf("i_", "child") r.record(i, i_) - r.get(i) should be (Some(Seq(i_))) - r.get(i.ref("a")) should be (Some(Seq(i_.ref("a")))) + r.get(i) should be(Some(Seq(i_))) + r.get(i.ref("a")) should be(Some(Seq(i_.ref("a")))) } it should "rename references to an instance's ports if the ports of the module have been renamed" in { @@ -830,7 +851,7 @@ class RenameMapSpec extends FirrtlFlatSpec { val r = RenameMap() r.record(child.ref("a"), Seq(child.ref("a_0"), child.ref("a_1"))) val i = top.instOf("i", "child") - r.get(i.ref("a")) should be (Some(Seq(i.ref("a_0"), i.ref("a_1")))) + r.get(i.ref("a")) should be(Some(Seq(i.ref("a_0"), i.ref("a_1")))) } it should "rename references to renamed instance's ports if the ports of the module have been renamed" in { @@ -848,6 +869,6 @@ class RenameMapSpec extends FirrtlFlatSpec { // The port and instance renames must be *explicitly* chained! val r = portRenames.andThen(instanceRenames) - r.get(i.ref("a")) should be (Some(Seq(i_.ref("a_0"), i_.ref("a_1")))) + r.get(i.ref("a")) should be(Some(Seq(i_.ref("a_0"), i_.ref("a_1")))) } } diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index cd2fdb05..17f4dcfd 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -25,7 +25,8 @@ class ReplSeqMemSpec extends SimpleTransformSpec { new SeqTransform { def inputForm = LowForm def outputForm = LowForm - def transforms = Seq(new ConstantPropagation, CommonSubexpressionElimination, new DeadCodeElimination, RemoveEmpty) + def transforms = + Seq(new ConstantPropagation, CommonSubexpressionElimination, new DeadCodeElimination, RemoveEmpty) } ) @@ -35,7 +36,12 @@ class ReplSeqMemSpec extends SimpleTransformSpec { // Verify that this does not throw an exception val fromConf = MemConf.fromString(text) // Verify the mems in the conf are the same as the expected ones - require(Set(fromConf: _*) == mems, "Parsed conf set:\n {\n " + fromConf.mkString(" ") + " }\n must be the same as reference conf set: \n {\n " + mems.toSeq.mkString(" ") + " }\n") + require( + Set(fromConf: _*) == mems, + "Parsed conf set:\n {\n " + fromConf.mkString( + " " + ) + " }\n must be the same as reference conf set: \n {\n " + mems.toSeq.mkString(" ") + " }\n" + ) } "ReplSeqMem" should "generate blackbox wrappers for mems of bundle type" in { @@ -63,7 +69,7 @@ circuit Top : MemConf("entries_info_ext", 24, 30, Map(WritePort -> 1, ReadPort -> 1), None) ) val confLoc = "ReplSeqMemTests.confTEMP" - val annos = Seq(ReplSeqMemAnnotation.parse("-c:Top:-o:"+confLoc)) + val annos = Seq(ReplSeqMemAnnotation.parse("-c:Top:-o:" + confLoc)) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // Check correctness of firrtl parse(res.getEmittedCircuit.value) @@ -88,7 +94,7 @@ circuit Top : """.stripMargin val mems = Set(MemConf("mem_ext", 32, 64, Map(MaskedWritePort -> 1), Some(64))) val confLoc = "ReplSeqMemTests.confTEMP" - val annos = Seq(ReplSeqMemAnnotation.parse("-c:Top:-o:"+confLoc)) + val annos = Seq(ReplSeqMemAnnotation.parse("-c:Top:-o:" + confLoc)) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // Check correctness of firrtl parse(res.getEmittedCircuit.value) @@ -116,7 +122,7 @@ circuit CustomMemory : """.stripMargin val mems = Set(MemConf("mem_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None)) val confLoc = "ReplSeqMemTests.confTEMP" - val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc)) + val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:" + confLoc)) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // Check correctness of firrtl parse(res.getEmittedCircuit.value) @@ -144,7 +150,7 @@ circuit CustomMemory : """.stripMargin val mems = Set(MemConf("mem_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None)) val confLoc = "ReplSeqMemTests.confTEMP" - val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc)) + val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:" + confLoc)) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // Check correctness of firrtl parse(res.getEmittedCircuit.value) @@ -153,8 +159,8 @@ circuit CustomMemory : (new java.io.File(confLoc)).delete() } - "ReplSeqMem Utility -- getConnectOrigin" should - "determine connect origin across nodes/PrimOps even if ConstProp isn't performed" in { + "ReplSeqMem Utility -- getConnectOrigin" should + "determine connect origin across nodes/PrimOps even if ConstProp isn't performed" in { def checkConnectOrigin(hurdle: String, origin: String) = { val input = s""" circuit Top : @@ -172,7 +178,7 @@ circuit Top : val circuit = InferTypes.run(ToWorkingIR.run(parse(input))) val m = circuit.modules.head.asInstanceOf[ir.Module] val connects = AnalysisUtils.getConnects(m) - val calculatedOrigin = AnalysisUtils.getOrigin(connects, "f").serialize + val calculatedOrigin = AnalysisUtils.getOrigin(connects, "f").serialize require(calculatedOrigin == origin, s"getConnectOrigin returns incorrect origin $calculatedOrigin !") } @@ -195,7 +201,7 @@ circuit Top : "validif(a, b)" -> "b" ) - tests foreach { case(hurdle, origin) => checkConnectOrigin(hurdle, origin) } + tests.foreach { case (hurdle, origin) => checkConnectOrigin(hurdle, origin) } } @@ -226,16 +232,17 @@ circuit CustomMemory : ) val confLoc = "ReplSeqMemTests.confTEMP" val annos = Seq( - ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc), - NoDedupMemAnnotation(ComponentName("mem_0", ModuleName("CustomMemory",CircuitName("CustomMemory"))))) + ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:" + confLoc), + NoDedupMemAnnotation(ComponentName("mem_0", ModuleName("CustomMemory", CircuitName("CustomMemory")))) + ) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // Check correctness of firrtl val circuit = parse(res.getEmittedCircuit.value) val numExtMods = circuit.modules.count { - case e: ExtModule => true + case e: ExtModule => true case _ => false } - numExtMods should be (2) + numExtMods should be(2) // Check the emitted conf checkMemConf(confLoc, mems) (new java.io.File(confLoc)).delete() @@ -272,16 +279,17 @@ circuit CustomMemory : ) val confLoc = "ReplSeqMemTests.confTEMP" val annos = Seq( - ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc), - NoDedupMemAnnotation(ComponentName("mem_1", ModuleName("CustomMemory",CircuitName("CustomMemory"))))) + ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:" + confLoc), + NoDedupMemAnnotation(ComponentName("mem_1", ModuleName("CustomMemory", CircuitName("CustomMemory")))) + ) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // Check correctness of firrtl val circuit = parse(res.getEmittedCircuit.value) val numExtMods = circuit.modules.count { - case e: ExtModule => true + case e: ExtModule => true case _ => false } - numExtMods should be (2) + numExtMods should be(2) // Check the emitted conf checkMemConf(confLoc, mems) (new java.io.File(confLoc)).delete() @@ -329,20 +337,21 @@ circuit CustomMemory : ) val confLoc = "ReplSeqMemTests.confTEMP" val annos = Seq( - ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc), - NoDedupMemAnnotation(ComponentName("mem_0", ModuleName("ChildMemory",CircuitName("CustomMemory"))))) + ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:" + confLoc), + NoDedupMemAnnotation(ComponentName("mem_0", ModuleName("ChildMemory", CircuitName("CustomMemory")))) + ) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // Check correctness of firrtl val circuit = parse(res.getEmittedCircuit.value) val numExtMods = circuit.modules.count { - case e: ExtModule => true + case e: ExtModule => true case _ => false } // Note that there are 3 identical SeqMems in this test // If the NoDedupMemAnnotation were ignored, we'd end up with just 1 ExtModule // If the NoDedupMemAnnotation were handled incorrectly as it was prior to this test, there // would be 3 ExtModules - numExtMods should be (2) + numExtMods should be(2) // Check the emitted conf checkMemConf(confLoc, mems) (new java.io.File(confLoc)).delete() @@ -371,12 +380,12 @@ circuit CustomMemory : """ val mems = Set(MemConf("mem_0_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None)) val confLoc = "ReplSeqMemTests.confTEMP" - val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc)) + val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:" + confLoc)) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // Check correctness of firrtl val circuit = parse(res.getEmittedCircuit.value) val numExtMods = circuit.modules.count { - case e: ExtModule => true + case e: ExtModule => true case _ => false } require(numExtMods == 1) @@ -400,9 +409,9 @@ circuit CustomMemory : """ val mems = Set(MemConf("mem_ext", 1024, 16, Map(WritePort -> 1, ReadPort -> 1), None)) val confLoc = "ReplSeqMemTests.confTEMP" - val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc)) + val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:" + confLoc)) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) - res.getEmittedCircuit.value shouldNot include ("mask") + res.getEmittedCircuit.value shouldNot include("mask") // Check the emitted conf checkMemConf(confLoc, mems) (new java.io.File(confLoc)).delete() @@ -428,11 +437,11 @@ circuit CustomMemory : """ val mems = Set(MemConf("mem_ext", 1024, 16, Map(MaskedWritePort -> 1, ReadPort -> 1), Some(8))) val confLoc = "ReplSeqMemTests.confTEMP" - val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc)) + val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:" + confLoc)) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // TODO Until RemoveCHIRRTL is removed, enable will still drive validif for mask - res should containLine ("mem.W0_mask_0 <= validif(io_en, io_mask_0)") - res should containLine ("mem.W0_mask_1 <= validif(io_en, io_mask_1)") + res should containLine("mem.W0_mask_0 <= validif(io_en, io_mask_0)") + res should containLine("mem.W0_mask_1 <= validif(io_en, io_mask_1)") // Check the emitted conf checkMemConf(confLoc, mems) (new java.io.File(confLoc)).delete() @@ -462,12 +471,11 @@ circuit CustomMemory : """ val mems = Set(MemConf("mem_ext", 1024, 16, Map(MaskedReadWritePort -> 1), Some(8))) val confLoc = "ReplSeqMemTests.confTEMP" - val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc), - InferReadWriteAnnotation) + val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:" + confLoc), InferReadWriteAnnotation) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // TODO Until RemoveCHIRRTL is removed, enable will still drive validif for mask - res should containLine ("mem.RW0_wmask_0 <= validif(io_en, io_mask_0)") - res should containLine ("mem.RW0_wmask_1 <= validif(io_en, io_mask_1)") + res should containLine("mem.RW0_wmask_0 <= validif(io_en, io_mask_0)") + res should containLine("mem.RW0_wmask_1 <= validif(io_en, io_mask_1)") // Check the emitted conf checkMemConf(confLoc, mems) (new java.io.File(confLoc)).delete() @@ -487,15 +495,14 @@ circuit NoMemsHere : """ val mems = Set.empty[MemConf] val confLoc = "ReplSeqMemTests.confTEMP" - val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc), - InferReadWriteAnnotation) + val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:" + confLoc), InferReadWriteAnnotation) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // Check the emitted conf checkMemConf(confLoc, mems) (new java.io.File(confLoc)).delete() } - "ReplSeqMem" should "throw an exception when encountering masks with variable granularity" in { + "ReplSeqMem" should "throw an exception when encountering masks with variable granularity" in { val input = """ circuit Top : module Top : @@ -518,10 +525,9 @@ circuit Top : """.stripMargin intercept[ReplaceMemMacros.UnsupportedBlackboxMemoryException] { val confLoc = "ReplSeqMemTests.confTEMP" - val annos = Seq(ReplSeqMemAnnotation.parse("-c:Top:-o:"+confLoc)) + val annos = Seq(ReplSeqMemAnnotation.parse("-c:Top:-o:" + confLoc)) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) } } } - diff --git a/src/test/scala/firrtlTests/ReplaceAccessesSpec.scala b/src/test/scala/firrtlTests/ReplaceAccessesSpec.scala index fcf36876..588b7c39 100644 --- a/src/test/scala/firrtlTests/ReplaceAccessesSpec.scala +++ b/src/test/scala/firrtlTests/ReplaceAccessesSpec.scala @@ -7,17 +7,14 @@ import firrtl.passes._ import firrtl.testutils._ class ReplaceAccessesSpec extends FirrtlFlatSpec { - val transforms = Seq( - ToWorkingIR, - ResolveKinds, - InferTypes, - ResolveFlows, - new InferWidths, - ReplaceAccesses) + val transforms = Seq(ToWorkingIR, ResolveKinds, InferTypes, ResolveFlows, new InferWidths, ReplaceAccesses) protected def exec(input: String) = { - transforms.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, t: Transform) => t.runTransform(c) - }.circuit.serialize + transforms + .foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, t: Transform) => + t.runTransform(c) + } + .circuit + .serialize } } @@ -40,7 +37,7 @@ class ReplaceAccessesMultiDim extends ReplaceAccessesSpec { reset => (UInt<1>(0), r_vec) out <= r_vec[2][1] """ - (parse(exec(input))) should be (parse(check)) + (parse(exec(input))) should be(parse(check)) } "ReplacesAccesses" should "NOT generate out-of-bounds indices" in { @@ -61,6 +58,6 @@ class ReplaceAccessesMultiDim extends ReplaceAccessesSpec { reset => (UInt<1>(0), r_vec) out <= r_vec[1][UInt<3>(8)] """ - (parse(exec(input))) should be (parse(check)) + (parse(exec(input))) should be(parse(check)) } } diff --git a/src/test/scala/firrtlTests/ReplaceTruncatingArithmeticSpec.scala b/src/test/scala/firrtlTests/ReplaceTruncatingArithmeticSpec.scala index 05a5fe29..ea01ca00 100644 --- a/src/test/scala/firrtlTests/ReplaceTruncatingArithmeticSpec.scala +++ b/src/test/scala/firrtlTests/ReplaceTruncatingArithmeticSpec.scala @@ -11,50 +11,46 @@ class ReplaceTruncatingArithmeticSpec extends FirrtlFlatSpec { (new VerilogCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty) def compileBody(body: String) = { val str = """ - |circuit Test : - | module Test : - |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") + |circuit Test : + | module Test : + |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") compile(str) } "Truncting addition" should "be inferred and emitted in Verilog" in { val result = compileBody(s""" - |input x : UInt<8> - |input y : UInt<8> - |output z : UInt<8> - |z <= tail(add(x, y), 1)""".stripMargin - ) - result should containLine (s"assign z = x + y;") + |input x : UInt<8> + |input y : UInt<8> + |output z : UInt<8> + |z <= tail(add(x, y), 1)""".stripMargin) + result should containLine(s"assign z = x + y;") } it should "be inferred and emitted in Verilog even with an intermediate node" in { val result = compileBody(s""" - |input x : UInt<8> - |input y : UInt<8> - |output z : UInt<8> - |node n = add(x, y) - |z <= tail(n, 1)""".stripMargin - ) - result should containLine (s"assign z = x + y;") + |input x : UInt<8> + |input y : UInt<8> + |output z : UInt<8> + |node n = add(x, y) + |z <= tail(n, 1)""".stripMargin) + result should containLine(s"assign z = x + y;") } "Truncting subtraction" should "be inferred and emitted in Verilog" in { val result = compileBody(s""" - |input x : UInt<8> - |input y : UInt<8> - |output z : UInt<8> - |z <= tail(sub(x, y), 1)""".stripMargin - ) - result should containLine (s"assign z = x - y;") + |input x : UInt<8> + |input y : UInt<8> + |output z : UInt<8> + |z <= tail(sub(x, y), 1)""".stripMargin) + result should containLine(s"assign z = x - y;") } "Tailing more than 1" should "not result in a truncating operator" in { val result = compileBody(s""" - |input x : UInt<8> - |input y : UInt<8> - |output z : UInt<7> - |node n = sub(x, y) - |z <= tail(n, 2)""".stripMargin - ) - result should containLine (s"wire [8:0] n = x - y;") - result should containLine (s"assign z = n[6:0];") + |input x : UInt<8> + |input y : UInt<8> + |output z : UInt<7> + |node n = sub(x, y) + |z <= tail(n, 2)""".stripMargin) + result should containLine(s"wire [8:0] n = x - y;") + result should containLine(s"assign z = n[6:0];") } } diff --git a/src/test/scala/firrtlTests/SimplifyMemsSpec.scala b/src/test/scala/firrtlTests/SimplifyMemsSpec.scala index ec947ecf..c7d04d46 100644 --- a/src/test/scala/firrtlTests/SimplifyMemsSpec.scala +++ b/src/test/scala/firrtlTests/SimplifyMemsSpec.scala @@ -12,73 +12,73 @@ class SimplifyMemsSpec extends ConstantPropagationSpec { "SimplifyMems" should "lower aggregate memories" in { val input = - """circuit Test : - | module Test : - | input clock : Clock - | input wen : UInt<1> - | input wdata : { a : UInt<8>, b : UInt<8> } - | output rdata : { a : UInt<8>, b : UInt<8> } - | mem m : - | data-type => { a : UInt<8>, b : UInt<8>} - | depth => 32 - | read-latency => 1 - | write-latency => 1 - | reader => read - | writer => write - | m.read.clk <= clock - | m.read.en <= UInt<1>(1) - | m.read.addr is invalid - | rdata <= m.read.data - | m.write.clk <= clock - | m.write.en <= wen - | m.write.mask.a <= UInt<1>(1) - | m.write.mask.b <= UInt<1>(1) - | m.write.addr is invalid - | m.write.data <= wdata + """circuit Test : + | module Test : + | input clock : Clock + | input wen : UInt<1> + | input wdata : { a : UInt<8>, b : UInt<8> } + | output rdata : { a : UInt<8>, b : UInt<8> } + | mem m : + | data-type => { a : UInt<8>, b : UInt<8>} + | depth => 32 + | read-latency => 1 + | write-latency => 1 + | reader => read + | writer => write + | m.read.clk <= clock + | m.read.en <= UInt<1>(1) + | m.read.addr is invalid + | rdata <= m.read.data + | m.write.clk <= clock + | m.write.en <= wen + | m.write.mask.a <= UInt<1>(1) + | m.write.mask.b <= UInt<1>(1) + | m.write.addr is invalid + | m.write.data <= wdata """.stripMargin val check = - """circuit Test : - | module Test : - | input clock : Clock - | input wen : UInt<1> - | input wdata : { a : UInt<8>, b : UInt<8>} - | output rdata : { a : UInt<8>, b : UInt<8>} - | - | wire m : { flip read : { addr : UInt<5>, en : UInt<1>, clk : Clock, flip data : { a : UInt<8>, b : UInt<8>}}, flip write : { addr : UInt<5>, en : UInt<1>, clk : Clock, data : { a : UInt<8>, b : UInt<8>}, mask : { a : UInt<1>, b : UInt<1>}}} - | mem m_flattened : - | data-type => UInt<16> - | depth => 32 - | read-latency => 1 - | write-latency => 1 - | reader => read - | writer => write - | read-under-write => undefined - | m_flattened.read.addr <= m.read.addr - | m_flattened.read.en <= m.read.en - | m_flattened.read.clk <= m.read.clk - | m.read.data.b <= asUInt(bits(m_flattened.read.data, 7, 0)) - | m.read.data.a <= asUInt(bits(m_flattened.read.data, 15, 8)) - | m_flattened.write.addr <= m.write.addr - | m_flattened.write.en <= m.write.en - | m_flattened.write.clk <= m.write.clk - | m_flattened.write.data <= cat(asUInt(m.write.data.a), asUInt(m.write.data.b)) - | m_flattened.write.mask <= UInt<1>("h1") - | rdata.a <= m.read.data.a - | rdata.b <= m.read.data.b - | m.read.addr is invalid - | m.read.en <= UInt<1>("h1") - | m.read.clk <= clock - | m.write.addr is invalid - | m.write.en <= wen - | m.write.clk <= clock - | m.write.data.a <= wdata.a - | m.write.data.b <= wdata.b - | m.write.mask.a <= UInt<1>("h1") - | m.write.mask.b <= UInt<1>("h1") + """circuit Test : + | module Test : + | input clock : Clock + | input wen : UInt<1> + | input wdata : { a : UInt<8>, b : UInt<8>} + | output rdata : { a : UInt<8>, b : UInt<8>} + | + | wire m : { flip read : { addr : UInt<5>, en : UInt<1>, clk : Clock, flip data : { a : UInt<8>, b : UInt<8>}}, flip write : { addr : UInt<5>, en : UInt<1>, clk : Clock, data : { a : UInt<8>, b : UInt<8>}, mask : { a : UInt<1>, b : UInt<1>}}} + | mem m_flattened : + | data-type => UInt<16> + | depth => 32 + | read-latency => 1 + | write-latency => 1 + | reader => read + | writer => write + | read-under-write => undefined + | m_flattened.read.addr <= m.read.addr + | m_flattened.read.en <= m.read.en + | m_flattened.read.clk <= m.read.clk + | m.read.data.b <= asUInt(bits(m_flattened.read.data, 7, 0)) + | m.read.data.a <= asUInt(bits(m_flattened.read.data, 15, 8)) + | m_flattened.write.addr <= m.write.addr + | m_flattened.write.en <= m.write.en + | m_flattened.write.clk <= m.write.clk + | m_flattened.write.data <= cat(asUInt(m.write.data.a), asUInt(m.write.data.b)) + | m_flattened.write.mask <= UInt<1>("h1") + | rdata.a <= m.read.data.a + | rdata.b <= m.read.data.b + | m.read.addr is invalid + | m.read.en <= UInt<1>("h1") + | m.read.clk <= clock + | m.write.addr is invalid + | m.write.en <= wen + | m.write.clk <= clock + | m.write.data.a <= wdata.a + | m.write.data.b <= wdata.b + | m.write.mask.a <= UInt<1>("h1") + | m.write.mask.b <= UInt<1>("h1") """.stripMargin - (parse(exec(input))) should be (parse(check)) + (parse(exec(input))) should be(parse(check)) } } diff --git a/src/test/scala/firrtlTests/StringSpec.scala b/src/test/scala/firrtlTests/StringSpec.scala index 30535466..fc2fa486 100644 --- a/src/test/scala/firrtlTests/StringSpec.scala +++ b/src/test/scala/firrtlTests/StringSpec.scala @@ -21,7 +21,7 @@ class PrintfSpec extends FirrtlPropSpec { copyResourceToFile(cppHarnessResourceName, harness) verilogToCpp(prefix, testDir, Seq(), harness) #&& - cppToExe(prefix, testDir) ! loggingProcessLogger + cppToExe(prefix, testDir) ! loggingProcessLogger // Check for correct Printf: // Count up from 0, match decimal, hex, and binary @@ -31,7 +31,7 @@ class PrintfSpec extends FirrtlPropSpec { var expected = 0 var error = false val ret = Process(s"./V${prefix}", testDir) ! - ProcessLogger( line => { + ProcessLogger(line => { line match { case regex(dec, hex, bin) => { if (!done) { @@ -57,7 +57,7 @@ class StringSpec extends FirrtlPropSpec { // Whitelist is [0x20 - 0x7e] val whitelist = """ !\"#$%&\''()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ""" + - """[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~""" + """[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~""" property(s"Character whitelist should be supported: [$whitelist] ") { val lit = StringLit.unescape(whitelist) @@ -102,7 +102,7 @@ class StringSpec extends FirrtlPropSpec { val legalFormats = "HhDdOoBbCcLlVvMmSsTtUuZz%".toSet def isValidVerilogFormat(str: String): Boolean = str.toSeq.sliding(2).forall { case Seq('%', char) if legalFormats contains char => true - case _ => true + case _ => true } // Generators for legal Firrtl format strings @@ -112,8 +112,8 @@ class StringSpec extends FirrtlPropSpec { val genFragment = Gen.frequency((10, genChar), (1, genFormat), (1, genEsc)).map(_.mkString) val genString = Gen.listOf[String](genFragment).map(_.mkString) - property ("Firrtl Format Strings with Unicode chars should emit as legal Verilog Strings") { - forAll (genString) { str => + property("Firrtl Format Strings with Unicode chars should emit as legal Verilog Strings") { + forAll(genString) { str => val verilogStr = StringLit(str).verilogFormat.verilogEscape assert(isValidVerilogString(verilogStr)) assert(isValidVerilogFormat(verilogStr)) diff --git a/src/test/scala/firrtlTests/UniquifySpec.scala b/src/test/scala/firrtlTests/UniquifySpec.scala index 074da256..19ae75fc 100644 --- a/src/test/scala/firrtlTests/UniquifySpec.scala +++ b/src/test/scala/firrtlTests/UniquifySpec.scala @@ -21,24 +21,29 @@ class UniquifySpec extends FirrtlFlatSpec { Uniquify ) - private def executeTest(input: String, expected: Seq[String]): Unit = executeTest(input, expected, Seq.empty, Seq.empty) - private def executeTest(input: String, expected: Seq[String], - inputAnnos: Seq[Annotation], expectedAnnos: Seq[Annotation]): Unit = { + private def executeTest(input: String, expected: Seq[String]): Unit = + executeTest(input, expected, Seq.empty, Seq.empty) + private def executeTest( + input: String, + expected: Seq[String], + inputAnnos: Seq[Annotation], + expectedAnnos: Seq[Annotation] + ): Unit = { val circuit = Parser.parse(input.split("\n").toIterator) val result = transforms.foldLeft(CircuitState(circuit, UnknownForm, inputAnnos)) { (c: CircuitState, p: Transform) => p.runTransform(c) } val c = result.circuit - val lines = c.serialize.split("\n") map normalized + val lines = c.serialize.split("\n").map(normalized) - expected foreach { e => + expected.foreach { e => lines should contain(e) } result.annotations.toSeq should equal(expectedAnnos) } - behavior of "Uniquify" + behavior.of("Uniquify") it should "rename colliding ports" in { val input = @@ -51,13 +56,22 @@ class UniquifySpec extends FirrtlFlatSpec { val expected = Seq( "input a__ : { flip b : UInt<1>, c_ : { d : UInt<2>, flip e : UInt<3>}[2], c_1_e : UInt<4>}[2]", "output a_0_c_ : UInt<5>", - "output a__0 : UInt<6>") map normalized - - val inputAnnos = Seq(DontTouchAnnotation(ReferenceTarget("Test", "Test", Seq.empty, "a", Seq(Index(0), Field("b")))), - DontTouchAnnotation(ReferenceTarget("Test", "Test", Seq.empty, "a", Seq(Index(0), Field("c"), Index(0), Field("e"))))) - - val expectedAnnos = Seq(DontTouchAnnotation(ReferenceTarget("Test", "Test", Seq.empty, "a__", Seq(Index(0), Field("b")))), - DontTouchAnnotation(ReferenceTarget("Test", "Test", Seq.empty, "a__", Seq(Index(0), Field("c_"), Index(0), Field("e"))))) + "output a__0 : UInt<6>" + ).map(normalized) + + val inputAnnos = Seq( + DontTouchAnnotation(ReferenceTarget("Test", "Test", Seq.empty, "a", Seq(Index(0), Field("b")))), + DontTouchAnnotation( + ReferenceTarget("Test", "Test", Seq.empty, "a", Seq(Index(0), Field("c"), Index(0), Field("e"))) + ) + ) + + val expectedAnnos = Seq( + DontTouchAnnotation(ReferenceTarget("Test", "Test", Seq.empty, "a__", Seq(Index(0), Field("b")))), + DontTouchAnnotation( + ReferenceTarget("Test", "Test", Seq.empty, "a__", Seq(Index(0), Field("c_"), Index(0), Field("e"))) + ) + ) executeTest(input, expected, inputAnnos, expectedAnnos) } @@ -74,7 +88,8 @@ class UniquifySpec extends FirrtlFlatSpec { val expected = Seq( "reg a__ : { b : UInt<1>, c_ : { d : UInt<2>, e : UInt<3>}[2], c_1_e : UInt<4>}[2], clock with :", "reg a_0_c_ : UInt<5>, clock with :", - "reg a__0 : UInt<6>, clock with :") map normalized + "reg a__0 : UInt<6>, clock with :" + ).map(normalized) executeTest(input, expected) } @@ -89,12 +104,11 @@ class UniquifySpec extends FirrtlFlatSpec { | node a_0_c_ = a[0].b | node a__0 = a[1].c[0].d """.stripMargin - val expected = Seq("node a__ = x") map normalized + val expected = Seq("node a__ = x").map(normalized) executeTest(input, expected) } - it should "rename DefRegister expressions: clock, reset, and init" in { val input = """circuit Test : @@ -111,7 +125,7 @@ class UniquifySpec extends FirrtlFlatSpec { val expected = Seq( "reg foo : UInt<4>, clock_[1] with :", "reset => (reset_.a, init_[3].b_[1].d)" - ) map normalized + ).map(normalized) executeTest(input, expected) } @@ -126,7 +140,7 @@ class UniquifySpec extends FirrtlFlatSpec { val expected = Seq( "input data : { a : UInt<4>, b : UInt<4>}[2]", "node data_0_a_ = data[0].a" - ) map normalized + ).map(normalized) executeTest(input, expected) } @@ -141,9 +155,7 @@ class UniquifySpec extends FirrtlFlatSpec { | node foo = data.a | node bar = data.b[1] """.stripMargin - val expected = Seq( - "node foo = data__.a", - "node bar = data__.b[1]") map normalized + val expected = Seq("node foo = data__.a", "node bar = data__.b[1]").map(normalized) executeTest(input, expected) } @@ -158,25 +170,22 @@ class UniquifySpec extends FirrtlFlatSpec { | a_0_b <= a[0].b | a[0].c <- a__0_c_ """.stripMargin - val expected = Seq( - "a_0_b <= a__[0].b", - "a__[0].c_ <- a__0_c_") map normalized + val expected = Seq("a_0_b <= a__[0].b", "a__[0].c_ <- a__0_c_").map(normalized) executeTest(input, expected) } it should "rename SubAccesses" in { val input = - """circuit Test : - | module Test : - | input a : { b : UInt<1>, c : { d : UInt<2>, e : UInt<3>}[2], c_1_e : UInt<4>}[2] - | output a_0_b : UInt<2> - | input i : UInt<1>[2] - | output i_0 : UInt<1> - | a_0_b <= a.c[i[1]].d + """circuit Test : + | module Test : + | input a : { b : UInt<1>, c : { d : UInt<2>, e : UInt<3>}[2], c_1_e : UInt<4>}[2] + | output a_0_b : UInt<2> + | input i : UInt<1>[2] + | output i_0 : UInt<1> + | a_0_b <= a.c[i[1]].d """.stripMargin - val expected = Seq( - "a_0_b <= a_.c_[i_[1]].d") map normalized + val expected = Seq("a_0_b <= a_.c_[i_[1]].d").map(normalized) executeTest(input, expected) } @@ -192,7 +201,7 @@ class UniquifySpec extends FirrtlFlatSpec { """.stripMargin val expected = Seq( "a_0_b <= mux(a__[UInt<1>(\"h0\")].c_1_e, or(a__[or(a__[0].b, a__[1].b)].b, xorr(a__[0].c_1_e)), orr(cat(a__0_c_[0].e, a__[1].c_1_e)))" - ) map normalized + ).map(normalized) executeTest(input, expected) } @@ -220,10 +229,7 @@ class UniquifySpec extends FirrtlFlatSpec { | mem.write.en <= UInt(0) | mem.write.clk <= clock """.stripMargin - val expected = Seq( - "mem mem_ :", - "node mem_0_b = mem_.read.data[0].b", - "mem_.read.addr is invalid") map normalized + val expected = Seq("mem mem_ :", "node mem_0_b = mem_.read.data[0].b", "mem_.read.addr is invalid").map(normalized) executeTest(input, expected) } @@ -251,33 +257,29 @@ class UniquifySpec extends FirrtlFlatSpec { | mem.write.en <= UInt(0) | mem.write.clk <= clock """.stripMargin - val expected = Seq( - "data-type => { a : UInt<8>, b_ : UInt<8>[2], b_0 : UInt<8>}", - "node x = mem.read.data.b_[0]") map normalized + val expected = + Seq("data-type => { a : UInt<8>, b_ : UInt<8>[2], b_0 : UInt<8>}", "node x = mem.read.data.b_[0]").map(normalized) executeTest(input, expected) } it should "rename instances and their ports" in { val input = - """circuit Test : - | module Other : - | input a : { b : UInt<4>, c : UInt<4> } - | output a_b : UInt<4> - | a_b <= a.b - | - | module Test : - | node x = UInt(6) - | inst mod of Other - | mod.a.b <= x - | mod.a.c <= x - | node mod_a_b = mod.a_b + """circuit Test : + | module Other : + | input a : { b : UInt<4>, c : UInt<4> } + | output a_b : UInt<4> + | a_b <= a.b + | + | module Test : + | node x = UInt(6) + | inst mod of Other + | mod.a.b <= x + | mod.a.c <= x + | node mod_a_b = mod.a_b """.stripMargin - val expected = Seq( - "inst mod_ of Other", - "mod_.a_.b <= x", - "mod_.a_.c <= x", - "node mod_a_b = mod_.a_b") map normalized + val expected = + Seq("inst mod_ of Other", "mod_.a_.b <= x", "mod_.a_.c <= x", "node mod_a_b = mod_.a_b").map(normalized) executeTest(input, expected) } @@ -296,7 +298,7 @@ class UniquifySpec extends FirrtlFlatSpec { // Run the "quick" test three times and choose the longest time as the basis. val nCalibrationRuns = 3 def mkType(i: Int): String = { - if(i == 0) "UInt<8>" else s"{x: ${mkType(i - 1)}}" + if (i == 0) "UInt<8>" else s"{x: ${mkType(i - 1)}}" } val timesMs = ( for (depth <- (List.fill(nCalibrationRuns)(1) :+ depth)) yield { @@ -308,12 +310,12 @@ class UniquifySpec extends FirrtlFlatSpec { |""".stripMargin val (ms, _) = Utils.time(compileToVerilog(input)) ms - } + } ).toArray // The baseMs will be the maximum of the first calibration runs val baseMs = timesMs.slice(0, nCalibrationRuns - 1).max val renameMs = timesMs(nCalibrationRuns) if (TestOptions.accurateTiming) - renameMs shouldBe < (baseMs * threshold) + renameMs shouldBe <(baseMs * threshold) } } diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index 288bf336..8f128274 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -12,50 +12,41 @@ import FirrtlCheckers._ class UnitTests extends FirrtlFlatSpec { private def executeTest(input: String, expected: Seq[String], transforms: Seq[Transform]) = { - val lines = execute(input, transforms).circuit.serialize.split("\n") map normalized + val lines = execute(input, transforms).circuit.serialize.split("\n").map(normalized) - expected foreach { e => + expected.foreach { e => lines should contain(e) } } private def executeTest(input: String, expected: String, transforms: Seq[Transform]) = { - execute(input, transforms).circuit should be (parse(expected)) + execute(input, transforms).circuit should be(parse(expected)) } def execute(input: String, transforms: Seq[Transform]): CircuitState = { - val c = transforms.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, t: Transform) => t.runTransform(c) - }.circuit + val c = transforms + .foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, t: Transform) => + t.runTransform(c) + } + .circuit CircuitState(c, UnknownForm, Seq(), None) } "Pull muxes" should "not be exponential in runtime" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes, - PullMuxes) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes, PullMuxes) val input = """circuit Unit : | module Unit : | input _2: UInt<1> | output x: UInt<32> | x <= cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat( _2, cat(_2, cat(_2, cat(_2, _2)))))))))))))))))))))))))))))))""".stripMargin - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => + p.run(c) } } "Connecting bundles of different types" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes) val input = """circuit Unit : | module Unit : @@ -63,96 +54,78 @@ class UnitTests extends FirrtlFlatSpec { | output x: {a : UInt<1>, b : UInt<1>} | x <= y""".stripMargin intercept[CheckTypes.InvalidConnect] { - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => + p.run(c) } } } "Initializing a register with a different type" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes) val input = - """circuit Unit : - | module Unit : - | input clock : Clock - | input reset : UInt<1> - | wire x : { valid : UInt<1> } - | reg y : { valid : UInt<1>, bits : UInt<3> }, clock with : - | reset => (reset, x)""".stripMargin + """circuit Unit : + | module Unit : + | input clock : Clock + | input reset : UInt<1> + | wire x : { valid : UInt<1> } + | reg y : { valid : UInt<1>, bits : UInt<3> }, clock with : + | reset => (reset, x)""".stripMargin intercept[CheckTypes.InvalidRegInit] { - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => + p.run(c) } } } "Partial connection two bundle types whose relative flips don't match but leaf node directions do" should "connect correctly" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes, - ResolveFlows, - ExpandConnects) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes, ResolveFlows, ExpandConnects) val input = - """circuit Unit : - | module Unit : - | wire x : { flip a: { b: UInt<32> } } - | wire y : { a: { flip b: UInt<32> } } - | x <- y""".stripMargin + """circuit Unit : + | module Unit : + | wire x : { flip a: { b: UInt<32> } } + | wire y : { a: { flip b: UInt<32> } } + | x <- y""".stripMargin val check = - """circuit Unit : - | module Unit : - | wire x : { flip a: { b: UInt<32> } } - | wire y : { a: { flip b: UInt<32> } } - | y.a.b <= x.a.b""".stripMargin - val c_result = passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) + """circuit Unit : + | module Unit : + | wire x : { flip a: { b: UInt<32> } } + | wire y : { a: { flip b: UInt<32> } } + | y.a.b <= x.a.b""".stripMargin + val c_result = passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => + p.run(c) } val writer = new StringWriter() (new HighFirrtlEmitter).emit(CircuitState(c_result, HighForm), writer) - (parse(writer.toString())) should be (parse(check)) + (parse(writer.toString())) should be(parse(check)) } val splitExpTestCode = - """ - |circuit Unit : - | module Unit : - | input a : UInt<1> - | input b : UInt<2> - | input c : UInt<2> - | output out : UInt<1> - | out <= bits(mux(a, b, c), 0, 0) - |""".stripMargin + """ + |circuit Unit : + | module Unit : + | input a : UInt<1> + | input b : UInt<2> + | input c : UInt<2> + | output out : UInt<1> + | out <= bits(mux(a, b, c), 0, 0) + |""".stripMargin "Emitting a nested expression" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - InferTypes, - ResolveKinds) + val passes = Seq(ToWorkingIR, InferTypes, ResolveKinds) intercept[PassException] { val c = Parser.parse(splitExpTestCode.split("\n").toIterator) - val c2 = passes.foldLeft(c)((c, p) => p run c) + val c2 = passes.foldLeft(c)((c, p) => p.run(c)) val writer = new StringWriter() (new VerilogEmitter).emit(CircuitState(c2, LowForm), writer) } } "After splitting, emitting a nested expression" should "compile" in { - val passes = Seq( - ToWorkingIR, - SplitExpressions, - InferTypes) + val passes = Seq(ToWorkingIR, SplitExpressions, InferTypes) val c = Parser.parse(splitExpTestCode.split("\n").toIterator) - val c2 = passes.foldLeft(c)((c, p) => p run c) - val writer = new StringWriter() - (new VerilogEmitter).emit(CircuitState(c2, LowForm), writer) + val c2 = passes.foldLeft(c)((c, p) => p.run(c)) + val writer = new StringWriter() + (new VerilogEmitter).emit(CircuitState(c2, LowForm), writer) } "Simple compound expressions" should "be split" in { @@ -166,12 +139,12 @@ class UnitTests extends FirrtlFlatSpec { ) val input = """circuit Top : - | module Top : - | input a : UInt<32> - | input b : UInt<32> - | input d : UInt<32> - | output c : UInt<1> - | c <= geq(add(a, b),d)""".stripMargin + | module Top : + | input a : UInt<32> + | input b : UInt<32> + | input d : UInt<32> + | output c : UInt<1> + | c <= geq(add(a, b),d)""".stripMargin val check = Seq( "node _GEN_0 = add(a, b)", "c <= geq(_GEN_0, d)" @@ -190,14 +163,14 @@ class UnitTests extends FirrtlFlatSpec { ) val input = """circuit Top : - | module Top : - | input a : UInt<32> - | input b : UInt<20> - | input pred : UInt<1> - | output c : UInt<32> - | c <= mux(pred,a,b)""".stripMargin - val check = Seq("c <= mux(pred, a, pad(b, 32))") - executeTest(input, check, passes) + | module Top : + | input a : UInt<32> + | input b : UInt<20> + | input pred : UInt<1> + | output c : UInt<32> + | c <= mux(pred,a,b)""".stripMargin + val check = Seq("c <= mux(pred, a, pad(b, 32))") + executeTest(input, check, passes) } "Indexes into sub-accesses" should "be dealt with" in { @@ -214,40 +187,34 @@ class UnitTests extends FirrtlFlatSpec { ) val input = """circuit AssignViaDeref : - | module AssignViaDeref : - | input clock : Clock - | input reset : UInt<1> - | output io : {a : UInt<8>, sel : UInt<1>} - | - | io is invalid - | reg table : {a : UInt<8>}[2], clock - | reg otherTable : {a : UInt<8>}[2], clock - | otherTable[table[UInt<1>("h01")].a].a <= UInt<1>("h00")""".stripMargin - //TODO(azidar): I realize this is brittle, but unfortunately there - // isn't a better way to test this pass - val check = Seq( - """wire _table_1 : { a : UInt<8>}""", - """_table_1.a is invalid""", - """when UInt<1>("h1") :""", - """_table_1.a <= table[1].a""", - """wire _otherTable_table_1_a_a : UInt<8>""", - """when eq(UInt<1>("h0"), _table_1.a) :""", - """otherTable[0].a <= _otherTable_table_1_a_a""", - """when eq(UInt<1>("h1"), _table_1.a) :""", - """otherTable[1].a <= _otherTable_table_1_a_a""", - """_otherTable_table_1_a_a <= UInt<1>("h0")""" - ) - executeTest(input, check, passes) + | module AssignViaDeref : + | input clock : Clock + | input reset : UInt<1> + | output io : {a : UInt<8>, sel : UInt<1>} + | + | io is invalid + | reg table : {a : UInt<8>}[2], clock + | reg otherTable : {a : UInt<8>}[2], clock + | otherTable[table[UInt<1>("h01")].a].a <= UInt<1>("h00")""".stripMargin + //TODO(azidar): I realize this is brittle, but unfortunately there + // isn't a better way to test this pass + val check = Seq( + """wire _table_1 : { a : UInt<8>}""", + """_table_1.a is invalid""", + """when UInt<1>("h1") :""", + """_table_1.a <= table[1].a""", + """wire _otherTable_table_1_a_a : UInt<8>""", + """when eq(UInt<1>("h0"), _table_1.a) :""", + """otherTable[0].a <= _otherTable_table_1_a_a""", + """when eq(UInt<1>("h1"), _table_1.a) :""", + """otherTable[1].a <= _otherTable_table_1_a_a""", + """_otherTable_table_1_a_a <= UInt<1>("h0")""" + ) + executeTest(input, check, passes) } "Oversized bit select" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - ResolveKinds, - InferTypes, - ResolveFlows, - new InferWidths, - CheckWidths) + val passes = Seq(ToWorkingIR, ResolveKinds, InferTypes, ResolveFlows, new InferWidths, CheckWidths) val input = """circuit Unit : | module Unit : @@ -260,13 +227,7 @@ class UnitTests extends FirrtlFlatSpec { } "Oversized head select" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - ResolveKinds, - InferTypes, - ResolveFlows, - new InferWidths, - CheckWidths) + val passes = Seq(ToWorkingIR, ResolveKinds, InferTypes, ResolveFlows, new InferWidths, CheckWidths) val input = """circuit Unit : | module Unit : @@ -279,14 +240,8 @@ class UnitTests extends FirrtlFlatSpec { } "zero head select" should "return an empty module" in { - val passes = Seq( - ToWorkingIR, - ResolveKinds, - InferTypes, - ResolveFlows, - new InferWidths, - CheckWidths, - new DeadCodeElimination) + val passes = + Seq(ToWorkingIR, ResolveKinds, InferTypes, ResolveFlows, new InferWidths, CheckWidths, new DeadCodeElimination) val input = """circuit Unit : | module Unit : @@ -299,13 +254,7 @@ class UnitTests extends FirrtlFlatSpec { } "Oversized tail select" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - ResolveKinds, - InferTypes, - ResolveFlows, - new InferWidths, - CheckWidths) + val passes = Seq(ToWorkingIR, ResolveKinds, InferTypes, ResolveFlows, new InferWidths, CheckWidths) val input = """circuit Unit : | module Unit : @@ -318,14 +267,8 @@ class UnitTests extends FirrtlFlatSpec { } "max tail select" should "return an empty module" in { - val passes = Seq( - ToWorkingIR, - ResolveKinds, - InferTypes, - ResolveFlows, - new InferWidths, - CheckWidths, - new DeadCodeElimination) + val passes = + Seq(ToWorkingIR, ResolveKinds, InferTypes, ResolveFlows, new InferWidths, CheckWidths, new DeadCodeElimination) val input = """circuit Unit : | module Unit : @@ -338,11 +281,7 @@ class UnitTests extends FirrtlFlatSpec { } "Partial connecting incompatable types" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - ResolveKinds, - InferTypes, - CheckTypes) + val passes = Seq(ToWorkingIR, ResolveKinds, InferTypes, CheckTypes) val input = """circuit Unit : | module Unit : @@ -419,13 +358,12 @@ class UnitTests extends FirrtlFlatSpec { """assign negSInt = -5'shd;""" ) val out = compileToVerilog(input) - val lines = out.split("\n") map normalized - expected foreach { e => + val lines = out.split("\n").map(normalized) + expected.foreach { e => lines should contain(e) } } - "Out of bound accesses" should "be invalid" in { val passes = Seq( ToWorkingIR, @@ -460,8 +398,9 @@ class UnitTests extends FirrtlFlatSpec { val index = WRef("index", ut2, PortKind, SourceFlow) val out = WRef("out", ut16, PortKind, SinkFlow) - def eq(e1: Expression, e2: Expression): Expression = DoPrim(PrimOps.Eq, Seq(e1, e2), Nil, ut1) - def array(v: Int): Expression = WSubIndex(WRef("array", VectorType(ut16, 3), WireKind, SourceFlow), v, ut16, SourceFlow) + def eq(e1: Expression, e2: Expression): Expression = DoPrim(PrimOps.Eq, Seq(e1, e2), Nil, ut1) + def array(v: Int): Expression = + WSubIndex(WRef("array", VectorType(ut16, 3), WireKind, SourceFlow), v, ut16, SourceFlow) result should containTree { case DefWire(_, "_array_index", `ut16`) => true } result should containTree { case IsInvalid(_, `fgen`) => true } @@ -490,6 +429,6 @@ class UnitTests extends FirrtlFlatSpec { | out <= shl(in, 4) |""".stripMargin val res = (new VerilogCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm)) - res should containLine ("assign out = {in, 4'h0};") + res should containLine("assign out = {in, 4'h0};") } } diff --git a/src/test/scala/firrtlTests/UtilsSpec.scala b/src/test/scala/firrtlTests/UtilsSpec.scala index 8ea69460..483b7dbd 100644 --- a/src/test/scala/firrtlTests/UtilsSpec.scala +++ b/src/test/scala/firrtlTests/UtilsSpec.scala @@ -8,28 +8,31 @@ import org.scalatest.flatspec.AnyFlatSpec class UtilsSpec extends AnyFlatSpec { - behavior of "Utils.expandPrefix" + behavior.of("Utils.expandPrefix") val expandPrefixTests = List( ("return a name without prefixes", "_", "foo", Set("foo")), ("expand a name ending with prefixes", "_", "foo__", Set("foo__")), ("expand a name with on prefix", "_", "foo_bar", Set("foo_bar", "foo_")), - ("expand a name with complex prefixes", "_", - "foo__$ba9_9X__$$$$$_", Set("foo__$ba9_9X__$$$$$_", "foo__$ba9_9X__", "foo__$ba9_", "foo__")), + ( + "expand a name with complex prefixes", + "_", + "foo__$ba9_9X__$$$$$_", + Set("foo__$ba9_9X__$$$$$_", "foo__$ba9_9X__", "foo__$ba9_", "foo__") + ), ("expand a name starting with a delimiter", "_", "__foo_bar", Set("__", "__foo_", "__foo_bar")), ("expand a name with a $ delimiter", "$", "foo$bar$$$baz", Set("foo$", "foo$bar$$$", "foo$bar$$$baz")), ("expand a name with a multi-character delimiter", "FOO", "fooFOOFOOFOObar", Set("fooFOOFOOFOO", "fooFOOFOOFOObar")) ) for ((description, delimiter, in, out) <- expandPrefixTests) { - it should description in { Utils.expandPrefixes(in, delimiter).toSet should be (out)} + it should description in { Utils.expandPrefixes(in, delimiter).toSet should be(out) } } "expandRef" should "return intermediate expressions" in { val bTpe = VectorType(Utils.BoolType, 2) val topTpe = BundleType(Seq(Field("a", Default, Utils.BoolType), Field("b", Default, bTpe))) val wr = WRef("out", topTpe, PortKind, SourceFlow) - val expected = Seq( wr, @@ -39,6 +42,6 @@ class UtilsSpec extends AnyFlatSpec { WSubIndex(WSubField(wr, "b", bTpe, SourceFlow), 1, Utils.BoolType, SourceFlow) ) - (Utils.expandRef(wr)) should be (expected) + (Utils.expandRef(wr)) should be(expected) } } diff --git a/src/test/scala/firrtlTests/VerilogEmitterTests.scala b/src/test/scala/firrtlTests/VerilogEmitterTests.scala index 21d7075e..9840229e 100644 --- a/src/test/scala/firrtlTests/VerilogEmitterTests.scala +++ b/src/test/scala/firrtlTests/VerilogEmitterTests.scala @@ -31,7 +31,7 @@ class DoPrimVerilog extends FirrtlFlatSpec { |); | assign b = ^a; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } "Andr" should "emit correctly" in { @@ -49,7 +49,7 @@ class DoPrimVerilog extends FirrtlFlatSpec { |); | assign b = &a; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } "Orr" should "emit correctly" in { @@ -67,7 +67,7 @@ class DoPrimVerilog extends FirrtlFlatSpec { |); | assign b = |a; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } "Not" should "emit correctly" in { @@ -85,7 +85,7 @@ class DoPrimVerilog extends FirrtlFlatSpec { |); | assign b = ~a; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } "inline Bits" should "emit correctly" in { @@ -179,7 +179,7 @@ class DoPrimVerilog extends FirrtlFlatSpec { | assign t = a[2:1]; | assign u = a[3]; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } "Rem" should "emit correctly" in { @@ -199,7 +199,7 @@ class DoPrimVerilog extends FirrtlFlatSpec { | wire [7:0] _GEN_0 = in % 8'h1; | assign out = _GEN_0[0]; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } "nested cats" should "emit correctly" in { @@ -225,12 +225,12 @@ class DoPrimVerilog extends FirrtlFlatSpec { | wire [5:0] _GEN_1 = {in3,in2,in1}; | assign out = {in4,_GEN_1}; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) val finalState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm), Seq(new CombineCats())) - val lines = finalState.getEmittedCircuit.value split "\n" map normalized + val lines = finalState.getEmittedCircuit.value.split("\n").map(normalized) for (e <- check) { - lines should contain (e) + lines should contain(e) } } } @@ -240,9 +240,9 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { (new VerilogCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty) private def compileBody(body: String): CircuitState = { val str = """ - |circuit Test : - | module Test : - |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") + |circuit Test : + | module Test : + |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") compile(str) } @@ -273,7 +273,7 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { compiler.compile(CircuitState(parse(input), ChirrtlForm), writer) val lines = writer.toString.split("\n") for (c <- check) { - lines should contain (c) + lines should contain(c) } } "The Verilog Emitter" should "support Modules with no ports" in { @@ -302,7 +302,7 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { |); | assign out = in; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } "The Verilog Emitter" should "support pads with width <= the width of the argument" in { @@ -325,7 +325,7 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { val state = CircuitState(circuit, LowForm, Seq(EmitCircuitAnnotation(classOf[VerilogEmitter]))) val emitter = new VerilogEmitter val result = emitter.execute(state) - result should containLine ("assign out = in;") + result should containLine("assign out = in;") } } @@ -368,22 +368,22 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { val moduleMap = state.circuit.modules.map(m => m.name -> m).toMap - val module = state.circuit.modules.filter(module => module.name == "Test").collectFirst { case m: firrtl.ir.Module => m }.get + val module = + state.circuit.modules.filter(module => module.name == "Test").collectFirst { case m: firrtl.ir.Module => m }.get val renderer = emitter.getRenderer(module, moduleMap)(writer) - renderer.emitVerilogBind("BindsToTest", - """ - |$readmemh("file", memory); - | - |""".stripMargin) + renderer.emitVerilogBind("BindsToTest", """ + |$readmemh("file", memory); + | + |""".stripMargin) val lines = writer.toString.split("\n") val outString = writer.toString // This confirms that the module io's were emitted for (c <- check) { - lines should contain (c) + lines should contain(c) } } @@ -401,16 +401,20 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { """.stripMargin val state = CircuitState(parse(input), ChirrtlForm) val result = (new VerilogCompiler).compileAndEmit(state, List()) - result should containLines ("`ifndef SYNTHESIS", - "`ifdef FIRRTL_BEFORE_INITIAL", - "`FIRRTL_BEFORE_INITIAL", - "`endif", - "initial begin") - result should containLines ("end // initial", - "`ifdef FIRRTL_AFTER_INITIAL", - "`FIRRTL_AFTER_INITIAL", - "`endif", - "`endif // SYNTHESIS") + result should containLines( + "`ifndef SYNTHESIS", + "`ifdef FIRRTL_BEFORE_INITIAL", + "`FIRRTL_BEFORE_INITIAL", + "`endif", + "initial begin" + ) + result should containLines( + "end // initial", + "`ifdef FIRRTL_AFTER_INITIAL", + "`FIRRTL_AFTER_INITIAL", + "`endif", + "`endif // SYNTHESIS" + ) } "Verilog name conflicts" should "be resolved" in { @@ -455,14 +459,14 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { | fork_ <= const_ |""".stripMargin val state = CircuitState(parse(input), UnknownForm, Seq.empty, None) - val output = Seq( ToWorkingIR, ResolveKinds, InferTypes, new VerilogRename ) - .foldLeft(state){ case (c, tx) => tx.runTransform(c) } - Seq( CheckHighForm ) - .foldLeft(output.circuit){ case (c, tx) => tx.run(c) } - output.circuit.serialize should be (parse(check_firrtl).serialize) + val output = Seq(ToWorkingIR, ResolveKinds, InferTypes, new VerilogRename) + .foldLeft(state) { case (c, tx) => tx.runTransform(c) } + Seq(CheckHighForm) + .foldLeft(output.circuit) { case (c, tx) => tx.run(c) } + output.circuit.serialize should be(parse(check_firrtl).serialize) } - behavior of "Register Updates" + behavior.of("Register Updates") they should "emit using 'else if' constructs" in { val input = @@ -484,10 +488,10 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { val circuit = Seq(ToWorkingIR, ResolveKinds, InferTypes).foldLeft(parse(input)) { case (c, p) => p.run(c) } val state = CircuitState(circuit, LowForm, Seq(EmitCircuitAnnotation(classOf[VerilogEmitter]))) val result = (new VerilogEmitter).execute(state) - result should containLine ("if (sel == 2'h0) begin") - result should containLine ("end else if (sel == 2'h1) begin" ) - result should containLine ("end else if (sel == 2'h2) begin") - result should containLine ("end else begin") + result should containLine("if (sel == 2'h0) begin") + result should containLine("end else if (sel == 2'h1) begin") + result should containLine("end else if (sel == 2'h2) begin") + result should containLine("end else begin") } they should "ignore self assignments in false conditions" in { @@ -505,7 +509,7 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { val circuit = Seq(ToWorkingIR, ResolveKinds, InferTypes).foldLeft(parse(input)) { case (c, p) => p.run(c) } val state = CircuitState(circuit, LowForm, Seq(EmitCircuitAnnotation(classOf[VerilogEmitter]))) val result = (new VerilogEmitter).execute(state) - result should not (containLine ("tmp <= tmp")) + result should not(containLine("tmp <= tmp")) } they should "ignore self assignments in true conditions and invert condition" in { @@ -523,8 +527,8 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { val circuit = Seq(ToWorkingIR, ResolveKinds, InferTypes).foldLeft(parse(input)) { case (c, p) => p.run(c) } val state = CircuitState(circuit, LowForm, Seq(EmitCircuitAnnotation(classOf[VerilogEmitter]))) val result = (new VerilogEmitter).execute(state) - result should containLine ("if (!(sel == 1'h0)) begin") - result should not (containLine ("tmp <= tmp")) + result should containLine("if (!(sel == 1'h0)) begin") + result should not(containLine("tmp <= tmp")) } they should "ignore self assignments in both true and false conditions" in { @@ -542,8 +546,8 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { val circuit = Seq(ToWorkingIR, ResolveKinds, InferTypes).foldLeft(parse(input)) { case (c, p) => p.run(c) } val state = CircuitState(circuit, LowForm, Seq(EmitCircuitAnnotation(classOf[VerilogEmitter]))) val result = (new VerilogEmitter).execute(state) - result should not (containLine ("tmp <= tmp")) - result should not (containLine ("always @(posedge clock) begin")) + result should not(containLine("tmp <= tmp")) + result should not(containLine("always @(posedge clock) begin")) } they should "properly indent muxes in either the true or false condition" in { @@ -583,24 +587,24 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { val result = (new VerilogEmitter).execute(state) /* The Verilog string is used to check for no whitespace between "else" and "if". */ val verilogString = result.getEmittedCircuit.value - result should containLine ("if (sel == 3'h0) begin") - verilogString should include ("end else if (sel == 3'h1) begin") - result should containLine ("if (sel == 3'h2) begin") - verilogString should include ("end else if (sel == 3'h3) begin") - result should containLine ("if (sel == 3'h4) begin") - verilogString should include ("end else if (sel == 3'h5) begin") - result should containLine ("if (sel == 3'h6) begin") - verilogString should include ("end else if (sel == 3'h7) begin") - result should containLine ("tmp <= in_0;") - result should containLine ("tmp <= in_1;") - result should containLine ("tmp <= in_2;") - result should containLine ("tmp <= in_3;") - result should containLine ("tmp <= in_4;") - result should containLine ("tmp <= in_5;") - result should containLine ("tmp <= in_6;") - result should containLine ("tmp <= in_7;") - result should containLine ("tmp <= in_8;") - result should containLine ("tmp <= in_9;") + result should containLine("if (sel == 3'h0) begin") + verilogString should include("end else if (sel == 3'h1) begin") + result should containLine("if (sel == 3'h2) begin") + verilogString should include("end else if (sel == 3'h3) begin") + result should containLine("if (sel == 3'h4) begin") + verilogString should include("end else if (sel == 3'h5) begin") + result should containLine("if (sel == 3'h6) begin") + verilogString should include("end else if (sel == 3'h7) begin") + result should containLine("tmp <= in_0;") + result should containLine("tmp <= in_1;") + result should containLine("tmp <= in_2;") + result should containLine("tmp <= in_3;") + result should containLine("tmp <= in_4;") + result should containLine("tmp <= in_5;") + result should containLine("tmp <= in_6;") + result should containLine("tmp <= in_7;") + result should containLine("tmp <= in_8;") + result should containLine("tmp <= in_9;") } "SInt addition" should "have casts" in { @@ -700,7 +704,7 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { |""".stripMargin ) result shouldNot containLine("assign z = $signed(x) + -8'sh2;") - result should containLine("assign z = $signed(x) - 8'sh2;") + result should containLine("assign z = $signed(x) - 8'sh2;") } it should "subtract positive literals even with max negative literal" in { @@ -712,7 +716,7 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { |""".stripMargin ) result shouldNot containLine("assign z = $signed(x) + -2'sh2;") - result should containLine("assign z = $signed(x) - 3'sh2;") + result should containLine("assign z = $signed(x) - 3'sh2;") } it should "subtract positive literals even with max negative literal with no carryout" in { @@ -724,16 +728,16 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { |""".stripMargin ) result shouldNot containLine("assign z = $signed(x) + -2'sh2;") - result should containLine("wire [2:0] _GEN_0 = $signed(x) - 3'sh2;") - result should containLine("assign z = _GEN_0[1:0];") + result should containLine("wire [2:0] _GEN_0 = $signed(x) - 3'sh2;") + result should containLine("assign z = _GEN_0[1:0];") } it should "emit FileInfo as Verilog comment" in { def result(info: String): CircuitState = compileBody( s"""input x : UInt<2> - |output z : UInt<2> - |z <= x @[$info] - |""".stripMargin + |output z : UInt<2> + |z <= x @[$info] + |""".stripMargin ) result("test") should containLine(" assign z = x; // @[test]") // newlines currently are supposed to be escaped for both firrtl and Verilog @@ -772,7 +776,8 @@ class VerilogDescriptionEmitterSpec extends FirrtlFlatSpec { val modName = ModuleName("Test", CircuitName("Test")) val annos = Seq( DocStringAnnotation(ComponentName("a", modName), "multi\nline"), - DocStringAnnotation(ComponentName("b", modName), "single line")) + DocStringAnnotation(ComponentName("b", modName), "single line") + ) val finalState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos), Seq.empty) val output = finalState.getEmittedCircuit.value for (c <- check) { @@ -816,7 +821,8 @@ class VerilogDescriptionEmitterSpec extends FirrtlFlatSpec { val annos = Seq( DocStringAnnotation(ComponentName("d", modName), "multi\nline"), DocStringAnnotation(ComponentName("e", modName), "multi\nline"), - DocStringAnnotation(ComponentName("f", modName), "single line")) + DocStringAnnotation(ComponentName("f", modName), "single line") + ) val finalState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos), Seq.empty) val output = finalState.getEmittedCircuit.value for (c <- check) { @@ -940,8 +946,8 @@ class EmittedMacroSpec extends FirrtlPropSpec { ProcessLogger(line => { line match { case "printing from FIRRTL_BEFORE_INITIAL macro" => saw_before = true - case "printing from FIRRTL_AFTER_INITIAL macro" => saw_after = true - case _ => // Do Nothing + case "printing from FIRRTL_AFTER_INITIAL macro" => saw_after = true + case _ => // Do Nothing } }) diff --git a/src/test/scala/firrtlTests/WidthSpec.scala b/src/test/scala/firrtlTests/WidthSpec.scala index 4b0bc5e5..b8fb3955 100644 --- a/src/test/scala/firrtlTests/WidthSpec.scala +++ b/src/test/scala/firrtlTests/WidthSpec.scala @@ -8,24 +8,20 @@ import firrtl.testutils._ class WidthSpec extends FirrtlFlatSpec { private def executeTest(input: String, expected: Seq[String], passes: Seq[Transform]) = { - val c = passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) - }.circuit - val lines = c.serialize.split("\n") map normalized + val c = passes + .foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + .circuit + val lines = c.serialize.split("\n").map(normalized) - expected foreach { e => + expected.foreach { e => lines should contain(e) } } - private val inferPasses = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes, - ResolveFlows, - new InferWidths) + private val inferPasses = + Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes, ResolveFlows, new InferWidths) private val inferAndCheckPasses = inferPasses :+ CheckWidths @@ -42,13 +38,13 @@ class WidthSpec extends FirrtlFlatSpec { LiteralWidthCheck(4, Some(3), 4) ) for (LiteralWidthCheck(lit, uwo, sw) <- litChecks) { - import firrtl.ir.{UIntLiteral, SIntLiteral, IntWidth} + import firrtl.ir.{IntWidth, SIntLiteral, UIntLiteral} s"$lit" should s"have signed width $sw" in { - SIntLiteral(lit).width should equal (IntWidth(sw)) + SIntLiteral(lit).width should equal(IntWidth(sw)) } uwo.foreach { uw => it should s"have unsigned width $uw" in { - UIntLiteral(lit).width should equal (IntWidth(uw)) + UIntLiteral(lit).width should equal(IntWidth(uw)) } } } @@ -75,7 +71,7 @@ class WidthSpec extends FirrtlFlatSpec { | input i: UInt<2> | node x = asClock(i)""".stripMargin intercept[CheckWidths.MultiBitAsClock] { - executeTest(input, Nil, inferAndCheckPasses) + executeTest(input, Nil, inferAndCheckPasses) } } @@ -86,15 +82,15 @@ class WidthSpec extends FirrtlFlatSpec { | input i: UInt<2> | node x = asAsyncReset(i)""".stripMargin intercept[CheckWidths.MultiBitAsAsyncReset] { - executeTest(input, Nil, inferAndCheckPasses) + executeTest(input, Nil, inferAndCheckPasses) } } "Width >= MaxWidth" should "result in an error" in { val input = - s"""circuit Unit : - | module Unit : - | input x: UInt<${CheckWidths.MaxWidth}> + s"""circuit Unit : + | module Unit : + | input x: UInt<${CheckWidths.MaxWidth}> """.stripMargin intercept[CheckWidths.WidthTooBig] { executeTest(input, Nil, inferAndCheckPasses) @@ -124,7 +120,7 @@ class WidthSpec extends FirrtlFlatSpec { | input y: SInt<2> | output z: SInt | z <= add(x, y)""".stripMargin - val check = Seq( "output z : SInt<4>") + val check = Seq("output z : SInt<4>") intercept[PassExceptions] { executeTest(input, check, inferPasses) } @@ -138,13 +134,13 @@ class WidthSpec extends FirrtlFlatSpec { | input y: SInt<2> | output z: SInt | z <= sub(y, x)""".stripMargin - val check = Seq( "output z : SInt<5>") + val check = Seq("output z : SInt<5>") intercept[PassExceptions] { executeTest(input, check, inferPasses) } } - behavior of "CheckWidths.UniferredWidth" + behavior.of("CheckWidths.UniferredWidth") it should "provide a good error message with a full target if a user forgets an assign" in { val input = @@ -155,9 +151,10 @@ class WidthSpec extends FirrtlFlatSpec { | module Bar : | wire a: { b : UInt<1>, c : { d : UInt<1>, e : UInt } } |""".stripMargin - val msg = intercept[CheckWidths.UninferredWidth] { executeTest(input, Nil, inferAndCheckPasses) } - .getMessage should include ("""| circuit Foo: - | └── module Bar: - | └── a.c.e""".stripMargin) + val msg = intercept[CheckWidths.UninferredWidth] { + executeTest(input, Nil, inferAndCheckPasses) + }.getMessage should include("""| circuit Foo: + | └── module Bar: + | └── a.c.e""".stripMargin) } } diff --git a/src/test/scala/firrtlTests/WiringTests.scala b/src/test/scala/firrtlTests/WiringTests.scala index 8ec6d5ce..0c5be2e0 100644 --- a/src/test/scala/firrtlTests/WiringTests.scala +++ b/src/test/scala/firrtlTests/WiringTests.scala @@ -9,15 +9,14 @@ import annotations._ import wiring._ class WiringTests extends FirrtlFlatSpec { - private def executeTest(input: String, - expected: String, - passes: Seq[Transform], - annos: Seq[Annotation]): Unit = { - val c = passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm, annos)) { - (c: CircuitState, p: Transform) => p.runTransform(c) - }.circuit - - (parse(c.serialize).serialize) should be (parse(expected).serialize) + private def executeTest(input: String, expected: String, passes: Seq[Transform], annos: Seq[Annotation]): Unit = { + val c = passes + .foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm, annos)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + .circuit + + (parse(c.serialize).serialize) should be(parse(expected).serialize) } private def executeTest(input: String, expected: String, passes: Seq[Transform]): Unit = { @@ -405,8 +404,10 @@ class WiringTests extends FirrtlFlatSpec { } it should "wire multiple sinks in the same module" in { - val sinks = Seq(ComponentName("s", ModuleName("A", CircuitName("Top"))), - ComponentName("t", ModuleName("A", CircuitName("Top")))) + val sinks = Seq( + ComponentName("s", ModuleName("A", CircuitName("Top"))), + ComponentName("t", ModuleName("A", CircuitName("Top"))) + ) val source = ComponentName("r", ModuleName("A", CircuitName("Top"))) val sas = WiringInfo(source, sinks, "pin") val input = @@ -741,8 +742,7 @@ class WiringTests extends FirrtlFlatSpec { | bundle_0 <= bundle | module B : | input clock : Clock - | input pin : {x : UInt<1>, y: UInt<1>, z: {zz : UInt<1>} }""" - .stripMargin + | input pin : {x : UInt<1>, y: UInt<1>, z: {zz : UInt<1>} }""".stripMargin val wiringXForm = new WiringTransform() executeTest(input, check, passes :+ wiringXForm, Seq(source, sink)) @@ -753,9 +753,7 @@ class WiringTests extends FirrtlFlatSpec { val sourceX = ComponentName("r.x", ModuleName("A", CircuitName("Top"))) val sinkY = Seq(ModuleName("Y", CircuitName("Top"))) val sourceY = ComponentName("r.x", ModuleName("A", CircuitName("Top"))) - val wiSeq = Seq( - WiringInfo(sourceX, sinkX, "pin"), - WiringInfo(sourceY, sinkY, "pin")) + val wiSeq = Seq(WiringInfo(sourceX, sinkX, "pin"), WiringInfo(sourceY, sinkY, "pin")) val input = """|circuit Top : | module Top : @@ -809,9 +807,7 @@ class WiringTests extends FirrtlFlatSpec { val sink = ComponentName("s", ModuleName("Top", CircuitName("Top"))) val source1 = ComponentName("r", ModuleName("Top", CircuitName("Top"))) val source2 = ComponentName("r2", ModuleName("Top", CircuitName("Top"))) - val annos = Seq(SourceAnnotation(source1, "pin"), - SourceAnnotation(source2, "pin"), - SinkAnnotation(sink, "pin")) + val annos = Seq(SourceAnnotation(source1, "pin"), SourceAnnotation(source2, "pin"), SinkAnnotation(sink, "pin")) val input = """|circuit Top : | module Top : @@ -820,7 +816,7 @@ class WiringTests extends FirrtlFlatSpec { | reg r: UInt<5>, clock | reg r2: UInt<5>, clock |""".stripMargin - a [WiringException] shouldBe thrownBy { + a[WiringException] shouldBe thrownBy { executeTest(input, "", passes :+ new WiringTransform, annos) } } diff --git a/src/test/scala/firrtlTests/ZeroLengthVecsSpec.scala b/src/test/scala/firrtlTests/ZeroLengthVecsSpec.scala index 715714dd..48eb24c1 100644 --- a/src/test/scala/firrtlTests/ZeroLengthVecsSpec.scala +++ b/src/test/scala/firrtlTests/ZeroLengthVecsSpec.scala @@ -7,18 +7,14 @@ import firrtl.passes._ import firrtl.testutils.FirrtlFlatSpec class ZeroLengthVecsSpec extends FirrtlFlatSpec { - val transforms = Seq( - ToWorkingIR, - ResolveKinds, - InferTypes, - ResolveFlows, - new InferWidths, - ZeroLengthVecs, - CheckTypes) + val transforms = Seq(ToWorkingIR, ResolveKinds, InferTypes, ResolveFlows, new InferWidths, ZeroLengthVecs, CheckTypes) protected def exec(input: String) = { - transforms.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, t: Transform) => t.runTransform(c) - }.circuit.serialize + transforms + .foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, t: Transform) => + t.runTransform(c) + } + .circuit + .serialize } "ZeroLengthVecs" should "drop subaccesses to zero-length vectors" in { @@ -42,7 +38,7 @@ class ZeroLengthVecsSpec extends FirrtlFlatSpec { | skip | o <= validif(UInt<1>(0), UInt<8>(0)) |""".stripMargin - (parse(exec(input))) should be (parse(check)) + (parse(exec(input))) should be(parse(check)) } "ZeroLengthVecs" should "handle intervals correctly" in { @@ -62,7 +58,7 @@ class ZeroLengthVecsSpec extends FirrtlFlatSpec { | output o : Interval[3,4].0 | o <= validif(UInt<1>(0), clip(asInterval(SInt<1>(0), 0, 0, 0), i[sel])) |""".stripMargin - (parse(exec(input))) should be (parse(check)) + (parse(exec(input))) should be(parse(check)) } } diff --git a/src/test/scala/firrtlTests/ZeroWidthTests.scala b/src/test/scala/firrtlTests/ZeroWidthTests.scala index b53f55ea..3c3df5ca 100644 --- a/src/test/scala/firrtlTests/ZeroWidthTests.scala +++ b/src/test/scala/firrtlTests/ZeroWidthTests.scala @@ -7,20 +7,17 @@ import firrtl.passes._ import firrtl.testutils._ class ZeroWidthTests extends FirrtlFlatSpec { - def transforms = Seq( - ToWorkingIR, - ResolveKinds, - InferTypes, - ResolveFlows, - new InferWidths, - ZeroWidth) - private def exec (input: String) = { + def transforms = Seq(ToWorkingIR, ResolveKinds, InferTypes, ResolveFlows, new InferWidths, ZeroWidth) + private def exec(input: String) = { val circuit = parse(input) - transforms.foldLeft(CircuitState(circuit, UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) - }.circuit.serialize - } - // ============================= + transforms + .foldLeft(CircuitState(circuit, UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) + } + .circuit + .serialize + } + // ============================= "Zero width port" should " be deleted" in { val input = """circuit Top : @@ -30,10 +27,10 @@ class ZeroWidthTests extends FirrtlFlatSpec { | x <= y""".stripMargin val check = """circuit Top : - | module Top : - | output x : UInt<1> - | x <= UInt<1>(0)""".stripMargin - (parse(exec(input))) should be (parse(check)) + | module Top : + | output x : UInt<1> + | x <= UInt<1>(0)""".stripMargin + (parse(exec(input))) should be(parse(check)) } "Add of <0> and <2> " should " put in zero" in { val input = @@ -47,7 +44,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { | module Top : | output x : UInt<3> | x <= add(UInt<1>(0), UInt<2>(2))""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "Mux on <0>" should "put in zero" in { val input = @@ -61,7 +58,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { | module Top : | output x : UInt<2> | x <= mux(UInt<1>(0), UInt<2>(2), UInt<2>(1))""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "Bundle with field of <0>" should "get deleted" in { val input = @@ -75,7 +72,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { | module Top : | output x : { b: UInt<1> } | skip""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "Vector with type of <0>" should "get deleted" in { val input = @@ -88,7 +85,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { """circuit Top : | module Top : | skip""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "Node with <0>" should "be removed" in { val input = @@ -100,7 +97,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { """circuit Top : | module Top : | skip""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "IsInvalid on <0>" should "be deleted" in { val input = @@ -112,7 +109,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { """circuit Top : | module Top : | skip""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "Expression in node with type <0>" should "be replaced by UInt<1>(0)" in { val input = @@ -126,7 +123,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { | module Top : | input x: UInt<1> | node z = add(x, UInt<1>(0))""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "Expression in cat with type <0>" should "be removed" in { val input = @@ -140,7 +137,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { | module Top : | input x: UInt<1> | node z = x""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "Nested cats with type <0>" should "be removed" in { val input = @@ -154,7 +151,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { """circuit Top : | module Top : | skip""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "Nested cats where one has type <0>" should "be unaffected" in { val input = @@ -170,7 +167,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { | input x: UInt<1> | input z: UInt<1> | node a = cat(x, z)""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "Stop with type <0>" should "be replaced with UInt(0)" in { val input = @@ -188,7 +185,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { | input x: UInt<1> | input z: UInt<1> | stop(clk, UInt(0), 1)""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "Print with type <0>" should "be replaced with UInt(0)" in { val input = @@ -206,7 +203,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { | input x: UInt<1> | input z: UInt<1> | printf(clk, UInt(1), "%d %d %d\n", x, UInt(0), z)""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "Andr of zero-width expression" should "return true" in { @@ -218,10 +215,10 @@ class ZeroWidthTests extends FirrtlFlatSpec { | x <= andr(y)""".stripMargin val check = """circuit Top : - | module Top : - | output x : UInt<1> - | x <= UInt<1>(1)""".stripMargin - (parse(exec(input))) should be (parse(check)) + | module Top : + | output x : UInt<1> + | x <= UInt<1>(1)""".stripMargin + (parse(exec(input))) should be(parse(check)) } } @@ -230,17 +227,17 @@ class ZeroWidthVerilog extends FirrtlFlatSpec { val compiler = new VerilogCompiler val input = """circuit Top : - | module Top : - | input y: UInt<0> - | output x: UInt<3> - | x <= y""".stripMargin + | module Top : + | input y: UInt<0> + | output x: UInt<3> + | x <= y""".stripMargin val check = """module Top( | output [2:0] x |); | assign x = 3'h0; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } } diff --git a/src/test/scala/firrtlTests/analyses/CircuitGraphSpec.scala b/src/test/scala/firrtlTests/analyses/CircuitGraphSpec.scala index 79922fa9..0f0d5d47 100644 --- a/src/test/scala/firrtlTests/analyses/CircuitGraphSpec.scala +++ b/src/test/scala/firrtlTests/analyses/CircuitGraphSpec.scala @@ -11,40 +11,42 @@ import firrtl.testutils.FirrtlFlatSpec import firrtl.{ChirrtlForm, CircuitState, FileUtils, UnknownForm} class CircuitGraphSpec extends FirrtlFlatSpec { - "CircuitGraph" should "find paths with deep hierarchy quickly" in { - def mkChild(n: Int): String = - s""" module Child${n} : - | input in: UInt<8> - | output out: UInt<8> - | inst c1 of Child${n+1} - | inst c2 of Child${n+1} - | c1.in <= in - | c2.in <= c1.out - | out <= c2.out + "CircuitGraph" should "find paths with deep hierarchy quickly" in { + def mkChild(n: Int): String = + s""" module Child${n} : + | input in: UInt<8> + | output out: UInt<8> + | inst c1 of Child${n + 1} + | inst c2 of Child${n + 1} + | c1.in <= in + | c2.in <= c1.out + | out <= c2.out """.stripMargin - def mkLeaf(n: Int): String = - s""" module Child${n} : - | input in: UInt<8> - | output out: UInt<8> - | wire middle: UInt<8> - | middle <= in - | out <= middle + def mkLeaf(n: Int): String = + s""" module Child${n} : + | input in: UInt<8> + | output out: UInt<8> + | wire middle: UInt<8> + | middle <= in + | out <= middle """.stripMargin - (2 until 23 by 2).foreach { n => - val input = new StringBuilder() - input ++= - """circuit Child0: - |""".stripMargin - (0 until n).foreach { i => input ++= mkChild(i); input ++= "\n" } - input ++= mkLeaf(n) - val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform( + (2 until 23 by 2).foreach { n => + val input = new StringBuilder() + input ++= + """circuit Child0: + |""".stripMargin + (0 until n).foreach { i => input ++= mkChild(i); input ++= "\n" } + input ++= mkLeaf(n) + val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])) + .runTransform( CircuitState(parse(input.toString()), UnknownForm) - ).circuit - val circuitGraph = CircuitGraph(circuit) - val C = CircuitTarget("Child0") - val Child0 = C.module("Child0") - circuitGraph.connectionPath(Child0.ref("in"), Child0.ref("out")) - } + ) + .circuit + val circuitGraph = CircuitGraph(circuit) + val C = CircuitTarget("Child0") + val Child0 = C.module("Child0") + circuitGraph.connectionPath(Child0.ref("in"), Child0.ref("out")) } + } } diff --git a/src/test/scala/firrtlTests/analyses/ConnectionGraphSpec.scala b/src/test/scala/firrtlTests/analyses/ConnectionGraphSpec.scala index 06f59a3c..e08b7efc 100644 --- a/src/test/scala/firrtlTests/analyses/ConnectionGraphSpec.scala +++ b/src/test/scala/firrtlTests/analyses/ConnectionGraphSpec.scala @@ -14,9 +14,11 @@ class ConnectionGraphSpec extends FirrtlFlatSpec { "ConnectionGraph" should "build connection graph for rocket-chip" in { ConnectionGraph( - new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform( - CircuitState(parse(FileUtils.getTextResource("/regress/RocketCore.fir")), UnknownForm) - ).circuit + new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])) + .runTransform( + CircuitState(parse(FileUtils.getTextResource("/regress/RocketCore.fir")), UnknownForm) + ) + .circuit ) } @@ -44,9 +46,11 @@ class ConnectionGraphSpec extends FirrtlFlatSpec { | out <= in |""".stripMargin - val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform( - CircuitState(parse(input), UnknownForm) - ).circuit + val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])) + .runTransform( + CircuitState(parse(input), UnknownForm) + ) + .circuit "ConnectionGraph" should "work with pathsInDAG" in { val Test = ModuleTarget("Test", "Test") diff --git a/src/test/scala/firrtlTests/analyses/IRLookupSpec.scala b/src/test/scala/firrtlTests/analyses/IRLookupSpec.scala index 50ee75ac..b1e9fd73 100644 --- a/src/test/scala/firrtlTests/analyses/IRLookupSpec.scala +++ b/src/test/scala/firrtlTests/analyses/IRLookupSpec.scala @@ -10,7 +10,6 @@ import firrtl.passes.ExpandWhensAndCheck import firrtl.stage.{Forms, TransformManager} import firrtl.testutils.FirrtlFlatSpec - class IRLookupSpec extends FirrtlFlatSpec { "IRLookup" should "return declarations" in { @@ -38,9 +37,11 @@ class IRLookupSpec extends FirrtlFlatSpec { | out <= UInt(1) |""".stripMargin - val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform( - CircuitState(parse(input), UnknownForm) - ).circuit + val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])) + .runTransform( + CircuitState(parse(input), UnknownForm) + ) + .circuit val irLookup = IRLookup(circuit) val Test = ModuleTarget("Test", "Test") val uint8 = UIntType(IntWidth(8)) @@ -49,7 +50,10 @@ class IRLookupSpec extends FirrtlFlatSpec { irLookup.declaration(Test.ref("clk")) shouldBe Port(NoInfo, "clk", Input, ClockType) irLookup.declaration(Test.ref("reset")) shouldBe Port(NoInfo, "reset", Input, UIntType(IntWidth(1))) - val out = Port(NoInfo, "out", Output, + val out = Port( + NoInfo, + "out", + Output, BundleType(Seq(Field("a", Default, uint8), Field("b", Default, VectorType(uint8, 2)))) ) irLookup.declaration(Test.ref("out")) shouldBe out @@ -73,7 +77,8 @@ class IRLookupSpec extends FirrtlFlatSpec { irLookup.declaration(Test.ref("y")) shouldBe DefWire(NoInfo, "y", uint8) irLookup.declaration(Test.ref("@and#0")) shouldBe - DoPrim(PrimOps.And, + DoPrim( + PrimOps.And, Seq(WRef("y", uint8, WireKind, SourceFlow), DoPrim(AsUInt, Seq(SIntLiteral(-1)), Nil, UIntType(IntWidth(1)))), Nil, uint8 @@ -84,12 +89,14 @@ class IRLookupSpec extends FirrtlFlatSpec { irLookup.declaration(Test.ref("child").field("out")) shouldBe inst irLookup.declaration(Test.instOf("child", "Child").ref("out")) shouldBe Port(NoInfo, "out", Output, uint8) - intercept[IllegalArgumentException]{ irLookup.declaration(Test.instOf("child", "Child").ref("missing")) } - intercept[IllegalArgumentException]{ irLookup.declaration(Test.instOf("child", "Missing").ref("out")) } - intercept[IllegalArgumentException]{ irLookup.declaration(Test.instOf("missing", "Child").ref("out")) } - intercept[IllegalArgumentException]{ irLookup.declaration(Test.ref("missing")) } - intercept[IllegalArgumentException]{ irLookup.declaration(Test.ref("out").field("c")) } - intercept[IllegalArgumentException]{ irLookup.declaration(Test.instOf("child", "Child").ref("out").field("missing")) } + intercept[IllegalArgumentException] { irLookup.declaration(Test.instOf("child", "Child").ref("missing")) } + intercept[IllegalArgumentException] { irLookup.declaration(Test.instOf("child", "Missing").ref("out")) } + intercept[IllegalArgumentException] { irLookup.declaration(Test.instOf("missing", "Child").ref("out")) } + intercept[IllegalArgumentException] { irLookup.declaration(Test.ref("missing")) } + intercept[IllegalArgumentException] { irLookup.declaration(Test.ref("out").field("c")) } + intercept[IllegalArgumentException] { + irLookup.declaration(Test.instOf("child", "Child").ref("out").field("missing")) + } } "IRLookup" should "return mem declarations" in { @@ -152,9 +159,11 @@ class IRLookupSpec extends FirrtlFlatSpec { val Readwriter = Mem.field("rw") val allSignals = readerTargets(Reader) ++ writerTargets(Writer) ++ readwriterTargets(Readwriter) - val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform( - CircuitState(parse(input), UnknownForm) - ).circuit + val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])) + .runTransform( + CircuitState(parse(input), UnknownForm) + ) + .circuit val irLookup = IRLookup(circuit) val uint8 = UIntType(IntWidth(8)) val mem = DefMemory(NoInfo, "m", uint8, 2, 1, 0, Seq("r"), Seq("w"), Seq("rw")) @@ -188,9 +197,11 @@ class IRLookupSpec extends FirrtlFlatSpec { | out <= UInt(1) |""".stripMargin - val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform( - CircuitState(parse(input), UnknownForm) - ).circuit + val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])) + .runTransform( + CircuitState(parse(input), UnknownForm) + ) + .circuit val irLookup = IRLookup(circuit) val Test = ModuleTarget("Test", "Test") val uint8 = UIntType(IntWidth(8)) @@ -209,7 +220,8 @@ class IRLookupSpec extends FirrtlFlatSpec { val out = Test.ref("out") val outExpr = - WRef("out", + WRef( + "out", BundleType(Seq(Field("a", Default, uint8), Field("b", Default, VectorType(uint8, 2)))), PortKind, SinkFlow @@ -237,8 +249,10 @@ class IRLookupSpec extends FirrtlFlatSpec { check(Test.ref("y"), WRef("y", uint8, WireKind, DuplexFlow)) - check(Test.ref("@and#0"), - DoPrim(PrimOps.And, + check( + Test.ref("@and#0"), + DoPrim( + PrimOps.And, Seq(WRef("y", uint8, WireKind, SourceFlow), DoPrim(AsUInt, Seq(SIntLiteral(-1)), Nil, UIntType(IntWidth(1)))), Nil, uint8 @@ -247,33 +261,34 @@ class IRLookupSpec extends FirrtlFlatSpec { val child = WRef("child", BundleType(Seq(Field("out", Default, uint8))), InstanceKind, SourceFlow) check(Test.ref("child"), child) - check(Test.ref("child").field("out"), - WSubField(child, "out", uint8, SourceFlow) - ) + check(Test.ref("child").field("out"), WSubField(child, "out", uint8, SourceFlow)) } "IRLookup" should "cache expressions" in { def mkType(i: Int): String = { - if(i == 0) "UInt<8>" else s"{x: ${mkType(i - 1)}}" + if (i == 0) "UInt<8>" else s"{x: ${mkType(i - 1)}}" } val depth = 500 val input = s"""circuit Test: - | module Test : - | input in: ${mkType(depth)} - | output out: ${mkType(depth)} - | out <= in - |""".stripMargin + | module Test : + | input in: ${mkType(depth)} + | output out: ${mkType(depth)} + | out <= in + |""".stripMargin - val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform( - CircuitState(parse(input), UnknownForm) - ).circuit + val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])) + .runTransform( + CircuitState(parse(input), UnknownForm) + ) + .circuit val Test = ModuleTarget("Test", "Test") val irLookup = IRLookup(circuit) def mkReferences(parent: ReferenceTarget, i: Int): Seq[ReferenceTarget] = { - if(i == 0) Seq(parent) else { + if (i == 0) Seq(parent) + else { val newParent = parent.field("x") newParent +: mkReferences(newParent, i - 1) } diff --git a/src/test/scala/firrtlTests/analyses/InstanceGraphTests.scala b/src/test/scala/firrtlTests/analyses/InstanceGraphTests.scala index a0d444b3..e134f6e5 100644 --- a/src/test/scala/firrtlTests/analyses/InstanceGraphTests.scala +++ b/src/test/scala/firrtlTests/analyses/InstanceGraphTests.scala @@ -9,10 +9,10 @@ import firrtl.testutils._ class InstanceGraphTests extends FirrtlFlatSpec { private def getEdgeSet(graph: DiGraph[String]): collection.Map[String, collection.Set[String]] = { - (graph.getVertices map {v => (v, graph.getEdges(v))}).toMap + (graph.getVertices.map { v => (v, graph.getEdges(v)) }).toMap } - behavior of "InstanceGraph" + behavior.of("InstanceGraph") it should "recognize a simple hierarchy" in { val input = """ @@ -33,7 +33,13 @@ circuit Top : """ val circuit = ToWorkingIR.run(parse(input)) val graph = new InstanceGraph(circuit).graph.transformNodes(_.module) - getEdgeSet(graph) shouldBe Map("Top" -> Set("Child1", "Child2"), "Child1" -> Set("Child1a", "Child1b"), "Child2" -> Set(), "Child1a" -> Set(), "Child1b" -> Set()) + getEdgeSet(graph) shouldBe Map( + "Top" -> Set("Child1", "Child2"), + "Child1" -> Set("Child1a", "Child1b"), + "Child2" -> Set(), + "Child1a" -> Set(), + "Child1b" -> Set() + ) } it should "find hierarchical instances correctly in disconnected hierarchies" in { @@ -97,12 +103,20 @@ circuit Top : """ val circuit = ToWorkingIR.run(parse(input)) val graph = new InstanceGraph(circuit).graph.transformNodes(_.module) - getEdgeSet(graph) shouldBe Map("Top" -> Set("Child1"), "Top2" -> Set("Child2", "Child3"), "Child2" -> Set("Child2a", "Child2b"), "Child1" -> Set(), "Child2a" -> Set(), "Child2b" -> Set(), "Child3" -> Set()) + getEdgeSet(graph) shouldBe Map( + "Top" -> Set("Child1"), + "Top2" -> Set("Child2", "Child3"), + "Child2" -> Set("Child2a", "Child2b"), + "Child1" -> Set(), + "Child2a" -> Set(), + "Child2b" -> Set(), + "Child3" -> Set() + ) } it should "not drop duplicate nodes when they collide as a result of transformNodes" in { val input = -"""circuit Top : + """circuit Top : module Buzz : skip module Fizz : @@ -134,70 +148,70 @@ circuit Top : // experience non-determinism it should "preserve Module declaration order" in { val input = """ - |circuit Top : - | module Top : - | inst c1 of Child1 - | inst c2 of Child2 - | module Child1 : - | inst a of Child1a - | inst b of Child1b - | skip - | module Child1a : - | skip - | module Child1b : - | skip - | module Child2 : - | skip - |""".stripMargin + |circuit Top : + | module Top : + | inst c1 of Child1 + | inst c2 of Child2 + | module Child1 : + | inst a of Child1a + | inst b of Child1b + | skip + | module Child1a : + | skip + | module Child1b : + | skip + | module Child2 : + | skip + |""".stripMargin val circuit = ToWorkingIR.run(parse(input)) val instGraph = new InstanceGraph(circuit) val childMap = instGraph.getChildrenInstances - childMap.keys.toSeq should equal (Seq("Top", "Child1", "Child1a", "Child1b", "Child2")) + childMap.keys.toSeq should equal(Seq("Top", "Child1", "Child1a", "Child1b", "Child2")) } // Note that due to optimized implementations of Map1-4, at least 5 entries are needed to // experience non-determinism it should "preserve Instance declaration order" in { val input = """ - |circuit Top : - | module Top : - | inst a of Child - | inst b of Child - | inst c of Child - | inst d of Child - | inst e of Child - | inst f of Child - | module Child : - | skip - |""".stripMargin + |circuit Top : + | module Top : + | inst a of Child + | inst b of Child + | inst c of Child + | inst d of Child + | inst e of Child + | inst f of Child + | module Child : + | skip + |""".stripMargin val circuit = ToWorkingIR.run(parse(input)) val instGraph = new InstanceGraph(circuit) val childMap = instGraph.getChildrenInstances val insts = childMap("Top").toSeq.map(_.name) - insts should equal (Seq("a", "b", "c", "d", "e", "f")) + insts should equal(Seq("a", "b", "c", "d", "e", "f")) } // Note that due to optimized implementations of Map1-4, at least 5 entries are needed to // experience non-determinism it should "have defined fullHierarchy order" in { val input = """ - |circuit Top : - | module Top : - | inst a of Child - | inst b of Child - | inst c of Child - | inst d of Child - | inst e of Child - | module Child : - | skip - |""".stripMargin + |circuit Top : + | module Top : + | inst a of Child + | inst b of Child + | inst c of Child + | inst d of Child + | inst e of Child + | module Child : + | skip + |""".stripMargin val circuit = ToWorkingIR.run(parse(input)) val instGraph = new InstanceGraph(circuit) val hier = instGraph.fullHierarchy - hier.keys.toSeq.map(_.name) should equal (Seq("Top", "a", "b", "c", "d", "e")) + hier.keys.toSeq.map(_.name) should equal(Seq("Top", "a", "b", "c", "d", "e")) } - behavior of "InstanceGraph.staticInstanceCount" + behavior.of("InstanceGraph.staticInstanceCount") it should "report that there is one instance of the top module" in { val input = @@ -207,7 +221,7 @@ circuit Top : |""".stripMargin val iGraph = new InstanceGraph(ToWorkingIR.run(parse(input))) val expectedCounts = Map(OfModule("Foo") -> 1) - iGraph.staticInstanceCount should be (expectedCounts) + iGraph.staticInstanceCount should be(expectedCounts) } it should "report correct number of instances for a sample circuit" in { @@ -225,10 +239,8 @@ circuit Top : | inst bar2 of Bar |""".stripMargin val iGraph = new InstanceGraph(ToWorkingIR.run(parse(input))) - val expectedCounts = Map(OfModule("Foo") -> 1, - OfModule("Bar") -> 2, - OfModule("Baz") -> 3) - iGraph.staticInstanceCount should be (expectedCounts) + val expectedCounts = Map(OfModule("Foo") -> 1, OfModule("Bar") -> 2, OfModule("Baz") -> 3) + iGraph.staticInstanceCount should be(expectedCounts) } it should "report zero instances for dead modules" in { @@ -240,12 +252,11 @@ circuit Top : | skip |""".stripMargin val iGraph = new InstanceGraph(ToWorkingIR.run(parse(input))) - val expectedCounts = Map(OfModule("Foo") -> 1, - OfModule("Bar") -> 0) - iGraph.staticInstanceCount should be (expectedCounts) + val expectedCounts = Map(OfModule("Foo") -> 1, OfModule("Bar") -> 0) + iGraph.staticInstanceCount should be(expectedCounts) } - behavior of "Reachable/Unreachable helper methods" + behavior.of("Reachable/Unreachable helper methods") they should "report correct reachable/unreachable counts" in { val input = diff --git a/src/test/scala/firrtlTests/analyses/InstanceKeyGraphSpec.scala b/src/test/scala/firrtlTests/analyses/InstanceKeyGraphSpec.scala index ec403259..1e486fe4 100644 --- a/src/test/scala/firrtlTests/analyses/InstanceKeyGraphSpec.scala +++ b/src/test/scala/firrtlTests/analyses/InstanceKeyGraphSpec.scala @@ -9,10 +9,10 @@ import firrtl.graph.DiGraph import firrtl.testutils.FirrtlFlatSpec class InstanceKeyGraphSpec extends FirrtlFlatSpec { - behavior of "InstanceKeyGraph.graph" + behavior.of("InstanceKeyGraph.graph") private def getEdgeSet(graph: DiGraph[String]): collection.Map[String, collection.Set[String]] = { - (graph.getVertices map {v => (v, graph.getEdges(v))}).toMap + (graph.getVertices.map { v => (v, graph.getEdges(v)) }).toMap } it should "recognize a simple hierarchy" in { @@ -37,7 +37,10 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { getEdgeSet(graph) shouldBe Map( "Top" -> Set("Child1", "Child2"), "Child1" -> Set("Child1a", "Child1b"), - "Child2" -> Set(), "Child1a" -> Set(), "Child1b" -> Set()) + "Child2" -> Set(), + "Child1a" -> Set(), + "Child1b" -> Set() + ) } it should "recognize disconnected hierarchies" in { @@ -69,7 +72,11 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { "Top" -> Set("Child1"), "Top2" -> Set("Child2", "Child3"), "Child2" -> Set("Child2a", "Child2b"), - "Child1" -> Set(), "Child2a" -> Set(), "Child2b" -> Set(), "Child3" -> Set()) + "Child1" -> Set(), + "Child2a" -> Set(), + "Child2b" -> Set(), + "Child3" -> Set() + ) } it should "not drop duplicate nodes when they collide as a result of transformNodes" in { @@ -101,8 +108,7 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { g2.getEdges("Fizz") shouldBe Set("Foo", "Bar") } - - behavior of "InstanceKeyGraph.getChildInstances" + behavior.of("InstanceKeyGraph.getChildInstances") // Note that due to optimized implementations of Map1-4, at least 5 entries are needed to // experience non-determinism @@ -126,7 +132,7 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { val circuit = parse(input) val instGraph = InstanceKeyGraph(circuit) val childMap = instGraph.getChildInstances - childMap.map(_._1) should equal (Seq("Top", "Child1", "Child1a", "Child1b", "Child2")) + childMap.map(_._1) should equal(Seq("Top", "Child1", "Child1a", "Child1b", "Child2")) } // Note that due to optimized implementations of Map1-4, at least 5 entries are needed to @@ -148,10 +154,10 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { val instGraph = InstanceKeyGraph(circuit) val childMap = instGraph.getChildInstances.toMap val insts = childMap("Top").map(_.name) - insts should equal (Seq("a", "b", "c", "d", "e", "f")) + insts should equal(Seq("a", "b", "c", "d", "e", "f")) } - behavior of "InstanceKeyGraph.moduleOrder" + behavior.of("InstanceKeyGraph.moduleOrder") it should "compute a correct and deterministic module order" in { val input = """ @@ -180,10 +186,10 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { val instGraph = InstanceKeyGraph(circuit) val order = instGraph.moduleOrder.map(_.name) // Where it has freedom, the instance declaration order will be reversed. - order should equal (Seq("Top", "Child3", "Child4", "Child2", "Child1", "Child1b", "Child1a")) + order should equal(Seq("Top", "Child3", "Child4", "Child2", "Child1", "Child1b", "Child1a")) } - behavior of "InstanceKeyGraph.findInstancesInHierarchy" + behavior.of("InstanceKeyGraph.findInstancesInHierarchy") it should "find hierarchical instances correctly in disconnected hierarchies" in { val input = @@ -221,7 +227,7 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { iGraph.findInstancesInHierarchy("Child3") shouldBe Nil } - behavior of "InstanceKeyGraph.staticInstanceCount" + behavior.of("InstanceKeyGraph.staticInstanceCount") it should "report that there is one instance of the top module" in { val input = @@ -231,7 +237,7 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { |""".stripMargin val iGraph = InstanceKeyGraph(parse(input)) val expectedCounts = Map(OfModule("Foo") -> 1) - iGraph.staticInstanceCount should be (expectedCounts) + iGraph.staticInstanceCount should be(expectedCounts) } it should "report correct number of instances for a sample circuit" in { @@ -249,10 +255,8 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { | inst bar2 of Bar |""".stripMargin val iGraph = InstanceKeyGraph(parse(input)) - val expectedCounts = Map(OfModule("Foo") -> 1, - OfModule("Bar") -> 2, - OfModule("Baz") -> 3) - iGraph.staticInstanceCount should be (expectedCounts) + val expectedCounts = Map(OfModule("Foo") -> 1, OfModule("Bar") -> 2, OfModule("Baz") -> 3) + iGraph.staticInstanceCount should be(expectedCounts) } it should "report zero instances for dead modules" in { @@ -264,12 +268,11 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { | skip |""".stripMargin val iGraph = InstanceKeyGraph(parse(input)) - val expectedCounts = Map(OfModule("Foo") -> 1, - OfModule("Bar") -> 0) - iGraph.staticInstanceCount should be (expectedCounts) + val expectedCounts = Map(OfModule("Foo") -> 1, OfModule("Bar") -> 0) + iGraph.staticInstanceCount should be(expectedCounts) } - behavior of "InstanceKeyGraph.getChildInstanceMap" + behavior.of("InstanceKeyGraph.getChildInstanceMap") it should "preserve Module declaration order" in { val input = """ @@ -302,15 +305,17 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { assert(childMap(OfModule("Child1b")).isEmpty) assert(childMap(OfModule("Child2")).isEmpty) - val topInstances = childMap(OfModule("Top")).map { case (k,v) => k.value -> v.value}.toSeq - assert(topInstances == - Seq("c1" -> "Child1", "c2" -> "Child2", "c3" -> "Child1", "c4" -> "Child1", "c5" -> "Child1")) + val topInstances = childMap(OfModule("Top")).map { case (k, v) => k.value -> v.value }.toSeq + assert( + topInstances == + Seq("c1" -> "Child1", "c2" -> "Child2", "c3" -> "Child1", "c4" -> "Child1", "c5" -> "Child1") + ) - val child1Instance = childMap(OfModule("Child1")).map { case (k,v) => k.value -> v.value}.toSeq + val child1Instance = childMap(OfModule("Child1")).map { case (k, v) => k.value -> v.value }.toSeq assert(child1Instance == Seq("a" -> "Child1a", "b" -> "Child1b")) } - behavior of "InstanceKeyGraph.fullHierarchy" + behavior.of("InstanceKeyGraph.fullHierarchy") // Note that due to optimized implementations of Map1-4, at least 5 entries are needed to // experience non-determinism @@ -329,10 +334,10 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { val instGraph = InstanceKeyGraph(parse(input)) val hier = instGraph.fullHierarchy - hier.keys.toSeq.map(_.name) should equal (Seq("Top", "a", "b", "c", "d", "e")) + hier.keys.toSeq.map(_.name) should equal(Seq("Top", "a", "b", "c", "d", "e")) } - behavior of "Reachable/Unreachable helper methods" + behavior.of("Reachable/Unreachable helper methods") they should "report correct reachable/unreachable counts" in { val input = diff --git a/src/test/scala/firrtlTests/annotationTests/CleanupNamedTargetsSpec.scala b/src/test/scala/firrtlTests/annotationTests/CleanupNamedTargetsSpec.scala index 58cb3d11..67408bb7 100644 --- a/src/test/scala/firrtlTests/annotationTests/CleanupNamedTargetsSpec.scala +++ b/src/test/scala/firrtlTests/annotationTests/CleanupNamedTargetsSpec.scala @@ -11,7 +11,8 @@ import firrtl.annotations.{ MultiTargetAnnotation, ReferenceTarget, SingleTargetAnnotation, - Target} + Target +} import firrtl.annotations.transforms.CleanupNamedTargets import org.scalatest.flatspec.AnyFlatSpec @@ -56,7 +57,7 @@ class CleanupNamedTargetsSpec extends AnyFlatSpec with Matchers { } - behavior of "CleanupNamedTargets" + behavior.of("CleanupNamedTargets") it should "convert a SingleTargetAnnotation[ReferenceTarget] of an instance to an InstanceTarget" in new F { val annotations: AnnotationSeq = Seq(SingleReferenceAnnotation(barTarget)) @@ -71,10 +72,10 @@ class CleanupNamedTargetsSpec extends AnyFlatSpec with Matchers { val renames = transform.transform(circuitState(annotations)).renames.get - renames.get(barTarget) should be (Some(Seq(foo.instOf("bar", "Bar")))) + renames.get(barTarget) should be(Some(Seq(foo.instOf("bar", "Bar")))) info("and not touch a true ReferenceAnnotation") - renames.get(bazTarget) should be (None) + renames.get(bazTarget) should be(None) } diff --git a/src/test/scala/firrtlTests/annotationTests/EliminateTargetPathsSpec.scala b/src/test/scala/firrtlTests/annotationTests/EliminateTargetPathsSpec.scala index 73f36cf0..bb833f0b 100644 --- a/src/test/scala/firrtlTests/annotationTests/EliminateTargetPathsSpec.scala +++ b/src/test/scala/firrtlTests/annotationTests/EliminateTargetPathsSpec.scala @@ -6,7 +6,7 @@ import firrtl._ import firrtl.annotations._ import firrtl.annotations.analysis.DuplicationHelper import firrtl.annotations.transforms.{NoSuchTargetException} -import firrtl.transforms.{DontTouchAnnotation, DedupedResult} +import firrtl.transforms.{DedupedResult, DontTouchAnnotation} import firrtl.testutils.{FirrtlMatchers, FirrtlPropSpec} object EliminateTargetPathsSpec { @@ -15,7 +15,7 @@ object EliminateTargetPathsSpec { override def duplicate(n: Target): Annotation = DummyAnnotation(n) } class DummyTransform() extends Transform with ResolvedAnnotationPaths { - override def inputForm: CircuitForm = LowForm + override def inputForm: CircuitForm = LowForm override def outputForm: CircuitForm = LowForm override val annotationClasses: Traversable[Class[_]] = Seq(classOf[DummyAnnotation]) @@ -72,40 +72,47 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { property("Hierarchical tokens should be expanded properly") { val dupMap = DuplicationHelper(inputState.circuit.modules.map(_.name).toSet) - // Only a few instance references dupMap.expandHierarchy(Top_m1_l1_a) dupMap.expandHierarchy(Top_m2_l1_a) dupMap.expandHierarchy(Middle_l1_a) - dupMap.makePathless(Top_m1_l1_a).foreach {Set(TopCircuit.module("Leaf___Top_m1_l1").ref("a")) should contain (_)} - dupMap.makePathless(Top_m2_l1_a).foreach {Set(TopCircuit.module("Leaf___Top_m2_l1").ref("a")) should contain (_)} - dupMap.makePathless(Top_m1_l2_a).foreach {Set(Leaf_a) should contain (_)} - dupMap.makePathless(Top_m2_l2_a).foreach {Set(Leaf_a) should contain (_)} - dupMap.makePathless(Middle_l1_a).foreach {Set( - TopCircuit.module("Leaf___Top_m1_l1").ref("a"), - TopCircuit.module("Leaf___Top_m2_l1").ref("a"), - TopCircuit.module("Leaf___Middle_l1").ref("a") - ) should contain (_) } - dupMap.makePathless(Middle_l2_a).foreach {Set(Leaf_a) should contain (_)} - dupMap.makePathless(Leaf_a).foreach {Set( - TopCircuit.module("Leaf___Top_m1_l1").ref("a"), - TopCircuit.module("Leaf___Top_m2_l1").ref("a"), - TopCircuit.module("Leaf___Middle_l1").ref("a"), - Leaf_a - ) should contain (_)} - dupMap.makePathless(Top).foreach {Set(Top) should contain (_)} - dupMap.makePathless(Middle).foreach {Set( - TopCircuit.module("Middle___Top_m1"), - TopCircuit.module("Middle___Top_m2"), - Middle - ) should contain (_)} - dupMap.makePathless(Leaf).foreach {Set( - TopCircuit.module("Leaf___Top_m1_l1"), - TopCircuit.module("Leaf___Top_m2_l1"), - TopCircuit.module("Leaf___Middle_l1"), - Leaf - ) should contain (_) } + dupMap.makePathless(Top_m1_l1_a).foreach { Set(TopCircuit.module("Leaf___Top_m1_l1").ref("a")) should contain(_) } + dupMap.makePathless(Top_m2_l1_a).foreach { Set(TopCircuit.module("Leaf___Top_m2_l1").ref("a")) should contain(_) } + dupMap.makePathless(Top_m1_l2_a).foreach { Set(Leaf_a) should contain(_) } + dupMap.makePathless(Top_m2_l2_a).foreach { Set(Leaf_a) should contain(_) } + dupMap.makePathless(Middle_l1_a).foreach { + Set( + TopCircuit.module("Leaf___Top_m1_l1").ref("a"), + TopCircuit.module("Leaf___Top_m2_l1").ref("a"), + TopCircuit.module("Leaf___Middle_l1").ref("a") + ) should contain(_) + } + dupMap.makePathless(Middle_l2_a).foreach { Set(Leaf_a) should contain(_) } + dupMap.makePathless(Leaf_a).foreach { + Set( + TopCircuit.module("Leaf___Top_m1_l1").ref("a"), + TopCircuit.module("Leaf___Top_m2_l1").ref("a"), + TopCircuit.module("Leaf___Middle_l1").ref("a"), + Leaf_a + ) should contain(_) + } + dupMap.makePathless(Top).foreach { Set(Top) should contain(_) } + dupMap.makePathless(Middle).foreach { + Set( + TopCircuit.module("Middle___Top_m1"), + TopCircuit.module("Middle___Top_m2"), + Middle + ) should contain(_) + } + dupMap.makePathless(Leaf).foreach { + Set( + TopCircuit.module("Leaf___Top_m1_l1"), + TopCircuit.module("Leaf___Top_m2_l1"), + TopCircuit.module("Leaf___Middle_l1"), + Leaf + ) should contain(_) + } } property("Hierarchical donttouch should be resolved properly") { @@ -159,10 +166,10 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { | m2.i <= m1.o | """.stripMargin - canonicalize(outputState.circuit).serialize should be (canonicalize(parse(check)).serialize) + canonicalize(outputState.circuit).serialize should be(canonicalize(parse(check)).serialize) outputState.annotations.collect { case x: DontTouchAnnotation => x.target - } should be (Seq(Top.circuitTarget.module("Leaf___Top_m1_l1").ref("a"))) + } should be(Seq(Top.circuitTarget.module("Leaf___Top_m1_l1").ref("a"))) } property("No name conflicts between old and new modules") { @@ -199,7 +206,7 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { val outputState = new LowFirrtlCompiler().compile(inputState, customTransforms) val outputLines = outputState.circuit.serialize.split("\n") checks.foreach { line => - outputLines should contain (line) + outputLines should contain(line) } } @@ -239,7 +246,7 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { val outputLines = outputState.circuit.serialize.split("\n") checks.foreach { line => - outputLines should contain (line) + outputLines should contain(line) } checks.foreach { line => outputLines should not contain (" module Middle :") @@ -267,19 +274,19 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { | m2.i <= m1.o | o <= m2.o """.stripMargin - val e1 = the [CustomTransformException] thrownBy { + val e1 = the[CustomTransformException] thrownBy { val Top_m1 = Top.instOf("m1", "MiddleX") val inputState = CircuitState(parse(input), ChirrtlForm, Seq(DummyAnnotation(Top_m1))) new LowFirrtlCompiler().compile(inputState, customTransforms) } - e1.cause shouldBe a [NoSuchTargetException] + e1.cause shouldBe a[NoSuchTargetException] - val e2 = the [CustomTransformException] thrownBy { + val e2 = the[CustomTransformException] thrownBy { val Top_m2 = Top.instOf("x2", "Middle") val inputState = CircuitState(parse(input), ChirrtlForm, Seq(DummyAnnotation(Top_m2))) new LowFirrtlCompiler().compile(inputState, customTransforms) } - e2.cause shouldBe a [NoSuchTargetException] + e2.cause shouldBe a[NoSuchTargetException] } property("No name conflicts between two new modules") { @@ -320,11 +327,12 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { | module Leaf____Middle__l :""".stripMargin.split("\n") val Middle_l1 = CircuitTarget("Top").module("Middle").instOf("_l", "Leaf") val Middle_l2 = CircuitTarget("Top").module("Middle_").instOf("l", "Leaf") - val inputState = CircuitState(parse(input), ChirrtlForm, Seq(DummyAnnotation(Middle_l1), DummyAnnotation(Middle_l2))) + val inputState = + CircuitState(parse(input), ChirrtlForm, Seq(DummyAnnotation(Middle_l1), DummyAnnotation(Middle_l2))) val outputState = new LowFirrtlCompiler().compile(inputState, customTransforms) val outputLines = outputState.circuit.serialize.split("\n") checks.foreach { line => - outputLines should contain (line) + outputLines should contain(line) } } @@ -362,12 +370,12 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { val outputState = new VerilogCompiler().compile(inputState, customTransforms) val outputLines = outputState.circuit.serialize.split("\n") checks.foreach { line => - outputLines should contain (line) + outputLines should contain(line) } } property("It should remove ResolvePaths annotations") { - val input = + val input = """|circuit Foo: | module Bar: | skip @@ -378,7 +386,7 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { CircuitState(passes.ToWorkingIR.run(Parser.parse(input)), UnknownForm, Nil) .resolvePaths(Seq(CircuitTarget("Foo").module("Foo").instOf("bar", "Bar"))) .annotations - .collect{ case a: firrtl.annotations.transforms.ResolvePaths => a } should be (empty) + .collect { case a: firrtl.annotations.transforms.ResolvePaths => a } should be(empty) } property("It should rename module annotations") { @@ -404,16 +412,14 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { val parsedCheck = Parser.parse(check) info(output.circuit.serialize) - (output.circuit.serialize) should be (parsedCheck.serialize) + (output.circuit.serialize) should be(parsedCheck.serialize) val newBar_x = CircuitTarget("Foo").module("Bar___Foo_bar").ref("x") - output - .annotations - .filter{ - case _: DeletedAnnotation => false - case _ => true - } should contain allOf (DontTouchAnnotation(newBar_x), DontTouchAnnotation(Bar_x)) + (output.annotations.filter { + case _: DeletedAnnotation => false + case _ => true + } should contain).allOf(DontTouchAnnotation(newBar_x), DontTouchAnnotation(Bar_x)) } property("It should not rename lone instances") { @@ -440,10 +446,10 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { info(output.circuit.serialize) - output.circuit.serialize should be (inputCircuit.serialize) - output.annotations.collect { + output.circuit.serialize should be(inputCircuit.serialize) + (output.annotations.collect { case a: DontTouchAnnotation => a - } should contain allOf ( + } should contain).allOf( DontTouchAnnotation(ModuleTarget("Foo", "Foo").ref("foo")), DontTouchAnnotation(ModuleTarget("Foo", "Bar").ref("foo")), DontTouchAnnotation(ModuleTarget("Foo", "Baz").ref("foo")) @@ -481,12 +487,12 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { val outputLines = output.circuit.serialize.split("\n") checks.foreach { line => - outputLines should contain (line) + outputLines should contain(line) } - output.annotations.collect { + (output.annotations.collect { case a: DontTouchAnnotation => a - } should contain allOf ( + } should contain).allOf( DontTouchAnnotation(ModuleTarget("FooBar", "Bar___Foo_bar").ref("baz")), DontTouchAnnotation(ModuleTarget("FooBar", "Bar___Foo_barBar").ref("baz")) ) @@ -527,11 +533,11 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { val outputLines = output.circuit.serialize.split("\n") checks.foreach { line => - outputLines should contain (line) + outputLines should contain(line) } - output.annotations.collect { + (output.annotations.collect { case a: DontTouchAnnotation => a - } should contain allOf ( + } should contain).allOf( DontTouchAnnotation(ModuleTarget("Top", "Baz_0").ref("foo")), DontTouchAnnotation(ModuleTarget("Top", "Baz_1").ref("foo")) ) @@ -563,11 +569,13 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { info(output.circuit.serialize) - output.annotations.collect { case a: DontTouchAnnotation => a } should be (Seq( - DontTouchAnnotation(ModuleTarget("Top", "Baz___Bar_asdf").ref("foo")), - DontTouchAnnotation(ModuleTarget("Top", "Baz___Bar_lkj").ref("foo")), - DontTouchAnnotation(baz.ref("foo")) - )) + output.annotations.collect { case a: DontTouchAnnotation => a } should be( + Seq( + DontTouchAnnotation(ModuleTarget("Top", "Baz___Bar_asdf").ref("foo")), + DontTouchAnnotation(ModuleTarget("Top", "Baz___Bar_lkj").ref("foo")), + DontTouchAnnotation(baz.ref("foo")) + ) + ) } property("It should properly rename modules with multiple instances") { @@ -600,6 +608,6 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { val checkDontTouches = (1 to 4).map { i => DummyAnnotation(ModuleTarget("Top", s"Core___System_core_$i")) } - output.annotations.collect { case a: DummyAnnotation => a } should be (checkDontTouches) + output.annotations.collect { case a: DummyAnnotation => a } should be(checkDontTouches) } } diff --git a/src/test/scala/firrtlTests/annotationTests/JsonProtocolSpec.scala b/src/test/scala/firrtlTests/annotationTests/JsonProtocolSpec.scala index 2c817c23..54a94edb 100644 --- a/src/test/scala/firrtlTests/annotationTests/JsonProtocolSpec.scala +++ b/src/test/scala/firrtlTests/annotationTests/JsonProtocolSpec.scala @@ -6,20 +6,20 @@ import firrtl._ import firrtl.annotations.{JsonProtocol, NoTargetAnnotation} import firrtl.ir._ import firrtl.options.Dependency -import _root_.logger.{Logger, LogLevel, LogLevelAnnotation} +import _root_.logger.{LogLevel, LogLevelAnnotation, Logger} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should._ case class AnAnnotation( - info: Info, - cir: Circuit, - mod: DefModule, - port: Port, - statement: Statement, - expr: Expression, - tpe: Type, - groundType: GroundType -) extends NoTargetAnnotation + info: Info, + cir: Circuit, + mod: DefModule, + port: Port, + statement: Statement, + expr: Expression, + tpe: Type, + groundType: GroundType) + extends NoTargetAnnotation class AnnoInjector extends Transform with DependencyAPIMigration { override def optionalPrerequisiteOf = Dependency[ChirrtlEmitter] :: Nil @@ -51,16 +51,16 @@ class JsonProtocolSpec extends AnyFlatSpec with Matchers { val inputAnnos = Seq(AnAnnotation(cir.info, cir, mod, port, stmt, expr, tpe, groundType)) val annosString = JsonProtocol.serialize(inputAnnos) val outputAnnos = JsonProtocol.deserialize(annosString) - inputAnnos should be (outputAnnos) + inputAnnos should be(outputAnnos) } "Annotation serialization during logging" should "not throw an exception" in { val compiler = new firrtl.stage.transforms.Compiler(Seq(Dependency[AnnoInjector])) val circuit = Parser.parse(""" - |circuit test : - | module test : - | output out : UInt<1> - | out <= UInt(0) + |circuit test : + | module test : + | output out : UInt<1> + | out <= UInt(0) """.stripMargin) Logger.makeScope(LogLevelAnnotation(LogLevel.Trace) :: Nil) { compiler.execute(CircuitState(circuit, Nil)) diff --git a/src/test/scala/firrtlTests/annotationTests/MorphismSpec.scala b/src/test/scala/firrtlTests/annotationTests/MorphismSpec.scala index ccf930ba..ac4f2b63 100644 --- a/src/test/scala/firrtlTests/annotationTests/MorphismSpec.scala +++ b/src/test/scala/firrtlTests/annotationTests/MorphismSpec.scala @@ -16,15 +16,15 @@ class MorphismSpec extends AnyFlatSpec with Matchers { } case class AnAnnotation( - target: Option[CompleteTarget], - from: Option[AnAnnotation] = None, - cause: Option[String] = None - ) extends Annotation { + target: Option[CompleteTarget], + from: Option[AnAnnotation] = None, + cause: Option[String] = None) + extends Annotation { override def update(renames: RenameMap): Seq[AnAnnotation] = { if (target.isDefined) { renames.get(target.get) match { - case None => Seq(this) - case Some(Seq()) => Seq(AnAnnotation(None, Some(this))) + case None => Seq(this) + case Some(Seq()) => Seq(AnAnnotation(None, Some(this))) case Some(targets) => //TODO: Add cause of renaming, requires FIRRTL change to RenameMap targets.map { t => AnAnnotation(Some(t), Some(this)) } @@ -60,7 +60,7 @@ class MorphismSpec extends AnyFlatSpec with Matchers { val annotationsx = a.annotations.filter { case a: DeletedAnnotation => false case AnAnnotation(None, _, _) => false - case _: DupedResult => false + case _: DupedResult => false case _: DedupedResult => false case _ => true } @@ -296,8 +296,7 @@ class MorphismSpec extends AnyFlatSpec with Matchers { ) } - - behavior of "EliminateTargetPaths" + behavior.of("EliminateTargetPaths") // NOTE: equivalience is defined structurally in this case trait RightInverseEliminateTargetsFixture extends RightInverseFixture with DefaultExample { @@ -393,24 +392,29 @@ class MorphismSpec extends AnyFlatSpec with Matchers { | inst qux of Baz___Top_qux""".stripMargin override val annotations: AnnotationSeq = Seq( AnAnnotation(CircuitTarget("Top").module("Baz").instOf("foo", "Foo")), - ResolvePaths(Seq( - CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("foo", "Foo"), - CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("foox", "Foo"), - CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("bar", "Bar"), - CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("foo", "Foo"), - CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("foox", "Foo"), - CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("bar", "Bar") - )) + ResolvePaths( + Seq( + CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("foo", "Foo"), + CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("foox", "Foo"), + CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("bar", "Bar"), + CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("foo", "Foo"), + CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("foox", "Foo"), + CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("bar", "Bar") + ) + ) ) - override val finalAnnotations: Option[AnnotationSeq] = Some(Seq( - AnAnnotation(CircuitTarget("Top").module("Foo___Top_qux_foo")), - AnAnnotation(CircuitTarget("Top").module("Foo___Top_baz_foo")) - )) + override val finalAnnotations: Option[AnnotationSeq] = Some( + Seq( + AnAnnotation(CircuitTarget("Top").module("Foo___Top_qux_foo")), + AnAnnotation(CircuitTarget("Top").module("Foo___Top_baz_foo")) + ) + ) test() } it should "be idempotent with per-module annotations" in new IdempotencyEliminateTargetsFixture { + /** An endomorphism */ override val annotations: AnnotationSeq = allModuleInstances.map(AnAnnotation.apply) :+ ResolvePaths(allAbsoluteInstances) @@ -418,6 +422,7 @@ class MorphismSpec extends AnyFlatSpec with Matchers { } it should "be idempotent with per-instance annotations" in new IdempotencyEliminateTargetsFixture { + /** An endomorphism */ override val annotations: AnnotationSeq = allAbsoluteInstances.map(AnAnnotation.apply) :+ ResolvePaths(allAbsoluteInstances) @@ -425,13 +430,14 @@ class MorphismSpec extends AnyFlatSpec with Matchers { } it should "be idempotent with relative module annotations" in new IdempotencyEliminateTargetsFixture { + /** An endomorphism */ override val annotations: AnnotationSeq = allRelative2LevelInstances.map(AnAnnotation.apply) :+ ResolvePaths(allAbsoluteInstances) test() } - behavior of "DedupModules" + behavior.of("DedupModules") trait RightInverseDedupModulesFixture extends RightInverseFixture with DefaultExample { override val f: Seq[Transform] = Seq(new firrtl.annotations.transforms.EliminateTargetPaths) @@ -498,24 +504,29 @@ class MorphismSpec extends AnyFlatSpec with Matchers { | inst qux of Baz""".stripMargin override val annotations: AnnotationSeq = Seq( AnAnnotation(CircuitTarget("Top").module("Baz").instOf("foo", "Foo")), - ResolvePaths(Seq( - CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("foo", "Foo"), - CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("foox", "Foo"), - CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("bar", "Bar"), - CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("foo", "Foo"), - CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("foox", "Foo"), - CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("bar", "Bar") - )) + ResolvePaths( + Seq( + CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("foo", "Foo"), + CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("foox", "Foo"), + CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("bar", "Bar"), + CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("foo", "Foo"), + CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("foox", "Foo"), + CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("bar", "Bar") + ) + ) ) - override val finalAnnotations: Option[AnnotationSeq] = Some(Seq( - AnAnnotation(CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("foo", "Foo")), - AnAnnotation(CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("foo", "Foo")) - )) + override val finalAnnotations: Option[AnnotationSeq] = Some( + Seq( + AnAnnotation(CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("foo", "Foo")), + AnAnnotation(CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("foo", "Foo")) + ) + ) test() } it should "be idempotent with per-module annotations" in new IdempotencyDedupModulesFixture { + /** An endomorphism */ override val annotations: AnnotationSeq = allModuleInstances.map(AnAnnotation.apply) :+ ResolvePaths(allAbsoluteInstances) @@ -523,6 +534,7 @@ class MorphismSpec extends AnyFlatSpec with Matchers { } it should "be idempotent with per-instance annotations" in new IdempotencyDedupModulesFixture { + /** An endomorphism */ override val annotations: AnnotationSeq = allAbsoluteInstances.map(AnAnnotation.apply) :+ ResolvePaths(allAbsoluteInstances) @@ -530,6 +542,7 @@ class MorphismSpec extends AnyFlatSpec with Matchers { } it should "be idempotent with relative module annotations" in new IdempotencyDedupModulesFixture { + /** An endomorphism */ override val annotations: AnnotationSeq = allRelative2LevelInstances.map(AnAnnotation.apply) :+ ResolvePaths(allAbsoluteInstances) diff --git a/src/test/scala/firrtlTests/annotationTests/TargetDirAnnotationSpec.scala b/src/test/scala/firrtlTests/annotationTests/TargetDirAnnotationSpec.scala index cbcd72e9..cc875ea1 100644 --- a/src/test/scala/firrtlTests/annotationTests/TargetDirAnnotationSpec.scala +++ b/src/test/scala/firrtlTests/annotationTests/TargetDirAnnotationSpec.scala @@ -11,7 +11,6 @@ import firrtl.annotations.{Annotation, NoTargetAnnotation} case object FoundTargetDirTransformRanAnnotation extends NoTargetAnnotation case object FoundTargetDirTransformFoundTargetDirAnnotation extends NoTargetAnnotation - /** Looks for [[TargetDirAnnotation]] */ class FindTargetDirTransform extends Transform { def inputForm = HighForm @@ -19,14 +18,15 @@ class FindTargetDirTransform extends Transform { def execute(state: CircuitState): CircuitState = { val a: Option[Annotation] = state.annotations.collectFirst { - case TargetDirAnnotation("a/b/c") => FoundTargetDirTransformFoundTargetDirAnnotation } + case TargetDirAnnotation("a/b/c") => FoundTargetDirTransformFoundTargetDirAnnotation + } state.copy(annotations = state.annotations ++ a ++ Some(FoundTargetDirTransformRanAnnotation)) } } class TargetDirAnnotationSpec extends FirrtlFlatSpec { - behavior of "The target directory" + behavior.of("The target directory") val input = """circuit Top : @@ -41,37 +41,35 @@ class TargetDirAnnotationSpec extends FirrtlFlatSpec { val findTargetDir = new FindTargetDirTransform // looks for the annotation val optionsManager = new ExecutionOptionsManager("TargetDir") with HasFirrtlOptions { - commonOptions = commonOptions.copy(targetDirName = targetDir, - topName = "Top") - firrtlOptions = firrtlOptions.copy(compilerName = "high", - firrtlSource = Some(input), - customTransforms = Seq(findTargetDir)) + commonOptions = commonOptions.copy(targetDirName = targetDir, topName = "Top") + firrtlOptions = + firrtlOptions.copy(compilerName = "high", firrtlSource = Some(input), customTransforms = Seq(findTargetDir)) } val annotations: Seq[Annotation] = Driver.execute(optionsManager) match { case a: FirrtlExecutionSuccess => a.circuitState.annotations case _ => fail } - annotations should contain (FoundTargetDirTransformRanAnnotation) - annotations should contain (FoundTargetDirTransformFoundTargetDirAnnotation) + annotations should contain(FoundTargetDirTransformRanAnnotation) + annotations should contain(FoundTargetDirTransformFoundTargetDirAnnotation) // Delete created directory val dir = new java.io.File(targetDir) - dir.exists should be (true) - FileUtils.deleteDirectoryHierarchy("a") should be (true) + dir.exists should be(true) + FileUtils.deleteDirectoryHierarchy("a") should be(true) } it should "NOT be available as an annotation when using a raw compiler" in { val findTargetDir = new FindTargetDirTransform // looks for the annotation val compiler = new VerilogCompiler - val circuit = Parser.parse(input split "\n") + val circuit = Parser.parse(input.split("\n")) val annotations: Seq[Annotation] = compiler .compileAndEmit(CircuitState(circuit, HighForm), Seq(findTargetDir)) .annotations // Check that FindTargetDirTransform does not find the annotation - annotations should contain (FoundTargetDirTransformRanAnnotation) + annotations should contain(FoundTargetDirTransformRanAnnotation) annotations should not contain (FoundTargetDirTransformFoundTargetDirAnnotation) } } diff --git a/src/test/scala/firrtlTests/annotationTests/TargetSpec.scala b/src/test/scala/firrtlTests/annotationTests/TargetSpec.scala index 641eeb99..48f27faa 100644 --- a/src/test/scala/firrtlTests/annotationTests/TargetSpec.scala +++ b/src/test/scala/firrtlTests/annotationTests/TargetSpec.scala @@ -24,8 +24,9 @@ class TargetSpec extends FirrtlPropSpec { (top.ref("r").index(1).field("hi").clock, "~Circuit|Top>r[1].hi@clock"), (GenericTarget(None, None, Vector(Ref("r"))), "~???|???>r") ) - targets.foreach { case (t, str) => - assert(t.serialize == str, s"$t does not properly serialize") + targets.foreach { + case (t, str) => + assert(t.serialize == str, s"$t does not properly serialize") } } property("Should convert to/from Named") { @@ -38,7 +39,7 @@ class TargetSpec extends FirrtlPropSpec { check(Target(Some("Top"), Some("Top"), r2)) } property("Should enable creating from API") { - val top = ModuleTarget("Top","Top") + val top = ModuleTarget("Top", "Top") val x_reg0_data = top.instOf("x", "X").ref("reg0").field("data") top.instOf("x", "x") top.ref("y") @@ -47,8 +48,14 @@ class TargetSpec extends FirrtlPropSpec { val circuit = CircuitTarget("Circuit") val top = circuit.module("Top") val targets: Seq[Target] = - Seq(circuit, top, top.instOf("i", "I"), top.ref("r"), - top.ref("r").index(1).field("hi").clock, GenericTarget(None, None, Vector(Ref("r")))) + Seq( + circuit, + top, + top.instOf("i", "I"), + top.ref("r"), + top.ref("r").index(1).field("hi").clock, + GenericTarget(None, None, Vector(Ref("r"))) + ) targets.foreach { t => assert(Target.deserialize(t.serialize) == t, s"$t does not properly serialize/deserialize") } @@ -58,25 +65,20 @@ class TargetSpec extends FirrtlPropSpec { val top = circuit.module("B") val targets = Seq( (circuit, "circuit A:"), - (top, - """|circuit A: - |└── module B:""".stripMargin), - (top.instOf("c", "C"), - """|circuit A: - |└── module B: - | └── inst c of C:""".stripMargin), - (top.ref("r"), - """|circuit A: - |└── module B: - | └── r""".stripMargin), - (top.ref("r").index(1).field("hi").clock, - """|circuit A: - |└── module B: - | └── r[1].hi@clock""".stripMargin), - (GenericTarget(None, None, Vector(Ref("r"))), - """|circuit ???: - |└── module ???: - | └── r""".stripMargin) + (top, """|circuit A: + |└── module B:""".stripMargin), + (top.instOf("c", "C"), """|circuit A: + |└── module B: + | └── inst c of C:""".stripMargin), + (top.ref("r"), """|circuit A: + |└── module B: + | └── r""".stripMargin), + (top.ref("r").index(1).field("hi").clock, """|circuit A: + |└── module B: + | └── r[1].hi@clock""".stripMargin), + (GenericTarget(None, None, Vector(Ref("r"))), """|circuit ???: + |└── module ???: + | └── r""".stripMargin) ) targets.foreach { case (t, str) => assert(t.prettyPrint() == str, s"$t didn't properly prettyPrint") } } diff --git a/src/test/scala/firrtlTests/constraint/InequalitySpec.scala b/src/test/scala/firrtlTests/constraint/InequalitySpec.scala index 8b26c80c..68db6873 100644 --- a/src/test/scala/firrtlTests/constraint/InequalitySpec.scala +++ b/src/test/scala/firrtlTests/constraint/InequalitySpec.scala @@ -7,101 +7,109 @@ import org.scalatest.matchers.should.Matchers class InequalitySpec extends AnyFlatSpec with Matchers { - behavior of "Constraints" + behavior.of("Constraints") "IsConstraints" should "reduce properly" in { - IsMin(Closed(0), Closed(1)) should be (Closed(0)) - IsMin(Closed(-1), Closed(1)) should be (Closed(-1)) - IsMax(Closed(-1), Closed(1)) should be (Closed(1)) - IsNeg(IsMul(Closed(-1), Closed(-2))) should be (Closed(-2)) + IsMin(Closed(0), Closed(1)) should be(Closed(0)) + IsMin(Closed(-1), Closed(1)) should be(Closed(-1)) + IsMax(Closed(-1), Closed(1)) should be(Closed(1)) + IsNeg(IsMul(Closed(-1), Closed(-2))) should be(Closed(-2)) val x = IsMin(IsMul(Closed(1), VarCon("a")), Closed(2)) - x.children.toSet should be (IsMin(Closed(2), IsMul(Closed(1), VarCon("a"))).children.toSet) + x.children.toSet should be(IsMin(Closed(2), IsMul(Closed(1), VarCon("a"))).children.toSet) } "IsAdd" should "reduce properly" in { // All constants - IsAdd(Closed(-1), Closed(1)) should be (Closed(0)) + IsAdd(Closed(-1), Closed(1)) should be(Closed(0)) // Pull Out IsMax - IsAdd(Closed(1), IsMax(Closed(1), VarCon("a"))) should be (IsMax(Closed(2), IsAdd(VarCon("a"), Closed(1)))) - IsAdd(Closed(1), IsMax(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be ( + IsAdd(Closed(1), IsMax(Closed(1), VarCon("a"))) should be(IsMax(Closed(2), IsAdd(VarCon("a"), Closed(1)))) + IsAdd(Closed(1), IsMax(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be( IsMax(Seq(Closed(2), IsAdd(VarCon("a"), Closed(1)), IsAdd(VarCon("b"), Closed(1)))) ) // Pull Out IsMin - IsAdd(Closed(1), IsMin(Closed(1), VarCon("a"))) should be (IsMin(Closed(2), IsAdd(VarCon("a"), Closed(1)))) - IsAdd(Closed(1), IsMin(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be ( + IsAdd(Closed(1), IsMin(Closed(1), VarCon("a"))) should be(IsMin(Closed(2), IsAdd(VarCon("a"), Closed(1)))) + IsAdd(Closed(1), IsMin(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be( IsMin(Seq(Closed(2), IsAdd(VarCon("a"), Closed(1)), IsAdd(VarCon("b"), Closed(1)))) ) // Add Zero - IsAdd(Closed(0), VarCon("a")) should be (VarCon("a")) + IsAdd(Closed(0), VarCon("a")) should be(VarCon("a")) // One argument - IsAdd(Seq(VarCon("a"))) should be (VarCon("a")) + IsAdd(Seq(VarCon("a"))) should be(VarCon("a")) } "IsMax" should "reduce properly" in { // All constants - IsMax(Closed(-1), Closed(1)) should be (Closed(1)) + IsMax(Closed(-1), Closed(1)) should be(Closed(1)) // Flatten nested IsMax - IsMax(Closed(1), IsMax(Closed(1), VarCon("a"))) should be (IsMax(Closed(1), VarCon("a"))) - IsMax(Closed(1), IsMax(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be ( + IsMax(Closed(1), IsMax(Closed(1), VarCon("a"))) should be(IsMax(Closed(1), VarCon("a"))) + IsMax(Closed(1), IsMax(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be( IsMax(Seq(Closed(1), VarCon("a"), VarCon("b"))) ) // Eliminate IsMins if possible - IsMax(Closed(2), IsMin(Closed(1), VarCon("a"))) should be (Closed(2)) - IsMax(Seq( - Closed(2), - IsMin(Closed(1), VarCon("a")), - IsMin(Closed(3), VarCon("b")) - )) should be ( - IsMax(Seq( + IsMax(Closed(2), IsMin(Closed(1), VarCon("a"))) should be(Closed(2)) + IsMax( + Seq( Closed(2), + IsMin(Closed(1), VarCon("a")), IsMin(Closed(3), VarCon("b")) - )) + ) + ) should be( + IsMax( + Seq( + Closed(2), + IsMin(Closed(3), VarCon("b")) + ) + ) ) // One argument - IsMax(Seq(VarCon("a"))) should be (VarCon("a")) - IsMax(Seq(Closed(0))) should be (Closed(0)) - IsMax(Seq(IsMin(VarCon("a"), Closed(0)))) should be (IsMin(VarCon("a"), Closed(0))) + IsMax(Seq(VarCon("a"))) should be(VarCon("a")) + IsMax(Seq(Closed(0))) should be(Closed(0)) + IsMax(Seq(IsMin(VarCon("a"), Closed(0)))) should be(IsMin(VarCon("a"), Closed(0))) } "IsMin" should "reduce properly" in { // All constants - IsMin(Closed(-1), Closed(1)) should be (Closed(-1)) + IsMin(Closed(-1), Closed(1)) should be(Closed(-1)) // Flatten nested IsMin - IsMin(Closed(1), IsMin(Closed(1), VarCon("a"))) should be (IsMin(Closed(1), VarCon("a"))) - IsMin(Closed(1), IsMin(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be ( + IsMin(Closed(1), IsMin(Closed(1), VarCon("a"))) should be(IsMin(Closed(1), VarCon("a"))) + IsMin(Closed(1), IsMin(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be( IsMin(Seq(Closed(1), VarCon("a"), VarCon("b"))) ) // Eliminate IsMaxs if possible - IsMin(Closed(1), IsMax(Closed(2), VarCon("a"))) should be (Closed(1)) - IsMin(Seq( - Closed(2), - IsMax(Closed(1), VarCon("a")), - IsMax(Closed(3), VarCon("b")) - )) should be ( - IsMin(Seq( + IsMin(Closed(1), IsMax(Closed(2), VarCon("a"))) should be(Closed(1)) + IsMin( + Seq( Closed(2), - IsMax(Closed(1), VarCon("a")) - )) + IsMax(Closed(1), VarCon("a")), + IsMax(Closed(3), VarCon("b")) + ) + ) should be( + IsMin( + Seq( + Closed(2), + IsMax(Closed(1), VarCon("a")) + ) + ) ) // One argument - IsMin(Seq(VarCon("a"))) should be (VarCon("a")) - IsMin(Seq(Closed(0))) should be (Closed(0)) - IsMin(Seq(IsMax(VarCon("a"), Closed(0)))) should be (IsMax(VarCon("a"), Closed(0))) + IsMin(Seq(VarCon("a"))) should be(VarCon("a")) + IsMin(Seq(Closed(0))) should be(Closed(0)) + IsMin(Seq(IsMax(VarCon("a"), Closed(0)))) should be(IsMax(VarCon("a"), Closed(0))) } "IsMul" should "reduce properly" in { // All constants - IsMul(Closed(2), Closed(3)) should be (Closed(6)) + IsMul(Closed(2), Closed(3)) should be(Closed(6)) // Pull out max, if positive stays max IsMul(Closed(2), IsMax(Closed(3), VarCon("a"))) should be( @@ -124,75 +132,74 @@ class InequalitySpec extends AnyFlatSpec with Matchers { ) // Times zero - IsMul(Closed(0), VarCon("x")) should be (Closed(0)) + IsMul(Closed(0), VarCon("x")) should be(Closed(0)) // Times 1 - IsMul(Closed(1), VarCon("x")) should be (VarCon("x")) + IsMul(Closed(1), VarCon("x")) should be(VarCon("x")) // One argument - IsMul(Seq(Closed(0))) should be (Closed(0)) - IsMul(Seq(VarCon("a"))) should be (VarCon("a")) + IsMul(Seq(Closed(0))) should be(Closed(0)) + IsMul(Seq(VarCon("a"))) should be(VarCon("a")) // No optimizations val isMax = IsMax(VarCon("x"), VarCon("y")) val isMin = IsMin(VarCon("x"), VarCon("y")) val a = VarCon("a") - IsMul(a, isMax).children should be (Vector(a, isMax)) //non-known multiply - IsMul(a, isMin).children should be (Vector(a, isMin)) //non-known multiply - IsMul(Seq(Closed(2), isMin, isMin)).children should be (Vector(Closed(2), isMin, isMin)) //>1 min - IsMul(Seq(Closed(2), isMax, isMax)).children should be (Vector(Closed(2), isMax, isMax)) //>1 max - IsMul(Seq(Closed(2), isMin, isMax)).children should be (Vector(Closed(2), isMin, isMax)) //mixed min/max + IsMul(a, isMax).children should be(Vector(a, isMax)) //non-known multiply + IsMul(a, isMin).children should be(Vector(a, isMin)) //non-known multiply + IsMul(Seq(Closed(2), isMin, isMin)).children should be(Vector(Closed(2), isMin, isMin)) //>1 min + IsMul(Seq(Closed(2), isMax, isMax)).children should be(Vector(Closed(2), isMax, isMax)) //>1 max + IsMul(Seq(Closed(2), isMin, isMax)).children should be(Vector(Closed(2), isMin, isMax)) //mixed min/max } "IsNeg" should "reduce properly" in { // All constants - IsNeg(Closed(1)) should be (Closed(-1)) + IsNeg(Closed(1)) should be(Closed(-1)) // Pull out max - IsNeg(IsMax(Closed(1), VarCon("a"))) should be (IsMin(Closed(-1), IsNeg(VarCon("a")))) + IsNeg(IsMax(Closed(1), VarCon("a"))) should be(IsMin(Closed(-1), IsNeg(VarCon("a")))) // Pull out min - IsNeg(IsMin(Closed(1), VarCon("a"))) should be (IsMax(Closed(-1), IsNeg(VarCon("a")))) + IsNeg(IsMin(Closed(1), VarCon("a"))) should be(IsMax(Closed(-1), IsNeg(VarCon("a")))) // Pull out add - IsNeg(IsAdd(Closed(1), VarCon("a"))) should be (IsAdd(Closed(-1), IsNeg(VarCon("a")))) + IsNeg(IsAdd(Closed(1), VarCon("a"))) should be(IsAdd(Closed(-1), IsNeg(VarCon("a")))) // Pull out mul - IsNeg(IsMul(Closed(2), VarCon("a"))) should be (IsMul(Closed(-2), VarCon("a"))) + IsNeg(IsMul(Closed(2), VarCon("a"))) should be(IsMul(Closed(-2), VarCon("a"))) // No optimizations // (pow), (floor?) - IsNeg(IsPow(VarCon("x"))).children should be (Vector(IsPow(VarCon("x")))) - IsNeg(IsFloor(VarCon("x"))).children should be (Vector(IsFloor(VarCon("x")))) + IsNeg(IsPow(VarCon("x"))).children should be(Vector(IsPow(VarCon("x")))) + IsNeg(IsFloor(VarCon("x"))).children should be(Vector(IsFloor(VarCon("x")))) } "IsPow" should "reduce properly" in { // All constants - IsPow(Closed(1)) should be (Closed(2)) + IsPow(Closed(1)) should be(Closed(2)) // Pull out max - IsPow(IsMax(Closed(1), VarCon("a"))) should be (IsMax(Closed(2), IsPow(VarCon("a")))) + IsPow(IsMax(Closed(1), VarCon("a"))) should be(IsMax(Closed(2), IsPow(VarCon("a")))) // Pull out min - IsPow(IsMin(Closed(1), VarCon("a"))) should be (IsMin(Closed(2), IsPow(VarCon("a")))) + IsPow(IsMin(Closed(1), VarCon("a"))) should be(IsMin(Closed(2), IsPow(VarCon("a")))) // Pull out add - IsPow(IsAdd(Closed(1), VarCon("a"))) should be (IsMul(Closed(2), IsPow(VarCon("a")))) + IsPow(IsAdd(Closed(1), VarCon("a"))) should be(IsMul(Closed(2), IsPow(VarCon("a")))) // No optimizations // (mul), (pow), (floor?) - IsPow(IsMul(Closed(2), VarCon("x"))).children should be (Vector(IsMul(Closed(2), VarCon("x")))) - IsPow(IsPow(VarCon("x"))).children should be (Vector(IsPow(VarCon("x")))) - IsPow(IsFloor(VarCon("x"))).children should be (Vector(IsFloor(VarCon("x")))) + IsPow(IsMul(Closed(2), VarCon("x"))).children should be(Vector(IsMul(Closed(2), VarCon("x")))) + IsPow(IsPow(VarCon("x"))).children should be(Vector(IsPow(VarCon("x")))) + IsPow(IsFloor(VarCon("x"))).children should be(Vector(IsFloor(VarCon("x")))) } "IsFloor" should "reduce properly" in { // All constants - IsFloor(Closed(1.9)) should be (Closed(1)) - IsFloor(Closed(-1.9)) should be (Closed(-2)) + IsFloor(Closed(1.9)) should be(Closed(1)) + IsFloor(Closed(-1.9)) should be(Closed(-2)) // Pull out max - IsFloor(IsMax(Closed(1.9), VarCon("a"))) should be (IsMax(Closed(1), IsFloor(VarCon("a")))) + IsFloor(IsMax(Closed(1.9), VarCon("a"))) should be(IsMax(Closed(1), IsFloor(VarCon("a")))) // Pull out min - IsFloor(IsMin(Closed(1.9), VarCon("a"))) should be (IsMin(Closed(1), IsFloor(VarCon("a")))) + IsFloor(IsMin(Closed(1.9), VarCon("a"))) should be(IsMin(Closed(1), IsFloor(VarCon("a")))) // Cancel with another floor - IsFloor(IsFloor(VarCon("a"))) should be (IsFloor(VarCon("a"))) + IsFloor(IsFloor(VarCon("a"))) should be(IsFloor(VarCon("a"))) // No optimizations // (add), (mul), (pow) - IsFloor(IsMul(Closed(2), VarCon("x"))).children should be (Vector(IsMul(Closed(2), VarCon("x")))) - IsFloor(IsPow(VarCon("x"))).children should be (Vector(IsPow(VarCon("x")))) - IsFloor(IsAdd(Closed(1), VarCon("x"))).children should be (Vector(IsAdd(Closed(1), VarCon("x")))) + IsFloor(IsMul(Closed(2), VarCon("x"))).children should be(Vector(IsMul(Closed(2), VarCon("x")))) + IsFloor(IsPow(VarCon("x"))).children should be(Vector(IsPow(VarCon("x")))) + IsFloor(IsAdd(Closed(1), VarCon("x"))).children should be(Vector(IsAdd(Closed(1), VarCon("x")))) } } - diff --git a/src/test/scala/firrtlTests/execution/ExecutionTestHelper.scala b/src/test/scala/firrtlTests/execution/ExecutionTestHelper.scala index 7d250664..0a50e53e 100644 --- a/src/test/scala/firrtlTests/execution/ExecutionTestHelper.scala +++ b/src/test/scala/firrtlTests/execution/ExecutionTestHelper.scala @@ -31,18 +31,18 @@ object ExecutionTestHelper { // Generate test step counter, create ExecutionTestHelper that represents initial test state val cnt = DefRegister(NoInfo, DUTRules.counter.name, counterType, DUTRules.clock, DUTRules.reset, Utils.zero) - val inc = Connect(NoInfo, DUTRules.counter, DoPrim(PrimOps.Add, Seq(DUTRules.counter, UIntLiteral(1)), Nil, UnknownType)) + val inc = + Connect(NoInfo, DUTRules.counter, DoPrim(PrimOps.Add, Seq(DUTRules.counter, UIntLiteral(1)), Nil, UnknownType)) ExecutionTestHelper(c, Seq(cnt, inc), Map.empty[Expression, Expression], Nil, Nil) } } case class ExecutionTestHelper( - dut: Circuit, - setup: Seq[Statement], - pokeRegs: Map[Expression, Expression], + dut: Circuit, + setup: Seq[Statement], + pokeRegs: Map[Expression, Expression], completedSteps: Seq[Conditionally], - activeStep: Seq[Statement] -) { + activeStep: Seq[Statement]) { def step(n: Int): ExecutionTestHelper = { require(n > 0, "Step length must be positive") @@ -52,9 +52,7 @@ case class ExecutionTestHelper( def poke(expString: String, value: Literal): ExecutionTestHelper = { val pokeExp = ParseExpression(expString) val pokeable = ensurePokeable(pokeExp) - pokeable.addStatements( - Connect(NoInfo, pokeExp, value), - Connect(NoInfo, pokeable.pokeRegs(pokeExp), value)) + pokeable.addStatements(Connect(NoInfo, pokeExp, value), Connect(NoInfo, pokeable.pokeRegs(pokeExp), value)) } def invalidate(expString: String): ExecutionTestHelper = { @@ -85,7 +83,7 @@ case class ExecutionTestHelper( } private def top: Module = { - dut.modules.collectFirst({ case m: Module if m.name == dut.main => m }).get + dut.modules.collectFirst({ case m: Module if m.name == dut.main => m }).get } private[execution] def emit: Circuit = { diff --git a/src/test/scala/firrtlTests/execution/ParserHelpers.scala b/src/test/scala/firrtlTests/execution/ParserHelpers.scala index 3472c19c..1f74d634 100644 --- a/src/test/scala/firrtlTests/execution/ParserHelpers.scala +++ b/src/test/scala/firrtlTests/execution/ParserHelpers.scala @@ -14,10 +14,10 @@ object ParseStatement { val indent = " " val indented = stmtStr.split("\n").mkString(indent, s"\n${indent}", "") s"""circuit ${DUTRules.dutName} : - | module ${DUTRules.dutName} : - | input clock : Clock - | input reset : UInt<1> - |${indented}""".stripMargin + | module ${DUTRules.dutName} : + | input clock : Clock + | input reset : UInt<1> + |${indented}""".stripMargin } private def parse(stmtStr: String): Circuit = { diff --git a/src/test/scala/firrtlTests/execution/SimpleExecutionTest.scala b/src/test/scala/firrtlTests/execution/SimpleExecutionTest.scala index 2654f476..911f7485 100644 --- a/src/test/scala/firrtlTests/execution/SimpleExecutionTest.scala +++ b/src/test/scala/firrtlTests/execution/SimpleExecutionTest.scala @@ -20,19 +20,19 @@ trait TestExecution { /** * A class that makes it easier to write execution-driven tests. - * + * * By combining a DUT body (supplied as a string without an enclosing * module or circuit) with a sequence of test operations, an * executable, self-contained Verilog testbench may be automatically * created and checked. - * + * * @note It is necessary to mix in a trait extending TestExecution * @note The DUT has two implicit ports, "clock" and "reset" * @note Execution of the command sequences begins after reset is deasserted - * + * * @see [[firrtlTests.execution.TestExecution]] * @see [[firrtlTests.execution.VerilogExecution]] - * + * * @example {{{ * class AndTester extends SimpleExecutionTest with VerilogExecution { * val body = "reg r : UInt<32>, clock with: (reset => (reset, UInt<32>(0)))" @@ -64,9 +64,9 @@ abstract class SimpleExecutionTest extends FirrtlPropSpec { def commands: Seq[SimpleTestCommand] private def interpretCommand(eth: ExecutionTestHelper, cmd: SimpleTestCommand) = cmd match { - case Step(n) => eth.step(n) - case Invalidate(expStr) => eth.invalidate(expStr) - case Poke(expStr, value) => eth.poke(expStr, UIntLiteral(value)) + case Step(n) => eth.step(n) + case Invalidate(expStr) => eth.invalidate(expStr) + case Poke(expStr, value) => eth.poke(expStr, UIntLiteral(value)) case Expect(expStr, value) => eth.expect(expStr, UIntLiteral(value)) } diff --git a/src/test/scala/firrtlTests/execution/VerilogExecution.scala b/src/test/scala/firrtlTests/execution/VerilogExecution.scala index 89f27609..913cfc71 100644 --- a/src/test/scala/firrtlTests/execution/VerilogExecution.scala +++ b/src/test/scala/firrtlTests/execution/VerilogExecution.scala @@ -30,7 +30,7 @@ trait VerilogExecution extends TestExecution { // Make and run Verilog simulation verilogToCpp(c.main, testDir, Nil, harness) #&& - cppToExe(c.main, testDir) ! loggingProcessLogger + cppToExe(c.main, testDir) ! loggingProcessLogger assert(executeExpectingSuccess(c.main, testDir)) } } diff --git a/src/test/scala/firrtlTests/features/LetterCaseTransformSpec.scala b/src/test/scala/firrtlTests/features/LetterCaseTransformSpec.scala index 4299ac7f..7ab80387 100644 --- a/src/test/scala/firrtlTests/features/LetterCaseTransformSpec.scala +++ b/src/test/scala/firrtlTests/features/LetterCaseTransformSpec.scala @@ -15,7 +15,7 @@ import org.scalatest.matchers.should.Matchers class LetterCaseTransformSpec extends AnyFlatSpec with Matchers { case class TrackingAnnotation(val target: IsMember) extends SingleTargetAnnotation[IsMember] { - override def duplicate(a: IsMember) = this.copy(target=a) + override def duplicate(a: IsMember) = this.copy(target = a) } class CircuitFixture { @@ -66,72 +66,94 @@ class LetterCaseTransformSpec extends AnyFlatSpec with Matchers { private val Foo = CircuitTarget("Foo") private val Bar = Foo.module("Bar") - val annotations = Seq(TrackingAnnotation(Foo.module("Foo").ref("MeM").field("wRITE")field("en")), - ManipulateNamesBlocklistAnnotation(Seq(Seq(Bar)), Dependency[LowerCaseNames]), - ManipulateNamesBlocklistAnnotation(Seq(Seq(Bar.ref("OuT"))), Dependency[UpperCaseNames])) + val annotations = Seq( + TrackingAnnotation(Foo.module("Foo").ref("MeM").field("wRITE").field("en")), + ManipulateNamesBlocklistAnnotation(Seq(Seq(Bar)), Dependency[LowerCaseNames]), + ManipulateNamesBlocklistAnnotation(Seq(Seq(Bar.ref("OuT"))), Dependency[UpperCaseNames]) + ) val state = CircuitState(Parser.parse(input), annotations) } - behavior of "LowerCaseNames" + behavior.of("LowerCaseNames") it should "change all names to lowercase" in new CircuitFixture { val tm = new firrtl.stage.transforms.Compiler(Seq(firrtl.options.Dependency[LowerCaseNames])) val statex = tm.execute(state) val expected: Seq[PartialFunction[Any, Boolean]] = Seq( { case ir.Circuit(_, _, "foo") => true }, - { case ir.Module(_, "foo", - Seq(ir.Port(_, "clk", _, _), ir.Port(_, "rst_p", _, _), ir.Port(_, "addr", _, _)), _) => true }, - /* Module "Bar" should be skipped via a ManipulateNamesBlocklistAnnotation */ - { case ir.Module(_, "Bar", Seq(ir.Port(_, "out", _, _)), _) => true }, + { + case ir + .Module(_, "foo", Seq(ir.Port(_, "clk", _, _), ir.Port(_, "rst_p", _, _), ir.Port(_, "addr", _, _)), _) => + true + }, + /* Module "Bar" should be skipped via a ManipulateNamesBlocklistAnnotation */ { + case ir.Module(_, "Bar", Seq(ir.Port(_, "out", _, _)), _) => true + }, { case ir.Module(_, "baz_0", Seq(ir.Port(_, "out", _, _)), _) => true }, { case ir.Module(_, "baz", Seq(ir.Port(_, "out", _, _)), _) => true }, - /* External module "Ext" is not renamed */ - { case ir.ExtModule(_, "Ext", Seq(ir.Port(_, "OuT", _, _)), _, _) => true }, + /* External module "Ext" is not renamed */ { + case ir.ExtModule(_, "Ext", Seq(ir.Port(_, "OuT", _, _)), _, _) => true + }, { case ir.DefNode(_, "bar", _) => true }, { case ir.DefRegister(_, "baz", _, WRef("clk", _, _, _), WRef("rst_p", _, _, _), WRef("bar", _, _, _)) => true }, { case ir.DefWire(_, "qux", _) => true }, { case ir.Connect(_, WRef("qux", _, _, _), _) => true }, { case ir.DefNode(_, "quuxquux", _) => true }, { case ir.DefMemory(_, "mem", _, _, _, _, Seq("read"), Seq("write"), Seq("rw"), _) => true }, - /* Ports of memories should be ignored, but these are already lower case */ - { case ir.IsInvalid(_, WSubField(WSubField(WRef("mem", _, _, _), "read", _, _), "addr", _, _)) => true }, + /* Ports of memories should be ignored, but these are already lower case */ { + case ir.IsInvalid(_, WSubField(WSubField(WRef("mem", _, _, _), "read", _, _), "addr", _, _)) => true + }, { case ir.IsInvalid(_, WSubField(WSubField(WRef("mem", _, _, _), "write", _, _), "addr", _, _)) => true }, { case ir.IsInvalid(_, WSubField(WSubField(WRef("mem", _, _, _), "rw", _, _), "addr", _, _)) => true }, /* Module "Bar" was skipped via a ManipulateNamesBlocklistAnnotation. The instance "SuB1" is renamed to "sub1_0" * because node "sub1" already exists. This differs from the upper case test. - */ - { case WDefInstance(_, "sub1_0", "Bar", _) => true }, + */ { case WDefInstance(_, "sub1_0", "Bar", _) => true }, { case WDefInstance(_, "sub2", "baz_0", _) => true }, { case WDefInstance(_, "sub3", "baz", _) => true }, - /* External module instance names are renamed */ - { case WDefInstance(_, "sub4", "Ext", _) => true }, + /* External module instance names are renamed */ { case WDefInstance(_, "sub4", "Ext", _) => true }, { case ir.DefNode(_, "sub1", _) => true }, { case ir.DefNode(_, "corge_corge", WSubField(WRef("sub1_0", _, _, _), "out", _, _)) => true }, - { case ir.DefNode(_, "quuzquuz", - ir.DoPrim(_,Seq(WSubField(WRef("sub2", _, _, _), "out", _, _), - WSubField(WRef("sub3", _, _, _), "out", _, _)), _, _)) => true }, - /* References to external module ports are not renamed, e.g., OuT */ - { case ir.DefNode(_, "graultgrault", - ir.DoPrim(_, Seq(WSubField(WRef("sub4", _, _, _), "OuT", _, _)), _, _)) => true } + { + case ir.DefNode( + _, + "quuzquuz", + ir.DoPrim( + _, + Seq(WSubField(WRef("sub2", _, _, _), "out", _, _), WSubField(WRef("sub3", _, _, _), "out", _, _)), + _, + _ + ) + ) => + true + }, + /* References to external module ports are not renamed, e.g., OuT */ { + case ir.DefNode(_, "graultgrault", ir.DoPrim(_, Seq(WSubField(WRef("sub4", _, _, _), "OuT", _, _)), _, _)) => + true + } ) - expected.foreach( statex should containTree (_) ) + expected.foreach(statex should containTree(_)) } - behavior of "UpperCaseNames" + behavior.of("UpperCaseNames") it should "change all names to uppercase" in new CircuitFixture { val tm = new firrtl.stage.transforms.Compiler(Seq(firrtl.options.Dependency[UpperCaseNames])) val statex = tm.execute(state) val expected: Seq[PartialFunction[Any, Boolean]] = Seq( { case ir.Circuit(_, _, "FOO") => true }, - { case ir.Module(_, "FOO", - Seq(ir.Port(_, "CLK", _, _), ir.Port(_, "RST_P", _, _), ir.Port(_, "ADDR", _, _)), _) => true }, - /* "Bar>OuT" should be skipped via a ManipulateNamesBlocklistAnnotation */ - { case ir.Module(_, "BAR", Seq(ir.Port(_, "OuT", _, _)), _) => true }, + { + case ir + .Module(_, "FOO", Seq(ir.Port(_, "CLK", _, _), ir.Port(_, "RST_P", _, _), ir.Port(_, "ADDR", _, _)), _) => + true + }, + /* "Bar>OuT" should be skipped via a ManipulateNamesBlocklistAnnotation */ { + case ir.Module(_, "BAR", Seq(ir.Port(_, "OuT", _, _)), _) => true + }, { case ir.Module(_, "BAZ", Seq(ir.Port(_, "OUT", _, _)), _) => true }, { case ir.Module(_, "BAZ_0", Seq(ir.Port(_, "OUT", _, _)), _) => true }, - /* External module "Ext" is not renamed */ - { case ir.ExtModule(_, "Ext", Seq(ir.Port(_, "OuT", _, _)), _, _) => true }, + /* External module "Ext" is not renamed */ { + case ir.ExtModule(_, "Ext", Seq(ir.Port(_, "OuT", _, _)), _, _) => true + }, { case ir.DefNode(_, "BAR", _) => true }, { case ir.DefRegister(_, "BAZ", _, WRef("CLK", _, _, _), WRef("RST_P", _, _, _), WRef("BAR", _, _, _)) => true }, { case ir.DefWire(_, "QUX", _) => true }, @@ -140,28 +162,42 @@ class LetterCaseTransformSpec extends AnyFlatSpec with Matchers { { case ir.DefMemory(_, "MEM", _, _, _, _, Seq("READ"), Seq("WRITE"), Seq("RW"), _) => true }, /* Ports of memories should be ignored while readers/writers are renamed, e.g., "Read" is converted to upper case * while "addr" is not touched. - */ - { case ir.IsInvalid(_, WSubField(WSubField(WRef("MEM", _, _, _), "READ", _, _), "addr", _, _)) => true }, + */ { case ir.IsInvalid(_, WSubField(WSubField(WRef("MEM", _, _, _), "READ", _, _), "addr", _, _)) => true }, { case ir.IsInvalid(_, WSubField(WSubField(WRef("MEM", _, _, _), "WRITE", _, _), "addr", _, _)) => true }, { case ir.IsInvalid(_, WSubField(WSubField(WRef("MEM", _, _, _), "RW", _, _), "addr", _, _)) => true }, { case WDefInstance(_, "SUB1", "BAR", _) => true }, - /* Instance "SuB2" and "SuB3" switch their modules from the lower case test due to namespace behavior. */ - { case WDefInstance(_, "SUB2", "BAZ", _) => true }, + /* Instance "SuB2" and "SuB3" switch their modules from the lower case test due to namespace behavior. */ { + case WDefInstance(_, "SUB2", "BAZ", _) => true + }, { case WDefInstance(_, "SUB3", "BAZ_0", _) => true }, - /* External module "Ext" was skipped via a ManipulateBlocklistAnnotation */ - { case WDefInstance(_, "SUB4", "Ext", _) => true }, - /* Node "sub1" becomes "SUB1_0" because instance "SuB1" already got the "SUB1" name. */ - { case ir.DefNode(_, "SUB1_0", _) => true }, - /* Port "OuT" was skipped via a ManipulateNamesBlocklistAnnotation */ - { case ir.DefNode(_, "CORGE_CORGE", WSubField(WRef("SUB1", _, _, _), "OuT", _, _)) => true }, - { case ir.DefNode(_, "QUUZQUUZ", - ir.DoPrim(_,Seq(WSubField(WRef("SUB2", _, _, _), "OUT", _, _), - WSubField(WRef("SUB3", _, _, _), "OUT", _, _)), _, _)) => true }, - /* References to external module ports are not renamed, e.g., "OuT" */ - { case ir.DefNode(_, "GRAULTGRAULT", - ir.DoPrim(_, Seq(WSubField(WRef("SUB4", _, _, _), "OuT", _, _)), _, _)) => true } + /* External module "Ext" was skipped via a ManipulateBlocklistAnnotation */ { + case WDefInstance(_, "SUB4", "Ext", _) => true + }, + /* Node "sub1" becomes "SUB1_0" because instance "SuB1" already got the "SUB1" name. */ { + case ir.DefNode(_, "SUB1_0", _) => true + }, + /* Port "OuT" was skipped via a ManipulateNamesBlocklistAnnotation */ { + case ir.DefNode(_, "CORGE_CORGE", WSubField(WRef("SUB1", _, _, _), "OuT", _, _)) => true + }, + { + case ir.DefNode( + _, + "QUUZQUUZ", + ir.DoPrim( + _, + Seq(WSubField(WRef("SUB2", _, _, _), "OUT", _, _), WSubField(WRef("SUB3", _, _, _), "OUT", _, _)), + _, + _ + ) + ) => + true + }, + /* References to external module ports are not renamed, e.g., "OuT" */ { + case ir.DefNode(_, "GRAULTGRAULT", ir.DoPrim(_, Seq(WSubField(WRef("SUB4", _, _, _), "OuT", _, _)), _, _)) => + true + } ) - expected.foreach( statex should containTree (_) ) + expected.foreach(statex should containTree(_)) } } diff --git a/src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala b/src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala index c4de1f46..a41ac90a 100644 --- a/src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala +++ b/src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala @@ -2,21 +2,21 @@ package firrtlTests.fixed -import firrtl.{CircuitState, ChirrtlForm, LowFirrtlCompiler} +import firrtl.{ChirrtlForm, CircuitState, LowFirrtlCompiler} import firrtl.testutils.FirrtlFlatSpec class FixedPointMathSpec extends FirrtlFlatSpec { - val SumPattern = """.*output sum.*<(\d+)>.*.*""".r - val ProductPattern = """.*output product.*<(\d+)>.*""".r + val SumPattern = """.*output sum.*<(\d+)>.*.*""".r + val ProductPattern = """.*output product.*<(\d+)>.*""".r val DifferencePattern = """.*output difference.*<(\d+)>.*""".r - val AssignPattern = """\s*(\w+) <= (\w+)\((.*)\)\s*""".r + val AssignPattern = """\s*(\w+) <= (\w+)\((.*)\)\s*""".r for { - bits1 <- 1 to 4 + bits1 <- 1 to 4 binaryPoint1 <- 1 to 4 - bits2 <- 1 to 4 + bits2 <- 1 to 4 binaryPoint2 <- 1 to 4 } { def config = s"($bits1,$binaryPoint1)($bits2,$binaryPoint2)" @@ -25,26 +25,26 @@ class FixedPointMathSpec extends FirrtlFlatSpec { val input = s"""circuit Unit : - | module Unit : - | input a : Fixed<$bits1><<$binaryPoint1>> - | input b : Fixed<$bits2><<$binaryPoint2>> - | output sum : Fixed - | output product : Fixed - | output difference : Fixed - | sum <= add(a, b) - | product <= mul(a, b) - | difference <= sub(a, b) - | """.stripMargin + | module Unit : + | input a : Fixed<$bits1><<$binaryPoint1>> + | input b : Fixed<$bits2><<$binaryPoint2>> + | output sum : Fixed + | output product : Fixed + | output difference : Fixed + | sum <= add(a, b) + | product <= mul(a, b) + | difference <= sub(a, b) + | """.stripMargin val lowerer = new LowFirrtlCompiler val res = lowerer.compileAndEmit(CircuitState(parse(input), ChirrtlForm)) - val output = res.getEmittedCircuit.value split "\n" + val output = res.getEmittedCircuit.value.split("\n") def inferredAddWidth: Int = { val binaryDifference = binaryPoint1 - binaryPoint2 - val (newW1, newW2) = if(binaryDifference > 0) { + val (newW1, newW2) = if (binaryDifference > 0) { (bits1, bits2 + binaryDifference) } else { (bits1 + binaryDifference.abs, bits2) @@ -54,11 +54,11 @@ class FixedPointMathSpec extends FirrtlFlatSpec { for (line <- output) { line match { - case SumPattern(varWidth) => + case SumPattern(varWidth) => assert(varWidth.toInt === inferredAddWidth, s"$config sum sint bits wrong for $line") case ProductPattern(varWidth) => assert(varWidth.toInt === bits1 + bits2, s"$config product bits wrong for $line") - case DifferencePattern(varWidth) => + case DifferencePattern(varWidth) => assert(varWidth.toInt === inferredAddWidth, s"$config difference bits wrong for $line") case AssignPattern(varName, operation, args) => varName match { @@ -66,11 +66,15 @@ class FixedPointMathSpec extends FirrtlFlatSpec { assert(operation === "add", s"var sum should be result of an add in $line") if (binaryPoint1 > binaryPoint2) { assert(!args.contains("shl(a"), s"$config first arg should be just a in $line") - assert(args.contains(s"shl(b, ${binaryPoint1 - binaryPoint2})"), - s"$config second arg incorrect in $line") + assert( + args.contains(s"shl(b, ${binaryPoint1 - binaryPoint2})"), + s"$config second arg incorrect in $line" + ) } else if (binaryPoint1 < binaryPoint2) { - assert(args.contains(s"shl(a, ${(binaryPoint1 - binaryPoint2).abs})"), - s"$config second arg incorrect in $line") + assert( + args.contains(s"shl(a, ${(binaryPoint1 - binaryPoint2).abs})"), + s"$config second arg incorrect in $line" + ) assert(!args.contains("shl(b"), s"$config second arg should be just b in $line") } else { assert(!args.contains("shl(a"), s"$config first arg should be just a in $line") @@ -84,11 +88,15 @@ class FixedPointMathSpec extends FirrtlFlatSpec { assert(operation === "sub", s"var difference should be result of an sub in $line") if (binaryPoint1 > binaryPoint2) { assert(!args.contains("shl(a"), s"$config first arg should be just a in $line") - assert(args.contains(s"shl(b, ${binaryPoint1 - binaryPoint2})"), - s"$config second arg incorrect in $line") + assert( + args.contains(s"shl(b, ${binaryPoint1 - binaryPoint2})"), + s"$config second arg incorrect in $line" + ) } else if (binaryPoint1 < binaryPoint2) { - assert(args.contains(s"shl(a, ${(binaryPoint1 - binaryPoint2).abs})"), - s"$config second arg incorrect in $line") + assert( + args.contains(s"shl(a, ${(binaryPoint1 - binaryPoint2).abs})"), + s"$config second arg incorrect in $line" + ) assert(!args.contains("shl(b"), s"$config second arg should be just b in $line") } else { assert(!args.contains("shl(a"), s"$config first arg should be just a in $line") @@ -102,4 +110,3 @@ class FixedPointMathSpec extends FirrtlFlatSpec { } } } - diff --git a/src/test/scala/firrtlTests/fixed/FixedSerializationSpec.scala b/src/test/scala/firrtlTests/fixed/FixedSerializationSpec.scala index db107cb3..68125bc0 100644 --- a/src/test/scala/firrtlTests/fixed/FixedSerializationSpec.scala +++ b/src/test/scala/firrtlTests/fixed/FixedSerializationSpec.scala @@ -7,7 +7,7 @@ import firrtl.ir import org.scalatest.flatspec.AnyFlatSpec class FixedSerializationSpec extends AnyFlatSpec { - behavior of "FixedType" + behavior.of("FixedType") it should "serialize correctly" in { assert(ir.FixedType(ir.IntWidth(3), ir.IntWidth(2)).serialize == "Fixed<3><<2>>") @@ -16,7 +16,7 @@ class FixedSerializationSpec extends AnyFlatSpec { assert(ir.FixedType(ir.UnknownWidth, ir.UnknownWidth).serialize == "Fixed") } - behavior of "FixedLiteral" + behavior.of("FixedLiteral") it should "serialize correctly" in { assert(ir.FixedLiteral(1, ir.IntWidth(3), ir.IntWidth(2)).serialize == "Fixed<3><<2>>(\"h1\")") diff --git a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala index 1a7092bb..4d3dbe98 100644 --- a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala +++ b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala @@ -9,12 +9,14 @@ import firrtl.testutils._ class FixedTypeInferenceSpec extends FirrtlFlatSpec { private def executeTest(input: String, expected: Seq[String], passes: Seq[Transform]) = { - val c = passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) - }.circuit - val lines = c.serialize.split("\n") map normalized + val c = passes + .foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + .circuit + val lines = c.serialize.split("\n").map(normalized) - expected foreach { e => + expected.foreach { e => lines should contain(e) } } @@ -29,7 +31,8 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { ResolveFlows, CheckFlows, new InferWidths, - CheckWidths) + CheckWidths + ) val input = """circuit Unit : | module Unit : @@ -46,7 +49,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | input c : Fixed<4><<3>> | output d : Fixed<13><<3>> | d <= add(a, add(b, c))""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "infer add correctly" in { @@ -59,7 +62,8 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { ResolveFlows, CheckFlows, new InferWidths, - CheckWidths) + CheckWidths + ) val input = """circuit Unit : | module Unit : @@ -76,7 +80,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | input c : Fixed<4><<3>> | output d : Fixed<15><<3>> | d <= add(a, add(b, c))""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "be correctly shifted left" in { @@ -89,7 +93,8 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { ResolveFlows, CheckFlows, new InferWidths, - CheckWidths) + CheckWidths + ) val input = """circuit Unit : | module Unit : @@ -102,7 +107,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | input a : Fixed<10><<2>> | output d : Fixed<12><<2>> | d <= shl(a, 2)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "be correctly shifted right" in { @@ -115,7 +120,8 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { ResolveFlows, CheckFlows, new InferWidths, - CheckWidths) + CheckWidths + ) val input = """circuit Unit : | module Unit : @@ -128,7 +134,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | input a : Fixed<10><<2>> | output d : Fixed<8><<2>> | d <= shr(a, 2)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "relatively move binary point left" in { @@ -141,7 +147,8 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { ResolveFlows, CheckFlows, new InferWidths, - CheckWidths) + CheckWidths + ) val input = """circuit Unit : | module Unit : @@ -154,7 +161,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | input a : Fixed<10><<2>> | output d : Fixed<12><<4>> | d <= incp(a, 2)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "relatively move binary point right" in { @@ -167,7 +174,8 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { ResolveFlows, CheckFlows, new InferWidths, - CheckWidths) + CheckWidths + ) val input = """circuit Unit : | module Unit : @@ -180,7 +188,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | input a : Fixed<10><<2>> | output d : Fixed<8><<0>> | d <= decp(a, 2)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "absolutely set binary point correctly" in { @@ -193,7 +201,8 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { ResolveFlows, CheckFlows, new InferWidths, - CheckWidths) + CheckWidths + ) val input = """circuit Unit : | module Unit : @@ -206,7 +215,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | input a : Fixed<10><<2>> | output d : Fixed<11><<3>> | d <= setp(a, 3)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "cat, head, tail, bits" in { @@ -219,7 +228,8 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { ResolveFlows, CheckFlows, new InferWidths, - CheckWidths) + CheckWidths + ) val input = """circuit Unit : | module Unit : @@ -246,7 +256,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | head <= head(a, 3) | tail <= tail(a, 3) | bits <= bits(a, 6, 3)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "be cast to" in { @@ -259,7 +269,8 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { ResolveFlows, CheckFlows, new InferWidths, - CheckWidths) + CheckWidths + ) val input = """circuit Unit : | module Unit : @@ -272,7 +283,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | input a : SInt<10> | output d : Fixed<10><<2>> | d <= asFixedPoint(a, 2)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "support binary point of zero" in { @@ -286,7 +297,8 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { CheckFlows, new InferWidths, CheckWidths, - ConvertFixedToSInt) + ConvertFixedToSInt + ) val input = """ |circuit Unit : @@ -312,53 +324,53 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | io_out <= io_in | """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "work with mems" in { def input(memType: String): String = s""" - |circuit Unit : - | module Unit : - | input clock : Clock - | input in : Fixed<16><<8>> - | input ridx : UInt<3> - | output out : Fixed<16><<8>> - | input widx : UInt<3> - | $memType mem : Fixed<16><<8>>[8] - | infer mport min = mem[ridx], clock - | min <= in - | infer mport mout = mem[widx], clock - | out <= mout + |circuit Unit : + | module Unit : + | input clock : Clock + | input in : Fixed<16><<8>> + | input ridx : UInt<3> + | output out : Fixed<16><<8>> + | input widx : UInt<3> + | $memType mem : Fixed<16><<8>>[8] + | infer mport min = mem[ridx], clock + | min <= in + | infer mport mout = mem[widx], clock + | out <= mout """.stripMargin def check(readLatency: Int, moutEn: Int, minEn: Int): String = s""" - |circuit Unit : - | module Unit : - | input clock : Clock - | input in : SInt<16> - | input ridx : UInt<3> - | output out : SInt<16> - | input widx : UInt<3> - | - | mem mem : - | data-type => SInt<16> - | depth => 8 - | read-latency => $readLatency - | write-latency => 1 - | reader => mout - | writer => min - | read-under-write => undefined - | out <= mem.mout.data - | mem.mout.addr <= widx - | mem.mout.en <= UInt<1>("h$moutEn") - | mem.mout.clk <= clock - | mem.min.addr <= ridx - | mem.min.en <= UInt<1>("h$minEn") - | mem.min.clk <= clock - | mem.min.data <= in - | mem.min.mask <= UInt<1>("h1") + |circuit Unit : + | module Unit : + | input clock : Clock + | input in : SInt<16> + | input ridx : UInt<3> + | output out : SInt<16> + | input widx : UInt<3> + | + | mem mem : + | data-type => SInt<16> + | depth => 8 + | read-latency => $readLatency + | write-latency => 1 + | reader => mout + | writer => min + | read-under-write => undefined + | out <= mem.mout.data + | mem.mout.addr <= widx + | mem.mout.en <= UInt<1>("h$moutEn") + | mem.mout.clk <= clock + | mem.min.addr <= ridx + | mem.min.en <= UInt<1>("h$minEn") + | mem.min.clk <= clock + | mem.min.data <= in + | mem.min.mask <= UInt<1>("h1") """.stripMargin - executeTest(input("smem"), check(1, 0, 1).split("\n") map normalized, new LowFirrtlCompiler) - executeTest(input("cmem"), check(0, 1, 1).split("\n") map normalized, new LowFirrtlCompiler) + executeTest(input("smem"), check(1, 0, 1).split("\n").map(normalized), new LowFirrtlCompiler) + executeTest(input("cmem"), check(0, 1, 1).split("\n").map(normalized), new LowFirrtlCompiler) } } diff --git a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala index 9dc61927..d0218b11 100644 --- a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala +++ b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala @@ -9,12 +9,14 @@ import firrtl.testutils._ class RemoveFixedTypeSpec extends FirrtlFlatSpec { private def executeTest(input: String, expected: Seq[String], passes: Seq[Transform]) = { - val c = passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) - }.circuit - val lines = c.serialize.split("\n") map normalized + val c = passes + .foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + .circuit + val lines = c.serialize.split("\n").map(normalized) - expected foreach { e => + expected.foreach { e => lines should contain(e) } } @@ -30,7 +32,8 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { CheckFlows, new InferWidths, CheckWidths, - ConvertFixedToSInt) + ConvertFixedToSInt + ) val input = """circuit Unit : | module Unit : @@ -41,13 +44,13 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { | d <= add(a, add(b, c))""".stripMargin val check = """circuit Unit : - | module Unit : - | input a : SInt<10> - | input b : SInt<10> - | input c : SInt<4> - | output d : SInt<15> - | d <= shl(add(shl(a, 1), add(shl(b, 3), c)), 2)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + | module Unit : + | input a : SInt<10> + | input b : SInt<10> + | input c : SInt<4> + | output d : SInt<15> + | d <= shl(add(shl(a, 1), add(shl(b, 3), c)), 2)""".stripMargin + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "be removed, even with a bulk connect" in { val passes = Seq( @@ -60,7 +63,8 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { CheckFlows, new InferWidths, CheckWidths, - ConvertFixedToSInt) + ConvertFixedToSInt + ) val input = """circuit Unit : | module Unit : @@ -71,13 +75,13 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { | d <- add(a, add(b, c))""".stripMargin val check = """circuit Unit : - | module Unit : - | input a : SInt<10> - | input b : SInt<10> - | input c : SInt<4> - | output d : SInt<15> - | d <- shl(add(shl(a, 1), add(shl(b, 3), c)), 2)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + | module Unit : + | input a : SInt<10> + | input b : SInt<10> + | input c : SInt<4> + | output d : SInt<15> + | d <- shl(add(shl(a, 1), add(shl(b, 3), c)), 2)""".stripMargin + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "remove binary point shift correctly" in { @@ -91,7 +95,8 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { CheckFlows, new InferWidths, CheckWidths, - ConvertFixedToSInt) + ConvertFixedToSInt + ) val input = """circuit Unit : | module Unit : @@ -104,7 +109,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { | input a : SInt<10> | output d : SInt<12> | d <= shl(a, 2)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "remove binary point shift correctly in reverse" in { @@ -118,7 +123,8 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { CheckFlows, new InferWidths, CheckWidths, - ConvertFixedToSInt) + ConvertFixedToSInt + ) val input = """circuit Unit : | module Unit : @@ -131,7 +137,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { | input a : SInt<10> | output d : SInt<9> | d <= shr(a, 1)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "remove an absolutely set binary point correctly" in { @@ -145,7 +151,8 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { CheckFlows, new InferWidths, CheckWidths, - ConvertFixedToSInt) + ConvertFixedToSInt + ) val input = """circuit Unit : | module Unit : @@ -158,7 +165,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { | input a : SInt<10> | output d : SInt<11> | d <= shl(a, 1)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed point numbers" should "allow binary point to be set to zero at creation" in { @@ -197,7 +204,8 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { CheckFlows, new InferWidths, CheckWidths, - ConvertFixedToSInt) + ConvertFixedToSInt + ) val input = """ |circuit Unit : @@ -210,6 +218,6 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { | module Unit : | node x = asSInt(asSInt(UInt<2>("h3"))) """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } } diff --git a/src/test/scala/firrtlTests/formal/AssertSubmoduleAssumptionsSpec.scala b/src/test/scala/firrtlTests/formal/AssertSubmoduleAssumptionsSpec.scala index edfd31d3..e413a70d 100644 --- a/src/test/scala/firrtlTests/formal/AssertSubmoduleAssumptionsSpec.scala +++ b/src/test/scala/firrtlTests/formal/AssertSubmoduleAssumptionsSpec.scala @@ -1,4 +1,3 @@ - package firrtlTests.formal import firrtl.{CircuitState, Parser, Transform, UnknownForm} @@ -7,24 +6,25 @@ import firrtl.transforms.formal.AssertSubmoduleAssumptions import firrtl.stage.{Forms, TransformManager} class AssertSubmoduleAssumptionsSpec extends FirrtlFlatSpec { - behavior of "AssertSubmoduleAssumptions" + behavior.of("AssertSubmoduleAssumptions") - val transforms = new TransformManager(Forms.HighForm, Forms.MinimalHighForm) - .flattenedTransformOrder ++ Seq(new AssertSubmoduleAssumptions) + val transforms = new TransformManager(Forms.HighForm, Forms.MinimalHighForm).flattenedTransformOrder ++ Seq( + new AssertSubmoduleAssumptions + ) def run(input: String, check: Seq[String], debug: Boolean = false): Unit = { val circuit = Parser.parse(input.split("\n").toIterator) - val result = transforms.foldLeft(CircuitState(circuit, UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) + val result = transforms.foldLeft(CircuitState(circuit, UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) } - val lines = result.circuit.serialize.split("\n") map normalized + val lines = result.circuit.serialize.split("\n").map(normalized) if (debug) { println(lines.mkString("\n")) } for (ch <- check) { - lines should contain (ch) + lines should contain(ch) } } diff --git a/src/test/scala/firrtlTests/formal/ConvertAssertsSpec.scala b/src/test/scala/firrtlTests/formal/ConvertAssertsSpec.scala index c70a3ce4..847c211e 100644 --- a/src/test/scala/firrtlTests/formal/ConvertAssertsSpec.scala +++ b/src/test/scala/firrtlTests/formal/ConvertAssertsSpec.scala @@ -8,15 +8,15 @@ import firrtl.transforms.formal.ConvertAsserts class ConvertAssertsSpec extends FirrtlFlatSpec { val preamble = - """circuit DUT: - | module DUT: - | input clock: Clock - | input reset: UInt<1> - | input x: UInt<8> - | output y: UInt<8> - | y <= x - | node ne5 = neq(x, UInt(5)) - |""".stripMargin + """circuit DUT: + | module DUT: + | input clock: Clock + | input reset: UInt<1> + | input x: UInt<8> + | output y: UInt<8> + | y <= x + | node ne5 = neq(x, UInt(5)) + |""".stripMargin "assert nodes" should "be converted to predicated prints and stops" in { val input = preamble + @@ -29,7 +29,7 @@ class ConvertAssertsSpec extends FirrtlFlatSpec { |""".stripMargin val outputCS = ConvertAsserts.execute(CircuitState(parse(input), Nil)) - (parse(outputCS.circuit.serialize)) should be (parse(ref)) + (parse(outputCS.circuit.serialize)) should be(parse(ref)) } "assert nodes with no message" should "omit printed messages" in { @@ -42,6 +42,6 @@ class ConvertAssertsSpec extends FirrtlFlatSpec { |""".stripMargin val outputCS = ConvertAsserts.execute(CircuitState(parse(input), Nil)) - (parse(outputCS.circuit.serialize)) should be (parse(ref)) + (parse(outputCS.circuit.serialize)) should be(parse(ref)) } } diff --git a/src/test/scala/firrtlTests/formal/RemoveVerificationStatementsSpec.scala b/src/test/scala/firrtlTests/formal/RemoveVerificationStatementsSpec.scala index 10e63ae4..40d810c5 100644 --- a/src/test/scala/firrtlTests/formal/RemoveVerificationStatementsSpec.scala +++ b/src/test/scala/firrtlTests/formal/RemoveVerificationStatementsSpec.scala @@ -1,4 +1,3 @@ - package firrtlTests.formal import firrtl.{CircuitState, Parser, Transform, UnknownForm} @@ -7,17 +6,18 @@ import firrtl.testutils.FirrtlFlatSpec import firrtl.transforms.formal.RemoveVerificationStatements class RemoveVerificationStatementsSpec extends FirrtlFlatSpec { - behavior of "RemoveVerificationStatements" + behavior.of("RemoveVerificationStatements") - val transforms = new TransformManager(Forms.HighForm, Forms.MinimalHighForm) - .flattenedTransformOrder ++ Seq(new RemoveVerificationStatements) + val transforms = new TransformManager(Forms.HighForm, Forms.MinimalHighForm).flattenedTransformOrder ++ Seq( + new RemoveVerificationStatements + ) def run(input: String, antiCheck: Seq[String], debug: Boolean = false): Unit = { val circuit = Parser.parse(input.split("\n").toIterator) - val result = transforms.foldLeft(CircuitState(circuit, UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) + val result = transforms.foldLeft(CircuitState(circuit, UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) } - val lines = result.circuit.serialize.split("\n") map normalized + val lines = result.circuit.serialize.split("\n").map(normalized) if (debug) { println(lines.mkString("\n")) diff --git a/src/test/scala/firrtlTests/formal/VerificationSpec.scala b/src/test/scala/firrtlTests/formal/VerificationSpec.scala index 73d1404d..a8e28c13 100644 --- a/src/test/scala/firrtlTests/formal/VerificationSpec.scala +++ b/src/test/scala/firrtlTests/formal/VerificationSpec.scala @@ -2,14 +2,14 @@ package firrtlTests.formal -import firrtl.{CircuitState, SystemVerilogCompiler, ir} +import firrtl.{ir, CircuitState, SystemVerilogCompiler} import firrtl.testutils.FirrtlFlatSpec import logger.{LogLevel, Logger} import firrtl.options.Dependency import firrtl.stage.TransformManager class VerificationSpec extends FirrtlFlatSpec { - behavior of "Formal" + behavior.of("Formal") it should "generate SystemVerilog verification statements" in { val compiler = new SystemVerilogCompiler @@ -56,7 +56,7 @@ class VerificationSpec extends FirrtlFlatSpec { | end | end |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, expected, compiler) } diff --git a/src/test/scala/firrtlTests/graph/DiGraphTests.scala b/src/test/scala/firrtlTests/graph/DiGraphTests.scala index 0f8c7193..0f5cf09c 100644 --- a/src/test/scala/firrtlTests/graph/DiGraphTests.scala +++ b/src/test/scala/firrtlTests/graph/DiGraphTests.scala @@ -7,32 +7,24 @@ import firrtl.testutils._ class DiGraphTests extends FirrtlFlatSpec { - val acyclicGraph = DiGraph(Map( - "a" -> Set("b","c"), - "b" -> Set("d"), - "c" -> Set("d"), - "d" -> Set("e"), - "e" -> Set.empty[String])) - - val reversedAcyclicGraph = DiGraph(Map( - "a" -> Set.empty[String], - "b" -> Set("a"), - "c" -> Set("a"), - "d" -> Set("b", "c"), - "e" -> Set("d"))) - - val cyclicGraph = DiGraph(Map( - "a" -> Set("b","c"), - "b" -> Set("d"), - "c" -> Set("d"), - "d" -> Set("a"))) - - val tupleGraph = DiGraph(Map( - ("a", 0) -> Set(("b", 2)), - ("a", 1) -> Set(("c", 3)), - ("b", 2) -> Set.empty[(String, Int)], - ("c", 3) -> Set.empty[(String, Int)] - )) + val acyclicGraph = DiGraph( + Map("a" -> Set("b", "c"), "b" -> Set("d"), "c" -> Set("d"), "d" -> Set("e"), "e" -> Set.empty[String]) + ) + + val reversedAcyclicGraph = DiGraph( + Map("a" -> Set.empty[String], "b" -> Set("a"), "c" -> Set("a"), "d" -> Set("b", "c"), "e" -> Set("d")) + ) + + val cyclicGraph = DiGraph(Map("a" -> Set("b", "c"), "b" -> Set("d"), "c" -> Set("d"), "d" -> Set("a"))) + + val tupleGraph = DiGraph( + Map( + ("a", 0) -> Set(("b", 2)), + ("a", 1) -> Set(("c", 3)), + ("b", 2) -> Set.empty[(String, Int)], + ("c", 3) -> Set.empty[(String, Int)] + ) + ) val degenerateGraph = DiGraph(Map("a" -> Set.empty[String])) @@ -45,109 +37,113 @@ class DiGraphTests extends FirrtlFlatSpec { } "Asking a DiGraph for a path that exists" should "work" in { - acyclicGraph.path("a","e") should not be empty + acyclicGraph.path("a", "e") should not be empty } "Asking a DiGraph for a path from one node to another with no path" should "error" in { - an [PathNotFoundException] should be thrownBy acyclicGraph.path("e","a") + an[PathNotFoundException] should be thrownBy acyclicGraph.path("e", "a") } "The first element in a linearized graph with a single root node" should "be the root" in { - acyclicGraph.linearize.head should equal ("a") + acyclicGraph.linearize.head should equal("a") } "A DiGraph with a cycle" should "error when linearized" in { - a [CyclicException] should be thrownBy cyclicGraph.linearize + a[CyclicException] should be thrownBy cyclicGraph.linearize } "CyclicExceptions" should "contain information about the cycle" in { - val c = the [CyclicException] thrownBy { + val c = the[CyclicException] thrownBy { cyclicGraph.linearize } - c.getMessage.contains("found at a") should be (true) - c.node.asInstanceOf[String] should be ("a") + c.getMessage.contains("found at a") should be(true) + c.node.asInstanceOf[String] should be("a") } "Reversing a graph" should "reverse all of the edges" in { - acyclicGraph.reverse.getEdgeMap should equal (reversedAcyclicGraph.getEdgeMap) + acyclicGraph.reverse.getEdgeMap should equal(reversedAcyclicGraph.getEdgeMap) } "Reversing a graph with no edges" should "equal the graph itself" in { - degenerateGraph.getEdgeMap should equal (degenerateGraph.reverse.getEdgeMap) + degenerateGraph.getEdgeMap should equal(degenerateGraph.reverse.getEdgeMap) } "transformNodes" should "combine vertices that collide, not drop them" in { - tupleGraph.transformNodes(_._1).getEdgeMap should contain ("a" -> Set("b", "c")) + tupleGraph.transformNodes(_._1).getEdgeMap should contain("a" -> Set("b", "c")) } "Graph summation" should "be order-wise equivalent to original" in { val first = acyclicGraph.subgraph(Set("a", "b", "c")) val second = acyclicGraph.subgraph(Set("b", "c", "d", "e")) - (first + second).getEdgeMap should equal (acyclicGraph.getEdgeMap) + (first + second).getEdgeMap should equal(acyclicGraph.getEdgeMap) } it should "be idempotent" in { val first = acyclicGraph.subgraph(Set("a", "b", "c")) val second = acyclicGraph.subgraph(Set("b", "c", "d", "e")) - (first + second + second + second).getEdgeMap should equal (acyclicGraph.getEdgeMap) + (first + second + second + second).getEdgeMap should equal(acyclicGraph.getEdgeMap) } "linearize" should "not cause a stack overflow on very large graphs" in { // Graph of 0 -> 1, 1 -> 2, etc. val N = 10000 - val edges = (1 to N).zipWithIndex.map({ case (n, idx) => idx -> Set(n)}).toMap + val edges = (1 to N).zipWithIndex.map({ case (n, idx) => idx -> Set(n) }).toMap val bigGraph = DiGraph(edges + (N -> Set.empty[Int])) - bigGraph.linearize should be (0 to N) + bigGraph.linearize should be(0 to N) } it should "work on multi-rooted graphs" in { val graph = DiGraph(Map("a" -> Set[String](), "b" -> Set[String]())) - graph.linearize.toSet should be (graph.getVertices) + graph.linearize.toSet should be(graph.getVertices) } "acyclic graph" should "be rendered" in { - val acyclicGraph2 = DiGraph(Map( - "a" -> Set("b","c"), - "b" -> Set("d", "x", "z"), - "c" -> Set("d", "x"), - "d" -> Set("e", "k", "l"), - "x" -> Set("e"), - "z" -> Set("e", "j"), - "j" -> Set("k", "l", "c"), - "k" -> Set("l"), - "l" -> Set("e"), - "e" -> Set.empty[String] - )) + val acyclicGraph2 = DiGraph( + Map( + "a" -> Set("b", "c"), + "b" -> Set("d", "x", "z"), + "c" -> Set("d", "x"), + "d" -> Set("e", "k", "l"), + "x" -> Set("e"), + "z" -> Set("e", "j"), + "j" -> Set("k", "l", "c"), + "k" -> Set("l"), + "l" -> Set("e"), + "e" -> Set.empty[String] + ) + ) val render = new RenderDiGraph(acyclicGraph2) val dotLines = render.toDotRanked.split("\n") - dotLines.count(s => s.contains("rank=same")) should be (4) - dotLines.exists(s => s.contains(""""b" -> { "d" "x" "z" };""")) should be (true) - dotLines.exists(s => s.contains("""rankdir="LR";""")) should be (true) + dotLines.count(s => s.contains("rank=same")) should be(4) + dotLines.exists(s => s.contains(""""b" -> { "d" "x" "z" };""")) should be(true) + dotLines.exists(s => s.contains("""rankdir="LR";""")) should be(true) } "subgraphs containing cycles" should "be rendered with loop edges in red, can override orientation" in { - val cyclicGraph2 = DiGraph(Map( - "a" -> Set("b","c"), - "b" -> Set("d", "x", "z"), - "c" -> Set("d", "x"), - "d" -> Set("e", "k", "l"), - "x" -> Set("e"), - "z" -> Set("e", "j"), - "j" -> Set("k", "l", "c"), - "k" -> Set("l"), - "l" -> Set("e"), - "e" -> Set("c") - )) + val cyclicGraph2 = DiGraph( + Map( + "a" -> Set("b", "c"), + "b" -> Set("d", "x", "z"), + "c" -> Set("d", "x"), + "d" -> Set("e", "k", "l"), + "x" -> Set("e"), + "z" -> Set("e", "j"), + "j" -> Set("k", "l", "c"), + "k" -> Set("l"), + "l" -> Set("e"), + "e" -> Set("c") + ) + ) val render = new RenderDiGraph(cyclicGraph2, rankDir = "TB") val dotLines = render.showOnlyTheLoopAsDot.split("\n") - dotLines.count(s => s.contains("rank=same")) should be (4) - dotLines.count(s => s.contains("""[color=red,penwidth=3.0];""")) should be (3) - dotLines.exists(s => s.contains(""""d" -> "k";""")) should be (true) - dotLines.exists(s => s.contains("""rankdir="TB";""")) should be (true) + dotLines.count(s => s.contains("rank=same")) should be(4) + dotLines.count(s => s.contains("""[color=red,penwidth=3.0];""")) should be(3) + dotLines.exists(s => s.contains(""""d" -> "k";""")) should be(true) + dotLines.exists(s => s.contains("""rankdir="TB";""")) should be(true) } "reachableFrom" should "omit the queried node if no self-path exists" in { diff --git a/src/test/scala/firrtlTests/graph/EulerTourTests.scala b/src/test/scala/firrtlTests/graph/EulerTourTests.scala index f6deb721..703235af 100644 --- a/src/test/scala/firrtlTests/graph/EulerTourTests.scala +++ b/src/test/scala/firrtlTests/graph/EulerTourTests.scala @@ -11,26 +11,30 @@ class EulerTourTests extends FirrtlFlatSpec { val third_layer = Set("3a", "3b", "3c") val last_null = Set.empty[String] - val m = Map(top -> first_layer) ++ first_layer.map{ - case x => Map(x -> second_layer) }.flatten.toMap ++ second_layer.map{ - case x => Map(x -> third_layer) }.flatten.toMap ++ third_layer.map{ - case x => Map(x -> last_null) }.flatten.toMap + val m = Map(top -> first_layer) ++ first_layer.map { + case x => Map(x -> second_layer) + }.flatten.toMap ++ second_layer.map { + case x => Map(x -> third_layer) + }.flatten.toMap ++ third_layer.map { + case x => Map(x -> last_null) + }.flatten.toMap val graph = DiGraph(m) val instances = graph.pathsInDAG(top).values.flatten val tour = EulerTour(graph, top) it should "show equivalency of Berkman--Vishkin and naive RMQs" in { - instances.toSeq.combinations(2).toList.map { case Seq(a, b) => - tour.rmqNaive(a, b) should be (tour.rmqBV(a, b)) + instances.toSeq.combinations(2).toList.map { + case Seq(a, b) => + tour.rmqNaive(a, b) should be(tour.rmqBV(a, b)) } } it should "determine naive RMQs of itself correctly" in { - instances.toSeq.map { case a => tour.rmqNaive(a, a) should be (a) } + instances.toSeq.map { case a => tour.rmqNaive(a, a) should be(a) } } it should "determine Berkman--Vishkin RMQs of itself correctly" in { - instances.toSeq.map { case a => tour.rmqNaive(a, a) should be (a) } + instances.toSeq.map { case a => tour.rmqNaive(a, a) should be(a) } } } diff --git a/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala b/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala index 656e1f8c..74e6cabf 100644 --- a/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala +++ b/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala @@ -10,14 +10,14 @@ import firrtl.constraint._ import firrtl.testutils.FirrtlFlatSpec class IntervalMathSpec extends FirrtlFlatSpec { - val SumPattern = """.*output sum.*<(\d+)>.*""".r - val ProductPattern = """.*output product.*<(\d+)>.*""".r - val DifferencePattern = """.*output difference.*<(\d+)>.*""".r - val ComparisonPattern = """.*output (\w+).*UInt<(\d+)>.*""".r - val ShiftLeftPattern = """.*output shl.*<(\d+)>.*""".r - val ShiftRightPattern = """.*output shr.*<(\d+)>.*""".r - val DShiftLeftPattern = """.*output dshl.*<(\d+)>.*""".r - val DShiftRightPattern = """.*output dshr.*<(\d+)>.*""".r + val SumPattern = """.*output sum.*<(\d+)>.*""".r + val ProductPattern = """.*output product.*<(\d+)>.*""".r + val DifferencePattern = """.*output difference.*<(\d+)>.*""".r + val ComparisonPattern = """.*output (\w+).*UInt<(\d+)>.*""".r + val ShiftLeftPattern = """.*output shl.*<(\d+)>.*""".r + val ShiftRightPattern = """.*output shr.*<(\d+)>.*""".r + val DShiftLeftPattern = """.*output dshl.*<(\d+)>.*""".r + val DShiftRightPattern = """.*output dshr.*<(\d+)>.*""".r val ArithAssignPattern = """\s*(\w+) <= asSInt\(bits\((\w+)\((.*)\).*\)\)\s*""".r def getBound(bound: String, value: BigDecimal): IsKnown = bound match { case "[" => Closed(value) @@ -29,16 +29,16 @@ class IntervalMathSpec extends FirrtlFlatSpec { val prec = 0.5 for { - lb1 <- Seq("[", "(") - lv1 <- Range.BigDecimal(-1.0, 1.0, prec) - uv1 <- if(lb1 == "[") Range.BigDecimal(lv1, 1.0, prec) else Range.BigDecimal(lv1 + prec, 1.0, prec) - ub1 <- if (lv1 == uv1) Seq("]") else Seq("]", ")") - bp1 <- 0 to 1 - lb2 <- Seq("[", "(") - lv2 <- Range.BigDecimal(-1.0, 1.0, prec) - uv2 <- if(lb2 == "[") Range.BigDecimal(lv2, 1.0, prec) else Range.BigDecimal(lv2 + prec, 1.0, prec) - ub2 <- if (lv2 == uv2) Seq("]") else Seq("]", ")") - bp2 <- 0 to 1 + lb1 <- Seq("[", "(") + lv1 <- Range.BigDecimal(-1.0, 1.0, prec) + uv1 <- if (lb1 == "[") Range.BigDecimal(lv1, 1.0, prec) else Range.BigDecimal(lv1 + prec, 1.0, prec) + ub1 <- if (lv1 == uv1) Seq("]") else Seq("]", ")") + bp1 <- 0 to 1 + lb2 <- Seq("[", "(") + lv2 <- Range.BigDecimal(-1.0, 1.0, prec) + uv2 <- if (lb2 == "[") Range.BigDecimal(lv2, 1.0, prec) else Range.BigDecimal(lv2 + prec, 1.0, prec) + ub2 <- if (lv2 == uv2) Seq("]") else Seq("]", ")") + bp2 <- 0 to 1 } { val it1 = IntervalType(getBound(lb1, lv1), getBound(ub1, uv1), IntWidth(bp1.toInt)) val it2 = IntervalType(getBound(lb2, lv2), getBound(ub2, uv2), IntWidth(bp2.toInt)) @@ -47,103 +47,108 @@ class IntervalMathSpec extends FirrtlFlatSpec { case (_, Some(Nil)) => case _ => def config = s"$lb1$lv1,$uv1$ub1.$bp1 and $lb2$lv2,$uv2$ub2.$bp2" - + s"Configuration $config" should "pass" in { - + val input = s"""circuit Unit : - | module Unit : - | input in1 : Interval$lb1$lv1, $uv1$ub1.$bp1 - | input in2 : Interval$lb2$lv2, $uv2$ub2.$bp2 - | input amt : UInt<3> - | output sum : Interval - | output difference : Interval - | output product : Interval - | output shl : Interval - | output shr : Interval - | output dshl : Interval - | output dshr : Interval - | output lt : UInt - | output leq : UInt - | output gt : UInt - | output geq : UInt - | output eq : UInt - | output neq : UInt - | output cat : UInt - | sum <= add(in1, in2) - | difference <= sub(in1, in2) - | product <= mul(in1, in2) - | shl <= shl(in1, 3) - | shr <= shr(in1, 3) - | dshl <= dshl(in1, amt) - | dshr <= dshr(in1, amt) - | lt <= lt(in1, in2) - | leq <= leq(in1, in2) - | gt <= gt(in1, in2) - | geq <= geq(in1, in2) - | eq <= eq(in1, in2) - | neq <= lt(in1, in2) - | cat <= cat(in1, in2) - | """.stripMargin - + | module Unit : + | input in1 : Interval$lb1$lv1, $uv1$ub1.$bp1 + | input in2 : Interval$lb2$lv2, $uv2$ub2.$bp2 + | input amt : UInt<3> + | output sum : Interval + | output difference : Interval + | output product : Interval + | output shl : Interval + | output shr : Interval + | output dshl : Interval + | output dshr : Interval + | output lt : UInt + | output leq : UInt + | output gt : UInt + | output geq : UInt + | output eq : UInt + | output neq : UInt + | output cat : UInt + | sum <= add(in1, in2) + | difference <= sub(in1, in2) + | product <= mul(in1, in2) + | shl <= shl(in1, 3) + | shr <= shr(in1, 3) + | dshl <= dshl(in1, amt) + | dshr <= dshr(in1, amt) + | lt <= lt(in1, in2) + | leq <= leq(in1, in2) + | gt <= gt(in1, in2) + | geq <= geq(in1, in2) + | eq <= eq(in1, in2) + | neq <= lt(in1, in2) + | cat <= cat(in1, in2) + | """.stripMargin + val lowerer = new LowFirrtlCompiler val res = lowerer.compileAndEmit(CircuitState(parse(input), ChirrtlForm)) - val output = res.getEmittedCircuit.value split "\n" + val output = res.getEmittedCircuit.value.split("\n") val min1 = Closed(it1.min.get) val max1 = Closed(it1.max.get) val min2 = Closed(it2.min.get) val max2 = Closed(it2.max.get) for (line <- output) { line match { - case SumPattern(varWidth) => + case SumPattern(varWidth) => val bp = IntWidth(Math.max(bp1.toInt, bp2.toInt)) val it = IntervalType(IsAdd(min1, min2), IsAdd(max1, max2), bp) assert(varWidth.toInt == it.width.asInstanceOf[IntWidth].width, s"$line,${it.range}") - case ProductPattern(varWidth) => + case ProductPattern(varWidth) => val bp = IntWidth(bp1.toInt + bp2.toInt) val lv = IsMin(Seq(IsMul(min1, min2), IsMul(min1, max2), IsMul(max1, min2), IsMul(max1, max2))) val uv = IsMax(Seq(IsMul(min1, min2), IsMul(min1, max2), IsMul(max1, min2), IsMul(max1, max2))) assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "product") - case DifferencePattern(varWidth) => + case DifferencePattern(varWidth) => val bp = IntWidth(Math.max(bp1.toInt, bp2.toInt)) val lv = min1 + max2.neg val uv = max1 + min2.neg assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "diff") - case ShiftLeftPattern(varWidth) => + case ShiftLeftPattern(varWidth) => val bp = IntWidth(bp1.toInt) val lv = min1 * Closed(8) val uv = max1 * Closed(8) val it = IntervalType(lv, uv, bp) assert(varWidth.toInt == it.width.asInstanceOf[IntWidth].width, "shl") - case ShiftRightPattern(varWidth) => + case ShiftRightPattern(varWidth) => val bp = IntWidth(bp1.toInt) - val lv = min1 * Closed(1/3) - val uv = max1 * Closed(1/3) + val lv = min1 * Closed(1 / 3) + val uv = max1 * Closed(1 / 3) assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "shr") - case DShiftLeftPattern(varWidth) => + case DShiftLeftPattern(varWidth) => val bp = IntWidth(bp1.toInt) val lv = min1 * Closed(128) val uv = max1 * Closed(128) assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "dshl") - case DShiftRightPattern(varWidth) => + case DShiftRightPattern(varWidth) => val bp = IntWidth(bp1.toInt) val lv = min1 val uv = max1 assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "dshr") case ComparisonPattern(varWidth) => assert(varWidth.toInt == 1, "==") case ArithAssignPattern(varName, operation, args) => - val arg1 = if(IntervalType(getBound(lb1, lv1), getBound(ub1, uv1), IntWidth(bp1)).width == IntWidth(0)) """SInt<1>("h0")""" else "in1" - val arg2 = if(IntervalType(getBound(lb2, lv2), getBound(ub2, uv2), IntWidth(bp2)).width == IntWidth(0)) """SInt<1>("h0")""" else "in2" + val arg1 = + if (IntervalType(getBound(lb1, lv1), getBound(ub1, uv1), IntWidth(bp1)).width == IntWidth(0)) + """SInt<1>("h0")""" + else "in1" + val arg2 = + if (IntervalType(getBound(lb2, lv2), getBound(ub2, uv2), IntWidth(bp2)).width == IntWidth(0)) + """SInt<1>("h0")""" + else "in2" varName match { case "sum" => assert(operation === "add", s"""var sum should be result of an add in ${output.mkString("\n")}""") if (bp1 > bp2) { - if (arg1 != arg2) assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") - assert(args.contains(s"shl($arg2, ${bp1 - bp2})"), - s"$config second arg incorrect in $line") + if (arg1 != arg2) + assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(args.contains(s"shl($arg2, ${bp1 - bp2})"), s"$config second arg incorrect in $line") } else if (bp1 < bp2) { - assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"), - s"$config second arg incorrect in $line") + assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"), s"$config second arg incorrect in $line") assert(!args.contains("shl($arg2"), s"$config second arg should be just $arg2 in $line") } else { assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") @@ -156,13 +161,13 @@ class IntervalMathSpec extends FirrtlFlatSpec { case "difference" => assert(operation === "sub", s"var difference should be result of an sub in $line") if (bp1 > bp2) { - if (arg1 != arg2) assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") - assert(args.contains(s"shl($arg2, ${bp1 - bp2})"), - s"$config second arg incorrect in $line") + if (arg1 != arg2) + assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(args.contains(s"shl($arg2, ${bp1 - bp2})"), s"$config second arg incorrect in $line") } else if (bp1 < bp2) { - assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"), - s"$config second arg incorrect in $line") - if (arg1 != arg2) assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") + assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"), s"$config second arg incorrect in $line") + if (arg1 != arg2) + assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") } else { assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") @@ -170,12 +175,11 @@ class IntervalMathSpec extends FirrtlFlatSpec { case _ => } case _ => + } } } - } } } } - // vim: set ts=4 sw=4 et: diff --git a/src/test/scala/firrtlTests/interval/IntervalSpec.scala b/src/test/scala/firrtlTests/interval/IntervalSpec.scala index 5d82f6b5..1a39e98e 100644 --- a/src/test/scala/firrtlTests/interval/IntervalSpec.scala +++ b/src/test/scala/firrtlTests/interval/IntervalSpec.scala @@ -10,13 +10,12 @@ import firrtl.testutils.FirrtlFlatSpec class IntervalSpec extends FirrtlFlatSpec { private def executeTest(input: String, expected: Seq[String], passes: Seq[Transform]) = { - val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Transform) => - p.runTransform(CircuitState(c, UnknownForm, AnnotationSeq(Nil), None)).circuit + val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Transform) => + p.runTransform(CircuitState(c, UnknownForm, AnnotationSeq(Nil), None)).circuit } - val lines = c.serialize.split("\n") map normalized + val lines = c.serialize.split("\n").map(normalized) - expected foreach { e => + expected.foreach { e => lines should contain(e) } } @@ -37,7 +36,7 @@ class IntervalSpec extends FirrtlFlatSpec { | output out1 : Interval | out0 <= add(in0, add(in1, add(in2, add(in3, add(in4, add(in5, in6)))))) | out1 <= add(in0, add(in1, add(in2, add(in3, add(in4, add(in5, in6))))))""".stripMargin - executeTest(input, input.split("\n") map normalized, passes) + executeTest(input, input.split("\n").map(normalized), passes) } "Interval types" should "infer bp correctly" in { @@ -58,7 +57,7 @@ class IntervalSpec extends FirrtlFlatSpec { | input in2 : Interval(-0.32, 10].2 | output out0 : Interval.4 | out0 <= add(in0, add(in1, in2))""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "trim known intervals correctly" in { @@ -79,11 +78,12 @@ class IntervalSpec extends FirrtlFlatSpec { | input in2 : Interval[-0.25, 10].2 | output out0 : Interval.4 | out0 <= add(in0, incp(add(in1, incp(in2, 1)), 1))""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "infer intervals correctly" in { - val passes = Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) + val passes = + Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) val input = """circuit Unit : | module Unit : @@ -100,11 +100,19 @@ class IntervalSpec extends FirrtlFlatSpec { """output out0 : Interval[-0.5625, 22.9375].4 |output out1 : Interval[-74.53125, 298.125].9 |output out2 : Interval[-10.6875, 12.8125].4""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "be removed correctly" in { - val passes = Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths(), new RemoveIntervals()) + val passes = Seq( + ToWorkingIR, + InferTypes, + ResolveFlows, + new InferBinaryPoints(), + new TrimIntervals(), + new InferWidths(), + new RemoveIntervals() + ) val input = """circuit Unit : | module Unit : @@ -129,209 +137,227 @@ class IntervalSpec extends FirrtlFlatSpec { | out0 <= add(in0, shl(add(in1, shl(in2, 1)), 1)) | out1 <= mul(in0, mul(in1, in2)) | out2 <= sub(in0, shl(sub(in1, shl(in2, 1)), 1))""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } -"Interval types" should "infer multiplication by zero correctly" in { - val passes = Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) + "Interval types" should "infer multiplication by zero correctly" in { + val passes = + Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) val input = s"""circuit Unit : - | module Unit : - | input in1 : Interval[0, 0.5].1 - | input in2 : Interval[0, 0].1 - | output mul : Interval - | mul <= mul(in2, in1) - | """.stripMargin - val check = s"""output mul : Interval[0, 0].2 """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) -} + | module Unit : + | input in1 : Interval[0, 0.5].1 + | input in2 : Interval[0, 0].1 + | output mul : Interval + | mul <= mul(in2, in1) + | """.stripMargin + val check = s"""output mul : Interval[0, 0].2 """.stripMargin + executeTest(input, check.split("\n").map(normalized), passes) + } "Interval types" should "infer muxes correctly" in { - val passes = Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) - val input = - s"""circuit Unit : - | module Unit : - | input p : UInt<1> - | input in1 : Interval[0, 0.5].1 - | input in2 : Interval[0, 0].1 - | output out : Interval - | out <= mux(p, in2, in1) - | """.stripMargin + val passes = + Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) + val input = + s"""circuit Unit : + | module Unit : + | input p : UInt<1> + | input in1 : Interval[0, 0.5].1 + | input in2 : Interval[0, 0].1 + | output out : Interval + | out <= mux(p, in2, in1) + | """.stripMargin val check = s"""output out : Interval[0, 0.5].1 """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "infer dshl correctly" in { - val passes = Seq(ToWorkingIR, InferTypes, ResolveKinds, ResolveFlows, new InferBinaryPoints(), new TrimIntervals, new InferWidths()) - val input = - s"""circuit Unit : - | module Unit : - | input p : UInt<3> - | input in1 : Interval[-1, 1].0 - | output out : Interval - | out <= dshl(in1, p) - | """.stripMargin + val passes = Seq( + ToWorkingIR, + InferTypes, + ResolveKinds, + ResolveFlows, + new InferBinaryPoints(), + new TrimIntervals, + new InferWidths() + ) + val input = + s"""circuit Unit : + | module Unit : + | input p : UInt<3> + | input in1 : Interval[-1, 1].0 + | output out : Interval + | out <= dshl(in1, p) + | """.stripMargin val check = s"""output out : Interval[-128, 128].0 """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "infer asInterval correctly" in { val passes = Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferWidths()) - val input = - s"""circuit Unit : - | module Unit : - | input p : UInt<3> - | output out : Interval - | out <= asInterval(p, 0, 4, 1) - | """.stripMargin + val input = + s"""circuit Unit : + | module Unit : + | input p : UInt<3> + | output out : Interval + | out <= asInterval(p, 0, 4, 1) + | """.stripMargin val check = s"""output out : Interval[0, 2].1 """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "do wrap/clip correctly" in { val passes = Seq(ToWorkingIR, new ResolveAndCheck()) - val input = - s"""circuit Unit : - | module Unit : - | input s: SInt<2> - | input u: UInt<3> - | input in1: Interval[-3, 5].0 - | output wrap3: Interval - | output wrap4: Interval - | output wrap5: Interval - | output wrap6: Interval - | output wrap7: Interval - | output clip3: Interval - | output clip4: Interval - | output clip5: Interval - | output clip6: Interval - | output clip7: Interval - | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0)) - | wrap4 <= wrap(in1, asInterval(s, -1, 1, 0)) - | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0)) - | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0)) - | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0)) - | clip3 <= clip(in1, asInterval(s, -2, 4, 0)) - | clip4 <= clip(in1, asInterval(s, -1, 1, 0)) - | clip5 <= clip(in1, asInterval(s, -4, 4, 0)) - | clip6 <= clip(in1, asInterval(s, -1, 7, 0)) - | clip7 <= clip(in1, asInterval(s, -4, 7, 0)) + val input = + s"""circuit Unit : + | module Unit : + | input s: SInt<2> + | input u: UInt<3> + | input in1: Interval[-3, 5].0 + | output wrap3: Interval + | output wrap4: Interval + | output wrap5: Interval + | output wrap6: Interval + | output wrap7: Interval + | output clip3: Interval + | output clip4: Interval + | output clip5: Interval + | output clip6: Interval + | output clip7: Interval + | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0)) + | wrap4 <= wrap(in1, asInterval(s, -1, 1, 0)) + | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0)) + | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0)) + | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0)) + | clip3 <= clip(in1, asInterval(s, -2, 4, 0)) + | clip4 <= clip(in1, asInterval(s, -1, 1, 0)) + | clip5 <= clip(in1, asInterval(s, -4, 4, 0)) + | clip6 <= clip(in1, asInterval(s, -1, 7, 0)) + | clip7 <= clip(in1, asInterval(s, -4, 7, 0)) """.stripMargin - //| output wrap1: Interval - //| output wrap2: Interval - //| output clip1: Interval - //| output clip2: Interval - //| wrap1 <= wrap(in1, u, 0) - //| wrap2 <= wrap(in1, s, 0) - //| clip1 <= clip(in1, u) - //| clip2 <= clip(in1, s) + //| output wrap1: Interval + //| output wrap2: Interval + //| output clip1: Interval + //| output clip2: Interval + //| wrap1 <= wrap(in1, u, 0) + //| wrap2 <= wrap(in1, s, 0) + //| clip1 <= clip(in1, u) + //| clip2 <= clip(in1, s) val check = s""" - | output wrap3 : Interval[-2, 4].0 - | output wrap4 : Interval[-1, 1].0 - | output wrap5 : Interval[-4, 4].0 - | output wrap6 : Interval[-1, 7].0 - | output wrap7 : Interval[-4, 7].0 - | output clip3 : Interval[-2, 4].0 - | output clip4 : Interval[-1, 1].0 - | output clip5 : Interval[-3, 4].0 - | output clip6 : Interval[-1, 5].0 - | output clip7 : Interval[-3, 5].0 """.stripMargin - // TODO: this optimization - //| output wrap1 : Interval[0, 7].0 - //| output wrap2 : Interval[-2, 1].0 - //| output clip1 : Interval[0, 5].0 - //| output clip2 : Interval[-2, 1].0 - //| output wrap7 : Interval[-3, 5].0 - executeTest(input, check.split("\n") map normalized, passes) + | output wrap3 : Interval[-2, 4].0 + | output wrap4 : Interval[-1, 1].0 + | output wrap5 : Interval[-4, 4].0 + | output wrap6 : Interval[-1, 7].0 + | output wrap7 : Interval[-4, 7].0 + | output clip3 : Interval[-2, 4].0 + | output clip4 : Interval[-1, 1].0 + | output clip5 : Interval[-3, 4].0 + | output clip6 : Interval[-1, 5].0 + | output clip7 : Interval[-3, 5].0 """.stripMargin + // TODO: this optimization + //| output wrap1 : Interval[0, 7].0 + //| output wrap2 : Interval[-2, 1].0 + //| output clip1 : Interval[0, 5].0 + //| output clip2 : Interval[-2, 1].0 + //| output wrap7 : Interval[-3, 5].0 + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "remove wrap/clip correctly" in { val passes = Seq(ToWorkingIR, new ResolveAndCheck(), new RemoveIntervals()) - val input = - s"""circuit Unit : - | module Unit : - | input s: SInt<2> - | input u: UInt<3> - | input in1: Interval[-3, 5].0 - | output wrap3: Interval - | output wrap5: Interval - | output wrap6: Interval - | output wrap7: Interval - | output clip3: Interval - | output clip4: Interval - | output clip5: Interval - | output clip6: Interval - | output clip7: Interval - | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0)) - | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0)) - | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0)) - | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0)) - | clip3 <= clip(in1, asInterval(s, -2, 4, 0)) - | clip4 <= clip(in1, asInterval(s, -1, 1, 0)) - | clip5 <= clip(in1, asInterval(s, -4, 4, 0)) - | clip6 <= clip(in1, asInterval(s, -1, 7, 0)) - | clip7 <= clip(in1, asInterval(s, -4, 7, 0)) - | """.stripMargin + val input = + s"""circuit Unit : + | module Unit : + | input s: SInt<2> + | input u: UInt<3> + | input in1: Interval[-3, 5].0 + | output wrap3: Interval + | output wrap5: Interval + | output wrap6: Interval + | output wrap7: Interval + | output clip3: Interval + | output clip4: Interval + | output clip5: Interval + | output clip6: Interval + | output clip7: Interval + | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0)) + | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0)) + | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0)) + | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0)) + | clip3 <= clip(in1, asInterval(s, -2, 4, 0)) + | clip4 <= clip(in1, asInterval(s, -1, 1, 0)) + | clip5 <= clip(in1, asInterval(s, -4, 4, 0)) + | clip6 <= clip(in1, asInterval(s, -1, 7, 0)) + | clip7 <= clip(in1, asInterval(s, -4, 7, 0)) + | """.stripMargin val check = s""" - | wrap3 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<4>("h7")), mux(lt(in1, SInt<2>("h-2")), add(in1, SInt<4>("h7")), in1)) - | wrap5 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), in1) - | wrap6 <= mux(lt(in1, SInt<1>("h-1")), add(in1, SInt<5>("h9")), in1) - | wrap7 <= in1 - | clip3 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<2>("h-2")), SInt<2>("h-2"), in1)) - | clip4 <= mux(gt(in1, SInt<2>("h1")), SInt<2>("h1"), mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1)) - | clip5 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), in1) - | clip6 <= mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1) - | clip7 <= in1 + | wrap3 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<4>("h7")), mux(lt(in1, SInt<2>("h-2")), add(in1, SInt<4>("h7")), in1)) + | wrap5 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), in1) + | wrap6 <= mux(lt(in1, SInt<1>("h-1")), add(in1, SInt<5>("h9")), in1) + | wrap7 <= in1 + | clip3 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<2>("h-2")), SInt<2>("h-2"), in1)) + | clip4 <= mux(gt(in1, SInt<2>("h1")), SInt<2>("h1"), mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1)) + | clip5 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), in1) + | clip6 <= mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1) + | clip7 <= in1 """.stripMargin - //| output wrap4: Interval - //| wrap4 <= wrap(in1, asInterval(s, -1, 1, 0), 0) - //| wrap4 <= add(rem(sub(in1, SInt<1>("h-1")), sub(SInt<2>("h1"), SInt<1>("h-1"))), SInt<1>("h-1")) - executeTest(input, check.split("\n") map normalized, passes) + //| output wrap4: Interval + //| wrap4 <= wrap(in1, asInterval(s, -1, 1, 0), 0) + //| wrap4 <= add(rem(sub(in1, SInt<1>("h-1")), sub(SInt<2>("h1"), SInt<1>("h-1"))), SInt<1>("h-1")) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "shift wrap/clip correctly" in { val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals()) - val input = - s"""circuit Unit : - | module Unit : - | input s: SInt<2> - | input in1: Interval[-3, 5].1 - | output wrap1: Interval - | output clip1: Interval - | wrap1 <= wrap(in1, asInterval(s, -2, 2, 0)) - | clip1 <= clip(in1, asInterval(s, -2, 2, 0)) - | """.stripMargin + val input = + s"""circuit Unit : + | module Unit : + | input s: SInt<2> + | input in1: Interval[-3, 5].1 + | output wrap1: Interval + | output clip1: Interval + | wrap1 <= wrap(in1, asInterval(s, -2, 2, 0)) + | clip1 <= clip(in1, asInterval(s, -2, 2, 0)) + | """.stripMargin val check = s""" - | wrap1 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), mux(lt(in1, SInt<3>("h-4")), add(in1, SInt<5>("h9")), in1)) - | clip1 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<3>("h-4")), SInt<3>("h-4"), in1)) + | wrap1 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), mux(lt(in1, SInt<3>("h-4")), add(in1, SInt<5>("h9")), in1)) + | clip1 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<3>("h-4")), SInt<3>("h-4"), in1)) """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "infer negative binary points" in { val passes = Seq(ToWorkingIR, new ResolveAndCheck()) - val input = - s"""circuit Unit : - | module Unit : - | input in1: Interval[-2, 4].-1 - | input in2: Interval[-4, 8].-2 - | output out: Interval - | out <= add(in1, in2) - | """.stripMargin + val input = + s"""circuit Unit : + | module Unit : + | input in1: Interval[-2, 4].-1 + | input in2: Interval[-4, 8].-2 + | output out: Interval + | out <= add(in1, in2) + | """.stripMargin val check = s""" - | output out : Interval[-6, 12].-1 + | output out : Interval[-6, 12].-1 """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "remove negative binary points" in { - val passes = Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths(), new RemoveIntervals()) - val input = - s"""circuit Unit : - | module Unit : - | input in1: Interval[-2, 4].-1 - | input in2: Interval[-4, 8].-2 - | output out: Interval.0 - | out <= add(in1, in2) - | """.stripMargin + val passes = Seq( + ToWorkingIR, + InferTypes, + ResolveFlows, + new InferBinaryPoints(), + new TrimIntervals(), + new InferWidths(), + new RemoveIntervals() + ) + val input = + s"""circuit Unit : + | module Unit : + | input in1: Interval[-2, 4].-1 + | input in2: Interval[-4, 8].-2 + | output out: Interval.0 + | out <= add(in1, in2) + | """.stripMargin val check = s""" - | output out : SInt<5> - | out <= shl(add(in1, shl(in2, 1)), 1) + | output out : SInt<5> + | out <= shl(add(in1, shl(in2, 1)), 1) """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "implement squz properly" in { val passes = Seq(ToWorkingIR, new ResolveAndCheck) @@ -372,7 +398,7 @@ class IntervalSpec extends FirrtlFlatSpec { | output minOff : Interval[-1, 4].1 | output offMin : Interval[-1, 4].2 """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "lower squz properly" in { val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals) @@ -413,7 +439,7 @@ class IntervalSpec extends FirrtlFlatSpec { | minOff <= asSInt(bits(min, 4, 0)) | offMin <= asSInt(bits(off, 5, 0)) """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Assigning a larger interval to a smaller interval" should "error!" in { val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals) @@ -424,7 +450,7 @@ class IntervalSpec extends FirrtlFlatSpec { | output out: Interval[2, 3].1 | out <= in | """.stripMargin - intercept[InvalidConnect]{ + intercept[InvalidConnect] { executeTest(input, Nil, passes) } } @@ -437,7 +463,7 @@ class IntervalSpec extends FirrtlFlatSpec { | output out: Interval[2, 3].1 | out <= in | """.stripMargin - intercept[InvalidConnect]{ + intercept[InvalidConnect] { executeTest(input, Nil, passes) } } @@ -512,7 +538,6 @@ class IntervalSpec extends FirrtlFlatSpec { ) } - "Wrap with remainder" should "error" in { intercept[WrapWithRemainder] { val input = diff --git a/src/test/scala/firrtlTests/options/OptionParserSpec.scala b/src/test/scala/firrtlTests/options/OptionParserSpec.scala index e93c9b2c..452e6cb7 100644 --- a/src/test/scala/firrtlTests/options/OptionParserSpec.scala +++ b/src/test/scala/firrtlTests/options/OptionParserSpec.scala @@ -19,66 +19,69 @@ class OptionParserSpec extends AnyFlatSpec with Matchers with firrtl.testutils.U /* An option parser that prepends to a Seq[Int] */ class IntParser extends OptionParser[AnnotationSeq]("Int Parser") { - opt[Int]("integer").abbr("n").unbounded.action( (x, c) => IntAnnotation(x) +: c ) + opt[Int]("integer").abbr("n").unbounded.action((x, c) => IntAnnotation(x) +: c) help("help") } trait DuplicateShortOption { this: OptionParser[AnnotationSeq] => - opt[Int]("not-an-integer").abbr("n").unbounded.action( (x, c) => IntAnnotation(x) +: c ) + opt[Int]("not-an-integer").abbr("n").unbounded.action((x, c) => IntAnnotation(x) +: c) } trait DuplicateLongOption { this: OptionParser[AnnotationSeq] => - opt[Int]("integer").abbr("m").unbounded.action( (x, c) => IntAnnotation(x) +: c ) + opt[Int]("integer").abbr("m").unbounded.action((x, c) => IntAnnotation(x) +: c) } trait WithIntParser { val parser = new IntParser } - behavior of "A default OptionsParser" + behavior.of("A default OptionsParser") it should "call sys.exit if terminate is called" in new WithIntParser { info("exit status of 1 for failure") - catchStatus { parser.terminate(Left("some message")) } should be (Left(1)) + catchStatus { parser.terminate(Left("some message")) } should be(Left(1)) info("exit status of 0 for success") - catchStatus { parser.terminate(Right(())) } should be (Left(0)) + catchStatus { parser.terminate(Right(())) } should be(Left(0)) } it should "print to stderr on an invalid option" in new WithIntParser { - grabStdOutErr{ parser.parse(Array("--foo"), Seq[Annotation]()) }._2 should include ("Unknown option --foo") + grabStdOutErr { parser.parse(Array("--foo"), Seq[Annotation]()) }._2 should include("Unknown option --foo") } - behavior of "An OptionParser with DoNotTerminateOnExit mixed in" + behavior.of("An OptionParser with DoNotTerminateOnExit mixed in") it should "disable sys.exit for terminate method" in { val parser = new IntParser with DoNotTerminateOnExit info("no exit for failure") - catchStatus { parser.terminate(Left("some message")) } should be (Right(())) + catchStatus { parser.terminate(Left("some message")) } should be(Right(())) info("no exit for success") - catchStatus { parser.terminate(Right(())) } should be (Right(())) + catchStatus { parser.terminate(Right(())) } should be(Right(())) } - behavior of "An OptionParser with DuplicateHandling mixed in" + behavior.of("An OptionParser with DuplicateHandling mixed in") it should "detect short duplicates" in { val parser = new IntParser with DuplicateHandling with DuplicateShortOption - intercept[OptionsException] { parser.parse(Array[String](), Seq[Annotation]()) } - .getMessage should startWith ("Duplicate short option") + intercept[OptionsException] { parser.parse(Array[String](), Seq[Annotation]()) }.getMessage should startWith( + "Duplicate short option" + ) } it should "detect long duplicates" in { val parser = new IntParser with DuplicateHandling with DuplicateLongOption - intercept[OptionsException] { parser.parse(Array[String](), Seq[Annotation]()) } - .getMessage should startWith ("Duplicate long option") + intercept[OptionsException] { parser.parse(Array[String](), Seq[Annotation]()) }.getMessage should startWith( + "Duplicate long option" + ) } - behavior of "An OptionParser with ExceptOnError mixed in" + behavior.of("An OptionParser with ExceptOnError mixed in") it should "cause an OptionsException on an invalid option" in { val parser = new IntParser with ExceptOnError - intercept[OptionsException] { parser.parse(Array("--foo"), Seq[Annotation]()) } - .getMessage should include ("Unknown option") + intercept[OptionsException] { parser.parse(Array("--foo"), Seq[Annotation]()) }.getMessage should include( + "Unknown option" + ) } } diff --git a/src/test/scala/firrtlTests/options/OptionsViewSpec.scala b/src/test/scala/firrtlTests/options/OptionsViewSpec.scala index 0c868cb2..504dcdf6 100644 --- a/src/test/scala/firrtlTests/options/OptionsViewSpec.scala +++ b/src/test/scala/firrtlTests/options/OptionsViewSpec.scala @@ -2,10 +2,9 @@ package firrtlTests.options - import firrtl.options.OptionsView import firrtl.AnnotationSeq -import firrtl.annotations.{Annotation,NoTargetAnnotation} +import firrtl.annotations.{Annotation, NoTargetAnnotation} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers @@ -22,7 +21,7 @@ class OptionsViewSpec extends AnyFlatSpec with Matchers { /* An OptionsView that converts an AnnotationSeq to Option[Foo] */ implicit object FooView extends OptionsView[Foo] { private def append(foo: Foo, anno: Annotation): Foo = anno match { - case NameAnnotation(n) => foo.copy(name = Some(n)) + case NameAnnotation(n) => foo.copy(name = Some(n)) case ValueAnnotation(v) => foo.copy(value = Some(v)) case _ => foo } @@ -40,20 +39,20 @@ class OptionsViewSpec extends AnyFlatSpec with Matchers { def view(options: AnnotationSeq): Bar = options.foldLeft(Bar())(append) } - behavior of "OptionsView" + behavior.of("OptionsView") it should "convert annotations to one of two types" in { /* Some default annotations */ val annos = Seq(NameAnnotation("foo"), ValueAnnotation(42)) info("Foo conversion okay") - FooView.view(annos) should be (Foo(Some("foo"), Some(42))) + FooView.view(annos) should be(Foo(Some("foo"), Some(42))) info("Bar conversion okay") - BarView.view(annos) should be (Bar("foo")) + BarView.view(annos) should be(Bar("foo")) } - behavior of "Viewer" + behavior.of("Viewer") it should "implicitly view annotations as the specified type" in { import firrtl.options.Viewer._ @@ -62,9 +61,9 @@ class OptionsViewSpec extends AnyFlatSpec with Matchers { val annos = Seq[Annotation]() info("Foo view okay") - view[Foo](annos) should be (Foo(None, None)) + view[Foo](annos) should be(Foo(None, None)) info("Bar view okay") - view[Bar](annos) should be (Bar()) + view[Bar](annos) should be(Bar()) } } diff --git a/src/test/scala/firrtlTests/options/PhaseManagerSpec.scala b/src/test/scala/firrtlTests/options/PhaseManagerSpec.scala index 108f3730..f31b96fd 100644 --- a/src/test/scala/firrtlTests/options/PhaseManagerSpec.scala +++ b/src/test/scala/firrtlTests/options/PhaseManagerSpec.scala @@ -2,9 +2,8 @@ package firrtlTests.options - import firrtl.AnnotationSeq -import firrtl.options.{DependencyManagerException, Phase, PhaseManager, Dependency} +import firrtl.options.{Dependency, DependencyManagerException, Phase, PhaseManager} import java.io.{File, PrintWriter} @@ -62,7 +61,6 @@ class F extends IdentityPhase { } } - /** [[Phase]] that requires [[C]] and invalidates [[F]] */ class G extends IdentityPhase { override def prerequisites = Seq(Dependency[C]) @@ -235,7 +233,7 @@ object UnrelatedFixture { trait InvalidatesB8Dep { this: Phase => override def invalidates(a: Phase) = a match { case _: B8Dep => true - case _ => false + case _ => false } } @@ -368,7 +366,7 @@ object OrderingFixture { class B extends IdentityPhase { override def invalidates(phase: Phase): Boolean = phase match { case _: A => true - case _ => false + case _ => false } } @@ -376,7 +374,7 @@ object OrderingFixture { override def prerequisites = Seq(Dependency[A], Dependency[B]) override def invalidates(phase: Phase): Boolean = phase match { case _: B => true - case _ => false + case _ => false } } @@ -423,7 +421,7 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { } - behavior of this.getClass.getName + behavior.of(this.getClass.getName) it should "do nothing if all targets are reached" in { val targets = Seq(Dependency[A], Dependency[B], Dependency[C], Dependency[D]) @@ -431,7 +429,7 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { writeGraphviz(pm, "test_run_dir/PhaseManagerSpec/DoNothing") - pm.flattenedTransformOrder should be (empty) + pm.flattenedTransformOrder should be(empty) } it should "handle a simple dependency" in { @@ -441,7 +439,7 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { writeGraphviz(pm, "test_run_dir/PhaseManagerSpec/SimpleDependency") - pm.flattenedTransformOrder.map(_.getClass) should be (order) + pm.flattenedTransformOrder.map(_.getClass) should be(order) } it should "handle a simple dependency with an invalidation" in { @@ -451,7 +449,7 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { writeGraphviz(pm, "test_run_dir/PhaseManagerSpec/OneInvalidate") - pm.flattenedTransformOrder.map(_.getClass) should be (order) + pm.flattenedTransformOrder.map(_.getClass) should be(order) } it should "handle a dependency with two invalidates optimally" in { @@ -460,7 +458,7 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { writeGraphviz(pm, "test_run_dir/PhaseManagerSpec/TwoInvalidates") - pm.flattenedTransformOrder.size should be (targets.size) + pm.flattenedTransformOrder.size should be(targets.size) } it should "throw an exception for cyclic prerequisites" in { @@ -469,8 +467,9 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { writeGraphviz(pm, "test_run_dir/PhaseManagerSpec/CyclicPrerequisites") - intercept[DependencyManagerException]{ pm.flattenedTransformOrder } - .getMessage should startWith ("No transform ordering possible") + intercept[DependencyManagerException] { pm.flattenedTransformOrder }.getMessage should startWith( + "No transform ordering possible" + ) } it should "throw an exception for cyclic invalidates" in { @@ -479,8 +478,9 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { writeGraphviz(pm, "test_run_dir/PhaseManagerSpec/CyclicInvalidates") - intercept[DependencyManagerException]{ pm.flattenedTransformOrder } - .getMessage should startWith ("No transform ordering possible") + intercept[DependencyManagerException] { pm.flattenedTransformOrder }.getMessage should startWith( + "No transform ordering possible" + ) } it should "handle a complicated graph" in { @@ -491,41 +491,31 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { writeGraphviz(pm, "test_run_dir/PhaseManagerSpec/Complicated") info("only one phase was recomputed") - pm.flattenedTransformOrder.size should be (targets.size + 1) + pm.flattenedTransformOrder.size should be(targets.size + 1) } it should "handle repeated recomputed analyses" in { val f = RepeatedAnalysisFixture val targets = Seq(Dependency[f.A], Dependency[f.B], Dependency[f.C]) val order = - Seq( classOf[f.Analysis], - classOf[f.A], - classOf[f.Analysis], - classOf[f.B], - classOf[f.Analysis], - classOf[f.C]) + Seq(classOf[f.Analysis], classOf[f.A], classOf[f.Analysis], classOf[f.B], classOf[f.Analysis], classOf[f.C]) val pm = new PhaseManager(targets) writeGraphviz(pm, "test_run_dir/PhaseManagerSpec/RepeatedAnalysis") - pm.flattenedTransformOrder.map(_.getClass) should be (order) + pm.flattenedTransformOrder.map(_.getClass) should be(order) } it should "handle inverted repeated recomputed analyses" in { val f = InvertedAnalysisFixture val targets = Seq(Dependency[f.A], Dependency[f.B], Dependency[f.C]) val order = - Seq( classOf[f.Analysis], - classOf[f.C], - classOf[f.Analysis], - classOf[f.B], - classOf[f.Analysis], - classOf[f.A]) + Seq(classOf[f.Analysis], classOf[f.C], classOf[f.Analysis], classOf[f.B], classOf[f.Analysis], classOf[f.A]) val pm = new PhaseManager(targets) writeGraphviz(pm, "test_run_dir/PhaseManagerSpec/InvertedRepeatedAnalysis") - pm.flattenedTransformOrder.map(_.getClass) should be (order) + pm.flattenedTransformOrder.map(_.getClass) should be(order) } /** This test shows how the optionalPrerequisiteOf member can be used to run one transform before another. */ @@ -535,7 +525,7 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { info("without the custom transform it runs: First -> Second") val pm = new PhaseManager(Seq(Dependency[f.Second])) val orderNoCustom = Seq(classOf[f.First], classOf[f.Second]) - pm.flattenedTransformOrder.map(_.getClass) should be (orderNoCustom) + pm.flattenedTransformOrder.map(_.getClass) should be(orderNoCustom) info("with the custom transform it runs: First -> Custom -> Second") val pmCustom = new PhaseManager(Seq(Dependency[f.Custom], Dependency[f.Second])) @@ -543,7 +533,7 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { writeGraphviz(pmCustom, "test_run_dir/PhaseManagerSpec/SingleDependent") - pmCustom.flattenedTransformOrder.map(_.getClass) should be (orderCustom) + pmCustom.flattenedTransformOrder.map(_.getClass) should be(orderCustom) } it should "handle chained invalidation" in { @@ -553,11 +543,11 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { val current = Seq(Dependency[f.B], Dependency[f.C], Dependency[f.D]) val pm = new PhaseManager(targets, current) - val order = Seq( classOf[f.A], classOf[f.B], classOf[f.C], classOf[f.D], classOf[f.E] ) + val order = Seq(classOf[f.A], classOf[f.B], classOf[f.C], classOf[f.D], classOf[f.E]) writeGraphviz(pm, "test_run_dir/PhaseManagerSpec/ChainedInvalidate") - pm.flattenedTransformOrder.map(_.getClass) should be (order) + pm.flattenedTransformOrder.map(_.getClass) should be(order) } it should "maintain the order of input targets" in { @@ -565,62 +555,70 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { /** A bunch of unrelated Phases. This ensures that these run in the order in which they are specified. */ val targets = - Seq( Dependency[f.B0], - Dependency[f.B1], - Dependency[f.B2], - Dependency[f.B3], - Dependency[f.B4], - Dependency[f.B5], - Dependency[f.B6], - Dependency[f.B7], - Dependency[f.B8], - Dependency[f.B9], - Dependency[f.B10], - Dependency[f.B11], - Dependency[f.B12], - Dependency[f.B13], - Dependency[f.B14], - Dependency[f.B15] ) + Seq( + Dependency[f.B0], + Dependency[f.B1], + Dependency[f.B2], + Dependency[f.B3], + Dependency[f.B4], + Dependency[f.B5], + Dependency[f.B6], + Dependency[f.B7], + Dependency[f.B8], + Dependency[f.B9], + Dependency[f.B10], + Dependency[f.B11], + Dependency[f.B12], + Dependency[f.B13], + Dependency[f.B14], + Dependency[f.B15] + ) + /** A sequence of custom transforms that should all run after B6 and before B7. This exercises correct ordering of the * prerequisiteGraph and optionalPrerequisiteOfGraph. */ val prerequisiteTargets = - Seq( Dependency[f.B6_0], - Dependency[f.B6_1], - Dependency[f.B6_2], - Dependency[f.B6_3], - Dependency[f.B6_4], - Dependency[f.B6_5], - Dependency[f.B6_6], - Dependency[f.B6_7], - Dependency[f.B6_8], - Dependency[f.B6_9], - Dependency[f.B6_10], - Dependency[f.B6_11], - Dependency[f.B6_12], - Dependency[f.B6_13], - Dependency[f.B6_14], - Dependency[f.B6_15] ) + Seq( + Dependency[f.B6_0], + Dependency[f.B6_1], + Dependency[f.B6_2], + Dependency[f.B6_3], + Dependency[f.B6_4], + Dependency[f.B6_5], + Dependency[f.B6_6], + Dependency[f.B6_7], + Dependency[f.B6_8], + Dependency[f.B6_9], + Dependency[f.B6_10], + Dependency[f.B6_11], + Dependency[f.B6_12], + Dependency[f.B6_13], + Dependency[f.B6_14], + Dependency[f.B6_15] + ) + /** A sequence of transforms that are invalidated by B0 and only define optionalPrerequisiteOf on B8. This exercises * the ordering defined by "otherPrerequisites". */ val current = - Seq( Dependency[f.B8_0], - Dependency[f.B8_1], - Dependency[f.B8_2], - Dependency[f.B8_3], - Dependency[f.B8_4], - Dependency[f.B8_5], - Dependency[f.B8_6], - Dependency[f.B8_7], - Dependency[f.B8_8], - Dependency[f.B8_9], - Dependency[f.B8_10], - Dependency[f.B8_11], - Dependency[f.B8_12], - Dependency[f.B8_13], - Dependency[f.B8_14], - Dependency[f.B8_15] ) + Seq( + Dependency[f.B8_0], + Dependency[f.B8_1], + Dependency[f.B8_2], + Dependency[f.B8_3], + Dependency[f.B8_4], + Dependency[f.B8_5], + Dependency[f.B8_6], + Dependency[f.B8_7], + Dependency[f.B8_8], + Dependency[f.B8_9], + Dependency[f.B8_10], + Dependency[f.B8_11], + Dependency[f.B8_12], + Dependency[f.B8_13], + Dependency[f.B8_14], + Dependency[f.B8_15] + ) /** The resulting order: B0--B6, B6_0--B6_B15, B7, B8_0--B8_15, B8--B15 */ val expectedDeps = targets.slice(0, 7) ++ prerequisiteTargets ++ Some(targets(7)) ++ current ++ targets.drop(8) @@ -630,7 +628,7 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { writeGraphviz(pm, "test_run_dir/PhaseManagerSpec/DeterministicOrder") - pm.flattenedTransformOrder.map(_.getClass) should be (expectedClasses) + pm.flattenedTransformOrder.map(_.getClass) should be(expectedClasses) } it should "allow conditional placement of custom transforms" in { @@ -642,13 +640,21 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { val targetsFull = Seq(Dependency[f.Custom], Dependency[f.DoneFull]) val pmFull = new PhaseManager(targetsFull) - val expectedMinimum = Seq(classOf[f.Root], classOf[f.OptMinimum], classOf[f.AfterOpt], classOf[f.Custom], classOf[f.DoneMinimum]) + val expectedMinimum = + Seq(classOf[f.Root], classOf[f.OptMinimum], classOf[f.AfterOpt], classOf[f.Custom], classOf[f.DoneMinimum]) writeGraphviz(pmMinimum, "test_run_dir/PhaseManagerSpec/CustomAfterOptimization/minimum") - pmMinimum.flattenedTransformOrder.map(_.getClass) should be (expectedMinimum) - - val expectedFull = Seq(classOf[f.Root], classOf[f.OptMinimum], classOf[f.OptFull], classOf[f.AfterOpt], classOf[f.Custom], classOf[f.DoneFull]) + pmMinimum.flattenedTransformOrder.map(_.getClass) should be(expectedMinimum) + + val expectedFull = Seq( + classOf[f.Root], + classOf[f.OptMinimum], + classOf[f.OptFull], + classOf[f.AfterOpt], + classOf[f.Custom], + classOf[f.DoneFull] + ) writeGraphviz(pmFull, "test_run_dir/PhaseManagerSpec/CustomAfterOptimization/full") - pmFull.flattenedTransformOrder.map(_.getClass) should be (expectedFull) + pmFull.flattenedTransformOrder.map(_.getClass) should be(expectedFull) } it should "support optional prerequisites" in { @@ -662,11 +668,12 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { val expectedMinimum = Seq(classOf[f.Root], classOf[f.OptMinimum], classOf[f.Custom], classOf[f.DoneMinimum]) writeGraphviz(pmMinimum, "test_run_dir/PhaseManagerSpec/CustomAfterOptimization/minimum") - pmMinimum.flattenedTransformOrder.map(_.getClass) should be (expectedMinimum) + pmMinimum.flattenedTransformOrder.map(_.getClass) should be(expectedMinimum) - val expectedFull = Seq(classOf[f.Root], classOf[f.OptMinimum], classOf[f.OptFull], classOf[f.Custom], classOf[f.DoneFull]) + val expectedFull = + Seq(classOf[f.Root], classOf[f.OptMinimum], classOf[f.OptFull], classOf[f.Custom], classOf[f.DoneFull]) writeGraphviz(pmFull, "test_run_dir/PhaseManagerSpec/CustomAfterOptimization/full") - pmFull.flattenedTransformOrder.map(_.getClass) should be (expectedFull) + pmFull.flattenedTransformOrder.map(_.getClass) should be(expectedFull) } /** This tests a situation the ordering of edges matters. Namely, this test is dependent on the ordering in which @@ -678,13 +685,13 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { { val targets = Seq(Dependency[f.A], Dependency[f.B], Dependency[f.C]) val order = Seq(classOf[f.B], classOf[f.A], classOf[f.C], classOf[f.B], classOf[f.A]) - (new PhaseManager(targets)).flattenedTransformOrder.map(_.getClass) should be (order) + (new PhaseManager(targets)).flattenedTransformOrder.map(_.getClass) should be(order) } { val targets = Seq(Dependency[f.A], Dependency[f.B], Dependency[f.Cx]) val order = Seq(classOf[f.B], classOf[f.A], classOf[f.Cx], classOf[f.B], classOf[f.A]) - (new PhaseManager(targets)).flattenedTransformOrder.map(_.getClass) should be (order) + (new PhaseManager(targets)).flattenedTransformOrder.map(_.getClass) should be(order) } } diff --git a/src/test/scala/firrtlTests/options/RegistrationSpec.scala b/src/test/scala/firrtlTests/options/RegistrationSpec.scala index fa6b0fa0..821ac8b3 100644 --- a/src/test/scala/firrtlTests/options/RegistrationSpec.scala +++ b/src/test/scala/firrtlTests/options/RegistrationSpec.scala @@ -6,7 +6,7 @@ import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers import java.util.ServiceLoader -import firrtl.options.{RegisteredTransform, RegisteredLibrary, ShellOption} +import firrtl.options.{RegisteredLibrary, RegisteredTransform, ShellOption} import firrtl.passes.Pass import firrtl.ir.Circuit import firrtl.annotations.NoTargetAnnotation @@ -19,10 +19,8 @@ class FooTransform extends Pass with RegisteredTransform { def run(c: Circuit): Circuit = c val options = Seq( - new ShellOption[Unit]( - longOption = "hello", - toAnnotationSeq = _ => Seq(HelloAnnotation), - helpText = "Hello option") ) + new ShellOption[Unit](longOption = "hello", toAnnotationSeq = _ => Seq(HelloAnnotation), helpText = "Hello option") + ) } @@ -30,15 +28,13 @@ class BarLibrary extends RegisteredLibrary { def name: String = "Bar" val options = Seq( - new ShellOption[Unit]( - longOption = "world", - toAnnotationSeq = _ => Seq(HelloAnnotation), - helpText = "World option") ) + new ShellOption[Unit](longOption = "world", toAnnotationSeq = _ => Seq(HelloAnnotation), helpText = "World option") + ) } class RegistrationSpec extends AnyFlatSpec with Matchers { - behavior of "RegisteredTransform" + behavior.of("RegisteredTransform") it should "FooTransform should be discovered by Java.util.ServiceLoader" in { val iter = ServiceLoader.load(classOf[RegisteredTransform]).iterator() @@ -46,10 +42,10 @@ class RegistrationSpec extends AnyFlatSpec with Matchers { while (iter.hasNext) { transforms += iter.next() } - transforms.map(_.getClass.getName) should contain ("firrtlTests.options.FooTransform") + transforms.map(_.getClass.getName) should contain("firrtlTests.options.FooTransform") } - behavior of "RegisteredLibrary" + behavior.of("RegisteredLibrary") it should "BarLibrary be discovered by Java.util.ServiceLoader" in { val iter = ServiceLoader.load(classOf[RegisteredLibrary]).iterator() @@ -57,6 +53,6 @@ class RegistrationSpec extends AnyFlatSpec with Matchers { while (iter.hasNext) { transforms += iter.next() } - transforms.map(_.getClass.getName) should contain ("firrtlTests.options.BarLibrary") + transforms.map(_.getClass.getName) should contain("firrtlTests.options.BarLibrary") } } diff --git a/src/test/scala/firrtlTests/options/ShellSpec.scala b/src/test/scala/firrtlTests/options/ShellSpec.scala index af6b2669..178b1128 100644 --- a/src/test/scala/firrtlTests/options/ShellSpec.scala +++ b/src/test/scala/firrtlTests/options/ShellSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.options - import firrtl.annotations.NoTargetAnnotation import firrtl.options.Shell import org.scalatest.flatspec.AnyFlatSpec @@ -17,25 +16,26 @@ class ShellSpec extends AnyFlatSpec with Matchers { case object E extends NoTargetAnnotation trait AlphabeticalCli { this: Shell => - parser.opt[Unit]('c', "c-option").unbounded().action( (x, c) => C +: c ) - parser.opt[Unit]('d', "d-option").unbounded().action( (x, c) => D +: c ) - parser.opt[Unit]('e', "e-option").unbounded().action( (x, c) => E +: c ) } + parser.opt[Unit]('c', "c-option").unbounded().action((x, c) => C +: c) + parser.opt[Unit]('d', "d-option").unbounded().action((x, c) => D +: c) + parser.opt[Unit]('e', "e-option").unbounded().action((x, c) => E +: c) + } - behavior of "Shell" + behavior.of("Shell") it should "detect all registered libraries and transforms" in { val shell = new Shell("foo") info("Found FooTransform") - shell.registeredTransforms.map(_.getClass.getName) should contain ("firrtlTests.options.FooTransform") + shell.registeredTransforms.map(_.getClass.getName) should contain("firrtlTests.options.FooTransform") info("Found BarLibrary") - shell.registeredLibraries.map(_.getClass.getName) should contain ("firrtlTests.options.BarLibrary") + shell.registeredLibraries.map(_.getClass.getName) should contain("firrtlTests.options.BarLibrary") } it should "correctly order annotations and options" in { val shell = new Shell("foo") with AlphabeticalCli - shell.parse(Array("-c", "-d", "-e"), Seq(A, B)).toSeq should be (Seq(A, B, C, D, E)) + shell.parse(Array("-c", "-d", "-e"), Seq(A, B)).toSeq should be(Seq(A, B, C, D, E)) } } diff --git a/src/test/scala/firrtlTests/options/phases/AddDefaultsSpec.scala b/src/test/scala/firrtlTests/options/phases/AddDefaultsSpec.scala index 3401a408..f625f991 100644 --- a/src/test/scala/firrtlTests/options/phases/AddDefaultsSpec.scala +++ b/src/test/scala/firrtlTests/options/phases/AddDefaultsSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.options.phases - import firrtl.options.{Phase, TargetDirAnnotation} import firrtl.options.phases.AddDefaults import org.scalatest.flatspec.AnyFlatSpec @@ -16,13 +15,13 @@ class AddDefaultsSpec extends AnyFlatSpec with Matchers { val defaultDir = TargetDirAnnotation(".") } - behavior of classOf[AddDefaults].toString + behavior.of(classOf[AddDefaults].toString) it should "add a TargetDirAnnotation if it does not exist" in new Fixture { - phase.transform(Seq.empty).toSeq should be (Seq(defaultDir)) + phase.transform(Seq.empty).toSeq should be(Seq(defaultDir)) } it should "don't add a TargetDirAnnotation if it exists" in new Fixture { - phase.transform(Seq(targetDir)).toSeq should be (Seq(targetDir)) + phase.transform(Seq(targetDir)).toSeq should be(Seq(targetDir)) } } diff --git a/src/test/scala/firrtlTests/options/phases/ChecksSpec.scala b/src/test/scala/firrtlTests/options/phases/ChecksSpec.scala index 96d6569d..62afed94 100644 --- a/src/test/scala/firrtlTests/options/phases/ChecksSpec.scala +++ b/src/test/scala/firrtlTests/options/phases/ChecksSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.options.phases - import firrtl.AnnotationSeq import firrtl.options.{OptionsException, OutputAnnotationFileAnnotation, Phase, TargetDirAnnotation} import firrtl.options.phases.Checks @@ -20,9 +19,9 @@ class ChecksSpec extends AnyFlatSpec with Matchers { val min = Seq(targetDir) def checkExceptionMessage(phase: Phase, annotations: AnnotationSeq, messageStart: String): Unit = - intercept[OptionsException]{ phase.transform(annotations) }.getMessage should startWith(messageStart) + intercept[OptionsException] { phase.transform(annotations) }.getMessage should startWith(messageStart) - behavior of classOf[Checks].toString + behavior.of(classOf[Checks].toString) it should "enforce exactly one TargetDirAnnotation" in new Fixture { info("0 target directories throws an exception") diff --git a/src/test/scala/firrtlTests/options/phases/GetIncludesSpec.scala b/src/test/scala/firrtlTests/options/phases/GetIncludesSpec.scala index 7d20ac89..95c2a435 100644 --- a/src/test/scala/firrtlTests/options/phases/GetIncludesSpec.scala +++ b/src/test/scala/firrtlTests/options/phases/GetIncludesSpec.scala @@ -2,12 +2,10 @@ package firrtlTests.options.phases - import java.io.{File, PrintWriter} import firrtl.AnnotationSeq -import firrtl.annotations.{AnnotationFileNotFoundException, JsonProtocol, - NoTargetAnnotation} +import firrtl.annotations.{AnnotationFileNotFoundException, JsonProtocol, NoTargetAnnotation} import firrtl.options.phases.GetIncludes import firrtl.options.{InputAnnotationFileAnnotation, Phase} import firrtl.util.BackendCompilationUtilities @@ -29,10 +27,10 @@ class GetIncludesSpec extends AnyFlatSpec with Matchers with BackendCompilationU def checkAnnos(a: AnnotationSeq, b: AnnotationSeq): Unit = { info("read the expected number of annotations") - a.size should be (b.size) + a.size should be(b.size) info("annotations match exact order") - a.zip(b).foreach{ case (ax, bx) => ax should be (bx) } + a.zip(b).foreach { case (ax, bx) => ax should be(bx) } } val files = Seq( @@ -43,19 +41,21 @@ class GetIncludesSpec extends AnyFlatSpec with Matchers with BackendCompilationU new File(dir + "/e.anno.json") -> Seq(E) ) - files.foreach{ case (file, annotations) => - val pw = new PrintWriter(file) - pw.write(JsonProtocol.serialize(annotations)) - pw.close() + files.foreach { + case (file, annotations) => + val pw = new PrintWriter(file) + pw.write(JsonProtocol.serialize(annotations)) + pw.close() } class Fixture { val phase: Phase = new GetIncludes } - behavior of classOf[GetIncludes].toString + behavior.of(classOf[GetIncludes].toString) it should "throw an exception if the annotation file doesn't exit" in new Fixture { - intercept[AnnotationFileNotFoundException]{ phase.transform(Seq(ref("f"))) } - .getMessage should startWith("Annotation file") + intercept[AnnotationFileNotFoundException] { phase.transform(Seq(ref("f"))) }.getMessage should startWith( + "Annotation file" + ) } it should "read annotations from a file" in new Fixture { @@ -75,9 +75,9 @@ class GetIncludesSpec extends AnyFlatSpec with Matchers with BackendCompilationU checkAnnos(out, expect) - Seq("d", "e").foreach{ x => + Seq("d", "e").foreach { x => info(s"a warning about '$x.anno.json' was printed") - stdout should include (s"Warning: Annotation file ($dir/$x.anno.json) already included!") + stdout should include(s"Warning: Annotation file ($dir/$x.anno.json) already included!") } } @@ -90,7 +90,7 @@ class GetIncludesSpec extends AnyFlatSpec with Matchers with BackendCompilationU checkAnnos(out, expect) info("a warning about 'a.anno.json' was printed") - stdout should include (s"Warning: Annotation file ($dir/a.anno.json)") + stdout should include(s"Warning: Annotation file ($dir/a.anno.json)") } } diff --git a/src/test/scala/firrtlTests/options/phases/WriteOutputAnnotationsSpec.scala b/src/test/scala/firrtlTests/options/phases/WriteOutputAnnotationsSpec.scala index 0a3cce67..4fe16041 100644 --- a/src/test/scala/firrtlTests/options/phases/WriteOutputAnnotationsSpec.scala +++ b/src/test/scala/firrtlTests/options/phases/WriteOutputAnnotationsSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.options.phases - import java.io.File import firrtl.AnnotationSeq @@ -15,7 +14,8 @@ import firrtl.options.{ PhaseException, StageOptions, TargetDirAnnotation, - WriteDeletedAnnotation} + WriteDeletedAnnotation +} import firrtl.options.Viewer.view import firrtl.options.phases.{GetIncludes, WriteOutputAnnotations} import org.scalatest.flatspec.AnyFlatSpec @@ -37,33 +37,38 @@ class WriteOutputAnnotationsSpec extends AnyFlatSpec with Matchers with firrtl.t info(s"reading '$f' works") val read = (new GetIncludes) .transform(Seq(InputAnnotationFileAnnotation(f.toString))) - .filterNot{ + .filterNot { case a @ DeletedAnnotation(_, _: InputAnnotationFileAnnotation) => true - case _ => false } + case _ => false + } info(s"annotations in file are expected size") - read.size should be (a.size) + read.size should be(a.size) read .zip(a) - .foreach{ case (read, expected) => - info(s"$read matches") - read should be (expected) } + .foreach { + case (read, expected) => + info(s"$read matches") + read should be(expected) + } f.delete() } class Fixture { val phase: Phase = new WriteOutputAnnotations } - behavior of classOf[WriteOutputAnnotations].toString + behavior.of(classOf[WriteOutputAnnotations].toString) it should "write annotations to a file (excluding DeletedAnnotations)" in new Fixture { val file = new File(dir + "/should-write-annotations-to-a-file.anno.json") - val annotations = Seq( OutputAnnotationFileAnnotation(file.toString), - WriteOutputAnnotationsSpec.FooAnnotation, - WriteOutputAnnotationsSpec.BarAnnotation(0), - WriteOutputAnnotationsSpec.BarAnnotation(1), - DeletedAnnotation("foo", WriteOutputAnnotationsSpec.FooAnnotation) ) + val annotations = Seq( + OutputAnnotationFileAnnotation(file.toString), + WriteOutputAnnotationsSpec.FooAnnotation, + WriteOutputAnnotationsSpec.BarAnnotation(0), + WriteOutputAnnotationsSpec.BarAnnotation(1), + DeletedAnnotation("foo", WriteOutputAnnotationsSpec.FooAnnotation) + ) val expected = annotations.filter { case a: DeletedAnnotation => false case a => true @@ -71,31 +76,35 @@ class WriteOutputAnnotationsSpec extends AnyFlatSpec with Matchers with firrtl.t val out = phase.transform(annotations) info("annotations are unmodified") - out.toSeq should be (annotations) + out.toSeq should be(annotations) fileContainsAnnotations(file, expected) } it should "include DeletedAnnotations if a WriteDeletedAnnotation is present" in new Fixture { val file = new File(dir + "should-include-deleted.anno.json") - val annotations = Seq( OutputAnnotationFileAnnotation(file.toString), - WriteOutputAnnotationsSpec.FooAnnotation, - WriteOutputAnnotationsSpec.BarAnnotation(0), - WriteOutputAnnotationsSpec.BarAnnotation(1), - DeletedAnnotation("foo", WriteOutputAnnotationsSpec.FooAnnotation), - WriteDeletedAnnotation ) + val annotations = Seq( + OutputAnnotationFileAnnotation(file.toString), + WriteOutputAnnotationsSpec.FooAnnotation, + WriteOutputAnnotationsSpec.BarAnnotation(0), + WriteOutputAnnotationsSpec.BarAnnotation(1), + DeletedAnnotation("foo", WriteOutputAnnotationsSpec.FooAnnotation), + WriteDeletedAnnotation + ) val out = phase.transform(annotations) info("annotations are unmodified") - out.toSeq should be (annotations) + out.toSeq should be(annotations) fileContainsAnnotations(file, annotations) } it should "do nothing if no output annotation file is specified" in new Fixture { - val annotations = Seq( WriteOutputAnnotationsSpec.FooAnnotation, - WriteOutputAnnotationsSpec.BarAnnotation(0), - WriteOutputAnnotationsSpec.BarAnnotation(1) ) + val annotations = Seq( + WriteOutputAnnotationsSpec.FooAnnotation, + WriteOutputAnnotationsSpec.BarAnnotation(0), + WriteOutputAnnotationsSpec.BarAnnotation(1) + ) val out = catchWrites { phase.transform(annotations) } match { case Right(a) => @@ -106,14 +115,16 @@ class WriteOutputAnnotationsSpec extends AnyFlatSpec with Matchers with firrtl.t } info("annotations are unmodified") - out.toSeq should be (annotations) + out.toSeq should be(annotations) } it should "write CustomFileEmission annotations" in new Fixture { val file = new File("write-CustomFileEmission-annotations.anno.json") - val annotations = Seq( TargetDirAnnotation(dir), - OutputAnnotationFileAnnotation(file.toString), - WriteOutputAnnotationsSpec.Custom("hello!") ) + val annotations = Seq( + TargetDirAnnotation(dir), + OutputAnnotationFileAnnotation(file.toString), + WriteOutputAnnotationsSpec.Custom("hello!") + ) val serializedFileName = view[StageOptions](annotations).getBuildFileName("Custom", Some(".Emission")) val expected = annotations.map { case _: WriteOutputAnnotationsSpec.Custom => WriteOutputAnnotationsSpec.Replacement(serializedFileName) @@ -123,7 +134,7 @@ class WriteOutputAnnotationsSpec extends AnyFlatSpec with Matchers with firrtl.t val out = phase.transform(annotations) info("annotations are unmodified") - out.toSeq should be (annotations) + out.toSeq should be(annotations) fileContainsAnnotations(new File(dir, file.toString), expected) @@ -133,13 +144,15 @@ class WriteOutputAnnotationsSpec extends AnyFlatSpec with Matchers with firrtl.t it should "error if multiple annotations try to write to the same file" in new Fixture { val file = new File("write-CustomFileEmission-annotations-error.anno.json") - val annotations = Seq( TargetDirAnnotation(dir), - OutputAnnotationFileAnnotation(file.toString), - WriteOutputAnnotationsSpec.Custom("foo"), - WriteOutputAnnotationsSpec.Custom("bar") ) + val annotations = Seq( + TargetDirAnnotation(dir), + OutputAnnotationFileAnnotation(file.toString), + WriteOutputAnnotationsSpec.Custom("foo"), + WriteOutputAnnotationsSpec.Custom("bar") + ) intercept[PhaseException] { phase.transform(annotations) - }.getMessage should startWith ("Multiple CustomFileEmission annotations") + }.getMessage should startWith("Multiple CustomFileEmission annotations") } } diff --git a/src/test/scala/firrtlTests/passes/InferTypesFlowsAndKindsSpec.scala b/src/test/scala/firrtlTests/passes/InferTypesFlowsAndKindsSpec.scala index bfc72f49..b628c1b7 100644 --- a/src/test/scala/firrtlTests/passes/InferTypesFlowsAndKindsSpec.scala +++ b/src/test/scala/firrtlTests/passes/InferTypesFlowsAndKindsSpec.scala @@ -6,48 +6,50 @@ import firrtl.ir.SubField import firrtl.options.Dependency import firrtl.stage.TransformManager import firrtl.{InstanceKind, MemKind, NodeKind, PortKind, RegKind, WireKind} -import firrtl.{CircuitState, SinkFlow, SourceFlow, ir, passes} +import firrtl.{ir, passes, CircuitState, SinkFlow, SourceFlow} import org.scalatest.flatspec.AnyFlatSpec /** Tests the combined results of ResolveKinds, InferTypes and ResolveFlows */ class InferTypesFlowsAndKindsSpec extends AnyFlatSpec { - private val deps = Seq( - Dependency(passes.ResolveKinds), - Dependency(passes.InferTypes), - Dependency(passes.ResolveFlows)) + private val deps = + Seq(Dependency(passes.ResolveKinds), Dependency(passes.InferTypes), Dependency(passes.ResolveFlows)) private val manager = new TransformManager(deps) private def infer(src: String): ir.Circuit = manager.execute(CircuitState(firrtl.Parser.parse(src), Seq())).circuit private def getNodes(s: ir.Statement): Seq[(String, ir.Expression)] = s match { - case ir.DefNode(_, name, value) => Seq((name, value)) - case ir.Block(stmts) => stmts.flatMap(getNodes) - case ir.Conditionally(_, _, a, b) => Seq(a,b).flatMap(getNodes) - case _ => Seq() + case ir.DefNode(_, name, value) => Seq((name, value)) + case ir.Block(stmts) => stmts.flatMap(getNodes) + case ir.Conditionally(_, _, a, b) => Seq(a, b).flatMap(getNodes) + case _ => Seq() } private def getConnects(s: ir.Statement): Seq[ir.Connect] = s match { - case c : ir.Connect => Seq(c) - case ir.Block(stmts) => stmts.flatMap(getConnects) - case ir.Conditionally(_, _, a, b) => Seq(a,b).flatMap(getConnects) - case _ => Seq() + case c: ir.Connect => Seq(c) + case ir.Block(stmts) => stmts.flatMap(getConnects) + case ir.Conditionally(_, _, a, b) => Seq(a, b).flatMap(getConnects) + case _ => Seq() } private def getModule(c: ir.Circuit, name: String): ir.Module = c.modules.find(_.name == name).get.asInstanceOf[ir.Module] it should "infer references to ports, wires, nodes and registers" in { - val node = getNodes(getModule(infer( - """circuit m: - | module m: - | input clk: Clock - | input a: UInt<4> - | wire b : SInt<5> - | reg c: UInt<5>, clk - | node na = a - | node nb = b - | node nc = c - | node nna = na - | node na2 = a - | node a_plus_c = add(a, c) - |""".stripMargin), "m").body).toMap + val node = getNodes( + getModule( + infer("""circuit m: + | module m: + | input clk: Clock + | input a: UInt<4> + | wire b : SInt<5> + | reg c: UInt<5>, clk + | node na = a + | node nb = b + | node nc = c + | node nna = na + | node na2 = a + | node a_plus_c = add(a, c) + |""".stripMargin), + "m" + ).body + ).toMap assert(node("na").tpe == ir.UIntType(ir.IntWidth(4))) assert(node("na").asInstanceOf[ir.Reference].flow == SourceFlow) @@ -74,29 +76,29 @@ class InferTypesFlowsAndKindsSpec extends AnyFlatSpec { } it should "infer types for references to instances" in { - val m = getModule(infer( - """circuit m: - | module other: - | output x: { y: UInt, flip z: UInt<1> } - | module m: - | inst i of other - | node i_x = i.x - | node i_x_y = i.x.y - | node i_x_y_2 = i_x.y - | node a = UInt<1>(1) - | i.x.z <= a - |""".stripMargin), "m") + val m = getModule( + infer("""circuit m: + | module other: + | output x: { y: UInt, flip z: UInt<1> } + | module m: + | inst i of other + | node i_x = i.x + | node i_x_y = i.x.y + | node i_x_y_2 = i_x.y + | node a = UInt<1>(1) + | i.x.z <= a + |""".stripMargin), + "m" + ) val node = getNodes(m.body).toMap val con = getConnects(m.body) - // node i_x_y = i.x.y assert(node("i_x_y").tpe.isInstanceOf[ir.UIntType]) // the type inference replaces all unknown widths with a variable assert(node("i_x_y").tpe.asInstanceOf[ir.UIntType].width.isInstanceOf[ir.VarWidth]) assert(node("i_x_y").asInstanceOf[ir.SubField].flow == SourceFlow) - // node i_x = i.x val x = node("i_x").asInstanceOf[ir.SubField] assert(x.tpe.isInstanceOf[ir.BundleType]) @@ -110,12 +112,10 @@ class InferTypesFlowsAndKindsSpec extends AnyFlatSpec { assert(i.kind == InstanceKind) assert(i.flow == SourceFlow) - // node i_x_y_2 = i_x.y assert(node("i_x_y").tpe == node("i_x_y_2").tpe) assert(node("i_x_y").asInstanceOf[ir.SubField].flow == node("i_x_y_2").asInstanceOf[ir.SubField].flow) - // i.x.z <= a val (left, right) = (con.head.loc.asInstanceOf[ir.SubField], con.head.expr.asInstanceOf[ir.Reference]) @@ -131,29 +131,27 @@ class InferTypesFlowsAndKindsSpec extends AnyFlatSpec { } it should "infer types for references to memories" in { - val c = infer( - """circuit m: - | module m: - | mem m: - | data-type => UInt - | depth => 30 - | reader => r - | writer => w - | read-latency => 1 - | write-latency => 1 - | read-under-write => undefined - | - | node m_r_addr = m.r.addr - | node m_r_data = m.r.data - | node m_w_addr = m.w.addr - | node m_w_data = m.w.data - |""".stripMargin) + val c = infer("""circuit m: + | module m: + | mem m: + | data-type => UInt + | depth => 30 + | reader => r + | writer => w + | read-latency => 1 + | write-latency => 1 + | read-under-write => undefined + | + | node m_r_addr = m.r.addr + | node m_r_data = m.r.data + | node m_w_addr = m.w.addr + | node m_w_data = m.w.data + |""".stripMargin) val m = getModule(c, "m") val node = getNodes(m.body).toMap // this might be a little flaky... val memory = m.body.asInstanceOf[ir.Block].stmts.head.asInstanceOf[ir.DefMemory] - // after InferTypes, all expressions referring to the `data` should have this type: val dataTpe = memory.dataType.asInstanceOf[ir.UIntType] val addrTpe = ir.UIntType(ir.IntWidth(5)) @@ -163,8 +161,12 @@ class InferTypesFlowsAndKindsSpec extends AnyFlatSpec { assert(node("m_w_addr").tpe == addrTpe) assert(node("m_w_data").tpe == dataTpe) - val memory_ref = node("m_r_addr").asInstanceOf[ir.SubField].expr - .asInstanceOf[ir.SubField].expr.asInstanceOf[ir.Reference] + val memory_ref = node("m_r_addr") + .asInstanceOf[ir.SubField] + .expr + .asInstanceOf[ir.SubField] + .expr + .asInstanceOf[ir.Reference] assert(memory_ref.kind == MemKind) val mem_ref_tpe = memory_ref.tpe.asInstanceOf[ir.BundleType] val r_tpe = mem_ref_tpe.fields.find(_.name == "r").get.tpe.asInstanceOf[ir.BundleType] @@ -176,18 +178,17 @@ class InferTypesFlowsAndKindsSpec extends AnyFlatSpec { } it should "infer different instances of the same module to have the same width variable" in { - val c = infer( - """circuit m: - | module other: - | input x: UInt - | module x: - | inst i of other - | i.x <= UInt<16>(3) - | module m: - | inst x of x - | inst i of other - | i.x <= UInt<1>(1) - |""".stripMargin) + val c = infer("""circuit m: + | module other: + | input x: UInt + | module x: + | inst i of other + | i.x <= UInt<16>(3) + | module m: + | inst x of x + | inst i of other + | i.x <= UInt<1>(1) + |""".stripMargin) val m_con = getConnects(getModule(c, "m").body).head val x_con = getConnects(getModule(c, "x").body).head val other = getModule(c, "other") diff --git a/src/test/scala/firrtlTests/stage/FirrtlCliSpec.scala b/src/test/scala/firrtlTests/stage/FirrtlCliSpec.scala index 157520ea..d4caa546 100644 --- a/src/test/scala/firrtlTests/stage/FirrtlCliSpec.scala +++ b/src/test/scala/firrtlTests/stage/FirrtlCliSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.stage - import firrtl.stage.RunFirrtlTransformAnnotation import firrtl.options.Shell import firrtl.stage.FirrtlCli @@ -11,25 +10,30 @@ import org.scalatest.matchers.should.Matchers class FirrtlCliSpec extends AnyFlatSpec with Matchers { - behavior of "FirrtlCli for RunFirrtlTransformAnnotation / -fct / --custom-transforms" + behavior.of("FirrtlCli for RunFirrtlTransformAnnotation / -fct / --custom-transforms") it should "preserver transform order" in { val shell = new Shell("foo") with FirrtlCli val args = Array( - "--custom-transforms", "firrtl.transforms.BlackBoxSourceHelper,firrtl.transforms.CheckCombLoops", - "--custom-transforms", "firrtl.transforms.CombineCats", - "--custom-transforms", "firrtl.transforms.ConstantPropagation" ) + "--custom-transforms", + "firrtl.transforms.BlackBoxSourceHelper,firrtl.transforms.CheckCombLoops", + "--custom-transforms", + "firrtl.transforms.CombineCats", + "--custom-transforms", + "firrtl.transforms.ConstantPropagation" + ) val expected = Seq( classOf[firrtl.transforms.BlackBoxSourceHelper], classOf[firrtl.transforms.CheckCombLoops], classOf[firrtl.transforms.CombineCats], - classOf[firrtl.transforms.ConstantPropagation] ) + classOf[firrtl.transforms.ConstantPropagation] + ) shell .parse(args) - .collect{ case a: RunFirrtlTransformAnnotation => a } + .collect { case a: RunFirrtlTransformAnnotation => a } .zip(expected) - .map{ case (RunFirrtlTransformAnnotation(a), b) => a.getClass should be (b) } + .map { case (RunFirrtlTransformAnnotation(a), b) => a.getClass should be(b) } } } diff --git a/src/test/scala/firrtlTests/stage/FirrtlMainSpec.scala b/src/test/scala/firrtlTests/stage/FirrtlMainSpec.scala index 7d57f7ed..9274bac6 100644 --- a/src/test/scala/firrtlTests/stage/FirrtlMainSpec.scala +++ b/src/test/scala/firrtlTests/stage/FirrtlMainSpec.scala @@ -21,7 +21,11 @@ import org.scalatest.matchers.should.Matchers * This test uses the [[org.scalatest.FeatureSpec FeatureSpec]] intentionally as this test exercises the top-level * interface and is more suitable to an Acceptance Testing style. */ -class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers with firrtl.testutils.Utils +class FirrtlMainSpec + extends AnyFeatureSpec + with GivenWhenThen + with Matchers + with firrtl.testutils.Utils with BackendCompilationUtilities { /** Parameterizes one test of [[FirrtlMain]]. Running the [[FirrtlMain]] `main` with certain args should produce @@ -36,13 +40,14 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit * @param result expected exit code */ case class FirrtlMainTest( - args: Array[String], - circuit: Option[FirrtlCircuitFixture] = Some(new SimpleFirrtlCircuitFixture), - files: Seq[String] = Seq.empty, + args: Array[String], + circuit: Option[FirrtlCircuitFixture] = Some(new SimpleFirrtlCircuitFixture), + files: Seq[String] = Seq.empty, notFiles: Seq[String] = Seq.empty, - stdout: Option[String] = None, - stderr: Option[String] = None, - result: Int = 0) { + stdout: Option[String] = None, + stderr: Option[String] = None, + result: Int = 0) { + /** Generate a name for the test based on the arguments */ def testName: String = "args" + args.mkString("_") @@ -70,8 +75,8 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit case None => Array.empty } - p.files.foreach( f => new File(td.buildDir + s"/$f").delete() ) - p.notFiles.foreach( f => new File(td.buildDir + s"/$f").delete() ) + p.files.foreach(f => new File(td.buildDir + s"/$f").delete()) + p.notFiles.foreach(f => new File(td.buildDir + s"/$f").delete()) When(s"""the user tries to compile with '${p.argsString}'""") val (stdout, stderr, result) = @@ -80,25 +85,25 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit p.stdout match { case Some(a) => Then(s"""STDOUT should include "$a"""") - stdout should include (a) + stdout should include(a) case None => Then(s"nothing should print to STDOUT") - stdout should be (empty) + stdout should be(empty) } p.stderr match { case Some(a) => And(s"""STDERR should include "$a"""") - stderr should include (a) + stderr should include(a) case None => And(s"nothing should print to STDERR") - stderr should be (empty) + stderr should be(empty) } p.result match { case 0 => And(s"the exit code should be 0") - result shouldBe a [Right[_,_]] + result shouldBe a[Right[_, _]] case a => And(s"the exit code should be $a") result shouldBe (Left(a)) @@ -113,12 +118,11 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit p.notFiles.foreach { f => And(s"file '$f' should NOT be emitted in the target directory") val out = new File(td.buildDir + s"/$f") - out should not (exist) + out should not(exist) } } } - /** Test fixture that links to the [[FirrtlMain]] object. This could be done without, but its use matches the * Given/When/Then style more accurately. */ @@ -137,7 +141,7 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit } trait FirrtlCircuitFixture { - val main: String + val main: String val input: String } @@ -185,13 +189,13 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit val (out, _, result) = grabStdOutErr { catchStatus { f.stage.main(Array("--help")) } } Then("the usage text should be shown") - out should include ("Usage: firrtl") + out should include("Usage: firrtl") And("usage text should show known registered transforms") - out should include ("--no-dce") + out should include("--no-dce") And("usage text should show known registered libraries") - out should include ("MemLib Options") + out should include("MemLib Options") info("""And the exit code should be 0, but scopt catches all throwable, so we can't check this... ¯\_(ツ)_/¯""") // And("the exit code should be zero") @@ -200,67 +204,89 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit Seq( /* Test all standard emitters with and without annotation file outputs */ - FirrtlMainTest(args = Array("-X", "none", "-E", "chirrtl"), - files = Seq("Top.fir")), - FirrtlMainTest(args = Array("-X", "high", "-E", "high"), - stdout = defaultStdOut, - files = Seq("Top.hi.fir")), - FirrtlMainTest(args = Array("-X", "middle", "-E", "middle", "-foaf", "Top"), - stdout = defaultStdOut, - files = Seq("Top.mid.fir", "Top.anno.json")), - FirrtlMainTest(args = Array("-X", "low", "-E", "low", "-foaf", "annotations.anno.json"), - stdout = defaultStdOut, - files = Seq("Top.lo.fir", "annotations.anno.json")), - FirrtlMainTest(args = Array("-X", "verilog", "-E", "verilog", "-foaf", "foo.anno"), - stdout = defaultStdOut, - files = Seq("Top.v", "foo.anno.anno.json")), - FirrtlMainTest(args = Array("-X", "sverilog", "-E", "sverilog", "-foaf", "foo.json"), - stdout = defaultStdOut, - files = Seq("Top.sv", "foo.json.anno.json")), - + FirrtlMainTest(args = Array("-X", "none", "-E", "chirrtl"), files = Seq("Top.fir")), + FirrtlMainTest(args = Array("-X", "high", "-E", "high"), stdout = defaultStdOut, files = Seq("Top.hi.fir")), + FirrtlMainTest( + args = Array("-X", "middle", "-E", "middle", "-foaf", "Top"), + stdout = defaultStdOut, + files = Seq("Top.mid.fir", "Top.anno.json") + ), + FirrtlMainTest( + args = Array("-X", "low", "-E", "low", "-foaf", "annotations.anno.json"), + stdout = defaultStdOut, + files = Seq("Top.lo.fir", "annotations.anno.json") + ), + FirrtlMainTest( + args = Array("-X", "verilog", "-E", "verilog", "-foaf", "foo.anno"), + stdout = defaultStdOut, + files = Seq("Top.v", "foo.anno.anno.json") + ), + FirrtlMainTest( + args = Array("-X", "sverilog", "-E", "sverilog", "-foaf", "foo.json"), + stdout = defaultStdOut, + files = Seq("Top.sv", "foo.json.anno.json") + ), /* Test all one file per module emitters */ - FirrtlMainTest(args = Array("-X", "none", "-e", "chirrtl"), - files = Seq("Top.fir", "Child.fir")), - FirrtlMainTest(args = Array("-X", "high", "-e", "high"), - stdout = defaultStdOut, - files = Seq("Top.hi.fir", "Child.hi.fir")), - FirrtlMainTest(args = Array("-X", "middle", "-e", "middle"), - stdout = defaultStdOut, - files = Seq("Top.mid.fir", "Child.mid.fir")), - FirrtlMainTest(args = Array("-X", "low", "-e", "low"), - stdout = defaultStdOut, - files = Seq("Top.lo.fir", "Child.lo.fir")), - FirrtlMainTest(args = Array("-X", "verilog", "-e", "verilog"), - stdout = defaultStdOut, - files = Seq("Top.v", "Child.v")), - FirrtlMainTest(args = Array("-X", "sverilog", "-e", "sverilog"), - stdout = defaultStdOut, - files = Seq("Top.sv", "Child.sv")), - + FirrtlMainTest(args = Array("-X", "none", "-e", "chirrtl"), files = Seq("Top.fir", "Child.fir")), + FirrtlMainTest( + args = Array("-X", "high", "-e", "high"), + stdout = defaultStdOut, + files = Seq("Top.hi.fir", "Child.hi.fir") + ), + FirrtlMainTest( + args = Array("-X", "middle", "-e", "middle"), + stdout = defaultStdOut, + files = Seq("Top.mid.fir", "Child.mid.fir") + ), + FirrtlMainTest( + args = Array("-X", "low", "-e", "low"), + stdout = defaultStdOut, + files = Seq("Top.lo.fir", "Child.lo.fir") + ), + FirrtlMainTest( + args = Array("-X", "verilog", "-e", "verilog"), + stdout = defaultStdOut, + files = Seq("Top.v", "Child.v") + ), + FirrtlMainTest( + args = Array("-X", "sverilog", "-e", "sverilog"), + stdout = defaultStdOut, + files = Seq("Top.sv", "Child.sv") + ), /* Test mixing of -E with -e */ - FirrtlMainTest(args = Array("-X", "middle", "-E", "high", "-e", "middle"), - stdout = defaultStdOut, - files = Seq("Top.hi.fir", "Top.mid.fir", "Child.mid.fir"), - notFiles = Seq("Child.hi.fir")), - + FirrtlMainTest( + args = Array("-X", "middle", "-E", "high", "-e", "middle"), + stdout = defaultStdOut, + files = Seq("Top.hi.fir", "Top.mid.fir", "Child.mid.fir"), + notFiles = Seq("Child.hi.fir") + ), /* Test changes to output file name */ - FirrtlMainTest(args = Array("-X", "none", "-E", "chirrtl", "-o", "foo"), - files = Seq("foo.fir")), - FirrtlMainTest(args = Array("-X", "high", "-E", "high", "-o", "foo"), - stdout = defaultStdOut, - files = Seq("foo.hi.fir")), - FirrtlMainTest(args = Array("-X", "middle", "-E", "middle", "-o", "foo.middle"), - stdout = defaultStdOut, - files = Seq("foo.middle.mid.fir")), - FirrtlMainTest(args = Array("-X", "low", "-E", "low", "-o", "foo.lo.fir"), - stdout = defaultStdOut, - files = Seq("foo.lo.fir")), - FirrtlMainTest(args = Array("-X", "verilog", "-E", "verilog", "-o", "foo.sv"), - stdout = defaultStdOut, - files = Seq("foo.sv.v")), - FirrtlMainTest(args = Array("-X", "sverilog", "-E", "sverilog", "-o", "Foo"), - stdout = defaultStdOut, - files = Seq("Foo.sv")) + FirrtlMainTest(args = Array("-X", "none", "-E", "chirrtl", "-o", "foo"), files = Seq("foo.fir")), + FirrtlMainTest( + args = Array("-X", "high", "-E", "high", "-o", "foo"), + stdout = defaultStdOut, + files = Seq("foo.hi.fir") + ), + FirrtlMainTest( + args = Array("-X", "middle", "-E", "middle", "-o", "foo.middle"), + stdout = defaultStdOut, + files = Seq("foo.middle.mid.fir") + ), + FirrtlMainTest( + args = Array("-X", "low", "-E", "low", "-o", "foo.lo.fir"), + stdout = defaultStdOut, + files = Seq("foo.lo.fir") + ), + FirrtlMainTest( + args = Array("-X", "verilog", "-E", "verilog", "-o", "foo.sv"), + stdout = defaultStdOut, + files = Seq("foo.sv.v") + ), + FirrtlMainTest( + args = Array("-X", "sverilog", "-E", "sverilog", "-o", "Foo"), + stdout = defaultStdOut, + files = Seq("Foo.sv") + ) ) .foreach(runStageExpectFiles) @@ -272,15 +298,17 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit val out = new File(s"$outName.hi.fir") out.delete() val result = catchStatus { - f.stage.main(Array("-i", "src/test/resources/integration/GCDTester.fir", "-o", outName, "-X", "high", - "-E", "high")) } + f.stage.main( + Array("-i", "src/test/resources/integration/GCDTester.fir", "-o", outName, "-X", "high", "-E", "high") + ) + } Then("outputs should be written to current directory") out should (exist) out.delete() And("the exit code should be 0") - result shouldBe a [Right[_,_]] + result shouldBe a[Right[_, _]] } Scenario("User provides Protocol Buffer input") { @@ -292,8 +320,9 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit copyResourceToFile("/integration/GCDTester.pb", protobufIn) When("the user tries to compile to High FIRRTL") - f.stage.main(Array("-i", protobufIn.toString, "-X", "high", "-E", "high", "-td", td.buildDir.toString, - "-o", "Foo")) + f.stage.main( + Array("-i", protobufIn.toString, "-X", "high", "-E", "high", "-td", td.buildDir.toString, "-o", "Foo") + ) Then("the output should be the same as using FIRRTL input") new File(td.buildDir + "/Foo.hi.fir") should (exist) @@ -311,16 +340,16 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit val (out, err, result) = grabStdOutErr { catchStatus { f.stage.main(Array.empty) } } Then("an error should be printed on stdout") - out should include (s"Error: Unable to determine FIRRTL source to read") + out should include(s"Error: Unable to determine FIRRTL source to read") And("no usage text should be shown") - out should not include ("Usage: firrtl") + (out should not).include("Usage: firrtl") And("nothing should print to stderr") - err should be (empty) + err should be(empty) And("the exit code should be 1") - result should be (Left(1)) + result should be(Left(1)) } } @@ -333,22 +362,30 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit Seq( /* Erroneous inputs */ - FirrtlMainTest(args = Array("--thisIsNotASupportedOption"), - circuit = None, - stdout = Some("Error: Unknown option"), - result = 1), - FirrtlMainTest(args = Array("-i", "foo", "--info-mode", "Use"), - circuit = None, - stdout = Some("Unknown info mode 'Use'! (Did you misspell it?)"), - result = 1), - FirrtlMainTest(args = Array("-i", "test_run_dir/I-DO-NOT-EXIST"), - circuit = None, - stdout = Some("Input file 'test_run_dir/I-DO-NOT-EXIST' not found!"), - result = 1), - FirrtlMainTest(args = Array("-i", "foo", "-X", "Verilog"), - circuit = None, - stdout = Some("Unknown compiler name 'Verilog'! (Did you misspell it?)"), - result = 1) + FirrtlMainTest( + args = Array("--thisIsNotASupportedOption"), + circuit = None, + stdout = Some("Error: Unknown option"), + result = 1 + ), + FirrtlMainTest( + args = Array("-i", "foo", "--info-mode", "Use"), + circuit = None, + stdout = Some("Unknown info mode 'Use'! (Did you misspell it?)"), + result = 1 + ), + FirrtlMainTest( + args = Array("-i", "test_run_dir/I-DO-NOT-EXIST"), + circuit = None, + stdout = Some("Input file 'test_run_dir/I-DO-NOT-EXIST' not found!"), + result = 1 + ), + FirrtlMainTest( + args = Array("-i", "foo", "-X", "Verilog"), + circuit = None, + stdout = Some("Unknown compiler name 'Verilog'! (Did you misspell it?)"), + result = 1 + ) ) .foreach(runStageExpectFiles) @@ -364,13 +401,13 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit val (out, _, result) = grabStdOutErr { catchStatus { f.stage.main(Array("--show-registrations")) } } Then("stdout should show registered transforms") - out should include ("firrtl.passes.InlineInstances") + out should include("firrtl.passes.InlineInstances") And("stdout should show registered libraries") out should include("firrtl.passes.memlib.MemLibOptions") And("the exit code should be 1") - result should be (Left(1)) + result should be(Left(1)) } } @@ -380,23 +417,21 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit def optionRemoved(a: String): Option[String] = Some(s"Option '$a' was removed as part of the FIRRTL Stage refactor") Seq( /* Removed --top-name/-tn handling */ - FirrtlMainTest(args = Array("--top-name", "foo"), - circuit = None, - stdout = optionRemoved("--top-name/-tn"), - result = 1), - FirrtlMainTest(args = Array("-tn"), - circuit = None, - stdout = optionRemoved("--top-name/-tn"), - result = 1), + FirrtlMainTest( + args = Array("--top-name", "foo"), + circuit = None, + stdout = optionRemoved("--top-name/-tn"), + result = 1 + ), + FirrtlMainTest(args = Array("-tn"), circuit = None, stdout = optionRemoved("--top-name/-tn"), result = 1), /* Removed --split-modules/-fsm handling */ - FirrtlMainTest(args = Array("--split-modules"), - circuit = None, - stdout = optionRemoved("--split-modules/-fsm"), - result = 1), - FirrtlMainTest(args = Array("-fsm"), - circuit = None, - stdout = optionRemoved("--split-modules/-fsm"), - result = 1) + FirrtlMainTest( + args = Array("--split-modules"), + circuit = None, + stdout = optionRemoved("--split-modules/-fsm"), + result = 1 + ), + FirrtlMainTest(args = Array("-fsm"), circuit = None, stdout = optionRemoved("--split-modules/-fsm"), result = 1) ) .foreach(runStageExpectFiles) } diff --git a/src/test/scala/firrtlTests/stage/FirrtlOptionsViewSpec.scala b/src/test/scala/firrtlTests/stage/FirrtlOptionsViewSpec.scala index 4161d29b..00aa8e6a 100644 --- a/src/test/scala/firrtlTests/stage/FirrtlOptionsViewSpec.scala +++ b/src/test/scala/firrtlTests/stage/FirrtlOptionsViewSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.stage - import firrtl.stage._ import firrtl.{ir, NoneCompiler, Parser} @@ -17,7 +16,7 @@ class Baz_Compiler extends NoneCompiler class FirrtlOptionsViewSpec extends AnyFlatSpec with Matchers { - behavior of FirrtlOptionsView.getClass.getName + behavior.of(FirrtlOptionsView.getClass.getName) def circuitString(main: String): String = s"""|circuit $main: | module $main: @@ -37,9 +36,9 @@ class FirrtlOptionsViewSpec extends AnyFlatSpec with Matchers { it should "construct a view from an AnnotationSeq" in { val out = view[FirrtlOptions](annotations) - out.outputFileName should be (Some("bar")) - out.infoModeName should be ("use") - out.firrtlCircuit should be (Some(grault)) + out.outputFileName should be(Some("bar")) + out.infoModeName should be("use") + out.firrtlCircuit should be(Some(grault)) } /* This test only exists to catch changes to existing behavior. This test does not indicate that this is the correct @@ -57,9 +56,9 @@ class FirrtlOptionsViewSpec extends AnyFlatSpec with Matchers { val out = view[FirrtlOptions](annotations ++ overwrites) - out.outputFileName should be (Some("bar_")) - out.infoModeName should be ("gen") - out.firrtlCircuit should be (Some(grault_)) + out.outputFileName should be(Some("bar_")) + out.infoModeName should be("gen") + out.firrtlCircuit should be(Some(grault_)) } } diff --git a/src/test/scala/firrtlTests/stage/phases/AddCircuitSpec.scala b/src/test/scala/firrtlTests/stage/phases/AddCircuitSpec.scala index 58026ecd..aac18dee 100644 --- a/src/test/scala/firrtlTests/stage/phases/AddCircuitSpec.scala +++ b/src/test/scala/firrtlTests/stage/phases/AddCircuitSpec.scala @@ -2,12 +2,16 @@ package firrtlTests.stage.phases - import firrtl.Parser import firrtl.annotations.NoTargetAnnotation import firrtl.options.{OptionsException, Phase, PhasePrerequisiteException} -import firrtl.stage.{CircuitOption, FirrtlCircuitAnnotation, FirrtlSourceAnnotation, InfoModeAnnotation, - FirrtlFileAnnotation} +import firrtl.stage.{ + CircuitOption, + FirrtlCircuitAnnotation, + FirrtlFileAnnotation, + FirrtlSourceAnnotation, + InfoModeAnnotation +} import firrtl.stage.phases.AddCircuit import java.io.{File, FileWriter} @@ -21,7 +25,7 @@ class AddCircuitSpec extends AnyFlatSpec with Matchers { class Fixture { val phase: Phase = new AddCircuit } - behavior of classOf[AddCircuit].toString + behavior.of(classOf[AddCircuit].toString) def firrtlSource(name: String): String = s"""|circuit $name: @@ -32,15 +36,16 @@ class AddCircuitSpec extends AnyFlatSpec with Matchers { |""".stripMargin it should "throw a PhasePrerequisiteException if a CircuitOption exists without an InfoModeAnnotation" in - new Fixture { - {the [PhasePrerequisiteException] thrownBy phase.transform(Seq(FirrtlSourceAnnotation("foo")))} - .message should startWith ("An InfoModeAnnotation must be present") - } + new Fixture { + { + the[PhasePrerequisiteException] thrownBy phase.transform(Seq(FirrtlSourceAnnotation("foo"))) + }.message should startWith("An InfoModeAnnotation must be present") + } it should "do nothing if no CircuitOption annotations are present" in new Fixture { val annotations = (1 to 10).map(FooAnnotation) ++ ('a' to 'm').map(_.toString).map(BarAnnotation) :+ InfoModeAnnotation("ignore") - phase.transform(annotations).toSeq should be (annotations.toSeq) + phase.transform(annotations).toSeq should be(annotations.toSeq) } val (file, fileCircuit) = { @@ -66,39 +71,44 @@ class AddCircuitSpec extends AnyFlatSpec with Matchers { FirrtlFileAnnotation(file), FirrtlSourceAnnotation(source), FirrtlCircuitAnnotation(circuit), - InfoModeAnnotation("ignore") ) + InfoModeAnnotation("ignore") + ) val annotationsExpected = Set( FirrtlCircuitAnnotation(fileCircuit), FirrtlCircuitAnnotation(sourceCircuit), - FirrtlCircuitAnnotation(circuit) ) + FirrtlCircuitAnnotation(circuit) + ) val out = phase.transform(annotations).toSeq info("generated expected FirrtlCircuitAnnotations") - out.collect{ case a: FirrtlCircuitAnnotation => a}.toSet should be (annotationsExpected) + out.collect { case a: FirrtlCircuitAnnotation => a }.toSet should be(annotationsExpected) info("all CircuitOptions were removed") - out.collect{ case a: CircuitOption => a } should be (empty) + out.collect { case a: CircuitOption => a } should be(empty) } it should """add info for a FirrtlFileAnnotation with a "gen" info mode""" in new Fixture { - phase.transform(Seq(InfoModeAnnotation("gen"), FirrtlFileAnnotation(file))) - .collectFirst{ case a: FirrtlCircuitAnnotation => a.circuit.serialize } - .get should include ("AddCircuitSpec") + phase + .transform(Seq(InfoModeAnnotation("gen"), FirrtlFileAnnotation(file))) + .collectFirst { case a: FirrtlCircuitAnnotation => a.circuit.serialize } + .get should include("AddCircuitSpec") } it should """add info for a FirrtlSourceAnnotation with an "append" info mode""" in new Fixture { - phase.transform(Seq(InfoModeAnnotation("append"), FirrtlSourceAnnotation(source))) - .collectFirst{ case a: FirrtlCircuitAnnotation => a.circuit.serialize } - .get should include ("anonymous source") + phase + .transform(Seq(InfoModeAnnotation("append"), FirrtlSourceAnnotation(source))) + .collectFirst { case a: FirrtlCircuitAnnotation => a.circuit.serialize } + .get should include("anonymous source") } it should "throw an OptionsException if the specified file doesn't exist" in new Fixture { val a = Seq(InfoModeAnnotation("ignore"), FirrtlFileAnnotation("test_run_dir/I-DO-NOT-EXIST")) - {the [OptionsException] thrownBy phase.transform(a)} - .message should startWith (s"Input file 'test_run_dir/I-DO-NOT-EXIST' not found") + { the[OptionsException] thrownBy phase.transform(a) }.message should startWith( + s"Input file 'test_run_dir/I-DO-NOT-EXIST' not found" + ) } } diff --git a/src/test/scala/firrtlTests/stage/phases/AddDefaultsSpec.scala b/src/test/scala/firrtlTests/stage/phases/AddDefaultsSpec.scala index b600e6c5..686c42ad 100644 --- a/src/test/scala/firrtlTests/stage/phases/AddDefaultsSpec.scala +++ b/src/test/scala/firrtlTests/stage/phases/AddDefaultsSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.stage.phases - import firrtl.NoneCompiler import firrtl.annotations.Annotation import firrtl.stage.phases.AddDefaults @@ -16,24 +15,25 @@ class AddDefaultsSpec extends AnyFlatSpec with Matchers { class Fixture { val phase: Phase = new AddDefaults } - behavior of classOf[AddDefaults].toString + behavior.of(classOf[AddDefaults].toString) it should "add expected default annotations and nothing else" in new Fixture { val expected = Seq( (a: Annotation) => a match { case BlackBoxTargetDirAnno(b) => b == TargetDirAnnotation().directory }, - (a: Annotation) => a match { case RunFirrtlTransformAnnotation(e: firrtl.Emitter) => - Dependency.fromTransform(e) == Dependency[firrtl.VerilogEmitter] }, - (a: Annotation) => a match { case InfoModeAnnotation(b) => b == InfoModeAnnotation().modeName } ) - - phase.transform(Seq.empty).zip(expected).map { case (x, f) => f(x) should be (true) } + (a: Annotation) => + a match { + case RunFirrtlTransformAnnotation(e: firrtl.Emitter) => + Dependency.fromTransform(e) == Dependency[firrtl.VerilogEmitter] + }, + (a: Annotation) => a match { case InfoModeAnnotation(b) => b == InfoModeAnnotation().modeName } + ) + + phase.transform(Seq.empty).zip(expected).map { case (x, f) => f(x) should be(true) } } it should "not overwrite existing annotations" in new Fixture { - val input = Seq( - BlackBoxTargetDirAnno("foo"), - CompilerAnnotation(new NoneCompiler()), - InfoModeAnnotation("ignore")) + val input = Seq(BlackBoxTargetDirAnno("foo"), CompilerAnnotation(new NoneCompiler()), InfoModeAnnotation("ignore")) - phase.transform(input).toSeq should be (input) + phase.transform(input).toSeq should be(input) } } diff --git a/src/test/scala/firrtlTests/stage/phases/AddImplicitEmitterSpec.scala b/src/test/scala/firrtlTests/stage/phases/AddImplicitEmitterSpec.scala index 941f1883..1252090b 100644 --- a/src/test/scala/firrtlTests/stage/phases/AddImplicitEmitterSpec.scala +++ b/src/test/scala/firrtlTests/stage/phases/AddImplicitEmitterSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.stage.phases - import firrtl.{EmitAllModulesAnnotation, EmitCircuitAnnotation, HighFirrtlEmitter, VerilogCompiler} import firrtl.annotations.NoTargetAnnotation import firrtl.options.Phase @@ -20,27 +19,26 @@ class AddImplicitEmitterSpec extends AnyFlatSpec with Matchers { val someAnnos = Seq(FooAnnotation(1), FooAnnotation(2), BarAnnotation("bar")) - behavior of classOf[AddImplicitEmitter].toString + behavior.of(classOf[AddImplicitEmitter].toString) it should "do nothing if no CompilerAnnotation is present" in new Fixture { - phase.transform(someAnnos).toSeq should be (someAnnos) + phase.transform(someAnnos).toSeq should be(someAnnos) } it should "add an EmitCircuitAnnotation derived from a CompilerAnnotation" in new Fixture { val input = CompilerAnnotation(new VerilogCompiler) +: someAnnos - val expected = input.flatMap{ - case a@ CompilerAnnotation(b) => Seq(a, - RunFirrtlTransformAnnotation(b.emitter), - EmitCircuitAnnotation(b.emitter.getClass)) + val expected = input.flatMap { + case a @ CompilerAnnotation(b) => + Seq(a, RunFirrtlTransformAnnotation(b.emitter), EmitCircuitAnnotation(b.emitter.getClass)) case a => Some(a) } - phase.transform(input).toSeq should be (expected) + phase.transform(input).toSeq should be(expected) } it should "not add an EmitCircuitAnnotation if an EmitAnnotation already exists" in new Fixture { - val input = Seq(CompilerAnnotation(new VerilogCompiler), - EmitAllModulesAnnotation(classOf[HighFirrtlEmitter])) ++ someAnnos - phase.transform(input).toSeq should be (input) + val input = + Seq(CompilerAnnotation(new VerilogCompiler), EmitAllModulesAnnotation(classOf[HighFirrtlEmitter])) ++ someAnnos + phase.transform(input).toSeq should be(input) } } diff --git a/src/test/scala/firrtlTests/stage/phases/AddImplicitOutputFileSpec.scala b/src/test/scala/firrtlTests/stage/phases/AddImplicitOutputFileSpec.scala index 5ec051f4..499b05ae 100644 --- a/src/test/scala/firrtlTests/stage/phases/AddImplicitOutputFileSpec.scala +++ b/src/test/scala/firrtlTests/stage/phases/AddImplicitOutputFileSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.stage.phases - import firrtl.{ChirrtlEmitter, EmitAllModulesAnnotation, Parser} import firrtl.options.Phase import firrtl.stage.{FirrtlCircuitAnnotation, OutputFileAnnotation} @@ -21,27 +20,27 @@ class AddImplicitOutputFileSpec extends AnyFlatSpec with Matchers { val circuit = Parser.parse(foo) - behavior of classOf[AddImplicitOutputFile].toString + behavior.of(classOf[AddImplicitOutputFile].toString) it should "default to an output file named 'a'" in new Fixture { - phase.transform(Seq.empty).toSeq should be (Seq(OutputFileAnnotation("a"))) + phase.transform(Seq.empty).toSeq should be(Seq(OutputFileAnnotation("a"))) } it should "set the output file based on a FirrtlCircuitAnnotation's main" in new Fixture { val in = Seq(FirrtlCircuitAnnotation(circuit)) val out = OutputFileAnnotation(circuit.main) +: in - phase.transform(in).toSeq should be (out) + phase.transform(in).toSeq should be(out) } it should "do nothing if an OutputFileAnnotation or EmitAllModulesAnnotation already exists" in new Fixture { info("OutputFileAnnotation works") val outputFile = Seq(OutputFileAnnotation("Bar"), FirrtlCircuitAnnotation(circuit)) - phase.transform(outputFile).toSeq should be (outputFile) + phase.transform(outputFile).toSeq should be(outputFile) info("EmitAllModulesAnnotation works") val eam = Seq(EmitAllModulesAnnotation(classOf[ChirrtlEmitter]), FirrtlCircuitAnnotation(circuit)) - phase.transform(eam).toSeq should be (eam) + phase.transform(eam).toSeq should be(eam) } } diff --git a/src/test/scala/firrtlTests/stage/phases/ChecksSpec.scala b/src/test/scala/firrtlTests/stage/phases/ChecksSpec.scala index e10bbe6d..65516da5 100644 --- a/src/test/scala/firrtlTests/stage/phases/ChecksSpec.scala +++ b/src/test/scala/firrtlTests/stage/phases/ChecksSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.stage.phases - import firrtl.stage._ import firrtl.{AnnotationSeq, ChirrtlEmitter, EmitAllModulesAnnotation, NoneCompiler} @@ -25,9 +24,9 @@ class ChecksSpec extends AnyFlatSpec with Matchers { val min = Seq(inputFile, goodCompiler, infoMode) def checkExceptionMessage(phase: Phase, annotations: AnnotationSeq, messageStart: String): Unit = - intercept[OptionsException]{ phase.transform(annotations) }.getMessage should startWith(messageStart) + intercept[OptionsException] { phase.transform(annotations) }.getMessage should startWith(messageStart) - behavior of classOf[Checks].toString + behavior.of(classOf[Checks].toString) it should "require exactly one input source" in new Fixture { info("0 input source causes an exception") @@ -74,8 +73,11 @@ class ChecksSpec extends AnyFlatSpec with Matchers { it should "enforce exactly one info mode" in new Fixture { info("0 info modes should throw an exception") - checkExceptionMessage(phase, Seq(inputFile, goodCompiler), - "Exactly one info mode must be specified, but none found") + checkExceptionMessage( + phase, + Seq(inputFile, goodCompiler), + "Exactly one info mode must be specified, but none found" + ) info("2 info modes should throw an exception") val i = infoMode.modeName diff --git a/src/test/scala/firrtlTests/stage/phases/CompilerSpec.scala b/src/test/scala/firrtlTests/stage/phases/CompilerSpec.scala index 0446d4a3..12ec66c2 100644 --- a/src/test/scala/firrtlTests/stage/phases/CompilerSpec.scala +++ b/src/test/scala/firrtlTests/stage/phases/CompilerSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.stage.phases - import scala.collection.mutable import firrtl.{Compiler => _, _} @@ -16,10 +15,10 @@ class CompilerSpec extends AnyFlatSpec with Matchers { class Fixture { val phase: Phase = new Compiler } - behavior of classOf[Compiler].toString + behavior.of(classOf[Compiler].toString) it should "do nothing for an empty AnnotationSeq" in new Fixture { - phase.transform(Seq.empty).toSeq should be (empty) + phase.transform(Seq.empty).toSeq should be(empty) } /** A circuit with a parameterized main (top name) that is different at High, Mid, and Low FIRRTL forms. */ @@ -36,11 +35,9 @@ class CompilerSpec extends AnyFlatSpec with Matchers { val circuitIn = Parser.parse(chirrtl("top")) val circuitOut = compiler.compile(CircuitState(circuitIn, ChirrtlForm), Seq.empty).circuit - val input = Seq( - FirrtlCircuitAnnotation(circuitIn), - CompilerAnnotation(compiler) ) + val input = Seq(FirrtlCircuitAnnotation(circuitIn), CompilerAnnotation(compiler)) - phase.transform(input).toSeq should be (Seq(FirrtlCircuitAnnotation(circuitOut))) + phase.transform(input).toSeq should be(Seq(FirrtlCircuitAnnotation(circuitOut))) } it should "compile multiple FirrtlCircuitAnnotations" in new Fixture { @@ -50,32 +47,31 @@ class CompilerSpec extends AnyFlatSpec with Matchers { new MiddleFirrtlCompiler, new LowFirrtlCompiler, new VerilogCompiler, - new SystemVerilogCompiler ) + new SystemVerilogCompiler + ) val (ce, hfe, mfe, lfe, ve, sve) = ( new ChirrtlEmitter, new HighFirrtlEmitter, new MiddleFirrtlEmitter, new LowFirrtlEmitter, new VerilogEmitter, - new SystemVerilogEmitter ) + new SystemVerilogEmitter + ) val a = Seq( /* Default Compiler is HighFirrtlCompiler */ CompilerAnnotation(hfc), - /* First compiler group, use NoneCompiler */ FirrtlCircuitAnnotation(Parser.parse(chirrtl("a"))), CompilerAnnotation(nc), RunFirrtlTransformAnnotation(ce), EmitCircuitAnnotation(ce.getClass), - /* Second compiler group, use default HighFirrtlCompiler */ FirrtlCircuitAnnotation(Parser.parse(chirrtl("b"))), RunFirrtlTransformAnnotation(ce), EmitCircuitAnnotation(ce.getClass), RunFirrtlTransformAnnotation(hfe), EmitCircuitAnnotation(hfe.getClass), - /* Third compiler group, use MiddleFirrtlCompiler */ FirrtlCircuitAnnotation(Parser.parse(chirrtl("c"))), CompilerAnnotation(mfc), @@ -85,7 +81,6 @@ class CompilerSpec extends AnyFlatSpec with Matchers { EmitCircuitAnnotation(hfe.getClass), RunFirrtlTransformAnnotation(mfe), EmitCircuitAnnotation(mfe.getClass), - /* Fourth compiler group, use LowFirrtlCompiler*/ FirrtlCircuitAnnotation(Parser.parse(chirrtl("d"))), CompilerAnnotation(lfc), @@ -97,7 +92,6 @@ class CompilerSpec extends AnyFlatSpec with Matchers { EmitCircuitAnnotation(mfe.getClass), RunFirrtlTransformAnnotation(lfe), EmitCircuitAnnotation(lfe.getClass), - /* Fifth compiler group, use VerilogCompiler */ FirrtlCircuitAnnotation(Parser.parse(chirrtl("e"))), CompilerAnnotation(vc), @@ -111,7 +105,6 @@ class CompilerSpec extends AnyFlatSpec with Matchers { EmitCircuitAnnotation(lfe.getClass), RunFirrtlTransformAnnotation(ve), EmitCircuitAnnotation(ve.getClass), - /* Sixth compiler group, use SystemVerilogCompiler */ FirrtlCircuitAnnotation(Parser.parse(chirrtl("f"))), CompilerAnnotation(svc), @@ -130,14 +123,10 @@ class CompilerSpec extends AnyFlatSpec with Matchers { val output = phase.transform(a) info("with the same number of output FirrtlCircuitAnnotations") - output - .collect{ case a: FirrtlCircuitAnnotation => a } - .size should be (6) + output.collect { case a: FirrtlCircuitAnnotation => a }.size should be(6) info("and all expected EmittedAnnotations should be generated") - output - .collect{ case a: EmittedAnnotation[_] => a } - .size should be (20) + output.collect { case a: EmittedAnnotation[_] => a }.size should be(20) } it should "run transforms in sequential order" in new Fixture { @@ -145,20 +134,23 @@ class CompilerSpec extends AnyFlatSpec with Matchers { val circuitIn = Parser.parse(chirrtl("top")) val annotations = - Seq( FirrtlCircuitAnnotation(circuitIn), - CompilerAnnotation(new VerilogCompiler), - RunFirrtlTransformAnnotation(new FirstTransform), - RunFirrtlTransformAnnotation(new SecondTransform) ) + Seq( + FirrtlCircuitAnnotation(circuitIn), + CompilerAnnotation(new VerilogCompiler), + RunFirrtlTransformAnnotation(new FirstTransform), + RunFirrtlTransformAnnotation(new SecondTransform) + ) phase.transform(annotations) - CompilerSpec.globalState.toSeq should be (Seq(classOf[FirstTransform], classOf[SecondTransform])) + CompilerSpec.globalState.toSeq should be(Seq(classOf[FirstTransform], classOf[SecondTransform])) } } object CompilerSpec { - private[CompilerSpec] val globalState: mutable.Queue[Class[_ <: Transform]] = mutable.Queue.empty[Class[_ <: Transform]] + private[CompilerSpec] val globalState: mutable.Queue[Class[_ <: Transform]] = + mutable.Queue.empty[Class[_ <: Transform]] class LoggingTransform extends Transform { override def inputForm = UnknownForm diff --git a/src/test/scala/firrtlTests/stage/phases/WriteEmittedSpec.scala b/src/test/scala/firrtlTests/stage/phases/WriteEmittedSpec.scala index bbec32fe..73ee455d 100644 --- a/src/test/scala/firrtlTests/stage/phases/WriteEmittedSpec.scala +++ b/src/test/scala/firrtlTests/stage/phases/WriteEmittedSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.stage.phases - import java.io.File import firrtl._ @@ -22,21 +21,22 @@ class WriteEmittedSpec extends AnyFlatSpec with Matchers { class Fixture { val phase: Phase = new WriteEmitted } - behavior of classOf[WriteEmitted].toString + behavior.of(classOf[WriteEmitted].toString) it should "write emitted circuits" in new Fixture { val annotations = Seq( TargetDirAnnotation("test_run_dir/WriteEmittedSpec"), EmittedFirrtlCircuitAnnotation(EmittedFirrtlCircuit("foo", "", ".foocircuit")), EmittedFirrtlCircuitAnnotation(EmittedFirrtlCircuit("bar", "", ".barcircuit")), - EmittedVerilogCircuitAnnotation(EmittedVerilogCircuit("baz", "", ".bazcircuit")) ) + EmittedVerilogCircuitAnnotation(EmittedVerilogCircuit("baz", "", ".bazcircuit")) + ) val expected = Seq("foo.foocircuit", "bar.barcircuit", "baz.bazcircuit") .map(a => new File(s"test_run_dir/WriteEmittedSpec/$a")) info("annotations are unmodified") - phase.transform(annotations).toSeq should be (removeEmitted(annotations).toSeq) + phase.transform(annotations).toSeq should be(removeEmitted(annotations).toSeq) - expected.foreach{ a => + expected.foreach { a => info(s"$a was written") a should (exist) a.delete() @@ -47,11 +47,12 @@ class WriteEmittedSpec extends AnyFlatSpec with Matchers { val annotations = Seq( TargetDirAnnotation("test_run_dir/WriteEmittedSpec"), OutputFileAnnotation("quux"), - EmittedFirrtlCircuitAnnotation(EmittedFirrtlCircuit("qux", "", ".quxcircuit")) ) + EmittedFirrtlCircuitAnnotation(EmittedFirrtlCircuit("qux", "", ".quxcircuit")) + ) val expected = new File("test_run_dir/WriteEmittedSpec/quux.quxcircuit") info("annotations are unmodified") - phase.transform(annotations).toSeq should be (removeEmitted(annotations).toSeq) + phase.transform(annotations).toSeq should be(removeEmitted(annotations).toSeq) info(s"$expected was written") expected should (exist) @@ -63,14 +64,15 @@ class WriteEmittedSpec extends AnyFlatSpec with Matchers { TargetDirAnnotation("test_run_dir/WriteEmittedSpec"), EmittedFirrtlModuleAnnotation(EmittedFirrtlModule("foo", "", ".foomodule")), EmittedFirrtlModuleAnnotation(EmittedFirrtlModule("bar", "", ".barmodule")), - EmittedVerilogModuleAnnotation(EmittedVerilogModule("baz", "", ".bazmodule")) ) + EmittedVerilogModuleAnnotation(EmittedVerilogModule("baz", "", ".bazmodule")) + ) val expected = Seq("foo.foomodule", "bar.barmodule", "baz.bazmodule") .map(a => new File(s"test_run_dir/WriteEmittedSpec/$a")) info("EmittedComponent annotations are deleted") - phase.transform(annotations).toSeq should be (removeEmitted(annotations).toSeq) + phase.transform(annotations).toSeq should be(removeEmitted(annotations).toSeq) - expected.foreach{ a => + expected.foreach { a => info(s"$a was written") a should (exist) a.delete() diff --git a/src/test/scala/firrtlTests/transforms/BlackBoxSourceHelperSpec.scala b/src/test/scala/firrtlTests/transforms/BlackBoxSourceHelperSpec.scala index 2c746c99..a52df4a9 100644 --- a/src/test/scala/firrtlTests/transforms/BlackBoxSourceHelperSpec.scala +++ b/src/test/scala/firrtlTests/transforms/BlackBoxSourceHelperSpec.scala @@ -8,9 +8,8 @@ import firrtl.{Transform, VerilogEmitter} import firrtl.FileUtils import firrtl.testutils.LowTransformSpec - class BlacklBoxSourceHelperTransformSpec extends LowTransformSpec { - def transform: Transform = new BlackBoxSourceHelper + def transform: Transform = new BlackBoxSourceHelper private val moduleName = ModuleName("Top", CircuitName("Top")) private val input = """ @@ -31,21 +30,21 @@ class BlacklBoxSourceHelperTransformSpec extends LowTransformSpec { | y <= a1.bar """.stripMargin private val output = """ - |circuit Top : - | - | extmodule AdderExtModule : - | input foo : UInt<16> - | output bar : UInt<16> - | - | defname = BBFAdd - | - | module Top : - | input x : UInt<16> - | output y : UInt<16> - | - | inst a1 of AdderExtModule - | y <= a1.bar - | a1.foo <= x + |circuit Top : + | + | extmodule AdderExtModule : + | input foo : UInt<16> + | output bar : UInt<16> + | + | defname = BBFAdd + | + | module Top : + | input x : UInt<16> + | output y : UInt<16> + | + | inst a1 of AdderExtModule + | y <= a1.bar + | a1.foo <= x """.stripMargin "annotated external modules with absolute path" should "appear in output directory" in { @@ -61,8 +60,8 @@ class BlacklBoxSourceHelperTransformSpec extends LowTransformSpec { val module = new java.io.File("test_run_dir/AdderExtModule.v") val fileList = new java.io.File(s"test_run_dir/${BlackBoxSourceHelper.defaultFileListName}") - module.exists should be (true) - fileList.exists should be (true) + module.exists should be(true) + fileList.exists should be(true) module.delete() fileList.delete() @@ -80,8 +79,8 @@ class BlacklBoxSourceHelperTransformSpec extends LowTransformSpec { val module = new java.io.File("test_run_dir/AdderExtModule.v") val fileList = new java.io.File(s"test_run_dir/${BlackBoxSourceHelper.defaultFileListName}") - module.exists should be (true) - fileList.exists should be (true) + module.exists should be(true) + fileList.exists should be(true) module.delete() fileList.delete() @@ -96,8 +95,8 @@ class BlacklBoxSourceHelperTransformSpec extends LowTransformSpec { execute(input, output, annos) - new java.io.File("test_run_dir/AdderExtModule.v").exists should be (true) - new java.io.File(s"test_run_dir/${BlackBoxSourceHelper.defaultFileListName}").exists should be (true) + new java.io.File("test_run_dir/AdderExtModule.v").exists should be(true) + new java.io.File(s"test_run_dir/${BlackBoxSourceHelper.defaultFileListName}").exists should be(true) } "verilog header files" should "be available but not mentioned in the file list" in { @@ -114,40 +113,41 @@ class BlacklBoxSourceHelperTransformSpec extends LowTransformSpec { // We'll copy the following resources to the test_run_dir via BlackBoxResourceAnno's val resourceNames = Seq("ParameterizedViaHeaderAdderExtModule.v", "VerilogHeaderFile.vh") - val annos = Seq( - BlackBoxTargetDirAnno("test_run_dir")) ++ resourceNames.map{ n => BlackBoxResourceAnno(moduleName, "/blackboxes/" + n)} + val annos = Seq(BlackBoxTargetDirAnno("test_run_dir")) ++ resourceNames.map { n => + BlackBoxResourceAnno(moduleName, "/blackboxes/" + n) + } execute(pInput, pOutput, annos) // Our resource files should exist in the test_run_dir, for (n <- resourceNames) - new java.io.File("test_run_dir/" + n).exists should be (true) + new java.io.File("test_run_dir/" + n).exists should be(true) // but our file list should not include the verilog header file. val fileListFile = new java.io.File(s"test_run_dir/${BlackBoxSourceHelper.defaultFileListName}") - fileListFile.exists should be (true) + fileListFile.exists should be(true) val fileList = FileUtils.getText(fileListFile) - fileList.contains("ParameterizedViaHeaderAdderExtModule.v") should be (true) - fileList.contains("VerilogHeaderFile.vh") should be (false) + fileList.contains("ParameterizedViaHeaderAdderExtModule.v") should be(true) + fileList.contains("VerilogHeaderFile.vh") should be(false) } - behavior of "BlackBox resources that do not exist" + behavior.of("BlackBox resources that do not exist") it should "provide a useful error message for BlackBoxResourceAnno" in { - val annos = Seq( BlackBoxTargetDirAnno("test_run_dir"), - BlackBoxResourceAnno(moduleName, "/blackboxes/IDontExist.v") ) + val annos = Seq(BlackBoxTargetDirAnno("test_run_dir"), BlackBoxResourceAnno(moduleName, "/blackboxes/IDontExist.v")) - (the [BlackBoxNotFoundException] thrownBy { execute(input, "", annos) }) - .getMessage should include ("Did you misspell it?") + (the[BlackBoxNotFoundException] thrownBy { execute(input, "", annos) }).getMessage should include( + "Did you misspell it?" + ) } it should "provide a useful error message for BlackBoxPathAnno" in { val absPath = new java.io.File("src/test/resources/blackboxes/IDontExist.v").getCanonicalPath - val annos = Seq( BlackBoxTargetDirAnno("test_run_dir"), - BlackBoxPathAnno(moduleName, absPath) ) + val annos = Seq(BlackBoxTargetDirAnno("test_run_dir"), BlackBoxPathAnno(moduleName, absPath)) - (the [BlackBoxNotFoundException] thrownBy { execute(input, "", annos) }) - .getMessage should include ("Did you misspell it?") + (the[BlackBoxNotFoundException] thrownBy { execute(input, "", annos) }).getMessage should include( + "Did you misspell it?" + ) } } diff --git a/src/test/scala/firrtlTests/transforms/CombineCatsSpec.scala b/src/test/scala/firrtlTests/transforms/CombineCatsSpec.scala index f2672bce..a916eac5 100644 --- a/src/test/scala/firrtlTests/transforms/CombineCatsSpec.scala +++ b/src/test/scala/firrtlTests/transforms/CombineCatsSpec.scala @@ -14,9 +14,11 @@ class CombineCatsSpec extends FirrtlFlatSpec { private val annotations = Seq(new MaxCatLenAnnotation(12)) private def execute(input: String, transforms: Seq[Transform], annotations: AnnotationSeq): CircuitState = { - val c = transforms.foldLeft(CircuitState(parse(input), UnknownForm, annotations)) { - (c: CircuitState, t: Transform) => t.runTransform(c) - }.circuit + val c = transforms + .foldLeft(CircuitState(parse(input), UnknownForm, annotations)) { (c: CircuitState, t: Transform) => + t.runTransform(c) + } + .circuit CircuitState(c, UnknownForm, Seq(), None) } @@ -86,11 +88,24 @@ class CombineCatsSpec extends FirrtlFlatSpec { // temp5 should get cat(cat(cat(in3, in2), cat(in4, in3)), cat(cat(in3, in2), cat(in4, in3))) result should containTree { - case DoPrim(Cat, Seq( - DoPrim(Cat, Seq( - DoPrim(Cat, Seq(WRef("in2", _, _, _), WRef("in1", _, _, _)), _, _), - DoPrim(Cat, Seq(WRef("in3", _, _, _), WRef("in2", _, _, _)), _, _)), _, _), - DoPrim(Cat, Seq(WRef("in4", _, _, _), WRef("in3", _, _, _)), _, _)), _, _) => true + case DoPrim( + Cat, + Seq( + DoPrim( + Cat, + Seq( + DoPrim(Cat, Seq(WRef("in2", _, _, _), WRef("in1", _, _, _)), _, _), + DoPrim(Cat, Seq(WRef("in3", _, _, _), WRef("in2", _, _, _)), _, _) + ), + _, + _ + ), + DoPrim(Cat, Seq(WRef("in4", _, _, _), WRef("in3", _, _, _)), _, _) + ), + _, + _ + ) => + true } } @@ -117,17 +132,19 @@ class CombineCatsSpec extends FirrtlFlatSpec { // should not contain any cat chains greater than 3 result shouldNot containTree { - case DoPrim(Cat, Seq(_, DoPrim(Cat, Seq(_, DoPrim(Cat, _, _, _)), _, _)), _, _) => true + case DoPrim(Cat, Seq(_, DoPrim(Cat, Seq(_, DoPrim(Cat, _, _, _)), _, _)), _, _) => true case DoPrim(Cat, Seq(_, DoPrim(Cat, Seq(_, DoPrim(Cat, Seq(_, DoPrim(Cat, _, _, _)), _, _)), _, _)), _, _) => true } // temp2 should get cat(in3, cat(in2, in1)) result should containTree { - case DoPrim(Cat, Seq( - WRef("in3", _, _, _), - DoPrim(Cat, Seq( - WRef("in2", _, _, _), - WRef("in1", _, _, _)), _, _)), _, _) => true + case DoPrim( + Cat, + Seq(WRef("in3", _, _, _), DoPrim(Cat, Seq(WRef("in2", _, _, _), WRef("in1", _, _, _)), _, _)), + _, + _ + ) => + true } } @@ -152,8 +169,8 @@ class CombineCatsSpec extends FirrtlFlatSpec { val result = execute(input, transforms, Seq.empty) result shouldNot containTree { - case DoPrim(Cat, Seq(_, DoPrim(Add, _, _, _)), _, _) => true - case DoPrim(Cat, Seq(_, DoPrim(Sub, _, _, _)), _, _) => true + case DoPrim(Cat, Seq(_, DoPrim(Add, _, _, _)), _, _) => true + case DoPrim(Cat, Seq(_, DoPrim(Sub, _, _, _)), _, _) => true case DoPrim(Cat, Seq(_, DoPrim(Cat, Seq(_, DoPrim(Cat, _, _, _)), _, _)), _, _) => true } } diff --git a/src/test/scala/firrtlTests/transforms/DedupTests.scala b/src/test/scala/firrtlTests/transforms/DedupTests.scala index 8ab3026c..8c2835dd 100644 --- a/src/test/scala/firrtlTests/transforms/DedupTests.scala +++ b/src/test/scala/firrtlTests/transforms/DedupTests.scala @@ -10,8 +10,8 @@ import firrtl.transforms.{DedupModules, NoCircuitDedupAnnotation} import firrtl.testutils._ /** - * Tests inline instances transformation - */ + * Tests inline instances transformation + */ class DedupModuleTests extends HighTransformSpec { case class MultiTargetDummyAnnotation(targets: Seq[Target], tag: Int) extends Annotation { override def update(renames: RenameMap): Seq[Annotation] = { @@ -24,234 +24,236 @@ class DedupModuleTests extends HighTransformSpec { } def transform = new DedupModules "The module A" should "be deduped" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : - | output x: UInt<1> - | x <= UInt(1) - | module A_ : - | output x: UInt<1> - | x <= UInt(1) + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | x <= UInt(1) + | module A_ : + | output x: UInt<1> + | x <= UInt(1) """.stripMargin - val check = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A - | module A : - | output x: UInt<1> - | x <= UInt(1) + val check = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A + | module A : + | output x: UInt<1> + | x <= UInt(1) """.stripMargin - execute(input, check, Seq.empty) + execute(input, check, Seq.empty) } "The module A and B" should "be deduped" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : - | output x: UInt<1> - | inst b of B - | x <= b.x - | module A_ : - | output x: UInt<1> - | inst b of B_ - | x <= b.x - | module B : - | output x: UInt<1> - | x <= UInt(1) - | module B_ : - | output x: UInt<1> - | x <= UInt(1) + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | inst b of B + | x <= b.x + | module A_ : + | output x: UInt<1> + | inst b of B_ + | x <= b.x + | module B : + | output x: UInt<1> + | x <= UInt(1) + | module B_ : + | output x: UInt<1> + | x <= UInt(1) """.stripMargin - val check = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A - | module A : - | output x: UInt<1> - | inst b of B - | x <= b.x - | module B : - | output x: UInt<1> - | x <= UInt(1) + val check = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A + | module A : + | output x: UInt<1> + | inst b of B + | x <= b.x + | module B : + | output x: UInt<1> + | x <= UInt(1) """.stripMargin - execute(input, check, Seq.empty) + execute(input, check, Seq.empty) } "The module A and B with comments" should "be deduped" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : @[yy 2:2] - | output x: UInt<1> @[yy 2:2] - | inst b of B @[yy 2:2] - | x <= b.x @[yy 2:2] - | module A_ : @[xx 1:1] - | output x: UInt<1> @[xx 1:1] - | inst b of B_ @[xx 1:1] - | x <= b.x @[xx 1:1] - | module B : - | output x: UInt<1> - | x <= UInt(1) - | module B_ : - | output x: UInt<1> - | x <= UInt(1) + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | inst b of B @[yy 2:2] + | x <= b.x @[yy 2:2] + | module A_ : @[xx 1:1] + | output x: UInt<1> @[xx 1:1] + | inst b of B_ @[xx 1:1] + | x <= b.x @[xx 1:1] + | module B : + | output x: UInt<1> + | x <= UInt(1) + | module B_ : + | output x: UInt<1> + | x <= UInt(1) """.stripMargin - val check = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A - | module A : @[yy 2:2] - | output x: UInt<1> @[yy 2:2] - | inst b of B @[yy 2:2] - | x <= b.x @[yy 2:2] - | module B : - | output x: UInt<1> - | x <= UInt(1) + val check = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | inst b of B @[yy 2:2] + | x <= b.x @[yy 2:2] + | module B : + | output x: UInt<1> + | x <= UInt(1) """.stripMargin - execute(input, check, Seq.empty) + execute(input, check, Seq.empty) } "A_ but not A" should "be deduped if not annotated" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : @[yy 2:2] - | output x: UInt<1> @[yy 2:2] - | x <= UInt(1) - | module A_ : @[xx 1:1] - | output x: UInt<1> @[xx 1:1] - | x <= UInt(1) + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | x <= UInt(1) + | module A_ : @[xx 1:1] + | output x: UInt<1> @[xx 1:1] + | x <= UInt(1) """.stripMargin - val check = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : @[yy 2:2] - | output x: UInt<1> @[yy 2:2] - | x <= UInt(1) - | module A_ : @[xx 1:1] - | output x: UInt<1> @[xx 1:1] - | x <= UInt(1) + val check = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | x <= UInt(1) + | module A_ : @[xx 1:1] + | output x: UInt<1> @[xx 1:1] + | x <= UInt(1) """.stripMargin - execute(input, check, Seq(dontDedup("A"))) + execute(input, check, Seq(dontDedup("A"))) } "The module A and A_" should "be deduped even with different port names and info, and annotations should remapped" in { - val input = - """circuit Top : - | module Top : - | output out: UInt<1> - | inst a1 of A - | inst a2 of A_ - | out <= and(a1.x, a2.y) - | module A : @[yy 2:2] - | output x: UInt<1> @[yy 2:2] - | x <= UInt(1) - | module A_ : @[xx 1:1] - | output y: UInt<1> @[xx 1:1] - | y <= UInt(1) + val input = + """circuit Top : + | module Top : + | output out: UInt<1> + | inst a1 of A + | inst a2 of A_ + | out <= and(a1.x, a2.y) + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | x <= UInt(1) + | module A_ : @[xx 1:1] + | output y: UInt<1> @[xx 1:1] + | y <= UInt(1) """.stripMargin - val check = - """circuit Top : - | module Top : - | output out: UInt<1> - | inst a1 of A - | inst a2 of A - | out <= and(a1.x, a2.x) - | module A : @[yy 2:2] - | output x: UInt<1> @[yy 2:2] - | x <= UInt(1) + val check = + """circuit Top : + | module Top : + | output out: UInt<1> + | inst a1 of A + | inst a2 of A + | out <= and(a1.x, a2.x) + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | x <= UInt(1) """.stripMargin - val mname = ModuleName("Top", CircuitName("Top")) - val finalState = execute(input, check, Seq(SingleTargetDummyAnnotation(ComponentName("a2.y", mname)))) - finalState.annotations.collect({ case d: SingleTargetDummyAnnotation => d }).head should be(SingleTargetDummyAnnotation(ComponentName("a2.x", mname))) + val mname = ModuleName("Top", CircuitName("Top")) + val finalState = execute(input, check, Seq(SingleTargetDummyAnnotation(ComponentName("a2.y", mname)))) + finalState.annotations.collect({ case d: SingleTargetDummyAnnotation => d }).head should be( + SingleTargetDummyAnnotation(ComponentName("a2.x", mname)) + ) } "Extmodules" should "with the same defname and parameters should dedup" in { - val input = - """circuit Top : - | module Top : - | output out: UInt<1> - | inst a1 of A - | inst a2 of A_ - | out <= and(a1.x, a2.y) - | module A : @[yy 2:2] - | output x: UInt<1> @[yy 2:2] - | inst b of B - | x <= b.u - | module A_ : @[xx 1:1] - | output y: UInt<1> @[xx 1:1] - | inst c of C - | y <= c.u - | extmodule B : @[aa 3:3] - | output u : UInt<1> @[aa 4:4] - | defname = BB - | parameter N = 0 - | extmodule C : @[bb 5:5] - | output u : UInt<1> @[bb 6:6] - | defname = BB - | parameter N = 0 + val input = + """circuit Top : + | module Top : + | output out: UInt<1> + | inst a1 of A + | inst a2 of A_ + | out <= and(a1.x, a2.y) + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | inst b of B + | x <= b.u + | module A_ : @[xx 1:1] + | output y: UInt<1> @[xx 1:1] + | inst c of C + | y <= c.u + | extmodule B : @[aa 3:3] + | output u : UInt<1> @[aa 4:4] + | defname = BB + | parameter N = 0 + | extmodule C : @[bb 5:5] + | output u : UInt<1> @[bb 6:6] + | defname = BB + | parameter N = 0 """.stripMargin - val check = - """circuit Top : - | module Top : - | output out: UInt<1> - | inst a1 of A - | inst a2 of A - | out <= and(a1.x, a2.x) - | module A : @[yy 2:2] - | output x: UInt<1> @[yy 2:2] - | inst b of B - | x <= b.u - | extmodule B : @[aa 3:3] - | output u : UInt<1> @[aa 4:4] - | defname = BB - | parameter N = 0 + val check = + """circuit Top : + | module Top : + | output out: UInt<1> + | inst a1 of A + | inst a2 of A + | out <= and(a1.x, a2.x) + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | inst b of B + | x <= b.u + | extmodule B : @[aa 3:3] + | output u : UInt<1> @[aa 4:4] + | defname = BB + | parameter N = 0 """.stripMargin - execute(input, check, Seq.empty) + execute(input, check, Seq.empty) } "Extmodules" should "with the different defname or parameters should NOT dedup" in { - def mkfir(defnames: (String, String), params: (String, String)) = - s"""circuit Top : - | module Top : - | output out: UInt<1> - | inst a1 of A - | inst a2 of A_ - | out <= and(a1.x, a2.y) - | module A : @[yy 2:2] - | output x: UInt<1> @[yy 2:2] - | inst b of B - | x <= b.u - | module A_ : @[xx 1:1] - | output y: UInt<1> @[xx 1:1] - | inst c of C - | y <= c.u - | extmodule B : @[aa 3:3] - | output u : UInt<1> @[aa 4:4] - | defname = ${defnames._1} - | parameter N = ${params._1} - | extmodule C : @[bb 5:5] - | output u : UInt<1> @[bb 6:6] - | defname = ${defnames._2} - | parameter N = ${params._2} + def mkfir(defnames: (String, String), params: (String, String)) = + s"""circuit Top : + | module Top : + | output out: UInt<1> + | inst a1 of A + | inst a2 of A_ + | out <= and(a1.x, a2.y) + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | inst b of B + | x <= b.u + | module A_ : @[xx 1:1] + | output y: UInt<1> @[xx 1:1] + | inst c of C + | y <= c.u + | extmodule B : @[aa 3:3] + | output u : UInt<1> @[aa 4:4] + | defname = ${defnames._1} + | parameter N = ${params._1} + | extmodule C : @[bb 5:5] + | output u : UInt<1> @[bb 6:6] + | defname = ${defnames._2} + | parameter N = ${params._2} """.stripMargin - val diff_defname = mkfir(("BB", "CC"), ("0", "0")) - execute(diff_defname, diff_defname, Seq.empty) - val diff_params = mkfir(("BB", "BB"), ("0", "1")) - execute(diff_params, diff_params, Seq.empty) + val diff_defname = mkfir(("BB", "CC"), ("0", "0")) + execute(diff_defname, diff_defname, Seq.empty) + val diff_params = mkfir(("BB", "BB"), ("0", "1")) + execute(diff_params, diff_params, Seq.empty) } "Modules with aggregate ports that are bulk connected" should "NOT dedup if their port names differ" in { @@ -426,12 +428,16 @@ class DedupModuleTests extends HighTransformSpec { | wire b: UInt<1> | x <= b """.stripMargin - val cs = execute(input, check, Seq( - dontTouch(ReferenceTarget("Top", "A", Nil, "b", Nil)), - dontTouch(ReferenceTarget("Top", "A_", Nil, "b", Nil)) - )) - cs.annotations.toSeq should contain (dontTouch(ModuleTarget("Top", "Top").instOf("a1", "A").ref("b"))) - cs.annotations.toSeq should contain (dontTouch(ModuleTarget("Top", "Top").instOf("a2", "A").ref("b"))) + val cs = execute( + input, + check, + Seq( + dontTouch(ReferenceTarget("Top", "A", Nil, "b", Nil)), + dontTouch(ReferenceTarget("Top", "A_", Nil, "b", Nil)) + ) + ) + cs.annotations.toSeq should contain(dontTouch(ModuleTarget("Top", "Top").instOf("a1", "A").ref("b"))) + cs.annotations.toSeq should contain(dontTouch(ModuleTarget("Top", "Top").instOf("a2", "A").ref("b"))) cs.annotations.toSeq should not contain dontTouch(ReferenceTarget("Top", "A_", Nil, "b", Nil)) } "The module A and A_" should "be deduped with same annotation targets when there are a lot" in { @@ -508,12 +514,24 @@ class DedupModuleTests extends HighTransformSpec { val annoAB = MultiTargetDummyAnnotation(Seq(A, B), 0) val annoA_B_ = MultiTargetDummyAnnotation(Seq(A_, B_), 1) val cs = execute(input, check, Seq(annoAB, annoA_B_)) - cs.annotations.toSeq should contain (MultiTargetDummyAnnotation(Seq( - Top_a1, Top_a1_b - ), 0)) - cs.annotations.toSeq should contain (MultiTargetDummyAnnotation(Seq( - Top_a2, Top_a2_b - ), 1)) + cs.annotations.toSeq should contain( + MultiTargetDummyAnnotation( + Seq( + Top_a1, + Top_a1_b + ), + 0 + ) + ) + cs.annotations.toSeq should contain( + MultiTargetDummyAnnotation( + Seq( + Top_a2, + Top_a2_b + ), + 1 + ) + ) } "The module A and A_" should "be deduped with same annotations with same multi-targets, that share roots" in { val input = @@ -555,15 +573,25 @@ class DedupModuleTests extends HighTransformSpec { val annoA = MultiTargetDummyAnnotation(Seq(A, A.instOf("b", "B")), 0) val annoA_ = MultiTargetDummyAnnotation(Seq(A_, A_.instOf("b", "B_")), 1) val cs = execute(input, check, Seq(annoA, annoA_)) - cs.annotations.toSeq should contain (MultiTargetDummyAnnotation(Seq( - Top.module("Top").instOf("a1", "A"), - Top.module("Top").instOf("a1", "A").instOf("b", "B") - ),0)) - cs.annotations.toSeq should contain (MultiTargetDummyAnnotation(Seq( - Top.module("Top").instOf("a2", "A"), - Top.module("Top").instOf("a2", "A").instOf("b", "B") - ),1)) - cs.deletedAnnotations.isEmpty should be (true) + cs.annotations.toSeq should contain( + MultiTargetDummyAnnotation( + Seq( + Top.module("Top").instOf("a1", "A"), + Top.module("Top").instOf("a1", "A").instOf("b", "B") + ), + 0 + ) + ) + cs.annotations.toSeq should contain( + MultiTargetDummyAnnotation( + Seq( + Top.module("Top").instOf("a2", "A"), + Top.module("Top").instOf("a2", "A").instOf("b", "B") + ), + 1 + ) + ) + cs.deletedAnnotations.isEmpty should be(true) } "The deduping module A and A_" should "rename internal signals that have different names" in { val input = @@ -600,12 +628,12 @@ class DedupModuleTests extends HighTransformSpec { val Top = CircuitTarget("Top") val A = Top.module("A") val A_ = Top.module("A_") - val annoA = SingleTargetDummyAnnotation(A.ref("a")) + val annoA = SingleTargetDummyAnnotation(A.ref("a")) val annoA_ = SingleTargetDummyAnnotation(A_.ref("b")) val cs = execute(input, check, Seq(annoA, annoA_)) - cs.annotations.toSeq should contain (annoA) + cs.annotations.toSeq should contain(annoA) cs.annotations.toSeq should not contain (SingleTargetDummyAnnotation(A.ref("b"))) - cs.deletedAnnotations.isEmpty should be (true) + cs.deletedAnnotations.isEmpty should be(true) } "main" should "not be deduped even if it's the last module" in { val input = @@ -691,14 +719,25 @@ class DedupModuleTests extends HighTransformSpec { val anno1 = MultiTargetDummyAnnotation(Seq(inst1, ref1), 0) val anno2 = MultiTargetDummyAnnotation(Seq(inst2, ref2), 1) val cs = execute(input, check, Seq(anno1, anno2)) - cs.annotations.toSeq should contain (MultiTargetDummyAnnotation(Seq( - inst1, ref1 - ),0)) - cs.annotations.toSeq should contain (MultiTargetDummyAnnotation(Seq( - Top.module("Top").instOf("a_", "A").instOf("b", "B"), - Top.module("Top").instOf("a_", "A").instOf("b", "B").ref("foo") - ),1)) - cs.deletedAnnotations.isEmpty should be (true) + cs.annotations.toSeq should contain( + MultiTargetDummyAnnotation( + Seq( + inst1, + ref1 + ), + 0 + ) + ) + cs.annotations.toSeq should contain( + MultiTargetDummyAnnotation( + Seq( + Top.module("Top").instOf("a_", "A").instOf("b", "B"), + Top.module("Top").instOf("a_", "A").instOf("b", "B").ref("foo") + ), + 1 + ) + ) + cs.deletedAnnotations.isEmpty should be(true) } "The deduping module A and A_" should "rename nested instances that have different names" in { @@ -746,14 +785,25 @@ class DedupModuleTests extends HighTransformSpec { val anno1 = MultiTargetDummyAnnotation(Seq(inst1, ref1), 0) val anno2 = MultiTargetDummyAnnotation(Seq(inst2, ref2), 1) val cs = execute(input, check, Seq(anno1, anno2)) - cs.annotations.toSeq should contain (MultiTargetDummyAnnotation(Seq( - inst1, ref1 - ),0)) - cs.annotations.toSeq should contain (MultiTargetDummyAnnotation(Seq( - Top.module("Top").instOf("a_", "A").instOf("b", "B").instOf("c", "C").instOf("d", "D"), - Top.module("Top").instOf("a_", "A").instOf("b", "B").instOf("c", "C").instOf("d", "D").ref("foo") - ),1)) - cs.deletedAnnotations.isEmpty should be (true) + cs.annotations.toSeq should contain( + MultiTargetDummyAnnotation( + Seq( + inst1, + ref1 + ), + 0 + ) + ) + cs.annotations.toSeq should contain( + MultiTargetDummyAnnotation( + Seq( + Top.module("Top").instOf("a_", "A").instOf("b", "B").instOf("c", "C").instOf("d", "D"), + Top.module("Top").instOf("a_", "A").instOf("b", "B").instOf("c", "C").instOf("d", "D").ref("foo") + ), + 1 + ) + ) + cs.deletedAnnotations.isEmpty should be(true) } "Deduping modules with multiple instances" should "corectly rename instances" in { @@ -801,50 +851,55 @@ class DedupModuleTests extends HighTransformSpec { val cInstances = bInstances.map(_.instOf("c", "C")) val annos = MultiTargetDummyAnnotation(bInstances ++ cInstances, 0) val cs = execute(input, check, Seq(annos)) - cs.annotations.toSeq should contain (MultiTargetDummyAnnotation(Seq( - Top.instOf("b", "B"), - Top.instOf("b_", "B"), - Top.instOf("a1", "A").instOf("b_", "B"), - Top.instOf("a2", "A").instOf("b_", "B"), - Top.instOf("a1", "A").instOf("b", "B"), - Top.instOf("a2", "A").instOf("b", "B"), - Top.instOf("b", "B").instOf("c", "C"), - Top.instOf("b_", "B").instOf("c", "C"), - Top.instOf("a1", "A").instOf("b_", "B").instOf("c", "C"), - Top.instOf("a2", "A").instOf("b_", "B").instOf("c", "C"), - Top.instOf("a1", "A").instOf("b", "B").instOf("c", "C"), - Top.instOf("a2", "A").instOf("b", "B").instOf("c", "C") - ),0)) - cs.deletedAnnotations.isEmpty should be (true) + cs.annotations.toSeq should contain( + MultiTargetDummyAnnotation( + Seq( + Top.instOf("b", "B"), + Top.instOf("b_", "B"), + Top.instOf("a1", "A").instOf("b_", "B"), + Top.instOf("a2", "A").instOf("b_", "B"), + Top.instOf("a1", "A").instOf("b", "B"), + Top.instOf("a2", "A").instOf("b", "B"), + Top.instOf("b", "B").instOf("c", "C"), + Top.instOf("b_", "B").instOf("c", "C"), + Top.instOf("a1", "A").instOf("b_", "B").instOf("c", "C"), + Top.instOf("a2", "A").instOf("b_", "B").instOf("c", "C"), + Top.instOf("a1", "A").instOf("b", "B").instOf("c", "C"), + Top.instOf("a2", "A").instOf("b", "B").instOf("c", "C") + ), + 0 + ) + ) + cs.deletedAnnotations.isEmpty should be(true) } "dedup" should "properly rename target components after retyping" in { val input = """ - |circuit top: - | module top: - | input ia: {z: {y: {x: UInt<1>}}, a: UInt<1>} - | input ib: {a: {b: {c: UInt<1>}}, z: UInt<1>} - | output oa: {z: {y: {x: UInt<1>}}, a: UInt<1>} - | output ob: {a: {b: {c: UInt<1>}}, z: UInt<1>} - | inst a of a - | a.i.z.y.x <= ia.z.y.x - | a.i.a <= ia.a - | oa.z.y.x <= a.o.z.y.x - | oa.a <= a.o.a - | inst b of b - | b.q.a.b.c <= ib.a.b.c - | b.q.z <= ib.z - | ob.a.b.c <= b.r.a.b.c - | ob.z <= b.r.z - | module a: - | input i: {z: {y: {x: UInt<1>}}, a: UInt<1>} - | output o: {z: {y: {x: UInt<1>}}, a: UInt<1>} - | o <= i - | module b: - | input q: {a: {b: {c: UInt<1>}}, z: UInt<1>} - | output r: {a: {b: {c: UInt<1>}}, z: UInt<1>} - | r <= q - |""".stripMargin + |circuit top: + | module top: + | input ia: {z: {y: {x: UInt<1>}}, a: UInt<1>} + | input ib: {a: {b: {c: UInt<1>}}, z: UInt<1>} + | output oa: {z: {y: {x: UInt<1>}}, a: UInt<1>} + | output ob: {a: {b: {c: UInt<1>}}, z: UInt<1>} + | inst a of a + | a.i.z.y.x <= ia.z.y.x + | a.i.a <= ia.a + | oa.z.y.x <= a.o.z.y.x + | oa.a <= a.o.a + | inst b of b + | b.q.a.b.c <= ib.a.b.c + | b.q.z <= ib.z + | ob.a.b.c <= b.r.a.b.c + | ob.z <= b.r.z + | module a: + | input i: {z: {y: {x: UInt<1>}}, a: UInt<1>} + | output o: {z: {y: {x: UInt<1>}}, a: UInt<1>} + | o <= i + | module b: + | input q: {a: {b: {c: UInt<1>}}, z: UInt<1>} + | output r: {a: {b: {c: UInt<1>}}, z: UInt<1>} + | r <= q + |""".stripMargin case class DummyRTAnnotation(target: ReferenceTarget) extends SingleTargetAnnotation[ReferenceTarget] { def duplicate(n: ReferenceTarget) = DummyRTAnnotation(n) @@ -853,7 +908,6 @@ class DedupModuleTests extends HighTransformSpec { val annA = DummyRTAnnotation(ReferenceTarget("top", "a", Nil, "i", Seq(TargetToken.Field("a")))) val annB = DummyRTAnnotation(ReferenceTarget("top", "b", Nil, "q", Seq(TargetToken.Field("a")))) - val cs = CircuitState(Parser.parseString(input, Parser.IgnoreInfo), Seq(annA, annB)) val deduper = new stage.transforms.Compiler(stage.Forms.Deduped, Nil) @@ -871,7 +925,7 @@ class DedupModuleTests extends HighTransformSpec { val bPath = Seq((TargetToken.Instance("b"), TargetToken.OfModule("a"))) val expectedAnnA = DummyRTAnnotation(ReferenceTarget("top", "top", aPath, "i", Seq(TargetToken.Field("a")))) val expectedAnnB = DummyRTAnnotation(ReferenceTarget("top", "top", aPath, "i", Seq(TargetToken.Field("a")))) - csDeduped.annotations.toSeq should contain (expectedAnnA) - csDeduped.annotations.toSeq should contain (expectedAnnB) + csDeduped.annotations.toSeq should contain(expectedAnnA) + csDeduped.annotations.toSeq should contain(expectedAnnB) } } diff --git a/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala b/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala index fdb129a1..65544764 100644 --- a/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala +++ b/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala @@ -17,75 +17,75 @@ class GroupComponentsSpec extends MiddleTransformSpec { def topComp(name: String): ComponentName = ComponentName(name, ModuleName(top, CircuitName(top))) "The register r" should "be grouped" in { val input = - s"""circuit $top : - | module $top : - | input clk: Clock - | input data: UInt<16> - | output out: UInt<16> - | reg r: UInt<16>, clk - | r <= data - | out <= r + s"""circuit $top : + | module $top : + | input clk: Clock + | input data: UInt<16> + | output out: UInt<16> + | reg r: UInt<16>, clk + | r <= data + | out <= r """.stripMargin val groups = Seq( GroupAnnotation(Seq(topComp("r")), "MyReg", "rInst", Some("_OUT"), Some("_IN")) ) val check = - s"""circuit Top : - | module $top : - | input clk: Clock - | input data: UInt<16> - | output out: UInt<16> - | inst rInst of MyReg - | rInst.clk_IN <= clk - | out <= rInst.r_OUT - | rInst.data_IN <= data - | module MyReg : - | input clk_IN: Clock - | output r_OUT: UInt<16> - | input data_IN: UInt<16> - | reg r: UInt<16>, clk_IN - | r_OUT <= r - | r <= data_IN + s"""circuit Top : + | module $top : + | input clk: Clock + | input data: UInt<16> + | output out: UInt<16> + | inst rInst of MyReg + | rInst.clk_IN <= clk + | out <= rInst.r_OUT + | rInst.data_IN <= data + | module MyReg : + | input clk_IN: Clock + | output r_OUT: UInt<16> + | input data_IN: UInt<16> + | reg r: UInt<16>, clk_IN + | r_OUT <= r + | r <= data_IN """.stripMargin execute(input, check, groups) } "Grouping" should "work even when there are unused nodes" in { val input = - s"""circuit $top : - | module $top : - | input in: UInt<16> - | output out: UInt<16> - | node n = UInt<16>("h0") - | wire w : UInt<16> - | wire a : UInt<16> - | wire b : UInt<16> - | a <= UInt<16>("h0") - | b <= a - | w <= in - | out <= w + s"""circuit $top : + | module $top : + | input in: UInt<16> + | output out: UInt<16> + | node n = UInt<16>("h0") + | wire w : UInt<16> + | wire a : UInt<16> + | wire b : UInt<16> + | a <= UInt<16>("h0") + | b <= a + | w <= in + | out <= w """.stripMargin val groups = Seq( GroupAnnotation(Seq(topComp("w")), "Child", "inst", Some("_OUT"), Some("_IN")) ) val check = - s"""circuit Top : - | module $top : - | input in: UInt<16> - | output out: UInt<16> - | inst inst of Child - | node n = UInt<16>("h0") - | wire a : UInt<16> - | wire b : UInt<16> - | out <= inst.w_OUT - | inst.in_IN <= in - | a <= UInt<16>("h0") - | b <= a - | module Child : - | output w_OUT : UInt<16> - | input in_IN : UInt<16> - | wire w : UInt<16> - | w_OUT <= w - | w <= in_IN + s"""circuit Top : + | module $top : + | input in: UInt<16> + | output out: UInt<16> + | inst inst of Child + | node n = UInt<16>("h0") + | wire a : UInt<16> + | wire b : UInt<16> + | out <= inst.w_OUT + | inst.in_IN <= in + | a <= UInt<16>("h0") + | b <= a + | module Child : + | output w_OUT : UInt<16> + | input in_IN : UInt<16> + | wire w : UInt<16> + | w_OUT <= w + | w <= in_IN """.stripMargin execute(input, check, groups) } @@ -116,8 +116,8 @@ class GroupComponentsSpec extends MiddleTransformSpec { | out <= UInt(2) """.stripMargin val annotations = Seq( - GroupAnnotation(Seq(topComp("c1a"), topComp("c2a")/*, topComp("asum")*/), "A", "cA", Some("_OUT"), Some("_IN")), - GroupAnnotation(Seq(topComp("c1b"), topComp("c2b")/*, topComp("bsum")*/), "B", "cB", Some("_OUT"), Some("_IN")), + GroupAnnotation(Seq(topComp("c1a"), topComp("c2a") /*, topComp("asum")*/ ), "A", "cA", Some("_OUT"), Some("_IN")), + GroupAnnotation(Seq(topComp("c1b"), topComp("c2b") /*, topComp("bsum")*/ ), "B", "cB", Some("_OUT"), Some("_IN")), NoCircuitDedupAnnotation ) val check = @@ -380,7 +380,7 @@ class GroupComponentsIntegrationSpec extends FirrtlFlatSpec { def topComp(name: String): ComponentName = ComponentName(name, ModuleName("Top", CircuitName("Top"))) "Grouping" should "properly set kinds" in { val input = - """circuit Top : + """circuit Top : | module Top : | input clk: Clock | input data: UInt<16> @@ -397,13 +397,13 @@ class GroupComponentsIntegrationSpec extends FirrtlFlatSpec { Seq(new GroupComponents) ) result should containTree { - case Connect(_, WSubField(WRef("inst",_, InstanceKind,_), "data_IN", _,_), WRef("data",_,_,_)) => true + case Connect(_, WSubField(WRef("inst", _, InstanceKind, _), "data_IN", _, _), WRef("data", _, _, _)) => true } result should containTree { - case Connect(_, WSubField(WRef("inst",_, InstanceKind,_), "clk_IN", _,_), WRef("clk",_,_,_)) => true + case Connect(_, WSubField(WRef("inst", _, InstanceKind, _), "clk_IN", _, _), WRef("clk", _, _, _)) => true } result should containTree { - case Connect(_, WRef("out",_,_,_), WSubField(WRef("inst",_, InstanceKind,_), "r_OUT", _,_)) => true + case Connect(_, WRef("out", _, _, _), WSubField(WRef("inst", _, InstanceKind, _), "r_OUT", _, _)) => true } } } diff --git a/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala b/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala index c5847364..0043cb1f 100644 --- a/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala +++ b/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala @@ -5,36 +5,25 @@ package firrtlTests.transforms import firrtl.testutils.FirrtlFlatSpec import firrtl._ import firrtl.passes._ -import firrtl.passes.wiring.{WiringTransform, SourceAnnotation, SinkAnnotation} +import firrtl.passes.wiring.{SinkAnnotation, SourceAnnotation, WiringTransform} import firrtl.annotations._ import firrtl.annotations.TargetToken.{Field, Index} - class InferWidthsWithAnnosSpec extends FirrtlFlatSpec { - private def executeTest(input: String, - check: String, - transforms: Seq[Transform], - annotations: Seq[Annotation]) = { + private def executeTest(input: String, check: String, transforms: Seq[Transform], annotations: Seq[Annotation]) = { val start = CircuitState(parse(input), ChirrtlForm, annotations) - val end = transforms.foldLeft(start) { - (c: CircuitState, t: Transform) => t.runTransform(c) + val end = transforms.foldLeft(start) { (c: CircuitState, t: Transform) => + t.runTransform(c) } - val resLines = end.circuit.serialize.split("\n") map normalized - val checkLines = parse(check).serialize.split("\n") map normalized + val resLines = end.circuit.serialize.split("\n").map(normalized) + val checkLines = parse(check).serialize.split("\n").map(normalized) - resLines should be (checkLines) + resLines should be(checkLines) } "CheckWidths on wires with unknown widths" should "result in an error" in { - val transforms = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes, - ResolveFlows, - new InferWidths, - CheckWidths) + val transforms = + Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes, ResolveFlows, new InferWidths, CheckWidths) val input = """circuit Top : @@ -55,19 +44,15 @@ class InferWidthsWithAnnosSpec extends FirrtlFlatSpec { } "InferWidthsWithAnnos" should "infer widths using WidthGeqConstraintAnnotation" in { - val transforms = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes, - ResolveFlows, - new InferWidths, - CheckWidths) + val transforms = + Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes, ResolveFlows, new InferWidths, CheckWidths) - val annos = Seq(WidthGeqConstraintAnnotation( - ReferenceTarget("Top", "A", Nil, "y", Nil), - ReferenceTarget("Top", "B", Nil, "x", Nil))) + val annos = Seq( + WidthGeqConstraintAnnotation( + ReferenceTarget("Top", "A", Nil, "y", Nil), + ReferenceTarget("Top", "B", Nil, "x", Nil) + ) + ) val input = """circuit Top : @@ -98,15 +83,8 @@ class InferWidthsWithAnnosSpec extends FirrtlFlatSpec { } "InferWidthsWithAnnos" should "work with token paths" in { - val transforms = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes, - ResolveFlows, - new InferWidths, - CheckWidths) + val transforms = + Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes, ResolveFlows, new InferWidths, CheckWidths) val tokenLists = Seq( Seq(Field("x")), @@ -117,7 +95,8 @@ class InferWidthsWithAnnosSpec extends FirrtlFlatSpec { val annos = tokenLists.map { tokens => WidthGeqConstraintAnnotation( ReferenceTarget("Top", "A", Nil, "bundle", tokens), - ReferenceTarget("Top", "B", Nil, "bundle", tokens)) + ReferenceTarget("Top", "B", Nil, "bundle", tokens) + ) } val input = @@ -174,7 +153,8 @@ class InferWidthsWithAnnosSpec extends FirrtlFlatSpec { val wgeqAnnos = tokenLists.map { tokens => WidthGeqConstraintAnnotation( ReferenceTarget("Top", "A", Nil, "bundle", tokens), - ReferenceTarget("Top", "B", Nil, "bundle", tokens)) + ReferenceTarget("Top", "B", Nil, "bundle", tokens) + ) } val failAnnos = Seq(source, sink) @@ -209,8 +189,7 @@ class InferWidthsWithAnnosSpec extends FirrtlFlatSpec { | module A : | output bundle_0 : {x : UInt<1>, y: {yy : UInt<3>}[2] } | wire bundle : {x : UInt<1>, y: {yy : UInt<3>}[2] } - | bundle_0 <= bundle""" - .stripMargin + | bundle_0 <= bundle""".stripMargin // should fail without extra constraint annos due to UninferredWidths val exceptions = intercept[PassExceptions] { diff --git a/src/test/scala/firrtlTests/transforms/LegalizeClocks.scala b/src/test/scala/firrtlTests/transforms/LegalizeClocks.scala index f57586f6..6ee0f5a0 100644 --- a/src/test/scala/firrtlTests/transforms/LegalizeClocks.scala +++ b/src/test/scala/firrtlTests/transforms/LegalizeClocks.scala @@ -10,7 +10,7 @@ class LegalizeClocksTransformSpec extends FirrtlFlatSpec { def compile(input: String): CircuitState = (new MinimumVerilogCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm), Nil) - behavior of "LegalizeClocksTransform" + behavior.of("LegalizeClocksTransform") it should "not emit @(posedge 1'h0) for stop" in { val input = @@ -19,8 +19,8 @@ class LegalizeClocksTransformSpec extends FirrtlFlatSpec { | stop(asClock(UInt(1)), UInt(1), 1) |""".stripMargin val result = compile(input) - result should containLine (s"always @(posedge _GEN_0) begin") - result.getEmittedCircuit.value shouldNot include ("always @(posedge 1") + result should containLine(s"always @(posedge _GEN_0) begin") + result.getEmittedCircuit.value shouldNot include("always @(posedge 1") } it should "not emit @(posedge 1'h0) for printf" in { @@ -30,8 +30,8 @@ class LegalizeClocksTransformSpec extends FirrtlFlatSpec { | printf(asClock(UInt(1)), UInt(1), "hi") |""".stripMargin val result = compile(input) - result should containLine (s"always @(posedge _GEN_0) begin") - result.getEmittedCircuit.value shouldNot include ("always @(posedge 1") + result should containLine(s"always @(posedge _GEN_0) begin") + result.getEmittedCircuit.value shouldNot include("always @(posedge 1") } it should "not emit @(posedge 1'h0) for reg" in { @@ -45,8 +45,8 @@ class LegalizeClocksTransformSpec extends FirrtlFlatSpec { | out <= r |""".stripMargin val result = compile(input) - result should containLine (s"always @(posedge _GEN_0) begin") - result.getEmittedCircuit.value shouldNot include ("always @(posedge 1") + result should containLine(s"always @(posedge _GEN_0) begin") + result.getEmittedCircuit.value shouldNot include("always @(posedge 1") } it should "deduplicate injected nodes for literal clocks" in { @@ -57,11 +57,11 @@ class LegalizeClocksTransformSpec extends FirrtlFlatSpec { | stop(asClock(UInt(1)), UInt(1), 1) |""".stripMargin val result = compile(input) - result should containLine (s"wire _GEN_0 = 1'h1;") + result should containLine(s"wire _GEN_0 = 1'h1;") // Check that there's only 1 _GEN_0 instantiation val verilog = result.getEmittedCircuit.value val matches = "wire\\s+_GEN_0\\s+=\\s+1'h1".r.findAllIn(verilog) - matches.size should be (1) + matches.size should be(1) } } diff --git a/src/test/scala/firrtlTests/transforms/LegalizeReductions.scala b/src/test/scala/firrtlTests/transforms/LegalizeReductions.scala index 5368c54c..3df47f1d 100644 --- a/src/test/scala/firrtlTests/transforms/LegalizeReductions.scala +++ b/src/test/scala/firrtlTests/transforms/LegalizeReductions.scala @@ -12,12 +12,11 @@ import java.io.File object LegalizeAndReductionsTransformSpec extends FirrtlRunners { private case class Test( - name: String, - op: String, - input: BigInt, - expected: BigInt, - forceWidth: Option[Int] = None - ) { + name: String, + op: String, + input: BigInt, + expected: BigInt, + forceWidth: Option[Int] = None) { def toFirrtl: String = { val width = forceWidth.getOrElse(input.bitLength) val inputLit = s"""UInt("h${input.toString(16)}")""" @@ -62,9 +61,9 @@ circuit $name : // Run FIRRTL val annos = FirrtlSourceAnnotation(test.toFirrtl) :: - TargetDirAnnotation(testDir.toString) :: - CompilerAnnotation(new MinimumVerilogCompiler) :: - Nil + TargetDirAnnotation(testDir.toString) :: + CompilerAnnotation(new MinimumVerilogCompiler) :: + Nil val resultAnnos = (new FirrtlStage).transform(annos) val outputFilename = resultAnnos.collectFirst { case OutputFileAnnotation(f) => f } outputFilename.toRight(s"Output file not found!") @@ -73,8 +72,8 @@ circuit $name : copyResourceToFile(cppHarnessResourceName, harness) // Run Verilator verilogToCpp(prefix, testDir, Nil, harness, suppressVcd = true) #&& - cppToExe(prefix, testDir) ! - loggingProcessLogger + cppToExe(prefix, testDir) ! + loggingProcessLogger // Run binary if (!executeExpectingSuccess(prefix, testDir)) { throw new Exception("Test failed!") with scala.util.control.NoStackTrace @@ -82,24 +81,23 @@ circuit $name : } } - class LegalizeAndReductionsTransformSpec extends AnyFlatSpec { import LegalizeAndReductionsTransformSpec._ - behavior of "LegalizeAndReductionsTransform" + behavior.of("LegalizeAndReductionsTransform") private val tests = // name primop input expected width - Test("andreduce_ones", "andr", BigInt("1"*68, 2), 1) :: - Test("andreduce_zero", "andr", 0, 0, Some(68)) :: - Test("orreduce_ones", "orr", BigInt("1"*68, 2), 1) :: - Test("orreduce_high_one", "orr", BigInt("1" + "0"*67, 2), 1) :: - Test("orreduce_zero", "orr", 0, 0, Some(68)) :: - Test("xorreduce_high_one", "xorr", BigInt("1" + "0"*67, 2), 1) :: - Test("xorreduce_high_low_one", "xorr", BigInt("1" + "0"*66 + "1", 2), 0) :: - Test("xorreduce_zero", "xorr", 0, 0, Some(68)) :: - Nil + Test("andreduce_ones", "andr", BigInt("1" * 68, 2), 1) :: + Test("andreduce_zero", "andr", 0, 0, Some(68)) :: + Test("orreduce_ones", "orr", BigInt("1" * 68, 2), 1) :: + Test("orreduce_high_one", "orr", BigInt("1" + "0" * 67, 2), 1) :: + Test("orreduce_zero", "orr", 0, 0, Some(68)) :: + Test("xorreduce_high_one", "xorr", BigInt("1" + "0" * 67, 2), 1) :: + Test("xorreduce_high_low_one", "xorr", BigInt("1" + "0" * 66 + "1", 2), 0) :: + Test("xorreduce_zero", "xorr", 0, 0, Some(68)) :: + Nil for (test <- tests) { it should s"support ${test.name}" in { diff --git a/src/test/scala/firrtlTests/transforms/ManipulateNamesSpec.scala b/src/test/scala/firrtlTests/transforms/ManipulateNamesSpec.scala index ec1b505b..a616b4bd 100644 --- a/src/test/scala/firrtlTests/transforms/ManipulateNamesSpec.scala +++ b/src/test/scala/firrtlTests/transforms/ManipulateNamesSpec.scala @@ -2,22 +2,15 @@ package firrtlTests.transforms -import firrtl.{ - ir, - CircuitState, - FirrtlUserException, - Namespace, - Parser, - RenameMap -} +import firrtl.{ir, CircuitState, FirrtlUserException, Namespace, Parser, RenameMap} import firrtl.annotations.CircuitTarget import firrtl.options.Dependency import firrtl.testutils.FirrtlCheckers._ import firrtl.transforms.{ ManipulateNames, - ManipulateNamesBlocklistAnnotation, ManipulateNamesAllowlistAnnotation, - ManipulateNamesAllowlistResultAnnotation + ManipulateNamesAllowlistResultAnnotation, + ManipulateNamesBlocklistAnnotation } import org.scalatest.flatspec.AnyFlatSpec @@ -57,24 +50,24 @@ class ManipulateNamesSpec extends AnyFlatSpec with Matchers { val tm = new firrtl.stage.transforms.Compiler(Seq(Dependency[AddPrefix])) } - behavior of "ManipulateNames" + behavior.of("ManipulateNames") it should "rename everything by default" in new CircuitFixture { val state = CircuitState(Parser.parse(input), Seq.empty) val statex = tm.execute(state) val expected: Seq[PartialFunction[Any, Boolean]] = Seq( { case ir.Circuit(_, _, "prefix_Foo") => true }, - { case ir.Module(_, "prefix_Foo", _, _) => true}, - { case ir.Module(_, "prefix_Bar", _, _) => true} + { case ir.Module(_, "prefix_Foo", _, _) => true }, + { case ir.Module(_, "prefix_Bar", _, _) => true } ) - expected.foreach(statex should containTree (_)) + expected.foreach(statex should containTree(_)) } it should "do nothing if the circuit is blocklisted" in new CircuitFixture { val annotations = Seq(ManipulateNamesBlocklistAnnotation(Seq(Seq(`~Foo`)), Dependency[AddPrefix])) val state = CircuitState(Parser.parse(input), annotations) val statex = tm.execute(state) - state.circuit.serialize should be (statex.circuit.serialize) + state.circuit.serialize should be(statex.circuit.serialize) } it should "not rename the circuit if the top module is blocklisted" in new CircuitFixture { @@ -82,31 +75,31 @@ class ManipulateNamesSpec extends AnyFlatSpec with Matchers { val state = CircuitState(Parser.parse(input), annotations) val expected: Seq[PartialFunction[Any, Boolean]] = Seq( { case ir.Circuit(_, _, "Foo") => true }, - { case ir.Module(_, "Foo", _, _) => true}, - { case ir.Module(_, "prefix_Bar", _, _) => true} + { case ir.Module(_, "Foo", _, _) => true }, + { case ir.Module(_, "prefix_Bar", _, _) => true } ) val statex = tm.execute(state) - expected.foreach(statex should containTree (_)) + expected.foreach(statex should containTree(_)) } it should "not rename instances if blocklisted" in new CircuitFixture { val annotations = Seq(ManipulateNamesBlocklistAnnotation(Seq(Seq(`~Foo|Foo/bar:Bar`)), Dependency[AddPrefix])) val state = CircuitState(Parser.parse(input), annotations) val expected: Seq[PartialFunction[Any, Boolean]] = Seq( - { case ir.DefInstance(_, "bar", "prefix_Bar", _) => true}, - { case ir.Module(_, "prefix_Bar", _, _) => true} + { case ir.DefInstance(_, "bar", "prefix_Bar", _) => true }, + { case ir.Module(_, "prefix_Bar", _, _) => true } ) val statex = tm.execute(state) - expected.foreach(statex should containTree (_)) + expected.foreach(statex should containTree(_)) } - it should "do nothing if the circuit is not allowlisted" in new CircuitFixture { + it should "do nothing if the circuit is not allowlisted" in new CircuitFixture { val annotations = Seq( ManipulateNamesAllowlistAnnotation(Seq(Seq(`~Foo|Foo`)), Dependency[AddPrefix]) ) val state = CircuitState(Parser.parse(input), annotations) val statex = tm.execute(state) - state.circuit.serialize should be (statex.circuit.serialize) + state.circuit.serialize should be(statex.circuit.serialize) } it should "rename only the circuit if allowlisted" in new CircuitFixture { @@ -118,13 +111,13 @@ class ManipulateNamesSpec extends AnyFlatSpec with Matchers { val statex = tm.execute(state) val expected: Seq[PartialFunction[Any, Boolean]] = Seq( { case ir.Circuit(_, _, "prefix_Foo") => true }, - { case ir.Module(_, "prefix_Foo", _, _) => true}, - { case ir.DefInstance(_, "bar", "Bar", _) => true}, - { case ir.DefInstance(_, "bar2", "Bar", _) => true}, - { case ir.Module(_, "Bar", _, _) => true}, - { case ir.DefNode(_, "a", _) => true} + { case ir.Module(_, "prefix_Foo", _, _) => true }, + { case ir.DefInstance(_, "bar", "Bar", _) => true }, + { case ir.DefInstance(_, "bar2", "Bar", _) => true }, + { case ir.Module(_, "Bar", _, _) => true }, + { case ir.DefNode(_, "a", _) => true } ) - expected.foreach(statex should containTree (_)) + expected.foreach(statex should containTree(_)) } it should "rename an instance via allowlisting" in new CircuitFixture { @@ -136,13 +129,13 @@ class ManipulateNamesSpec extends AnyFlatSpec with Matchers { val statex = tm.execute(state) val expected: Seq[PartialFunction[Any, Boolean]] = Seq( { case ir.Circuit(_, _, "Foo") => true }, - { case ir.Module(_, "Foo", _, _) => true}, - { case ir.DefInstance(_, "prefix_bar", "Bar", _) => true}, - { case ir.DefInstance(_, "bar2", "Bar", _) => true}, - { case ir.Module(_, "Bar", _, _) => true}, - { case ir.DefNode(_, "a", _) => true} + { case ir.Module(_, "Foo", _, _) => true }, + { case ir.DefInstance(_, "prefix_bar", "Bar", _) => true }, + { case ir.DefInstance(_, "bar2", "Bar", _) => true }, + { case ir.Module(_, "Bar", _, _) => true }, + { case ir.DefNode(_, "a", _) => true } ) - expected.foreach(statex should containTree (_)) + expected.foreach(statex should containTree(_)) } it should "rename a node via allowlisting" in new CircuitFixture { @@ -154,13 +147,13 @@ class ManipulateNamesSpec extends AnyFlatSpec with Matchers { val statex = tm.execute(state) val expected: Seq[PartialFunction[Any, Boolean]] = Seq( { case ir.Circuit(_, _, "Foo") => true }, - { case ir.Module(_, "Foo", _, _) => true}, - { case ir.DefInstance(_, "bar", "Bar", _) => true}, - { case ir.DefInstance(_, "bar2", "Bar", _) => true}, - { case ir.Module(_, "Bar", _, _) => true}, - { case ir.DefNode(_, "prefix_a", _) => true} + { case ir.Module(_, "Foo", _, _) => true }, + { case ir.DefInstance(_, "bar", "Bar", _) => true }, + { case ir.DefInstance(_, "bar2", "Bar", _) => true }, + { case ir.Module(_, "Bar", _, _) => true }, + { case ir.DefNode(_, "prefix_a", _) => true } ) - expected.foreach(statex should containTree (_)) + expected.foreach(statex should containTree(_)) } it should "throw user errors on circuits that haven't been run through LowerTypes" in { @@ -171,9 +164,9 @@ class ManipulateNamesSpec extends AnyFlatSpec with Matchers { | node baz = bar.a |""".stripMargin val state = CircuitState(Parser.parse(input), Seq.empty) - intercept [FirrtlUserException] { + intercept[FirrtlUserException] { (new AddPrefix).transform(state) - }.getMessage should include ("LowerTypes") + }.getMessage should include("LowerTypes") } it should "only consume annotations whose type parameter matches" in new CircuitFixture { @@ -187,25 +180,25 @@ class ManipulateNamesSpec extends AnyFlatSpec with Matchers { val statex = tm.execute(state) val expected: Seq[PartialFunction[Any, Boolean]] = Seq( { case ir.Circuit(_, _, "prefix_Foo") => true }, - { case ir.Module(_, "prefix_Foo", _, _) => true}, - { case ir.DefInstance(_, "prefix_bar", "prefix_Bar", _) => true}, - { case ir.DefInstance(_, "prefix_bar2", "prefix_Bar", _) => true}, - { case ir.Module(_, "prefix_Bar", _, _) => true}, - { case ir.DefNode(_, "a_suffix", _) => true} + { case ir.Module(_, "prefix_Foo", _, _) => true }, + { case ir.DefInstance(_, "prefix_bar", "prefix_Bar", _) => true }, + { case ir.DefInstance(_, "prefix_bar2", "prefix_Bar", _) => true }, + { case ir.Module(_, "prefix_Bar", _, _) => true }, + { case ir.DefNode(_, "a_suffix", _) => true } ) - expected.foreach(statex should containTree (_)) + expected.foreach(statex should containTree(_)) } - behavior of "ManipulateNamesBlocklistAnnotation" + behavior.of("ManipulateNamesBlocklistAnnotation") it should "throw an exception if a non-local target is skipped" in new CircuitFixture { val barA = CircuitTarget("Foo").module("Foo").instOf("bar", "Bar").ref("a") - assertThrows[java.lang.IllegalArgumentException]{ + assertThrows[java.lang.IllegalArgumentException] { Seq(ManipulateNamesBlocklistAnnotation(Seq(Seq(barA)), Dependency[AddPrefix])) } } - behavior of "ManipulateNamesAllowlistResultAnnotation" + behavior.of("ManipulateNamesAllowlistResultAnnotation") it should "delete itself if the new target is deleted" in { val `~Foo|Bar` = CircuitTarget("Foo").module("Bar") @@ -220,7 +213,7 @@ class ManipulateNamesSpec extends AnyFlatSpec with Matchers { val r = RenameMap() r.delete(`~Foo|prefix_Bar`) - a.update(r) should be (empty) + a.update(r) should be(empty) } it should "drop a deleted target" in { @@ -242,12 +235,12 @@ class ManipulateNamesSpec extends AnyFlatSpec with Matchers { case b: ManipulateNamesAllowlistResultAnnotation[_] => b } - ax should not be length (1) + ax should not be length(1) val keys = ax.head.toRenameMap.getUnderlying.keys keys should not contain (`~Foo|Bar`) - keys should contain (`~Foo|Baz`) + keys should contain(`~Foo|Baz`) } } diff --git a/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala b/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala index 299a4f48..d603db69 100644 --- a/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala +++ b/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala @@ -8,7 +8,7 @@ import firrtl.testutils.FirrtlFlatSpec import firrtl.testutils.FirrtlCheckers._ import firrtl.{CircuitState, WRef} -import firrtl.ir.{Connect, Mux, DefRegister} +import firrtl.ir.{Connect, DefRegister, Mux} import firrtl.stage.{FirrtlCircuitAnnotation, FirrtlSourceAnnotation, FirrtlStage} class RemoveResetSpec extends FirrtlFlatSpec with GivenWhenThen { @@ -17,12 +17,12 @@ class RemoveResetSpec extends FirrtlFlatSpec with GivenWhenThen { When("the circuit is compiled to low FIRRTL") (new FirrtlStage) .execute(Array("-X", "low"), Seq(FirrtlSourceAnnotation(string))) - .collectFirst{ case FirrtlCircuitAnnotation(a) => a } + .collectFirst { case FirrtlCircuitAnnotation(a) => a } .map(a => firrtl.CircuitState(a, firrtl.UnknownForm)) .get } - behavior of "RemoveReset" + behavior.of("RemoveReset") it should "not generate a reset mux for an invalid init" in { Given("a 1-bit register 'foo' initialized to invalid, 1-bit wire 'bar'") @@ -44,7 +44,7 @@ class RemoveResetSpec extends FirrtlFlatSpec with GivenWhenThen { val outputState = toLowFirrtl(input) Then("'foo' is NOT connected to a reset mux") - outputState shouldNot containTree { case Connect(_, WRef("foo",_,_,_), Mux(_,_,_,_)) => true } + outputState shouldNot containTree { case Connect(_, WRef("foo", _, _, _), Mux(_, _, _, _)) => true } } it should "generate a reset mux for only the portion of an invalid aggregate that is reset" in { @@ -71,11 +71,11 @@ class RemoveResetSpec extends FirrtlFlatSpec with GivenWhenThen { val outputState = toLowFirrtl(input) Then("foo.a[0] is NOT connected to a reset mux") - outputState shouldNot containTree { case Connect(_, WRef("foo_a_0",_,_,_), Mux(_,_,_,_)) => true } + outputState shouldNot containTree { case Connect(_, WRef("foo_a_0", _, _, _), Mux(_, _, _, _)) => true } And("foo.a[1] is connected to a reset mux") - outputState should containTree { case Connect(_, WRef("foo_a_1",_,_,_), Mux(_,_,_,_)) => true } + outputState should containTree { case Connect(_, WRef("foo_a_1", _, _, _), Mux(_, _, _, _)) => true } And("foo.b is NOT connected to a reset mux") - outputState shouldNot containTree { case Connect(_, WRef("foo_b",_,_,_), Mux(_,_,_,_)) => true } + outputState shouldNot containTree { case Connect(_, WRef("foo_b", _, _, _), Mux(_, _, _, _)) => true } } it should "propagate invalidations across connects" in { @@ -107,9 +107,9 @@ class RemoveResetSpec extends FirrtlFlatSpec with GivenWhenThen { val outputState = toLowFirrtl(input) Then("'foo.a' is connected to a reset mux") - outputState should containTree { case Connect(_, WRef("foo_a",_,_,_), Mux(_,_,_,_)) => true } + outputState should containTree { case Connect(_, WRef("foo_a", _, _, _), Mux(_, _, _, _)) => true } And("'foo.b' is NOT connected to a reset mux") - outputState shouldNot containTree { case Connect(_, WRef("foo_b",_,_,_), Mux(_,_,_,_)) => true } + outputState shouldNot containTree { case Connect(_, WRef("foo_b", _, _, _), Mux(_, _, _, _)) => true } } it should "canvert a reset wired to UInt<0> to a canonical non-reset" in { @@ -128,8 +128,8 @@ class RemoveResetSpec extends FirrtlFlatSpec with GivenWhenThen { val outputState = toLowFirrtl(input) Then("foo has a canonical non-reset declaration after RemoveReset") - outputState should containTree { case DefRegister(_, "foo", _,_, firrtl.Utils.zero, WRef("foo", _,_,_)) => true } + outputState should containTree { case DefRegister(_, "foo", _, _, firrtl.Utils.zero, WRef("foo", _, _, _)) => true } And("foo is NOT connected to a reset mux") - outputState shouldNot containTree { case Connect(_, WRef("foo",_,_,_), Mux(_,_,_,_)) => true } + outputState shouldNot containTree { case Connect(_, WRef("foo", _, _, _), Mux(_, _, _, _)) => true } } } diff --git a/src/test/scala/firrtlTests/transforms/TopWiringTest.scala b/src/test/scala/firrtlTests/transforms/TopWiringTest.scala index 0ac12ef8..97fafe41 100644 --- a/src/test/scala/firrtlTests/transforms/TopWiringTest.scala +++ b/src/test/scala/firrtlTests/transforms/TopWiringTest.scala @@ -6,724 +6,718 @@ package transforms import java.io._ import firrtl._ -import firrtl.ir.{Type, GroundType, IntWidth} +import firrtl.ir.{GroundType, IntWidth, Type} import firrtl.Parser -import firrtl.annotations.{ - CircuitName, - ModuleName, - ComponentName, - Target -} +import firrtl.annotations.{CircuitName, ComponentName, ModuleName, Target} import firrtl.transforms.TopWiring._ import firrtl.testutils._ - trait TopWiringTestsCommon extends FirrtlRunners { - val testDir = createTestDirectory("TopWiringTests") - val testDirName = testDir.getPath - def transform = new TopWiringTransform + val testDir = createTestDirectory("TopWiringTests") + val testDirName = testDir.getPath + def transform = new TopWiringTransform - def topWiringDummyOutputFilesFunction(dir: String, - mapping: Seq[((ComponentName, Type, Boolean, Seq[String], String), Int)], - state: CircuitState): CircuitState = { - state - } + def topWiringDummyOutputFilesFunction( + dir: String, + mapping: Seq[((ComponentName, Type, Boolean, Seq[String], String), Int)], + state: CircuitState + ): CircuitState = { + state + } - def topWiringTestOutputFilesFunction(dir: String, - mapping: Seq[((ComponentName, Type, Boolean, Seq[String], String), Int)], - state: CircuitState): CircuitState = { - val testOutputFile = new PrintWriter(new File(dir, "TopWiringOutputTest.txt" )) - mapping map { - case ((_, tpe, _, path, prefix), index) => { - val portwidth = tpe match { case GroundType(IntWidth(w)) => w } - val portnum = index - val portname = prefix + path.mkString("_") - testOutputFile.append(s"new top level port $portnum : $portname, with width $portwidth \n") - } - } - testOutputFile.close() - state - } + def topWiringTestOutputFilesFunction( + dir: String, + mapping: Seq[((ComponentName, Type, Boolean, Seq[String], String), Int)], + state: CircuitState + ): CircuitState = { + val testOutputFile = new PrintWriter(new File(dir, "TopWiringOutputTest.txt")) + mapping.map { + case ((_, tpe, _, path, prefix), index) => { + val portwidth = tpe match { case GroundType(IntWidth(w)) => w } + val portnum = index + val portname = prefix + path.mkString("_") + testOutputFile.append(s"new top level port $portnum : $portname, with width $portwidth \n") + } + } + testOutputFile.close() + state + } } /** - * Tests TopWiring transformation - */ -class TopWiringTests extends MiddleTransformSpec with TopWiringTestsCommon { + * Tests TopWiring transformation + */ +class TopWiringTests extends MiddleTransformSpec with TopWiringTestsCommon { - "The signal x in module C" should s"be connected to Top port with topwiring prefix and outputfile in $testDirName" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : - | output x: UInt<1> - | x <= UInt(1) - | inst b1 of B - | module A_ : - | output x: UInt<1> - | x <= UInt(1) - | module B : - | output x: UInt<1> - | x <= UInt(1) - | inst c1 of C - | module C: - | output x: UInt<1> - | x <= UInt(0) + "The signal x in module C" should s"be connected to Top port with topwiring prefix and outputfile in $testDirName" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | x <= UInt(1) + | inst b1 of B + | module A_ : + | output x: UInt<1> + | x <= UInt(1) + | module B : + | output x: UInt<1> + | x <= UInt(1) + | inst c1 of C + | module C: + | output x: UInt<1> + | x <= UInt(0) """.stripMargin - val topwiringannos = Seq(TopWiringAnnotation(ComponentName(s"x", - ModuleName(s"C", CircuitName(s"Top"))), - s"topwiring_"), - TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction)) - val check = - """circuit Top : - | module Top : - | output topwiring_a1_b1_c1_x: UInt<1> - | inst a1 of A - | inst a2 of A_ - | topwiring_a1_b1_c1_x <= a1.topwiring_b1_c1_x - | module A : - | output x: UInt<1> - | output topwiring_b1_c1_x: UInt<1> - | inst b1 of B - | x <= UInt(1) - | topwiring_b1_c1_x <= b1.topwiring_c1_x - | module A_ : - | output x: UInt<1> - | x <= UInt(1) - | module B : - | output x: UInt<1> - | output topwiring_c1_x: UInt<1> - | inst c1 of C - | x <= UInt(1) - | topwiring_c1_x <= c1.topwiring_x - | module C: - | output x: UInt<1> - | output topwiring_x: UInt<1> - | x <= UInt(0) - | topwiring_x <= x + val topwiringannos = Seq( + TopWiringAnnotation(ComponentName(s"x", ModuleName(s"C", CircuitName(s"Top"))), s"topwiring_"), + TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction) + ) + val check = + """circuit Top : + | module Top : + | output topwiring_a1_b1_c1_x: UInt<1> + | inst a1 of A + | inst a2 of A_ + | topwiring_a1_b1_c1_x <= a1.topwiring_b1_c1_x + | module A : + | output x: UInt<1> + | output topwiring_b1_c1_x: UInt<1> + | inst b1 of B + | x <= UInt(1) + | topwiring_b1_c1_x <= b1.topwiring_c1_x + | module A_ : + | output x: UInt<1> + | x <= UInt(1) + | module B : + | output x: UInt<1> + | output topwiring_c1_x: UInt<1> + | inst c1 of C + | x <= UInt(1) + | topwiring_c1_x <= c1.topwiring_x + | module C: + | output x: UInt<1> + | output topwiring_x: UInt<1> + | x <= UInt(0) + | topwiring_x <= x """.stripMargin - execute(input, check, topwiringannos) - } + execute(input, check, topwiringannos) + } - "The signal x in module C inst c1 and c2" should + "The signal x in module C inst c1 and c2" should s"be connected to Top port with topwiring prefix and outfile in $testDirName" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : - | output x: UInt<1> - | x <= UInt(1) - | inst b1 of B - | module A_ : - | output x: UInt<1> - | x <= UInt(1) - | module B : - | output x: UInt<1> - | x <= UInt(1) - | inst c1 of C - | inst c2 of C - | module C: - | output x: UInt<1> - | x <= UInt(0) + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | x <= UInt(1) + | inst b1 of B + | module A_ : + | output x: UInt<1> + | x <= UInt(1) + | module B : + | output x: UInt<1> + | x <= UInt(1) + | inst c1 of C + | inst c2 of C + | module C: + | output x: UInt<1> + | x <= UInt(0) """.stripMargin - val topwiringannos = Seq(TopWiringAnnotation(ComponentName(s"x", - ModuleName(s"C", CircuitName(s"Top"))), s"topwiring_"), - TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction)) - val check = - """circuit Top : - | module Top : - | output topwiring_a1_b1_c1_x: UInt<1> - | output topwiring_a1_b1_c2_x: UInt<1> - | inst a1 of A - | inst a2 of A_ - | topwiring_a1_b1_c1_x <= a1.topwiring_b1_c1_x - | topwiring_a1_b1_c2_x <= a1.topwiring_b1_c2_x - | module A : - | output x: UInt<1> - | output topwiring_b1_c1_x: UInt<1> - | output topwiring_b1_c2_x: UInt<1> - | inst b1 of B - | x <= UInt(1) - | topwiring_b1_c1_x <= b1.topwiring_c1_x - | topwiring_b1_c2_x <= b1.topwiring_c2_x - | module A_ : - | output x: UInt<1> - | x <= UInt(1) - | module B : - | output x: UInt<1> - | output topwiring_c1_x: UInt<1> - | output topwiring_c2_x: UInt<1> - | inst c1 of C - | inst c2 of C - | x <= UInt(1) - | topwiring_c1_x <= c1.topwiring_x - | topwiring_c2_x <= c2.topwiring_x - | module C: - | output x: UInt<1> - | output topwiring_x: UInt<1> - | x <= UInt(0) - | topwiring_x <= x + val topwiringannos = Seq( + TopWiringAnnotation(ComponentName(s"x", ModuleName(s"C", CircuitName(s"Top"))), s"topwiring_"), + TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction) + ) + val check = + """circuit Top : + | module Top : + | output topwiring_a1_b1_c1_x: UInt<1> + | output topwiring_a1_b1_c2_x: UInt<1> + | inst a1 of A + | inst a2 of A_ + | topwiring_a1_b1_c1_x <= a1.topwiring_b1_c1_x + | topwiring_a1_b1_c2_x <= a1.topwiring_b1_c2_x + | module A : + | output x: UInt<1> + | output topwiring_b1_c1_x: UInt<1> + | output topwiring_b1_c2_x: UInt<1> + | inst b1 of B + | x <= UInt(1) + | topwiring_b1_c1_x <= b1.topwiring_c1_x + | topwiring_b1_c2_x <= b1.topwiring_c2_x + | module A_ : + | output x: UInt<1> + | x <= UInt(1) + | module B : + | output x: UInt<1> + | output topwiring_c1_x: UInt<1> + | output topwiring_c2_x: UInt<1> + | inst c1 of C + | inst c2 of C + | x <= UInt(1) + | topwiring_c1_x <= c1.topwiring_x + | topwiring_c2_x <= c2.topwiring_x + | module C: + | output x: UInt<1> + | output topwiring_x: UInt<1> + | x <= UInt(0) + | topwiring_x <= x """.stripMargin - execute(input, check, topwiringannos) - } + execute(input, check, topwiringannos) + } - "The signal x in module C" should - s"be connected to Top port with topwiring prefix and outputfile in $testDirName, after name colission" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | wire topwiring_a1_b1_c1_x : UInt<1> - | topwiring_a1_b1_c1_x <= UInt(0) - | module A : - | output x: UInt<1> - | x <= UInt(1) - | inst b1 of B - | wire topwiring_b1_c1_x : UInt<1> - | topwiring_b1_c1_x <= UInt(0) - | module A_ : - | output x: UInt<1> - | x <= UInt(1) - | module B : - | output x: UInt<1> - | x <= UInt(1) - | inst c1 of C - | module C: - | output x: UInt<1> - | x <= UInt(0) + "The signal x in module C" should + s"be connected to Top port with topwiring prefix and outputfile in $testDirName, after name colission" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | wire topwiring_a1_b1_c1_x : UInt<1> + | topwiring_a1_b1_c1_x <= UInt(0) + | module A : + | output x: UInt<1> + | x <= UInt(1) + | inst b1 of B + | wire topwiring_b1_c1_x : UInt<1> + | topwiring_b1_c1_x <= UInt(0) + | module A_ : + | output x: UInt<1> + | x <= UInt(1) + | module B : + | output x: UInt<1> + | x <= UInt(1) + | inst c1 of C + | module C: + | output x: UInt<1> + | x <= UInt(0) """.stripMargin - val topwiringannos = Seq(TopWiringAnnotation(ComponentName(s"x", - ModuleName(s"C", CircuitName(s"Top"))), - s"topwiring_"), - TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction)) - val check = - """circuit Top : - | module Top : - | output topwiring_a1_b1_c1_x_0: UInt<1> - | inst a1 of A - | inst a2 of A_ - | wire topwiring_a1_b1_c1_x : UInt<1> - | topwiring_a1_b1_c1_x <= UInt<1>("h0") - | topwiring_a1_b1_c1_x_0 <= a1.topwiring_b1_c1_x_0 - | module A : - | output x: UInt<1> - | output topwiring_b1_c1_x_0: UInt<1> - | inst b1 of B - | wire topwiring_b1_c1_x : UInt<1> - | x <= UInt(1) - | topwiring_b1_c1_x <= UInt<1>("h0") - | topwiring_b1_c1_x_0 <= b1.topwiring_c1_x - | module A_ : - | output x: UInt<1> - | x <= UInt(1) - | module B : - | output x: UInt<1> - | output topwiring_c1_x: UInt<1> - | inst c1 of C - | x <= UInt(1) - | topwiring_c1_x <= c1.topwiring_x - | module C: - | output x: UInt<1> - | output topwiring_x: UInt<1> - | x <= UInt(0) - | topwiring_x <= x + val topwiringannos = Seq( + TopWiringAnnotation(ComponentName(s"x", ModuleName(s"C", CircuitName(s"Top"))), s"topwiring_"), + TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction) + ) + val check = + """circuit Top : + | module Top : + | output topwiring_a1_b1_c1_x_0: UInt<1> + | inst a1 of A + | inst a2 of A_ + | wire topwiring_a1_b1_c1_x : UInt<1> + | topwiring_a1_b1_c1_x <= UInt<1>("h0") + | topwiring_a1_b1_c1_x_0 <= a1.topwiring_b1_c1_x_0 + | module A : + | output x: UInt<1> + | output topwiring_b1_c1_x_0: UInt<1> + | inst b1 of B + | wire topwiring_b1_c1_x : UInt<1> + | x <= UInt(1) + | topwiring_b1_c1_x <= UInt<1>("h0") + | topwiring_b1_c1_x_0 <= b1.topwiring_c1_x + | module A_ : + | output x: UInt<1> + | x <= UInt(1) + | module B : + | output x: UInt<1> + | output topwiring_c1_x: UInt<1> + | inst c1 of C + | x <= UInt(1) + | topwiring_c1_x <= c1.topwiring_x + | module C: + | output x: UInt<1> + | output topwiring_x: UInt<1> + | x <= UInt(0) + | topwiring_x <= x """.stripMargin - execute(input, check, topwiringannos) - } + execute(input, check, topwiringannos) + } - "The signal x in module C" should - "be connected to Top port with topwiring prefix and no output function" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : - | output x: UInt<1> - | x <= UInt(1) - | inst b1 of B - | module A_ : - | output x: UInt<1> - | x <= UInt(1) - | module B : - | output x: UInt<1> - | x <= UInt(1) - | inst c1 of C - | module C: - | output x: UInt<1> - | x <= UInt(0) + "The signal x in module C" should + "be connected to Top port with topwiring prefix and no output function" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | x <= UInt(1) + | inst b1 of B + | module A_ : + | output x: UInt<1> + | x <= UInt(1) + | module B : + | output x: UInt<1> + | x <= UInt(1) + | inst c1 of C + | module C: + | output x: UInt<1> + | x <= UInt(0) """.stripMargin - val topwiringannos = Seq(TopWiringAnnotation(ComponentName(s"x", - ModuleName(s"C", CircuitName(s"Top"))), - s"topwiring_")) - val check = - """circuit Top : - | module Top : - | output topwiring_a1_b1_c1_x: UInt<1> - | inst a1 of A - | inst a2 of A_ - | topwiring_a1_b1_c1_x <= a1.topwiring_b1_c1_x - | module A : - | output x: UInt<1> - | output topwiring_b1_c1_x: UInt<1> - | inst b1 of B - | x <= UInt(1) - | topwiring_b1_c1_x <= b1.topwiring_c1_x - | module A_ : - | output x: UInt<1> - | x <= UInt(1) - | module B : - | output x: UInt<1> - | output topwiring_c1_x: UInt<1> - | inst c1 of C - | x <= UInt(1) - | topwiring_c1_x <= c1.topwiring_x - | module C: - | output x: UInt<1> - | output topwiring_x: UInt<1> - | x <= UInt(0) - | topwiring_x <= x + val topwiringannos = + Seq(TopWiringAnnotation(ComponentName(s"x", ModuleName(s"C", CircuitName(s"Top"))), s"topwiring_")) + val check = + """circuit Top : + | module Top : + | output topwiring_a1_b1_c1_x: UInt<1> + | inst a1 of A + | inst a2 of A_ + | topwiring_a1_b1_c1_x <= a1.topwiring_b1_c1_x + | module A : + | output x: UInt<1> + | output topwiring_b1_c1_x: UInt<1> + | inst b1 of B + | x <= UInt(1) + | topwiring_b1_c1_x <= b1.topwiring_c1_x + | module A_ : + | output x: UInt<1> + | x <= UInt(1) + | module B : + | output x: UInt<1> + | output topwiring_c1_x: UInt<1> + | inst c1 of C + | x <= UInt(1) + | topwiring_c1_x <= c1.topwiring_x + | module C: + | output x: UInt<1> + | output topwiring_x: UInt<1> + | x <= UInt(0) + | topwiring_x <= x """.stripMargin - execute(input, check, topwiringannos) - } + execute(input, check, topwiringannos) + } - "The signal x in module C inst c1 and c2 and signal y in module A_" should - s"be connected to Top port with topwiring prefix and outfile in $testDirName" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : - | output x: UInt<1> - | x <= UInt(1) - | inst b1 of B - | module A_ : - | output x: UInt<1> - | wire y : UInt<1> - | y <= UInt(1) - | x <= UInt(1) - | module B : - | output x: UInt<1> - | x <= UInt(1) - | inst c1 of C - | inst c2 of C - | module C: - | output x: UInt<1> - | x <= UInt(0) + "The signal x in module C inst c1 and c2 and signal y in module A_" should + s"be connected to Top port with topwiring prefix and outfile in $testDirName" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | x <= UInt(1) + | inst b1 of B + | module A_ : + | output x: UInt<1> + | wire y : UInt<1> + | y <= UInt(1) + | x <= UInt(1) + | module B : + | output x: UInt<1> + | x <= UInt(1) + | inst c1 of C + | inst c2 of C + | module C: + | output x: UInt<1> + | x <= UInt(0) """.stripMargin - val topwiringannos = Seq(TopWiringAnnotation(ComponentName(s"x", - ModuleName(s"C", CircuitName(s"Top"))), - s"topwiring_"), - TopWiringAnnotation(ComponentName(s"y", - ModuleName(s"A_", CircuitName(s"Top"))), - s"topwiring_"), - TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction)) - val check = - """circuit Top : - | module Top : - | output topwiring_a1_b1_c1_x: UInt<1> - | output topwiring_a1_b1_c2_x: UInt<1> - | output topwiring_a2_y: UInt<1> - | inst a1 of A - | inst a2 of A_ - | topwiring_a1_b1_c1_x <= a1.topwiring_b1_c1_x - | topwiring_a1_b1_c2_x <= a1.topwiring_b1_c2_x - | topwiring_a2_y <= a2.topwiring_y - | module A : - | output x: UInt<1> - | output topwiring_b1_c1_x: UInt<1> - | output topwiring_b1_c2_x: UInt<1> - | inst b1 of B - | x <= UInt(1) - | topwiring_b1_c1_x <= b1.topwiring_c1_x - | topwiring_b1_c2_x <= b1.topwiring_c2_x - | module A_ : - | output x: UInt<1> - | output topwiring_y: UInt<1> - | wire y : UInt<1> - | x <= UInt(1) - | y <= UInt<1>("h1") - | topwiring_y <= y - | module B : - | output x: UInt<1> - | output topwiring_c1_x: UInt<1> - | output topwiring_c2_x: UInt<1> - | inst c1 of C - | inst c2 of C - | x <= UInt(1) - | topwiring_c1_x <= c1.topwiring_x - | topwiring_c2_x <= c2.topwiring_x - | module C: - | output x: UInt<1> - | output topwiring_x: UInt<1> - | x <= UInt(0) - | topwiring_x <= x + val topwiringannos = Seq( + TopWiringAnnotation(ComponentName(s"x", ModuleName(s"C", CircuitName(s"Top"))), s"topwiring_"), + TopWiringAnnotation(ComponentName(s"y", ModuleName(s"A_", CircuitName(s"Top"))), s"topwiring_"), + TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction) + ) + val check = + """circuit Top : + | module Top : + | output topwiring_a1_b1_c1_x: UInt<1> + | output topwiring_a1_b1_c2_x: UInt<1> + | output topwiring_a2_y: UInt<1> + | inst a1 of A + | inst a2 of A_ + | topwiring_a1_b1_c1_x <= a1.topwiring_b1_c1_x + | topwiring_a1_b1_c2_x <= a1.topwiring_b1_c2_x + | topwiring_a2_y <= a2.topwiring_y + | module A : + | output x: UInt<1> + | output topwiring_b1_c1_x: UInt<1> + | output topwiring_b1_c2_x: UInt<1> + | inst b1 of B + | x <= UInt(1) + | topwiring_b1_c1_x <= b1.topwiring_c1_x + | topwiring_b1_c2_x <= b1.topwiring_c2_x + | module A_ : + | output x: UInt<1> + | output topwiring_y: UInt<1> + | wire y : UInt<1> + | x <= UInt(1) + | y <= UInt<1>("h1") + | topwiring_y <= y + | module B : + | output x: UInt<1> + | output topwiring_c1_x: UInt<1> + | output topwiring_c2_x: UInt<1> + | inst c1 of C + | inst c2 of C + | x <= UInt(1) + | topwiring_c1_x <= c1.topwiring_x + | topwiring_c2_x <= c2.topwiring_x + | module C: + | output x: UInt<1> + | output topwiring_x: UInt<1> + | x <= UInt(0) + | topwiring_x <= x """.stripMargin - execute(input, check, topwiringannos) - } + execute(input, check, topwiringannos) + } - "The signal x in module C inst c1 and c2 and signal y in module A_" should - s"be connected to Top port with topwiring and top2wiring prefix and outfile in $testDirName" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : - | output x: UInt<1> - | x <= UInt(1) - | inst b1 of B - | module A_ : - | output x: UInt<1> - | wire y : UInt<1> - | y <= UInt(1) - | x <= UInt(1) - | module B : - | output x: UInt<1> - | x <= UInt(1) - | inst c1 of C - | inst c2 of C - | module C: - | output x: UInt<1> - | x <= UInt(0) + "The signal x in module C inst c1 and c2 and signal y in module A_" should + s"be connected to Top port with topwiring and top2wiring prefix and outfile in $testDirName" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | x <= UInt(1) + | inst b1 of B + | module A_ : + | output x: UInt<1> + | wire y : UInt<1> + | y <= UInt(1) + | x <= UInt(1) + | module B : + | output x: UInt<1> + | x <= UInt(1) + | inst c1 of C + | inst c2 of C + | module C: + | output x: UInt<1> + | x <= UInt(0) """.stripMargin - val topwiringannos = Seq(TopWiringAnnotation(ComponentName(s"x", - ModuleName(s"C", CircuitName(s"Top"))), - s"topwiring_"), - TopWiringAnnotation(ComponentName(s"y", - ModuleName(s"A_", CircuitName(s"Top"))), - s"top2wiring_"), - TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction)) - val check = - """circuit Top : - | module Top : - | output topwiring_a1_b1_c1_x: UInt<1> - | output topwiring_a1_b1_c2_x: UInt<1> - | output top2wiring_a2_y: UInt<1> - | inst a1 of A - | inst a2 of A_ - | topwiring_a1_b1_c1_x <= a1.topwiring_b1_c1_x - | topwiring_a1_b1_c2_x <= a1.topwiring_b1_c2_x - | top2wiring_a2_y <= a2.top2wiring_y - | module A : - | output x: UInt<1> - | output topwiring_b1_c1_x: UInt<1> - | output topwiring_b1_c2_x: UInt<1> - | inst b1 of B - | x <= UInt(1) - | topwiring_b1_c1_x <= b1.topwiring_c1_x - | topwiring_b1_c2_x <= b1.topwiring_c2_x - | module A_ : - | output x: UInt<1> - | output top2wiring_y: UInt<1> - | wire y : UInt<1> - | x <= UInt(1) - | y <= UInt<1>("h1") - | top2wiring_y <= y - | module B : - | output x: UInt<1> - | output topwiring_c1_x: UInt<1> - | output topwiring_c2_x: UInt<1> - | inst c1 of C - | inst c2 of C - | x <= UInt(1) - | topwiring_c1_x <= c1.topwiring_x - | topwiring_c2_x <= c2.topwiring_x - | module C: - | output x: UInt<1> - | output topwiring_x: UInt<1> - | x <= UInt(0) - | topwiring_x <= x + val topwiringannos = Seq( + TopWiringAnnotation(ComponentName(s"x", ModuleName(s"C", CircuitName(s"Top"))), s"topwiring_"), + TopWiringAnnotation(ComponentName(s"y", ModuleName(s"A_", CircuitName(s"Top"))), s"top2wiring_"), + TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction) + ) + val check = + """circuit Top : + | module Top : + | output topwiring_a1_b1_c1_x: UInt<1> + | output topwiring_a1_b1_c2_x: UInt<1> + | output top2wiring_a2_y: UInt<1> + | inst a1 of A + | inst a2 of A_ + | topwiring_a1_b1_c1_x <= a1.topwiring_b1_c1_x + | topwiring_a1_b1_c2_x <= a1.topwiring_b1_c2_x + | top2wiring_a2_y <= a2.top2wiring_y + | module A : + | output x: UInt<1> + | output topwiring_b1_c1_x: UInt<1> + | output topwiring_b1_c2_x: UInt<1> + | inst b1 of B + | x <= UInt(1) + | topwiring_b1_c1_x <= b1.topwiring_c1_x + | topwiring_b1_c2_x <= b1.topwiring_c2_x + | module A_ : + | output x: UInt<1> + | output top2wiring_y: UInt<1> + | wire y : UInt<1> + | x <= UInt(1) + | y <= UInt<1>("h1") + | top2wiring_y <= y + | module B : + | output x: UInt<1> + | output topwiring_c1_x: UInt<1> + | output topwiring_c2_x: UInt<1> + | inst c1 of C + | inst c2 of C + | x <= UInt(1) + | topwiring_c1_x <= c1.topwiring_x + | topwiring_c2_x <= c2.topwiring_x + | module C: + | output x: UInt<1> + | output topwiring_x: UInt<1> + | x <= UInt(0) + | topwiring_x <= x """.stripMargin - execute(input, check, topwiringannos) - } + execute(input, check, topwiringannos) + } - "The signal fullword in module C inst c1 and c2 and signal y in module A_" should - s"be connected to Top port with topwiring and top2wiring prefix and outfile in $testDirName" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : - | output fullword: UInt<1> - | fullword <= UInt(1) - | inst b1 of B - | module A_ : - | output fullword: UInt<1> - | wire y : UInt<1> - | y <= UInt(1) - | fullword <= UInt(1) - | module B : - | output fullword: UInt<1> - | fullword <= UInt(1) - | inst c1 of C - | inst c2 of C - | module C: - | output fullword: UInt<1> - | fullword <= UInt(0) + "The signal fullword in module C inst c1 and c2 and signal y in module A_" should + s"be connected to Top port with topwiring and top2wiring prefix and outfile in $testDirName" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output fullword: UInt<1> + | fullword <= UInt(1) + | inst b1 of B + | module A_ : + | output fullword: UInt<1> + | wire y : UInt<1> + | y <= UInt(1) + | fullword <= UInt(1) + | module B : + | output fullword: UInt<1> + | fullword <= UInt(1) + | inst c1 of C + | inst c2 of C + | module C: + | output fullword: UInt<1> + | fullword <= UInt(0) """.stripMargin - val topwiringannos = Seq(TopWiringAnnotation(ComponentName(s"fullword", - ModuleName(s"C", CircuitName(s"Top"))), - s"topwiring_"), - TopWiringAnnotation(ComponentName(s"y", - ModuleName(s"A_", CircuitName(s"Top"))), - s"top2wiring_"), - TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction)) - val check = - """circuit Top : - | module Top : - | output topwiring_a1_b1_c1_fullword: UInt<1> - | output topwiring_a1_b1_c2_fullword: UInt<1> - | output top2wiring_a2_y: UInt<1> - | inst a1 of A - | inst a2 of A_ - | topwiring_a1_b1_c1_fullword <= a1.topwiring_b1_c1_fullword - | topwiring_a1_b1_c2_fullword <= a1.topwiring_b1_c2_fullword - | top2wiring_a2_y <= a2.top2wiring_y - | module A : - | output fullword: UInt<1> - | output topwiring_b1_c1_fullword: UInt<1> - | output topwiring_b1_c2_fullword: UInt<1> - | inst b1 of B - | fullword <= UInt(1) - | topwiring_b1_c1_fullword <= b1.topwiring_c1_fullword - | topwiring_b1_c2_fullword <= b1.topwiring_c2_fullword - | module A_ : - | output fullword: UInt<1> - | output top2wiring_y: UInt<1> - | wire y : UInt<1> - | fullword <= UInt(1) - | y <= UInt<1>("h1") - | top2wiring_y <= y - | module B : - | output fullword: UInt<1> - | output topwiring_c1_fullword: UInt<1> - | output topwiring_c2_fullword: UInt<1> - | inst c1 of C - | inst c2 of C - | fullword <= UInt(1) - | topwiring_c1_fullword <= c1.topwiring_fullword - | topwiring_c2_fullword <= c2.topwiring_fullword - | module C: - | output fullword: UInt<1> - | output topwiring_fullword: UInt<1> - | fullword <= UInt(0) - | topwiring_fullword <= fullword + val topwiringannos = Seq( + TopWiringAnnotation(ComponentName(s"fullword", ModuleName(s"C", CircuitName(s"Top"))), s"topwiring_"), + TopWiringAnnotation(ComponentName(s"y", ModuleName(s"A_", CircuitName(s"Top"))), s"top2wiring_"), + TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction) + ) + val check = + """circuit Top : + | module Top : + | output topwiring_a1_b1_c1_fullword: UInt<1> + | output topwiring_a1_b1_c2_fullword: UInt<1> + | output top2wiring_a2_y: UInt<1> + | inst a1 of A + | inst a2 of A_ + | topwiring_a1_b1_c1_fullword <= a1.topwiring_b1_c1_fullword + | topwiring_a1_b1_c2_fullword <= a1.topwiring_b1_c2_fullword + | top2wiring_a2_y <= a2.top2wiring_y + | module A : + | output fullword: UInt<1> + | output topwiring_b1_c1_fullword: UInt<1> + | output topwiring_b1_c2_fullword: UInt<1> + | inst b1 of B + | fullword <= UInt(1) + | topwiring_b1_c1_fullword <= b1.topwiring_c1_fullword + | topwiring_b1_c2_fullword <= b1.topwiring_c2_fullword + | module A_ : + | output fullword: UInt<1> + | output top2wiring_y: UInt<1> + | wire y : UInt<1> + | fullword <= UInt(1) + | y <= UInt<1>("h1") + | top2wiring_y <= y + | module B : + | output fullword: UInt<1> + | output topwiring_c1_fullword: UInt<1> + | output topwiring_c2_fullword: UInt<1> + | inst c1 of C + | inst c2 of C + | fullword <= UInt(1) + | topwiring_c1_fullword <= c1.topwiring_fullword + | topwiring_c2_fullword <= c2.topwiring_fullword + | module C: + | output fullword: UInt<1> + | output topwiring_fullword: UInt<1> + | fullword <= UInt(0) + | topwiring_fullword <= fullword """.stripMargin - execute(input, check, topwiringannos) - } - + execute(input, check, topwiringannos) + } - "The signal fullword in module C inst c1 and c2 and signal fullword in module B" should - s"be connected to Top port with topwiring prefix" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : - | output fullword: UInt<1> - | fullword <= UInt(1) - | inst b1 of B - | module A_ : - | output fullword: UInt<1> - | wire y : UInt<1> - | y <= UInt(1) - | fullword <= UInt(1) - | module B : - | output fullword: UInt<1> - | fullword <= UInt(1) - | inst c1 of C - | inst c2 of C - | module C: - | output fullword: UInt<1> - | fullword <= UInt(0) + "The signal fullword in module C inst c1 and c2 and signal fullword in module B" should + s"be connected to Top port with topwiring prefix" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output fullword: UInt<1> + | fullword <= UInt(1) + | inst b1 of B + | module A_ : + | output fullword: UInt<1> + | wire y : UInt<1> + | y <= UInt(1) + | fullword <= UInt(1) + | module B : + | output fullword: UInt<1> + | fullword <= UInt(1) + | inst c1 of C + | inst c2 of C + | module C: + | output fullword: UInt<1> + | fullword <= UInt(0) """.stripMargin - val topwiringannos = Seq(TopWiringAnnotation(ComponentName(s"fullword", - ModuleName(s"C", CircuitName(s"Top"))), - s"topwiring_"), - TopWiringAnnotation(ComponentName(s"fullword", - ModuleName(s"B", CircuitName(s"Top"))), - s"topwiring_")) - val check = - """circuit Top : - | module Top : - | output topwiring_a1_b1_fullword: UInt<1> - | output topwiring_a1_b1_c1_fullword: UInt<1> - | output topwiring_a1_b1_c2_fullword: UInt<1> - | inst a1 of A - | inst a2 of A_ - | topwiring_a1_b1_fullword <= a1.topwiring_b1_fullword - | topwiring_a1_b1_c1_fullword <= a1.topwiring_b1_c1_fullword - | topwiring_a1_b1_c2_fullword <= a1.topwiring_b1_c2_fullword - | module A : - | output fullword: UInt<1> - | output topwiring_b1_fullword: UInt<1> - | output topwiring_b1_c1_fullword: UInt<1> - | output topwiring_b1_c2_fullword: UInt<1> - | inst b1 of B - | fullword <= UInt(1) - | topwiring_b1_fullword <= b1.topwiring_fullword - | topwiring_b1_c1_fullword <= b1.topwiring_c1_fullword - | topwiring_b1_c2_fullword <= b1.topwiring_c2_fullword - | module A_ : - | output fullword: UInt<1> - | wire y : UInt<1> - | fullword <= UInt(1) - | y <= UInt<1>("h1") - | module B : - | output fullword: UInt<1> - | output topwiring_fullword: UInt<1> - | output topwiring_c1_fullword: UInt<1> - | output topwiring_c2_fullword: UInt<1> - | inst c1 of C - | inst c2 of C - | fullword <= UInt(1) - | topwiring_fullword <= fullword - | topwiring_c1_fullword <= c1.topwiring_fullword - | topwiring_c2_fullword <= c2.topwiring_fullword - | module C: - | output fullword: UInt<1> - | output topwiring_fullword: UInt<1> - | fullword <= UInt(0) - | topwiring_fullword <= fullword + val topwiringannos = Seq( + TopWiringAnnotation(ComponentName(s"fullword", ModuleName(s"C", CircuitName(s"Top"))), s"topwiring_"), + TopWiringAnnotation(ComponentName(s"fullword", ModuleName(s"B", CircuitName(s"Top"))), s"topwiring_") + ) + val check = + """circuit Top : + | module Top : + | output topwiring_a1_b1_fullword: UInt<1> + | output topwiring_a1_b1_c1_fullword: UInt<1> + | output topwiring_a1_b1_c2_fullword: UInt<1> + | inst a1 of A + | inst a2 of A_ + | topwiring_a1_b1_fullword <= a1.topwiring_b1_fullword + | topwiring_a1_b1_c1_fullword <= a1.topwiring_b1_c1_fullword + | topwiring_a1_b1_c2_fullword <= a1.topwiring_b1_c2_fullword + | module A : + | output fullword: UInt<1> + | output topwiring_b1_fullword: UInt<1> + | output topwiring_b1_c1_fullword: UInt<1> + | output topwiring_b1_c2_fullword: UInt<1> + | inst b1 of B + | fullword <= UInt(1) + | topwiring_b1_fullword <= b1.topwiring_fullword + | topwiring_b1_c1_fullword <= b1.topwiring_c1_fullword + | topwiring_b1_c2_fullword <= b1.topwiring_c2_fullword + | module A_ : + | output fullword: UInt<1> + | wire y : UInt<1> + | fullword <= UInt(1) + | y <= UInt<1>("h1") + | module B : + | output fullword: UInt<1> + | output topwiring_fullword: UInt<1> + | output topwiring_c1_fullword: UInt<1> + | output topwiring_c2_fullword: UInt<1> + | inst c1 of C + | inst c2 of C + | fullword <= UInt(1) + | topwiring_fullword <= fullword + | topwiring_c1_fullword <= c1.topwiring_fullword + | topwiring_c2_fullword <= c2.topwiring_fullword + | module C: + | output fullword: UInt<1> + | output topwiring_fullword: UInt<1> + | fullword <= UInt(0) + | topwiring_fullword <= fullword """.stripMargin - execute(input, check, topwiringannos) - } + execute(input, check, topwiringannos) + } - "TopWiringTransform" should "do nothing if run without TopWiring* annotations" in { - val input = """|circuit Top : - | module Top : - | input foo : UInt<1>""".stripMargin - val inputFile = { - val fileName = s"${testDir.getAbsolutePath}/input-no-sources.fir" - val w = new PrintWriter(fileName) - w.write(input) - w.close() - fileName - } - val args = Array( - "--custom-transforms", "firrtl.transforms.TopWiring.TopWiringTransform", - "--input-file", inputFile, - "--top-name", "Top", - "--compiler", "low", - "--info-mode", "ignore" - ) - firrtl.Driver.execute(args) match { - case FirrtlExecutionSuccess(_, emitted) => - parse(emitted).serialize should be (parse(input).serialize) - case _ => fail - } - } + "TopWiringTransform" should "do nothing if run without TopWiring* annotations" in { + val input = """|circuit Top : + | module Top : + | input foo : UInt<1>""".stripMargin + val inputFile = { + val fileName = s"${testDir.getAbsolutePath}/input-no-sources.fir" + val w = new PrintWriter(fileName) + w.write(input) + w.close() + fileName + } + val args = Array( + "--custom-transforms", + "firrtl.transforms.TopWiring.TopWiringTransform", + "--input-file", + inputFile, + "--top-name", + "Top", + "--compiler", + "low", + "--info-mode", + "ignore" + ) + firrtl.Driver.execute(args) match { + case FirrtlExecutionSuccess(_, emitted) => + parse(emitted).serialize should be(parse(input).serialize) + case _ => fail + } + } - "TopWiringTransform" should "remove TopWiringAnnotations" in { - val input = - """|circuit Top: - | module Top: - | wire foo: UInt<1>""".stripMargin + "TopWiringTransform" should "remove TopWiringAnnotations" in { + val input = + """|circuit Top: + | module Top: + | wire foo: UInt<1>""".stripMargin - val bar = - Target - .deserialize("~Top|Top>foo") - .toNamed match { case a: ComponentName => a } + val bar = + Target + .deserialize("~Top|Top>foo") + .toNamed match { case a: ComponentName => a } - val annotations = Seq(TopWiringAnnotation(bar, "bar_")) - val outputState = (new TopWiringTransform).execute(CircuitState(Parser.parse(input), MidForm, annotations, None)) + val annotations = Seq(TopWiringAnnotation(bar, "bar_")) + val outputState = (new TopWiringTransform).execute(CircuitState(Parser.parse(input), MidForm, annotations, None)) - outputState.circuit.serialize should include ("output bar_foo") - outputState.annotations.toSeq should be (empty) - } + outputState.circuit.serialize should include("output bar_foo") + outputState.annotations.toSeq should be(empty) + } } class AggregateTopWiringTests extends MiddleTransformSpec with TopWiringTestsCommon { - "An aggregate wire named myAgg in A" should s"be wired to Top's IO as topwiring_a1_myAgg" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | module A: - | wire myAgg: { a: UInt<1>, b: SInt<8> } - | myAgg.a <= UInt(0) - | myAgg.b <= SInt(-1) + "An aggregate wire named myAgg in A" should s"be wired to Top's IO as topwiring_a1_myAgg" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | module A: + | wire myAgg: { a: UInt<1>, b: SInt<8> } + | myAgg.a <= UInt(0) + | myAgg.b <= SInt(-1) """.stripMargin - val topwiringannos = Seq(TopWiringAnnotation(ComponentName(s"myAgg", ModuleName(s"A", CircuitName(s"Top"))), - s"topwiring_")) - val check = - """circuit Top : - | module Top : - | output topwiring_a1_myAgg: { a: UInt<1>, b: SInt<8> } - | inst a1 of A - | topwiring_a1_myAgg.a <= a1.topwiring_myAgg.a - | topwiring_a1_myAgg.b <= a1.topwiring_myAgg.b - | module A : - | output topwiring_myAgg: { a: UInt<1>, b: SInt<8> } - | wire myAgg: { a: UInt<1>, b: SInt<8> } - | myAgg.a <= UInt(0) - | myAgg.b <= SInt(-1) - | topwiring_myAgg.a <= myAgg.a - | topwiring_myAgg.b <= myAgg.b + val topwiringannos = + Seq(TopWiringAnnotation(ComponentName(s"myAgg", ModuleName(s"A", CircuitName(s"Top"))), s"topwiring_")) + val check = + """circuit Top : + | module Top : + | output topwiring_a1_myAgg: { a: UInt<1>, b: SInt<8> } + | inst a1 of A + | topwiring_a1_myAgg.a <= a1.topwiring_myAgg.a + | topwiring_a1_myAgg.b <= a1.topwiring_myAgg.b + | module A : + | output topwiring_myAgg: { a: UInt<1>, b: SInt<8> } + | wire myAgg: { a: UInt<1>, b: SInt<8> } + | myAgg.a <= UInt(0) + | myAgg.b <= SInt(-1) + | topwiring_myAgg.a <= myAgg.a + | topwiring_myAgg.b <= myAgg.b """.stripMargin - execute(input, check, topwiringannos) - } + execute(input, check, topwiringannos) + } - "Aggregate wires myAgg in Top.a1, Top.b.a1 and Top.b.a2" should - s"be wired to Top's IO as topwiring_a1_myAgg, topwiring_b_a1_myAgg, and topwiring_b_a2_myAgg" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst b of B - | module B: - | inst a1 of A - | inst a2 of A - | module A: - | wire myAgg: { a: UInt<1>, b: SInt<8> } - | myAgg.a <= UInt(0) - | myAgg.b <= SInt(-1) + "Aggregate wires myAgg in Top.a1, Top.b.a1 and Top.b.a2" should + s"be wired to Top's IO as topwiring_a1_myAgg, topwiring_b_a1_myAgg, and topwiring_b_a2_myAgg" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst b of B + | module B: + | inst a1 of A + | inst a2 of A + | module A: + | wire myAgg: { a: UInt<1>, b: SInt<8> } + | myAgg.a <= UInt(0) + | myAgg.b <= SInt(-1) """.stripMargin - val topwiringannos = Seq( - TopWiringAnnotation(ComponentName(s"myAgg", ModuleName(s"A", CircuitName(s"Top"))), s"topwiring_")) + val topwiringannos = + Seq(TopWiringAnnotation(ComponentName(s"myAgg", ModuleName(s"A", CircuitName(s"Top"))), s"topwiring_")) - val check = - """circuit Top : - | module Top : - | output topwiring_a1_myAgg: { a: UInt<1>, b: SInt<8> } - | output topwiring_b_a1_myAgg: { a: UInt<1>, b: SInt<8> } - | output topwiring_b_a2_myAgg: { a: UInt<1>, b: SInt<8> } - | inst a1 of A - | inst b of B - | topwiring_a1_myAgg.a <= a1.topwiring_myAgg.a - | topwiring_a1_myAgg.b <= a1.topwiring_myAgg.b - | topwiring_b_a1_myAgg.a <= b.topwiring_a1_myAgg.a - | topwiring_b_a1_myAgg.b <= b.topwiring_a1_myAgg.b - | topwiring_b_a2_myAgg.a <= b.topwiring_a2_myAgg.a - | topwiring_b_a2_myAgg.b <= b.topwiring_a2_myAgg.b - | module B: - | output topwiring_a1_myAgg: { a: UInt<1>, b: SInt<8> } - | output topwiring_a2_myAgg: { a: UInt<1>, b: SInt<8> } - | inst a1 of A - | inst a2 of A - | topwiring_a1_myAgg.a <= a1.topwiring_myAgg.a - | topwiring_a1_myAgg.b <= a1.topwiring_myAgg.b - | topwiring_a2_myAgg.a <= a2.topwiring_myAgg.a - | topwiring_a2_myAgg.b <= a2.topwiring_myAgg.b - | module A : - | output topwiring_myAgg: { a: UInt<1>, b: SInt<8> } - | wire myAgg: { a: UInt<1>, b: SInt<8> } - | myAgg.a <= UInt(0) - | myAgg.b <= SInt(-1) - | topwiring_myAgg.a <= myAgg.a - | topwiring_myAgg.b <= myAgg.b + val check = + """circuit Top : + | module Top : + | output topwiring_a1_myAgg: { a: UInt<1>, b: SInt<8> } + | output topwiring_b_a1_myAgg: { a: UInt<1>, b: SInt<8> } + | output topwiring_b_a2_myAgg: { a: UInt<1>, b: SInt<8> } + | inst a1 of A + | inst b of B + | topwiring_a1_myAgg.a <= a1.topwiring_myAgg.a + | topwiring_a1_myAgg.b <= a1.topwiring_myAgg.b + | topwiring_b_a1_myAgg.a <= b.topwiring_a1_myAgg.a + | topwiring_b_a1_myAgg.b <= b.topwiring_a1_myAgg.b + | topwiring_b_a2_myAgg.a <= b.topwiring_a2_myAgg.a + | topwiring_b_a2_myAgg.b <= b.topwiring_a2_myAgg.b + | module B: + | output topwiring_a1_myAgg: { a: UInt<1>, b: SInt<8> } + | output topwiring_a2_myAgg: { a: UInt<1>, b: SInt<8> } + | inst a1 of A + | inst a2 of A + | topwiring_a1_myAgg.a <= a1.topwiring_myAgg.a + | topwiring_a1_myAgg.b <= a1.topwiring_myAgg.b + | topwiring_a2_myAgg.a <= a2.topwiring_myAgg.a + | topwiring_a2_myAgg.b <= a2.topwiring_myAgg.b + | module A : + | output topwiring_myAgg: { a: UInt<1>, b: SInt<8> } + | wire myAgg: { a: UInt<1>, b: SInt<8> } + | myAgg.a <= UInt(0) + | myAgg.b <= SInt(-1) + | topwiring_myAgg.a <= myAgg.a + | topwiring_myAgg.b <= myAgg.b """.stripMargin - execute(input, check, topwiringannos) - } + execute(input, check, topwiringannos) + } } diff --git a/src/test/scala/loggertests/LoggerSpec.scala b/src/test/scala/loggertests/LoggerSpec.scala index c8aae949..553f4966 100644 --- a/src/test/scala/loggertests/LoggerSpec.scala +++ b/src/test/scala/loggertests/LoggerSpec.scala @@ -260,7 +260,6 @@ class LoggerSpec extends AnyFreeSpec with Matchers with OneInstancePerTest with val captor = new OutputCaptor Logger.setOutput(captor.printStream) - Logger.setLevel(LogLevel.Info) Logger.setLevel("loggertests.LogsInfo2", LogLevel.Error) @@ -302,47 +301,47 @@ class LoggerSpec extends AnyFreeSpec with Matchers with OneInstancePerTest with } val logText = captor.getOutputAsString - logText should include ("message 1") - logText should include ("message 2") + logText should include("message 1") + logText should include("message 2") } } "Show that nested makeScopes share same state" in { - Logger.getGlobalLevel should be (LoggerSpec.globalLevel) + Logger.getGlobalLevel should be(LoggerSpec.globalLevel) Logger.makeScope() { Logger.setLevel(LogLevel.Info) - Logger.getGlobalLevel should be (LogLevel.Info) + Logger.getGlobalLevel should be(LogLevel.Info) Logger.makeScope() { - Logger.getGlobalLevel should be (LogLevel.Info) + Logger.getGlobalLevel should be(LogLevel.Info) } Logger.makeScope() { Logger.setLevel(LogLevel.Debug) - Logger.getGlobalLevel should be (LogLevel.Debug) + Logger.getGlobalLevel should be(LogLevel.Debug) } - Logger.getGlobalLevel should be (LogLevel.Debug) + Logger.getGlobalLevel should be(LogLevel.Debug) } - Logger.getGlobalLevel should be (LoggerSpec.globalLevel) + Logger.getGlobalLevel should be(LoggerSpec.globalLevel) } "Show that first makeScope starts with fresh state" in { - Logger.getGlobalLevel should be (LoggerSpec.globalLevel) + Logger.getGlobalLevel should be(LoggerSpec.globalLevel) Logger.setLevel(LogLevel.Warn) - Logger.getGlobalLevel should be (LogLevel.Warn) + Logger.getGlobalLevel should be(LogLevel.Warn) Logger.makeScope() { - Logger.getGlobalLevel should be (LoggerSpec.globalLevel) + Logger.getGlobalLevel should be(LoggerSpec.globalLevel) Logger.setLevel(LogLevel.Trace) - Logger.getGlobalLevel should be (LogLevel.Trace) + Logger.getGlobalLevel should be(LogLevel.Trace) } - Logger.getGlobalLevel should be (LogLevel.Warn) + Logger.getGlobalLevel should be(LogLevel.Warn) } } } |
