diff options
| author | Jack Koenig | 2017-12-12 16:06:31 -0800 |
|---|---|---|
| committer | GitHub | 2017-12-12 16:06:31 -0800 |
| commit | df579547f769843b76922dbb055ea26839b1d7d4 (patch) | |
| tree | f557811fb961a3125bbfef95815eb81f72ca8346 /src/main | |
| parent | 0fd0c66adcf1226ee5947cdaa5629bf59c4123f1 (diff) | |
| parent | 0d794d57df7b388109d7a0834d3b5be8f79892be (diff) | |
Merge pull request #684 from freechipsproject/remove-wires
Remove wires, replacing them with nodes
Diffstat (limited to 'src/main')
| -rw-r--r-- | src/main/scala/firrtl/LoweringCompilers.scala | 3 | ||||
| -rw-r--r-- | src/main/scala/firrtl/ir/IR.scala | 17 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 15 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/RemoveWires.scala | 132 |
4 files changed, 159 insertions, 8 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index 8dd9b180..f032868a 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -86,7 +86,8 @@ class MiddleFirrtlToLowFirrtl extends CoreTransform { passes.InferWidths, passes.Legalize, new firrtl.transforms.RemoveReset, - new firrtl.transforms.CheckCombLoops) + new firrtl.transforms.CheckCombLoops, + new firrtl.transforms.RemoveWires) } /** Runs a series of optimization passes on LowFirrtl diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index 090ad8ec..a3ad4231 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -29,10 +29,23 @@ case class MultiInfo(infos: Seq[Info]) extends Info { case MultiInfo(seq) => seq flatMap collectStringLits case NoInfo => Seq.empty } - override def toString: String = - collectStringLits(this).map(_.serialize).mkString(" @[", " ", "]") + override def toString: String = { + val parts = collectStringLits(this) + if (parts.nonEmpty) parts.map(_.serialize).mkString(" @[", " ", "]") + else "" + } def ++(that: Info): Info = MultiInfo(Seq(this, that)) } +object MultiInfo { + def apply(infos: Info*) = { + val infosx = infos.filterNot(_ == NoInfo) + infosx.size match { + case 0 => NoInfo + case 1 => infosx.head + case _ => new MultiInfo(infosx) + } + } +} trait HasName { val name: String diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 84b63e3d..ca48cbb5 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -16,15 +16,20 @@ import firrtl.analyses.InstanceGraph import annotation.tailrec import collection.mutable -class ConstantPropagation extends Transform { - def inputForm = LowForm - def outputForm = LowForm - - private def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match { +object ConstantPropagation { + /** Pads e to the width of t */ + def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match { case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) case (we, wt) if we == wt => e } +} + +class ConstantPropagation extends Transform { + import ConstantPropagation._ + def inputForm = LowForm + def outputForm = LowForm + private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t) trait FoldLogicalOp { diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala new file mode 100644 index 00000000..a1fb32db --- /dev/null +++ b/src/main/scala/firrtl/transforms/RemoveWires.scala @@ -0,0 +1,132 @@ +// See LICENSE for license details. + +package firrtl +package transforms + +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.Mappers._ +import firrtl.WrappedExpression._ +import firrtl.graph.{DiGraph, MutableDiGraph, CyclicException} + +import scala.collection.mutable +import scala.util.{Try, Success, Failure} + +/** Replace wires with nodes in a legal, flow-forward order + * + * This pass must run after LowerTypes because Aggregate-type + * wires have multiple connections that may be impossible to order in a + * flow-foward way + */ +class RemoveWires extends Transform { + def inputForm = LowForm + def outputForm = LowForm + + // Extract all expressions that are references to a Wire or Node + // Since we are operating on LowForm, they can only be WRefs + private def extractNodeWireRefs(expr: Expression): Seq[WRef] = { + val refs = mutable.ArrayBuffer.empty[WRef] + def rec(e: Expression): Expression = { + e match { + case ref @ WRef(_,_, WireKind | NodeKind, _) => refs += ref + case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested map rec + case _ => // Do nothing + } + e + } + rec(expr) + refs + } + + // Transform netlist into DefNodes + private def getOrderedNodes( + netlist: mutable.LinkedHashMap[WrappedExpression, (Expression, Info)]): Try[Seq[DefNode]] = { + val digraph = new MutableDiGraph[WrappedExpression] + for ((sink, (expr, _)) <- netlist) { + digraph.addVertex(sink) + for (source <- extractNodeWireRefs(expr)) { + digraph.addPairWithEdge(sink, source) + } + } + + // We could reverse edge directions and not have to do this reverse, but doing it this way does + // a MUCH better job of preserving the logic order as expressed by the designer + // See RemoveWireTests for illustration + Try { + val ordered = digraph.linearize.reverse + ordered.map { key => + val WRef(name, _,_,_) = key.e1 + val (rhs, info) = netlist(key) + DefNode(info, name, rhs) + } + } + } + + private def onModule(m: DefModule): DefModule = { + // Store all non-node declarations here (like reg, inst, and mem) + val decls = mutable.ArrayBuffer.empty[Statement] + // Store all "other" statements here, non-wire, non-node connections, printfs, etc. + val otherStmts = mutable.ArrayBuffer.empty[Statement] + // Add nodes and wire connection here + val netlist = mutable.LinkedHashMap.empty[WrappedExpression, (Expression, Info)] + // Info at definition of wires for combining into node + val wireInfo = mutable.HashMap.empty[WrappedExpression, Info] + + def onStmt(stmt: Statement): Statement = { + stmt match { + case DefNode(info, name, expr) => + netlist(we(WRef(name))) = (expr, info) + case wire: DefWire if !wire.tpe.isInstanceOf[AnalogType] => // Remove all non-Analog wires + wireInfo(WRef(wire)) = wire.info + case decl: IsDeclaration => // Keep all declarations except for nodes and non-Analog wires + decls += decl + case con @ Connect(cinfo, lhs, rhs) => kind(lhs) match { + case WireKind => + // Be sure to pad the rhs since nodes get their type from the rhs + val paddedRhs = ConstantPropagation.pad(rhs, lhs.tpe) + val dinfo = wireInfo(lhs) + netlist(we(lhs)) = (paddedRhs, MultiInfo(dinfo, cinfo)) + case _ => otherStmts += con // Other connections just pass through + } + case invalid @ IsInvalid(info, expr) => + kind(expr) match { + case WireKind => + val width = expr.tpe match { case GroundType(width) => width } // LowFirrtl + netlist(we(expr)) = (ValidIf(Utils.zero, UIntLiteral(BigInt(0), width), expr.tpe), info) + case _ => otherStmts += invalid + } + case other @ (_: Print | _: Stop | _: Attach) => + otherStmts += other + case EmptyStmt => // Dont bother keeping EmptyStmts around + case block: Block => block map onStmt + case _ => throwInternalError + } + stmt + } + + m match { + case mod @ Module(info, name, ports, body) => + onStmt(body) + getOrderedNodes(netlist) match { + case Success(logic) => + Module(info, name, ports, Block(decls ++ logic ++ otherStmts)) + // If we hit a CyclicException, just abort removing wires + case Failure(_: CyclicException) => + logger.warn(s"Cycle found in module $name, " + + "wires will not be removed which can prevent optimizations!") + mod + case Failure(other) => throw other + } + case m: ExtModule => m + } + } + + private val cleanup = Seq( + passes.ResolveKinds + ) + + def execute(state: CircuitState): CircuitState = { + val result = state.copy(circuit = state.circuit map onModule) + cleanup.foldLeft(result) { case (in, xform) => xform.execute(in) } + } +} |
