aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/PrimOps.scala24
-rw-r--r--src/main/scala/firrtl/passes/CheckWidths.scala6
-rw-r--r--src/main/scala/firrtl/passes/Checks.scala97
-rw-r--r--src/test/scala/firrtlTests/IntegrationSpec.scala6
-rw-r--r--src/test/scala/firrtlTests/WidthSpec.scala78
-rw-r--r--src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala8
6 files changed, 92 insertions, 127 deletions
diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala
index 1ca005d7..0e88ff45 100644
--- a/src/main/scala/firrtl/PrimOps.scala
+++ b/src/main/scala/firrtl/PrimOps.scala
@@ -129,86 +129,64 @@ object PrimOps extends LazyLogging {
e copy (tpe = e.op match {
case Add => (t1, t2) match {
case (_: UIntType, _: UIntType) => UIntType(PLUS(MAX(w1, w2), IntWidth(1)))
- case (_: UIntType, _: SIntType) => SIntType(PLUS(MAX(w1, MINUS(w2, IntWidth(1))), IntWidth(2)))
- case (_: SIntType, _: UIntType) => SIntType(PLUS(MAX(w2, MINUS(w1, IntWidth(1))), IntWidth(2)))
case (_: SIntType, _: SIntType) => SIntType(PLUS(MAX(w1, w2), IntWidth(1)))
case (_: FixedType, _: FixedType) => FixedType(PLUS(PLUS(MAX(p1, p2), MAX(MINUS(w1, p1), MINUS(w2, p2))), IntWidth(1)), MAX(p1, p2))
case _ => UnknownType
}
case Sub => (t1, t2) match {
- case (_: UIntType, _: UIntType) => SIntType(PLUS(MAX(w1, w2), IntWidth(1)))
- case (_: UIntType, _: SIntType) => SIntType(MAX(PLUS(w2, IntWidth(1)), PLUS(w1, IntWidth(2))))
- case (_: SIntType, _: UIntType) => SIntType(MAX(PLUS(w1, IntWidth(1)), PLUS(w2, IntWidth(2))))
+ case (_: UIntType, _: UIntType) => UIntType(PLUS(MAX(w1, w2), IntWidth(1)))
case (_: SIntType, _: SIntType) => SIntType(PLUS(MAX(w1, w2), IntWidth(1)))
case (_: FixedType, _: FixedType) => FixedType(PLUS(PLUS(MAX(p1, p2),MAX(MINUS(w1, p1), MINUS(w2, p2))),IntWidth(1)), MAX(p1, p2))
case _ => UnknownType
}
case Mul => (t1, t2) match {
case (_: UIntType, _: UIntType) => UIntType(PLUS(w1, w2))
- case (_: UIntType, _: SIntType) => SIntType(PLUS(w1, w2))
- case (_: SIntType, _: UIntType) => SIntType(PLUS(w1, w2))
case (_: SIntType, _: SIntType) => SIntType(PLUS(w1, w2))
case (_: FixedType, _: FixedType) => FixedType(PLUS(w1, w2), PLUS(p1, p2))
case _ => UnknownType
}
case Div => (t1, t2) match {
case (_: UIntType, _: UIntType) => UIntType(w1)
- case (_: UIntType, _: SIntType) => SIntType(PLUS(w1, IntWidth(1)))
- case (_: SIntType, _: UIntType) => SIntType(w1)
case (_: SIntType, _: SIntType) => SIntType(PLUS(w1, IntWidth(1)))
case _ => UnknownType
}
case Rem => (t1, t2) match {
case (_: UIntType, _: UIntType) => UIntType(MIN(w1, w2))
- case (_: UIntType, _: SIntType) => UIntType(MIN(w1, w2))
- case (_: SIntType, _: UIntType) => SIntType(MIN(w1, PLUS(w2, IntWidth(1))))
case (_: SIntType, _: SIntType) => SIntType(MIN(w1, w2))
case _ => UnknownType
}
case Lt => (t1, t2) match {
case (_: UIntType, _: UIntType) => Utils.BoolType
- case (_: SIntType, _: UIntType) => Utils.BoolType
- case (_: UIntType, _: SIntType) => Utils.BoolType
case (_: SIntType, _: SIntType) => Utils.BoolType
case (_: FixedType, _: FixedType) => Utils.BoolType
case _ => UnknownType
}
case Leq => (t1, t2) match {
case (_: UIntType, _: UIntType) => Utils.BoolType
- case (_: SIntType, _: UIntType) => Utils.BoolType
- case (_: UIntType, _: SIntType) => Utils.BoolType
case (_: SIntType, _: SIntType) => Utils.BoolType
case (_: FixedType, _: FixedType) => Utils.BoolType
case _ => UnknownType
}
case Gt => (t1, t2) match {
case (_: UIntType, _: UIntType) => Utils.BoolType
- case (_: SIntType, _: UIntType) => Utils.BoolType
- case (_: UIntType, _: SIntType) => Utils.BoolType
case (_: SIntType, _: SIntType) => Utils.BoolType
case (_: FixedType, _: FixedType) => Utils.BoolType
case _ => UnknownType
}
case Geq => (t1, t2) match {
case (_: UIntType, _: UIntType) => Utils.BoolType
- case (_: SIntType, _: UIntType) => Utils.BoolType
- case (_: UIntType, _: SIntType) => Utils.BoolType
case (_: SIntType, _: SIntType) => Utils.BoolType
case (_: FixedType, _: FixedType) => Utils.BoolType
case _ => UnknownType
}
case Eq => (t1, t2) match {
case (_: UIntType, _: UIntType) => Utils.BoolType
- case (_: SIntType, _: UIntType) => Utils.BoolType
- case (_: UIntType, _: SIntType) => Utils.BoolType
case (_: SIntType, _: SIntType) => Utils.BoolType
case (_: FixedType, _: FixedType) => Utils.BoolType
case _ => UnknownType
}
case Neq => (t1, t2) match {
case (_: UIntType, _: UIntType) => Utils.BoolType
- case (_: SIntType, _: UIntType) => Utils.BoolType
- case (_: UIntType, _: SIntType) => Utils.BoolType
case (_: SIntType, _: SIntType) => Utils.BoolType
case (_: FixedType, _: FixedType) => Utils.BoolType
case _ => UnknownType
diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala
index 55391d99..7406f09a 100644
--- a/src/main/scala/firrtl/passes/CheckWidths.scala
+++ b/src/main/scala/firrtl/passes/CheckWidths.scala
@@ -22,8 +22,8 @@ object CheckWidths extends Pass {
s"$info : [module $mname] Width of dshl shift amount cannot be larger than $DshlMaxWidth bits.")
class NegWidthException(info:Info, mname: String) extends PassException(
s"$info: [module $mname] Width cannot be negative or zero.")
- class BitsWidthException(info: Info, mname: String, hi: BigInt, width: BigInt) extends PassException(
- s"$info: [module $mname] High bit $hi in bits operator is larger than input width $width.")
+ class BitsWidthException(info: Info, mname: String, hi: BigInt, width: BigInt, exp: String) extends PassException(
+ s"$info: [module $mname] High bit $hi in bits operator is larger than input width $width in $exp.")
class HeadWidthException(info: Info, mname: String, n: BigInt, width: BigInt) extends PassException(
s"$info: [module $mname] Parameter $n in head operator is larger than input width $width.")
class TailWidthException(info: Info, mname: String, n: BigInt, width: BigInt) extends PassException(
@@ -69,7 +69,7 @@ object CheckWidths extends Pass {
case _ =>
}
case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) <= hi) =>
- errors append new BitsWidthException(info, mname, hi, bitWidth(a.tpe))
+ errors append new BitsWidthException(info, mname, hi, bitWidth(a.tpe), e.serialize)
case DoPrim(Head, Seq(a), Seq(n), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) < n) =>
errors append new HeadWidthException(info, mname, n, bitWidth(a.tpe))
case DoPrim(Tail, Seq(a), Seq(n), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) <= n) =>
diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala
index ce599112..6934fca2 100644
--- a/src/main/scala/firrtl/passes/Checks.scala
+++ b/src/main/scala/firrtl/passes/Checks.scala
@@ -112,7 +112,7 @@ object CheckHighForm extends Pass {
if (npercents != i) errors.append(new BadPrintfIncorrectNumException(info, mname))
}
- def checkValidLoc(info: Info, mname: String, e: Expression) = e match {
+ def checkValidLoc(info: Info, mname: String, e: Expression): Unit = e match {
case _: UIntLiteral | _: SIntLiteral | _: DoPrim =>
errors.append(new InvalidLOCException(info, mname))
case _ => // Do Nothing
@@ -254,6 +254,7 @@ object CheckTypes extends Pass {
class OpNotAllSameType(info: Info, mname: String, op: String) extends PassException(
s"$info: [module $mname] Primop $op requires all operands to have the same type.")
class OpNoMixFix(info:Info, mname: String, op: String) extends PassException(s"${info}: [module ${mname}] Primop ${op} cannot operate on args of some, but not all, fixed type.")
+ class OpNotCorrectType(info:Info, mname: String, op: String, tpes: Seq[String]) extends PassException(s"${info}: [module ${mname}] Primop ${op} does not have correct arg types: $tpes.")
class OpNotAnalog(info: Info, mname: String, exp: String) extends PassException(
s"$info: [module $mname] Attach requires all arguments to be Analog type: $exp.")
class NodePassiveType(info: Info, mname: String) extends PassException(
@@ -276,6 +277,9 @@ object CheckTypes extends Pass {
s"$info: [module $mname] Attach expression must be an port, wire, or port of instance: $expName.")
class IllegalResetType(info: Info, mname: String, exp: String) extends PassException(
s"$info: [module $mname] Register resets must have type UInt<1>: $exp.")
+ class IllegalUnknownType(info: Info, mname: String, exp: String) extends PassException(
+ s"$info: [module $mname] Uninferred type: $exp."
+ )
//;---------------- Helper Functions --------------
def ut: UIntType = UIntType(UnknownWidth)
@@ -290,65 +294,42 @@ object CheckTypes extends Pass {
case tx: BundleType => tx.fields forall (x => x.flip == Default && passive(x.tpe))
case tx => true
}
- def check_types_primop(info: Info, mname: String, e: DoPrim) {
- def all_same_type (ls:Seq[Expression]) {
- if (ls exists (x => wt(ls.head.tpe) != wt(e.tpe)))
- errors.append(new OpNotAllSameType(info, mname, e.op.serialize))
- }
- def allUSC(ls: Seq[Expression]) {
- val error = ls.foldLeft(false)((error, x) => x.tpe match {
- case (_: UIntType| _: SIntType| ClockType) => error
- case _ => true
- })
- if (error) errors.append(new OpNotGround(info, mname, e.op.serialize))
- }
- def allUSF(ls: Seq[Expression]) {
- val error = ls.foldLeft(false)((error, x) => x.tpe match {
- case (_: UIntType| _: SIntType| _: FixedType) => error
- case _ => true
- })
- if (error) errors.append(new OpNotGround(info, mname, e.op.serialize))
- }
- def allUS(ls: Seq[Expression]) {
- if (ls exists (x => x.tpe match {
- case _: UIntType | _: SIntType => false
- case _ => true
- })) errors.append(new OpNotGround(info, mname, e.op.serialize))
- }
- def allF(ls: Seq[Expression]) {
- val error = ls.foldLeft(false)((error, x) => x.tpe match {
- case _:FixedType => error
- case _ => true
- })
- if (error) errors.append(new OpNotGround(info, mname, e.op.serialize))
- }
- def strictFix(ls: Seq[Expression]) =
- ls.filter(!_.tpe.isInstanceOf[FixedType]).size match {
- case 0 =>
- case x if(x == ls.size) =>
- case x => errors.append(new OpNoMixFix(info, mname, e.op.serialize))
+ def check_types_primop(info: Info, mname: String, e: DoPrim): Unit = {
+ def checkAllTypes(exprs: Seq[Expression], okUInt: Boolean, okSInt: Boolean, okClock: Boolean, okFix: Boolean): Unit = {
+ exprs.foldLeft((false, false, false, false)) {
+ case ((isUInt, isSInt, isClock, isFix), expr) => expr.tpe match {
+ case u: UIntType => (true, isSInt, isClock, isFix)
+ case s: SIntType => (isUInt, true, isClock, isFix)
+ case ClockType => (isUInt, isSInt, true, isFix)
+ case f: FixedType => (isUInt, isSInt, isClock, true)
+ case UnknownType =>
+ errors.append(new IllegalUnknownType(info, mname, e.serialize))
+ (isUInt, isSInt, isClock, isFix)
+ case other => throwInternalError(Some(s"Illegal Type: ${other.serialize}"))
+ }
+ } match {
+ // (UInt, SInt, Clock, Fixed)
+ case (isAll, false, false, false) if isAll == okUInt =>
+ case (false, isAll, false, false) if isAll == okSInt =>
+ case (false, false, isAll, false) if isAll == okClock =>
+ case (false, false, false, isAll) if isAll == okFix =>
+ case x => errors.append(new OpNotCorrectType(info, mname, e.op.serialize, exprs.map(_.tpe.serialize)))
}
- def all_uint (ls: Seq[Expression]) {
- if (ls exists (x => x.tpe match {
- case _: UIntType => false
- case _ => true
- })) errors.append(new OpNotAllUInt(info, mname, e.op.serialize))
- }
- def is_uint (x:Expression) {
- if (x.tpe match {
- case _: UIntType => false
- case _ => true
- }) errors.append(new OpNotUInt(info, mname, e.op.serialize, x.serialize))
}
e.op match {
- case AsUInt | AsSInt | AsFixedPoint =>
- case AsClock => allUSC(e.args)
- case Dshl => is_uint(e.args(1)); allUSF(e.args)
- case Dshr => is_uint(e.args(1)); allUSF(e.args)
- case Add | Sub | Mul | Lt | Leq | Gt | Geq | Eq | Neq => allUSF(e.args); strictFix(e.args)
- case Pad | Shl | Shr | Cat | Bits | Head | Tail => allUSF(e.args)
- case BPShl | BPShr | BPSet => allF(e.args)
- case _ => allUS(e.args)
+ case AsUInt | AsSInt | AsClock | AsFixedPoint =>
+ // All types are ok
+ case Dshl | Dshr =>
+ checkAllTypes(Seq(e.args.head), okUInt=true, okSInt=true, okClock=false, okFix=true)
+ checkAllTypes(Seq(e.args(1)), okUInt=true, okSInt=false, okClock=false, okFix=false)
+ case Add | Sub | Mul | Lt | Leq | Gt | Geq | Eq | Neq =>
+ checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true)
+ case Pad | Shl | Shr | Cat | Bits | Head | Tail =>
+ checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true)
+ case BPShl | BPShr | BPSet =>
+ checkAllTypes(e.args, okUInt=false, okSInt=false, okClock=false, okFix=true)
+ case _ =>
+ checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=false)
}
}
@@ -421,7 +402,7 @@ object CheckTypes extends Pass {
)
case (t1: VectorType, t2: VectorType) =>
bulk_equals(t1.tpe, t2.tpe, flip1, flip2)
- case (t1, t2) => false
+ case (_, _) => false
}
}
diff --git a/src/test/scala/firrtlTests/IntegrationSpec.scala b/src/test/scala/firrtlTests/IntegrationSpec.scala
index 647aa91b..54923be9 100644
--- a/src/test/scala/firrtlTests/IntegrationSpec.scala
+++ b/src/test/scala/firrtlTests/IntegrationSpec.scala
@@ -48,6 +48,8 @@ class GCDSplitEmissionExecutionTest extends FirrtlFlatSpec {
}
}
-class RocketCompilationTest extends CompilationTest("rocket", "/regress")
-class BOOMRobCompilationTest extends CompilationTest("Rob", "/regress")
+class RobCompilationTest extends CompilationTest("Rob", "/regress")
+class RocketCoreCompilationTest extends CompilationTest("RocketCore", "/regress")
+class ICacheCompilationTest extends CompilationTest("ICache", "/regress")
+class FPUCompilationTest extends CompilationTest("FPU", "/regress")
diff --git a/src/test/scala/firrtlTests/WidthSpec.scala b/src/test/scala/firrtlTests/WidthSpec.scala
index 770c2785..d1d02ee2 100644
--- a/src/test/scala/firrtlTests/WidthSpec.scala
+++ b/src/test/scala/firrtlTests/WidthSpec.scala
@@ -22,42 +22,6 @@ class WidthSpec extends FirrtlFlatSpec {
}
}
- "Add of UInt<2> and SInt<2>" should "return SInt<4>" in {
- val passes = Seq(
- ToWorkingIR,
- CheckHighForm,
- ResolveKinds,
- InferTypes,
- CheckTypes,
- InferWidths)
- val input =
- """circuit Unit :
- | module Unit :
- | input x: UInt<2>
- | input y: SInt<2>
- | output z: SInt
- | z <= add(x, y)""".stripMargin
- val check = Seq( "output z : SInt<4>")
- executeTest(input, check, passes)
- }
- "SInt<2> - UInt<3>" should "return SInt<5>" in {
- val passes = Seq(
- ToWorkingIR,
- CheckHighForm,
- ResolveKinds,
- InferTypes,
- CheckTypes,
- InferWidths)
- val input =
- """circuit Unit :
- | module Unit :
- | input x: UInt<3>
- | input y: SInt<2>
- | output z: SInt
- | z <= sub(y, x)""".stripMargin
- val check = Seq( "output z : SInt<5>")
- executeTest(input, check, passes)
- }
"Dshl by 20 bits" should "result in an error" in {
val passes = Seq(
ToWorkingIR,
@@ -121,4 +85,46 @@ class WidthSpec extends FirrtlFlatSpec {
executeTest(input, Nil, passes)
}
}
+
+ "Add of UInt<2> and SInt<2>" should "error" in {
+ val passes = Seq(
+ ToWorkingIR,
+ CheckHighForm,
+ ResolveKinds,
+ InferTypes,
+ CheckTypes,
+ InferWidths)
+ val input =
+ """circuit Unit :
+ | module Unit :
+ | input x: UInt<2>
+ | input y: SInt<2>
+ | output z: SInt
+ | z <= add(x, y)""".stripMargin
+ val check = Seq( "output z : SInt<4>")
+ intercept[PassExceptions] {
+ executeTest(input, check, passes)
+ }
+ }
+
+ "SInt<2> - UInt<3>" should "error" in {
+ val passes = Seq(
+ ToWorkingIR,
+ CheckHighForm,
+ ResolveKinds,
+ InferTypes,
+ CheckTypes,
+ InferWidths)
+ val input =
+ """circuit Unit :
+ | module Unit :
+ | input x: UInt<3>
+ | input y: SInt<2>
+ | output z: SInt
+ | z <= sub(y, x)""".stripMargin
+ val check = Seq( "output z : SInt<5>")
+ intercept[PassExceptions] {
+ executeTest(input, check, passes)
+ }
+ }
}
diff --git a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala
index 37209786..a866836f 100644
--- a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala
+++ b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala
@@ -197,12 +197,11 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec {
| module Unit :
| input a : Fixed<10><<2>>
| input b : Fixed<7><<3>>
- | input c : UInt<2>
| output cat : UInt
| output head : UInt
| output tail : UInt
| output bits : UInt
- | cat <= cat(a, c)
+ | cat <= cat(a, b)
| head <= head(a, 3)
| tail <= tail(a, 3)
| bits <= bits(a, 6, 3)""".stripMargin
@@ -211,12 +210,11 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec {
| module Unit :
| input a : Fixed<10><<2>>
| input b : Fixed<7><<3>>
- | input c : UInt<2>
- | output cat : UInt<12>
+ | output cat : UInt<17>
| output head : UInt<3>
| output tail : UInt<7>
| output bits : UInt<4>
- | cat <= cat(a, c)
+ | cat <= cat(a, b)
| head <= head(a, 3)
| tail <= tail(a, 3)
| bits <= bits(a, 6, 3)""".stripMargin