aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/transforms')
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala2
-rw-r--r--src/main/scala/firrtl/transforms/InlineNots.scala84
2 files changed, 86 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index f224546b..a008a4d3 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -155,6 +155,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1))
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs
+ case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => DoPrim(Not, Seq(rhs), Nil, e.tpe)
case _ => e
}
}
@@ -163,6 +164,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1))
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs
+ case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => DoPrim(Not, Seq(rhs), Nil, e.tpe)
case _ => e
}
}
diff --git a/src/main/scala/firrtl/transforms/InlineNots.scala b/src/main/scala/firrtl/transforms/InlineNots.scala
new file mode 100644
index 00000000..3dab5168
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/InlineNots.scala
@@ -0,0 +1,84 @@
+package firrtl
+package transforms
+
+import firrtl.ir._
+import firrtl.Mappers._
+import firrtl.PrimOps.Not
+import firrtl.Utils.isTemp
+import firrtl.WrappedExpression._
+
+import scala.collection.mutable
+
+object InlineNotsTransform {
+
+ /** Returns true if Expression is a Not PrimOp, false otherwise */
+ private def isNot(expr: Expression): Boolean = expr match {
+ case DoPrim(Not, args,_,_) => args.forall(isSimpleExpr)
+ case _ => false
+ }
+
+ // Checks if an Expression is made up of only Nots terminated by a Literal or Reference.
+ // private because it's not clear if this definition of "Simple Expression" would be useful elsewhere.
+ // Note that this can have false negatives but MUST NOT have false positives.
+ private def isSimpleExpr(expr: Expression): Boolean = expr match {
+ case _: WRef | _: Literal | _: WSubField => true
+ case DoPrim(Not, args, _,_) => args.forall(isSimpleExpr)
+ case _ => false
+ }
+
+ /** Mapping from references to the [[firrtl.ir.Expression Expression]]s that drive them */
+ type Netlist = mutable.HashMap[WrappedExpression, Expression]
+
+ /** Recursively replace [[WRef]]s with new [[Expression]]s
+ *
+ * @param netlist a '''mutable''' HashMap mapping references to [[firrtl.ir.DefNode DefNode]]s to their connected
+ * [[firrtl.ir.Expression Expression]]s. It is '''not''' mutated in this function
+ * @param expr the Expression being transformed
+ * @return Returns expr with Nots inlined
+ */
+ def onExpr(netlist: Netlist)(expr: Expression): Expression = {
+ expr.map(onExpr(netlist)) match {
+ case e @ WRef(name, _,_,_) =>
+ netlist.get(we(e))
+ .filter(isNot)
+ .getOrElse(e)
+ // replace back-to-back inversions with a straight rename
+ case lhs @ DoPrim(Not, Seq(inv), _,_) if isSimpleExpr(inv) =>
+ netlist.getOrElse(we(inv), inv) match {
+ case DoPrim(Not, Seq(rhs), _,_) if isSimpleExpr(inv) => rhs
+ case _ => lhs // Not a candiate
+ }
+ case other => other // Not a candidate
+ }
+ }
+
+ /** Inline nots in a Statement
+ *
+ * @param netlist a '''mutable''' HashMap mapping references to [[firrtl.ir.DefNode DefNode]]s to their connected
+ * [[firrtl.ir.Expression Expression]]s. This function '''will''' mutate it if stmt is a [[firrtl.ir.DefNode
+ * DefNode]] with a value that is a [[PrimOp]] Not
+ * @param stmt the Statement being searched for nodes and transformed
+ * @return Returns stmt with nots inlined
+ */
+ def onStmt(netlist: Netlist)(stmt: Statement): Statement =
+ stmt.map(onStmt(netlist)).map(onExpr(netlist)) match {
+ case node @ DefNode(_, name, value) if isTemp(name) =>
+ netlist(we(WRef(name))) = value
+ node
+ case other => other
+ }
+
+ /** Inline nots in a Module */
+ def onMod(mod: DefModule): DefModule = mod.map(onStmt(new Netlist))
+}
+
+/** Inline nodes that are simple nots */
+class InlineNotsTransform extends Transform {
+ def inputForm = LowForm
+ def outputForm = LowForm
+
+ def execute(state: CircuitState): CircuitState = {
+ val modulesx = state.circuit.modules.map(InlineNotsTransform.onMod(_))
+ state.copy(circuit = state.circuit.copy(modules = modulesx))
+ }
+}