diff options
Diffstat (limited to 'src/main')
| -rw-r--r-- | src/main/scala/firrtl/transforms/InlineCasts.scala | 41 |
1 files changed, 24 insertions, 17 deletions
diff --git a/src/main/scala/firrtl/transforms/InlineCasts.scala b/src/main/scala/firrtl/transforms/InlineCasts.scala index 71318eee..761252c1 100644 --- a/src/main/scala/firrtl/transforms/InlineCasts.scala +++ b/src/main/scala/firrtl/transforms/InlineCasts.scala @@ -28,23 +28,30 @@ object InlineCastsTransform { * @param expr the Expression being transformed * @return Returns expr with [[WRef]]s replaced by values found in replace */ - def onExpr(replace: NodeMap)(expr: Expression): Expression = expr match { - // Anything that may generate a part-select should not be inlined! - case DoPrim(op, _, _, _) if (isBitExtract(op) || op == Pad) => expr - case e => - e.map(onExpr(replace)) match { - case e @ WRef(name, _, _, _) => - replace - .get(name) - .filter(isSimpleCast(castSeen = false)) - .getOrElse(e) - case e @ DoPrim(op, Seq(WRef(name, _, _, _)), _, _) if isCast(op) => - replace - .get(name) - .map(value => e.copy(args = Seq(value))) - .getOrElse(e) - case other => other // Not a candidate - } + def onExpr(replace: NodeMap)(expr: Expression): Expression = { + // Keep track if we've seen any non-cast expressions while recursing + def rec(hasNonCastParent: Boolean)(expr: Expression): Expression = expr match { + // Skip pads to avoid inlining literals into pads which results in invalid Verilog + case DoPrim(op, _, _, _) if (isBitExtract(op) || op == Pad) => expr + case e => + e.map(rec(hasNonCastParent || !isCast(e))) match { + case e @ WRef(name, _, _, _) => + replace + .get(name) + .filter(isSimpleCast(castSeen = false)) + .getOrElse(e) + case e @ DoPrim(op, Seq(WRef(name, _, _, _)), _, _) if isCast(op) => + replace + .get(name) + // Only inline the Expression if there is no non-cast parent in the expression tree OR + // if the subtree contains only casts and references. + .filter(x => !hasNonCastParent || isSimpleCast(castSeen = true)(x)) + .map(value => e.copy(args = Seq(value))) + .getOrElse(e) + case other => other // Not a candidate + } + } + rec(false)(expr) } /** Inline casts in a Statement |
