aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJohn Ingalls2020-01-15 15:34:19 -0800
committermergify[bot]2020-01-15 23:34:19 +0000
commitbc8605d6e198ca38f446547a52d492ac678eda7d (patch)
treef1f4b5a9928cbf0b82bdbac536aeffdf236daf93 /src
parent0aa0ba8fac56fc81f57b24b6e0694d93de2b66df (diff)
Verilog emitter transform InlineBitExtractions (#1296)
* transform InlineBitExtractions * InlineNotsTransform, InlineBitExtractionsTransform: inputForm/outputForm = UnknownForm * clean up some minor redundancies from Adam review * clarifications from Seldrige review
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/Emitter.scala20
-rw-r--r--src/main/scala/firrtl/Utils.scala11
-rw-r--r--src/main/scala/firrtl/transforms/InlineBitExtractions.scala102
-rw-r--r--src/main/scala/firrtl/transforms/InlineNots.scala22
-rw-r--r--src/test/scala/firrtlTests/VerilogEmitterTests.scala110
5 files changed, 242 insertions, 23 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala
index 2d98cf04..7df3e242 100644
--- a/src/main/scala/firrtl/Emitter.scala
+++ b/src/main/scala/firrtl/Emitter.scala
@@ -320,6 +320,7 @@ class VerilogEmitter extends SeqTransform with Emitter {
case _: UIntLiteral | _: SIntLiteral | _: WRef | _: WSubField =>
case DoPrim(Not, args, _,_) => args.foreach(checkArgumentLegality)
case DoPrim(op, args, _,_) if isCast(op) => args.foreach(checkArgumentLegality)
+ case DoPrim(op, args, _,_) if isBitExtract(op) => args.foreach(checkArgumentLegality)
case _ => throw EmitterException(s"Can't emit ${e.getClass.getName} as PrimOp argument")
}
@@ -327,6 +328,7 @@ class VerilogEmitter extends SeqTransform with Emitter {
case _: UIntLiteral | _: SIntLiteral | _: WRef | _: WSubField =>
case DoPrim(Not, args, _,_) => args.foreach(checkArgumentLegality)
case DoPrim(op, args, _,_) if isCast(op) => args.foreach(checkArgumentLegality)
+ case DoPrim(op, args, _,_) if isBitExtract(op) => args.foreach(checkArgumentLegality)
case DoPrim(Cat, args, _, _) => args foreach(checkCatArgumentLegality)
case _ => throw EmitterException(s"Can't emit ${e.getClass.getName} as PrimOp argument")
}
@@ -386,6 +388,7 @@ class VerilogEmitter extends SeqTransform with Emitter {
case Shl => if (c0 > 0) Seq("{", cast(a0), s", $c0'h0}") else Seq(cast(a0))
case Shr if c0 >= bitWidth(a0.tpe) =>
error("Verilog emitter does not support SHIFT_RIGHT >= arg width")
+ case Shr if c0 == (bitWidth(a0.tpe)-1) => Seq(a0,"[", bitWidth(a0.tpe) - 1, "]")
case Shr => Seq(a0,"[", bitWidth(a0.tpe) - 1, ":", c0, "]")
case Neg => Seq("-{", cast(a0), "}")
case Cvt => a0.tpe match {
@@ -404,15 +407,15 @@ class VerilogEmitter extends SeqTransform with Emitter {
case Bits if c0 == 0 && c1 == 0 && bitWidth(a0.tpe) == BigInt(1) => Seq(a0)
case Bits if c0 == c1 => Seq(a0, "[", c0, "]")
case Bits => Seq(a0, "[", c0, ":", c1, "]")
+ // If selecting zeroth bit and single-bit wire, just emit the wire
+ case Head if c0 == 1 && bitWidth(a0.tpe) == BigInt(1) => Seq(a0)
+ case Head if c0 == 1 => Seq(a0, "[", bitWidth(a0.tpe)-1, "]")
case Head =>
- val w = bitWidth(a0.tpe)
- val high = w - 1
- val low = w - c0
- Seq(a0, "[", high, ":", low, "]")
- case Tail =>
- val w = bitWidth(a0.tpe)
- val low = w - c0 - 1
- Seq(a0, "[", low, ":", 0, "]")
+ val msb = bitWidth(a0.tpe) - 1
+ val lsb = bitWidth(a0.tpe) - c0
+ Seq(a0, "[", msb, ":", lsb, "]")
+ case Tail if c0 == (bitWidth(a0.tpe)-1) => Seq(a0, "[0]")
+ case Tail => Seq(a0, "[", bitWidth(a0.tpe) - c0 - 1, ":0]")
}
}
@@ -976,6 +979,7 @@ class VerilogEmitter extends SeqTransform with Emitter {
new BlackBoxSourceHelper,
new ReplaceTruncatingArithmetic,
new InlineNotsTransform,
+ new InlineBitExtractionsTransform, // here after InlineNots to clean up not(not(...)) rename
new InlineCastsTransform,
new LegalizeClocksTransform,
new FlattenRegUpdate,
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index 6cb309b3..b9c642d9 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -199,6 +199,17 @@ object Utils extends LazyLogging {
case _ => false
}
+ /** Returns true if PrimOp is a BitExtraction, false otherwise */
+ def isBitExtract(op: PrimOp): Boolean = op match {
+ case Bits | Head | Tail | Shr => true
+ case _ => false
+ }
+ /** Returns true if Expression is a Bits PrimOp, false otherwise */
+ def isBitExtract(expr: Expression): Boolean = expr match {
+ case DoPrim(op, _,_, UIntType(_)) if isBitExtract(op) => true
+ case _ => false
+ }
+
/** Provide a nice name to create a temporary **/
def niceName(e: Expression): String = niceName(1)(e)
def niceName(depth: Int)(e: Expression): String = {
diff --git a/src/main/scala/firrtl/transforms/InlineBitExtractions.scala b/src/main/scala/firrtl/transforms/InlineBitExtractions.scala
new file mode 100644
index 00000000..c4f40700
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/InlineBitExtractions.scala
@@ -0,0 +1,102 @@
+package firrtl
+package transforms
+
+import firrtl.ir._
+import firrtl.Mappers._
+import firrtl.PrimOps.{Bits, Head, Tail, Shr}
+import firrtl.Utils.{isBitExtract, isTemp}
+import firrtl.WrappedExpression._
+
+import scala.collection.mutable
+
+object InlineBitExtractionsTransform {
+
+ // Checks if an Expression is made up of only Bits terminated by a Literal or Reference.
+ // private because it's not clear if this definition of "Simple Expression" would be useful elsewhere.
+ // Note that this can have false negatives but MUST NOT have false positives.
+ private def isSimpleExpr(expr: Expression): Boolean = expr match {
+ case _: WRef | _: Literal | _: WSubField => true
+ case DoPrim(op, args, _,_) if isBitExtract(op) => args.forall(isSimpleExpr)
+ case _ => false
+ }
+
+ // replace Head/Tail/Shr with Bits for easier back-to-back Bits Extractions
+ private def lowerToDoPrimOpBits(expr: Expression): Expression = expr match {
+ case DoPrim(Head, rhs, c, tpe) if isSimpleExpr(expr) =>
+ val msb = bitWidth(rhs.head.tpe) - 1
+ val lsb = bitWidth(rhs.head.tpe) - c.head
+ DoPrim(Bits, rhs, Seq(msb,lsb), tpe)
+ case DoPrim(Tail, rhs, c, tpe) if isSimpleExpr(expr) =>
+ val msb = bitWidth(rhs.head.tpe) - c.head - 1
+ DoPrim(Bits, rhs, Seq(msb,0), tpe)
+ case DoPrim(Shr, rhs, c, tpe) if isSimpleExpr(expr) =>
+ DoPrim(Bits, rhs, Seq(bitWidth(rhs.head.tpe)-1, c.head), tpe)
+ case _ => expr // Not a candidate
+ }
+
+ /** Mapping from references to the [[firrtl.ir.Expression Expression]]s that drive them */
+ type Netlist = mutable.HashMap[WrappedExpression, Expression]
+
+ /** Recursively replace [[WRef]]s with new [[Expression]]s
+ *
+ * @param netlist a '''mutable''' HashMap mapping references to [[firrtl.ir.DefNode DefNode]]s to their connected
+ * [[firrtl.ir.Expression Expression]]s. It is '''not''' mutated in this function
+ * @param expr the Expression being transformed
+ * @return Returns expr with Bits inlined
+ */
+ def onExpr(netlist: Netlist)(expr: Expression): Expression = {
+ expr.map(onExpr(netlist)) match {
+ case e @ WRef(name, _,_,_) =>
+ netlist.get(we(e))
+ .filter(isBitExtract)
+ .getOrElse(e)
+ // replace back-to-back Bits Extractions
+ case lhs @ DoPrim(lop, ival, lc, ltpe) if isSimpleExpr(lhs) =>
+ ival.head match {
+ case of @ DoPrim(rop, rhs, rc, rtpe) if isSimpleExpr(of) =>
+ (lop, rop) match {
+ case (Head, Head) => DoPrim(Head, rhs, Seq(lc.head min rc.head), ltpe)
+ case (Tail, Tail) => DoPrim(Tail, rhs, Seq(lc.head + rc.head), ltpe)
+ case (Shr, Shr) => DoPrim(Shr, rhs, Seq(lc.head + rc.head), ltpe)
+ case (_,_) => (lowerToDoPrimOpBits(lhs), lowerToDoPrimOpBits(of)) match {
+ case (DoPrim(Bits, _, Seq(lmsb, llsb), _), DoPrim(Bits, _, Seq(rmsb, rlsb), _)) =>
+ DoPrim(Bits, rhs, Seq(lmsb+rlsb,llsb+rlsb), ltpe)
+ case (_,_) => lhs // Not a candidate
+ }
+ }
+ case _ => lhs // Not a candidate
+ }
+ case other => other // Not a candidate
+ }
+ }
+
+ /** Inline bits in a Statement
+ *
+ * @param netlist a '''mutable''' HashMap mapping references to [[firrtl.ir.DefNode DefNode]]s to their connected
+ * [[firrtl.ir.Expression Expression]]s. This function '''will''' mutate it if stmt is a [[firrtl.ir.DefNode
+ * DefNode]] with a Temporary name and a value that is a [[PrimOp]] Bits
+ * @param stmt the Statement being searched for nodes and transformed
+ * @return Returns stmt with Bits inlined
+ */
+ def onStmt(netlist: Netlist)(stmt: Statement): Statement =
+ stmt.map(onStmt(netlist)).map(onExpr(netlist)) match {
+ case node @ DefNode(_, name, value) if isTemp(name) =>
+ netlist(we(WRef(name))) = value
+ node
+ case other => other
+ }
+
+ /** Replaces bits in a Module */
+ def onMod(mod: DefModule): DefModule = mod.map(onStmt(new Netlist))
+}
+
+/** Inline nodes that are simple bits */
+class InlineBitExtractionsTransform extends Transform {
+ def inputForm = UnknownForm
+ def outputForm = UnknownForm
+
+ def execute(state: CircuitState): CircuitState = {
+ val modulesx = state.circuit.modules.map(InlineBitExtractionsTransform.onMod(_))
+ state.copy(circuit = state.circuit.copy(modules = modulesx))
+ }
+}
diff --git a/src/main/scala/firrtl/transforms/InlineNots.scala b/src/main/scala/firrtl/transforms/InlineNots.scala
index 3dab5168..299c130a 100644
--- a/src/main/scala/firrtl/transforms/InlineNots.scala
+++ b/src/main/scala/firrtl/transforms/InlineNots.scala
@@ -3,8 +3,8 @@ package transforms
import firrtl.ir._
import firrtl.Mappers._
-import firrtl.PrimOps.Not
-import firrtl.Utils.isTemp
+import firrtl.PrimOps.{Bits, Not}
+import firrtl.Utils.{isBitExtract, isTemp}
import firrtl.WrappedExpression._
import scala.collection.mutable
@@ -42,10 +42,18 @@ object InlineNotsTransform {
netlist.get(we(e))
.filter(isNot)
.getOrElse(e)
+ // replace bits-of-not with not-of-bits to enable later bit extraction transform
+ case lhs @ DoPrim(op, Seq(lval), lcons, ltpe) if isBitExtract(op) && isSimpleExpr(lval) =>
+ netlist.getOrElse(we(lval), lval) match {
+ case DoPrim(Not, Seq(rhs), rcons, rtpe) =>
+ DoPrim(Not, Seq(DoPrim(op, Seq(rhs), lcons, ltpe)), rcons, ltpe)
+ case _ => lhs // Not a candiate
+ }
// replace back-to-back inversions with a straight rename
- case lhs @ DoPrim(Not, Seq(inv), _,_) if isSimpleExpr(inv) =>
+ case lhs @ DoPrim(Not, Seq(inv), _, invtpe) if isSimpleExpr(lhs) && isSimpleExpr(inv) && (lhs.tpe == invtpe) && (bitWidth(lhs.tpe) == bitWidth(inv.tpe)) =>
netlist.getOrElse(we(inv), inv) match {
- case DoPrim(Not, Seq(rhs), _,_) if isSimpleExpr(inv) => rhs
+ case DoPrim(Not, Seq(rhs), _, rtpe) if (invtpe == rtpe) && (bitWidth(inv.tpe) == bitWidth(rhs.tpe)) =>
+ DoPrim(Bits, Seq(rhs), Seq(bitWidth(lhs.tpe)-1,0), rtpe)
case _ => lhs // Not a candiate
}
case other => other // Not a candidate
@@ -56,7 +64,7 @@ object InlineNotsTransform {
*
* @param netlist a '''mutable''' HashMap mapping references to [[firrtl.ir.DefNode DefNode]]s to their connected
* [[firrtl.ir.Expression Expression]]s. This function '''will''' mutate it if stmt is a [[firrtl.ir.DefNode
- * DefNode]] with a value that is a [[PrimOp]] Not
+ * DefNode]] with a Temporary name and a value that is a [[PrimOp]] Not
* @param stmt the Statement being searched for nodes and transformed
* @return Returns stmt with nots inlined
*/
@@ -74,8 +82,8 @@ object InlineNotsTransform {
/** Inline nodes that are simple nots */
class InlineNotsTransform extends Transform {
- def inputForm = LowForm
- def outputForm = LowForm
+ def inputForm = UnknownForm
+ def outputForm = UnknownForm
def execute(state: CircuitState): CircuitState = {
val modulesx = state.circuit.modules.map(InlineNotsTransform.onMod(_))
diff --git a/src/test/scala/firrtlTests/VerilogEmitterTests.scala b/src/test/scala/firrtlTests/VerilogEmitterTests.scala
index bb7659e9..bce9b155 100644
--- a/src/test/scala/firrtlTests/VerilogEmitterTests.scala
+++ b/src/test/scala/firrtlTests/VerilogEmitterTests.scala
@@ -88,6 +88,100 @@ class DoPrimVerilog extends FirrtlFlatSpec {
|""".stripMargin.split("\n") map normalized
executeTest(input, check, compiler)
}
+ "inline Bits" should "emit correctly" in {
+ val compiler = new VerilogCompiler
+ val input =
+ """circuit InlineBits :
+ | module InlineBits :
+ | input a: UInt<4>
+ | output b: UInt<1>
+ | output c: UInt<3>
+ | output d: UInt<2>
+ | output e: UInt<2>
+ | output f: UInt<2>
+ | output g: UInt<2>
+ | output h: UInt<2>
+ | output i: UInt<2>
+ | output j: UInt<2>
+ | output k: UInt<1>
+ | output l: UInt<1>
+ | output m: UInt<1>
+ | output n: UInt<1>
+ | output o: UInt<2>
+ | output p: UInt<2>
+ | output q: UInt<2>
+ | output r: UInt<1>
+ | output s: UInt<2>
+ | output t: UInt<2>
+ | output u: UInt<1>
+ | b <= bits(a, 2, 2)
+ | c <= bits(a, 3, 1)
+ | d <= head(a, 2)
+ | e <= tail(a, 2)
+ | f <= bits(bits(a, 3, 1), 2, 1)
+ | g <= bits(head(a, 3), 1, 0)
+ | h <= bits(tail(a, 1), 1, 0)
+ | i <= bits(shr(a, 1), 1, 0)
+ | j <= head(bits(a, 3, 1), 2)
+ | k <= head(head(a, 3), 1)
+ | l <= head(tail(a, 1), 1)
+ | m <= head(shr(a, 1), 1)
+ | n <= tail(bits(a, 3, 1), 2)
+ | o <= tail(head(a, 3), 1)
+ | p <= tail(tail(a, 1), 1)
+ | q <= tail(shr(a, 1), 1)
+ | r <= shr(bits(a, 1, 0), 1)
+ | s <= shr(head(a, 3), 1)
+ | t <= shr(tail(a, 1), 1)
+ | u <= shr(shr(a, 1), 2)""".stripMargin
+ val check =
+ """module InlineBits(
+ | input [3:0] a,
+ | output b,
+ | output [2:0] c,
+ | output [1:0] d,
+ | output [1:0] e,
+ | output [1:0] f,
+ | output [1:0] g,
+ | output [1:0] h,
+ | output [1:0] i,
+ | output [1:0] j,
+ | output k,
+ | output l,
+ | output m,
+ | output n,
+ | output [1:0] o,
+ | output [1:0] p,
+ | output [1:0] q,
+ | output r,
+ | output [1:0] s,
+ | output [1:0] t,
+ | output u
+ |);
+ | assign b = a[2];
+ | assign c = a[3:1];
+ | assign d = a[3:2];
+ | assign e = a[1:0];
+ | assign f = a[3:2];
+ | assign g = a[2:1];
+ | assign h = a[1:0];
+ | assign i = a[2:1];
+ | assign j = a[3:2];
+ | assign k = a[3];
+ | assign l = a[2];
+ | assign m = a[3];
+ | assign n = a[1];
+ | assign o = a[2:1];
+ | assign p = a[1:0];
+ | assign q = a[2:1];
+ | assign r = a[1];
+ | assign s = a[3:2];
+ | assign t = a[2:1];
+ | assign u = a[3];
+ |endmodule
+ |""".stripMargin.split("\n") map normalized
+ executeTest(input, check, compiler)
+ }
"inline Not" should "emit correctly" in {
val compiler = new VerilogCompiler
val input =
@@ -100,10 +194,12 @@ class DoPrimVerilog extends FirrtlFlatSpec {
| output e: UInt<1>
| output f: UInt<1>
| output g: UInt<1>
+ | output h: UInt<1>
| d <= and(a, not(b))
| e <= or(a, not(b))
| f <= not(not(not(bits(c, 2, 2))))
- | g <= mux(not(bits(c, 2, 2)), a, b)""".stripMargin
+ | g <= mux(not(bits(c, 2, 2)), a, b)
+ | h <= shr(not(bits(c, 2, 1)), 1)""".stripMargin
val check =
"""module InlineNot(
| input a,
@@ -112,16 +208,14 @@ class DoPrimVerilog extends FirrtlFlatSpec {
| output d,
| output e,
| output f,
- | output g
+ | output g,
+ | output h
|);
- | wire _GEN_2;
- | wire _GEN_4;
| assign d = a & ~b;
| assign e = a | ~b;
- | assign _GEN_2 = c[2];
- | assign _GEN_4 = _GEN_2;
- | assign f = ~_GEN_4;
- | assign g = _GEN_2 ? b : a;
+ | assign f = ~c[2];
+ | assign g = c[2] ? b : a;
+ | assign h = ~c[2];
|endmodule
|""".stripMargin.split("\n") map normalized
executeTest(input, check, compiler)