diff options
| author | Schuyler Eldridge | 2019-02-05 14:03:08 -0500 |
|---|---|---|
| committer | Schuyler Eldridge | 2019-02-05 14:09:42 -0500 |
| commit | 0a88492bfbbfe7e446b74776ec59cab69e73585b (patch) | |
| tree | 3d7a3bacd8debc917cd5525d6fdecdee6a50e31c /src | |
| parent | a77122b4bb8756636c169473af3dc367b14698ef (diff) | |
Do Shr constant propagation in Legalize
This uses the foldShiftRight method of the ConstantPropagation
Transform when legalizing Shr PrimOps. This has the effect of removing
literals with bit extracts from the MinimumVerilogCompiler.
This makes the formerly private foldShiftRight method of a public
method of the ConstantPropagation companion object.
Tests in the MimimumVerilogCompilerSpec are updated to check that Shr
is handled as intended.
Signed-off-by: Schuyler Eldridge <schuyler.eldridge@ibm.com>
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 30 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 22 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/CompilerTests.scala | 20 |
3 files changed, 41 insertions, 31 deletions
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index bb65201b..04bfb19c 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -183,19 +183,23 @@ object ExpandConnects extends Pass { object Legalize extends Pass { private def legalizeShiftRight(e: DoPrim): Expression = { require(e.op == Shr) - val amount = e.consts.head.toInt - val width = bitWidth(e.args.head.tpe) - lazy val msb = width - 1 - if (amount >= width) { - e.tpe match { - case UIntType(_) => zero - case SIntType(_) => - val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType) - DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1))) - case t => error(s"Unsupported type $t for Primop Shift Right") - } - } else { - e + e.args.head match { + case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.foldShiftRight(e) + case _ => + val amount = e.consts.head.toInt + val width = bitWidth(e.args.head.tpe) + lazy val msb = width - 1 + if (amount >= width) { + e.tpe match { + case UIntType(_) => zero + case SIntType(_) => + val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType) + DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1))) + case t => error(s"Unsupported type $t for Primop Shift Right") + } + } else { + e + } } } private def legalizeBitExtract(expr: DoPrim): Expression = { diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 54338719..6618312a 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -45,6 +45,17 @@ object ConstantPropagation { case _ => e } } + + def foldShiftRight(e: DoPrim) = e.consts.head.toInt match { + case 0 => e.args.head + case x => e.args.head match { + // TODO when amount >= x.width, return a zero-width wire + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1)) + // take sign bit if shift amount is larger than arg width + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1)) + case _ => e + } + } } class ConstantPropagation extends Transform with ResolvedAnnotationPaths { @@ -144,17 +155,6 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { case _ => e } - private def foldShiftRight(e: DoPrim) = e.consts.head.toInt match { - case 0 => e.args.head - case x => e.args.head match { - // TODO when amount >= x.width, return a zero-width wire - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1)) - // take sign bit if shift amount is larger than arg width - case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1)) - case _ => e - } - } - private def foldDynamicShiftRight(e: DoPrim) = e.args.last match { case UIntLiteral(v, IntWidth(w)) => val shr = DoPrim(Shr, Seq(e.args.head), Seq(v), UnknownType) diff --git a/src/test/scala/firrtlTests/CompilerTests.scala b/src/test/scala/firrtlTests/CompilerTests.scala index df83dd38..dc70847a 100644 --- a/src/test/scala/firrtlTests/CompilerTests.scala +++ b/src/test/scala/firrtlTests/CompilerTests.scala @@ -158,19 +158,25 @@ class VerilogCompilerSpec extends CompilerSpec with Matchers { class MinimumVerilogCompilerSpec extends CompilerSpec with Matchers { val input = """|circuit Top: | module Top: - | output b: UInt<1>[2] - | node c = UInt<1>("h1") - | b[0] <= c - | b[1] is invalid + | output b: UInt<1>[3] + | node c = bits(UInt<3>("h7"), 2, 2) + | node d = shr(UInt<3>("h7"), 2) + | b[0] is invalid + | b[1] <= c + | b[2] <= d |""".stripMargin val check = """|module Top( | output b_0, - | output b_1 + | output b_1, + | output b_2 |); | wire c; + | wire d; | assign c = 1'h1; - | assign b_0 = c; - | assign b_1 = 1'h0; + | assign d = 1'h1; + | assign b_0 = 1'h0; + | assign b_1 = c; + | assign b_2 = d; |endmodule |""".stripMargin def compiler = new MinimumVerilogCompiler() |
