diff options
| author | Jack Koenig | 2020-02-18 11:45:11 -0800 |
|---|---|---|
| committer | GitHub | 2020-02-18 11:45:11 -0800 |
| commit | 6c7a15a2d9123fd94770dda15f9e9070ae6b2bdc (patch) | |
| tree | 01903cee1942688d6ae09a724e16b06f0a56fb1b /src | |
| parent | 38cf27e8fbdf201761b018afc93107f15cf17cb7 (diff) | |
Remove last connect semantics from reset inference (#1396)
* Revert "Infer resets last connect semantics (#1291)"
* Fix handling of invalidated and undriven components of type Reset
* Run CheckTypes after InferResets
* Make reset inference bidirectional on connect
* Support AsyncResetType in RemoveValidIf
* Fix InferResets for parent constraints on child ports
* Apply suggestions from code review
* Add ScalaDoc to InferResets
Co-authored-by: Albert Magyar <albert.magyar@gmail.com>
Co-authored-by: Schuyler Eldridge <schuyler.eldridge@gmail.com>
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/LoweringCompilers.scala | 3 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveValidIf.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/InferResets.scala | 133 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/InferResetsSpec.scala | 126 |
4 files changed, 205 insertions, 59 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index 0d9b971b..14a8e637 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -49,7 +49,8 @@ class ResolveAndCheck extends CoreTransform { new passes.TrimIntervals(), new passes.InferWidths, passes.CheckWidths, - new firrtl.transforms.InferResets) + new firrtl.transforms.InferResets, + passes.CheckTypes) } /** Expands aggregate connects, removes dynamic accesses, and when diff --git a/src/main/scala/firrtl/passes/RemoveValidIf.scala b/src/main/scala/firrtl/passes/RemoveValidIf.scala index 37a3f931..42eae7e5 100644 --- a/src/main/scala/firrtl/passes/RemoveValidIf.scala +++ b/src/main/scala/firrtl/passes/RemoveValidIf.scala @@ -13,6 +13,7 @@ object RemoveValidIf extends Pass { val SIntZero = SIntLiteral(BigInt(0), IntWidth(1)) val ClockZero = DoPrim(PrimOps.AsClock, Seq(UIntZero), Seq.empty, ClockType) val FixedZero = FixedLiteral(BigInt(0), IntWidth(1), IntWidth(0)) + val AsyncZero = DoPrim(PrimOps.AsAsyncReset, Seq(UIntZero), Nil, AsyncResetType) /** Returns an [[firrtl.ir.Expression Expression]] equal to zero for a given [[firrtl.ir.GroundType GroundType]] * @note Accepts [[firrtl.ir.Type Type]] but dyanmically expects [[firrtl.ir.GroundType GroundType]] @@ -22,6 +23,7 @@ object RemoveValidIf extends Pass { case _: SIntType => SIntZero case ClockType => ClockZero case _: FixedType => FixedZero + case AsyncResetType => AsyncZero case other => throwInternalError(s"Unexpected type $other") } diff --git a/src/main/scala/firrtl/transforms/InferResets.scala b/src/main/scala/firrtl/transforms/InferResets.scala index fbc915e2..026b15fc 100644 --- a/src/main/scala/firrtl/transforms/InferResets.scala +++ b/src/main/scala/firrtl/transforms/InferResets.scala @@ -7,14 +7,17 @@ import firrtl.ir._ import firrtl.Mappers._ import firrtl.traversals.Foreachers._ import firrtl.annotations.{ReferenceTarget, TargetToken} -import firrtl.Utils.toTarget -import firrtl.passes.{Pass, PassException, Errors, InferTypes} +import firrtl.Utils.{toTarget, throwInternalError} +import firrtl.passes.{Pass, PassException, InferTypes} +import firrtl.graph.MutableDiGraph import scala.collection.mutable import scala.util.Try object InferResets { + @deprecated("This is no longer in use and will be removed", "1.3") final class DifferingDriverTypesException private (msg: String) extends PassException(msg) + @deprecated("This is no longer in use and will be removed", "1.3") object DifferingDriverTypesException { def apply(target: ReferenceTarget, tpes: Seq[(Type, Seq[TypeDriver])]): DifferingDriverTypesException = { val xs = tpes.map { case (t, ds) => s"${ds.map(_.target().serialize).mkString(", ")} of type ${t.serialize}" } @@ -23,6 +26,15 @@ object InferResets { } } + final class InferResetsException private (msg: String) extends PassException(msg) + object InferResetsException { + private[InferResets] def apply(path: Seq[Node]): InferResetsException = { + val ps = path.collect { case Var(t) => t.serialize }.mkString("\n - ", "\n - ", "") + val msg = s"Reset-typed components connected to both AsyncReset and UInt<1>. Offending path:$ps" + new InferResetsException(msg) + } + } + /** Type hierarchy to represent the type of the thing driving a [[ResetType]] */ private sealed trait ResetDriver // When a [[ResetType]] is driven by another ResetType, we track the target so that we can infer @@ -34,7 +46,7 @@ object InferResets { // as a constraint on the type we should infer to be // We keep the target around (lazily) so that we can report errors private case class TypeDriver(tpe: Type, target: () => ReferenceTarget) extends ResetDriver { - override def toString: String = s"TypeDriver(${tpe.serialize}, $target)" + override def toString: String = s"TypeDriver(${tpe.serialize}, () => ?)" } // When a [[ResetType]] is invalidated, we record the InvalidDrive // If there are no types but invalid drivers, we default to BoolType @@ -42,6 +54,15 @@ object InferResets { def defaultType: Type = Utils.BoolType } + // Private type hierarchy used as DiGraph nodes for type inference + private sealed trait Node + private case class Var(target: ReferenceTarget) extends Node { + override def toString = target.serialize + } + private case class Typ(tpe: Type) extends Node { + override def toString = tpe.serialize + } + // Type hierarchy representing the path to a leaf type in an aggregate type structure // Used by this [[InferResets]] to pinpoint instances of [[ResetType]] and their inferred type private sealed trait TypeTree @@ -74,8 +95,16 @@ object InferResets { } } -/** Infers the concrete type of [[Reset]]s by their connections - * This is a global inference because ports can be of type [[Reset]] +/** Infers the concrete type of [[ResetType]]s by their connections + * + * There are 3 cases + * 1. An abstract reset driven by and/or driving only asynchronous resets will be inferred as + * asynchronous reset + * 1. An abstract reset driven by and/or driving both asynchronous and synchronous resets will + * error + * 1. Otherwise, the reset is inferred as synchronous (i.e. the abstract reset is only invalidated + * or is driven by or drives only synchronous resets) + * @note This is a global inference because ports can be of type [[ResetType]] * @note This transform should be run before [[DedupModules]] so that similar Modules from * generator languages like Chisel can infer differently */ @@ -88,8 +117,8 @@ class InferResets extends Transform { // Collect all drivers for circuit elements of type ResetType private def analyze(c: Circuit): Map[ReferenceTarget, List[ResetDriver]] = { - type DriverMap = mutable.HashMap[ReferenceTarget, List[ResetDriver]] - def onMod(mod: DefModule): DriverMap = { + type DriverMap = mutable.HashMap[ReferenceTarget, mutable.ListBuffer[ResetDriver]] + def onMod(types: DriverMap)(mod: DefModule): DriverMap = { val instMap = mutable.Map[String, String]() // We need to convert submodule port targets into targets on the Module port itself def makeTarget(expr: Expression): ReferenceTarget = { @@ -119,7 +148,7 @@ class InferResets extends Transform { case ResetType => TargetDriver(makeTarget(exp)) case tpe => TypeDriver(tpe, () => makeTarget(exp)) } - map(makeTarget(loc)) = driver :: Nil + map.getOrElseUpdate(makeTarget(loc), mutable.ListBuffer()) += driver } } stmt match { @@ -146,7 +175,7 @@ class InferResets extends Transform { // Unlike in markResetDriver, flow is irrelevant for invalidation if (expr.tpe == ResetType) { val target = makeTarget(expr) - map(target) = InvalidDriver :: Nil + map.getOrElseUpdate(target, mutable.ListBuffer()) += InvalidDriver } } case WDefInstance(_, inst, module, _) => @@ -156,54 +185,70 @@ class InferResets extends Transform { val altMap = new DriverMap onStmt(conMap)(con) onStmt(altMap)(alt) - // Default to outerscope if not found on either side - val conLookup = conMap.orElse(map).lift - val altLookup = altMap.orElse(map).lift for (key <- conMap.keys ++ altMap.keys) { - val values = conLookup(key).getOrElse(Nil) ++ altLookup(key).getOrElse(Nil) - map(key) = values + val ds = map.getOrElseUpdate(key, mutable.ListBuffer()) + conMap.get(key).foreach(ds ++= _) + altMap.get(key).foreach(ds ++= _) } case other => other.foreach(onStmt(map)) } } - val types = new DriverMap mod.foreach(onStmt(types)) types } - c.modules.foldLeft(Map[ReferenceTarget, List[ResetDriver]]()) { - case (map, mod) => map ++ onMod(mod).mapValues(_.toList) - } + val res = new DriverMap + c.modules.foreach(m => onMod(res)(m)) + res.mapValues(_.toList).toMap } - // Determine the type driving a given ResetType + /** Determine the type driving a given ResetType + * + * This is implemented as a graph traversal. Every type constraint is a forwards and backwards + * edge between the two components (where one of the components can be Bool or AsyncReset). Then, + * types are inferred by determining all of the nodes that are reachable from Bool and AsyncReset + * respectively. As an optimization, we actually only need to check reachability from one, since + * nodes that are not reachable from the one are either reachable from the other or reachable + * from neither. If unreachable, then the component must only be connected to invalidated + * components of type Reset, thus we can arbitrarily choose which reset to infer it to. As an + * optimization, we have edges from components back to Bool and AsyncReset which allows us to + * check if any node is erroneously constrained to be both by simply checking if Bool is + * reachable from AsyncReset. + */ private def resolve(map: Map[ReferenceTarget, List[ResetDriver]]): Try[Map[ReferenceTarget, Type]] = { - val res = mutable.Map[ReferenceTarget, Type]() - val errors = new Errors - def rec(target: ReferenceTarget): Type = { - res.getOrElseUpdate(target, { - val drivers = map.getOrElse(target, Nil) - val tpes = drivers.flatMap { - case TargetDriver(t) => Some(TypeDriver(rec(t), () => t)) - case td: TypeDriver => Some(td) - case InvalidDriver => None - }.groupBy(_.tpe) - tpes.keys.size match { - // This can occur if something of type Reset has no driver - case 0 => InvalidDriver.defaultType - case 1 => tpes.keys.head - case _ => - // Multiple types of driver! - errors.append(DifferingDriverTypesException(target, tpes.toSeq)) - tpes.keys.head - } - }) - } - for ((target, _) <- map) { - rec(target) + val graph = new MutableDiGraph[Node] + val asyncNode = Typ(AsyncResetType) + val syncNode = Typ(Utils.BoolType) + for ((target, drivers) <- map) { + val v = Var(target) + drivers.foreach { + case TargetDriver(t) => + val u = Var(t) + graph.addPairWithEdge(v, u) + graph.addPairWithEdge(u, v) + case TypeDriver(tpe, _) => + // Use nodes we already made, saves memory + val u = tpe match { + case AsyncResetType => asyncNode + case Utils.BoolType => syncNode + case other => throwInternalError(s"Shouldn't have $other here") + } + graph.addPairWithEdge(u, v) + // This backwards edge allows us to check for nodes that infer to both at the same time we + // do the actual inference, the check is simply if syncNode is reachable from asyncNode + graph.addPairWithEdge(v, u) + case InvalidDriver => + graph.addVertex(v) // Must be in the graph or won't be inferred + } } + val async = graph.reachableFrom(asyncNode) + val sync = graph.getVertices -- async Try { - errors.trigger() - res.toMap + (async, sync) match { + case (a, _) if a.contains(syncNode) => throw InferResetsException(graph.path(asyncNode, syncNode)) + case (a, s) => + (a.view.collect { case Var(t) => t -> asyncNode.tpe } ++ + s.view.collect { case Var(t) => t -> syncNode.tpe }).toMap + } } } diff --git a/src/test/scala/firrtlTests/InferResetsSpec.scala b/src/test/scala/firrtlTests/InferResetsSpec.scala index 501dce20..0bcc459c 100644 --- a/src/test/scala/firrtlTests/InferResetsSpec.scala +++ b/src/test/scala/firrtlTests/InferResetsSpec.scala @@ -5,7 +5,7 @@ package firrtlTests import firrtl._ import firrtl.ir._ import firrtl.passes.{CheckHighForm, CheckTypes, CheckInitialization} -import firrtl.transforms.InferResets +import firrtl.transforms.{CheckCombLoops, InferResets} import FirrtlCheckers._ // TODO @@ -138,8 +138,8 @@ class InferResetsSpec extends FirrtlFlatSpec { } } - it should "allow last connect semantics to pick the right type for Reset" in { - val result = + it should "NOT allow last connect semantics to pick the right type for Reset" in { + an [InferResets.InferResetsException] shouldBe thrownBy { compile(s""" |circuit top : | module top : @@ -154,13 +154,11 @@ class InferResetsSpec extends FirrtlFlatSpec { | out <= w1 |""".stripMargin ) - result should containTree { case DefWire(_, "w0", AsyncResetType) => true } - result should containTree { case DefWire(_, "w1", BoolType) => true } - result should containTree { case Port(_, "out", Output, BoolType) => true } + } } - it should "support last connect semantics across whens" in { - val result = + it should "NOT support last connect semantics across whens" in { + an [InferResets.InferResetsException] shouldBe thrownBy { compile(s""" |circuit top : | module top : @@ -182,14 +180,11 @@ class InferResetsSpec extends FirrtlFlatSpec { | out <= w1 |""".stripMargin ) - result should containTree { case DefWire(_, "w0", AsyncResetType) => true } - result should containTree { case DefWire(_, "w1", AsyncResetType) => true } - result should containTree { case DefWire(_, "w2", BoolType) => true } - result should containTree { case Port(_, "out", Output, AsyncResetType) => true } + } } it should "not allow different Reset Types to drive a single Reset" in { - an [InferResets.DifferingDriverTypesException] shouldBe thrownBy { + an [InferResets.InferResetsException] shouldBe thrownBy { val result = compile(s""" |circuit top : | module top : @@ -410,5 +405,108 @@ class InferResetsSpec extends FirrtlFlatSpec { result should containTree { case Port(_, "childReset", Input, BoolType) => true } result should containTree { case Port(_, "childReset", Input, AsyncResetType) => true } } -} + it should "infer based on what a component *drives* not just what drives it" in { + val result = compile(s""" + |circuit top : + | module top : + | input in : AsyncReset + | output out : Reset + | wire w : Reset + | w is invalid + | out <= w + | out <= in + |""".stripMargin) + result should containTree { case DefWire(_, "w", AsyncResetType) => true } + } + + it should "infer from connections, ignoring the fact that the invalidation wins" in { + val result = compile(s""" + |circuit top : + | module top : + | input in : AsyncReset + | output out : Reset + | out <= in + | out is invalid + |""".stripMargin) + result should containTree { case Port(_, "out", Output, AsyncResetType) => true } + } + + // The backwards type propagation constrains `w` to be the same as both `out0` and `out1` + it should "not allow an invalidated Wire to drive both a UInt<1> and an AsyncReset" in { + an [InferResets.InferResetsException] shouldBe thrownBy { + val result = compile(s""" + |circuit top : + | module top : + | input in0 : AsyncReset + | input in1 : UInt<1> + | output out0 : Reset + | output out1 : Reset + | wire w : Reset + | w is invalid + | out0 <= w + | out1 <= w + | out0 <= in0 + | out1 <= in1 + |""".stripMargin + ) + } + } + + it should "not propagate type info from downstream across a cast" in { + val result = compile(s""" + |circuit top : + | module top : + | input in0 : AsyncReset + | input in1 : UInt<1> + | output out0 : Reset + | output out1 : Reset + | wire w : Reset + | w is invalid + | out0 <= asAsyncReset(w) + | out1 <= w + | out0 <= in0 + | out1 <= in1 + |""".stripMargin + ) + result should containTree { case Port(_, "out0", Output, AsyncResetType) => true } + } + + // This tests for a bug unrelated to support or lackthereof for last connect in inference + it should "take into account both internal and external constraints on Module port types" in { + val result = compile(s""" + |circuit top : + | module child : + | input i : AsyncReset + | output o : Reset + | o <= i + | module top : + | input in : AsyncReset + | output out : AsyncReset + | inst c of child + | c.o is invalid + | c.i <= in + | out <= c.o + |""".stripMargin) + result should containTree { case Port(_, "o", Output, AsyncResetType) => true } + } + + it should "not crash on combinational loops" in { + a [CheckCombLoops.CombLoopException] shouldBe thrownBy { + val result = compile(s""" + |circuit top : + | module top : + | input in : AsyncReset + | output out : Reset + | wire w0 : Reset + | wire w1 : Reset + | w0 <= in + | w0 <= w1 + | w1 <= w0 + | out <= in + |""".stripMargin, + compiler = new LowFirrtlCompiler + ) + } + } +} |
