diff options
| author | John Ingalls | 2020-01-15 15:34:19 -0800 |
|---|---|---|
| committer | mergify[bot] | 2020-01-15 23:34:19 +0000 |
| commit | bc8605d6e198ca38f446547a52d492ac678eda7d (patch) | |
| tree | f1f4b5a9928cbf0b82bdbac536aeffdf236daf93 /src | |
| parent | 0aa0ba8fac56fc81f57b24b6e0694d93de2b66df (diff) | |
Verilog emitter transform InlineBitExtractions (#1296)
* transform InlineBitExtractions
* InlineNotsTransform, InlineBitExtractionsTransform: inputForm/outputForm = UnknownForm
* clean up some minor redundancies from Adam review
* clarifications from Seldrige review
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Emitter.scala | 20 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 11 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/InlineBitExtractions.scala | 102 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/InlineNots.scala | 22 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/VerilogEmitterTests.scala | 110 |
5 files changed, 242 insertions, 23 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index 2d98cf04..7df3e242 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -320,6 +320,7 @@ class VerilogEmitter extends SeqTransform with Emitter { case _: UIntLiteral | _: SIntLiteral | _: WRef | _: WSubField => case DoPrim(Not, args, _,_) => args.foreach(checkArgumentLegality) case DoPrim(op, args, _,_) if isCast(op) => args.foreach(checkArgumentLegality) + case DoPrim(op, args, _,_) if isBitExtract(op) => args.foreach(checkArgumentLegality) case _ => throw EmitterException(s"Can't emit ${e.getClass.getName} as PrimOp argument") } @@ -327,6 +328,7 @@ class VerilogEmitter extends SeqTransform with Emitter { case _: UIntLiteral | _: SIntLiteral | _: WRef | _: WSubField => case DoPrim(Not, args, _,_) => args.foreach(checkArgumentLegality) case DoPrim(op, args, _,_) if isCast(op) => args.foreach(checkArgumentLegality) + case DoPrim(op, args, _,_) if isBitExtract(op) => args.foreach(checkArgumentLegality) case DoPrim(Cat, args, _, _) => args foreach(checkCatArgumentLegality) case _ => throw EmitterException(s"Can't emit ${e.getClass.getName} as PrimOp argument") } @@ -386,6 +388,7 @@ class VerilogEmitter extends SeqTransform with Emitter { case Shl => if (c0 > 0) Seq("{", cast(a0), s", $c0'h0}") else Seq(cast(a0)) case Shr if c0 >= bitWidth(a0.tpe) => error("Verilog emitter does not support SHIFT_RIGHT >= arg width") + case Shr if c0 == (bitWidth(a0.tpe)-1) => Seq(a0,"[", bitWidth(a0.tpe) - 1, "]") case Shr => Seq(a0,"[", bitWidth(a0.tpe) - 1, ":", c0, "]") case Neg => Seq("-{", cast(a0), "}") case Cvt => a0.tpe match { @@ -404,15 +407,15 @@ class VerilogEmitter extends SeqTransform with Emitter { case Bits if c0 == 0 && c1 == 0 && bitWidth(a0.tpe) == BigInt(1) => Seq(a0) case Bits if c0 == c1 => Seq(a0, "[", c0, "]") case Bits => Seq(a0, "[", c0, ":", c1, "]") + // If selecting zeroth bit and single-bit wire, just emit the wire + case Head if c0 == 1 && bitWidth(a0.tpe) == BigInt(1) => Seq(a0) + case Head if c0 == 1 => Seq(a0, "[", bitWidth(a0.tpe)-1, "]") case Head => - val w = bitWidth(a0.tpe) - val high = w - 1 - val low = w - c0 - Seq(a0, "[", high, ":", low, "]") - case Tail => - val w = bitWidth(a0.tpe) - val low = w - c0 - 1 - Seq(a0, "[", low, ":", 0, "]") + val msb = bitWidth(a0.tpe) - 1 + val lsb = bitWidth(a0.tpe) - c0 + Seq(a0, "[", msb, ":", lsb, "]") + case Tail if c0 == (bitWidth(a0.tpe)-1) => Seq(a0, "[0]") + case Tail => Seq(a0, "[", bitWidth(a0.tpe) - c0 - 1, ":0]") } } @@ -976,6 +979,7 @@ class VerilogEmitter extends SeqTransform with Emitter { new BlackBoxSourceHelper, new ReplaceTruncatingArithmetic, new InlineNotsTransform, + new InlineBitExtractionsTransform, // here after InlineNots to clean up not(not(...)) rename new InlineCastsTransform, new LegalizeClocksTransform, new FlattenRegUpdate, diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 6cb309b3..b9c642d9 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -199,6 +199,17 @@ object Utils extends LazyLogging { case _ => false } + /** Returns true if PrimOp is a BitExtraction, false otherwise */ + def isBitExtract(op: PrimOp): Boolean = op match { + case Bits | Head | Tail | Shr => true + case _ => false + } + /** Returns true if Expression is a Bits PrimOp, false otherwise */ + def isBitExtract(expr: Expression): Boolean = expr match { + case DoPrim(op, _,_, UIntType(_)) if isBitExtract(op) => true + case _ => false + } + /** Provide a nice name to create a temporary **/ def niceName(e: Expression): String = niceName(1)(e) def niceName(depth: Int)(e: Expression): String = { diff --git a/src/main/scala/firrtl/transforms/InlineBitExtractions.scala b/src/main/scala/firrtl/transforms/InlineBitExtractions.scala new file mode 100644 index 00000000..c4f40700 --- /dev/null +++ b/src/main/scala/firrtl/transforms/InlineBitExtractions.scala @@ -0,0 +1,102 @@ +package firrtl +package transforms + +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.PrimOps.{Bits, Head, Tail, Shr} +import firrtl.Utils.{isBitExtract, isTemp} +import firrtl.WrappedExpression._ + +import scala.collection.mutable + +object InlineBitExtractionsTransform { + + // Checks if an Expression is made up of only Bits terminated by a Literal or Reference. + // private because it's not clear if this definition of "Simple Expression" would be useful elsewhere. + // Note that this can have false negatives but MUST NOT have false positives. + private def isSimpleExpr(expr: Expression): Boolean = expr match { + case _: WRef | _: Literal | _: WSubField => true + case DoPrim(op, args, _,_) if isBitExtract(op) => args.forall(isSimpleExpr) + case _ => false + } + + // replace Head/Tail/Shr with Bits for easier back-to-back Bits Extractions + private def lowerToDoPrimOpBits(expr: Expression): Expression = expr match { + case DoPrim(Head, rhs, c, tpe) if isSimpleExpr(expr) => + val msb = bitWidth(rhs.head.tpe) - 1 + val lsb = bitWidth(rhs.head.tpe) - c.head + DoPrim(Bits, rhs, Seq(msb,lsb), tpe) + case DoPrim(Tail, rhs, c, tpe) if isSimpleExpr(expr) => + val msb = bitWidth(rhs.head.tpe) - c.head - 1 + DoPrim(Bits, rhs, Seq(msb,0), tpe) + case DoPrim(Shr, rhs, c, tpe) if isSimpleExpr(expr) => + DoPrim(Bits, rhs, Seq(bitWidth(rhs.head.tpe)-1, c.head), tpe) + case _ => expr // Not a candidate + } + + /** Mapping from references to the [[firrtl.ir.Expression Expression]]s that drive them */ + type Netlist = mutable.HashMap[WrappedExpression, Expression] + + /** Recursively replace [[WRef]]s with new [[Expression]]s + * + * @param netlist a '''mutable''' HashMap mapping references to [[firrtl.ir.DefNode DefNode]]s to their connected + * [[firrtl.ir.Expression Expression]]s. It is '''not''' mutated in this function + * @param expr the Expression being transformed + * @return Returns expr with Bits inlined + */ + def onExpr(netlist: Netlist)(expr: Expression): Expression = { + expr.map(onExpr(netlist)) match { + case e @ WRef(name, _,_,_) => + netlist.get(we(e)) + .filter(isBitExtract) + .getOrElse(e) + // replace back-to-back Bits Extractions + case lhs @ DoPrim(lop, ival, lc, ltpe) if isSimpleExpr(lhs) => + ival.head match { + case of @ DoPrim(rop, rhs, rc, rtpe) if isSimpleExpr(of) => + (lop, rop) match { + case (Head, Head) => DoPrim(Head, rhs, Seq(lc.head min rc.head), ltpe) + case (Tail, Tail) => DoPrim(Tail, rhs, Seq(lc.head + rc.head), ltpe) + case (Shr, Shr) => DoPrim(Shr, rhs, Seq(lc.head + rc.head), ltpe) + case (_,_) => (lowerToDoPrimOpBits(lhs), lowerToDoPrimOpBits(of)) match { + case (DoPrim(Bits, _, Seq(lmsb, llsb), _), DoPrim(Bits, _, Seq(rmsb, rlsb), _)) => + DoPrim(Bits, rhs, Seq(lmsb+rlsb,llsb+rlsb), ltpe) + case (_,_) => lhs // Not a candidate + } + } + case _ => lhs // Not a candidate + } + case other => other // Not a candidate + } + } + + /** Inline bits in a Statement + * + * @param netlist a '''mutable''' HashMap mapping references to [[firrtl.ir.DefNode DefNode]]s to their connected + * [[firrtl.ir.Expression Expression]]s. This function '''will''' mutate it if stmt is a [[firrtl.ir.DefNode + * DefNode]] with a Temporary name and a value that is a [[PrimOp]] Bits + * @param stmt the Statement being searched for nodes and transformed + * @return Returns stmt with Bits inlined + */ + def onStmt(netlist: Netlist)(stmt: Statement): Statement = + stmt.map(onStmt(netlist)).map(onExpr(netlist)) match { + case node @ DefNode(_, name, value) if isTemp(name) => + netlist(we(WRef(name))) = value + node + case other => other + } + + /** Replaces bits in a Module */ + def onMod(mod: DefModule): DefModule = mod.map(onStmt(new Netlist)) +} + +/** Inline nodes that are simple bits */ +class InlineBitExtractionsTransform extends Transform { + def inputForm = UnknownForm + def outputForm = UnknownForm + + def execute(state: CircuitState): CircuitState = { + val modulesx = state.circuit.modules.map(InlineBitExtractionsTransform.onMod(_)) + state.copy(circuit = state.circuit.copy(modules = modulesx)) + } +} diff --git a/src/main/scala/firrtl/transforms/InlineNots.scala b/src/main/scala/firrtl/transforms/InlineNots.scala index 3dab5168..299c130a 100644 --- a/src/main/scala/firrtl/transforms/InlineNots.scala +++ b/src/main/scala/firrtl/transforms/InlineNots.scala @@ -3,8 +3,8 @@ package transforms import firrtl.ir._ import firrtl.Mappers._ -import firrtl.PrimOps.Not -import firrtl.Utils.isTemp +import firrtl.PrimOps.{Bits, Not} +import firrtl.Utils.{isBitExtract, isTemp} import firrtl.WrappedExpression._ import scala.collection.mutable @@ -42,10 +42,18 @@ object InlineNotsTransform { netlist.get(we(e)) .filter(isNot) .getOrElse(e) + // replace bits-of-not with not-of-bits to enable later bit extraction transform + case lhs @ DoPrim(op, Seq(lval), lcons, ltpe) if isBitExtract(op) && isSimpleExpr(lval) => + netlist.getOrElse(we(lval), lval) match { + case DoPrim(Not, Seq(rhs), rcons, rtpe) => + DoPrim(Not, Seq(DoPrim(op, Seq(rhs), lcons, ltpe)), rcons, ltpe) + case _ => lhs // Not a candiate + } // replace back-to-back inversions with a straight rename - case lhs @ DoPrim(Not, Seq(inv), _,_) if isSimpleExpr(inv) => + case lhs @ DoPrim(Not, Seq(inv), _, invtpe) if isSimpleExpr(lhs) && isSimpleExpr(inv) && (lhs.tpe == invtpe) && (bitWidth(lhs.tpe) == bitWidth(inv.tpe)) => netlist.getOrElse(we(inv), inv) match { - case DoPrim(Not, Seq(rhs), _,_) if isSimpleExpr(inv) => rhs + case DoPrim(Not, Seq(rhs), _, rtpe) if (invtpe == rtpe) && (bitWidth(inv.tpe) == bitWidth(rhs.tpe)) => + DoPrim(Bits, Seq(rhs), Seq(bitWidth(lhs.tpe)-1,0), rtpe) case _ => lhs // Not a candiate } case other => other // Not a candidate @@ -56,7 +64,7 @@ object InlineNotsTransform { * * @param netlist a '''mutable''' HashMap mapping references to [[firrtl.ir.DefNode DefNode]]s to their connected * [[firrtl.ir.Expression Expression]]s. This function '''will''' mutate it if stmt is a [[firrtl.ir.DefNode - * DefNode]] with a value that is a [[PrimOp]] Not + * DefNode]] with a Temporary name and a value that is a [[PrimOp]] Not * @param stmt the Statement being searched for nodes and transformed * @return Returns stmt with nots inlined */ @@ -74,8 +82,8 @@ object InlineNotsTransform { /** Inline nodes that are simple nots */ class InlineNotsTransform extends Transform { - def inputForm = LowForm - def outputForm = LowForm + def inputForm = UnknownForm + def outputForm = UnknownForm def execute(state: CircuitState): CircuitState = { val modulesx = state.circuit.modules.map(InlineNotsTransform.onMod(_)) diff --git a/src/test/scala/firrtlTests/VerilogEmitterTests.scala b/src/test/scala/firrtlTests/VerilogEmitterTests.scala index bb7659e9..bce9b155 100644 --- a/src/test/scala/firrtlTests/VerilogEmitterTests.scala +++ b/src/test/scala/firrtlTests/VerilogEmitterTests.scala @@ -88,6 +88,100 @@ class DoPrimVerilog extends FirrtlFlatSpec { |""".stripMargin.split("\n") map normalized executeTest(input, check, compiler) } + "inline Bits" should "emit correctly" in { + val compiler = new VerilogCompiler + val input = + """circuit InlineBits : + | module InlineBits : + | input a: UInt<4> + | output b: UInt<1> + | output c: UInt<3> + | output d: UInt<2> + | output e: UInt<2> + | output f: UInt<2> + | output g: UInt<2> + | output h: UInt<2> + | output i: UInt<2> + | output j: UInt<2> + | output k: UInt<1> + | output l: UInt<1> + | output m: UInt<1> + | output n: UInt<1> + | output o: UInt<2> + | output p: UInt<2> + | output q: UInt<2> + | output r: UInt<1> + | output s: UInt<2> + | output t: UInt<2> + | output u: UInt<1> + | b <= bits(a, 2, 2) + | c <= bits(a, 3, 1) + | d <= head(a, 2) + | e <= tail(a, 2) + | f <= bits(bits(a, 3, 1), 2, 1) + | g <= bits(head(a, 3), 1, 0) + | h <= bits(tail(a, 1), 1, 0) + | i <= bits(shr(a, 1), 1, 0) + | j <= head(bits(a, 3, 1), 2) + | k <= head(head(a, 3), 1) + | l <= head(tail(a, 1), 1) + | m <= head(shr(a, 1), 1) + | n <= tail(bits(a, 3, 1), 2) + | o <= tail(head(a, 3), 1) + | p <= tail(tail(a, 1), 1) + | q <= tail(shr(a, 1), 1) + | r <= shr(bits(a, 1, 0), 1) + | s <= shr(head(a, 3), 1) + | t <= shr(tail(a, 1), 1) + | u <= shr(shr(a, 1), 2)""".stripMargin + val check = + """module InlineBits( + | input [3:0] a, + | output b, + | output [2:0] c, + | output [1:0] d, + | output [1:0] e, + | output [1:0] f, + | output [1:0] g, + | output [1:0] h, + | output [1:0] i, + | output [1:0] j, + | output k, + | output l, + | output m, + | output n, + | output [1:0] o, + | output [1:0] p, + | output [1:0] q, + | output r, + | output [1:0] s, + | output [1:0] t, + | output u + |); + | assign b = a[2]; + | assign c = a[3:1]; + | assign d = a[3:2]; + | assign e = a[1:0]; + | assign f = a[3:2]; + | assign g = a[2:1]; + | assign h = a[1:0]; + | assign i = a[2:1]; + | assign j = a[3:2]; + | assign k = a[3]; + | assign l = a[2]; + | assign m = a[3]; + | assign n = a[1]; + | assign o = a[2:1]; + | assign p = a[1:0]; + | assign q = a[2:1]; + | assign r = a[1]; + | assign s = a[3:2]; + | assign t = a[2:1]; + | assign u = a[3]; + |endmodule + |""".stripMargin.split("\n") map normalized + executeTest(input, check, compiler) + } "inline Not" should "emit correctly" in { val compiler = new VerilogCompiler val input = @@ -100,10 +194,12 @@ class DoPrimVerilog extends FirrtlFlatSpec { | output e: UInt<1> | output f: UInt<1> | output g: UInt<1> + | output h: UInt<1> | d <= and(a, not(b)) | e <= or(a, not(b)) | f <= not(not(not(bits(c, 2, 2)))) - | g <= mux(not(bits(c, 2, 2)), a, b)""".stripMargin + | g <= mux(not(bits(c, 2, 2)), a, b) + | h <= shr(not(bits(c, 2, 1)), 1)""".stripMargin val check = """module InlineNot( | input a, @@ -112,16 +208,14 @@ class DoPrimVerilog extends FirrtlFlatSpec { | output d, | output e, | output f, - | output g + | output g, + | output h |); - | wire _GEN_2; - | wire _GEN_4; | assign d = a & ~b; | assign e = a | ~b; - | assign _GEN_2 = c[2]; - | assign _GEN_4 = _GEN_2; - | assign f = ~_GEN_4; - | assign g = _GEN_2 ? b : a; + | assign f = ~c[2]; + | assign g = c[2] ? b : a; + | assign h = ~c[2]; |endmodule |""".stripMargin.split("\n") map normalized executeTest(input, check, compiler) |
