aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main/antlr4/FIRRTL.g41
-rw-r--r--src/main/proto/firrtl.proto5
-rw-r--r--src/main/scala/firrtl/Compiler.scala5
-rw-r--r--src/main/scala/firrtl/Emitter.scala7
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala4
-rw-r--r--src/main/scala/firrtl/PrimOps.scala15
-rw-r--r--src/main/scala/firrtl/Utils.scala9
-rw-r--r--src/main/scala/firrtl/Visitor.scala1
-rw-r--r--src/main/scala/firrtl/WIR.scala43
-rw-r--r--src/main/scala/firrtl/checks/CheckResets.scala75
-rw-r--r--src/main/scala/firrtl/ir/IR.scala7
-rw-r--r--src/main/scala/firrtl/passes/CheckWidths.scala1
-rw-r--r--src/main/scala/firrtl/passes/Checks.scala126
-rw-r--r--src/main/scala/firrtl/passes/InferWidths.scala2
-rw-r--r--src/main/scala/firrtl/proto/FromProto.scala1
-rw-r--r--src/main/scala/firrtl/proto/ToProto.scala3
-rw-r--r--src/main/scala/firrtl/transforms/InferResets.scala253
-rw-r--r--src/test/scala/firrtlTests/AsyncResetSpec.scala137
-rw-r--r--src/test/scala/firrtlTests/InferResetsSpec.scala355
-rw-r--r--src/test/scala/firrtlTests/ProtoBufSpec.scala5
20 files changed, 983 insertions, 72 deletions
diff --git a/src/main/antlr4/FIRRTL.g4 b/src/main/antlr4/FIRRTL.g4
index c49bb948..c3b4e74e 100644
--- a/src/main/antlr4/FIRRTL.g4
+++ b/src/main/antlr4/FIRRTL.g4
@@ -52,6 +52,7 @@ type
| 'Fixed' ('<' intLit '>')? ('<' '<' intLit '>' '>')?
| 'Clock'
| 'AsyncReset'
+ | 'Reset'
| 'Analog' ('<' intLit '>')?
| '{' field* '}' // Bundle
| type '[' intLit ']' // Vector
diff --git a/src/main/proto/firrtl.proto b/src/main/proto/firrtl.proto
index 2552b989..b8f6db98 100644
--- a/src/main/proto/firrtl.proto
+++ b/src/main/proto/firrtl.proto
@@ -277,6 +277,10 @@ message Firrtl {
// Empty.
}
+ message ResetType {
+ // Empty.
+ }
+
message BundleType {
message Field {
// Required.
@@ -315,6 +319,7 @@ message Firrtl {
FixedType fixed_type = 7;
AnalogType analog_type = 8;
AsyncResetType async_reset_type = 9;
+ ResetType reset_type = 10;
}
}
diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala
index b72fd4ce..367defb5 100644
--- a/src/main/scala/firrtl/Compiler.scala
+++ b/src/main/scala/firrtl/Compiler.scala
@@ -334,8 +334,9 @@ object CompilerUtils extends LazyLogging {
case ChirrtlForm =>
Seq(new ChirrtlToHighFirrtl) ++ getLoweringTransforms(HighForm, outputForm)
case HighForm =>
- Seq(new IRToWorkingIR, new ResolveAndCheck, new transforms.DedupModules,
- new HighFirrtlToMiddleFirrtl) ++ getLoweringTransforms(MidForm, outputForm)
+ Seq(new IRToWorkingIR, new ResolveAndCheck,
+ new transforms.DedupModules, new HighFirrtlToMiddleFirrtl) ++
+ getLoweringTransforms(MidForm, outputForm)
case MidForm => Seq(new MiddleFirrtlToLowFirrtl) ++ getLoweringTransforms(LowForm, outputForm)
case LowForm => throwInternalError("getLoweringTransforms - LowForm") // should be caught by if above
case UnknownForm => throwInternalError("getLoweringTransforms - UnknownForm") // should be caught by if above
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala
index 8e6408fe..854e1876 100644
--- a/src/main/scala/firrtl/Emitter.scala
+++ b/src/main/scala/firrtl/Emitter.scala
@@ -231,7 +231,12 @@ class VerilogEmitter extends SeqTransform with Emitter {
x match {
case (e: DoPrim) => emit(op_stream(e), top + 1)
case (e: Mux) => {
- if(e.tpe == ClockType) throw EmitterException("Cannot emit clock muxes directly")
+ if (e.tpe == ClockType) {
+ throw EmitterException("Cannot emit clock muxes directly")
+ }
+ if (e.tpe == AsyncResetType) {
+ throw EmitterException("Cannot emit async reset muxes directly")
+ }
emit(Seq(e.cond," ? ",cast(e.tval)," : ",cast(e.fval)),top + 1)
}
case (e: ValidIf) => emit(Seq(cast(e.value)),top + 1)
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index 316baec9..274ccf74 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -46,7 +46,8 @@ class ResolveAndCheck extends CoreTransform {
passes.ResolveGenders,
passes.CheckGenders,
new passes.InferWidths,
- passes.CheckWidths)
+ passes.CheckWidths,
+ new firrtl.transforms.InferResets)
}
/** Expands aggregate connects, removes dynamic accesses, and when
@@ -68,6 +69,7 @@ class HighFirrtlToMiddleFirrtl extends CoreTransform {
passes.ResolveKinds,
passes.InferTypes,
passes.CheckTypes,
+ new checks.CheckResets,
passes.ResolveGenders,
new passes.InferWidths,
passes.CheckWidths,
diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala
index 0f1ecff7..1a513352 100644
--- a/src/main/scala/firrtl/PrimOps.scala
+++ b/src/main/scala/firrtl/PrimOps.scala
@@ -203,43 +203,38 @@ object PrimOps extends LazyLogging {
case _: UIntType => UIntType(w1)
case _: SIntType => UIntType(w1)
case _: FixedType => UIntType(w1)
- case ClockType => UIntType(IntWidth(1))
+ case ClockType | AsyncResetType | ResetType => UIntType(IntWidth(1))
case AnalogType(w) => UIntType(w1)
- case AsyncResetType => UIntType(IntWidth(1))
case _ => UnknownType
}
case AsSInt => t1 match {
case _: UIntType => SIntType(w1)
case _: SIntType => SIntType(w1)
case _: FixedType => SIntType(w1)
- case ClockType => SIntType(IntWidth(1))
+ case ClockType | AsyncResetType | ResetType => SIntType(IntWidth(1))
case _: AnalogType => SIntType(w1)
- case AsyncResetType => SIntType(IntWidth(1))
case _ => UnknownType
}
case AsFixedPoint => t1 match {
case _: UIntType => FixedType(w1, c1)
case _: SIntType => FixedType(w1, c1)
case _: FixedType => FixedType(w1, c1)
- case ClockType => FixedType(IntWidth(1), c1)
+ case ClockType | AsyncResetType | ResetType => FixedType(IntWidth(1), c1)
case _: AnalogType => FixedType(w1, c1)
- case AsyncResetType => FixedType(IntWidth(1), c1)
case _ => UnknownType
}
case AsClock => t1 match {
case _: UIntType => ClockType
case _: SIntType => ClockType
- case ClockType => ClockType
+ case ClockType | AsyncResetType | ResetType => ClockType
case _: AnalogType => ClockType
- case AsyncResetType => ClockType
case _ => UnknownType
}
case AsAsyncReset => t1 match {
case _: UIntType => AsyncResetType
case _: SIntType => AsyncResetType
- case ClockType => AsyncResetType
+ case ClockType | AsyncResetType | ResetType => AsyncResetType
case _: AnalogType => AsyncResetType
- case AsyncResetType => AsyncResetType
case _ => UnknownType
}
case Shl => t1 match {
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index 72003608..206afc09 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -445,6 +445,7 @@ object Utils extends LazyLogging {
}
def get_valid_points(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Seq[(Int,Int)] = {
+ import passes.CheckTypes.legalResetType
//;println_all(["Inside with t1:" t1 ",t2:" t2 ",f1:" flip1 ",f2:" flip2])
(t1, t2) match {
case (_: UIntType, _: UIntType) => if (flip1 == flip2) Seq((0, 0)) else Nil
@@ -474,6 +475,14 @@ object Utils extends LazyLogging {
ilen + get_size(t1x.tpe), jlen + get_size(t2x.tpe))
}._1
case (ClockType, ClockType) => if (flip1 == flip2) Seq((0, 0)) else Nil
+ case (AsyncResetType, AsyncResetType) => if (flip1 == flip2) Seq((0, 0)) else Nil
+ // The following two cases handle driving ResetType from other legal reset types
+ // Flippedness is important here because ResetType can be driven by other reset types, but it
+ // cannot *drive* other reset types
+ case (ResetType, other) =>
+ if (legalResetType(other) && flip1 == Default && flip1 == flip2) Seq((0, 0)) else Nil
+ case (other, ResetType) =>
+ if (legalResetType(other) && flip1 == Flip && flip1 == flip2) Seq((0, 0)) else Nil
case _ => throwInternalError(s"get_valid_points: shouldn't be here - ($t1, $t2)")
}
}
diff --git a/src/main/scala/firrtl/Visitor.scala b/src/main/scala/firrtl/Visitor.scala
index 302c6142..6d9f0d31 100644
--- a/src/main/scala/firrtl/Visitor.scala
+++ b/src/main/scala/firrtl/Visitor.scala
@@ -131,6 +131,7 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
}
case "Clock" => ClockType
case "AsyncReset" => AsyncResetType
+ case "Reset" => ResetType
case "Analog" => if (ctx.getChildCount > 1) AnalogType(getWidth(ctx.intLit(0)))
else AnalogType(UnknownWidth)
case "{" => BundleType(ctx.field.asScala.map(visitField))
diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala
index 241a89b8..c1839a22 100644
--- a/src/main/scala/firrtl/WIR.scala
+++ b/src/main/scala/firrtl/WIR.scala
@@ -7,6 +7,7 @@ import Utils._
import firrtl.ir._
import WrappedExpression._
import WrappedWidth._
+import firrtl.passes.CheckTypes.legalResetType
trait Kind
case object WireKind extends Kind
@@ -216,32 +217,46 @@ case class ExpWidth(arg1: Width) extends Width with HasMapWidth {
object WrappedType {
def apply(t: Type) = new WrappedType(t)
def wt(t: Type) = apply(t)
-}
-class WrappedType(val t: Type) {
- def wt(tx: Type) = new WrappedType(tx)
- override def equals(o: Any): Boolean = o match {
- case (t2: WrappedType) => (t, t2.t) match {
+ // Check if it is legal for the source type to drive the sink type
+ // Which is which matters because ResetType can be driven by itself, Bool, or AsyncResetType, but
+ // it cannot drive Bool nor AsyncResetType
+ private def compare(sink: Type, source: Type): Boolean =
+ (sink, source) match {
case (_: UIntType, _: UIntType) => true
case (_: SIntType, _: SIntType) => true
case (ClockType, ClockType) => true
case (AsyncResetType, AsyncResetType) => true
+ case (ResetType, tpe) => legalResetType(tpe)
+ case (tpe, ResetType) => legalResetType(tpe)
case (_: FixedType, _: FixedType) => true
// Analog totally skips out of the Firrtl type system.
// The only way Analog can play with another Analog component is through Attach.
// Ohterwise, we'd need to special case it during ExpandWhens, Lowering,
// ExpandConnects, etc.
case (_: AnalogType, _: AnalogType) => false
- case (t1: VectorType, t2: VectorType) =>
- t1.size == t2.size && wt(t1.tpe) == wt(t2.tpe)
- case (t1: BundleType, t2: BundleType) =>
- t1.fields.size == t2.fields.size && (
- (t1.fields zip t2.fields) forall { case (f1, f2) =>
- f1.flip == f2.flip && f1.name == f2.name
- }) && ((t1.fields zip t2.fields) forall { case (f1, f2) =>
- wt(f1.tpe) == wt(f2.tpe)
- })
+ case (sink: VectorType, source: VectorType) =>
+ sink.size == source.size && compare(sink.tpe, source.tpe)
+ case (sink: BundleType, source: BundleType) =>
+ (sink.fields.size == source.fields.size) &&
+ sink.fields.zip(source.fields).forall { case (f1, f2) =>
+ (f1.flip == f2.flip) && (f1.name == f2.name) && (f1.flip match {
+ case Default => compare(f1.tpe, f2.tpe)
+ // We allow UInt<1> and AsyncReset to drive Reset but not the other way around
+ case Flip => compare(f2.tpe, f1.tpe)
+ })
+ }
case _ => false
}
+}
+class WrappedType(val t: Type) {
+ def wt(tx: Type) = new WrappedType(tx)
+ // TODO Better name?
+ /** Strict comparison except Reset accepts AsyncReset, Reset, and `UInt<1>`
+ */
+ def superTypeOf(that: WrappedType): Boolean = WrappedType.compare(this.t, that.t)
+
+ override def equals(o: Any): Boolean = o match {
+ case (t2: WrappedType) => WrappedType.compare(this.t, t2.t)
case _ => false
}
}
diff --git a/src/main/scala/firrtl/checks/CheckResets.scala b/src/main/scala/firrtl/checks/CheckResets.scala
new file mode 100644
index 00000000..d6337f9e
--- /dev/null
+++ b/src/main/scala/firrtl/checks/CheckResets.scala
@@ -0,0 +1,75 @@
+// See LICENSE for license details.
+
+package firrtl.checks
+
+import firrtl._
+import firrtl.passes.{Errors, PassException}
+import firrtl.ir._
+import firrtl.traversals.Foreachers._
+import firrtl.WrappedExpression._
+
+import scala.collection.mutable
+
+object CheckResets {
+ class NonLiteralAsyncResetValueException(info: Info, mname: String, reg: String, init: String) extends PassException(
+ s"$info: [module $mname] AsyncReset Reg '$reg' reset to non-literal '$init'")
+
+ // Map of Initialization Expression to check
+ private type RegCheckList = mutable.ListBuffer[(Expression, DefRegister)]
+ // Record driving for literal propagation
+ // Indicates *driven by*
+ private type DirectDriverMap = mutable.HashMap[WrappedExpression, Expression]
+
+}
+
+// Must run after ExpandWhens
+// Requires
+// - static single connections of ground types
+class CheckResets extends Transform {
+ def inputForm: CircuitForm = MidForm
+ def outputForm: CircuitForm = MidForm
+
+ import CheckResets._
+
+ private def onStmt(regCheck: RegCheckList, drivers: DirectDriverMap)(stmt: Statement): Unit = {
+ stmt match {
+ case DefNode(_, name, expr) => drivers += we(WRef(name)) -> expr
+ case Connect(_, lhs, rhs) => drivers += we(lhs) -> rhs
+ case reg @ DefRegister(_,_,_,_, reset, init) if reset.tpe == AsyncResetType =>
+ regCheck += init -> reg
+ case _ => // Do nothing
+ }
+ stmt.foreach(onStmt(regCheck, drivers))
+ }
+
+ private def findDriver(drivers: DirectDriverMap)(expr: Expression): Expression =
+ drivers.get(we(expr)) match {
+ case Some(lit: Literal) => lit
+ case Some(other) => findDriver(drivers)(other)
+ case None => expr
+ }
+
+ private def onMod(errors: Errors)(mod: DefModule): Unit = {
+ val regCheck = new RegCheckList()
+ val drivers = new DirectDriverMap()
+ mod.foreach(onStmt(regCheck, drivers))
+ for ((init, reg) <- regCheck) {
+ for (subInit <- Utils.create_exps(init)) {
+ findDriver(drivers)(subInit) match {
+ case lit: Literal => // All good
+ case other =>
+ val e = new NonLiteralAsyncResetValueException(reg.info, mod.name, reg.name, other.serialize)
+ errors.append(e)
+ }
+ }
+ }
+ }
+
+ def execute(state: CircuitState): CircuitState = {
+ val errors = new Errors
+ state.circuit.foreach(onMod(errors))
+ errors.trigger()
+ state
+ }
+}
+
diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala
index 8124e1e6..9268865b 100644
--- a/src/main/scala/firrtl/ir/IR.scala
+++ b/src/main/scala/firrtl/ir/IR.scala
@@ -591,6 +591,13 @@ case object ClockType extends GroundType {
def mapWidth(f: Width => Width): Type = this
def foreachWidth(f: Width => Unit): Unit = Unit
}
+/* Abstract reset, will be inferred to UInt<1> or AsyncReset */
+case object ResetType extends GroundType {
+ val width = IntWidth(1)
+ def serialize: String = "Reset"
+ def mapWidth(f: Width => Width): Type = this
+ def foreachWidth(f: Width => Unit): Unit = Unit
+}
case object AsyncResetType extends GroundType {
val width = IntWidth(1)
def serialize: String = "AsyncReset"
diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala
index b0d9085b..07784e19 100644
--- a/src/main/scala/firrtl/passes/CheckWidths.scala
+++ b/src/main/scala/firrtl/passes/CheckWidths.scala
@@ -108,6 +108,7 @@ object CheckWidths extends Pass {
sx.reset.tpe match {
case UIntType(IntWidth(w)) if w == 1 =>
case AsyncResetType =>
+ case ResetType =>
case _ => errors.append(new CheckTypes.IllegalResetType(info, target.serialize, sx.name))
}
case _ =>
diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala
index 972a018e..a17a5a2e 100644
--- a/src/main/scala/firrtl/passes/Checks.scala
+++ b/src/main/scala/firrtl/passes/Checks.scala
@@ -57,8 +57,11 @@ trait CheckHighFormLike {
s"$info: [module $mname] Primop $op argument $value < 0.")
class LsbLargerThanMsbException(info: Info, mname: String, op: String, lsb: Int, msb: Int) extends PassException(
s"$info: [module $mname] Primop $op lsb $lsb > $msb.")
- class NonLiteralAsyncResetValueException(info: Info, mname: String, reg: String, init: String) extends PassException(
- s"$info: [module $mname] AsyncReset Reg '$reg' reset to non-literal '$init'")
+ class ResetInputException(info: Info, mname: String, expr: Expression) extends PassException(
+ s"$info: [module $mname] Abstract Reset not allowed as top-level input: ${expr.serialize}")
+ class ResetExtModuleOutputException(info: Info, mname: String, expr: Expression) extends PassException(
+ s"$info: [module $mname] Abstract Reset not allowed as ExtModule output: ${expr.serialize}")
+
// Is Chirrtl allowed for this check? If not, return an error
def errorOnChirrtl(info: Info, mname: String, s: Statement): Option[PassException]
@@ -188,8 +191,6 @@ trait CheckHighFormLike {
case DefRegister(info, name, tpe, _, reset, init) =>
if (hasFlip(tpe))
errors.append(new RegWithFlipException(info, mname, name))
- if (reset.tpe == AsyncResetType && !init.isInstanceOf[Literal])
- errors.append(new NonLiteralAsyncResetValueException(info, mname, name, init.serialize))
case sx: DefMemory =>
if (sx.readLatency < 0 || sx.writeLatency <= 0)
errors.append(new IllegalMemLatencyException(info, mname, sx.name))
@@ -218,15 +219,36 @@ trait CheckHighFormLike {
p.tpe foreach checkHighFormW(p.info, mname)
}
+ // Search for ResetType Ports of direction
+ def findBadResetTypePorts(m: DefModule, dir: Direction): Seq[(Port, Expression)] = {
+ val bad = to_gender(dir)
+ for {
+ port <- m.ports
+ ref = WRef(port).copy(gender = to_gender(port.direction))
+ expr <- create_exps(ref)
+ if ((expr.tpe == ResetType) && (gender(expr) == bad))
+ } yield (port, expr)
+ }
+
def checkHighFormM(m: DefModule): Unit = {
val names = new NameSet
m foreach checkHighFormP(m.name, names)
m foreach checkHighFormS(m.info, m.name, names)
+ m match {
+ case _: Module =>
+ case ext: ExtModule =>
+ for ((port, expr) <- findBadResetTypePorts(ext, Output)) {
+ errors.append(new ResetExtModuleOutputException(port.info, ext.name, expr))
+ }
+ }
}
c.modules foreach checkHighFormM
- c.modules count (_.name == c.main) match {
- case 1 =>
+ c.modules.filter(_.name == c.main) match {
+ case Seq(topMod) =>
+ for ((port, expr) <- findBadResetTypePorts(topMod, Input)) {
+ errors.append(new ResetInputException(port.info, topMod.name, expr))
+ }
case _ => errors.append(new NoTopModuleException(c.info, c.main))
}
errors.trigger()
@@ -263,8 +285,12 @@ object CheckTypes extends Pass {
s"$info: [module $mname] Index is not of UIntType.")
class EnableNotUInt(info: Info, mname: String) extends PassException(
s"$info: [module $mname] Enable is not of UIntType.")
- class InvalidConnect(info: Info, mname: String, lhs: String, rhs: String) extends PassException(
- s"$info: [module $mname] Type mismatch. Cannot connect $lhs to $rhs.")
+ class InvalidConnect(info: Info, mname: String, con: String, lhs: Expression, rhs: Expression)
+ extends PassException({
+ val ltpe = s" ${lhs.serialize}: ${lhs.tpe.serialize}"
+ val rtpe = s" ${rhs.serialize}: ${rhs.tpe.serialize}"
+ s"$info: [module $mname] Type mismatch in '$con'.\n$ltpe\n$rtpe"
+ })
class InvalidRegInit(info: Info, mname: String) extends PassException(
s"$info: [module $mname] Type of init must match type of DefRegister.")
class PrintfArgNotGround(info: Info, mname: String) extends PassException(
@@ -308,11 +334,52 @@ object CheckTypes extends Pass {
class IllegalAttachExp(info: Info, mname: String, expName: String) extends PassException(
s"$info: [module $mname] Attach expression must be an port, wire, or port of instance: $expName.")
class IllegalResetType(info: Info, mname: String, exp: String) extends PassException(
- s"$info: [module $mname] Register resets must have type UInt<1>: $exp.")
+ s"$info: [module $mname] Register resets must have type Reset, AsyncReset, or UInt<1>: $exp.")
class IllegalUnknownType(info: Info, mname: String, exp: String) extends PassException(
s"$info: [module $mname] Uninferred type: $exp."
)
+ def legalResetType(tpe: Type): Boolean = tpe match {
+ case UIntType(IntWidth(w)) if w == 1 => true
+ case AsyncResetType => true
+ case ResetType => true
+ case UIntType(UnknownWidth) =>
+ // cannot catch here, though width may ultimately be wrong
+ true
+ case _ => false
+ }
+
+ private def bulk_equals(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Boolean = {
+ (t1, t2) match {
+ case (ClockType, ClockType) => flip1 == flip2
+ case (_: UIntType, _: UIntType) => flip1 == flip2
+ case (_: SIntType, _: SIntType) => flip1 == flip2
+ case (_: FixedType, _: FixedType) => flip1 == flip2
+ case (_: AnalogType, _: AnalogType) => true
+ case (AsyncResetType, AsyncResetType) => flip1 == flip2
+ case (ResetType, tpe) => legalResetType(tpe) && flip1 == flip2
+ case (tpe, ResetType) => legalResetType(tpe) && flip1 == flip2
+ case (t1: BundleType, t2: BundleType) =>
+ val t1_fields = (t1.fields foldLeft Map[String, (Type, Orientation)]())(
+ (map, f1) => map + (f1.name ->( (f1.tpe, f1.flip) )))
+ t2.fields forall (f2 =>
+ t1_fields get f2.name match {
+ case None => true
+ case Some((f1_tpe, f1_flip)) =>
+ bulk_equals(f1_tpe, f2.tpe, times(flip1, f1_flip), times(flip2, f2.flip))
+ }
+ )
+ case (t1: VectorType, t2: VectorType) =>
+ bulk_equals(t1.tpe, t2.tpe, flip1, flip2)
+ case (_, _) => false
+ }
+ }
+
+ def validConnect(con: Connect): Boolean = wt(con.loc.tpe).superTypeOf(wt(con.expr.tpe))
+
+ def validPartialConnect(con: PartialConnect): Boolean =
+ bulk_equals(con.loc.tpe, con.expr.tpe, Default, Default)
+
//;---------------- Helper Functions --------------
def ut: UIntType = UIntType(UnknownWidth)
def st: SIntType = SIntType(UnknownWidth)
@@ -414,48 +481,23 @@ object CheckTypes extends Pass {
e foreach check_types_e(info, mname)
}
- def bulk_equals(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Boolean = {
- //;println_all(["Inside with t1:" t1 ",t2:" t2 ",f1:" flip1 ",f2:" flip2])
- (t1, t2) match {
- case (ClockType, ClockType) => flip1 == flip2
- case (_: UIntType, _: UIntType) => flip1 == flip2
- case (_: SIntType, _: SIntType) => flip1 == flip2
- case (_: FixedType, _: FixedType) => flip1 == flip2
- case (_: AnalogType, _: AnalogType) => true
- case (t1: BundleType, t2: BundleType) =>
- val t1_fields = (t1.fields foldLeft Map[String, (Type, Orientation)]())(
- (map, f1) => map + (f1.name ->( (f1.tpe, f1.flip) )))
- t2.fields forall (f2 =>
- t1_fields get f2.name match {
- case None => true
- case Some((f1_tpe, f1_flip)) =>
- bulk_equals(f1_tpe, f2.tpe, times(flip1, f1_flip), times(flip2, f2.flip))
- }
- )
- case (t1: VectorType, t2: VectorType) =>
- bulk_equals(t1.tpe, t2.tpe, flip1, flip2)
- case (_, _) => false
- }
- }
-
def check_types_s(minfo: Info, mname: String)(s: Statement): Unit = {
val info = get_info(s) match { case NoInfo => minfo case x => x }
s match {
- case sx: Connect if wt(sx.loc.tpe) != wt(sx.expr.tpe) =>
- errors.append(new InvalidConnect(info, mname, sx.loc.serialize, sx.expr.serialize))
- case sx: PartialConnect if !bulk_equals(sx.loc.tpe, sx.expr.tpe, Default, Default) =>
- errors.append(new InvalidConnect(info, mname, sx.loc.serialize, sx.expr.serialize))
+ case sx: Connect if !validConnect(sx) =>
+ val conMsg = sx.copy(info = NoInfo).serialize
+ errors.append(new InvalidConnect(info, mname, conMsg, sx.loc, sx.expr))
+ case sx: PartialConnect if !validPartialConnect(sx) =>
+ val conMsg = sx.copy(info = NoInfo).serialize
+ errors.append(new InvalidConnect(info, mname, conMsg, sx.loc, sx.expr))
case sx: DefRegister =>
sx.tpe match {
case AnalogType(_) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name))
case t if wt(sx.tpe) != wt(sx.init.tpe) => errors.append(new InvalidRegInit(info, mname))
case t =>
}
- sx.reset.tpe match {
- case UIntType(IntWidth(w)) if w == 1 =>
- case AsyncResetType =>
- case UIntType(UnknownWidth) => // cannot catch here, though width may ultimately be wrong
- case _ => errors.append(new IllegalResetType(info, mname, sx.name))
+ if (!legalResetType(sx.reset.tpe)) {
+ errors.append(new IllegalResetType(info, mname, sx.name))
}
if (sx.clock.tpe != ClockType) {
errors.append(new RegReqClk(info, mname, sx.name))
diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala
index cf6f2ae0..8f663afd 100644
--- a/src/main/scala/firrtl/passes/InferWidths.scala
+++ b/src/main/scala/firrtl/passes/InferWidths.scala
@@ -262,6 +262,8 @@ class InferWidths extends Transform with ResolvedAnnotationPaths {
})
}
case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe)
+ case (ResetType, _) => Nil
+ case (_, ResetType) => Nil
}
def run(c: Circuit, extra: Seq[WGeq]): Circuit = {
diff --git a/src/main/scala/firrtl/proto/FromProto.scala b/src/main/scala/firrtl/proto/FromProto.scala
index 44e505f1..22c90316 100644
--- a/src/main/scala/firrtl/proto/FromProto.scala
+++ b/src/main/scala/firrtl/proto/FromProto.scala
@@ -256,6 +256,7 @@ object FromProto {
case FIXED_TYPE_FIELD_NUMBER => convert(tpe.getFixedType)
case CLOCK_TYPE_FIELD_NUMBER => ir.ClockType
case ASYNC_RESET_TYPE_FIELD_NUMBER => ir.AsyncResetType
+ case RESET_TYPE_FIELD_NUMBER => ir.ResetType
case ANALOG_TYPE_FIELD_NUMBER => convert(tpe.getAnalogType)
case BUNDLE_TYPE_FIELD_NUMBER =>
ir.BundleType(tpe.getBundleType.getFieldList.asScala.map(convert(_)))
diff --git a/src/main/scala/firrtl/proto/ToProto.scala b/src/main/scala/firrtl/proto/ToProto.scala
index c67f446c..17adb698 100644
--- a/src/main/scala/firrtl/proto/ToProto.scala
+++ b/src/main/scala/firrtl/proto/ToProto.scala
@@ -343,6 +343,9 @@ object ToProto {
case ir.AsyncResetType =>
val at = Firrtl.Type.AsyncResetType.newBuilder()
tb.setAsyncResetType(at)
+ case ir.ResetType =>
+ val rt = Firrtl.Type.ResetType.newBuilder()
+ tb.setResetType(rt)
case ir.AnalogType(width) =>
val at = Firrtl.Type.AnalogType.newBuilder()
convert(width).foreach(at.setWidth)
diff --git a/src/main/scala/firrtl/transforms/InferResets.scala b/src/main/scala/firrtl/transforms/InferResets.scala
new file mode 100644
index 00000000..70e2b76c
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/InferResets.scala
@@ -0,0 +1,253 @@
+// See LICENSE for license details.
+
+package firrtl.transforms
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Mappers._
+import firrtl.traversals.Foreachers._
+import firrtl.annotations.{ReferenceTarget, TargetToken}
+import firrtl.Utils.toTarget
+import firrtl.passes.{Pass, PassException, Errors, InferTypes}
+
+import scala.collection.mutable
+import scala.util.Try
+
+object InferResets {
+ final class DifferingDriverTypesException private (msg: String) extends PassException(msg)
+ object DifferingDriverTypesException {
+ def apply(target: ReferenceTarget, tpes: Seq[(Type, Seq[TypeDriver])]): DifferingDriverTypesException = {
+ val xs = tpes.map { case (t, ds) => s"${ds.map(_.target().serialize).mkString(", ")} of type ${t.serialize}" }
+ val msg = s"${target.serialize} driven with multiple types!" + xs.mkString("\n ", "\n ", "")
+ new DifferingDriverTypesException(msg)
+ }
+ }
+
+ /** Type hierarchy to represent the type of the thing driving a [[ResetType]] */
+ private sealed trait ResetDriver
+ // When a [[ResetType]] is driven by another ResetType, we track the target so that we can infer
+ // the same type as the driver
+ private case class TargetDriver(target: ReferenceTarget) extends ResetDriver {
+ override def toString: String = s"TargetDriver(${target.serialize})"
+ }
+ // When a [[ResetType]] is driven by something of type Bool or AsyncResetType, we keep track of it
+ // as a constraint on the type we should infer to be
+ // We keep the target around (lazily) so that we can report errors
+ private case class TypeDriver(tpe: Type, target: () => ReferenceTarget) extends ResetDriver {
+ override def toString: String = s"TypeDriver(${tpe.serialize}, $target)"
+ }
+
+
+ // Type hierarchy representing the path to a leaf type in an aggregate type structure
+ // Used by this [[InferResets]] to pinpoint instances of [[ResetType]] and their inferred type
+ private sealed trait TypeTree
+ private case class BundleTree(fields: Map[String, TypeTree]) extends TypeTree
+ private case class VectorTree(subType: TypeTree) extends TypeTree
+ // TODO ensure is only AsyncResetType or BoolType
+ private case class GroundTree(tpe: Type) extends TypeTree
+
+ private object TypeTree {
+ // Given groups of [[TargetToken]]s and Types corresponding to them, construct a [[TypeTree]]
+ // that allows us to lookup the type of each leaf node in the aggregate structure
+ // TODO make return Try[TypeTree]
+ def fromTokens(tokens: (Seq[TargetToken], Type)*): TypeTree = tokens match {
+ case Seq((Seq(), tpe)) => GroundTree(tpe)
+ // VectorTree
+ case (TargetToken.Index(_) +: _, _) +: _ =>
+ // Vectors must all have the same type, so we only process Index 0
+ // If the subtype is an aggregate, there can be multiple of each index
+ val ts = tokens.collect { case (TargetToken.Index(0) +: tail, tpe) => (tail, tpe) }
+ VectorTree(fromTokens(ts:_*))
+ // BundleTree
+ case (TargetToken.Field(_) +: _, _) +: _ =>
+ val fields =
+ tokens.groupBy { case (TargetToken.Field(n) +: t, _) => n }
+ .mapValues { ts =>
+ fromTokens(ts.map { case (_ +: t, tpe) => (t, tpe) }:_*)
+ }
+ BundleTree(fields)
+ }
+ }
+}
+
+/** Infers the concrete type of [[Reset]]s by their connections
+ * This is a global inference because ports can be of type [[Reset]]
+ * @note This transform should be run before [[DedupModules]] so that similar Modules from
+ * generator languages like Chisel can infer differently
+ */
+// TODO should we error if a DefMemory is of type AsyncReset? In CheckTypes?
+class InferResets extends Transform {
+ def inputForm: CircuitForm = HighForm
+ def outputForm: CircuitForm = HighForm
+
+ import InferResets._
+
+ // Collect all drivers for circuit elements of type ResetType
+ private def analyze(c: Circuit): Map[ReferenceTarget, List[ResetDriver]] = {
+ type DriverMap = mutable.HashMap[ReferenceTarget, mutable.ListBuffer[ResetDriver]]
+ def onMod(mod: DefModule): DriverMap = {
+ val instMap = mutable.Map[String, String]()
+ // We need to convert submodule port targets into targets on the Module port itself
+ def makeTarget(expr: Expression): ReferenceTarget = {
+ val target = toTarget(c.main, mod.name)(expr)
+ Utils.kind(expr) match {
+ case InstanceKind =>
+ val mod = instMap(target.ref)
+ val port = target.component.head match {
+ case TargetToken.Field(name) => name
+ case bad => Utils.throwInternalError(s"Unexpected token $bad")
+ }
+ target.copy(module = mod, ref = port, component = target.component.tail)
+ case _ => target
+ }
+ }
+ def onStmt(map: DriverMap)(stmt: Statement): Unit = {
+ // Mark driver of a ResetType leaf
+ def markResetDriver(lhs: Expression, rhs: Expression): Unit = {
+ val lflip = Utils.to_flip(Utils.gender(lhs))
+ if ((lflip == Default && lhs.tpe == ResetType) ||
+ (lflip == Flip && rhs.tpe == ResetType)) {
+ val (loc, exp) = lflip match {
+ case Default => (lhs, rhs)
+ case Flip => (rhs, lhs)
+ }
+ val target = makeTarget(loc)
+ val driver = exp.tpe match {
+ case ResetType => TargetDriver(makeTarget(exp))
+ case tpe => TypeDriver(tpe, () => makeTarget(exp))
+ }
+ map.getOrElseUpdate(target, mutable.ListBuffer()) += driver
+ }
+ }
+ stmt match {
+ // TODO
+ // - Each connect duplicates a bunch of code from ExpandConnects, could be cleaner
+ // - The full create_exps duplication is inefficient, there has to be a better way
+ case Connect(_, lhs, rhs) =>
+ val locs = Utils.create_exps(lhs)
+ val exps = Utils.create_exps(rhs)
+ for ((loc, exp) <- locs.zip(exps)) {
+ markResetDriver(loc, exp)
+ }
+ case PartialConnect(_, lhs, rhs) =>
+ val points = Utils.get_valid_points(lhs.tpe, rhs.tpe, Default, Default)
+ val locs = Utils.create_exps(lhs)
+ val exps = Utils.create_exps(rhs)
+ for ((i, j) <- points) {
+ markResetDriver(locs(i), exps(j))
+ }
+ case WDefInstance(_, inst, module, _) =>
+ instMap += (inst -> module)
+ case Conditionally(_, _, con, alt) =>
+ val conMap = new DriverMap
+ val altMap = new DriverMap
+ onStmt(conMap)(con)
+ onStmt(altMap)(alt)
+ // Default to outerscope if not found in alt
+ val altLookup = altMap.orElse(map).lift
+ for (key <- conMap.keys ++ altMap.keys) {
+ val ds = map.getOrElseUpdate(key, mutable.ListBuffer())
+ conMap.get(key).foreach(ds ++= _)
+ altLookup(key).foreach(ds ++= _)
+ }
+ case other => other.foreach(onStmt(map))
+ }
+ }
+ val types = new DriverMap
+ mod.foreach(onStmt(types))
+ types
+ }
+ c.modules.foldLeft(Map[ReferenceTarget, List[ResetDriver]]()) {
+ case (map, mod) => map ++ onMod(mod).mapValues(_.toList)
+ }
+ }
+
+ // Determine the type driving a given ResetType
+ private def resolve(map: Map[ReferenceTarget, List[ResetDriver]]): Try[Map[ReferenceTarget, Type]] = {
+ val res = mutable.Map[ReferenceTarget, Type]()
+ val errors = new Errors
+ def rec(target: ReferenceTarget): Type = {
+ val drivers = map(target)
+ res.getOrElseUpdate(target, {
+ val tpes = drivers.map {
+ case TargetDriver(t) => TypeDriver(rec(t), () => t)
+ case td: TypeDriver => td
+ }.groupBy(_.tpe)
+ if (tpes.keys.size != 1) {
+ // Multiple types of driver!
+ errors.append(DifferingDriverTypesException(target, tpes.toSeq))
+ }
+ tpes.keys.head
+ })
+ }
+ for ((target, _) <- map) {
+ rec(target)
+ }
+ Try {
+ errors.trigger()
+ res.toMap
+ }
+ }
+
+ private def fixupType(tpe: Type, tree: TypeTree): Type = (tpe, tree) match {
+ case (BundleType(fields), BundleTree(map)) =>
+ val fieldsx =
+ fields.map(f => map.get(f.name) match {
+ case Some(t) => f.copy(tpe = fixupType(f.tpe, t))
+ case None => f
+ })
+ BundleType(fieldsx)
+ case (VectorType(vtpe, size), VectorTree(t)) =>
+ VectorType(fixupType(vtpe, t), size)
+ case (_, GroundTree(t)) => t
+ case x => throw new Exception(s"Error! Unexpected pair $x")
+ }
+
+ // Assumes all ReferenceTargets are in the same module
+ private def makeDeclMap(map: Map[ReferenceTarget, Type]): Map[String, TypeTree] =
+ map.groupBy(_._1.ref).mapValues { ts =>
+ TypeTree.fromTokens(ts.toSeq.map { case (target, tpe) => (target.component, tpe) }:_*)
+ }
+
+ private def implPort(map: Map[String, TypeTree])(port: Port): Port =
+ map.get(port.name)
+ .map(tree => port.copy(tpe = fixupType(port.tpe, tree)))
+ .getOrElse(port)
+ private def implStmt(map: Map[String, TypeTree])(stmt: Statement): Statement =
+ stmt.map(implStmt(map)) match {
+ case decl: IsDeclaration if map.contains(decl.name) =>
+ val tree = map(decl.name)
+ decl match {
+ case reg: DefRegister => reg.copy(tpe = fixupType(reg.tpe, tree))
+ case wire: DefWire => wire.copy(tpe = fixupType(wire.tpe, tree))
+ // TODO Can this really happen?
+ case mem: DefMemory => mem.copy(dataType = fixupType(mem.dataType, tree))
+ case other => other
+ }
+ case other => other
+ }
+
+ private def implement(c: Circuit, map: Map[ReferenceTarget, Type]): Circuit = {
+ val modMaps = map.groupBy(_._1.module)
+ def onMod(mod: DefModule): DefModule = {
+ modMaps.get(mod.name).map { tmap =>
+ val declMap = makeDeclMap(tmap)
+ mod.map(implPort(declMap)).map(implStmt(declMap))
+ }.getOrElse(mod)
+ }
+ c.map(onMod)
+ }
+
+ private def fixupPasses: Seq[Pass] = Seq(
+ InferTypes
+ )
+
+ def execute(state: CircuitState): CircuitState = {
+ val c = state.circuit
+ val analysis = analyze(c)
+ val inferred = resolve(analysis)
+ val result = inferred.map(m => implement(c, m)).get
+ val fixedup = fixupPasses.foldLeft(result)((c, p) => p.run(c))
+ state.copy(circuit = fixedup)
+ }
+}
diff --git a/src/test/scala/firrtlTests/AsyncResetSpec.scala b/src/test/scala/firrtlTests/AsyncResetSpec.scala
index c1078a03..6fcb647a 100644
--- a/src/test/scala/firrtlTests/AsyncResetSpec.scala
+++ b/src/test/scala/firrtlTests/AsyncResetSpec.scala
@@ -3,7 +3,6 @@
package firrtlTests
import firrtl._
-import firrtl.ir._
import FirrtlCheckers._
class AsyncResetSpec extends FirrtlFlatSpec {
@@ -30,6 +29,25 @@ class AsyncResetSpec extends FirrtlFlatSpec {
result should containLine ("always @(posedge clock or posedge reset) begin")
}
+ it should "work in nested and flipped aggregates with regular and partial connect" in {
+ val result = compileBody(s"""
+ |output fizz : { flip foo : { a : AsyncReset, flip b: AsyncReset }[2], bar : { a : AsyncReset, flip b: AsyncReset }[2] }
+ |output buzz : { flip foo : { a : AsyncReset, flip b: AsyncReset }[2], bar : { a : AsyncReset, flip b: AsyncReset }[2] }
+ |fizz.bar <= fizz.foo
+ |buzz.bar <- buzz.foo
+ |""".stripMargin
+ )
+
+ result should containLine ("assign fizz_foo_0_b = fizz_bar_0_b;")
+ result should containLine ("assign fizz_foo_1_b = fizz_bar_1_b;")
+ result should containLine ("assign fizz_bar_0_a = fizz_foo_0_a;")
+ result should containLine ("assign fizz_bar_1_a = fizz_foo_1_a;")
+ result should containLine ("assign buzz_foo_0_b = buzz_bar_0_b;")
+ result should containLine ("assign buzz_foo_1_b = buzz_bar_1_b;")
+ result should containLine ("assign buzz_bar_0_a = buzz_foo_0_a;")
+ result should containLine ("assign buzz_bar_1_a = buzz_foo_1_a;")
+ }
+
it should "support casting to other types" in {
val result = compileBody(s"""
|input a : AsyncReset
@@ -77,7 +95,7 @@ class AsyncResetSpec extends FirrtlFlatSpec {
}
"Non-literals" should "NOT be allowed as reset values for AsyncReset" in {
- an [passes.CheckHighForm.NonLiteralAsyncResetValueException] shouldBe thrownBy {
+ an [checks.CheckResets.NonLiteralAsyncResetValueException] shouldBe thrownBy {
compileBody(s"""
|input clock : Clock
|input reset : AsyncReset
@@ -91,6 +109,121 @@ class AsyncResetSpec extends FirrtlFlatSpec {
}
}
+ "Late non-literals connections" should "NOT be allowed as reset values for AsyncReset" in {
+ an [checks.CheckResets.NonLiteralAsyncResetValueException] shouldBe thrownBy {
+ compileBody(s"""
+ |input clock : Clock
+ |input reset : AsyncReset
+ |input x : UInt<8>
+ |input y : UInt<8>
+ |output z : UInt<8>
+ |wire a : UInt<8>
+ |reg r : UInt<8>, clock with : (reset => (reset, a))
+ |a <= y
+ |r <= x
+ |z <= r""".stripMargin
+ )
+ }
+ }
+
+ "Hidden Non-literals" should "NOT be allowed as reset values for AsyncReset" in {
+ an [checks.CheckResets.NonLiteralAsyncResetValueException] shouldBe thrownBy {
+ compileBody(s"""
+ |input clock : Clock
+ |input reset : AsyncReset
+ |input x : UInt<1>[4]
+ |input y : UInt<1>
+ |output z : UInt<1>[4]
+ |wire literal : UInt<1>[4]
+ |literal[0] <= UInt<1>("h00")
+ |literal[1] <= y
+ |literal[2] <= UInt<1>("h00")
+ |literal[3] <= UInt<1>("h00")
+ |reg r : UInt<1>[4], clock with : (reset => (reset, literal))
+ |r <= x
+ |z <= r""".stripMargin
+ )
+ }
+ }
+ "Wire connected to non-literal" should "NOT be allowed as reset values for AsyncReset" in {
+ an [checks.CheckResets.NonLiteralAsyncResetValueException] shouldBe thrownBy {
+ compileBody(s"""
+ |input clock : Clock
+ |input reset : AsyncReset
+ |input x : UInt<1>
+ |input y : UInt<1>
+ |input cond : UInt<1>
+ |output z : UInt<1>
+ |wire w : UInt<1>
+ |w <= UInt(1)
+ |when cond :
+ | w <= y
+ |reg r : UInt<1>, clock with : (reset => (reset, w))
+ |r <= x
+ |z <= r""".stripMargin
+ )
+ }
+ }
+
+ "Complex literals" should "be allowed as reset values for AsyncReset" in {
+ val result = compileBody(s"""
+ |input clock : Clock
+ |input reset : AsyncReset
+ |input x : UInt<1>[4]
+ |output z : UInt<1>[4]
+ |wire literal : UInt<1>[4]
+ |literal[0] <= UInt<1>("h00")
+ |literal[1] <= UInt<1>("h00")
+ |literal[2] <= UInt<1>("h00")
+ |literal[3] <= UInt<1>("h00")
+ |reg r : UInt<1>[4], clock with : (reset => (reset, literal))
+ |r <= x
+ |z <= r""".stripMargin
+ )
+ result should containLine ("always @(posedge clock or posedge reset) begin")
+ }
+
+ "Complex literals of complex literals" should "be allowed as reset values for AsyncReset" in {
+ val result = compileBody(s"""
+ |input clock : Clock
+ |input reset : AsyncReset
+ |input x : UInt<1>[4]
+ |output z : UInt<1>[4]
+ |wire literal : UInt<1>[2]
+ |literal[0] <= UInt<1>("h01")
+ |literal[1] <= UInt<1>("h01")
+ |wire complex_literal : UInt<1>[4]
+ |complex_literal[0] <= literal[0]
+ |complex_literal[1] <= literal[1]
+ |complex_literal[2] <= UInt<1>("h00")
+ |complex_literal[3] <= UInt<1>("h00")
+ |reg r : UInt<1>[4], clock with : (reset => (reset, complex_literal))
+ |r <= x
+ |z <= r""".stripMargin
+ )
+ result should containLine ("always @(posedge clock or posedge reset) begin")
+ }
+ "Literals of bundle literals" should "be allowed as reset values for AsyncReset" in {
+ val result = compileBody(s"""
+ |input clock : Clock
+ |input reset : AsyncReset
+ |input x : UInt<1>[4]
+ |output z : UInt<1>[4]
+ |wire bundle : {a: UInt<1>, b: UInt<1>}
+ |bundle.a <= UInt<1>("h01")
+ |bundle.b <= UInt<1>("h01")
+ |wire complex_literal : UInt<1>[4]
+ |complex_literal[0] <= bundle.a
+ |complex_literal[1] <= bundle.b
+ |complex_literal[2] <= UInt<1>("h00")
+ |complex_literal[3] <= UInt<1>("h00")
+ |reg r : UInt<1>[4], clock with : (reset => (reset, complex_literal))
+ |r <= x
+ |z <= r""".stripMargin
+ )
+ result should containLine ("always @(posedge clock or posedge reset) begin")
+ }
+
"Every async reset reg" should "generate its own always block" in {
val result = compileBody(s"""
diff --git a/src/test/scala/firrtlTests/InferResetsSpec.scala b/src/test/scala/firrtlTests/InferResetsSpec.scala
new file mode 100644
index 00000000..ac13033a
--- /dev/null
+++ b/src/test/scala/firrtlTests/InferResetsSpec.scala
@@ -0,0 +1,355 @@
+// See LICENSE for license details.
+
+package firrtlTests
+
+import firrtl._
+import firrtl.ir._
+import firrtl.passes.{CheckHighForm, CheckTypes}
+import firrtl.transforms.InferResets
+import FirrtlCheckers._
+
+// TODO
+// - Test nodes in the connection
+// - Test with whens (is this allowed?)
+class InferResetsSpec extends FirrtlFlatSpec {
+ def compile(input: String, compiler: Compiler = new MiddleFirrtlCompiler): CircuitState =
+ compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty)
+
+ behavior of "ResetType"
+
+ val BoolType = UIntType(IntWidth(1))
+
+ it should "support casting to other types" in {
+ val result = compile(s"""
+ |circuit top:
+ | module top:
+ | input a : UInt<1>
+ | output v : UInt<1>
+ | output w : SInt<1>
+ | output x : Clock
+ | output y : Fixed<1><<0>>
+ | output z : AsyncReset
+ | wire r : Reset
+ | r <= a
+ | v <= asUInt(r)
+ | w <= asSInt(r)
+ | x <= asClock(r)
+ | y <= asFixedPoint(r, 0)
+ | z <= asAsyncReset(r)""".stripMargin
+ )
+ println(result.getEmittedCircuit)
+ result should containLine ("wire r : UInt<1>")
+ result should containLine ("r <= a")
+ result should containLine ("v <= asUInt(r)")
+ result should containLine ("w <= asSInt(r)")
+ result should containLine ("x <= asClock(r)")
+ result should containLine ("y <= asSInt(r)")
+ result should containLine ("z <= asAsyncReset(r)")
+ }
+
+ it should "work across Module boundaries" in {
+ val result = compile(s"""
+ |circuit top :
+ | module child :
+ | input clock : Clock
+ | input childReset : Reset
+ | input x : UInt<8>
+ | output z : UInt<8>
+ | reg r : UInt<8>, clock with : (reset => (childReset, UInt(123)))
+ | r <= x
+ | z <= r
+ | module top :
+ | input clock : Clock
+ | input reset : UInt<1>
+ | input x : UInt<8>
+ | output z : UInt<8>
+ | inst c of child
+ | c.clock <= clock
+ | c.childReset <= reset
+ | c.x <= x
+ | z <= c.z
+ |""".stripMargin
+ )
+ result should containTree { case Port(_, "childReset", Input, BoolType) => true }
+ }
+
+ it should "work across multiple Module boundaries" in {
+ val result = compile(s"""
+ |circuit top :
+ | module child :
+ | input resetIn : Reset
+ | output resetOut : Reset
+ | resetOut <= resetIn
+ | module top :
+ | input clock : Clock
+ | input reset : UInt<1>
+ | input x : UInt<8>
+ | output z : UInt<8>
+ | inst c of child
+ | c.resetIn <= reset
+ | reg r : UInt<8>, clock with : (reset => (c.resetOut, UInt(123)))
+ | r <= x
+ | z <= r
+ |""".stripMargin
+ )
+ result should containTree { case Port(_, "resetIn", Input, BoolType) => true }
+ result should containTree { case Port(_, "resetOut", Output, BoolType) => true }
+ }
+
+ it should "work in nested and flipped aggregates with regular and partial connect" in {
+ val result = compile(s"""
+ |circuit top :
+ | module top :
+ | output fizz : { flip foo : { a : AsyncReset, flip b: Reset }[2], bar : { a : Reset, flip b: AsyncReset }[2] }
+ | output buzz : { flip foo : { a : AsyncReset, c: UInt<1>, flip b: Reset }[2], bar : { a : Reset, flip b: AsyncReset, c: UInt<8> }[2] }
+ | fizz.bar <= fizz.foo
+ | buzz.bar <- buzz.foo
+ |""".stripMargin,
+ new LowFirrtlCompiler
+ )
+ result should containTree { case Port(_, "fizz_foo_0_a", Input, AsyncResetType) => true }
+ result should containTree { case Port(_, "fizz_foo_0_b", Output, AsyncResetType) => true }
+ result should containTree { case Port(_, "fizz_foo_1_a", Input, AsyncResetType) => true }
+ result should containTree { case Port(_, "fizz_foo_1_b", Output, AsyncResetType) => true }
+ result should containTree { case Port(_, "fizz_bar_0_a", Output, AsyncResetType) => true }
+ result should containTree { case Port(_, "fizz_bar_0_b", Input, AsyncResetType) => true }
+ result should containTree { case Port(_, "fizz_bar_1_a", Output, AsyncResetType) => true }
+ result should containTree { case Port(_, "fizz_bar_1_b", Input, AsyncResetType) => true }
+ result should containTree { case Port(_, "buzz_foo_0_a", Input, AsyncResetType) => true }
+ result should containTree { case Port(_, "buzz_foo_0_b", Output, AsyncResetType) => true }
+ result should containTree { case Port(_, "buzz_foo_1_a", Input, AsyncResetType) => true }
+ result should containTree { case Port(_, "buzz_foo_1_b", Output, AsyncResetType) => true }
+ result should containTree { case Port(_, "buzz_bar_0_a", Output, AsyncResetType) => true }
+ result should containTree { case Port(_, "buzz_bar_0_b", Input, AsyncResetType) => true }
+ result should containTree { case Port(_, "buzz_bar_1_a", Output, AsyncResetType) => true }
+ result should containTree { case Port(_, "buzz_bar_1_b", Input, AsyncResetType) => true }
+ }
+
+ it should "NOT allow last connect semantics to pick the right type for Reset" in {
+ an [InferResets.DifferingDriverTypesException] shouldBe thrownBy {
+ compile(s"""
+ |circuit top :
+ | module top :
+ | input reset0 : AsyncReset
+ | input reset1 : UInt<1>
+ | output out : Reset
+ | wire w1 : Reset
+ | wire w2 : Reset
+ | w1 <= reset0
+ | w2 <= reset1
+ | out <= w1
+ | out <= w2
+ |""".stripMargin
+ )
+ }
+ }
+
+ it should "NOT support last connect semantics across whens" in {
+ an [InferResets.DifferingDriverTypesException] shouldBe thrownBy {
+ compile(s"""
+ |circuit top :
+ | module top :
+ | input reset0 : AsyncReset
+ | input reset1 : UInt<1>
+ | input en0 : UInt<1>
+ | output out : Reset
+ | wire w1 : Reset
+ | wire w2 : Reset
+ | w1 <= reset0
+ | w2 <= reset1
+ | out <= w1
+ | when en0 :
+ | out <= w2
+ |""".stripMargin
+ )
+ }
+ }
+
+ it should "not allow different Reset Types to drive a single Reset" in {
+ an [InferResets.DifferingDriverTypesException] shouldBe thrownBy {
+ val result = compile(s"""
+ |circuit top :
+ | module top :
+ | input reset0 : AsyncReset
+ | input reset1 : UInt<1>
+ | input en : UInt<1>
+ | output out : Reset
+ | wire w1 : Reset
+ | wire w2 : Reset
+ | w1 <= reset0
+ | w2 <= reset1
+ | out <= w1
+ | when en :
+ | out <= w2
+ |""".stripMargin
+ )
+ }
+ }
+
+ it should "allow ResetType to drive AsyncResets or UInt<1>" in {
+ val result1 = compile(s"""
+ |circuit top :
+ | module top :
+ | input in : UInt<1>
+ | output out : UInt<1>
+ | wire w : Reset
+ | w <= in
+ | out <= w
+ |""".stripMargin
+ )
+ result1 should containTree { case DefWire(_, "w", BoolType) => true }
+ val result2 = compile(s"""
+ |circuit top :
+ | module top :
+ | output foo : { flip a : UInt<1> }
+ | input bar : { flip a : UInt<1> }
+ | wire w : { flip a : Reset }
+ | foo <= w
+ | w <= bar
+ |""".stripMargin
+ )
+ val AggType = BundleType(Seq(Field("a", Flip, BoolType)))
+ result2 should containTree { case DefWire(_, "w", AggType) => true }
+ val result3 = compile(s"""
+ |circuit top :
+ | module top :
+ | input in : UInt<1>
+ | output out : UInt<1>
+ | wire w : Reset
+ | w <- in
+ | out <- w
+ |""".stripMargin
+ )
+ result3 should containTree { case DefWire(_, "w", BoolType) => true }
+ }
+
+ it should "error if a ResetType driving UInt<1> infers to AsyncReset" in {
+ an [Exception] shouldBe thrownBy {
+ compile(s"""
+ |circuit top :
+ | module top :
+ | input in : AsyncReset
+ | output out : UInt<1>
+ | wire w : Reset
+ | w <= in
+ | out <= w
+ |""".stripMargin
+ )
+ }
+ }
+
+ it should "error if a ResetType driving AsyncReset infers to UInt<1>" in {
+ an [Exception] shouldBe thrownBy {
+ compile(s"""
+ |circuit top :
+ | module top :
+ | input in : UInt<1>
+ | output out : AsyncReset
+ | wire w : Reset
+ | w <= in
+ | out <= w
+ |""".stripMargin
+ )
+ }
+ }
+
+ it should "not allow ResetType as an Input or ExtModule output" in {
+ // TODO what exception should be thrown here?
+ an [CheckHighForm.ResetInputException] shouldBe thrownBy {
+ val result = compile(s"""
+ |circuit top :
+ | module top :
+ | input in : { foo : Reset }
+ | output out : Reset
+ | out <= in.foo
+ |""".stripMargin
+ )
+ }
+ an [CheckHighForm.ResetExtModuleOutputException] shouldBe thrownBy {
+ val result = compile(s"""
+ |circuit top :
+ | extmodule ext :
+ | output out : { foo : Reset }
+ | module top :
+ | output out : Reset
+ | inst e of ext
+ | out <= e.out.foo
+ |""".stripMargin
+ )
+ }
+ }
+
+ it should "not allow Vecs to infer different Reset Types" in {
+ an [CheckTypes.InvalidConnect] shouldBe thrownBy {
+ val result = compile(s"""
+ |circuit top :
+ | module top :
+ | input reset0 : AsyncReset
+ | input reset1 : UInt<1>
+ | output out : Reset[2]
+ | out[0] <= reset0
+ | out[1] <= reset1
+ |""".stripMargin
+ )
+ }
+ }
+
+ // Or is this actually an error? The behavior is that out is inferred as AsyncReset[2]
+ ignore should "not allow Vecs only be partially inferred" in {
+ // Some exception should be thrown, TODO figure out which one
+ an [Exception] shouldBe thrownBy {
+ val result = compile(s"""
+ |circuit top :
+ | module top :
+ | input reset : AsyncReset
+ | output out : Reset[2]
+ | out is invalid
+ | out[0] <= reset
+ |""".stripMargin
+ )
+ }
+ }
+
+
+ it should "support inferring modules that would dedup differently" in {
+ val result = compile(s"""
+ |circuit top :
+ | module child :
+ | input clock : Clock
+ | input childReset : Reset
+ | input x : UInt<8>
+ | output z : UInt<8>
+ | reg r : UInt<8>, clock with : (reset => (childReset, UInt(123)))
+ | r <= x
+ | z <= r
+ | module child_1 :
+ | input clock : Clock
+ | input childReset : Reset
+ | input x : UInt<8>
+ | output z : UInt<8>
+ | reg r : UInt<8>, clock with : (reset => (childReset, UInt(123)))
+ | r <= x
+ | z <= r
+ | module top :
+ | input clock : Clock
+ | input reset1 : UInt<1>
+ | input reset2 : AsyncReset
+ | input x : UInt<8>[2]
+ | output z : UInt<8>[2]
+ | inst c of child
+ | c.clock <= clock
+ | c.childReset <= reset1
+ | c.x <= x[0]
+ | z[0] <= c.z
+ | inst c2 of child_1
+ | c2.clock <= clock
+ | c2.childReset <= reset2
+ | c2.x <= x[1]
+ | z[1] <= c2.z
+ |""".stripMargin
+ )
+ result should containTree { case Port(_, "childReset", Input, BoolType) => true }
+ result should containTree { case Port(_, "childReset", Input, AsyncResetType) => true }
+ }
+}
+
diff --git a/src/test/scala/firrtlTests/ProtoBufSpec.scala b/src/test/scala/firrtlTests/ProtoBufSpec.scala
index 526a194c..2f347c6d 100644
--- a/src/test/scala/firrtlTests/ProtoBufSpec.scala
+++ b/src/test/scala/firrtlTests/ProtoBufSpec.scala
@@ -180,4 +180,9 @@ class ProtoBufSpec extends FirrtlFlatSpec {
val port = ir.Port(ir.NoInfo, "reset", ir.Input, ir.AsyncResetType)
FromProto.convert(ToProto.convert(port).build) should equal (port)
}
+
+ it should "support ResetTypes" in {
+ val port = ir.Port(ir.NoInfo, "reset", ir.Input, ir.ResetType)
+ FromProto.convert(ToProto.convert(port).build) should equal (port)
+ }
}