diff options
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 30 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 22 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/CompilerTests.scala | 20 |
3 files changed, 41 insertions, 31 deletions
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index bb65201b..04bfb19c 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -183,19 +183,23 @@ object ExpandConnects extends Pass { object Legalize extends Pass { private def legalizeShiftRight(e: DoPrim): Expression = { require(e.op == Shr) - val amount = e.consts.head.toInt - val width = bitWidth(e.args.head.tpe) - lazy val msb = width - 1 - if (amount >= width) { - e.tpe match { - case UIntType(_) => zero - case SIntType(_) => - val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType) - DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1))) - case t => error(s"Unsupported type $t for Primop Shift Right") - } - } else { - e + e.args.head match { + case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.foldShiftRight(e) + case _ => + val amount = e.consts.head.toInt + val width = bitWidth(e.args.head.tpe) + lazy val msb = width - 1 + if (amount >= width) { + e.tpe match { + case UIntType(_) => zero + case SIntType(_) => + val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType) + DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1))) + case t => error(s"Unsupported type $t for Primop Shift Right") + } + } else { + e + } } } private def legalizeBitExtract(expr: DoPrim): Expression = { diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 54338719..6618312a 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -45,6 +45,17 @@ object ConstantPropagation { case _ => e } } + + def foldShiftRight(e: DoPrim) = e.consts.head.toInt match { + case 0 => e.args.head + case x => e.args.head match { + // TODO when amount >= x.width, return a zero-width wire + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1)) + // take sign bit if shift amount is larger than arg width + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1)) + case _ => e + } + } } class ConstantPropagation extends Transform with ResolvedAnnotationPaths { @@ -144,17 +155,6 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { case _ => e } - private def foldShiftRight(e: DoPrim) = e.consts.head.toInt match { - case 0 => e.args.head - case x => e.args.head match { - // TODO when amount >= x.width, return a zero-width wire - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1)) - // take sign bit if shift amount is larger than arg width - case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1)) - case _ => e - } - } - private def foldDynamicShiftRight(e: DoPrim) = e.args.last match { case UIntLiteral(v, IntWidth(w)) => val shr = DoPrim(Shr, Seq(e.args.head), Seq(v), UnknownType) diff --git a/src/test/scala/firrtlTests/CompilerTests.scala b/src/test/scala/firrtlTests/CompilerTests.scala index df83dd38..dc70847a 100644 --- a/src/test/scala/firrtlTests/CompilerTests.scala +++ b/src/test/scala/firrtlTests/CompilerTests.scala @@ -158,19 +158,25 @@ class VerilogCompilerSpec extends CompilerSpec with Matchers { class MinimumVerilogCompilerSpec extends CompilerSpec with Matchers { val input = """|circuit Top: | module Top: - | output b: UInt<1>[2] - | node c = UInt<1>("h1") - | b[0] <= c - | b[1] is invalid + | output b: UInt<1>[3] + | node c = bits(UInt<3>("h7"), 2, 2) + | node d = shr(UInt<3>("h7"), 2) + | b[0] is invalid + | b[1] <= c + | b[2] <= d |""".stripMargin val check = """|module Top( | output b_0, - | output b_1 + | output b_1, + | output b_2 |); | wire c; + | wire d; | assign c = 1'h1; - | assign b_0 = c; - | assign b_1 = 1'h0; + | assign d = 1'h1; + | assign b_0 = 1'h0; + | assign b_1 = c; + | assign b_2 = d; |endmodule |""".stripMargin def compiler = new MinimumVerilogCompiler() |
