aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Ingalls2020-01-06 18:47:19 -0800
committerJack Koenig2020-01-06 18:47:19 -0800
commitf77487d37bd7c61be231a8000a3197d37cf55499 (patch)
tree99208af73baad6fef176ce86d14a17e790e15d10
parentdcf0076ca9b4b3c094d2d082717265fb4e326ae0 (diff)
Verilog emitter transform InlineNots (#1270)
[skip formal checks] * ConstProp FoldEqual/FoldNotEqual propagate boolean (non-)equality with true/false * transform InlineNots * transform back-to-back Nots into straight rename * swap mux with inverted select Co-authored-by: Jack Koenig <jack.koenig3@gmail.com>
-rw-r--r--src/main/scala/firrtl/Emitter.scala12
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala2
-rw-r--r--src/main/scala/firrtl/transforms/InlineNots.scala84
-rw-r--r--src/test/scala/firrtlTests/ConstantPropagationTests.scala72
-rw-r--r--src/test/scala/firrtlTests/VerilogEmitterTests.scala56
5 files changed, 223 insertions, 3 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala
index c87928c5..95c762ae 100644
--- a/src/main/scala/firrtl/Emitter.scala
+++ b/src/main/scala/firrtl/Emitter.scala
@@ -237,7 +237,10 @@ class VerilogEmitter extends SeqTransform with Emitter {
if (e.tpe == AsyncResetType) {
throw EmitterException("Cannot emit async reset muxes directly")
}
- emit(Seq(e.cond," ? ",cast(e.tval)," : ",cast(e.fval)),top + 1)
+ e.cond match {
+ case DoPrim(Not, Seq(sel), _,_) => emit(Seq(sel," ? ",cast(e.fval)," : ",cast(e.tval)),top + 1)
+ case _ => emit(Seq(e.cond," ? ",cast(e.tval)," : ",cast(e.fval)),top + 1)
+ }
}
case (e: ValidIf) => emit(Seq(cast(e.value)),top + 1)
case (e: WRef) => w write e.serialize
@@ -307,13 +310,15 @@ class VerilogEmitter extends SeqTransform with Emitter {
def c0: Int = doprim.consts.head.toInt
def c1: Int = doprim.consts(1).toInt
- def checkArgumentLegality(e: Expression) = e match {
+ def checkArgumentLegality(e: Expression): Unit = e match {
case _: UIntLiteral | _: SIntLiteral | _: WRef | _: WSubField =>
+ case DoPrim(Not, args, _,_) => args.foreach(checkArgumentLegality)
case _ => throw EmitterException(s"Can't emit ${e.getClass.getName} as PrimOp argument")
}
def checkCatArgumentLegality(e: Expression): Unit = e match {
case _: UIntLiteral | _: SIntLiteral | _: WRef | _: WSubField =>
+ case DoPrim(Not, args, _,_) => args.foreach(checkArgumentLegality)
case DoPrim(Cat, args, _, _) => args foreach(checkCatArgumentLegality)
case _ => throw EmitterException(s"Can't emit ${e.getClass.getName} as PrimOp argument")
}
@@ -378,7 +383,7 @@ class VerilogEmitter extends SeqTransform with Emitter {
case (_: UIntType) => Seq("{1'b0,", cast(a0), "}")
case (_: SIntType) => Seq(cast(a0))
}
- case Not => Seq("~ ", a0)
+ case Not => Seq("~", a0)
case And => Seq(cast_as(a0), " & ", cast_as(a1))
case Or => Seq(cast_as(a0), " | ", cast_as(a1))
case Xor => Seq(cast_as(a0), " ^ ", cast_as(a1))
@@ -961,6 +966,7 @@ class VerilogEmitter extends SeqTransform with Emitter {
def transforms = Seq(
new BlackBoxSourceHelper,
new ReplaceTruncatingArithmetic,
+ new InlineNotsTransform,
new FlattenRegUpdate,
new DeadCodeElimination,
passes.VerilogModulusCleanup,
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index f224546b..a008a4d3 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -155,6 +155,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1))
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs
+ case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => DoPrim(Not, Seq(rhs), Nil, e.tpe)
case _ => e
}
}
@@ -163,6 +164,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1))
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs
+ case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => DoPrim(Not, Seq(rhs), Nil, e.tpe)
case _ => e
}
}
diff --git a/src/main/scala/firrtl/transforms/InlineNots.scala b/src/main/scala/firrtl/transforms/InlineNots.scala
new file mode 100644
index 00000000..3dab5168
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/InlineNots.scala
@@ -0,0 +1,84 @@
+package firrtl
+package transforms
+
+import firrtl.ir._
+import firrtl.Mappers._
+import firrtl.PrimOps.Not
+import firrtl.Utils.isTemp
+import firrtl.WrappedExpression._
+
+import scala.collection.mutable
+
+object InlineNotsTransform {
+
+ /** Returns true if Expression is a Not PrimOp, false otherwise */
+ private def isNot(expr: Expression): Boolean = expr match {
+ case DoPrim(Not, args,_,_) => args.forall(isSimpleExpr)
+ case _ => false
+ }
+
+ // Checks if an Expression is made up of only Nots terminated by a Literal or Reference.
+ // private because it's not clear if this definition of "Simple Expression" would be useful elsewhere.
+ // Note that this can have false negatives but MUST NOT have false positives.
+ private def isSimpleExpr(expr: Expression): Boolean = expr match {
+ case _: WRef | _: Literal | _: WSubField => true
+ case DoPrim(Not, args, _,_) => args.forall(isSimpleExpr)
+ case _ => false
+ }
+
+ /** Mapping from references to the [[firrtl.ir.Expression Expression]]s that drive them */
+ type Netlist = mutable.HashMap[WrappedExpression, Expression]
+
+ /** Recursively replace [[WRef]]s with new [[Expression]]s
+ *
+ * @param netlist a '''mutable''' HashMap mapping references to [[firrtl.ir.DefNode DefNode]]s to their connected
+ * [[firrtl.ir.Expression Expression]]s. It is '''not''' mutated in this function
+ * @param expr the Expression being transformed
+ * @return Returns expr with Nots inlined
+ */
+ def onExpr(netlist: Netlist)(expr: Expression): Expression = {
+ expr.map(onExpr(netlist)) match {
+ case e @ WRef(name, _,_,_) =>
+ netlist.get(we(e))
+ .filter(isNot)
+ .getOrElse(e)
+ // replace back-to-back inversions with a straight rename
+ case lhs @ DoPrim(Not, Seq(inv), _,_) if isSimpleExpr(inv) =>
+ netlist.getOrElse(we(inv), inv) match {
+ case DoPrim(Not, Seq(rhs), _,_) if isSimpleExpr(inv) => rhs
+ case _ => lhs // Not a candiate
+ }
+ case other => other // Not a candidate
+ }
+ }
+
+ /** Inline nots in a Statement
+ *
+ * @param netlist a '''mutable''' HashMap mapping references to [[firrtl.ir.DefNode DefNode]]s to their connected
+ * [[firrtl.ir.Expression Expression]]s. This function '''will''' mutate it if stmt is a [[firrtl.ir.DefNode
+ * DefNode]] with a value that is a [[PrimOp]] Not
+ * @param stmt the Statement being searched for nodes and transformed
+ * @return Returns stmt with nots inlined
+ */
+ def onStmt(netlist: Netlist)(stmt: Statement): Statement =
+ stmt.map(onStmt(netlist)).map(onExpr(netlist)) match {
+ case node @ DefNode(_, name, value) if isTemp(name) =>
+ netlist(we(WRef(name))) = value
+ node
+ case other => other
+ }
+
+ /** Inline nots in a Module */
+ def onMod(mod: DefModule): DefModule = mod.map(onStmt(new Netlist))
+}
+
+/** Inline nodes that are simple nots */
+class InlineNotsTransform extends Transform {
+ def inputForm = LowForm
+ def outputForm = LowForm
+
+ def execute(state: CircuitState): CircuitState = {
+ val modulesx = state.circuit.modules.map(InlineNotsTransform.onMod(_))
+ state.copy(circuit = state.circuit.copy(modules = modulesx))
+ }
+}
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
index 79e73c80..71709255 100644
--- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala
+++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
@@ -735,6 +735,78 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec {
(parse(exec(input))) should be(parse(check))
}
+ "ConstProp" should "propagate boolean equality with true" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= eq(x, UInt<1>("h1"))
+ """.stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= x
+ """.stripMargin
+ (parse(exec(input))) should be(parse(check))
+ }
+
+ "ConstProp" should "propagate boolean equality with false" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= eq(x, UInt<1>("h0"))
+ """.stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= not(x)
+ """.stripMargin
+ (parse(exec(input))) should be(parse(check))
+ }
+
+ "ConstProp" should "propagate boolean non-equality with true" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= neq(x, UInt<1>("h1"))
+ """.stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= not(x)
+ """.stripMargin
+ (parse(exec(input))) should be(parse(check))
+ }
+
+ "ConstProp" should "propagate boolean non-equality with false" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= neq(x, UInt<1>("h0"))
+ """.stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= x
+ """.stripMargin
+ (parse(exec(input))) should be(parse(check))
+ }
+
// Optimizing this mux gives: z <= pad(UInt<2>(0), 4)
// Thus this checks that we then optimize that pad
"ConstProp" should "optimize nested Expressions" in {
diff --git a/src/test/scala/firrtlTests/VerilogEmitterTests.scala b/src/test/scala/firrtlTests/VerilogEmitterTests.scala
index cf2ff320..0376a830 100644
--- a/src/test/scala/firrtlTests/VerilogEmitterTests.scala
+++ b/src/test/scala/firrtlTests/VerilogEmitterTests.scala
@@ -70,6 +70,62 @@ class DoPrimVerilog extends FirrtlFlatSpec {
|""".stripMargin.split("\n") map normalized
executeTest(input, check, compiler)
}
+ "Not" should "emit correctly" in {
+ val compiler = new VerilogCompiler
+ val input =
+ """circuit Not :
+ | module Not :
+ | input a: UInt<1>
+ | output b: UInt<1>
+ | b <= not(a)""".stripMargin
+ val check =
+ """module Not(
+ | input a,
+ | output b
+ |);
+ | assign b = ~a;
+ |endmodule
+ |""".stripMargin.split("\n") map normalized
+ executeTest(input, check, compiler)
+ }
+ "inline Not" should "emit correctly" in {
+ val compiler = new VerilogCompiler
+ val input =
+ """circuit InlineNot :
+ | module InlineNot :
+ | input a: UInt<1>
+ | input b: UInt<1>
+ | input c: UInt<4>
+ | output d: UInt<1>
+ | output e: UInt<1>
+ | output f: UInt<1>
+ | output g: UInt<1>
+ | d <= and(a, not(b))
+ | e <= or(a, not(b))
+ | f <= not(not(not(bits(c, 2, 2))))
+ | g <= mux(not(bits(c, 2, 2)), a, b)""".stripMargin
+ val check =
+ """module InlineNot(
+ | input a,
+ | input b,
+ | input [3:0] c,
+ | output d,
+ | output e,
+ | output f,
+ | output g
+ |);
+ | wire _GEN_2;
+ | wire _GEN_4;
+ | assign d = a & ~b;
+ | assign e = a | ~b;
+ | assign _GEN_2 = c[2];
+ | assign _GEN_4 = _GEN_2;
+ | assign f = ~_GEN_4;
+ | assign g = _GEN_2 ? b : a;
+ |endmodule
+ |""".stripMargin.split("\n") map normalized
+ executeTest(input, check, compiler)
+ }
"Rem" should "emit correctly" in {
val compiler = new VerilogCompiler
val input =