diff options
| -rw-r--r-- | src/main/scala/firrtl/Compiler.scala | 1 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Emitter.scala | 6 | ||||
| -rw-r--r-- | src/main/scala/firrtl/IR.scala | 21 | ||||
| -rw-r--r-- | src/main/scala/firrtl/PrimOps.scala | 4 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 55 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/IntegrationSpec.scala | 4 | ||||
| -rw-r--r-- | test/integration/RightShiftTester.fir | 141 |
7 files changed, 218 insertions, 14 deletions
diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala index 782d43cb..aa3eeace 100644 --- a/src/main/scala/firrtl/Compiler.scala +++ b/src/main/scala/firrtl/Compiler.scala @@ -73,6 +73,7 @@ object VerilogCompiler extends Compiler { RemoveAccesses, ExpandWhens, CheckInitialization, + Legalize, ConstProp, ResolveKinds, InferTypes, diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index e9c90fd6..4b88f526 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -211,7 +211,11 @@ object VerilogEmitter extends Emitter { } } case SHIFT_LEFT_OP => Seq(cast(a0())," << ",c0()) - case SHIFT_RIGHT_OP => Seq(a0(),"[", long_BANG(tpe(a0())) - 1,":",c0(),"]") + case SHIFT_RIGHT_OP => { + if (c0 >= long_BANG(tpe(a0))) + error("Verilog emitter does not support SHIFT_RIGHT >= arg width") + Seq(a0(),"[", long_BANG(tpe(a0())) - 1,":",c0(),"]") + } case NEG_OP => Seq("-{",cast(a0()),"}") case CONVERT_OP => { tpe(a0()) match { diff --git a/src/main/scala/firrtl/IR.scala b/src/main/scala/firrtl/IR.scala index ce22f010..17b16052 100644 --- a/src/main/scala/firrtl/IR.scala +++ b/src/main/scala/firrtl/IR.scala @@ -42,7 +42,7 @@ case class FileInfo(file: String, line: Int, column: Int) extends Info { override def toString(): String = s"$file@$line.$column" } -case class FIRRTLException(str:String) extends Exception +case class FIRRTLException(str: String) extends Exception(str) trait AST { def serialize: String = firrtl.Serialize.serialize(this) @@ -110,7 +110,24 @@ case class Stop(info: Info, ret: Int, clk: Expression, en: Expression) extends S case class Print(info: Info, string: String, args: Seq[Expression], clk: Expression, en: Expression) extends Stmt case class Empty() extends Stmt -trait Width extends AST +trait Width extends AST { + def +(x: Width): Width = (this, x) match { + case (a: IntWidth, b: IntWidth) => IntWidth(a.width + b.width) + case _ => UnknownWidth() + } + def -(x: Width): Width = (this, x) match { + case (a: IntWidth, b: IntWidth) => IntWidth(a.width - b.width) + case _ => UnknownWidth() + } + def max(x: Width): Width = (this, x) match { + case (a: IntWidth, b: IntWidth) => IntWidth(a.width max b.width) + case _ => UnknownWidth() + } + def min(x: Width): Width = (this, x) match { + case (a: IntWidth, b: IntWidth) => IntWidth(a.width min b.width) + case _ => UnknownWidth() + } +} case class IntWidth(width: BigInt) extends Width case class UnknownWidth() extends Width diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala index 8a2865bb..ed3752f9 100644 --- a/src/main/scala/firrtl/PrimOps.scala +++ b/src/main/scala/firrtl/PrimOps.scala @@ -254,8 +254,8 @@ object PrimOps extends LazyLogging { } case SHIFT_RIGHT_OP => { val t = (t1()) match { - case (t1:UIntType) => UIntType(MINUS(w1(),c1())) - case (t1:SIntType) => SIntType(MINUS(w1(),c1())) + case (t1:UIntType) => UIntType(MAX(MINUS(w1(),c1()),ONE)) + case (t1:SIntType) => SIntType(MAX(MINUS(w1(),c1()),ONE)) case (t1) => UnknownType() } DoPrim(o,a,c,t) diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 7490c479..1e8ceae2 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -1128,24 +1128,63 @@ object ExpandWhens extends Pass { } } +// Replace shr by amount >= arg width with 0 for UInts and MSB for SInts +// TODO replace UInt with zero-width wire instead +object Legalize extends Pass { + def name = "Legalize" + def legalizeShiftRight (e: DoPrim): Expression = e.op match { + case SHIFT_RIGHT_OP => { + val amount = e.consts(0).toInt + val width = long_BANG(tpe(e.args(0))) + lazy val msb = width - 1 + if (amount >= width) { + e.tpe match { + case t: UIntType => UIntValue(0, IntWidth(1)) + case t: SIntType => + DoPrim(BITS_SELECT_OP, e.args, Seq(msb, msb), SIntType(IntWidth(1))) + case t => error(s"Unsupported type ${t} for Primop Shift Right") + } + } else { + e + } + } + case _ => e + } + def run (c: Circuit): Circuit = { + def legalizeE (e: Expression): Expression = { + e map (legalizeE) match { + case e: DoPrim => legalizeShiftRight(e) + case e => e + } + } + def legalizeS (s: Stmt): Stmt = s map (legalizeS) map (legalizeE) + def legalizeM (m: Module): Module = m map (legalizeS) + Circuit(c.info, c.modules.map(legalizeM), c.main) + } +} + object ConstProp extends Pass { def name = "Constant Propogation" var mname = "" + def const_prop_e (e:Expression) : Expression = { e map (const_prop_e) match { case (e:DoPrim) => { e.op match { case SHIFT_RIGHT_OP => { - (e.args(0)) match { - case (x:UIntValue) => { - val b = x.value >> e.consts(0).toInt - UIntValue(b,tpe(e).as[UIntType].get.width) + val amount = e.consts(0).toInt + e.args(0) match { + case x: UIntValue => { + val v = x.value >> amount + val w = (x.width - IntWidth(amount)) max IntWidth(1) + UIntValue(v, w) } - case (x:SIntValue) => { - val b = x.value >> e.consts(0).toInt - SIntValue(b,tpe(e).as[SIntType].get.width) + case x: SIntValue => { // take sign bit if shift amount is larger than arg width + val v = x.value >> amount + val w = (x.width - IntWidth(amount)) max IntWidth(1) + SIntValue(v, w) } - case (x) => e + case _ => e } } case BITS_SELECT_OP => { diff --git a/src/test/scala/firrtlTests/IntegrationSpec.scala b/src/test/scala/firrtlTests/IntegrationSpec.scala index eb5a7fa1..e9afd739 100644 --- a/src/test/scala/firrtlTests/IntegrationSpec.scala +++ b/src/test/scala/firrtlTests/IntegrationSpec.scala @@ -8,7 +8,9 @@ class IntegrationSpec extends FirrtlPropSpec { case class Test(name: String, dir: String) - val runTests = Seq(Test("GCDTester", "/integration")) + val runTests = Seq(Test("GCDTester", "/integration"), + Test("RightShiftTester", "/integration")) + runTests foreach { test => property(s"${test.name} should execute correctly") { diff --git a/test/integration/RightShiftTester.fir b/test/integration/RightShiftTester.fir new file mode 100644 index 00000000..b73b98ff --- /dev/null +++ b/test/integration/RightShiftTester.fir @@ -0,0 +1,141 @@ +circuit RightShiftTester : + module RightShift : + input clk : Clock + input reset : UInt<1> + output io : {flip i : UInt<1>, flip j : SInt<1>, i_shifted : UInt, j_shifted : SInt, k_shifted : UInt, l_shifted : UInt, m_shifted : SInt, n_shifted : SInt, o_shifted : UInt} + + io is invalid + wire k : UInt<16> + k is invalid + k <= UInt<1>("h01") + wire o : UInt<32> + o is invalid + o <= UInt<21>("h012d687") + node T_19 = shr(io.i, 1) + io.i_shifted <= T_19 + node T_20 = shr(io.j, 1) + io.j_shifted <= T_20 + node T_21 = shr(k, 18) + io.k_shifted <= T_21 + node T_23 = shr(UInt<4>("h0f"), 4) + io.l_shifted <= T_23 + node T_25 = shr(asSInt(UInt<1>("h01")), 5) + io.m_shifted <= T_25 + node T_27 = shr(asSInt(UInt<3>("h03")), 4) + io.n_shifted <= T_27 + node T_28 = shr(o, 16) + io.o_shifted <= T_28 + + module RightShiftTester : + input clk : Clock + input reset : UInt<1> + output io : {} + + io is invalid + inst dut of RightShift + dut.io is invalid + dut.clk <= clk + dut.reset <= reset + reg T_6 : UInt<2>, clk with : (reset => (reset, UInt<2>("h00"))) + when UInt<1>("h01") : + node T_8 = eq(T_6, UInt<2>("h03")) + node T_10 = and(UInt<1>("h00"), T_8) + node T_13 = add(T_6, UInt<1>("h01")) + node T_14 = tail(T_13, 1) + node T_15 = mux(T_10, UInt<1>("h00"), T_14) + T_6 <= T_15 + skip + node done = and(UInt<1>("h01"), T_8) + when done : + node T_18 = eq(reset, UInt<1>("h00")) + when T_18 : + stop(clk, UInt<1>(1), 0) + skip + skip + dut.io.i <= UInt<1>("h01") + dut.io.j <= asSInt(UInt<1>("h01")) + node T_22 = eq(dut.io.i_shifted, UInt<1>("h00")) + node T_24 = eq(reset, UInt<1>("h00")) + when T_24 : + node T_26 = eq(T_22, UInt<1>("h00")) + when T_26 : + node T_28 = eq(reset, UInt<1>("h00")) + when T_28 : + printf(clk, UInt<1>(1), "Assertion failed\n at RightShift.scala:47 assert(dut.io.i_shifted === UInt(0))\n") + skip + stop(clk, UInt<1>(1), 1) + skip + skip + node T_30 = eq(dut.io.j_shifted, asSInt(UInt<1>("h01"))) + node T_32 = eq(reset, UInt<1>("h00")) + when T_32 : + node T_34 = eq(T_30, UInt<1>("h00")) + when T_34 : + node T_36 = eq(reset, UInt<1>("h00")) + when T_36 : + printf(clk, UInt<1>(1), "Assertion failed\n at RightShift.scala:48 assert(dut.io.j_shifted === SInt(-1))\n") + skip + stop(clk, UInt<1>(1), 1) + skip + skip + node T_38 = eq(dut.io.k_shifted, UInt<1>("h00")) + node T_40 = eq(reset, UInt<1>("h00")) + when T_40 : + node T_42 = eq(T_38, UInt<1>("h00")) + when T_42 : + node T_44 = eq(reset, UInt<1>("h00")) + when T_44 : + printf(clk, UInt<1>(1), "Assertion failed\n at RightShift.scala:49 assert(dut.io.k_shifted === UInt(0))\n") + skip + stop(clk, UInt<1>(1), 1) + skip + skip + node T_46 = eq(dut.io.l_shifted, UInt<1>("h00")) + node T_48 = eq(reset, UInt<1>("h00")) + when T_48 : + node T_50 = eq(T_46, UInt<1>("h00")) + when T_50 : + node T_52 = eq(reset, UInt<1>("h00")) + when T_52 : + printf(clk, UInt<1>(1), "Assertion failed\n at RightShift.scala:50 assert(dut.io.l_shifted === UInt(0))\n") + skip + stop(clk, UInt<1>(1), 1) + skip + skip + node T_54 = eq(dut.io.m_shifted, asSInt(UInt<1>("h01"))) + node T_56 = eq(reset, UInt<1>("h00")) + when T_56 : + node T_58 = eq(T_54, UInt<1>("h00")) + when T_58 : + node T_60 = eq(reset, UInt<1>("h00")) + when T_60 : + printf(clk, UInt<1>(1), "Assertion failed\n at RightShift.scala:51 assert(dut.io.m_shifted === SInt(-1))\n") + skip + stop(clk, UInt<1>(1), 1) + skip + skip + node T_62 = eq(dut.io.n_shifted, asSInt(UInt<1>("h00"))) + node T_64 = eq(reset, UInt<1>("h00")) + when T_64 : + node T_66 = eq(T_62, UInt<1>("h00")) + when T_66 : + node T_68 = eq(reset, UInt<1>("h00")) + when T_68 : + printf(clk, UInt<1>(1), "Assertion failed\n at RightShift.scala:52 assert(dut.io.n_shifted === SInt(0))\n") + skip + stop(clk, UInt<1>(1), 1) + skip + skip + node T_70 = eq(dut.io.o_shifted, UInt<5>("h012")) + node T_72 = eq(reset, UInt<1>("h00")) + when T_72 : + node T_74 = eq(T_70, UInt<1>("h00")) + when T_74 : + node T_76 = eq(reset, UInt<1>("h00")) + when T_76 : + printf(clk, UInt<1>(1), "Assertion failed\n at RightShift.scala:53 assert(dut.io.o_shifted === UInt(18))\n") + skip + stop(clk, UInt<1>(1), 1) + skip + skip + |
