diff options
| author | Adam Izraelevitz | 2018-02-21 14:30:00 -0800 |
|---|---|---|
| committer | GitHub | 2018-02-21 14:30:00 -0800 |
| commit | 65bbf155003a86cd836f7ff4a2def6af91794780 (patch) | |
| tree | 49c968e051a36c323fd0a5839ea6e1432b2f56aa /src/main | |
| parent | edcb81a34dbf8a04d0b011aa1ca07c6e19598f23 (diff) | |
Change primop arg type (#587)
* Changed primops to not accept mixed-type args
* Changed return type of sub of two uints to uint
* Added negative tests
* Removed rocket.fir. Manually changed RocketCore to not mix mul arg types. Added integration tests
* Clarified test description and remove println
* Fixed use of throwInternalError
Diffstat (limited to 'src/main')
| -rw-r--r-- | src/main/scala/firrtl/PrimOps.scala | 24 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/CheckWidths.scala | 6 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Checks.scala | 97 |
3 files changed, 43 insertions, 84 deletions
diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala index 1ca005d7..0e88ff45 100644 --- a/src/main/scala/firrtl/PrimOps.scala +++ b/src/main/scala/firrtl/PrimOps.scala @@ -129,86 +129,64 @@ object PrimOps extends LazyLogging { e copy (tpe = e.op match { case Add => (t1, t2) match { case (_: UIntType, _: UIntType) => UIntType(PLUS(MAX(w1, w2), IntWidth(1))) - case (_: UIntType, _: SIntType) => SIntType(PLUS(MAX(w1, MINUS(w2, IntWidth(1))), IntWidth(2))) - case (_: SIntType, _: UIntType) => SIntType(PLUS(MAX(w2, MINUS(w1, IntWidth(1))), IntWidth(2))) case (_: SIntType, _: SIntType) => SIntType(PLUS(MAX(w1, w2), IntWidth(1))) case (_: FixedType, _: FixedType) => FixedType(PLUS(PLUS(MAX(p1, p2), MAX(MINUS(w1, p1), MINUS(w2, p2))), IntWidth(1)), MAX(p1, p2)) case _ => UnknownType } case Sub => (t1, t2) match { - case (_: UIntType, _: UIntType) => SIntType(PLUS(MAX(w1, w2), IntWidth(1))) - case (_: UIntType, _: SIntType) => SIntType(MAX(PLUS(w2, IntWidth(1)), PLUS(w1, IntWidth(2)))) - case (_: SIntType, _: UIntType) => SIntType(MAX(PLUS(w1, IntWidth(1)), PLUS(w2, IntWidth(2)))) + case (_: UIntType, _: UIntType) => UIntType(PLUS(MAX(w1, w2), IntWidth(1))) case (_: SIntType, _: SIntType) => SIntType(PLUS(MAX(w1, w2), IntWidth(1))) case (_: FixedType, _: FixedType) => FixedType(PLUS(PLUS(MAX(p1, p2),MAX(MINUS(w1, p1), MINUS(w2, p2))),IntWidth(1)), MAX(p1, p2)) case _ => UnknownType } case Mul => (t1, t2) match { case (_: UIntType, _: UIntType) => UIntType(PLUS(w1, w2)) - case (_: UIntType, _: SIntType) => SIntType(PLUS(w1, w2)) - case (_: SIntType, _: UIntType) => SIntType(PLUS(w1, w2)) case (_: SIntType, _: SIntType) => SIntType(PLUS(w1, w2)) case (_: FixedType, _: FixedType) => FixedType(PLUS(w1, w2), PLUS(p1, p2)) case _ => UnknownType } case Div => (t1, t2) match { case (_: UIntType, _: UIntType) => UIntType(w1) - case (_: UIntType, _: SIntType) => SIntType(PLUS(w1, IntWidth(1))) - case (_: SIntType, _: UIntType) => SIntType(w1) case (_: SIntType, _: SIntType) => SIntType(PLUS(w1, IntWidth(1))) case _ => UnknownType } case Rem => (t1, t2) match { case (_: UIntType, _: UIntType) => UIntType(MIN(w1, w2)) - case (_: UIntType, _: SIntType) => UIntType(MIN(w1, w2)) - case (_: SIntType, _: UIntType) => SIntType(MIN(w1, PLUS(w2, IntWidth(1)))) case (_: SIntType, _: SIntType) => SIntType(MIN(w1, w2)) case _ => UnknownType } case Lt => (t1, t2) match { case (_: UIntType, _: UIntType) => Utils.BoolType - case (_: SIntType, _: UIntType) => Utils.BoolType - case (_: UIntType, _: SIntType) => Utils.BoolType case (_: SIntType, _: SIntType) => Utils.BoolType case (_: FixedType, _: FixedType) => Utils.BoolType case _ => UnknownType } case Leq => (t1, t2) match { case (_: UIntType, _: UIntType) => Utils.BoolType - case (_: SIntType, _: UIntType) => Utils.BoolType - case (_: UIntType, _: SIntType) => Utils.BoolType case (_: SIntType, _: SIntType) => Utils.BoolType case (_: FixedType, _: FixedType) => Utils.BoolType case _ => UnknownType } case Gt => (t1, t2) match { case (_: UIntType, _: UIntType) => Utils.BoolType - case (_: SIntType, _: UIntType) => Utils.BoolType - case (_: UIntType, _: SIntType) => Utils.BoolType case (_: SIntType, _: SIntType) => Utils.BoolType case (_: FixedType, _: FixedType) => Utils.BoolType case _ => UnknownType } case Geq => (t1, t2) match { case (_: UIntType, _: UIntType) => Utils.BoolType - case (_: SIntType, _: UIntType) => Utils.BoolType - case (_: UIntType, _: SIntType) => Utils.BoolType case (_: SIntType, _: SIntType) => Utils.BoolType case (_: FixedType, _: FixedType) => Utils.BoolType case _ => UnknownType } case Eq => (t1, t2) match { case (_: UIntType, _: UIntType) => Utils.BoolType - case (_: SIntType, _: UIntType) => Utils.BoolType - case (_: UIntType, _: SIntType) => Utils.BoolType case (_: SIntType, _: SIntType) => Utils.BoolType case (_: FixedType, _: FixedType) => Utils.BoolType case _ => UnknownType } case Neq => (t1, t2) match { case (_: UIntType, _: UIntType) => Utils.BoolType - case (_: SIntType, _: UIntType) => Utils.BoolType - case (_: UIntType, _: SIntType) => Utils.BoolType case (_: SIntType, _: SIntType) => Utils.BoolType case (_: FixedType, _: FixedType) => Utils.BoolType case _ => UnknownType diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala index 55391d99..7406f09a 100644 --- a/src/main/scala/firrtl/passes/CheckWidths.scala +++ b/src/main/scala/firrtl/passes/CheckWidths.scala @@ -22,8 +22,8 @@ object CheckWidths extends Pass { s"$info : [module $mname] Width of dshl shift amount cannot be larger than $DshlMaxWidth bits.") class NegWidthException(info:Info, mname: String) extends PassException( s"$info: [module $mname] Width cannot be negative or zero.") - class BitsWidthException(info: Info, mname: String, hi: BigInt, width: BigInt) extends PassException( - s"$info: [module $mname] High bit $hi in bits operator is larger than input width $width.") + class BitsWidthException(info: Info, mname: String, hi: BigInt, width: BigInt, exp: String) extends PassException( + s"$info: [module $mname] High bit $hi in bits operator is larger than input width $width in $exp.") class HeadWidthException(info: Info, mname: String, n: BigInt, width: BigInt) extends PassException( s"$info: [module $mname] Parameter $n in head operator is larger than input width $width.") class TailWidthException(info: Info, mname: String, n: BigInt, width: BigInt) extends PassException( @@ -69,7 +69,7 @@ object CheckWidths extends Pass { case _ => } case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) <= hi) => - errors append new BitsWidthException(info, mname, hi, bitWidth(a.tpe)) + errors append new BitsWidthException(info, mname, hi, bitWidth(a.tpe), e.serialize) case DoPrim(Head, Seq(a), Seq(n), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) < n) => errors append new HeadWidthException(info, mname, n, bitWidth(a.tpe)) case DoPrim(Tail, Seq(a), Seq(n), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) <= n) => diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index ce599112..6934fca2 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -112,7 +112,7 @@ object CheckHighForm extends Pass { if (npercents != i) errors.append(new BadPrintfIncorrectNumException(info, mname)) } - def checkValidLoc(info: Info, mname: String, e: Expression) = e match { + def checkValidLoc(info: Info, mname: String, e: Expression): Unit = e match { case _: UIntLiteral | _: SIntLiteral | _: DoPrim => errors.append(new InvalidLOCException(info, mname)) case _ => // Do Nothing @@ -254,6 +254,7 @@ object CheckTypes extends Pass { class OpNotAllSameType(info: Info, mname: String, op: String) extends PassException( s"$info: [module $mname] Primop $op requires all operands to have the same type.") class OpNoMixFix(info:Info, mname: String, op: String) extends PassException(s"${info}: [module ${mname}] Primop ${op} cannot operate on args of some, but not all, fixed type.") + class OpNotCorrectType(info:Info, mname: String, op: String, tpes: Seq[String]) extends PassException(s"${info}: [module ${mname}] Primop ${op} does not have correct arg types: $tpes.") class OpNotAnalog(info: Info, mname: String, exp: String) extends PassException( s"$info: [module $mname] Attach requires all arguments to be Analog type: $exp.") class NodePassiveType(info: Info, mname: String) extends PassException( @@ -276,6 +277,9 @@ object CheckTypes extends Pass { s"$info: [module $mname] Attach expression must be an port, wire, or port of instance: $expName.") class IllegalResetType(info: Info, mname: String, exp: String) extends PassException( s"$info: [module $mname] Register resets must have type UInt<1>: $exp.") + class IllegalUnknownType(info: Info, mname: String, exp: String) extends PassException( + s"$info: [module $mname] Uninferred type: $exp." + ) //;---------------- Helper Functions -------------- def ut: UIntType = UIntType(UnknownWidth) @@ -290,65 +294,42 @@ object CheckTypes extends Pass { case tx: BundleType => tx.fields forall (x => x.flip == Default && passive(x.tpe)) case tx => true } - def check_types_primop(info: Info, mname: String, e: DoPrim) { - def all_same_type (ls:Seq[Expression]) { - if (ls exists (x => wt(ls.head.tpe) != wt(e.tpe))) - errors.append(new OpNotAllSameType(info, mname, e.op.serialize)) - } - def allUSC(ls: Seq[Expression]) { - val error = ls.foldLeft(false)((error, x) => x.tpe match { - case (_: UIntType| _: SIntType| ClockType) => error - case _ => true - }) - if (error) errors.append(new OpNotGround(info, mname, e.op.serialize)) - } - def allUSF(ls: Seq[Expression]) { - val error = ls.foldLeft(false)((error, x) => x.tpe match { - case (_: UIntType| _: SIntType| _: FixedType) => error - case _ => true - }) - if (error) errors.append(new OpNotGround(info, mname, e.op.serialize)) - } - def allUS(ls: Seq[Expression]) { - if (ls exists (x => x.tpe match { - case _: UIntType | _: SIntType => false - case _ => true - })) errors.append(new OpNotGround(info, mname, e.op.serialize)) - } - def allF(ls: Seq[Expression]) { - val error = ls.foldLeft(false)((error, x) => x.tpe match { - case _:FixedType => error - case _ => true - }) - if (error) errors.append(new OpNotGround(info, mname, e.op.serialize)) - } - def strictFix(ls: Seq[Expression]) = - ls.filter(!_.tpe.isInstanceOf[FixedType]).size match { - case 0 => - case x if(x == ls.size) => - case x => errors.append(new OpNoMixFix(info, mname, e.op.serialize)) + def check_types_primop(info: Info, mname: String, e: DoPrim): Unit = { + def checkAllTypes(exprs: Seq[Expression], okUInt: Boolean, okSInt: Boolean, okClock: Boolean, okFix: Boolean): Unit = { + exprs.foldLeft((false, false, false, false)) { + case ((isUInt, isSInt, isClock, isFix), expr) => expr.tpe match { + case u: UIntType => (true, isSInt, isClock, isFix) + case s: SIntType => (isUInt, true, isClock, isFix) + case ClockType => (isUInt, isSInt, true, isFix) + case f: FixedType => (isUInt, isSInt, isClock, true) + case UnknownType => + errors.append(new IllegalUnknownType(info, mname, e.serialize)) + (isUInt, isSInt, isClock, isFix) + case other => throwInternalError(Some(s"Illegal Type: ${other.serialize}")) + } + } match { + // (UInt, SInt, Clock, Fixed) + case (isAll, false, false, false) if isAll == okUInt => + case (false, isAll, false, false) if isAll == okSInt => + case (false, false, isAll, false) if isAll == okClock => + case (false, false, false, isAll) if isAll == okFix => + case x => errors.append(new OpNotCorrectType(info, mname, e.op.serialize, exprs.map(_.tpe.serialize))) } - def all_uint (ls: Seq[Expression]) { - if (ls exists (x => x.tpe match { - case _: UIntType => false - case _ => true - })) errors.append(new OpNotAllUInt(info, mname, e.op.serialize)) - } - def is_uint (x:Expression) { - if (x.tpe match { - case _: UIntType => false - case _ => true - }) errors.append(new OpNotUInt(info, mname, e.op.serialize, x.serialize)) } e.op match { - case AsUInt | AsSInt | AsFixedPoint => - case AsClock => allUSC(e.args) - case Dshl => is_uint(e.args(1)); allUSF(e.args) - case Dshr => is_uint(e.args(1)); allUSF(e.args) - case Add | Sub | Mul | Lt | Leq | Gt | Geq | Eq | Neq => allUSF(e.args); strictFix(e.args) - case Pad | Shl | Shr | Cat | Bits | Head | Tail => allUSF(e.args) - case BPShl | BPShr | BPSet => allF(e.args) - case _ => allUS(e.args) + case AsUInt | AsSInt | AsClock | AsFixedPoint => + // All types are ok + case Dshl | Dshr => + checkAllTypes(Seq(e.args.head), okUInt=true, okSInt=true, okClock=false, okFix=true) + checkAllTypes(Seq(e.args(1)), okUInt=true, okSInt=false, okClock=false, okFix=false) + case Add | Sub | Mul | Lt | Leq | Gt | Geq | Eq | Neq => + checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true) + case Pad | Shl | Shr | Cat | Bits | Head | Tail => + checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true) + case BPShl | BPShr | BPSet => + checkAllTypes(e.args, okUInt=false, okSInt=false, okClock=false, okFix=true) + case _ => + checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=false) } } @@ -421,7 +402,7 @@ object CheckTypes extends Pass { ) case (t1: VectorType, t2: VectorType) => bulk_equals(t1.tpe, t2.tpe, flip1, flip2) - case (t1, t2) => false + case (_, _) => false } } |
