diff options
Diffstat (limited to 'src/main')
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index da7f1a46..8a273476 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -67,7 +67,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { } object FoldADD extends FoldCommutativeOp { - def fold(c1: Literal, c2: Literal) = (c1, c2) match { + def fold(c1: Literal, c2: Literal) = ((c1, c2): @unchecked) match { case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1)) case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1)) } @@ -137,6 +137,13 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { } } + private def foldDynamicShiftLeft(e: DoPrim) = e.args.last match { + case UIntLiteral(v, IntWidth(w)) => + val shl = DoPrim(Shl, Seq(e.args.head), Seq(v), UnknownType) + pad(PrimOps.set_primop_type(shl), e.tpe) + case _ => e + } + private def foldShiftRight(e: DoPrim) = e.consts.head.toInt match { case 0 => e.args.head case x => e.args.head match { @@ -148,6 +155,14 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { } } + private def foldDynamicShiftRight(e: DoPrim) = e.args.last match { + case UIntLiteral(v, IntWidth(w)) => + val shr = DoPrim(Shr, Seq(e.args.head), Seq(v), UnknownType) + pad(PrimOps.set_primop_type(shr), e.tpe) + case _ => e + } + + private def foldComparison(e: DoPrim) = { def foldIfZeroedArg(x: Expression): Expression = { def isUInt(e: Expression): Boolean = e.tpe match { @@ -221,7 +236,9 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { private def constPropPrim(e: DoPrim): Expression = e.op match { case Shl => foldShiftLeft(e) + case Dshl => foldDynamicShiftLeft(e) case Shr => foldShiftRight(e) + case Dshr => foldDynamicShiftRight(e) case Cat => foldConcat(e) case Add => FoldADD(e) case And => FoldAND(e) @@ -277,6 +294,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e) def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e) + private def constPropExpression(nodeMap: NodeMap, instMap: Map[String, String], constSubOutputs: Map[String, Map[String, Literal]])(e: Expression): Expression = { val old = e map constPropExpression(nodeMap, instMap, constSubOutputs) val propagated = old match { @@ -290,7 +308,9 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { constSubOutputs.get(module).flatMap(_.get(pname)).getOrElse(ref) case x => x } - propagated + // We're done when the Expression no longer changes + if (propagated eq old) propagated + else constPropExpression(nodeMap, instMap, constSubOutputs)(propagated) } /** Constant propagate a Module |
