From 161938b84013a6c3307abc2707f541deddf487b4 Mon Sep 17 00:00:00 2001 From: Adam Izraelevitz Date: Mon, 8 Jun 2020 10:29:39 -0700 Subject: Grouping Chisel API (#1073) * Added group chisel API * Removed println * Added scaladoc * Added more tests * Cleaned spacing and removed println Co-authored-by: Chick Markley Co-authored-by: Jim Lawson Co-authored-by: Schuyler Eldridge --- .../scala/chiselTests/experimental/GroupSpec.scala | 118 +++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 src/test/scala/chiselTests/experimental/GroupSpec.scala (limited to 'src/test') diff --git a/src/test/scala/chiselTests/experimental/GroupSpec.scala b/src/test/scala/chiselTests/experimental/GroupSpec.scala new file mode 100644 index 00000000..593179f4 --- /dev/null +++ b/src/test/scala/chiselTests/experimental/GroupSpec.scala @@ -0,0 +1,118 @@ +// See LICENSE for license details. + +package chiselTests.experimental + +import chiselTests.ChiselFlatSpec +import chisel3._ +import chisel3.RawModule +import chisel3.stage.{ChiselGeneratorAnnotation, ChiselMain} +import chisel3.util.experimental.group +import firrtl.analyses.InstanceGraph +import firrtl.options.TargetDirAnnotation +import firrtl.stage.CompilerAnnotation +import firrtl.{LowFirrtlCompiler, ir => fir} + +import scala.collection.mutable + +class GroupSpec extends ChiselFlatSpec { + + def collectInstances(c: fir.Circuit, top: Option[String] = None): Seq[String] = new InstanceGraph(c) + .fullHierarchy.values.flatten.toSeq + .map( v => (top.getOrElse(v.head.name) +: v.tail.map(_.name)).mkString(".") ) + + def collectDeclarations(m: fir.DefModule): Set[String] = { + val decs = mutable.HashSet[String]() + def onStmt(s: fir.Statement): fir.Statement = s.mapStmt(onStmt) match { + case d: fir.IsDeclaration => decs += d.name; d + case other => other + } + m.mapStmt(onStmt) + decs.toSet + } + + def lower[T <: RawModule](gen: () => T): fir.Circuit = { + (ChiselMain.stage.run( + Seq( + CompilerAnnotation(new LowFirrtlCompiler()), + TargetDirAnnotation("test_run_dir"), + ChiselGeneratorAnnotation(gen) + ) + ) collectFirst { + case firrtl.stage.FirrtlCircuitAnnotation(circuit) => circuit + }).get + } + + "Module Grouping" should "compile to low FIRRTL" in { + class MyModule extends Module { + val io = IO(new Bundle{ + val a = Input(Bool()) + val b = Output(Bool()) + }) + val reg1 = RegInit(0.U) + reg1 := io.a + val reg2 = RegNext(reg1) + io.b := reg2 + group(Seq(reg1, reg2), "DosRegisters", "doubleReg") + } + + val firrtlCircuit = lower(() => new MyModule) + firrtlCircuit.modules.collect { + case m: fir.Module if m.name == "MyModule" => + Set("doubleReg") should be (collectDeclarations(m)) + case m: fir.Module if m.name == "DosRegisters" => + Set("reg1", "reg2") should be (collectDeclarations(m)) + } + val instances = collectInstances(firrtlCircuit, Some("MyModule")).toSet + Set("MyModule", "MyModule.doubleReg") should be (instances) + } + + "Module Grouping" should "not include intermediate registers" in { + class MyModule extends Module { + val io = IO(new Bundle{ + val a = Input(Bool()) + val b = Output(Bool()) + }) + val reg1 = RegInit(0.U) + reg1 := io.a + val reg2 = RegNext(reg1) + val reg3 = RegNext(reg2) + io.b := reg3 + group(Seq(reg1, reg3), "DosRegisters", "doubleReg") + } + + val firrtlCircuit = lower(() => new MyModule) + firrtlCircuit.modules.collect { + case m: fir.Module if m.name == "MyModule" => + Set("reg2", "doubleReg") should be (collectDeclarations(m)) + case m: fir.Module if m.name == "DosRegisters" => + Set("reg1", "reg3") should be (collectDeclarations(m)) + } + val instances = collectInstances(firrtlCircuit, Some("MyModule")).toSet + Set("MyModule", "MyModule.doubleReg") should be (instances) + } + + "Module Grouping" should "include intermediate wires" in { + class MyModule extends Module { + val io = IO(new Bundle{ + val a = Input(Bool()) + val b = Output(Bool()) + }) + val reg1 = RegInit(0.U) + reg1 := io.a + val wire = WireInit(reg1) + val reg3 = RegNext(wire) + io.b := reg3 + group(Seq(reg1, reg3), "DosRegisters", "doubleReg") + } + + val firrtlCircuit = lower(() => new MyModule) + firrtlCircuit.modules.collect { + case m: fir.Module if m.name == "MyModule" => + Set("doubleReg") should be (collectDeclarations(m)) + case m: fir.Module if m.name == "DosRegisters" => + Set("reg1", "reg3", "wire") should be (collectDeclarations(m)) + } + val instances = collectInstances(firrtlCircuit, Some("MyModule")).toSet + Set("MyModule", "MyModule.doubleReg") should be (instances) + } +} -- cgit v1.2.3