aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJack Koenig2020-02-18 11:45:11 -0800
committerGitHub2020-02-18 11:45:11 -0800
commit6c7a15a2d9123fd94770dda15f9e9070ae6b2bdc (patch)
tree01903cee1942688d6ae09a724e16b06f0a56fb1b /src
parent38cf27e8fbdf201761b018afc93107f15cf17cb7 (diff)
Remove last connect semantics from reset inference (#1396)
* Revert "Infer resets last connect semantics (#1291)" * Fix handling of invalidated and undriven components of type Reset * Run CheckTypes after InferResets * Make reset inference bidirectional on connect * Support AsyncResetType in RemoveValidIf * Fix InferResets for parent constraints on child ports * Apply suggestions from code review * Add ScalaDoc to InferResets Co-authored-by: Albert Magyar <albert.magyar@gmail.com> Co-authored-by: Schuyler Eldridge <schuyler.eldridge@gmail.com>
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala3
-rw-r--r--src/main/scala/firrtl/passes/RemoveValidIf.scala2
-rw-r--r--src/main/scala/firrtl/transforms/InferResets.scala133
-rw-r--r--src/test/scala/firrtlTests/InferResetsSpec.scala126
4 files changed, 205 insertions, 59 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index 0d9b971b..14a8e637 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -49,7 +49,8 @@ class ResolveAndCheck extends CoreTransform {
new passes.TrimIntervals(),
new passes.InferWidths,
passes.CheckWidths,
- new firrtl.transforms.InferResets)
+ new firrtl.transforms.InferResets,
+ passes.CheckTypes)
}
/** Expands aggregate connects, removes dynamic accesses, and when
diff --git a/src/main/scala/firrtl/passes/RemoveValidIf.scala b/src/main/scala/firrtl/passes/RemoveValidIf.scala
index 37a3f931..42eae7e5 100644
--- a/src/main/scala/firrtl/passes/RemoveValidIf.scala
+++ b/src/main/scala/firrtl/passes/RemoveValidIf.scala
@@ -13,6 +13,7 @@ object RemoveValidIf extends Pass {
val SIntZero = SIntLiteral(BigInt(0), IntWidth(1))
val ClockZero = DoPrim(PrimOps.AsClock, Seq(UIntZero), Seq.empty, ClockType)
val FixedZero = FixedLiteral(BigInt(0), IntWidth(1), IntWidth(0))
+ val AsyncZero = DoPrim(PrimOps.AsAsyncReset, Seq(UIntZero), Nil, AsyncResetType)
/** Returns an [[firrtl.ir.Expression Expression]] equal to zero for a given [[firrtl.ir.GroundType GroundType]]
* @note Accepts [[firrtl.ir.Type Type]] but dyanmically expects [[firrtl.ir.GroundType GroundType]]
@@ -22,6 +23,7 @@ object RemoveValidIf extends Pass {
case _: SIntType => SIntZero
case ClockType => ClockZero
case _: FixedType => FixedZero
+ case AsyncResetType => AsyncZero
case other => throwInternalError(s"Unexpected type $other")
}
diff --git a/src/main/scala/firrtl/transforms/InferResets.scala b/src/main/scala/firrtl/transforms/InferResets.scala
index fbc915e2..026b15fc 100644
--- a/src/main/scala/firrtl/transforms/InferResets.scala
+++ b/src/main/scala/firrtl/transforms/InferResets.scala
@@ -7,14 +7,17 @@ import firrtl.ir._
import firrtl.Mappers._
import firrtl.traversals.Foreachers._
import firrtl.annotations.{ReferenceTarget, TargetToken}
-import firrtl.Utils.toTarget
-import firrtl.passes.{Pass, PassException, Errors, InferTypes}
+import firrtl.Utils.{toTarget, throwInternalError}
+import firrtl.passes.{Pass, PassException, InferTypes}
+import firrtl.graph.MutableDiGraph
import scala.collection.mutable
import scala.util.Try
object InferResets {
+ @deprecated("This is no longer in use and will be removed", "1.3")
final class DifferingDriverTypesException private (msg: String) extends PassException(msg)
+ @deprecated("This is no longer in use and will be removed", "1.3")
object DifferingDriverTypesException {
def apply(target: ReferenceTarget, tpes: Seq[(Type, Seq[TypeDriver])]): DifferingDriverTypesException = {
val xs = tpes.map { case (t, ds) => s"${ds.map(_.target().serialize).mkString(", ")} of type ${t.serialize}" }
@@ -23,6 +26,15 @@ object InferResets {
}
}
+ final class InferResetsException private (msg: String) extends PassException(msg)
+ object InferResetsException {
+ private[InferResets] def apply(path: Seq[Node]): InferResetsException = {
+ val ps = path.collect { case Var(t) => t.serialize }.mkString("\n - ", "\n - ", "")
+ val msg = s"Reset-typed components connected to both AsyncReset and UInt<1>. Offending path:$ps"
+ new InferResetsException(msg)
+ }
+ }
+
/** Type hierarchy to represent the type of the thing driving a [[ResetType]] */
private sealed trait ResetDriver
// When a [[ResetType]] is driven by another ResetType, we track the target so that we can infer
@@ -34,7 +46,7 @@ object InferResets {
// as a constraint on the type we should infer to be
// We keep the target around (lazily) so that we can report errors
private case class TypeDriver(tpe: Type, target: () => ReferenceTarget) extends ResetDriver {
- override def toString: String = s"TypeDriver(${tpe.serialize}, $target)"
+ override def toString: String = s"TypeDriver(${tpe.serialize}, () => ?)"
}
// When a [[ResetType]] is invalidated, we record the InvalidDrive
// If there are no types but invalid drivers, we default to BoolType
@@ -42,6 +54,15 @@ object InferResets {
def defaultType: Type = Utils.BoolType
}
+ // Private type hierarchy used as DiGraph nodes for type inference
+ private sealed trait Node
+ private case class Var(target: ReferenceTarget) extends Node {
+ override def toString = target.serialize
+ }
+ private case class Typ(tpe: Type) extends Node {
+ override def toString = tpe.serialize
+ }
+
// Type hierarchy representing the path to a leaf type in an aggregate type structure
// Used by this [[InferResets]] to pinpoint instances of [[ResetType]] and their inferred type
private sealed trait TypeTree
@@ -74,8 +95,16 @@ object InferResets {
}
}
-/** Infers the concrete type of [[Reset]]s by their connections
- * This is a global inference because ports can be of type [[Reset]]
+/** Infers the concrete type of [[ResetType]]s by their connections
+ *
+ * There are 3 cases
+ * 1. An abstract reset driven by and/or driving only asynchronous resets will be inferred as
+ * asynchronous reset
+ * 1. An abstract reset driven by and/or driving both asynchronous and synchronous resets will
+ * error
+ * 1. Otherwise, the reset is inferred as synchronous (i.e. the abstract reset is only invalidated
+ * or is driven by or drives only synchronous resets)
+ * @note This is a global inference because ports can be of type [[ResetType]]
* @note This transform should be run before [[DedupModules]] so that similar Modules from
* generator languages like Chisel can infer differently
*/
@@ -88,8 +117,8 @@ class InferResets extends Transform {
// Collect all drivers for circuit elements of type ResetType
private def analyze(c: Circuit): Map[ReferenceTarget, List[ResetDriver]] = {
- type DriverMap = mutable.HashMap[ReferenceTarget, List[ResetDriver]]
- def onMod(mod: DefModule): DriverMap = {
+ type DriverMap = mutable.HashMap[ReferenceTarget, mutable.ListBuffer[ResetDriver]]
+ def onMod(types: DriverMap)(mod: DefModule): DriverMap = {
val instMap = mutable.Map[String, String]()
// We need to convert submodule port targets into targets on the Module port itself
def makeTarget(expr: Expression): ReferenceTarget = {
@@ -119,7 +148,7 @@ class InferResets extends Transform {
case ResetType => TargetDriver(makeTarget(exp))
case tpe => TypeDriver(tpe, () => makeTarget(exp))
}
- map(makeTarget(loc)) = driver :: Nil
+ map.getOrElseUpdate(makeTarget(loc), mutable.ListBuffer()) += driver
}
}
stmt match {
@@ -146,7 +175,7 @@ class InferResets extends Transform {
// Unlike in markResetDriver, flow is irrelevant for invalidation
if (expr.tpe == ResetType) {
val target = makeTarget(expr)
- map(target) = InvalidDriver :: Nil
+ map.getOrElseUpdate(target, mutable.ListBuffer()) += InvalidDriver
}
}
case WDefInstance(_, inst, module, _) =>
@@ -156,54 +185,70 @@ class InferResets extends Transform {
val altMap = new DriverMap
onStmt(conMap)(con)
onStmt(altMap)(alt)
- // Default to outerscope if not found on either side
- val conLookup = conMap.orElse(map).lift
- val altLookup = altMap.orElse(map).lift
for (key <- conMap.keys ++ altMap.keys) {
- val values = conLookup(key).getOrElse(Nil) ++ altLookup(key).getOrElse(Nil)
- map(key) = values
+ val ds = map.getOrElseUpdate(key, mutable.ListBuffer())
+ conMap.get(key).foreach(ds ++= _)
+ altMap.get(key).foreach(ds ++= _)
}
case other => other.foreach(onStmt(map))
}
}
- val types = new DriverMap
mod.foreach(onStmt(types))
types
}
- c.modules.foldLeft(Map[ReferenceTarget, List[ResetDriver]]()) {
- case (map, mod) => map ++ onMod(mod).mapValues(_.toList)
- }
+ val res = new DriverMap
+ c.modules.foreach(m => onMod(res)(m))
+ res.mapValues(_.toList).toMap
}
- // Determine the type driving a given ResetType
+ /** Determine the type driving a given ResetType
+ *
+ * This is implemented as a graph traversal. Every type constraint is a forwards and backwards
+ * edge between the two components (where one of the components can be Bool or AsyncReset). Then,
+ * types are inferred by determining all of the nodes that are reachable from Bool and AsyncReset
+ * respectively. As an optimization, we actually only need to check reachability from one, since
+ * nodes that are not reachable from the one are either reachable from the other or reachable
+ * from neither. If unreachable, then the component must only be connected to invalidated
+ * components of type Reset, thus we can arbitrarily choose which reset to infer it to. As an
+ * optimization, we have edges from components back to Bool and AsyncReset which allows us to
+ * check if any node is erroneously constrained to be both by simply checking if Bool is
+ * reachable from AsyncReset.
+ */
private def resolve(map: Map[ReferenceTarget, List[ResetDriver]]): Try[Map[ReferenceTarget, Type]] = {
- val res = mutable.Map[ReferenceTarget, Type]()
- val errors = new Errors
- def rec(target: ReferenceTarget): Type = {
- res.getOrElseUpdate(target, {
- val drivers = map.getOrElse(target, Nil)
- val tpes = drivers.flatMap {
- case TargetDriver(t) => Some(TypeDriver(rec(t), () => t))
- case td: TypeDriver => Some(td)
- case InvalidDriver => None
- }.groupBy(_.tpe)
- tpes.keys.size match {
- // This can occur if something of type Reset has no driver
- case 0 => InvalidDriver.defaultType
- case 1 => tpes.keys.head
- case _ =>
- // Multiple types of driver!
- errors.append(DifferingDriverTypesException(target, tpes.toSeq))
- tpes.keys.head
- }
- })
- }
- for ((target, _) <- map) {
- rec(target)
+ val graph = new MutableDiGraph[Node]
+ val asyncNode = Typ(AsyncResetType)
+ val syncNode = Typ(Utils.BoolType)
+ for ((target, drivers) <- map) {
+ val v = Var(target)
+ drivers.foreach {
+ case TargetDriver(t) =>
+ val u = Var(t)
+ graph.addPairWithEdge(v, u)
+ graph.addPairWithEdge(u, v)
+ case TypeDriver(tpe, _) =>
+ // Use nodes we already made, saves memory
+ val u = tpe match {
+ case AsyncResetType => asyncNode
+ case Utils.BoolType => syncNode
+ case other => throwInternalError(s"Shouldn't have $other here")
+ }
+ graph.addPairWithEdge(u, v)
+ // This backwards edge allows us to check for nodes that infer to both at the same time we
+ // do the actual inference, the check is simply if syncNode is reachable from asyncNode
+ graph.addPairWithEdge(v, u)
+ case InvalidDriver =>
+ graph.addVertex(v) // Must be in the graph or won't be inferred
+ }
}
+ val async = graph.reachableFrom(asyncNode)
+ val sync = graph.getVertices -- async
Try {
- errors.trigger()
- res.toMap
+ (async, sync) match {
+ case (a, _) if a.contains(syncNode) => throw InferResetsException(graph.path(asyncNode, syncNode))
+ case (a, s) =>
+ (a.view.collect { case Var(t) => t -> asyncNode.tpe } ++
+ s.view.collect { case Var(t) => t -> syncNode.tpe }).toMap
+ }
}
}
diff --git a/src/test/scala/firrtlTests/InferResetsSpec.scala b/src/test/scala/firrtlTests/InferResetsSpec.scala
index 501dce20..0bcc459c 100644
--- a/src/test/scala/firrtlTests/InferResetsSpec.scala
+++ b/src/test/scala/firrtlTests/InferResetsSpec.scala
@@ -5,7 +5,7 @@ package firrtlTests
import firrtl._
import firrtl.ir._
import firrtl.passes.{CheckHighForm, CheckTypes, CheckInitialization}
-import firrtl.transforms.InferResets
+import firrtl.transforms.{CheckCombLoops, InferResets}
import FirrtlCheckers._
// TODO
@@ -138,8 +138,8 @@ class InferResetsSpec extends FirrtlFlatSpec {
}
}
- it should "allow last connect semantics to pick the right type for Reset" in {
- val result =
+ it should "NOT allow last connect semantics to pick the right type for Reset" in {
+ an [InferResets.InferResetsException] shouldBe thrownBy {
compile(s"""
|circuit top :
| module top :
@@ -154,13 +154,11 @@ class InferResetsSpec extends FirrtlFlatSpec {
| out <= w1
|""".stripMargin
)
- result should containTree { case DefWire(_, "w0", AsyncResetType) => true }
- result should containTree { case DefWire(_, "w1", BoolType) => true }
- result should containTree { case Port(_, "out", Output, BoolType) => true }
+ }
}
- it should "support last connect semantics across whens" in {
- val result =
+ it should "NOT support last connect semantics across whens" in {
+ an [InferResets.InferResetsException] shouldBe thrownBy {
compile(s"""
|circuit top :
| module top :
@@ -182,14 +180,11 @@ class InferResetsSpec extends FirrtlFlatSpec {
| out <= w1
|""".stripMargin
)
- result should containTree { case DefWire(_, "w0", AsyncResetType) => true }
- result should containTree { case DefWire(_, "w1", AsyncResetType) => true }
- result should containTree { case DefWire(_, "w2", BoolType) => true }
- result should containTree { case Port(_, "out", Output, AsyncResetType) => true }
+ }
}
it should "not allow different Reset Types to drive a single Reset" in {
- an [InferResets.DifferingDriverTypesException] shouldBe thrownBy {
+ an [InferResets.InferResetsException] shouldBe thrownBy {
val result = compile(s"""
|circuit top :
| module top :
@@ -410,5 +405,108 @@ class InferResetsSpec extends FirrtlFlatSpec {
result should containTree { case Port(_, "childReset", Input, BoolType) => true }
result should containTree { case Port(_, "childReset", Input, AsyncResetType) => true }
}
-}
+ it should "infer based on what a component *drives* not just what drives it" in {
+ val result = compile(s"""
+ |circuit top :
+ | module top :
+ | input in : AsyncReset
+ | output out : Reset
+ | wire w : Reset
+ | w is invalid
+ | out <= w
+ | out <= in
+ |""".stripMargin)
+ result should containTree { case DefWire(_, "w", AsyncResetType) => true }
+ }
+
+ it should "infer from connections, ignoring the fact that the invalidation wins" in {
+ val result = compile(s"""
+ |circuit top :
+ | module top :
+ | input in : AsyncReset
+ | output out : Reset
+ | out <= in
+ | out is invalid
+ |""".stripMargin)
+ result should containTree { case Port(_, "out", Output, AsyncResetType) => true }
+ }
+
+ // The backwards type propagation constrains `w` to be the same as both `out0` and `out1`
+ it should "not allow an invalidated Wire to drive both a UInt<1> and an AsyncReset" in {
+ an [InferResets.InferResetsException] shouldBe thrownBy {
+ val result = compile(s"""
+ |circuit top :
+ | module top :
+ | input in0 : AsyncReset
+ | input in1 : UInt<1>
+ | output out0 : Reset
+ | output out1 : Reset
+ | wire w : Reset
+ | w is invalid
+ | out0 <= w
+ | out1 <= w
+ | out0 <= in0
+ | out1 <= in1
+ |""".stripMargin
+ )
+ }
+ }
+
+ it should "not propagate type info from downstream across a cast" in {
+ val result = compile(s"""
+ |circuit top :
+ | module top :
+ | input in0 : AsyncReset
+ | input in1 : UInt<1>
+ | output out0 : Reset
+ | output out1 : Reset
+ | wire w : Reset
+ | w is invalid
+ | out0 <= asAsyncReset(w)
+ | out1 <= w
+ | out0 <= in0
+ | out1 <= in1
+ |""".stripMargin
+ )
+ result should containTree { case Port(_, "out0", Output, AsyncResetType) => true }
+ }
+
+ // This tests for a bug unrelated to support or lackthereof for last connect in inference
+ it should "take into account both internal and external constraints on Module port types" in {
+ val result = compile(s"""
+ |circuit top :
+ | module child :
+ | input i : AsyncReset
+ | output o : Reset
+ | o <= i
+ | module top :
+ | input in : AsyncReset
+ | output out : AsyncReset
+ | inst c of child
+ | c.o is invalid
+ | c.i <= in
+ | out <= c.o
+ |""".stripMargin)
+ result should containTree { case Port(_, "o", Output, AsyncResetType) => true }
+ }
+
+ it should "not crash on combinational loops" in {
+ a [CheckCombLoops.CombLoopException] shouldBe thrownBy {
+ val result = compile(s"""
+ |circuit top :
+ | module top :
+ | input in : AsyncReset
+ | output out : Reset
+ | wire w0 : Reset
+ | wire w1 : Reset
+ | w0 <= in
+ | w0 <= w1
+ | w1 <= w0
+ | out <= in
+ |""".stripMargin,
+ compiler = new LowFirrtlCompiler
+ )
+ }
+ }
+}