diff options
| author | John Ingalls | 2020-01-06 18:47:19 -0800 |
|---|---|---|
| committer | Jack Koenig | 2020-01-06 18:47:19 -0800 |
| commit | f77487d37bd7c61be231a8000a3197d37cf55499 (patch) | |
| tree | 99208af73baad6fef176ce86d14a17e790e15d10 | |
| parent | dcf0076ca9b4b3c094d2d082717265fb4e326ae0 (diff) | |
Verilog emitter transform InlineNots (#1270)
[skip formal checks]
* ConstProp FoldEqual/FoldNotEqual propagate boolean (non-)equality with true/false
* transform InlineNots
* transform back-to-back Nots into straight rename
* swap mux with inverted select
Co-authored-by: Jack Koenig <jack.koenig3@gmail.com>
| -rw-r--r-- | src/main/scala/firrtl/Emitter.scala | 12 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/InlineNots.scala | 84 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ConstantPropagationTests.scala | 72 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/VerilogEmitterTests.scala | 56 |
5 files changed, 223 insertions, 3 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index c87928c5..95c762ae 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -237,7 +237,10 @@ class VerilogEmitter extends SeqTransform with Emitter { if (e.tpe == AsyncResetType) { throw EmitterException("Cannot emit async reset muxes directly") } - emit(Seq(e.cond," ? ",cast(e.tval)," : ",cast(e.fval)),top + 1) + e.cond match { + case DoPrim(Not, Seq(sel), _,_) => emit(Seq(sel," ? ",cast(e.fval)," : ",cast(e.tval)),top + 1) + case _ => emit(Seq(e.cond," ? ",cast(e.tval)," : ",cast(e.fval)),top + 1) + } } case (e: ValidIf) => emit(Seq(cast(e.value)),top + 1) case (e: WRef) => w write e.serialize @@ -307,13 +310,15 @@ class VerilogEmitter extends SeqTransform with Emitter { def c0: Int = doprim.consts.head.toInt def c1: Int = doprim.consts(1).toInt - def checkArgumentLegality(e: Expression) = e match { + def checkArgumentLegality(e: Expression): Unit = e match { case _: UIntLiteral | _: SIntLiteral | _: WRef | _: WSubField => + case DoPrim(Not, args, _,_) => args.foreach(checkArgumentLegality) case _ => throw EmitterException(s"Can't emit ${e.getClass.getName} as PrimOp argument") } def checkCatArgumentLegality(e: Expression): Unit = e match { case _: UIntLiteral | _: SIntLiteral | _: WRef | _: WSubField => + case DoPrim(Not, args, _,_) => args.foreach(checkArgumentLegality) case DoPrim(Cat, args, _, _) => args foreach(checkCatArgumentLegality) case _ => throw EmitterException(s"Can't emit ${e.getClass.getName} as PrimOp argument") } @@ -378,7 +383,7 @@ class VerilogEmitter extends SeqTransform with Emitter { case (_: UIntType) => Seq("{1'b0,", cast(a0), "}") case (_: SIntType) => Seq(cast(a0)) } - case Not => Seq("~ ", a0) + case Not => Seq("~", a0) case And => Seq(cast_as(a0), " & ", cast_as(a1)) case Or => Seq(cast_as(a0), " | ", cast_as(a1)) case Xor => Seq(cast_as(a0), " ^ ", cast_as(a1)) @@ -961,6 +966,7 @@ class VerilogEmitter extends SeqTransform with Emitter { def transforms = Seq( new BlackBoxSourceHelper, new ReplaceTruncatingArithmetic, + new InlineNotsTransform, new FlattenRegUpdate, new DeadCodeElimination, passes.VerilogModulusCleanup, diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index f224546b..a008a4d3 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -155,6 +155,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1)) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs + case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => DoPrim(Not, Seq(rhs), Nil, e.tpe) case _ => e } } @@ -163,6 +164,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1)) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs + case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => DoPrim(Not, Seq(rhs), Nil, e.tpe) case _ => e } } diff --git a/src/main/scala/firrtl/transforms/InlineNots.scala b/src/main/scala/firrtl/transforms/InlineNots.scala new file mode 100644 index 00000000..3dab5168 --- /dev/null +++ b/src/main/scala/firrtl/transforms/InlineNots.scala @@ -0,0 +1,84 @@ +package firrtl +package transforms + +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.PrimOps.Not +import firrtl.Utils.isTemp +import firrtl.WrappedExpression._ + +import scala.collection.mutable + +object InlineNotsTransform { + + /** Returns true if Expression is a Not PrimOp, false otherwise */ + private def isNot(expr: Expression): Boolean = expr match { + case DoPrim(Not, args,_,_) => args.forall(isSimpleExpr) + case _ => false + } + + // Checks if an Expression is made up of only Nots terminated by a Literal or Reference. + // private because it's not clear if this definition of "Simple Expression" would be useful elsewhere. + // Note that this can have false negatives but MUST NOT have false positives. + private def isSimpleExpr(expr: Expression): Boolean = expr match { + case _: WRef | _: Literal | _: WSubField => true + case DoPrim(Not, args, _,_) => args.forall(isSimpleExpr) + case _ => false + } + + /** Mapping from references to the [[firrtl.ir.Expression Expression]]s that drive them */ + type Netlist = mutable.HashMap[WrappedExpression, Expression] + + /** Recursively replace [[WRef]]s with new [[Expression]]s + * + * @param netlist a '''mutable''' HashMap mapping references to [[firrtl.ir.DefNode DefNode]]s to their connected + * [[firrtl.ir.Expression Expression]]s. It is '''not''' mutated in this function + * @param expr the Expression being transformed + * @return Returns expr with Nots inlined + */ + def onExpr(netlist: Netlist)(expr: Expression): Expression = { + expr.map(onExpr(netlist)) match { + case e @ WRef(name, _,_,_) => + netlist.get(we(e)) + .filter(isNot) + .getOrElse(e) + // replace back-to-back inversions with a straight rename + case lhs @ DoPrim(Not, Seq(inv), _,_) if isSimpleExpr(inv) => + netlist.getOrElse(we(inv), inv) match { + case DoPrim(Not, Seq(rhs), _,_) if isSimpleExpr(inv) => rhs + case _ => lhs // Not a candiate + } + case other => other // Not a candidate + } + } + + /** Inline nots in a Statement + * + * @param netlist a '''mutable''' HashMap mapping references to [[firrtl.ir.DefNode DefNode]]s to their connected + * [[firrtl.ir.Expression Expression]]s. This function '''will''' mutate it if stmt is a [[firrtl.ir.DefNode + * DefNode]] with a value that is a [[PrimOp]] Not + * @param stmt the Statement being searched for nodes and transformed + * @return Returns stmt with nots inlined + */ + def onStmt(netlist: Netlist)(stmt: Statement): Statement = + stmt.map(onStmt(netlist)).map(onExpr(netlist)) match { + case node @ DefNode(_, name, value) if isTemp(name) => + netlist(we(WRef(name))) = value + node + case other => other + } + + /** Inline nots in a Module */ + def onMod(mod: DefModule): DefModule = mod.map(onStmt(new Netlist)) +} + +/** Inline nodes that are simple nots */ +class InlineNotsTransform extends Transform { + def inputForm = LowForm + def outputForm = LowForm + + def execute(state: CircuitState): CircuitState = { + val modulesx = state.circuit.modules.map(InlineNotsTransform.onMod(_)) + state.copy(circuit = state.circuit.copy(modules = modulesx)) + } +} diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index 79e73c80..71709255 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -735,6 +735,78 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { (parse(exec(input))) should be(parse(check)) } + "ConstProp" should "propagate boolean equality with true" in { + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | z <= eq(x, UInt<1>("h1")) + """.stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | z <= x + """.stripMargin + (parse(exec(input))) should be(parse(check)) + } + + "ConstProp" should "propagate boolean equality with false" in { + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | z <= eq(x, UInt<1>("h0")) + """.stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | z <= not(x) + """.stripMargin + (parse(exec(input))) should be(parse(check)) + } + + "ConstProp" should "propagate boolean non-equality with true" in { + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | z <= neq(x, UInt<1>("h1")) + """.stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | z <= not(x) + """.stripMargin + (parse(exec(input))) should be(parse(check)) + } + + "ConstProp" should "propagate boolean non-equality with false" in { + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | z <= neq(x, UInt<1>("h0")) + """.stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | z <= x + """.stripMargin + (parse(exec(input))) should be(parse(check)) + } + // Optimizing this mux gives: z <= pad(UInt<2>(0), 4) // Thus this checks that we then optimize that pad "ConstProp" should "optimize nested Expressions" in { diff --git a/src/test/scala/firrtlTests/VerilogEmitterTests.scala b/src/test/scala/firrtlTests/VerilogEmitterTests.scala index cf2ff320..0376a830 100644 --- a/src/test/scala/firrtlTests/VerilogEmitterTests.scala +++ b/src/test/scala/firrtlTests/VerilogEmitterTests.scala @@ -70,6 +70,62 @@ class DoPrimVerilog extends FirrtlFlatSpec { |""".stripMargin.split("\n") map normalized executeTest(input, check, compiler) } + "Not" should "emit correctly" in { + val compiler = new VerilogCompiler + val input = + """circuit Not : + | module Not : + | input a: UInt<1> + | output b: UInt<1> + | b <= not(a)""".stripMargin + val check = + """module Not( + | input a, + | output b + |); + | assign b = ~a; + |endmodule + |""".stripMargin.split("\n") map normalized + executeTest(input, check, compiler) + } + "inline Not" should "emit correctly" in { + val compiler = new VerilogCompiler + val input = + """circuit InlineNot : + | module InlineNot : + | input a: UInt<1> + | input b: UInt<1> + | input c: UInt<4> + | output d: UInt<1> + | output e: UInt<1> + | output f: UInt<1> + | output g: UInt<1> + | d <= and(a, not(b)) + | e <= or(a, not(b)) + | f <= not(not(not(bits(c, 2, 2)))) + | g <= mux(not(bits(c, 2, 2)), a, b)""".stripMargin + val check = + """module InlineNot( + | input a, + | input b, + | input [3:0] c, + | output d, + | output e, + | output f, + | output g + |); + | wire _GEN_2; + | wire _GEN_4; + | assign d = a & ~b; + | assign e = a | ~b; + | assign _GEN_2 = c[2]; + | assign _GEN_4 = _GEN_2; + | assign f = ~_GEN_4; + | assign g = _GEN_2 ? b : a; + |endmodule + |""".stripMargin.split("\n") map normalized + executeTest(input, check, compiler) + } "Rem" should "emit correctly" in { val compiler = new VerilogCompiler val input = |
