aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJack Koenig2019-01-14 09:49:11 -0800
committerGitHub2019-01-14 09:49:11 -0800
commitdf3a34f01d227ff9ad0e63a41ff10001ac01c01d (patch)
treee61e46216a10747ae38d7d63aa638a35a73c381d /src
parent5f0e893c9213464507418a532ee61347a5da26c8 (diff)
parent9636b550505e4803c6d7307af7e01d996d0f0ea8 (diff)
Merge pull request #992 from freechipsproject/const-prop-dshifts
Constant Propagate dshl and dshr with constant amounts
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala24
-rw-r--r--src/test/scala/firrtlTests/ConstantPropagationTests.scala89
2 files changed, 111 insertions, 2 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index da7f1a46..8a273476 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -67,7 +67,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
}
object FoldADD extends FoldCommutativeOp {
- def fold(c1: Literal, c2: Literal) = (c1, c2) match {
+ def fold(c1: Literal, c2: Literal) = ((c1, c2): @unchecked) match {
case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1))
case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1))
}
@@ -137,6 +137,13 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
}
}
+ private def foldDynamicShiftLeft(e: DoPrim) = e.args.last match {
+ case UIntLiteral(v, IntWidth(w)) =>
+ val shl = DoPrim(Shl, Seq(e.args.head), Seq(v), UnknownType)
+ pad(PrimOps.set_primop_type(shl), e.tpe)
+ case _ => e
+ }
+
private def foldShiftRight(e: DoPrim) = e.consts.head.toInt match {
case 0 => e.args.head
case x => e.args.head match {
@@ -148,6 +155,14 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
}
}
+ private def foldDynamicShiftRight(e: DoPrim) = e.args.last match {
+ case UIntLiteral(v, IntWidth(w)) =>
+ val shr = DoPrim(Shr, Seq(e.args.head), Seq(v), UnknownType)
+ pad(PrimOps.set_primop_type(shr), e.tpe)
+ case _ => e
+ }
+
+
private def foldComparison(e: DoPrim) = {
def foldIfZeroedArg(x: Expression): Expression = {
def isUInt(e: Expression): Boolean = e.tpe match {
@@ -221,7 +236,9 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
private def constPropPrim(e: DoPrim): Expression = e.op match {
case Shl => foldShiftLeft(e)
+ case Dshl => foldDynamicShiftLeft(e)
case Shr => foldShiftRight(e)
+ case Dshr => foldDynamicShiftRight(e)
case Cat => foldConcat(e)
case Add => FoldADD(e)
case And => FoldAND(e)
@@ -277,6 +294,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e)
def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e)
+
private def constPropExpression(nodeMap: NodeMap, instMap: Map[String, String], constSubOutputs: Map[String, Map[String, Literal]])(e: Expression): Expression = {
val old = e map constPropExpression(nodeMap, instMap, constSubOutputs)
val propagated = old match {
@@ -290,7 +308,9 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
constSubOutputs.get(module).flatMap(_.get(pname)).getOrElse(ref)
case x => x
}
- propagated
+ // We're done when the Expression no longer changes
+ if (propagated eq old) propagated
+ else constPropExpression(nodeMap, instMap, constSubOutputs)(propagated)
}
/** Constant propagate a Module
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
index 603ddc25..8a69fcaa 100644
--- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala
+++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
@@ -734,6 +734,24 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec {
""".stripMargin
(parse(exec(input))) should be(parse(check))
}
+
+ // Optimizing this mux gives: z <= pad(UInt<2>(0), 4)
+ // Thus this checks that we then optimize that pad
+ "ConstProp" should "optimize nested Expressions" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<4>
+ | z <= mux(UInt(1), UInt<2>(0), UInt<4>(0))
+ """.stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<4>
+ | z <= UInt<4>("h0")
+ """.stripMargin
+ (parse(exec(input))) should be(parse(check))
+ }
}
// More sophisticated tests of the full compiler
@@ -1104,6 +1122,77 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec {
| z <= _T_61""".stripMargin
execute(input, check, Seq.empty)
}
+
+ behavior of "ConstProp"
+
+ it should "optimize shl of constants" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<7>
+ | z <= shl(UInt(5), 4)
+ """.stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<7>
+ | z <= UInt<7>("h50")
+ """.stripMargin
+ execute(input, check, Seq.empty)
+ }
+
+ it should "optimize shr of constants" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<1>
+ | z <= shr(UInt(5), 2)
+ """.stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<1>
+ | z <= UInt<1>("h1")
+ """.stripMargin
+ execute(input, check, Seq.empty)
+ }
+
+ // Due to #866, we need dshl optimized away or it'll become a dshlw and error in parsing
+ // Include cat to verify width is correct
+ it should "optimize dshl of constant" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<8>
+ | node n = dshl(UInt<1>(0), UInt<2>(0))
+ | z <= cat(UInt<4>("hf"), n)
+ """.stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<8>
+ | z <= UInt<8>("hf0")
+ """.stripMargin
+ execute(input, check, Seq.empty)
+ }
+
+ // Include cat and constants to verify width is correct
+ it should "optimize dshr of constant" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<8>
+ | node n = dshr(UInt<4>(0), UInt<2>(2))
+ | z <= cat(UInt<4>("hf"), n)
+ """.stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<8>
+ | z <= UInt<8>("hf0")
+ """.stripMargin
+ execute(input, check, Seq.empty)
+ }
}