aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlbert Chen2018-11-15 08:53:16 -0800
committerSchuyler Eldridge2018-11-15 08:53:16 -0800
commitb90589f5cd9d4048ada2a05d5225874791546170 (patch)
tree126699d5955746ecb7e4d5432299c648ec3446d5 /src
parent6ece732d09b8610ae50545dab312d6759ac2f8e2 (diff)
Combine cats (#851)
- Add firrtl.transforms.CombineCats - Use CombineCats in LowFirrtlOptimization - Modify Verilog emitter to allow for nested Cat DoPrims - Modify firrtlEquivalenceTest to write input FIRRTL string to test directory
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/Emitter.scala26
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala1
-rw-r--r--src/main/scala/firrtl/transforms/CombineCats.scala68
-rw-r--r--src/test/scala/firrtlTests/FirrtlSpec.scala3
-rw-r--r--src/test/scala/firrtlTests/VerilogEmitterTests.scala34
-rw-r--r--src/test/scala/firrtlTests/transforms/CombineCatsSpec.scala160
6 files changed, 290 insertions, 2 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala
index e28a3d20..0897b2db 100644
--- a/src/main/scala/firrtl/Emitter.scala
+++ b/src/main/scala/firrtl/Emitter.scala
@@ -246,7 +246,29 @@ class VerilogEmitter extends SeqTransform with Emitter {
case _: UIntLiteral | _: SIntLiteral | _: WRef | _: WSubField =>
case _ => throw EmitterException(s"Can't emit ${e.getClass.getName} as PrimOp argument")
}
- doprim.args foreach checkArgumentLegality
+
+ def checkCatArgumentLegality(e: Expression): Unit = e match {
+ case _: UIntLiteral | _: SIntLiteral | _: WRef | _: WSubField =>
+ case DoPrim(Cat, args, _, _) => args foreach(checkCatArgumentLegality)
+ case _ => throw EmitterException(s"Can't emit ${e.getClass.getName} as PrimOp argument")
+ }
+
+ def castCatArgs(a0: Expression, a1: Expression): Seq[Any] = {
+ val a0Seq = a0 match {
+ case cat@DoPrim(PrimOps.Cat, args, _, _) => castCatArgs(args.head, args(1))
+ case _ => Seq(cast(a0))
+ }
+ val a1Seq = a1 match {
+ case cat@DoPrim(PrimOps.Cat, args, _, _) => castCatArgs(args.head, args(1))
+ case _ => Seq(cast(a1))
+ }
+ a0Seq ++ Seq(",") ++ a1Seq
+ }
+
+ doprim.op match {
+ case Cat => doprim.args foreach(checkCatArgumentLegality)
+ case other => doprim.args foreach checkArgumentLegality
+ }
doprim.op match {
case Add => Seq(cast_if(a0), " + ", cast_if(a1))
case Addw => Seq(cast_if(a0), " + ", cast_if(a1))
@@ -298,7 +320,7 @@ class VerilogEmitter extends SeqTransform with Emitter {
case Andr => Seq("&", cast(a0))
case Orr => Seq("|", cast(a0))
case Xorr => Seq("^", cast(a0))
- case Cat => Seq("{", cast(a0), ",", cast(a1), "}")
+ case Cat => "{" +: (castCatArgs(a0, a1) :+ "}")
// If selecting zeroth bit and single-bit wire, just emit the wire
case Bits if c0 == 0 && c1 == 0 && bitWidth(a0.tpe) == BigInt(1) => Seq(a0)
case Bits if c0 == c1 => Seq(a0, "[", c0, "]")
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index 92f9a9a4..c22fe99f 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -108,6 +108,7 @@ class LowFirrtlOptimization extends CoreTransform {
passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter
new firrtl.transforms.ConstantPropagation,
passes.SplitExpressions,
+ new firrtl.transforms.CombineCats,
passes.CommonSubexpressionElimination,
new firrtl.transforms.DeadCodeElimination)
}
diff --git a/src/main/scala/firrtl/transforms/CombineCats.scala b/src/main/scala/firrtl/transforms/CombineCats.scala
new file mode 100644
index 00000000..1265a5c3
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/CombineCats.scala
@@ -0,0 +1,68 @@
+
+package firrtl
+package transforms
+
+import firrtl.ir._
+import firrtl.Mappers._
+import firrtl.PrimOps._
+import firrtl.WrappedExpression._
+import firrtl.annotations.NoTargetAnnotation
+
+import scala.collection.mutable
+
+case class MaxCatLenAnnotation(maxCatLen: Int) extends NoTargetAnnotation
+
+object CombineCats {
+ /** Mapping from references to the [[Expression]]s that drive them paired with their Cat length */
+ type Netlist = mutable.HashMap[WrappedExpression, (Int, Expression)]
+
+ def expandCatArgs(maxCatLen: Int, netlist: Netlist)(expr: Expression): (Int, Expression) = expr match {
+ case cat@DoPrim(Cat, args, _, _) =>
+ val (a0Len, a0Expanded) = expandCatArgs(maxCatLen - 1, netlist)(args.head)
+ val (a1Len, a1Expanded) = expandCatArgs(maxCatLen - a0Len, netlist)(args(1))
+ (a0Len + a1Len, cat.copy(args = Seq(a0Expanded, a1Expanded)).asInstanceOf[Expression])
+ case other =>
+ netlist.get(we(expr)).collect {
+ case (len, cat@DoPrim(Cat, _, _, _)) if maxCatLen >= len => expandCatArgs(maxCatLen, netlist)(cat)
+ }.getOrElse((1, other))
+ }
+
+ def onStmt(maxCatLen: Int, netlist: Netlist)(stmt: Statement): Statement = {
+ stmt.map(onStmt(maxCatLen, netlist)) match {
+ case node@DefNode(_, name, value) =>
+ val catLenAndVal = value match {
+ case cat@DoPrim(Cat, _, _, _) => expandCatArgs(maxCatLen, netlist)(cat)
+ case other => (1, other)
+ }
+ netlist(we(WRef(name))) = catLenAndVal
+ node.copy(value = catLenAndVal._2)
+ case other => other
+ }
+ }
+
+ def onMod(maxCatLen: Int)(mod: DefModule): DefModule = mod.map(onStmt(maxCatLen, new Netlist))
+}
+
+/** Combine Cat DoPrims
+ *
+ * Expands the arguments of any Cat DoPrims if they are references to other Cat DoPrims.
+ * Operates only on Cat DoPrims that are node values.
+ *
+ * Use [[MaxCatLenAnnotation]] to limit the number of elements that can be concatenated.
+ * The default maximum number of elements is 10.
+ */
+class CombineCats extends Transform {
+ def inputForm: LowForm.type = LowForm
+ def outputForm: LowForm.type = LowForm
+ val defaultMaxCatLen = 10
+
+ def execute(state: CircuitState): CircuitState = {
+ val maxCatLen = state.annotations.collectFirst {
+ case m: MaxCatLenAnnotation => m.maxCatLen
+ }.getOrElse(defaultMaxCatLen)
+
+ val modulesx = state.circuit.modules.map(CombineCats.onMod(maxCatLen))
+ state.copy(circuit = state.circuit.copy(modules = modulesx))
+ }
+}
+
diff --git a/src/test/scala/firrtlTests/FirrtlSpec.scala b/src/test/scala/firrtlTests/FirrtlSpec.scala
index 95b09d93..88238785 100644
--- a/src/test/scala/firrtlTests/FirrtlSpec.scala
+++ b/src/test/scala/firrtlTests/FirrtlSpec.scala
@@ -59,6 +59,9 @@ trait FirrtlRunners extends BackendCompilationUtilities {
val compiler = new MinimumVerilogCompiler
val prefix = circuit.main
val testDir = createTestDirectory(prefix + "_equivalence_test")
+ val firrtlWriter = new PrintWriter(s"${testDir.getAbsolutePath}/$prefix.fir")
+ firrtlWriter.write(input)
+ firrtlWriter.close()
val customVerilog = compiler.compileAndEmit(CircuitState(circuit, HighForm, customAnnotations),
new GetNamespace +: new RenameTop(s"${prefix}_custom") +: customTransforms)
diff --git a/src/test/scala/firrtlTests/VerilogEmitterTests.scala b/src/test/scala/firrtlTests/VerilogEmitterTests.scala
index 3b9f4702..b712b46d 100644
--- a/src/test/scala/firrtlTests/VerilogEmitterTests.scala
+++ b/src/test/scala/firrtlTests/VerilogEmitterTests.scala
@@ -3,6 +3,7 @@
package firrtlTests
import java.io._
+
import org.scalatest._
import org.scalatest.prop._
import firrtl._
@@ -12,6 +13,7 @@ import firrtl.passes._
import firrtl.transforms.VerilogRename
import firrtl.Parser.IgnoreInfo
import FirrtlCheckers._
+import firrtl.transforms.CombineCats
class DoPrimVerilog extends FirrtlFlatSpec {
"Xorr" should "emit correctly" in {
@@ -89,6 +91,38 @@ class DoPrimVerilog extends FirrtlFlatSpec {
|""".stripMargin.split("\n") map normalized
executeTest(input, check, compiler)
}
+ "nested cats" should "emit correctly" in {
+ val compiler = new MinimumVerilogCompiler
+ val input =
+ """circuit Test :
+ | module Test :
+ | input in1 : UInt<1>
+ | input in2 : UInt<2>
+ | input in3 : UInt<3>
+ | input in4 : UInt<4>
+ | output out : UInt<10>
+ | out <= cat(in4, cat(in3, cat(in2, in1)))
+ |""".stripMargin
+ val check =
+ """module Test(
+ | input in1,
+ | input [1:0] in2,
+ | input [2:0] in3,
+ | input [3:0] in4,
+ | output [9:0] out
+ |);
+ | wire [5:0] _GEN_1;
+ | assign out = {in4,_GEN_1};
+ | assign _GEN_1 = {in3,in2,in1};
+ |endmodule
+ |""".stripMargin.split("\n") map normalized
+
+ val finalState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm), Seq(new CombineCats()))
+ val lines = finalState.getEmittedCircuit.value split "\n" map normalized
+ for (e <- check) {
+ lines should contain (e)
+ }
+ }
}
class VerilogEmitterSpec extends FirrtlFlatSpec {
diff --git a/src/test/scala/firrtlTests/transforms/CombineCatsSpec.scala b/src/test/scala/firrtlTests/transforms/CombineCatsSpec.scala
new file mode 100644
index 00000000..6ac2d14e
--- /dev/null
+++ b/src/test/scala/firrtlTests/transforms/CombineCatsSpec.scala
@@ -0,0 +1,160 @@
+// See LICENSE for license details.
+
+package firrtlTests.transforms
+
+import firrtl.PrimOps._
+import firrtl._
+import firrtl.ir.DoPrim
+import firrtl.transforms.{CombineCats, MaxCatLenAnnotation}
+import firrtlTests.FirrtlFlatSpec
+import firrtlTests.FirrtlCheckers._
+
+class CombineCatsSpec extends FirrtlFlatSpec {
+ private val transforms = Seq(new IRToWorkingIR, new CombineCats)
+ private val annotations = Seq(new MaxCatLenAnnotation(12))
+
+ private def execute(input: String, transforms: Seq[Transform], annotations: AnnotationSeq): CircuitState = {
+ val c = transforms.foldLeft(CircuitState(parse(input), UnknownForm, annotations)) {
+ (c: CircuitState, t: Transform) => t.runTransform(c)
+ }.circuit
+ CircuitState(c, UnknownForm, Seq(), None)
+ }
+
+ "circuit1 with combined cats" should "be equivalent to one without" in {
+ val input =
+ """circuit Test_CombinedCats1 :
+ | module Test_CombinedCats1 :
+ | input in1 : UInt<1>
+ | input in2 : UInt<2>
+ | input in3 : UInt<3>
+ | input in4 : UInt<4>
+ | output out : UInt<10>
+ | out <= cat(in4, cat(in3, cat(in2, in1)))
+ |""".stripMargin
+ firrtlEquivalenceTest(input, transforms, annotations)
+ }
+
+ "circuit2 with combined cats" should "be equivalent to one without" in {
+ val input =
+ """circuit Test_CombinedCats2 :
+ | module Test_CombinedCats2 :
+ | input in1 : UInt<1>
+ | input in2 : UInt<2>
+ | input in3 : UInt<3>
+ | input in4 : UInt<4>
+ | output out : UInt<10>
+ | out <= cat(cat(in4, in1), cat(cat(in4, in3), cat(in2, in1)))
+ |""".stripMargin
+ firrtlEquivalenceTest(input, transforms, annotations)
+ }
+
+ "circuit3 with combined cats" should "be equivalent to one without" in {
+ val input =
+ """circuit Test_CombinedCats3 :
+ | module Test_CombinedCats3 :
+ | input in1 : UInt<1>
+ | input in2 : UInt<2>
+ | input in3 : UInt<3>
+ | input in4 : UInt<4>
+ | output out : UInt<10>
+ | node temp1 = cat(cat(in4, in3), cat(in2, in1))
+ | node temp2 = cat(in4, cat(in3, cat(in2, in1)))
+ | out <= add(temp1, temp2)
+ |""".stripMargin
+ firrtlEquivalenceTest(input, transforms, annotations)
+ }
+
+ "nested cats" should "be combined" in {
+ val input =
+ """circuit Test_CombinedCats4 :
+ | module Test_CombinedCats4 :
+ | input in1 : UInt<1>
+ | input in2 : UInt<2>
+ | input in3 : UInt<3>
+ | input in4 : UInt<4>
+ | output out : UInt<10>
+ | node temp1 = cat(in2, in1)
+ | node temp2 = cat(in3, in2)
+ | node temp3 = cat(in4, in3)
+ | node temp4 = cat(temp1, temp2)
+ | node temp5 = cat(temp4, temp3)
+ | out <= temp5
+ |""".stripMargin
+
+ firrtlEquivalenceTest(input, transforms, annotations)
+ val result = execute(input, transforms, Seq.empty)
+
+ // temp5 should get cat(cat(cat(in3, in2), cat(in4, in3)), cat(cat(in3, in2), cat(in4, in3)))
+ result should containTree {
+ case DoPrim(Cat, Seq(
+ DoPrim(Cat, Seq(
+ DoPrim(Cat, Seq(WRef("in2", _, _, _), WRef("in1", _, _, _)), _, _),
+ DoPrim(Cat, Seq(WRef("in3", _, _, _), WRef("in2", _, _, _)), _, _)), _, _),
+ DoPrim(Cat, Seq(WRef("in4", _, _, _), WRef("in3", _, _, _)), _, _)), _, _) => true
+ }
+ }
+
+ "cats" should "not be longer than maxCatLen" in {
+ val input =
+ """circuit Test_CombinedCats5 :
+ | module Test_CombinedCats5 :
+ | input in1 : UInt<1>
+ | input in2 : UInt<2>
+ | input in3 : UInt<3>
+ | input in4 : UInt<4>
+ | input in5 : UInt<5>
+ | output out : UInt<10>
+ | node temp1 = cat(in2, in1)
+ | node temp2 = cat(in3, temp1)
+ | node temp3 = cat(in4, temp2)
+ | node temp4 = cat(in5, temp3)
+ | out <= temp4
+ |""".stripMargin
+
+ val maxCatLenAnnotation3 = Seq(new MaxCatLenAnnotation(3))
+ firrtlEquivalenceTest(input, transforms, maxCatLenAnnotation3)
+ val result = execute(input, transforms, maxCatLenAnnotation3)
+
+ // should not contain any cat chains greater than 3
+ result shouldNot containTree {
+ case DoPrim(Cat, Seq(_, DoPrim(Cat, Seq(_, DoPrim(Cat, _, _, _)), _, _)), _, _) => true
+ case DoPrim(Cat, Seq(_, DoPrim(Cat, Seq(_, DoPrim(Cat, Seq(_, DoPrim(Cat, _, _, _)), _, _)), _, _)), _, _) => true
+ }
+
+ // temp2 should get cat(in3, cat(in2, in1))
+ result should containTree {
+ case DoPrim(Cat, Seq(
+ WRef("in3", _, _, _),
+ DoPrim(Cat, Seq(
+ WRef("in2", _, _, _),
+ WRef("in1", _, _, _)), _, _)), _, _) => true
+ }
+ }
+
+ "nested nodes that are not cats" should "not be expanded" in {
+ val input =
+ """circuit Test_CombinedCats5 :
+ | module Test_CombinedCats5 :
+ | input in1 : UInt<1>
+ | input in2 : UInt<2>
+ | input in3 : UInt<3>
+ | input in4 : UInt<4>
+ | input in5 : UInt<5>
+ | output out : UInt<10>
+ | node temp1 = add(in2, in1)
+ | node temp2 = cat(in3, temp1)
+ | node temp3 = sub(in4, temp2)
+ | node temp4 = cat(in5, temp3)
+ | out <= temp4
+ |""".stripMargin
+
+ firrtlEquivalenceTest(input, transforms, annotations)
+
+ val result = execute(input, transforms, Seq.empty)
+ result shouldNot containTree {
+ case DoPrim(Cat, Seq(_, DoPrim(Add, _, _, _)), _, _) => true
+ case DoPrim(Cat, Seq(_, DoPrim(Sub, _, _, _)), _, _) => true
+ case DoPrim(Cat, Seq(_, DoPrim(Cat, Seq(_, DoPrim(Cat, _, _, _)), _, _)), _, _) => true
+ }
+ }
+}