aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/scala/firrtl/Compiler.scala1
-rw-r--r--src/main/scala/firrtl/Emitter.scala6
-rw-r--r--src/main/scala/firrtl/IR.scala21
-rw-r--r--src/main/scala/firrtl/PrimOps.scala4
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala55
-rw-r--r--src/test/scala/firrtlTests/IntegrationSpec.scala4
-rw-r--r--test/integration/RightShiftTester.fir141
7 files changed, 218 insertions, 14 deletions
diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala
index 782d43cb..aa3eeace 100644
--- a/src/main/scala/firrtl/Compiler.scala
+++ b/src/main/scala/firrtl/Compiler.scala
@@ -73,6 +73,7 @@ object VerilogCompiler extends Compiler {
RemoveAccesses,
ExpandWhens,
CheckInitialization,
+ Legalize,
ConstProp,
ResolveKinds,
InferTypes,
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala
index e9c90fd6..4b88f526 100644
--- a/src/main/scala/firrtl/Emitter.scala
+++ b/src/main/scala/firrtl/Emitter.scala
@@ -211,7 +211,11 @@ object VerilogEmitter extends Emitter {
}
}
case SHIFT_LEFT_OP => Seq(cast(a0())," << ",c0())
- case SHIFT_RIGHT_OP => Seq(a0(),"[", long_BANG(tpe(a0())) - 1,":",c0(),"]")
+ case SHIFT_RIGHT_OP => {
+ if (c0 >= long_BANG(tpe(a0)))
+ error("Verilog emitter does not support SHIFT_RIGHT >= arg width")
+ Seq(a0(),"[", long_BANG(tpe(a0())) - 1,":",c0(),"]")
+ }
case NEG_OP => Seq("-{",cast(a0()),"}")
case CONVERT_OP => {
tpe(a0()) match {
diff --git a/src/main/scala/firrtl/IR.scala b/src/main/scala/firrtl/IR.scala
index ce22f010..17b16052 100644
--- a/src/main/scala/firrtl/IR.scala
+++ b/src/main/scala/firrtl/IR.scala
@@ -42,7 +42,7 @@ case class FileInfo(file: String, line: Int, column: Int) extends Info {
override def toString(): String = s"$file@$line.$column"
}
-case class FIRRTLException(str:String) extends Exception
+case class FIRRTLException(str: String) extends Exception(str)
trait AST {
def serialize: String = firrtl.Serialize.serialize(this)
@@ -110,7 +110,24 @@ case class Stop(info: Info, ret: Int, clk: Expression, en: Expression) extends S
case class Print(info: Info, string: String, args: Seq[Expression], clk: Expression, en: Expression) extends Stmt
case class Empty() extends Stmt
-trait Width extends AST
+trait Width extends AST {
+ def +(x: Width): Width = (this, x) match {
+ case (a: IntWidth, b: IntWidth) => IntWidth(a.width + b.width)
+ case _ => UnknownWidth()
+ }
+ def -(x: Width): Width = (this, x) match {
+ case (a: IntWidth, b: IntWidth) => IntWidth(a.width - b.width)
+ case _ => UnknownWidth()
+ }
+ def max(x: Width): Width = (this, x) match {
+ case (a: IntWidth, b: IntWidth) => IntWidth(a.width max b.width)
+ case _ => UnknownWidth()
+ }
+ def min(x: Width): Width = (this, x) match {
+ case (a: IntWidth, b: IntWidth) => IntWidth(a.width min b.width)
+ case _ => UnknownWidth()
+ }
+}
case class IntWidth(width: BigInt) extends Width
case class UnknownWidth() extends Width
diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala
index 8a2865bb..ed3752f9 100644
--- a/src/main/scala/firrtl/PrimOps.scala
+++ b/src/main/scala/firrtl/PrimOps.scala
@@ -254,8 +254,8 @@ object PrimOps extends LazyLogging {
}
case SHIFT_RIGHT_OP => {
val t = (t1()) match {
- case (t1:UIntType) => UIntType(MINUS(w1(),c1()))
- case (t1:SIntType) => SIntType(MINUS(w1(),c1()))
+ case (t1:UIntType) => UIntType(MAX(MINUS(w1(),c1()),ONE))
+ case (t1:SIntType) => SIntType(MAX(MINUS(w1(),c1()),ONE))
case (t1) => UnknownType()
}
DoPrim(o,a,c,t)
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
index 7490c479..1e8ceae2 100644
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ b/src/main/scala/firrtl/passes/Passes.scala
@@ -1128,24 +1128,63 @@ object ExpandWhens extends Pass {
}
}
+// Replace shr by amount >= arg width with 0 for UInts and MSB for SInts
+// TODO replace UInt with zero-width wire instead
+object Legalize extends Pass {
+ def name = "Legalize"
+ def legalizeShiftRight (e: DoPrim): Expression = e.op match {
+ case SHIFT_RIGHT_OP => {
+ val amount = e.consts(0).toInt
+ val width = long_BANG(tpe(e.args(0)))
+ lazy val msb = width - 1
+ if (amount >= width) {
+ e.tpe match {
+ case t: UIntType => UIntValue(0, IntWidth(1))
+ case t: SIntType =>
+ DoPrim(BITS_SELECT_OP, e.args, Seq(msb, msb), SIntType(IntWidth(1)))
+ case t => error(s"Unsupported type ${t} for Primop Shift Right")
+ }
+ } else {
+ e
+ }
+ }
+ case _ => e
+ }
+ def run (c: Circuit): Circuit = {
+ def legalizeE (e: Expression): Expression = {
+ e map (legalizeE) match {
+ case e: DoPrim => legalizeShiftRight(e)
+ case e => e
+ }
+ }
+ def legalizeS (s: Stmt): Stmt = s map (legalizeS) map (legalizeE)
+ def legalizeM (m: Module): Module = m map (legalizeS)
+ Circuit(c.info, c.modules.map(legalizeM), c.main)
+ }
+}
+
object ConstProp extends Pass {
def name = "Constant Propogation"
var mname = ""
+
def const_prop_e (e:Expression) : Expression = {
e map (const_prop_e) match {
case (e:DoPrim) => {
e.op match {
case SHIFT_RIGHT_OP => {
- (e.args(0)) match {
- case (x:UIntValue) => {
- val b = x.value >> e.consts(0).toInt
- UIntValue(b,tpe(e).as[UIntType].get.width)
+ val amount = e.consts(0).toInt
+ e.args(0) match {
+ case x: UIntValue => {
+ val v = x.value >> amount
+ val w = (x.width - IntWidth(amount)) max IntWidth(1)
+ UIntValue(v, w)
}
- case (x:SIntValue) => {
- val b = x.value >> e.consts(0).toInt
- SIntValue(b,tpe(e).as[SIntType].get.width)
+ case x: SIntValue => { // take sign bit if shift amount is larger than arg width
+ val v = x.value >> amount
+ val w = (x.width - IntWidth(amount)) max IntWidth(1)
+ SIntValue(v, w)
}
- case (x) => e
+ case _ => e
}
}
case BITS_SELECT_OP => {
diff --git a/src/test/scala/firrtlTests/IntegrationSpec.scala b/src/test/scala/firrtlTests/IntegrationSpec.scala
index eb5a7fa1..e9afd739 100644
--- a/src/test/scala/firrtlTests/IntegrationSpec.scala
+++ b/src/test/scala/firrtlTests/IntegrationSpec.scala
@@ -8,7 +8,9 @@ class IntegrationSpec extends FirrtlPropSpec {
case class Test(name: String, dir: String)
- val runTests = Seq(Test("GCDTester", "/integration"))
+ val runTests = Seq(Test("GCDTester", "/integration"),
+ Test("RightShiftTester", "/integration"))
+
runTests foreach { test =>
property(s"${test.name} should execute correctly") {
diff --git a/test/integration/RightShiftTester.fir b/test/integration/RightShiftTester.fir
new file mode 100644
index 00000000..b73b98ff
--- /dev/null
+++ b/test/integration/RightShiftTester.fir
@@ -0,0 +1,141 @@
+circuit RightShiftTester :
+ module RightShift :
+ input clk : Clock
+ input reset : UInt<1>
+ output io : {flip i : UInt<1>, flip j : SInt<1>, i_shifted : UInt, j_shifted : SInt, k_shifted : UInt, l_shifted : UInt, m_shifted : SInt, n_shifted : SInt, o_shifted : UInt}
+
+ io is invalid
+ wire k : UInt<16>
+ k is invalid
+ k <= UInt<1>("h01")
+ wire o : UInt<32>
+ o is invalid
+ o <= UInt<21>("h012d687")
+ node T_19 = shr(io.i, 1)
+ io.i_shifted <= T_19
+ node T_20 = shr(io.j, 1)
+ io.j_shifted <= T_20
+ node T_21 = shr(k, 18)
+ io.k_shifted <= T_21
+ node T_23 = shr(UInt<4>("h0f"), 4)
+ io.l_shifted <= T_23
+ node T_25 = shr(asSInt(UInt<1>("h01")), 5)
+ io.m_shifted <= T_25
+ node T_27 = shr(asSInt(UInt<3>("h03")), 4)
+ io.n_shifted <= T_27
+ node T_28 = shr(o, 16)
+ io.o_shifted <= T_28
+
+ module RightShiftTester :
+ input clk : Clock
+ input reset : UInt<1>
+ output io : {}
+
+ io is invalid
+ inst dut of RightShift
+ dut.io is invalid
+ dut.clk <= clk
+ dut.reset <= reset
+ reg T_6 : UInt<2>, clk with : (reset => (reset, UInt<2>("h00")))
+ when UInt<1>("h01") :
+ node T_8 = eq(T_6, UInt<2>("h03"))
+ node T_10 = and(UInt<1>("h00"), T_8)
+ node T_13 = add(T_6, UInt<1>("h01"))
+ node T_14 = tail(T_13, 1)
+ node T_15 = mux(T_10, UInt<1>("h00"), T_14)
+ T_6 <= T_15
+ skip
+ node done = and(UInt<1>("h01"), T_8)
+ when done :
+ node T_18 = eq(reset, UInt<1>("h00"))
+ when T_18 :
+ stop(clk, UInt<1>(1), 0)
+ skip
+ skip
+ dut.io.i <= UInt<1>("h01")
+ dut.io.j <= asSInt(UInt<1>("h01"))
+ node T_22 = eq(dut.io.i_shifted, UInt<1>("h00"))
+ node T_24 = eq(reset, UInt<1>("h00"))
+ when T_24 :
+ node T_26 = eq(T_22, UInt<1>("h00"))
+ when T_26 :
+ node T_28 = eq(reset, UInt<1>("h00"))
+ when T_28 :
+ printf(clk, UInt<1>(1), "Assertion failed\n at RightShift.scala:47 assert(dut.io.i_shifted === UInt(0))\n")
+ skip
+ stop(clk, UInt<1>(1), 1)
+ skip
+ skip
+ node T_30 = eq(dut.io.j_shifted, asSInt(UInt<1>("h01")))
+ node T_32 = eq(reset, UInt<1>("h00"))
+ when T_32 :
+ node T_34 = eq(T_30, UInt<1>("h00"))
+ when T_34 :
+ node T_36 = eq(reset, UInt<1>("h00"))
+ when T_36 :
+ printf(clk, UInt<1>(1), "Assertion failed\n at RightShift.scala:48 assert(dut.io.j_shifted === SInt(-1))\n")
+ skip
+ stop(clk, UInt<1>(1), 1)
+ skip
+ skip
+ node T_38 = eq(dut.io.k_shifted, UInt<1>("h00"))
+ node T_40 = eq(reset, UInt<1>("h00"))
+ when T_40 :
+ node T_42 = eq(T_38, UInt<1>("h00"))
+ when T_42 :
+ node T_44 = eq(reset, UInt<1>("h00"))
+ when T_44 :
+ printf(clk, UInt<1>(1), "Assertion failed\n at RightShift.scala:49 assert(dut.io.k_shifted === UInt(0))\n")
+ skip
+ stop(clk, UInt<1>(1), 1)
+ skip
+ skip
+ node T_46 = eq(dut.io.l_shifted, UInt<1>("h00"))
+ node T_48 = eq(reset, UInt<1>("h00"))
+ when T_48 :
+ node T_50 = eq(T_46, UInt<1>("h00"))
+ when T_50 :
+ node T_52 = eq(reset, UInt<1>("h00"))
+ when T_52 :
+ printf(clk, UInt<1>(1), "Assertion failed\n at RightShift.scala:50 assert(dut.io.l_shifted === UInt(0))\n")
+ skip
+ stop(clk, UInt<1>(1), 1)
+ skip
+ skip
+ node T_54 = eq(dut.io.m_shifted, asSInt(UInt<1>("h01")))
+ node T_56 = eq(reset, UInt<1>("h00"))
+ when T_56 :
+ node T_58 = eq(T_54, UInt<1>("h00"))
+ when T_58 :
+ node T_60 = eq(reset, UInt<1>("h00"))
+ when T_60 :
+ printf(clk, UInt<1>(1), "Assertion failed\n at RightShift.scala:51 assert(dut.io.m_shifted === SInt(-1))\n")
+ skip
+ stop(clk, UInt<1>(1), 1)
+ skip
+ skip
+ node T_62 = eq(dut.io.n_shifted, asSInt(UInt<1>("h00")))
+ node T_64 = eq(reset, UInt<1>("h00"))
+ when T_64 :
+ node T_66 = eq(T_62, UInt<1>("h00"))
+ when T_66 :
+ node T_68 = eq(reset, UInt<1>("h00"))
+ when T_68 :
+ printf(clk, UInt<1>(1), "Assertion failed\n at RightShift.scala:52 assert(dut.io.n_shifted === SInt(0))\n")
+ skip
+ stop(clk, UInt<1>(1), 1)
+ skip
+ skip
+ node T_70 = eq(dut.io.o_shifted, UInt<5>("h012"))
+ node T_72 = eq(reset, UInt<1>("h00"))
+ when T_72 :
+ node T_74 = eq(T_70, UInt<1>("h00"))
+ when T_74 :
+ node T_76 = eq(reset, UInt<1>("h00"))
+ when T_76 :
+ printf(clk, UInt<1>(1), "Assertion failed\n at RightShift.scala:53 assert(dut.io.o_shifted === UInt(18))\n")
+ skip
+ stop(clk, UInt<1>(1), 1)
+ skip
+ skip
+