diff options
| author | Albert Chen | 2018-11-15 08:53:16 -0800 |
|---|---|---|
| committer | Schuyler Eldridge | 2018-11-15 08:53:16 -0800 |
| commit | b90589f5cd9d4048ada2a05d5225874791546170 (patch) | |
| tree | 126699d5955746ecb7e4d5432299c648ec3446d5 /src | |
| parent | 6ece732d09b8610ae50545dab312d6759ac2f8e2 (diff) | |
Combine cats (#851)
- Add firrtl.transforms.CombineCats
- Use CombineCats in LowFirrtlOptimization
- Modify Verilog emitter to allow for nested Cat DoPrims
- Modify firrtlEquivalenceTest to write input FIRRTL string to test directory
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Emitter.scala | 26 | ||||
| -rw-r--r-- | src/main/scala/firrtl/LoweringCompilers.scala | 1 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/CombineCats.scala | 68 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/FirrtlSpec.scala | 3 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/VerilogEmitterTests.scala | 34 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/transforms/CombineCatsSpec.scala | 160 |
6 files changed, 290 insertions, 2 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index e28a3d20..0897b2db 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -246,7 +246,29 @@ class VerilogEmitter extends SeqTransform with Emitter { case _: UIntLiteral | _: SIntLiteral | _: WRef | _: WSubField => case _ => throw EmitterException(s"Can't emit ${e.getClass.getName} as PrimOp argument") } - doprim.args foreach checkArgumentLegality + + def checkCatArgumentLegality(e: Expression): Unit = e match { + case _: UIntLiteral | _: SIntLiteral | _: WRef | _: WSubField => + case DoPrim(Cat, args, _, _) => args foreach(checkCatArgumentLegality) + case _ => throw EmitterException(s"Can't emit ${e.getClass.getName} as PrimOp argument") + } + + def castCatArgs(a0: Expression, a1: Expression): Seq[Any] = { + val a0Seq = a0 match { + case cat@DoPrim(PrimOps.Cat, args, _, _) => castCatArgs(args.head, args(1)) + case _ => Seq(cast(a0)) + } + val a1Seq = a1 match { + case cat@DoPrim(PrimOps.Cat, args, _, _) => castCatArgs(args.head, args(1)) + case _ => Seq(cast(a1)) + } + a0Seq ++ Seq(",") ++ a1Seq + } + + doprim.op match { + case Cat => doprim.args foreach(checkCatArgumentLegality) + case other => doprim.args foreach checkArgumentLegality + } doprim.op match { case Add => Seq(cast_if(a0), " + ", cast_if(a1)) case Addw => Seq(cast_if(a0), " + ", cast_if(a1)) @@ -298,7 +320,7 @@ class VerilogEmitter extends SeqTransform with Emitter { case Andr => Seq("&", cast(a0)) case Orr => Seq("|", cast(a0)) case Xorr => Seq("^", cast(a0)) - case Cat => Seq("{", cast(a0), ",", cast(a1), "}") + case Cat => "{" +: (castCatArgs(a0, a1) :+ "}") // If selecting zeroth bit and single-bit wire, just emit the wire case Bits if c0 == 0 && c1 == 0 && bitWidth(a0.tpe) == BigInt(1) => Seq(a0) case Bits if c0 == c1 => Seq(a0, "[", c0, "]") diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index 92f9a9a4..c22fe99f 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -108,6 +108,7 @@ class LowFirrtlOptimization extends CoreTransform { passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter new firrtl.transforms.ConstantPropagation, passes.SplitExpressions, + new firrtl.transforms.CombineCats, passes.CommonSubexpressionElimination, new firrtl.transforms.DeadCodeElimination) } diff --git a/src/main/scala/firrtl/transforms/CombineCats.scala b/src/main/scala/firrtl/transforms/CombineCats.scala new file mode 100644 index 00000000..1265a5c3 --- /dev/null +++ b/src/main/scala/firrtl/transforms/CombineCats.scala @@ -0,0 +1,68 @@ + +package firrtl +package transforms + +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.PrimOps._ +import firrtl.WrappedExpression._ +import firrtl.annotations.NoTargetAnnotation + +import scala.collection.mutable + +case class MaxCatLenAnnotation(maxCatLen: Int) extends NoTargetAnnotation + +object CombineCats { + /** Mapping from references to the [[Expression]]s that drive them paired with their Cat length */ + type Netlist = mutable.HashMap[WrappedExpression, (Int, Expression)] + + def expandCatArgs(maxCatLen: Int, netlist: Netlist)(expr: Expression): (Int, Expression) = expr match { + case cat@DoPrim(Cat, args, _, _) => + val (a0Len, a0Expanded) = expandCatArgs(maxCatLen - 1, netlist)(args.head) + val (a1Len, a1Expanded) = expandCatArgs(maxCatLen - a0Len, netlist)(args(1)) + (a0Len + a1Len, cat.copy(args = Seq(a0Expanded, a1Expanded)).asInstanceOf[Expression]) + case other => + netlist.get(we(expr)).collect { + case (len, cat@DoPrim(Cat, _, _, _)) if maxCatLen >= len => expandCatArgs(maxCatLen, netlist)(cat) + }.getOrElse((1, other)) + } + + def onStmt(maxCatLen: Int, netlist: Netlist)(stmt: Statement): Statement = { + stmt.map(onStmt(maxCatLen, netlist)) match { + case node@DefNode(_, name, value) => + val catLenAndVal = value match { + case cat@DoPrim(Cat, _, _, _) => expandCatArgs(maxCatLen, netlist)(cat) + case other => (1, other) + } + netlist(we(WRef(name))) = catLenAndVal + node.copy(value = catLenAndVal._2) + case other => other + } + } + + def onMod(maxCatLen: Int)(mod: DefModule): DefModule = mod.map(onStmt(maxCatLen, new Netlist)) +} + +/** Combine Cat DoPrims + * + * Expands the arguments of any Cat DoPrims if they are references to other Cat DoPrims. + * Operates only on Cat DoPrims that are node values. + * + * Use [[MaxCatLenAnnotation]] to limit the number of elements that can be concatenated. + * The default maximum number of elements is 10. + */ +class CombineCats extends Transform { + def inputForm: LowForm.type = LowForm + def outputForm: LowForm.type = LowForm + val defaultMaxCatLen = 10 + + def execute(state: CircuitState): CircuitState = { + val maxCatLen = state.annotations.collectFirst { + case m: MaxCatLenAnnotation => m.maxCatLen + }.getOrElse(defaultMaxCatLen) + + val modulesx = state.circuit.modules.map(CombineCats.onMod(maxCatLen)) + state.copy(circuit = state.circuit.copy(modules = modulesx)) + } +} + diff --git a/src/test/scala/firrtlTests/FirrtlSpec.scala b/src/test/scala/firrtlTests/FirrtlSpec.scala index 95b09d93..88238785 100644 --- a/src/test/scala/firrtlTests/FirrtlSpec.scala +++ b/src/test/scala/firrtlTests/FirrtlSpec.scala @@ -59,6 +59,9 @@ trait FirrtlRunners extends BackendCompilationUtilities { val compiler = new MinimumVerilogCompiler val prefix = circuit.main val testDir = createTestDirectory(prefix + "_equivalence_test") + val firrtlWriter = new PrintWriter(s"${testDir.getAbsolutePath}/$prefix.fir") + firrtlWriter.write(input) + firrtlWriter.close() val customVerilog = compiler.compileAndEmit(CircuitState(circuit, HighForm, customAnnotations), new GetNamespace +: new RenameTop(s"${prefix}_custom") +: customTransforms) diff --git a/src/test/scala/firrtlTests/VerilogEmitterTests.scala b/src/test/scala/firrtlTests/VerilogEmitterTests.scala index 3b9f4702..b712b46d 100644 --- a/src/test/scala/firrtlTests/VerilogEmitterTests.scala +++ b/src/test/scala/firrtlTests/VerilogEmitterTests.scala @@ -3,6 +3,7 @@ package firrtlTests import java.io._ + import org.scalatest._ import org.scalatest.prop._ import firrtl._ @@ -12,6 +13,7 @@ import firrtl.passes._ import firrtl.transforms.VerilogRename import firrtl.Parser.IgnoreInfo import FirrtlCheckers._ +import firrtl.transforms.CombineCats class DoPrimVerilog extends FirrtlFlatSpec { "Xorr" should "emit correctly" in { @@ -89,6 +91,38 @@ class DoPrimVerilog extends FirrtlFlatSpec { |""".stripMargin.split("\n") map normalized executeTest(input, check, compiler) } + "nested cats" should "emit correctly" in { + val compiler = new MinimumVerilogCompiler + val input = + """circuit Test : + | module Test : + | input in1 : UInt<1> + | input in2 : UInt<2> + | input in3 : UInt<3> + | input in4 : UInt<4> + | output out : UInt<10> + | out <= cat(in4, cat(in3, cat(in2, in1))) + |""".stripMargin + val check = + """module Test( + | input in1, + | input [1:0] in2, + | input [2:0] in3, + | input [3:0] in4, + | output [9:0] out + |); + | wire [5:0] _GEN_1; + | assign out = {in4,_GEN_1}; + | assign _GEN_1 = {in3,in2,in1}; + |endmodule + |""".stripMargin.split("\n") map normalized + + val finalState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm), Seq(new CombineCats())) + val lines = finalState.getEmittedCircuit.value split "\n" map normalized + for (e <- check) { + lines should contain (e) + } + } } class VerilogEmitterSpec extends FirrtlFlatSpec { diff --git a/src/test/scala/firrtlTests/transforms/CombineCatsSpec.scala b/src/test/scala/firrtlTests/transforms/CombineCatsSpec.scala new file mode 100644 index 00000000..6ac2d14e --- /dev/null +++ b/src/test/scala/firrtlTests/transforms/CombineCatsSpec.scala @@ -0,0 +1,160 @@ +// See LICENSE for license details. + +package firrtlTests.transforms + +import firrtl.PrimOps._ +import firrtl._ +import firrtl.ir.DoPrim +import firrtl.transforms.{CombineCats, MaxCatLenAnnotation} +import firrtlTests.FirrtlFlatSpec +import firrtlTests.FirrtlCheckers._ + +class CombineCatsSpec extends FirrtlFlatSpec { + private val transforms = Seq(new IRToWorkingIR, new CombineCats) + private val annotations = Seq(new MaxCatLenAnnotation(12)) + + private def execute(input: String, transforms: Seq[Transform], annotations: AnnotationSeq): CircuitState = { + val c = transforms.foldLeft(CircuitState(parse(input), UnknownForm, annotations)) { + (c: CircuitState, t: Transform) => t.runTransform(c) + }.circuit + CircuitState(c, UnknownForm, Seq(), None) + } + + "circuit1 with combined cats" should "be equivalent to one without" in { + val input = + """circuit Test_CombinedCats1 : + | module Test_CombinedCats1 : + | input in1 : UInt<1> + | input in2 : UInt<2> + | input in3 : UInt<3> + | input in4 : UInt<4> + | output out : UInt<10> + | out <= cat(in4, cat(in3, cat(in2, in1))) + |""".stripMargin + firrtlEquivalenceTest(input, transforms, annotations) + } + + "circuit2 with combined cats" should "be equivalent to one without" in { + val input = + """circuit Test_CombinedCats2 : + | module Test_CombinedCats2 : + | input in1 : UInt<1> + | input in2 : UInt<2> + | input in3 : UInt<3> + | input in4 : UInt<4> + | output out : UInt<10> + | out <= cat(cat(in4, in1), cat(cat(in4, in3), cat(in2, in1))) + |""".stripMargin + firrtlEquivalenceTest(input, transforms, annotations) + } + + "circuit3 with combined cats" should "be equivalent to one without" in { + val input = + """circuit Test_CombinedCats3 : + | module Test_CombinedCats3 : + | input in1 : UInt<1> + | input in2 : UInt<2> + | input in3 : UInt<3> + | input in4 : UInt<4> + | output out : UInt<10> + | node temp1 = cat(cat(in4, in3), cat(in2, in1)) + | node temp2 = cat(in4, cat(in3, cat(in2, in1))) + | out <= add(temp1, temp2) + |""".stripMargin + firrtlEquivalenceTest(input, transforms, annotations) + } + + "nested cats" should "be combined" in { + val input = + """circuit Test_CombinedCats4 : + | module Test_CombinedCats4 : + | input in1 : UInt<1> + | input in2 : UInt<2> + | input in3 : UInt<3> + | input in4 : UInt<4> + | output out : UInt<10> + | node temp1 = cat(in2, in1) + | node temp2 = cat(in3, in2) + | node temp3 = cat(in4, in3) + | node temp4 = cat(temp1, temp2) + | node temp5 = cat(temp4, temp3) + | out <= temp5 + |""".stripMargin + + firrtlEquivalenceTest(input, transforms, annotations) + val result = execute(input, transforms, Seq.empty) + + // temp5 should get cat(cat(cat(in3, in2), cat(in4, in3)), cat(cat(in3, in2), cat(in4, in3))) + result should containTree { + case DoPrim(Cat, Seq( + DoPrim(Cat, Seq( + DoPrim(Cat, Seq(WRef("in2", _, _, _), WRef("in1", _, _, _)), _, _), + DoPrim(Cat, Seq(WRef("in3", _, _, _), WRef("in2", _, _, _)), _, _)), _, _), + DoPrim(Cat, Seq(WRef("in4", _, _, _), WRef("in3", _, _, _)), _, _)), _, _) => true + } + } + + "cats" should "not be longer than maxCatLen" in { + val input = + """circuit Test_CombinedCats5 : + | module Test_CombinedCats5 : + | input in1 : UInt<1> + | input in2 : UInt<2> + | input in3 : UInt<3> + | input in4 : UInt<4> + | input in5 : UInt<5> + | output out : UInt<10> + | node temp1 = cat(in2, in1) + | node temp2 = cat(in3, temp1) + | node temp3 = cat(in4, temp2) + | node temp4 = cat(in5, temp3) + | out <= temp4 + |""".stripMargin + + val maxCatLenAnnotation3 = Seq(new MaxCatLenAnnotation(3)) + firrtlEquivalenceTest(input, transforms, maxCatLenAnnotation3) + val result = execute(input, transforms, maxCatLenAnnotation3) + + // should not contain any cat chains greater than 3 + result shouldNot containTree { + case DoPrim(Cat, Seq(_, DoPrim(Cat, Seq(_, DoPrim(Cat, _, _, _)), _, _)), _, _) => true + case DoPrim(Cat, Seq(_, DoPrim(Cat, Seq(_, DoPrim(Cat, Seq(_, DoPrim(Cat, _, _, _)), _, _)), _, _)), _, _) => true + } + + // temp2 should get cat(in3, cat(in2, in1)) + result should containTree { + case DoPrim(Cat, Seq( + WRef("in3", _, _, _), + DoPrim(Cat, Seq( + WRef("in2", _, _, _), + WRef("in1", _, _, _)), _, _)), _, _) => true + } + } + + "nested nodes that are not cats" should "not be expanded" in { + val input = + """circuit Test_CombinedCats5 : + | module Test_CombinedCats5 : + | input in1 : UInt<1> + | input in2 : UInt<2> + | input in3 : UInt<3> + | input in4 : UInt<4> + | input in5 : UInt<5> + | output out : UInt<10> + | node temp1 = add(in2, in1) + | node temp2 = cat(in3, temp1) + | node temp3 = sub(in4, temp2) + | node temp4 = cat(in5, temp3) + | out <= temp4 + |""".stripMargin + + firrtlEquivalenceTest(input, transforms, annotations) + + val result = execute(input, transforms, Seq.empty) + result shouldNot containTree { + case DoPrim(Cat, Seq(_, DoPrim(Add, _, _, _)), _, _) => true + case DoPrim(Cat, Seq(_, DoPrim(Sub, _, _, _)), _, _) => true + case DoPrim(Cat, Seq(_, DoPrim(Cat, Seq(_, DoPrim(Cat, _, _, _)), _, _)), _, _) => true + } + } +} |
