diff options
| author | Kevin Laeufer | 2020-07-29 15:25:34 -0700 |
|---|---|---|
| committer | GitHub | 2020-07-29 22:25:34 +0000 |
| commit | c02c9b7f33d67d8a65040c028395e881668294f6 (patch) | |
| tree | e6eaa4f2787e74759f4cfffa61f84bd08a03d4c2 | |
| parent | 3a6e352626915751b2b2a5d6aec4203fb8e83a1d (diff) | |
WiringTransform: fix non-determinism (#1799)
* WiringUtils.sinksToSources: make sinkInsts order deterministic
* WiringUtils: make owners a LinkedHashMap
* Wiring: only make something a Wire if it isn't a port already
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
| -rw-r--r-- | src/main/scala/firrtl/passes/wiring/Wiring.scala | 25 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/wiring/WiringUtils.scala | 47 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/WiringTests.scala | 12 |
3 files changed, 46 insertions, 38 deletions
diff --git a/src/main/scala/firrtl/passes/wiring/Wiring.scala b/src/main/scala/firrtl/passes/wiring/Wiring.scala index c074168b..1ee509e2 100644 --- a/src/main/scala/firrtl/passes/wiring/Wiring.scala +++ b/src/main/scala/firrtl/passes/wiring/Wiring.scala @@ -81,24 +81,28 @@ class Wiring(wiSeq: Seq[WiringInfo]) extends Pass { case (a, (c, m)) => a ++ Map(m -> (Seq(c) ++ a.getOrElse(m, Nil)) ) } // Determine "ownership" of sources to sinks via minimum distance - val owners = sinksToSources(sinks, source, iGraph) + val owners = sinksToSourcesSeq(sinks, source, iGraph) // Determine port and pending modifications for all sink--source // ownership pairs val meta = new mutable.HashMap[String, Modifications] .withDefaultValue(Modifications()) + + // only make something a wire if it isn't an output or input already + def makeWire(m: Modifications, portName: String): Modifications = + m.copy(addPortOrWire = Some(m.addPortOrWire.getOrElse((portName, DecWire)))) + def makeWireC(m: Modifications, portName: String, c: (String, String)): Modifications = + m.copy(addPortOrWire = Some(m.addPortOrWire.getOrElse((portName, DecWire))), cons = (m.cons :+ c).distinct ) + owners.foreach { case (sink, source) => val lca = iGraph.lowestCommonAncestor(sink, source) // Compute metadata along Sink to LCA paths. - sink.drop(lca.size - 1).sliding(2).toList.reverse.map { + sink.drop(lca.size - 1).sliding(2).toList.reverse.foreach { case Seq(WDefInstance(_,_,pm,_), WDefInstance(_,ci,cm,_)) => val to = s"$ci.${portNames(cm)}" val from = s"${portNames(pm)}" - meta(pm) = meta(pm).copy( - addPortOrWire = Some((portNames(pm), DecWire)), - cons = (meta(pm).cons :+( (to, from) )).distinct - ) + meta(pm) = makeWireC(meta(pm), portNames(pm), (to, from)) meta(cm) = meta(cm).copy( addPortOrWire = Some((portNames(cm), DecInput)) ) @@ -106,17 +110,12 @@ class Wiring(wiSeq: Seq[WiringInfo]) extends Pass { case Seq(WDefInstance(_,_,pm,_)) => // Case where the source is also the LCA if (source.drop(lca.size).isEmpty) { - meta(pm) = meta(pm).copy ( - addPortOrWire = Some((portNames(pm), DecWire)) - ) + meta(pm) = makeWire(meta(pm), portNames(pm)) } else { val WDefInstance(_,ci,cm,_) = source.drop(lca.size).head val to = s"${portNames(pm)}" val from = s"$ci.${portNames(cm)}" - meta(pm) = meta(pm).copy( - addPortOrWire = Some((portNames(pm), DecWire)), - cons = (meta(pm).cons :+( (to, from) )).distinct - ) + meta(pm) = makeWireC(meta(pm), portNames(pm), (to, from)) } } diff --git a/src/main/scala/firrtl/passes/wiring/WiringUtils.scala b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala index 9eed358f..5f09bbe0 100644 --- a/src/main/scala/firrtl/passes/wiring/WiringUtils.scala +++ b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala @@ -112,29 +112,39 @@ object WiringUtils { * @return a map of sink instance names to source instance names * @throws WiringException if a sink is equidistant to two sources */ - def sinksToSources(sinks: Seq[Named], + @deprecated("This method can lead to non-determinism in your compiler pass. Use sinksToSourcesSeq instead!", "Firrtl 1.4") + def sinksToSources(sinks: Seq[Named], source: String, i: InstanceGraph): Map[Seq[WDefInstance], Seq[WDefInstance]] = + sinksToSourcesSeq(sinks, source, i).toMap + + /** Return a map of sink instances to source instances that minimizes + * distance + * + * @param sinks a sequence of sink modules + * @param source the source module + * @param i a graph representing a circuit + * @return a map of sink instance names to source instance names + * @throws WiringException if a sink is equidistant to two sources + */ + def sinksToSourcesSeq(sinks: Seq[Named], source: String, i: InstanceGraph): - Map[Seq[WDefInstance], Seq[WDefInstance]] = { - val owners = new mutable.HashMap[Seq[WDefInstance], Vector[Seq[WDefInstance]]] - .withDefaultValue(Vector()) + Seq[(Seq[WDefInstance], Seq[WDefInstance])] = { + // The order of owners influences the order of the results, it thus needs to be deterministic with a LinkedHashMap. + val owners = new mutable.LinkedHashMap[Seq[WDefInstance], Vector[Seq[WDefInstance]]] val queue = new mutable.Queue[Seq[WDefInstance]] val visited = new mutable.HashMap[Seq[WDefInstance], Boolean] .withDefaultValue(false) - i.fullHierarchy.keys.filter { case WDefInstance(_,_,m,_) => m == source } - .foreach( i.fullHierarchy(_) - .foreach { l => - queue.enqueue(l) - owners(l) = Vector(l) - } - ) + val sourcePaths = i.fullHierarchy.collect { case (k,v) if k.module == source => v } + sourcePaths.flatten.foreach { l => + queue.enqueue(l) + owners(l) = Vector(l) + } - val sinkInsts = i.fullHierarchy.keys - .filter { case WDefInstance(_, _, module, _) => - sinks.map(getModuleName(_)).contains(module) } - .flatMap { k => i.fullHierarchy(k) } - .toSet + val sinkModuleNames = sinks.map(getModuleName).toSet + val sinkPaths = i.fullHierarchy.collect { case (k,v) if sinkModuleNames.contains(k.module) => v } + // sinkInsts needs to have unique entries but is also iterated over which is why we use a LinkedHashSet + val sinkInsts = mutable.LinkedHashSet() ++ sinkPaths.flatten /** If we're lucky and there is only one source, then that source owns * all sinks. If we're unlucky, we need to do a full (slow) BFS @@ -161,7 +171,7 @@ object WiringUtils { edges .filter( e => !visited(e) && e.nonEmpty ) .foreach{ v => - owners(v) = owners(v) ++ owners(u) + owners(v) = owners.getOrElse(v, Vector()) ++ owners(u) queue.enqueue(v) } } @@ -175,8 +185,7 @@ object WiringUtils { } } - owners - .collect { case (k, v) if sinkInsts.contains(k) => (k, v.flatten) }.toMap + owners.collect { case (k, v) if sinkInsts.contains(k) => (k, v.flatten) }.toSeq } /** Helper script to extract a module name from a named Module or Target */ diff --git a/src/test/scala/firrtlTests/WiringTests.scala b/src/test/scala/firrtlTests/WiringTests.scala index 48089f0c..8ec6d5ce 100644 --- a/src/test/scala/firrtlTests/WiringTests.scala +++ b/src/test/scala/firrtlTests/WiringTests.scala @@ -83,9 +83,9 @@ class WiringTests extends FirrtlFlatSpec { | x.clock <= clock | inst d of D | d.clock <= clock - | d.r <= r - | r <= b.r | x.pin <= r + | r <= b.r + | d.r <= r | module B : | input clock: Clock | output r: UInt<5> @@ -169,9 +169,9 @@ class WiringTests extends FirrtlFlatSpec { | x.clock <= clock | inst d of D | d.clock <= clock - | d.r <= r - | r <= b.r | x.pin <= r + | r <= b.r + | d.r <= r | module B : | input clock: Clock | output r: UInt<5> @@ -256,9 +256,9 @@ class WiringTests extends FirrtlFlatSpec { | x.clock <= clock | inst d of D | d.clock <= clock - | d.r <= r - | r <= b.r | x.pin <= r + | r <= b.r + | d.r <= r | module B : | input clock: Clock | output r: UInt<5> |
