aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSchuyler Eldridge2019-02-05 14:03:08 -0500
committerSchuyler Eldridge2019-02-05 14:09:42 -0500
commit0a88492bfbbfe7e446b74776ec59cab69e73585b (patch)
tree3d7a3bacd8debc917cd5525d6fdecdee6a50e31c /src
parenta77122b4bb8756636c169473af3dc367b14698ef (diff)
Do Shr constant propagation in Legalize
This uses the foldShiftRight method of the ConstantPropagation Transform when legalizing Shr PrimOps. This has the effect of removing literals with bit extracts from the MinimumVerilogCompiler. This makes the formerly private foldShiftRight method of a public method of the ConstantPropagation companion object. Tests in the MimimumVerilogCompilerSpec are updated to check that Shr is handled as intended. Signed-off-by: Schuyler Eldridge <schuyler.eldridge@ibm.com>
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala30
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala22
-rw-r--r--src/test/scala/firrtlTests/CompilerTests.scala20
3 files changed, 41 insertions, 31 deletions
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
index bb65201b..04bfb19c 100644
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ b/src/main/scala/firrtl/passes/Passes.scala
@@ -183,19 +183,23 @@ object ExpandConnects extends Pass {
object Legalize extends Pass {
private def legalizeShiftRight(e: DoPrim): Expression = {
require(e.op == Shr)
- val amount = e.consts.head.toInt
- val width = bitWidth(e.args.head.tpe)
- lazy val msb = width - 1
- if (amount >= width) {
- e.tpe match {
- case UIntType(_) => zero
- case SIntType(_) =>
- val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType)
- DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1)))
- case t => error(s"Unsupported type $t for Primop Shift Right")
- }
- } else {
- e
+ e.args.head match {
+ case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.foldShiftRight(e)
+ case _ =>
+ val amount = e.consts.head.toInt
+ val width = bitWidth(e.args.head.tpe)
+ lazy val msb = width - 1
+ if (amount >= width) {
+ e.tpe match {
+ case UIntType(_) => zero
+ case SIntType(_) =>
+ val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType)
+ DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1)))
+ case t => error(s"Unsupported type $t for Primop Shift Right")
+ }
+ } else {
+ e
+ }
}
}
private def legalizeBitExtract(expr: DoPrim): Expression = {
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index 54338719..6618312a 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -45,6 +45,17 @@ object ConstantPropagation {
case _ => e
}
}
+
+ def foldShiftRight(e: DoPrim) = e.consts.head.toInt match {
+ case 0 => e.args.head
+ case x => e.args.head match {
+ // TODO when amount >= x.width, return a zero-width wire
+ case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1))
+ // take sign bit if shift amount is larger than arg width
+ case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1))
+ case _ => e
+ }
+ }
}
class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
@@ -144,17 +155,6 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
case _ => e
}
- private def foldShiftRight(e: DoPrim) = e.consts.head.toInt match {
- case 0 => e.args.head
- case x => e.args.head match {
- // TODO when amount >= x.width, return a zero-width wire
- case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1))
- // take sign bit if shift amount is larger than arg width
- case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1))
- case _ => e
- }
- }
-
private def foldDynamicShiftRight(e: DoPrim) = e.args.last match {
case UIntLiteral(v, IntWidth(w)) =>
val shr = DoPrim(Shr, Seq(e.args.head), Seq(v), UnknownType)
diff --git a/src/test/scala/firrtlTests/CompilerTests.scala b/src/test/scala/firrtlTests/CompilerTests.scala
index df83dd38..dc70847a 100644
--- a/src/test/scala/firrtlTests/CompilerTests.scala
+++ b/src/test/scala/firrtlTests/CompilerTests.scala
@@ -158,19 +158,25 @@ class VerilogCompilerSpec extends CompilerSpec with Matchers {
class MinimumVerilogCompilerSpec extends CompilerSpec with Matchers {
val input = """|circuit Top:
| module Top:
- | output b: UInt<1>[2]
- | node c = UInt<1>("h1")
- | b[0] <= c
- | b[1] is invalid
+ | output b: UInt<1>[3]
+ | node c = bits(UInt<3>("h7"), 2, 2)
+ | node d = shr(UInt<3>("h7"), 2)
+ | b[0] is invalid
+ | b[1] <= c
+ | b[2] <= d
|""".stripMargin
val check = """|module Top(
| output b_0,
- | output b_1
+ | output b_1,
+ | output b_2
|);
| wire c;
+ | wire d;
| assign c = 1'h1;
- | assign b_0 = c;
- | assign b_1 = 1'h0;
+ | assign d = 1'h1;
+ | assign b_0 = 1'h0;
+ | assign b_1 = c;
+ | assign b_2 = d;
|endmodule
|""".stripMargin
def compiler = new MinimumVerilogCompiler()