diff options
Diffstat (limited to 'src/main')
| -rw-r--r-- | src/main/antlr4/FIRRTL.g4 | 1 | ||||
| -rw-r--r-- | src/main/proto/firrtl.proto | 5 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Compiler.scala | 5 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Emitter.scala | 7 | ||||
| -rw-r--r-- | src/main/scala/firrtl/LoweringCompilers.scala | 4 | ||||
| -rw-r--r-- | src/main/scala/firrtl/PrimOps.scala | 15 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 9 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Visitor.scala | 1 | ||||
| -rw-r--r-- | src/main/scala/firrtl/WIR.scala | 43 | ||||
| -rw-r--r-- | src/main/scala/firrtl/checks/CheckResets.scala | 75 | ||||
| -rw-r--r-- | src/main/scala/firrtl/ir/IR.scala | 7 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/CheckWidths.scala | 1 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Checks.scala | 126 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/InferWidths.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/proto/FromProto.scala | 1 | ||||
| -rw-r--r-- | src/main/scala/firrtl/proto/ToProto.scala | 3 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/InferResets.scala | 253 |
17 files changed, 488 insertions, 70 deletions
diff --git a/src/main/antlr4/FIRRTL.g4 b/src/main/antlr4/FIRRTL.g4 index c49bb948..c3b4e74e 100644 --- a/src/main/antlr4/FIRRTL.g4 +++ b/src/main/antlr4/FIRRTL.g4 @@ -52,6 +52,7 @@ type | 'Fixed' ('<' intLit '>')? ('<' '<' intLit '>' '>')? | 'Clock' | 'AsyncReset' + | 'Reset' | 'Analog' ('<' intLit '>')? | '{' field* '}' // Bundle | type '[' intLit ']' // Vector diff --git a/src/main/proto/firrtl.proto b/src/main/proto/firrtl.proto index 2552b989..b8f6db98 100644 --- a/src/main/proto/firrtl.proto +++ b/src/main/proto/firrtl.proto @@ -277,6 +277,10 @@ message Firrtl { // Empty. } + message ResetType { + // Empty. + } + message BundleType { message Field { // Required. @@ -315,6 +319,7 @@ message Firrtl { FixedType fixed_type = 7; AnalogType analog_type = 8; AsyncResetType async_reset_type = 9; + ResetType reset_type = 10; } } diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala index b72fd4ce..367defb5 100644 --- a/src/main/scala/firrtl/Compiler.scala +++ b/src/main/scala/firrtl/Compiler.scala @@ -334,8 +334,9 @@ object CompilerUtils extends LazyLogging { case ChirrtlForm => Seq(new ChirrtlToHighFirrtl) ++ getLoweringTransforms(HighForm, outputForm) case HighForm => - Seq(new IRToWorkingIR, new ResolveAndCheck, new transforms.DedupModules, - new HighFirrtlToMiddleFirrtl) ++ getLoweringTransforms(MidForm, outputForm) + Seq(new IRToWorkingIR, new ResolveAndCheck, + new transforms.DedupModules, new HighFirrtlToMiddleFirrtl) ++ + getLoweringTransforms(MidForm, outputForm) case MidForm => Seq(new MiddleFirrtlToLowFirrtl) ++ getLoweringTransforms(LowForm, outputForm) case LowForm => throwInternalError("getLoweringTransforms - LowForm") // should be caught by if above case UnknownForm => throwInternalError("getLoweringTransforms - UnknownForm") // should be caught by if above diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index 8e6408fe..854e1876 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -231,7 +231,12 @@ class VerilogEmitter extends SeqTransform with Emitter { x match { case (e: DoPrim) => emit(op_stream(e), top + 1) case (e: Mux) => { - if(e.tpe == ClockType) throw EmitterException("Cannot emit clock muxes directly") + if (e.tpe == ClockType) { + throw EmitterException("Cannot emit clock muxes directly") + } + if (e.tpe == AsyncResetType) { + throw EmitterException("Cannot emit async reset muxes directly") + } emit(Seq(e.cond," ? ",cast(e.tval)," : ",cast(e.fval)),top + 1) } case (e: ValidIf) => emit(Seq(cast(e.value)),top + 1) diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index 316baec9..274ccf74 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -46,7 +46,8 @@ class ResolveAndCheck extends CoreTransform { passes.ResolveGenders, passes.CheckGenders, new passes.InferWidths, - passes.CheckWidths) + passes.CheckWidths, + new firrtl.transforms.InferResets) } /** Expands aggregate connects, removes dynamic accesses, and when @@ -68,6 +69,7 @@ class HighFirrtlToMiddleFirrtl extends CoreTransform { passes.ResolveKinds, passes.InferTypes, passes.CheckTypes, + new checks.CheckResets, passes.ResolveGenders, new passes.InferWidths, passes.CheckWidths, diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala index 0f1ecff7..1a513352 100644 --- a/src/main/scala/firrtl/PrimOps.scala +++ b/src/main/scala/firrtl/PrimOps.scala @@ -203,43 +203,38 @@ object PrimOps extends LazyLogging { case _: UIntType => UIntType(w1) case _: SIntType => UIntType(w1) case _: FixedType => UIntType(w1) - case ClockType => UIntType(IntWidth(1)) + case ClockType | AsyncResetType | ResetType => UIntType(IntWidth(1)) case AnalogType(w) => UIntType(w1) - case AsyncResetType => UIntType(IntWidth(1)) case _ => UnknownType } case AsSInt => t1 match { case _: UIntType => SIntType(w1) case _: SIntType => SIntType(w1) case _: FixedType => SIntType(w1) - case ClockType => SIntType(IntWidth(1)) + case ClockType | AsyncResetType | ResetType => SIntType(IntWidth(1)) case _: AnalogType => SIntType(w1) - case AsyncResetType => SIntType(IntWidth(1)) case _ => UnknownType } case AsFixedPoint => t1 match { case _: UIntType => FixedType(w1, c1) case _: SIntType => FixedType(w1, c1) case _: FixedType => FixedType(w1, c1) - case ClockType => FixedType(IntWidth(1), c1) + case ClockType | AsyncResetType | ResetType => FixedType(IntWidth(1), c1) case _: AnalogType => FixedType(w1, c1) - case AsyncResetType => FixedType(IntWidth(1), c1) case _ => UnknownType } case AsClock => t1 match { case _: UIntType => ClockType case _: SIntType => ClockType - case ClockType => ClockType + case ClockType | AsyncResetType | ResetType => ClockType case _: AnalogType => ClockType - case AsyncResetType => ClockType case _ => UnknownType } case AsAsyncReset => t1 match { case _: UIntType => AsyncResetType case _: SIntType => AsyncResetType - case ClockType => AsyncResetType + case ClockType | AsyncResetType | ResetType => AsyncResetType case _: AnalogType => AsyncResetType - case AsyncResetType => AsyncResetType case _ => UnknownType } case Shl => t1 match { diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 72003608..206afc09 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -445,6 +445,7 @@ object Utils extends LazyLogging { } def get_valid_points(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Seq[(Int,Int)] = { + import passes.CheckTypes.legalResetType //;println_all(["Inside with t1:" t1 ",t2:" t2 ",f1:" flip1 ",f2:" flip2]) (t1, t2) match { case (_: UIntType, _: UIntType) => if (flip1 == flip2) Seq((0, 0)) else Nil @@ -474,6 +475,14 @@ object Utils extends LazyLogging { ilen + get_size(t1x.tpe), jlen + get_size(t2x.tpe)) }._1 case (ClockType, ClockType) => if (flip1 == flip2) Seq((0, 0)) else Nil + case (AsyncResetType, AsyncResetType) => if (flip1 == flip2) Seq((0, 0)) else Nil + // The following two cases handle driving ResetType from other legal reset types + // Flippedness is important here because ResetType can be driven by other reset types, but it + // cannot *drive* other reset types + case (ResetType, other) => + if (legalResetType(other) && flip1 == Default && flip1 == flip2) Seq((0, 0)) else Nil + case (other, ResetType) => + if (legalResetType(other) && flip1 == Flip && flip1 == flip2) Seq((0, 0)) else Nil case _ => throwInternalError(s"get_valid_points: shouldn't be here - ($t1, $t2)") } } diff --git a/src/main/scala/firrtl/Visitor.scala b/src/main/scala/firrtl/Visitor.scala index 302c6142..6d9f0d31 100644 --- a/src/main/scala/firrtl/Visitor.scala +++ b/src/main/scala/firrtl/Visitor.scala @@ -131,6 +131,7 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w } case "Clock" => ClockType case "AsyncReset" => AsyncResetType + case "Reset" => ResetType case "Analog" => if (ctx.getChildCount > 1) AnalogType(getWidth(ctx.intLit(0))) else AnalogType(UnknownWidth) case "{" => BundleType(ctx.field.asScala.map(visitField)) diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index 241a89b8..c1839a22 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -7,6 +7,7 @@ import Utils._ import firrtl.ir._ import WrappedExpression._ import WrappedWidth._ +import firrtl.passes.CheckTypes.legalResetType trait Kind case object WireKind extends Kind @@ -216,32 +217,46 @@ case class ExpWidth(arg1: Width) extends Width with HasMapWidth { object WrappedType { def apply(t: Type) = new WrappedType(t) def wt(t: Type) = apply(t) -} -class WrappedType(val t: Type) { - def wt(tx: Type) = new WrappedType(tx) - override def equals(o: Any): Boolean = o match { - case (t2: WrappedType) => (t, t2.t) match { + // Check if it is legal for the source type to drive the sink type + // Which is which matters because ResetType can be driven by itself, Bool, or AsyncResetType, but + // it cannot drive Bool nor AsyncResetType + private def compare(sink: Type, source: Type): Boolean = + (sink, source) match { case (_: UIntType, _: UIntType) => true case (_: SIntType, _: SIntType) => true case (ClockType, ClockType) => true case (AsyncResetType, AsyncResetType) => true + case (ResetType, tpe) => legalResetType(tpe) + case (tpe, ResetType) => legalResetType(tpe) case (_: FixedType, _: FixedType) => true // Analog totally skips out of the Firrtl type system. // The only way Analog can play with another Analog component is through Attach. // Ohterwise, we'd need to special case it during ExpandWhens, Lowering, // ExpandConnects, etc. case (_: AnalogType, _: AnalogType) => false - case (t1: VectorType, t2: VectorType) => - t1.size == t2.size && wt(t1.tpe) == wt(t2.tpe) - case (t1: BundleType, t2: BundleType) => - t1.fields.size == t2.fields.size && ( - (t1.fields zip t2.fields) forall { case (f1, f2) => - f1.flip == f2.flip && f1.name == f2.name - }) && ((t1.fields zip t2.fields) forall { case (f1, f2) => - wt(f1.tpe) == wt(f2.tpe) - }) + case (sink: VectorType, source: VectorType) => + sink.size == source.size && compare(sink.tpe, source.tpe) + case (sink: BundleType, source: BundleType) => + (sink.fields.size == source.fields.size) && + sink.fields.zip(source.fields).forall { case (f1, f2) => + (f1.flip == f2.flip) && (f1.name == f2.name) && (f1.flip match { + case Default => compare(f1.tpe, f2.tpe) + // We allow UInt<1> and AsyncReset to drive Reset but not the other way around + case Flip => compare(f2.tpe, f1.tpe) + }) + } case _ => false } +} +class WrappedType(val t: Type) { + def wt(tx: Type) = new WrappedType(tx) + // TODO Better name? + /** Strict comparison except Reset accepts AsyncReset, Reset, and `UInt<1>` + */ + def superTypeOf(that: WrappedType): Boolean = WrappedType.compare(this.t, that.t) + + override def equals(o: Any): Boolean = o match { + case (t2: WrappedType) => WrappedType.compare(this.t, t2.t) case _ => false } } diff --git a/src/main/scala/firrtl/checks/CheckResets.scala b/src/main/scala/firrtl/checks/CheckResets.scala new file mode 100644 index 00000000..d6337f9e --- /dev/null +++ b/src/main/scala/firrtl/checks/CheckResets.scala @@ -0,0 +1,75 @@ +// See LICENSE for license details. + +package firrtl.checks + +import firrtl._ +import firrtl.passes.{Errors, PassException} +import firrtl.ir._ +import firrtl.traversals.Foreachers._ +import firrtl.WrappedExpression._ + +import scala.collection.mutable + +object CheckResets { + class NonLiteralAsyncResetValueException(info: Info, mname: String, reg: String, init: String) extends PassException( + s"$info: [module $mname] AsyncReset Reg '$reg' reset to non-literal '$init'") + + // Map of Initialization Expression to check + private type RegCheckList = mutable.ListBuffer[(Expression, DefRegister)] + // Record driving for literal propagation + // Indicates *driven by* + private type DirectDriverMap = mutable.HashMap[WrappedExpression, Expression] + +} + +// Must run after ExpandWhens +// Requires +// - static single connections of ground types +class CheckResets extends Transform { + def inputForm: CircuitForm = MidForm + def outputForm: CircuitForm = MidForm + + import CheckResets._ + + private def onStmt(regCheck: RegCheckList, drivers: DirectDriverMap)(stmt: Statement): Unit = { + stmt match { + case DefNode(_, name, expr) => drivers += we(WRef(name)) -> expr + case Connect(_, lhs, rhs) => drivers += we(lhs) -> rhs + case reg @ DefRegister(_,_,_,_, reset, init) if reset.tpe == AsyncResetType => + regCheck += init -> reg + case _ => // Do nothing + } + stmt.foreach(onStmt(regCheck, drivers)) + } + + private def findDriver(drivers: DirectDriverMap)(expr: Expression): Expression = + drivers.get(we(expr)) match { + case Some(lit: Literal) => lit + case Some(other) => findDriver(drivers)(other) + case None => expr + } + + private def onMod(errors: Errors)(mod: DefModule): Unit = { + val regCheck = new RegCheckList() + val drivers = new DirectDriverMap() + mod.foreach(onStmt(regCheck, drivers)) + for ((init, reg) <- regCheck) { + for (subInit <- Utils.create_exps(init)) { + findDriver(drivers)(subInit) match { + case lit: Literal => // All good + case other => + val e = new NonLiteralAsyncResetValueException(reg.info, mod.name, reg.name, other.serialize) + errors.append(e) + } + } + } + } + + def execute(state: CircuitState): CircuitState = { + val errors = new Errors + state.circuit.foreach(onMod(errors)) + errors.trigger() + state + } +} + diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index 8124e1e6..9268865b 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -591,6 +591,13 @@ case object ClockType extends GroundType { def mapWidth(f: Width => Width): Type = this def foreachWidth(f: Width => Unit): Unit = Unit } +/* Abstract reset, will be inferred to UInt<1> or AsyncReset */ +case object ResetType extends GroundType { + val width = IntWidth(1) + def serialize: String = "Reset" + def mapWidth(f: Width => Width): Type = this + def foreachWidth(f: Width => Unit): Unit = Unit +} case object AsyncResetType extends GroundType { val width = IntWidth(1) def serialize: String = "AsyncReset" diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala index b0d9085b..07784e19 100644 --- a/src/main/scala/firrtl/passes/CheckWidths.scala +++ b/src/main/scala/firrtl/passes/CheckWidths.scala @@ -108,6 +108,7 @@ object CheckWidths extends Pass { sx.reset.tpe match { case UIntType(IntWidth(w)) if w == 1 => case AsyncResetType => + case ResetType => case _ => errors.append(new CheckTypes.IllegalResetType(info, target.serialize, sx.name)) } case _ => diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index 972a018e..a17a5a2e 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -57,8 +57,11 @@ trait CheckHighFormLike { s"$info: [module $mname] Primop $op argument $value < 0.") class LsbLargerThanMsbException(info: Info, mname: String, op: String, lsb: Int, msb: Int) extends PassException( s"$info: [module $mname] Primop $op lsb $lsb > $msb.") - class NonLiteralAsyncResetValueException(info: Info, mname: String, reg: String, init: String) extends PassException( - s"$info: [module $mname] AsyncReset Reg '$reg' reset to non-literal '$init'") + class ResetInputException(info: Info, mname: String, expr: Expression) extends PassException( + s"$info: [module $mname] Abstract Reset not allowed as top-level input: ${expr.serialize}") + class ResetExtModuleOutputException(info: Info, mname: String, expr: Expression) extends PassException( + s"$info: [module $mname] Abstract Reset not allowed as ExtModule output: ${expr.serialize}") + // Is Chirrtl allowed for this check? If not, return an error def errorOnChirrtl(info: Info, mname: String, s: Statement): Option[PassException] @@ -188,8 +191,6 @@ trait CheckHighFormLike { case DefRegister(info, name, tpe, _, reset, init) => if (hasFlip(tpe)) errors.append(new RegWithFlipException(info, mname, name)) - if (reset.tpe == AsyncResetType && !init.isInstanceOf[Literal]) - errors.append(new NonLiteralAsyncResetValueException(info, mname, name, init.serialize)) case sx: DefMemory => if (sx.readLatency < 0 || sx.writeLatency <= 0) errors.append(new IllegalMemLatencyException(info, mname, sx.name)) @@ -218,15 +219,36 @@ trait CheckHighFormLike { p.tpe foreach checkHighFormW(p.info, mname) } + // Search for ResetType Ports of direction + def findBadResetTypePorts(m: DefModule, dir: Direction): Seq[(Port, Expression)] = { + val bad = to_gender(dir) + for { + port <- m.ports + ref = WRef(port).copy(gender = to_gender(port.direction)) + expr <- create_exps(ref) + if ((expr.tpe == ResetType) && (gender(expr) == bad)) + } yield (port, expr) + } + def checkHighFormM(m: DefModule): Unit = { val names = new NameSet m foreach checkHighFormP(m.name, names) m foreach checkHighFormS(m.info, m.name, names) + m match { + case _: Module => + case ext: ExtModule => + for ((port, expr) <- findBadResetTypePorts(ext, Output)) { + errors.append(new ResetExtModuleOutputException(port.info, ext.name, expr)) + } + } } c.modules foreach checkHighFormM - c.modules count (_.name == c.main) match { - case 1 => + c.modules.filter(_.name == c.main) match { + case Seq(topMod) => + for ((port, expr) <- findBadResetTypePorts(topMod, Input)) { + errors.append(new ResetInputException(port.info, topMod.name, expr)) + } case _ => errors.append(new NoTopModuleException(c.info, c.main)) } errors.trigger() @@ -263,8 +285,12 @@ object CheckTypes extends Pass { s"$info: [module $mname] Index is not of UIntType.") class EnableNotUInt(info: Info, mname: String) extends PassException( s"$info: [module $mname] Enable is not of UIntType.") - class InvalidConnect(info: Info, mname: String, lhs: String, rhs: String) extends PassException( - s"$info: [module $mname] Type mismatch. Cannot connect $lhs to $rhs.") + class InvalidConnect(info: Info, mname: String, con: String, lhs: Expression, rhs: Expression) + extends PassException({ + val ltpe = s" ${lhs.serialize}: ${lhs.tpe.serialize}" + val rtpe = s" ${rhs.serialize}: ${rhs.tpe.serialize}" + s"$info: [module $mname] Type mismatch in '$con'.\n$ltpe\n$rtpe" + }) class InvalidRegInit(info: Info, mname: String) extends PassException( s"$info: [module $mname] Type of init must match type of DefRegister.") class PrintfArgNotGround(info: Info, mname: String) extends PassException( @@ -308,11 +334,52 @@ object CheckTypes extends Pass { class IllegalAttachExp(info: Info, mname: String, expName: String) extends PassException( s"$info: [module $mname] Attach expression must be an port, wire, or port of instance: $expName.") class IllegalResetType(info: Info, mname: String, exp: String) extends PassException( - s"$info: [module $mname] Register resets must have type UInt<1>: $exp.") + s"$info: [module $mname] Register resets must have type Reset, AsyncReset, or UInt<1>: $exp.") class IllegalUnknownType(info: Info, mname: String, exp: String) extends PassException( s"$info: [module $mname] Uninferred type: $exp." ) + def legalResetType(tpe: Type): Boolean = tpe match { + case UIntType(IntWidth(w)) if w == 1 => true + case AsyncResetType => true + case ResetType => true + case UIntType(UnknownWidth) => + // cannot catch here, though width may ultimately be wrong + true + case _ => false + } + + private def bulk_equals(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Boolean = { + (t1, t2) match { + case (ClockType, ClockType) => flip1 == flip2 + case (_: UIntType, _: UIntType) => flip1 == flip2 + case (_: SIntType, _: SIntType) => flip1 == flip2 + case (_: FixedType, _: FixedType) => flip1 == flip2 + case (_: AnalogType, _: AnalogType) => true + case (AsyncResetType, AsyncResetType) => flip1 == flip2 + case (ResetType, tpe) => legalResetType(tpe) && flip1 == flip2 + case (tpe, ResetType) => legalResetType(tpe) && flip1 == flip2 + case (t1: BundleType, t2: BundleType) => + val t1_fields = (t1.fields foldLeft Map[String, (Type, Orientation)]())( + (map, f1) => map + (f1.name ->( (f1.tpe, f1.flip) ))) + t2.fields forall (f2 => + t1_fields get f2.name match { + case None => true + case Some((f1_tpe, f1_flip)) => + bulk_equals(f1_tpe, f2.tpe, times(flip1, f1_flip), times(flip2, f2.flip)) + } + ) + case (t1: VectorType, t2: VectorType) => + bulk_equals(t1.tpe, t2.tpe, flip1, flip2) + case (_, _) => false + } + } + + def validConnect(con: Connect): Boolean = wt(con.loc.tpe).superTypeOf(wt(con.expr.tpe)) + + def validPartialConnect(con: PartialConnect): Boolean = + bulk_equals(con.loc.tpe, con.expr.tpe, Default, Default) + //;---------------- Helper Functions -------------- def ut: UIntType = UIntType(UnknownWidth) def st: SIntType = SIntType(UnknownWidth) @@ -414,48 +481,23 @@ object CheckTypes extends Pass { e foreach check_types_e(info, mname) } - def bulk_equals(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Boolean = { - //;println_all(["Inside with t1:" t1 ",t2:" t2 ",f1:" flip1 ",f2:" flip2]) - (t1, t2) match { - case (ClockType, ClockType) => flip1 == flip2 - case (_: UIntType, _: UIntType) => flip1 == flip2 - case (_: SIntType, _: SIntType) => flip1 == flip2 - case (_: FixedType, _: FixedType) => flip1 == flip2 - case (_: AnalogType, _: AnalogType) => true - case (t1: BundleType, t2: BundleType) => - val t1_fields = (t1.fields foldLeft Map[String, (Type, Orientation)]())( - (map, f1) => map + (f1.name ->( (f1.tpe, f1.flip) ))) - t2.fields forall (f2 => - t1_fields get f2.name match { - case None => true - case Some((f1_tpe, f1_flip)) => - bulk_equals(f1_tpe, f2.tpe, times(flip1, f1_flip), times(flip2, f2.flip)) - } - ) - case (t1: VectorType, t2: VectorType) => - bulk_equals(t1.tpe, t2.tpe, flip1, flip2) - case (_, _) => false - } - } - def check_types_s(minfo: Info, mname: String)(s: Statement): Unit = { val info = get_info(s) match { case NoInfo => minfo case x => x } s match { - case sx: Connect if wt(sx.loc.tpe) != wt(sx.expr.tpe) => - errors.append(new InvalidConnect(info, mname, sx.loc.serialize, sx.expr.serialize)) - case sx: PartialConnect if !bulk_equals(sx.loc.tpe, sx.expr.tpe, Default, Default) => - errors.append(new InvalidConnect(info, mname, sx.loc.serialize, sx.expr.serialize)) + case sx: Connect if !validConnect(sx) => + val conMsg = sx.copy(info = NoInfo).serialize + errors.append(new InvalidConnect(info, mname, conMsg, sx.loc, sx.expr)) + case sx: PartialConnect if !validPartialConnect(sx) => + val conMsg = sx.copy(info = NoInfo).serialize + errors.append(new InvalidConnect(info, mname, conMsg, sx.loc, sx.expr)) case sx: DefRegister => sx.tpe match { case AnalogType(_) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name)) case t if wt(sx.tpe) != wt(sx.init.tpe) => errors.append(new InvalidRegInit(info, mname)) case t => } - sx.reset.tpe match { - case UIntType(IntWidth(w)) if w == 1 => - case AsyncResetType => - case UIntType(UnknownWidth) => // cannot catch here, though width may ultimately be wrong - case _ => errors.append(new IllegalResetType(info, mname, sx.name)) + if (!legalResetType(sx.reset.tpe)) { + errors.append(new IllegalResetType(info, mname, sx.name)) } if (sx.clock.tpe != ClockType) { errors.append(new RegReqClk(info, mname, sx.name)) diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index cf6f2ae0..8f663afd 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -262,6 +262,8 @@ class InferWidths extends Transform with ResolvedAnnotationPaths { }) } case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe) + case (ResetType, _) => Nil + case (_, ResetType) => Nil } def run(c: Circuit, extra: Seq[WGeq]): Circuit = { diff --git a/src/main/scala/firrtl/proto/FromProto.scala b/src/main/scala/firrtl/proto/FromProto.scala index 44e505f1..22c90316 100644 --- a/src/main/scala/firrtl/proto/FromProto.scala +++ b/src/main/scala/firrtl/proto/FromProto.scala @@ -256,6 +256,7 @@ object FromProto { case FIXED_TYPE_FIELD_NUMBER => convert(tpe.getFixedType) case CLOCK_TYPE_FIELD_NUMBER => ir.ClockType case ASYNC_RESET_TYPE_FIELD_NUMBER => ir.AsyncResetType + case RESET_TYPE_FIELD_NUMBER => ir.ResetType case ANALOG_TYPE_FIELD_NUMBER => convert(tpe.getAnalogType) case BUNDLE_TYPE_FIELD_NUMBER => ir.BundleType(tpe.getBundleType.getFieldList.asScala.map(convert(_))) diff --git a/src/main/scala/firrtl/proto/ToProto.scala b/src/main/scala/firrtl/proto/ToProto.scala index c67f446c..17adb698 100644 --- a/src/main/scala/firrtl/proto/ToProto.scala +++ b/src/main/scala/firrtl/proto/ToProto.scala @@ -343,6 +343,9 @@ object ToProto { case ir.AsyncResetType => val at = Firrtl.Type.AsyncResetType.newBuilder() tb.setAsyncResetType(at) + case ir.ResetType => + val rt = Firrtl.Type.ResetType.newBuilder() + tb.setResetType(rt) case ir.AnalogType(width) => val at = Firrtl.Type.AnalogType.newBuilder() convert(width).foreach(at.setWidth) diff --git a/src/main/scala/firrtl/transforms/InferResets.scala b/src/main/scala/firrtl/transforms/InferResets.scala new file mode 100644 index 00000000..70e2b76c --- /dev/null +++ b/src/main/scala/firrtl/transforms/InferResets.scala @@ -0,0 +1,253 @@ +// See LICENSE for license details. + +package firrtl.transforms + +import firrtl._ +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.traversals.Foreachers._ +import firrtl.annotations.{ReferenceTarget, TargetToken} +import firrtl.Utils.toTarget +import firrtl.passes.{Pass, PassException, Errors, InferTypes} + +import scala.collection.mutable +import scala.util.Try + +object InferResets { + final class DifferingDriverTypesException private (msg: String) extends PassException(msg) + object DifferingDriverTypesException { + def apply(target: ReferenceTarget, tpes: Seq[(Type, Seq[TypeDriver])]): DifferingDriverTypesException = { + val xs = tpes.map { case (t, ds) => s"${ds.map(_.target().serialize).mkString(", ")} of type ${t.serialize}" } + val msg = s"${target.serialize} driven with multiple types!" + xs.mkString("\n ", "\n ", "") + new DifferingDriverTypesException(msg) + } + } + + /** Type hierarchy to represent the type of the thing driving a [[ResetType]] */ + private sealed trait ResetDriver + // When a [[ResetType]] is driven by another ResetType, we track the target so that we can infer + // the same type as the driver + private case class TargetDriver(target: ReferenceTarget) extends ResetDriver { + override def toString: String = s"TargetDriver(${target.serialize})" + } + // When a [[ResetType]] is driven by something of type Bool or AsyncResetType, we keep track of it + // as a constraint on the type we should infer to be + // We keep the target around (lazily) so that we can report errors + private case class TypeDriver(tpe: Type, target: () => ReferenceTarget) extends ResetDriver { + override def toString: String = s"TypeDriver(${tpe.serialize}, $target)" + } + + + // Type hierarchy representing the path to a leaf type in an aggregate type structure + // Used by this [[InferResets]] to pinpoint instances of [[ResetType]] and their inferred type + private sealed trait TypeTree + private case class BundleTree(fields: Map[String, TypeTree]) extends TypeTree + private case class VectorTree(subType: TypeTree) extends TypeTree + // TODO ensure is only AsyncResetType or BoolType + private case class GroundTree(tpe: Type) extends TypeTree + + private object TypeTree { + // Given groups of [[TargetToken]]s and Types corresponding to them, construct a [[TypeTree]] + // that allows us to lookup the type of each leaf node in the aggregate structure + // TODO make return Try[TypeTree] + def fromTokens(tokens: (Seq[TargetToken], Type)*): TypeTree = tokens match { + case Seq((Seq(), tpe)) => GroundTree(tpe) + // VectorTree + case (TargetToken.Index(_) +: _, _) +: _ => + // Vectors must all have the same type, so we only process Index 0 + // If the subtype is an aggregate, there can be multiple of each index + val ts = tokens.collect { case (TargetToken.Index(0) +: tail, tpe) => (tail, tpe) } + VectorTree(fromTokens(ts:_*)) + // BundleTree + case (TargetToken.Field(_) +: _, _) +: _ => + val fields = + tokens.groupBy { case (TargetToken.Field(n) +: t, _) => n } + .mapValues { ts => + fromTokens(ts.map { case (_ +: t, tpe) => (t, tpe) }:_*) + } + BundleTree(fields) + } + } +} + +/** Infers the concrete type of [[Reset]]s by their connections + * This is a global inference because ports can be of type [[Reset]] + * @note This transform should be run before [[DedupModules]] so that similar Modules from + * generator languages like Chisel can infer differently + */ +// TODO should we error if a DefMemory is of type AsyncReset? In CheckTypes? +class InferResets extends Transform { + def inputForm: CircuitForm = HighForm + def outputForm: CircuitForm = HighForm + + import InferResets._ + + // Collect all drivers for circuit elements of type ResetType + private def analyze(c: Circuit): Map[ReferenceTarget, List[ResetDriver]] = { + type DriverMap = mutable.HashMap[ReferenceTarget, mutable.ListBuffer[ResetDriver]] + def onMod(mod: DefModule): DriverMap = { + val instMap = mutable.Map[String, String]() + // We need to convert submodule port targets into targets on the Module port itself + def makeTarget(expr: Expression): ReferenceTarget = { + val target = toTarget(c.main, mod.name)(expr) + Utils.kind(expr) match { + case InstanceKind => + val mod = instMap(target.ref) + val port = target.component.head match { + case TargetToken.Field(name) => name + case bad => Utils.throwInternalError(s"Unexpected token $bad") + } + target.copy(module = mod, ref = port, component = target.component.tail) + case _ => target + } + } + def onStmt(map: DriverMap)(stmt: Statement): Unit = { + // Mark driver of a ResetType leaf + def markResetDriver(lhs: Expression, rhs: Expression): Unit = { + val lflip = Utils.to_flip(Utils.gender(lhs)) + if ((lflip == Default && lhs.tpe == ResetType) || + (lflip == Flip && rhs.tpe == ResetType)) { + val (loc, exp) = lflip match { + case Default => (lhs, rhs) + case Flip => (rhs, lhs) + } + val target = makeTarget(loc) + val driver = exp.tpe match { + case ResetType => TargetDriver(makeTarget(exp)) + case tpe => TypeDriver(tpe, () => makeTarget(exp)) + } + map.getOrElseUpdate(target, mutable.ListBuffer()) += driver + } + } + stmt match { + // TODO + // - Each connect duplicates a bunch of code from ExpandConnects, could be cleaner + // - The full create_exps duplication is inefficient, there has to be a better way + case Connect(_, lhs, rhs) => + val locs = Utils.create_exps(lhs) + val exps = Utils.create_exps(rhs) + for ((loc, exp) <- locs.zip(exps)) { + markResetDriver(loc, exp) + } + case PartialConnect(_, lhs, rhs) => + val points = Utils.get_valid_points(lhs.tpe, rhs.tpe, Default, Default) + val locs = Utils.create_exps(lhs) + val exps = Utils.create_exps(rhs) + for ((i, j) <- points) { + markResetDriver(locs(i), exps(j)) + } + case WDefInstance(_, inst, module, _) => + instMap += (inst -> module) + case Conditionally(_, _, con, alt) => + val conMap = new DriverMap + val altMap = new DriverMap + onStmt(conMap)(con) + onStmt(altMap)(alt) + // Default to outerscope if not found in alt + val altLookup = altMap.orElse(map).lift + for (key <- conMap.keys ++ altMap.keys) { + val ds = map.getOrElseUpdate(key, mutable.ListBuffer()) + conMap.get(key).foreach(ds ++= _) + altLookup(key).foreach(ds ++= _) + } + case other => other.foreach(onStmt(map)) + } + } + val types = new DriverMap + mod.foreach(onStmt(types)) + types + } + c.modules.foldLeft(Map[ReferenceTarget, List[ResetDriver]]()) { + case (map, mod) => map ++ onMod(mod).mapValues(_.toList) + } + } + + // Determine the type driving a given ResetType + private def resolve(map: Map[ReferenceTarget, List[ResetDriver]]): Try[Map[ReferenceTarget, Type]] = { + val res = mutable.Map[ReferenceTarget, Type]() + val errors = new Errors + def rec(target: ReferenceTarget): Type = { + val drivers = map(target) + res.getOrElseUpdate(target, { + val tpes = drivers.map { + case TargetDriver(t) => TypeDriver(rec(t), () => t) + case td: TypeDriver => td + }.groupBy(_.tpe) + if (tpes.keys.size != 1) { + // Multiple types of driver! + errors.append(DifferingDriverTypesException(target, tpes.toSeq)) + } + tpes.keys.head + }) + } + for ((target, _) <- map) { + rec(target) + } + Try { + errors.trigger() + res.toMap + } + } + + private def fixupType(tpe: Type, tree: TypeTree): Type = (tpe, tree) match { + case (BundleType(fields), BundleTree(map)) => + val fieldsx = + fields.map(f => map.get(f.name) match { + case Some(t) => f.copy(tpe = fixupType(f.tpe, t)) + case None => f + }) + BundleType(fieldsx) + case (VectorType(vtpe, size), VectorTree(t)) => + VectorType(fixupType(vtpe, t), size) + case (_, GroundTree(t)) => t + case x => throw new Exception(s"Error! Unexpected pair $x") + } + + // Assumes all ReferenceTargets are in the same module + private def makeDeclMap(map: Map[ReferenceTarget, Type]): Map[String, TypeTree] = + map.groupBy(_._1.ref).mapValues { ts => + TypeTree.fromTokens(ts.toSeq.map { case (target, tpe) => (target.component, tpe) }:_*) + } + + private def implPort(map: Map[String, TypeTree])(port: Port): Port = + map.get(port.name) + .map(tree => port.copy(tpe = fixupType(port.tpe, tree))) + .getOrElse(port) + private def implStmt(map: Map[String, TypeTree])(stmt: Statement): Statement = + stmt.map(implStmt(map)) match { + case decl: IsDeclaration if map.contains(decl.name) => + val tree = map(decl.name) + decl match { + case reg: DefRegister => reg.copy(tpe = fixupType(reg.tpe, tree)) + case wire: DefWire => wire.copy(tpe = fixupType(wire.tpe, tree)) + // TODO Can this really happen? + case mem: DefMemory => mem.copy(dataType = fixupType(mem.dataType, tree)) + case other => other + } + case other => other + } + + private def implement(c: Circuit, map: Map[ReferenceTarget, Type]): Circuit = { + val modMaps = map.groupBy(_._1.module) + def onMod(mod: DefModule): DefModule = { + modMaps.get(mod.name).map { tmap => + val declMap = makeDeclMap(tmap) + mod.map(implPort(declMap)).map(implStmt(declMap)) + }.getOrElse(mod) + } + c.map(onMod) + } + + private def fixupPasses: Seq[Pass] = Seq( + InferTypes + ) + + def execute(state: CircuitState): CircuitState = { + val c = state.circuit + val analysis = analyze(c) + val inferred = resolve(analysis) + val result = inferred.map(m => implement(c, m)).get + val fixedup = fixupPasses.foldLeft(result)((c, p) => p.run(c)) + state.copy(circuit = fixedup) + } +} |
