diff options
| author | Jack Koenig | 2019-08-13 12:09:27 +0530 |
|---|---|---|
| committer | GitHub | 2019-08-13 12:09:27 +0530 |
| commit | f08f8dbb3c480220f92923a7f3242fcbb644b65e (patch) | |
| tree | 45cdb7543f6252ad2feb5aaf4e0e0580d3d27565 /src/main/scala/firrtl/passes | |
| parent | 63e88b6e1696e2c8d6da91f6f5eb128a9d0395ae (diff) | |
Infer reset (#1068)
* Add abstract "Reset" which can be inferred to AsyncReset or UInt<1>
* Enhance async reset initial value literal check to support aggregates
Diffstat (limited to 'src/main/scala/firrtl/passes')
| -rw-r--r-- | src/main/scala/firrtl/passes/CheckWidths.scala | 1 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Checks.scala | 126 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/InferWidths.scala | 2 |
3 files changed, 87 insertions, 42 deletions
diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala index b0d9085b..07784e19 100644 --- a/src/main/scala/firrtl/passes/CheckWidths.scala +++ b/src/main/scala/firrtl/passes/CheckWidths.scala @@ -108,6 +108,7 @@ object CheckWidths extends Pass { sx.reset.tpe match { case UIntType(IntWidth(w)) if w == 1 => case AsyncResetType => + case ResetType => case _ => errors.append(new CheckTypes.IllegalResetType(info, target.serialize, sx.name)) } case _ => diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index 972a018e..a17a5a2e 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -57,8 +57,11 @@ trait CheckHighFormLike { s"$info: [module $mname] Primop $op argument $value < 0.") class LsbLargerThanMsbException(info: Info, mname: String, op: String, lsb: Int, msb: Int) extends PassException( s"$info: [module $mname] Primop $op lsb $lsb > $msb.") - class NonLiteralAsyncResetValueException(info: Info, mname: String, reg: String, init: String) extends PassException( - s"$info: [module $mname] AsyncReset Reg '$reg' reset to non-literal '$init'") + class ResetInputException(info: Info, mname: String, expr: Expression) extends PassException( + s"$info: [module $mname] Abstract Reset not allowed as top-level input: ${expr.serialize}") + class ResetExtModuleOutputException(info: Info, mname: String, expr: Expression) extends PassException( + s"$info: [module $mname] Abstract Reset not allowed as ExtModule output: ${expr.serialize}") + // Is Chirrtl allowed for this check? If not, return an error def errorOnChirrtl(info: Info, mname: String, s: Statement): Option[PassException] @@ -188,8 +191,6 @@ trait CheckHighFormLike { case DefRegister(info, name, tpe, _, reset, init) => if (hasFlip(tpe)) errors.append(new RegWithFlipException(info, mname, name)) - if (reset.tpe == AsyncResetType && !init.isInstanceOf[Literal]) - errors.append(new NonLiteralAsyncResetValueException(info, mname, name, init.serialize)) case sx: DefMemory => if (sx.readLatency < 0 || sx.writeLatency <= 0) errors.append(new IllegalMemLatencyException(info, mname, sx.name)) @@ -218,15 +219,36 @@ trait CheckHighFormLike { p.tpe foreach checkHighFormW(p.info, mname) } + // Search for ResetType Ports of direction + def findBadResetTypePorts(m: DefModule, dir: Direction): Seq[(Port, Expression)] = { + val bad = to_gender(dir) + for { + port <- m.ports + ref = WRef(port).copy(gender = to_gender(port.direction)) + expr <- create_exps(ref) + if ((expr.tpe == ResetType) && (gender(expr) == bad)) + } yield (port, expr) + } + def checkHighFormM(m: DefModule): Unit = { val names = new NameSet m foreach checkHighFormP(m.name, names) m foreach checkHighFormS(m.info, m.name, names) + m match { + case _: Module => + case ext: ExtModule => + for ((port, expr) <- findBadResetTypePorts(ext, Output)) { + errors.append(new ResetExtModuleOutputException(port.info, ext.name, expr)) + } + } } c.modules foreach checkHighFormM - c.modules count (_.name == c.main) match { - case 1 => + c.modules.filter(_.name == c.main) match { + case Seq(topMod) => + for ((port, expr) <- findBadResetTypePorts(topMod, Input)) { + errors.append(new ResetInputException(port.info, topMod.name, expr)) + } case _ => errors.append(new NoTopModuleException(c.info, c.main)) } errors.trigger() @@ -263,8 +285,12 @@ object CheckTypes extends Pass { s"$info: [module $mname] Index is not of UIntType.") class EnableNotUInt(info: Info, mname: String) extends PassException( s"$info: [module $mname] Enable is not of UIntType.") - class InvalidConnect(info: Info, mname: String, lhs: String, rhs: String) extends PassException( - s"$info: [module $mname] Type mismatch. Cannot connect $lhs to $rhs.") + class InvalidConnect(info: Info, mname: String, con: String, lhs: Expression, rhs: Expression) + extends PassException({ + val ltpe = s" ${lhs.serialize}: ${lhs.tpe.serialize}" + val rtpe = s" ${rhs.serialize}: ${rhs.tpe.serialize}" + s"$info: [module $mname] Type mismatch in '$con'.\n$ltpe\n$rtpe" + }) class InvalidRegInit(info: Info, mname: String) extends PassException( s"$info: [module $mname] Type of init must match type of DefRegister.") class PrintfArgNotGround(info: Info, mname: String) extends PassException( @@ -308,11 +334,52 @@ object CheckTypes extends Pass { class IllegalAttachExp(info: Info, mname: String, expName: String) extends PassException( s"$info: [module $mname] Attach expression must be an port, wire, or port of instance: $expName.") class IllegalResetType(info: Info, mname: String, exp: String) extends PassException( - s"$info: [module $mname] Register resets must have type UInt<1>: $exp.") + s"$info: [module $mname] Register resets must have type Reset, AsyncReset, or UInt<1>: $exp.") class IllegalUnknownType(info: Info, mname: String, exp: String) extends PassException( s"$info: [module $mname] Uninferred type: $exp." ) + def legalResetType(tpe: Type): Boolean = tpe match { + case UIntType(IntWidth(w)) if w == 1 => true + case AsyncResetType => true + case ResetType => true + case UIntType(UnknownWidth) => + // cannot catch here, though width may ultimately be wrong + true + case _ => false + } + + private def bulk_equals(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Boolean = { + (t1, t2) match { + case (ClockType, ClockType) => flip1 == flip2 + case (_: UIntType, _: UIntType) => flip1 == flip2 + case (_: SIntType, _: SIntType) => flip1 == flip2 + case (_: FixedType, _: FixedType) => flip1 == flip2 + case (_: AnalogType, _: AnalogType) => true + case (AsyncResetType, AsyncResetType) => flip1 == flip2 + case (ResetType, tpe) => legalResetType(tpe) && flip1 == flip2 + case (tpe, ResetType) => legalResetType(tpe) && flip1 == flip2 + case (t1: BundleType, t2: BundleType) => + val t1_fields = (t1.fields foldLeft Map[String, (Type, Orientation)]())( + (map, f1) => map + (f1.name ->( (f1.tpe, f1.flip) ))) + t2.fields forall (f2 => + t1_fields get f2.name match { + case None => true + case Some((f1_tpe, f1_flip)) => + bulk_equals(f1_tpe, f2.tpe, times(flip1, f1_flip), times(flip2, f2.flip)) + } + ) + case (t1: VectorType, t2: VectorType) => + bulk_equals(t1.tpe, t2.tpe, flip1, flip2) + case (_, _) => false + } + } + + def validConnect(con: Connect): Boolean = wt(con.loc.tpe).superTypeOf(wt(con.expr.tpe)) + + def validPartialConnect(con: PartialConnect): Boolean = + bulk_equals(con.loc.tpe, con.expr.tpe, Default, Default) + //;---------------- Helper Functions -------------- def ut: UIntType = UIntType(UnknownWidth) def st: SIntType = SIntType(UnknownWidth) @@ -414,48 +481,23 @@ object CheckTypes extends Pass { e foreach check_types_e(info, mname) } - def bulk_equals(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Boolean = { - //;println_all(["Inside with t1:" t1 ",t2:" t2 ",f1:" flip1 ",f2:" flip2]) - (t1, t2) match { - case (ClockType, ClockType) => flip1 == flip2 - case (_: UIntType, _: UIntType) => flip1 == flip2 - case (_: SIntType, _: SIntType) => flip1 == flip2 - case (_: FixedType, _: FixedType) => flip1 == flip2 - case (_: AnalogType, _: AnalogType) => true - case (t1: BundleType, t2: BundleType) => - val t1_fields = (t1.fields foldLeft Map[String, (Type, Orientation)]())( - (map, f1) => map + (f1.name ->( (f1.tpe, f1.flip) ))) - t2.fields forall (f2 => - t1_fields get f2.name match { - case None => true - case Some((f1_tpe, f1_flip)) => - bulk_equals(f1_tpe, f2.tpe, times(flip1, f1_flip), times(flip2, f2.flip)) - } - ) - case (t1: VectorType, t2: VectorType) => - bulk_equals(t1.tpe, t2.tpe, flip1, flip2) - case (_, _) => false - } - } - def check_types_s(minfo: Info, mname: String)(s: Statement): Unit = { val info = get_info(s) match { case NoInfo => minfo case x => x } s match { - case sx: Connect if wt(sx.loc.tpe) != wt(sx.expr.tpe) => - errors.append(new InvalidConnect(info, mname, sx.loc.serialize, sx.expr.serialize)) - case sx: PartialConnect if !bulk_equals(sx.loc.tpe, sx.expr.tpe, Default, Default) => - errors.append(new InvalidConnect(info, mname, sx.loc.serialize, sx.expr.serialize)) + case sx: Connect if !validConnect(sx) => + val conMsg = sx.copy(info = NoInfo).serialize + errors.append(new InvalidConnect(info, mname, conMsg, sx.loc, sx.expr)) + case sx: PartialConnect if !validPartialConnect(sx) => + val conMsg = sx.copy(info = NoInfo).serialize + errors.append(new InvalidConnect(info, mname, conMsg, sx.loc, sx.expr)) case sx: DefRegister => sx.tpe match { case AnalogType(_) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name)) case t if wt(sx.tpe) != wt(sx.init.tpe) => errors.append(new InvalidRegInit(info, mname)) case t => } - sx.reset.tpe match { - case UIntType(IntWidth(w)) if w == 1 => - case AsyncResetType => - case UIntType(UnknownWidth) => // cannot catch here, though width may ultimately be wrong - case _ => errors.append(new IllegalResetType(info, mname, sx.name)) + if (!legalResetType(sx.reset.tpe)) { + errors.append(new IllegalResetType(info, mname, sx.name)) } if (sx.clock.tpe != ClockType) { errors.append(new RegReqClk(info, mname, sx.name)) diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index cf6f2ae0..8f663afd 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -262,6 +262,8 @@ class InferWidths extends Transform with ResolvedAnnotationPaths { }) } case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe) + case (ResetType, _) => Nil + case (_, ResetType) => Nil } def run(c: Circuit, extra: Seq[WGeq]): Circuit = { |
