aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/transforms')
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala31
-rw-r--r--src/main/scala/firrtl/transforms/InlineCasts.scala70
-rw-r--r--src/main/scala/firrtl/transforms/LegalizeClocks.scala69
3 files changed, 163 insertions, 7 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index a008a4d3..20b24e60 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -286,16 +286,33 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w))
case _ => e
}
- case AsUInt => e.args.head match {
- case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w))
- case u: UIntLiteral => u
- case _ => e
- }
+ case AsUInt =>
+ e.args.head match {
+ case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w))
+ case arg => arg.tpe match {
+ case _: UIntType => arg
+ case _ => e
+ }
+ }
case AsSInt => e.args.head match {
case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt-1)) << w.toInt), IntWidth(w))
- case s: SIntLiteral => s
- case _ => e
+ case arg => arg.tpe match {
+ case _: SIntType => arg
+ case _ => e
+ }
}
+ case AsClock =>
+ val arg = e.args.head
+ arg.tpe match {
+ case ClockType => arg
+ case _ => e
+ }
+ case AsAsyncReset =>
+ val arg = e.args.head
+ arg.tpe match {
+ case AsyncResetType => arg
+ case _ => e
+ }
case Pad => e.args.head match {
case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head max w))
case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head max w))
diff --git a/src/main/scala/firrtl/transforms/InlineCasts.scala b/src/main/scala/firrtl/transforms/InlineCasts.scala
new file mode 100644
index 00000000..e504eb70
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/InlineCasts.scala
@@ -0,0 +1,70 @@
+package firrtl
+package transforms
+
+import firrtl.ir._
+import firrtl.Mappers._
+
+import firrtl.Utils.{isCast, NodeMap}
+
+object InlineCastsTransform {
+
+ // Checks if an Expression is made up of only casts terminated by a Literal or Reference
+ // There must be at least one cast
+ // Note that this can have false negatives but MUST NOT have false positives
+ private def isSimpleCast(castSeen: Boolean)(expr: Expression): Boolean = expr match {
+ case _: WRef | _: Literal | _: WSubField => castSeen
+ case DoPrim(op, args, _,_) if isCast(op) => args.forall(isSimpleCast(true))
+ case _ => false
+ }
+
+ /** Recursively replace [[WRef]]s with new [[Expression]]s
+ *
+ * @param replace a '''mutable''' HashMap mapping [[WRef]]s to values with which the [[WRef]]
+ * will be replaced. It is '''not''' mutated in this function
+ * @param expr the Expression being transformed
+ * @return Returns expr with [[WRef]]s replaced by values found in replace
+ */
+ def onExpr(replace: NodeMap)(expr: Expression): Expression = {
+ expr.map(onExpr(replace)) match {
+ case e @ WRef(name, _,_,_) =>
+ replace.get(name)
+ .filter(isSimpleCast(castSeen=false))
+ .getOrElse(e)
+ case e @ DoPrim(op, Seq(WRef(name, _,_,_)), _,_) if isCast(op) =>
+ replace.get(name)
+ .map(value => e.copy(args = Seq(value)))
+ .getOrElse(e)
+ case other => other // Not a candidate
+ }
+ }
+
+ /** Inline casts in a Statement
+ *
+ * @param netlist a '''mutable''' HashMap mapping references to [[firrtl.ir.DefNode DefNode]]s to their connected
+ * [[firrtl.ir.Expression Expression]]s. This function '''will''' mutate it if stmt is a [[firrtl.ir.DefNode
+ * DefNode]] with a value that is a cast [[PrimOp]]
+ * @param stmt the Statement being searched for nodes and transformed
+ * @return Returns stmt with casts inlined
+ */
+ def onStmt(netlist: NodeMap)(stmt: Statement): Statement =
+ stmt.map(onStmt(netlist)).map(onExpr(netlist)) match {
+ case node @ DefNode(_, name, value) =>
+ netlist(name) = value
+ node
+ case other => other
+ }
+
+ /** Replaces truncating arithmetic in a Module */
+ def onMod(mod: DefModule): DefModule = mod.map(onStmt(new NodeMap))
+}
+
+/** Inline nodes that are simple casts */
+class InlineCastsTransform extends Transform {
+ def inputForm = LowForm
+ def outputForm = LowForm
+
+ def execute(state: CircuitState): CircuitState = {
+ val modulesx = state.circuit.modules.map(InlineCastsTransform.onMod(_))
+ state.copy(circuit = state.circuit.copy(modules = modulesx))
+ }
+}
diff --git a/src/main/scala/firrtl/transforms/LegalizeClocks.scala b/src/main/scala/firrtl/transforms/LegalizeClocks.scala
new file mode 100644
index 00000000..1c2fc045
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/LegalizeClocks.scala
@@ -0,0 +1,69 @@
+package firrtl
+package transforms
+
+import firrtl.ir._
+import firrtl.Mappers._
+import firrtl.Utils.isCast
+
+// Fixup otherwise legal Verilog that lint tools and other tools don't like
+// Currently:
+// - don't emit "always @(posedge <literal>)"
+// Hitting this case is rare, but legal FIRRTL
+// TODO This should be unified with all Verilog legalization transforms
+object LegalizeClocksTransform {
+
+ // Checks if an Expression is illegal in use in a @(posedge <Expression>) construct
+ // Legality is defined here by what standard lint tools accept
+ // Currently only looks for literals nested within casts
+ private def illegalClockExpr(expr: Expression): Boolean = expr match {
+ case _: Literal => true
+ case DoPrim(op, args, _,_) if isCast(op) => args.exists(illegalClockExpr)
+ case _ => false
+ }
+
+ /** Legalize Clocks in a Statement
+ *
+ * Enforces legal Verilog semantics on all Clock Expressions.
+ * Legal is defined as what standard lint tools accept.
+ * Currently only Literal Expressions (guarded by casts) are handled.
+ *
+ * @note namespace is lazy because it should not typically be needed
+ */
+ def onStmt(namespace: => Namespace)(stmt: Statement): Statement =
+ stmt.map(onStmt(namespace)) match {
+ // Proper union types would deduplicate this code
+ case r: DefRegister if illegalClockExpr(r.clock) =>
+ val node = DefNode(r.info, namespace.newTemp, r.clock)
+ val rx = r.copy(clock = WRef(node))
+ Block(Seq(node, rx))
+ case p: Print if illegalClockExpr(p.clk) =>
+ val node = DefNode(p.info, namespace.newTemp, p.clk)
+ val px = p.copy(clk = WRef(node))
+ Block(Seq(node, px))
+ case s: Stop if illegalClockExpr(s.clk) =>
+ val node = DefNode(s.info, namespace.newTemp, s.clk)
+ val sx = s.copy(clk = WRef(node))
+ Block(Seq(node, sx))
+ case other => other
+ }
+
+ def onMod(mod: DefModule): DefModule = {
+ // It's actually *extremely* important that this Namespace is a lazy val
+ // onStmt accepts it lazily so that we don't perform the namespacing traversal unless necessary
+ // If we were to inline the declaration, it would create a Namespace for every problem, causing
+ // name collisions
+ lazy val namespace = Namespace(mod)
+ mod.map(onStmt(namespace))
+ }
+}
+
+/** Ensure Clocks to be emitted are legal Verilog */
+class LegalizeClocksTransform extends Transform {
+ def inputForm = LowForm
+ def outputForm = LowForm
+
+ def execute(state: CircuitState): CircuitState = {
+ val modulesx = state.circuit.modules.map(LegalizeClocksTransform.onMod(_))
+ state.copy(circuit = state.circuit.copy(modules = modulesx))
+ }
+}