diff options
| author | Adam Izraelevitz | 2018-02-21 14:30:00 -0800 |
|---|---|---|
| committer | GitHub | 2018-02-21 14:30:00 -0800 |
| commit | 65bbf155003a86cd836f7ff4a2def6af91794780 (patch) | |
| tree | 49c968e051a36c323fd0a5839ea6e1432b2f56aa /src | |
| parent | edcb81a34dbf8a04d0b011aa1ca07c6e19598f23 (diff) | |
Change primop arg type (#587)
* Changed primops to not accept mixed-type args
* Changed return type of sub of two uints to uint
* Added negative tests
* Removed rocket.fir. Manually changed RocketCore to not mix mul arg types. Added integration tests
* Clarified test description and remove println
* Fixed use of throwInternalError
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/PrimOps.scala | 24 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/CheckWidths.scala | 6 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Checks.scala | 97 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/IntegrationSpec.scala | 6 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/WidthSpec.scala | 78 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala | 8 |
6 files changed, 92 insertions, 127 deletions
diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala index 1ca005d7..0e88ff45 100644 --- a/src/main/scala/firrtl/PrimOps.scala +++ b/src/main/scala/firrtl/PrimOps.scala @@ -129,86 +129,64 @@ object PrimOps extends LazyLogging { e copy (tpe = e.op match { case Add => (t1, t2) match { case (_: UIntType, _: UIntType) => UIntType(PLUS(MAX(w1, w2), IntWidth(1))) - case (_: UIntType, _: SIntType) => SIntType(PLUS(MAX(w1, MINUS(w2, IntWidth(1))), IntWidth(2))) - case (_: SIntType, _: UIntType) => SIntType(PLUS(MAX(w2, MINUS(w1, IntWidth(1))), IntWidth(2))) case (_: SIntType, _: SIntType) => SIntType(PLUS(MAX(w1, w2), IntWidth(1))) case (_: FixedType, _: FixedType) => FixedType(PLUS(PLUS(MAX(p1, p2), MAX(MINUS(w1, p1), MINUS(w2, p2))), IntWidth(1)), MAX(p1, p2)) case _ => UnknownType } case Sub => (t1, t2) match { - case (_: UIntType, _: UIntType) => SIntType(PLUS(MAX(w1, w2), IntWidth(1))) - case (_: UIntType, _: SIntType) => SIntType(MAX(PLUS(w2, IntWidth(1)), PLUS(w1, IntWidth(2)))) - case (_: SIntType, _: UIntType) => SIntType(MAX(PLUS(w1, IntWidth(1)), PLUS(w2, IntWidth(2)))) + case (_: UIntType, _: UIntType) => UIntType(PLUS(MAX(w1, w2), IntWidth(1))) case (_: SIntType, _: SIntType) => SIntType(PLUS(MAX(w1, w2), IntWidth(1))) case (_: FixedType, _: FixedType) => FixedType(PLUS(PLUS(MAX(p1, p2),MAX(MINUS(w1, p1), MINUS(w2, p2))),IntWidth(1)), MAX(p1, p2)) case _ => UnknownType } case Mul => (t1, t2) match { case (_: UIntType, _: UIntType) => UIntType(PLUS(w1, w2)) - case (_: UIntType, _: SIntType) => SIntType(PLUS(w1, w2)) - case (_: SIntType, _: UIntType) => SIntType(PLUS(w1, w2)) case (_: SIntType, _: SIntType) => SIntType(PLUS(w1, w2)) case (_: FixedType, _: FixedType) => FixedType(PLUS(w1, w2), PLUS(p1, p2)) case _ => UnknownType } case Div => (t1, t2) match { case (_: UIntType, _: UIntType) => UIntType(w1) - case (_: UIntType, _: SIntType) => SIntType(PLUS(w1, IntWidth(1))) - case (_: SIntType, _: UIntType) => SIntType(w1) case (_: SIntType, _: SIntType) => SIntType(PLUS(w1, IntWidth(1))) case _ => UnknownType } case Rem => (t1, t2) match { case (_: UIntType, _: UIntType) => UIntType(MIN(w1, w2)) - case (_: UIntType, _: SIntType) => UIntType(MIN(w1, w2)) - case (_: SIntType, _: UIntType) => SIntType(MIN(w1, PLUS(w2, IntWidth(1)))) case (_: SIntType, _: SIntType) => SIntType(MIN(w1, w2)) case _ => UnknownType } case Lt => (t1, t2) match { case (_: UIntType, _: UIntType) => Utils.BoolType - case (_: SIntType, _: UIntType) => Utils.BoolType - case (_: UIntType, _: SIntType) => Utils.BoolType case (_: SIntType, _: SIntType) => Utils.BoolType case (_: FixedType, _: FixedType) => Utils.BoolType case _ => UnknownType } case Leq => (t1, t2) match { case (_: UIntType, _: UIntType) => Utils.BoolType - case (_: SIntType, _: UIntType) => Utils.BoolType - case (_: UIntType, _: SIntType) => Utils.BoolType case (_: SIntType, _: SIntType) => Utils.BoolType case (_: FixedType, _: FixedType) => Utils.BoolType case _ => UnknownType } case Gt => (t1, t2) match { case (_: UIntType, _: UIntType) => Utils.BoolType - case (_: SIntType, _: UIntType) => Utils.BoolType - case (_: UIntType, _: SIntType) => Utils.BoolType case (_: SIntType, _: SIntType) => Utils.BoolType case (_: FixedType, _: FixedType) => Utils.BoolType case _ => UnknownType } case Geq => (t1, t2) match { case (_: UIntType, _: UIntType) => Utils.BoolType - case (_: SIntType, _: UIntType) => Utils.BoolType - case (_: UIntType, _: SIntType) => Utils.BoolType case (_: SIntType, _: SIntType) => Utils.BoolType case (_: FixedType, _: FixedType) => Utils.BoolType case _ => UnknownType } case Eq => (t1, t2) match { case (_: UIntType, _: UIntType) => Utils.BoolType - case (_: SIntType, _: UIntType) => Utils.BoolType - case (_: UIntType, _: SIntType) => Utils.BoolType case (_: SIntType, _: SIntType) => Utils.BoolType case (_: FixedType, _: FixedType) => Utils.BoolType case _ => UnknownType } case Neq => (t1, t2) match { case (_: UIntType, _: UIntType) => Utils.BoolType - case (_: SIntType, _: UIntType) => Utils.BoolType - case (_: UIntType, _: SIntType) => Utils.BoolType case (_: SIntType, _: SIntType) => Utils.BoolType case (_: FixedType, _: FixedType) => Utils.BoolType case _ => UnknownType diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala index 55391d99..7406f09a 100644 --- a/src/main/scala/firrtl/passes/CheckWidths.scala +++ b/src/main/scala/firrtl/passes/CheckWidths.scala @@ -22,8 +22,8 @@ object CheckWidths extends Pass { s"$info : [module $mname] Width of dshl shift amount cannot be larger than $DshlMaxWidth bits.") class NegWidthException(info:Info, mname: String) extends PassException( s"$info: [module $mname] Width cannot be negative or zero.") - class BitsWidthException(info: Info, mname: String, hi: BigInt, width: BigInt) extends PassException( - s"$info: [module $mname] High bit $hi in bits operator is larger than input width $width.") + class BitsWidthException(info: Info, mname: String, hi: BigInt, width: BigInt, exp: String) extends PassException( + s"$info: [module $mname] High bit $hi in bits operator is larger than input width $width in $exp.") class HeadWidthException(info: Info, mname: String, n: BigInt, width: BigInt) extends PassException( s"$info: [module $mname] Parameter $n in head operator is larger than input width $width.") class TailWidthException(info: Info, mname: String, n: BigInt, width: BigInt) extends PassException( @@ -69,7 +69,7 @@ object CheckWidths extends Pass { case _ => } case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) <= hi) => - errors append new BitsWidthException(info, mname, hi, bitWidth(a.tpe)) + errors append new BitsWidthException(info, mname, hi, bitWidth(a.tpe), e.serialize) case DoPrim(Head, Seq(a), Seq(n), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) < n) => errors append new HeadWidthException(info, mname, n, bitWidth(a.tpe)) case DoPrim(Tail, Seq(a), Seq(n), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) <= n) => diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index ce599112..6934fca2 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -112,7 +112,7 @@ object CheckHighForm extends Pass { if (npercents != i) errors.append(new BadPrintfIncorrectNumException(info, mname)) } - def checkValidLoc(info: Info, mname: String, e: Expression) = e match { + def checkValidLoc(info: Info, mname: String, e: Expression): Unit = e match { case _: UIntLiteral | _: SIntLiteral | _: DoPrim => errors.append(new InvalidLOCException(info, mname)) case _ => // Do Nothing @@ -254,6 +254,7 @@ object CheckTypes extends Pass { class OpNotAllSameType(info: Info, mname: String, op: String) extends PassException( s"$info: [module $mname] Primop $op requires all operands to have the same type.") class OpNoMixFix(info:Info, mname: String, op: String) extends PassException(s"${info}: [module ${mname}] Primop ${op} cannot operate on args of some, but not all, fixed type.") + class OpNotCorrectType(info:Info, mname: String, op: String, tpes: Seq[String]) extends PassException(s"${info}: [module ${mname}] Primop ${op} does not have correct arg types: $tpes.") class OpNotAnalog(info: Info, mname: String, exp: String) extends PassException( s"$info: [module $mname] Attach requires all arguments to be Analog type: $exp.") class NodePassiveType(info: Info, mname: String) extends PassException( @@ -276,6 +277,9 @@ object CheckTypes extends Pass { s"$info: [module $mname] Attach expression must be an port, wire, or port of instance: $expName.") class IllegalResetType(info: Info, mname: String, exp: String) extends PassException( s"$info: [module $mname] Register resets must have type UInt<1>: $exp.") + class IllegalUnknownType(info: Info, mname: String, exp: String) extends PassException( + s"$info: [module $mname] Uninferred type: $exp." + ) //;---------------- Helper Functions -------------- def ut: UIntType = UIntType(UnknownWidth) @@ -290,65 +294,42 @@ object CheckTypes extends Pass { case tx: BundleType => tx.fields forall (x => x.flip == Default && passive(x.tpe)) case tx => true } - def check_types_primop(info: Info, mname: String, e: DoPrim) { - def all_same_type (ls:Seq[Expression]) { - if (ls exists (x => wt(ls.head.tpe) != wt(e.tpe))) - errors.append(new OpNotAllSameType(info, mname, e.op.serialize)) - } - def allUSC(ls: Seq[Expression]) { - val error = ls.foldLeft(false)((error, x) => x.tpe match { - case (_: UIntType| _: SIntType| ClockType) => error - case _ => true - }) - if (error) errors.append(new OpNotGround(info, mname, e.op.serialize)) - } - def allUSF(ls: Seq[Expression]) { - val error = ls.foldLeft(false)((error, x) => x.tpe match { - case (_: UIntType| _: SIntType| _: FixedType) => error - case _ => true - }) - if (error) errors.append(new OpNotGround(info, mname, e.op.serialize)) - } - def allUS(ls: Seq[Expression]) { - if (ls exists (x => x.tpe match { - case _: UIntType | _: SIntType => false - case _ => true - })) errors.append(new OpNotGround(info, mname, e.op.serialize)) - } - def allF(ls: Seq[Expression]) { - val error = ls.foldLeft(false)((error, x) => x.tpe match { - case _:FixedType => error - case _ => true - }) - if (error) errors.append(new OpNotGround(info, mname, e.op.serialize)) - } - def strictFix(ls: Seq[Expression]) = - ls.filter(!_.tpe.isInstanceOf[FixedType]).size match { - case 0 => - case x if(x == ls.size) => - case x => errors.append(new OpNoMixFix(info, mname, e.op.serialize)) + def check_types_primop(info: Info, mname: String, e: DoPrim): Unit = { + def checkAllTypes(exprs: Seq[Expression], okUInt: Boolean, okSInt: Boolean, okClock: Boolean, okFix: Boolean): Unit = { + exprs.foldLeft((false, false, false, false)) { + case ((isUInt, isSInt, isClock, isFix), expr) => expr.tpe match { + case u: UIntType => (true, isSInt, isClock, isFix) + case s: SIntType => (isUInt, true, isClock, isFix) + case ClockType => (isUInt, isSInt, true, isFix) + case f: FixedType => (isUInt, isSInt, isClock, true) + case UnknownType => + errors.append(new IllegalUnknownType(info, mname, e.serialize)) + (isUInt, isSInt, isClock, isFix) + case other => throwInternalError(Some(s"Illegal Type: ${other.serialize}")) + } + } match { + // (UInt, SInt, Clock, Fixed) + case (isAll, false, false, false) if isAll == okUInt => + case (false, isAll, false, false) if isAll == okSInt => + case (false, false, isAll, false) if isAll == okClock => + case (false, false, false, isAll) if isAll == okFix => + case x => errors.append(new OpNotCorrectType(info, mname, e.op.serialize, exprs.map(_.tpe.serialize))) } - def all_uint (ls: Seq[Expression]) { - if (ls exists (x => x.tpe match { - case _: UIntType => false - case _ => true - })) errors.append(new OpNotAllUInt(info, mname, e.op.serialize)) - } - def is_uint (x:Expression) { - if (x.tpe match { - case _: UIntType => false - case _ => true - }) errors.append(new OpNotUInt(info, mname, e.op.serialize, x.serialize)) } e.op match { - case AsUInt | AsSInt | AsFixedPoint => - case AsClock => allUSC(e.args) - case Dshl => is_uint(e.args(1)); allUSF(e.args) - case Dshr => is_uint(e.args(1)); allUSF(e.args) - case Add | Sub | Mul | Lt | Leq | Gt | Geq | Eq | Neq => allUSF(e.args); strictFix(e.args) - case Pad | Shl | Shr | Cat | Bits | Head | Tail => allUSF(e.args) - case BPShl | BPShr | BPSet => allF(e.args) - case _ => allUS(e.args) + case AsUInt | AsSInt | AsClock | AsFixedPoint => + // All types are ok + case Dshl | Dshr => + checkAllTypes(Seq(e.args.head), okUInt=true, okSInt=true, okClock=false, okFix=true) + checkAllTypes(Seq(e.args(1)), okUInt=true, okSInt=false, okClock=false, okFix=false) + case Add | Sub | Mul | Lt | Leq | Gt | Geq | Eq | Neq => + checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true) + case Pad | Shl | Shr | Cat | Bits | Head | Tail => + checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true) + case BPShl | BPShr | BPSet => + checkAllTypes(e.args, okUInt=false, okSInt=false, okClock=false, okFix=true) + case _ => + checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=false) } } @@ -421,7 +402,7 @@ object CheckTypes extends Pass { ) case (t1: VectorType, t2: VectorType) => bulk_equals(t1.tpe, t2.tpe, flip1, flip2) - case (t1, t2) => false + case (_, _) => false } } diff --git a/src/test/scala/firrtlTests/IntegrationSpec.scala b/src/test/scala/firrtlTests/IntegrationSpec.scala index 647aa91b..54923be9 100644 --- a/src/test/scala/firrtlTests/IntegrationSpec.scala +++ b/src/test/scala/firrtlTests/IntegrationSpec.scala @@ -48,6 +48,8 @@ class GCDSplitEmissionExecutionTest extends FirrtlFlatSpec { } } -class RocketCompilationTest extends CompilationTest("rocket", "/regress") -class BOOMRobCompilationTest extends CompilationTest("Rob", "/regress") +class RobCompilationTest extends CompilationTest("Rob", "/regress") +class RocketCoreCompilationTest extends CompilationTest("RocketCore", "/regress") +class ICacheCompilationTest extends CompilationTest("ICache", "/regress") +class FPUCompilationTest extends CompilationTest("FPU", "/regress") diff --git a/src/test/scala/firrtlTests/WidthSpec.scala b/src/test/scala/firrtlTests/WidthSpec.scala index 770c2785..d1d02ee2 100644 --- a/src/test/scala/firrtlTests/WidthSpec.scala +++ b/src/test/scala/firrtlTests/WidthSpec.scala @@ -22,42 +22,6 @@ class WidthSpec extends FirrtlFlatSpec { } } - "Add of UInt<2> and SInt<2>" should "return SInt<4>" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes, - InferWidths) - val input = - """circuit Unit : - | module Unit : - | input x: UInt<2> - | input y: SInt<2> - | output z: SInt - | z <= add(x, y)""".stripMargin - val check = Seq( "output z : SInt<4>") - executeTest(input, check, passes) - } - "SInt<2> - UInt<3>" should "return SInt<5>" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes, - InferWidths) - val input = - """circuit Unit : - | module Unit : - | input x: UInt<3> - | input y: SInt<2> - | output z: SInt - | z <= sub(y, x)""".stripMargin - val check = Seq( "output z : SInt<5>") - executeTest(input, check, passes) - } "Dshl by 20 bits" should "result in an error" in { val passes = Seq( ToWorkingIR, @@ -121,4 +85,46 @@ class WidthSpec extends FirrtlFlatSpec { executeTest(input, Nil, passes) } } + + "Add of UInt<2> and SInt<2>" should "error" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + InferWidths) + val input = + """circuit Unit : + | module Unit : + | input x: UInt<2> + | input y: SInt<2> + | output z: SInt + | z <= add(x, y)""".stripMargin + val check = Seq( "output z : SInt<4>") + intercept[PassExceptions] { + executeTest(input, check, passes) + } + } + + "SInt<2> - UInt<3>" should "error" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + InferWidths) + val input = + """circuit Unit : + | module Unit : + | input x: UInt<3> + | input y: SInt<2> + | output z: SInt + | z <= sub(y, x)""".stripMargin + val check = Seq( "output z : SInt<5>") + intercept[PassExceptions] { + executeTest(input, check, passes) + } + } } diff --git a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala index 37209786..a866836f 100644 --- a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala +++ b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala @@ -197,12 +197,11 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | module Unit : | input a : Fixed<10><<2>> | input b : Fixed<7><<3>> - | input c : UInt<2> | output cat : UInt | output head : UInt | output tail : UInt | output bits : UInt - | cat <= cat(a, c) + | cat <= cat(a, b) | head <= head(a, 3) | tail <= tail(a, 3) | bits <= bits(a, 6, 3)""".stripMargin @@ -211,12 +210,11 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | module Unit : | input a : Fixed<10><<2>> | input b : Fixed<7><<3>> - | input c : UInt<2> - | output cat : UInt<12> + | output cat : UInt<17> | output head : UInt<3> | output tail : UInt<7> | output bits : UInt<4> - | cat <= cat(a, c) + | cat <= cat(a, b) | head <= head(a, 3) | tail <= tail(a, 3) | bits <= bits(a, 6, 3)""".stripMargin |
