diff options
Diffstat (limited to 'src/main/scala/firrtl/transforms')
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 31 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/InlineCasts.scala | 70 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/LegalizeClocks.scala | 69 |
3 files changed, 163 insertions, 7 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index a008a4d3..20b24e60 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -286,16 +286,33 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) case _ => e } - case AsUInt => e.args.head match { - case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w)) - case u: UIntLiteral => u - case _ => e - } + case AsUInt => + e.args.head match { + case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w)) + case arg => arg.tpe match { + case _: UIntType => arg + case _ => e + } + } case AsSInt => e.args.head match { case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt-1)) << w.toInt), IntWidth(w)) - case s: SIntLiteral => s - case _ => e + case arg => arg.tpe match { + case _: SIntType => arg + case _ => e + } } + case AsClock => + val arg = e.args.head + arg.tpe match { + case ClockType => arg + case _ => e + } + case AsAsyncReset => + val arg = e.args.head + arg.tpe match { + case AsyncResetType => arg + case _ => e + } case Pad => e.args.head match { case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head max w)) case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head max w)) diff --git a/src/main/scala/firrtl/transforms/InlineCasts.scala b/src/main/scala/firrtl/transforms/InlineCasts.scala new file mode 100644 index 00000000..e504eb70 --- /dev/null +++ b/src/main/scala/firrtl/transforms/InlineCasts.scala @@ -0,0 +1,70 @@ +package firrtl +package transforms + +import firrtl.ir._ +import firrtl.Mappers._ + +import firrtl.Utils.{isCast, NodeMap} + +object InlineCastsTransform { + + // Checks if an Expression is made up of only casts terminated by a Literal or Reference + // There must be at least one cast + // Note that this can have false negatives but MUST NOT have false positives + private def isSimpleCast(castSeen: Boolean)(expr: Expression): Boolean = expr match { + case _: WRef | _: Literal | _: WSubField => castSeen + case DoPrim(op, args, _,_) if isCast(op) => args.forall(isSimpleCast(true)) + case _ => false + } + + /** Recursively replace [[WRef]]s with new [[Expression]]s + * + * @param replace a '''mutable''' HashMap mapping [[WRef]]s to values with which the [[WRef]] + * will be replaced. It is '''not''' mutated in this function + * @param expr the Expression being transformed + * @return Returns expr with [[WRef]]s replaced by values found in replace + */ + def onExpr(replace: NodeMap)(expr: Expression): Expression = { + expr.map(onExpr(replace)) match { + case e @ WRef(name, _,_,_) => + replace.get(name) + .filter(isSimpleCast(castSeen=false)) + .getOrElse(e) + case e @ DoPrim(op, Seq(WRef(name, _,_,_)), _,_) if isCast(op) => + replace.get(name) + .map(value => e.copy(args = Seq(value))) + .getOrElse(e) + case other => other // Not a candidate + } + } + + /** Inline casts in a Statement + * + * @param netlist a '''mutable''' HashMap mapping references to [[firrtl.ir.DefNode DefNode]]s to their connected + * [[firrtl.ir.Expression Expression]]s. This function '''will''' mutate it if stmt is a [[firrtl.ir.DefNode + * DefNode]] with a value that is a cast [[PrimOp]] + * @param stmt the Statement being searched for nodes and transformed + * @return Returns stmt with casts inlined + */ + def onStmt(netlist: NodeMap)(stmt: Statement): Statement = + stmt.map(onStmt(netlist)).map(onExpr(netlist)) match { + case node @ DefNode(_, name, value) => + netlist(name) = value + node + case other => other + } + + /** Replaces truncating arithmetic in a Module */ + def onMod(mod: DefModule): DefModule = mod.map(onStmt(new NodeMap)) +} + +/** Inline nodes that are simple casts */ +class InlineCastsTransform extends Transform { + def inputForm = LowForm + def outputForm = LowForm + + def execute(state: CircuitState): CircuitState = { + val modulesx = state.circuit.modules.map(InlineCastsTransform.onMod(_)) + state.copy(circuit = state.circuit.copy(modules = modulesx)) + } +} diff --git a/src/main/scala/firrtl/transforms/LegalizeClocks.scala b/src/main/scala/firrtl/transforms/LegalizeClocks.scala new file mode 100644 index 00000000..1c2fc045 --- /dev/null +++ b/src/main/scala/firrtl/transforms/LegalizeClocks.scala @@ -0,0 +1,69 @@ +package firrtl +package transforms + +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.Utils.isCast + +// Fixup otherwise legal Verilog that lint tools and other tools don't like +// Currently: +// - don't emit "always @(posedge <literal>)" +// Hitting this case is rare, but legal FIRRTL +// TODO This should be unified with all Verilog legalization transforms +object LegalizeClocksTransform { + + // Checks if an Expression is illegal in use in a @(posedge <Expression>) construct + // Legality is defined here by what standard lint tools accept + // Currently only looks for literals nested within casts + private def illegalClockExpr(expr: Expression): Boolean = expr match { + case _: Literal => true + case DoPrim(op, args, _,_) if isCast(op) => args.exists(illegalClockExpr) + case _ => false + } + + /** Legalize Clocks in a Statement + * + * Enforces legal Verilog semantics on all Clock Expressions. + * Legal is defined as what standard lint tools accept. + * Currently only Literal Expressions (guarded by casts) are handled. + * + * @note namespace is lazy because it should not typically be needed + */ + def onStmt(namespace: => Namespace)(stmt: Statement): Statement = + stmt.map(onStmt(namespace)) match { + // Proper union types would deduplicate this code + case r: DefRegister if illegalClockExpr(r.clock) => + val node = DefNode(r.info, namespace.newTemp, r.clock) + val rx = r.copy(clock = WRef(node)) + Block(Seq(node, rx)) + case p: Print if illegalClockExpr(p.clk) => + val node = DefNode(p.info, namespace.newTemp, p.clk) + val px = p.copy(clk = WRef(node)) + Block(Seq(node, px)) + case s: Stop if illegalClockExpr(s.clk) => + val node = DefNode(s.info, namespace.newTemp, s.clk) + val sx = s.copy(clk = WRef(node)) + Block(Seq(node, sx)) + case other => other + } + + def onMod(mod: DefModule): DefModule = { + // It's actually *extremely* important that this Namespace is a lazy val + // onStmt accepts it lazily so that we don't perform the namespacing traversal unless necessary + // If we were to inline the declaration, it would create a Namespace for every problem, causing + // name collisions + lazy val namespace = Namespace(mod) + mod.map(onStmt(namespace)) + } +} + +/** Ensure Clocks to be emitted are legal Verilog */ +class LegalizeClocksTransform extends Transform { + def inputForm = LowForm + def outputForm = LowForm + + def execute(state: CircuitState): CircuitState = { + val modulesx = state.circuit.modules.map(LegalizeClocksTransform.onMod(_)) + state.copy(circuit = state.circuit.copy(modules = modulesx)) + } +} |
