diff options
| author | chick | 2020-08-14 19:47:53 -0700 |
|---|---|---|
| committer | Jack Koenig | 2020-08-14 19:47:53 -0700 |
| commit | 6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch) | |
| tree | 2ed103ee80b0fba613c88a66af854ae9952610ce /src/main/scala/firrtl/transforms | |
| parent | b516293f703c4de86397862fee1897aded2ae140 (diff) | |
All of src/ formatted with scalafmt
Diffstat (limited to 'src/main/scala/firrtl/transforms')
28 files changed, 1368 insertions, 1163 deletions
diff --git a/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala b/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala index a57973d5..5000e07a 100644 --- a/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala +++ b/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala @@ -2,7 +2,7 @@ package firrtl.transforms -import java.io.{File, FileNotFoundException, FileInputStream, FileOutputStream, PrintWriter} +import java.io.{File, FileInputStream, FileNotFoundException, FileOutputStream, PrintWriter} import firrtl._ import firrtl.annotations._ @@ -11,31 +11,32 @@ import scala.collection.immutable.ListSet sealed trait BlackBoxHelperAnno extends Annotation -case class BlackBoxTargetDirAnno(targetDir: String) extends BlackBoxHelperAnno - with NoTargetAnnotation { +case class BlackBoxTargetDirAnno(targetDir: String) extends BlackBoxHelperAnno with NoTargetAnnotation { override def serialize: String = s"targetDir\n$targetDir" } -case class BlackBoxResourceAnno(target: ModuleName, resourceId: String) extends BlackBoxHelperAnno +case class BlackBoxResourceAnno(target: ModuleName, resourceId: String) + extends BlackBoxHelperAnno with SingleTargetAnnotation[ModuleName] { def duplicate(n: ModuleName) = this.copy(target = n) override def serialize: String = s"resource\n$resourceId" } -case class BlackBoxInlineAnno(target: ModuleName, name: String, text: String) extends BlackBoxHelperAnno +case class BlackBoxInlineAnno(target: ModuleName, name: String, text: String) + extends BlackBoxHelperAnno with SingleTargetAnnotation[ModuleName] { def duplicate(n: ModuleName) = this.copy(target = n) override def serialize: String = s"inline\n$name\n$text" } -case class BlackBoxPathAnno(target: ModuleName, path: String) extends BlackBoxHelperAnno +case class BlackBoxPathAnno(target: ModuleName, path: String) + extends BlackBoxHelperAnno with SingleTargetAnnotation[ModuleName] { def duplicate(n: ModuleName) = this.copy(target = n) override def serialize: String = s"path\n$path" } -case class BlackBoxResourceFileNameAnno(resourceFileName: String) extends BlackBoxHelperAnno - with NoTargetAnnotation { +case class BlackBoxResourceFileNameAnno(resourceFileName: String) extends BlackBoxHelperAnno with NoTargetAnnotation { override def serialize: String = s"resourceFileName\n$resourceFileName" } @@ -43,8 +44,10 @@ case class BlackBoxResourceFileNameAnno(resourceFileName: String) extends BlackB * @param fileName the name of the BlackBox file (only used for error message generation) * @param e an underlying exception that generated this */ -class BlackBoxNotFoundException(fileName: String, message: String) extends FirrtlUserException( - s"BlackBox '$fileName' not found. Did you misspell it? Is it in src/{main,test}/resources?\n$message") +class BlackBoxNotFoundException(fileName: String, message: String) + extends FirrtlUserException( + s"BlackBox '$fileName' not found. Did you misspell it? Is it in src/{main,test}/resources?\n$message" + ) /** Handle source for Verilog ExtModules (BlackBoxes) * @@ -72,15 +75,16 @@ class BlackBoxSourceHelper extends Transform with DependencyAPIMigration { */ def collectAnnos(annos: Seq[Annotation]): (ListSet[BlackBoxHelperAnno], File, File) = annos.foldLeft((ListSet.empty[BlackBoxHelperAnno], DefaultTargetDir, new File(defaultFileListName))) { - case ((acc, tdir, flistName), anno) => anno match { - case BlackBoxTargetDirAnno(dir) => - val targetDir = new File(dir) - if (!targetDir.exists()) { FileUtils.makeDirectory(targetDir.getAbsolutePath) } - (acc, targetDir, flistName) - case BlackBoxResourceFileNameAnno(fileName) => (acc, tdir, new File(fileName)) - case a: BlackBoxHelperAnno => (acc + a, tdir, flistName) - case _ => (acc, tdir, flistName) - } + case ((acc, tdir, flistName), anno) => + anno match { + case BlackBoxTargetDirAnno(dir) => + val targetDir = new File(dir) + if (!targetDir.exists()) { FileUtils.makeDirectory(targetDir.getAbsolutePath) } + (acc, targetDir, flistName) + case BlackBoxResourceFileNameAnno(fileName) => (acc, tdir, new File(fileName)) + case a: BlackBoxHelperAnno => (acc + a, tdir, flistName) + case _ => (acc, tdir, flistName) + } } /** @@ -112,14 +116,15 @@ class BlackBoxSourceHelper extends Transform with DependencyAPIMigration { case BlackBoxInlineAnno(_, name, text) => val outFile = new File(targetDir, name) (text, outFile) - }.map { case (text, file) => - writeTextToFile(text, file) - file + }.map { + case (text, file) => + writeTextToFile(text, file) + file } // Issue #917 - We don't want to list Verilog header files ("*.vh") in our file list - they will automatically be included by reference. def isHeader(name: String) = name.endsWith(".h") || name.endsWith(".vh") || name.endsWith(".svh") - val verilogSourcesOnly = (resourceFiles ++ inlineFiles).filterNot{ f => isHeader(f.getName()) } + val verilogSourcesOnly = (resourceFiles ++ inlineFiles).filterNot { f => isHeader(f.getName()) } val filelistFile = if (flistName.isAbsolute()) flistName else new File(targetDir, flistName.getName()) // We need the canonical path here, so verilator will create a path to the file that works from the targetDir, @@ -137,12 +142,14 @@ class BlackBoxSourceHelper extends Transform with DependencyAPIMigration { } object BlackBoxSourceHelper { + /** Safely access a file converting [[FileNotFoundException]]s and [[NullPointerException]]s into * [[BlackBoxNotFoundException]]s * @param fileName the name of the file to be accessed (only used for error message generation) * @param code some code to run */ - private def safeFile[A](fileName: String)(code: => A) = try { code } catch { + private def safeFile[A](fileName: String)(code: => A) = try { code } + catch { case e @ (_: FileNotFoundException | _: NullPointerException) => throw new BlackBoxNotFoundException(fileName, e.getMessage) } diff --git a/src/main/scala/firrtl/transforms/CheckCombLoops.scala b/src/main/scala/firrtl/transforms/CheckCombLoops.scala index 6403be23..ee4c1d0b 100644 --- a/src/main/scala/firrtl/transforms/CheckCombLoops.scala +++ b/src/main/scala/firrtl/transforms/CheckCombLoops.scala @@ -24,6 +24,7 @@ import firrtl.options.{Dependency, RegisteredTransform, ShellOption} case class LogicNode(name: String, inst: Option[String] = None, memport: Option[String] = None) object LogicNode { + /** * Construct a LogicNode from a *Low FIRRTL* reference or subfield that refers to a component. * Since aggregate types appear in Low FIRRTL only as the full types of instances or memories, @@ -39,11 +40,11 @@ object LogicNode { case s: WSubField => s.expr match { case modref: WRef => - LogicNode(s.name,Some(modref.name)) + LogicNode(s.name, Some(modref.name)) case memport: WSubField => memport.expr match { case memref: WRef => - LogicNode(s.name,Some(memref.name),Some(memport.name)) + LogicNode(s.name, Some(memref.name), Some(memport.name)) case _ => throwInternalError(s"LogicNode: unrecognized subsubfield expression - $memport") } case _ => throwInternalError(s"LogicNode: unrecognized subfield expression - $s") @@ -56,9 +57,8 @@ object CheckCombLoops { type ConnMap = DiGraph[LogicNode] with EdgeData[LogicNode, Info] type MutableConnMap = MutableDiGraph[LogicNode] with MutableEdgeData[LogicNode, Info] - - class CombLoopException(info: Info, mname: String, cycle: Seq[String]) extends PassException( - s"$info: [module $mname] Combinational loop detected:\n" + cycle.mkString("\n")) + class CombLoopException(info: Info, mname: String, cycle: Seq[String]) + extends PassException(s"$info: [module $mname] Combinational loop detected:\n" + cycle.mkString("\n")) } case object DontCheckCombLoopsAnnotation extends NoTargetAnnotation @@ -73,7 +73,7 @@ case class ExtModulePathAnnotation(source: ReferenceTarget, sink: ReferenceTarge override def update(renames: RenameMap): Seq[Annotation] = { val sources = renames.get(source).getOrElse(Seq(source)) val sinks = renames.get(sink).getOrElse(Seq(sink)) - val paths = sources flatMap { s => sinks.map((s, _)) } + val paths = sources.flatMap { s => sinks.map((s, _)) } paths.collect { case (source: ReferenceTarget, sink: ReferenceTarget) => ExtModulePathAnnotation(source, sink) } @@ -82,8 +82,8 @@ case class ExtModulePathAnnotation(source: ReferenceTarget, sink: ReferenceTarge case class CombinationalPath(sink: ReferenceTarget, sources: Seq[ReferenceTarget]) extends Annotation { override def update(renames: RenameMap): Seq[Annotation] = { - val newSources = sources.flatMap { s => renames(s) }.collect {case x: ReferenceTarget if x.isLocal => x} - val newSinks = renames(sink).collect { case x: ReferenceTarget if x.isLocal => x} + val newSources = sources.flatMap { s => renames(s) }.collect { case x: ReferenceTarget if x.isLocal => x } + val newSinks = renames(sink).collect { case x: ReferenceTarget if x.isLocal => x } newSinks.map(snk => CombinationalPath(snk, newSources)) } } @@ -98,14 +98,10 @@ case class CombinationalPath(sink: ReferenceTarget, sources: Seq[ReferenceTarget * @note The pass relies on ExtModulePathAnnotations to find loops through ExtModules * @note The pass will throw exceptions on "false paths" */ -class CheckCombLoops extends Transform - with RegisteredTransform - with DependencyAPIMigration { +class CheckCombLoops extends Transform with RegisteredTransform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.MidForm ++ - Seq( Dependency(passes.LowerTypes), - Dependency(passes.Legalize), - Dependency(firrtl.transforms.RemoveReset) ) + Seq(Dependency(passes.LowerTypes), Dependency(passes.Legalize), Dependency(firrtl.transforms.RemoveReset)) override def optionalPrerequisites = Seq.empty @@ -119,17 +115,21 @@ class CheckCombLoops extends Transform new ShellOption[Unit]( longOption = "no-check-comb-loops", toAnnotationSeq = (_: Unit) => Seq(DontCheckCombLoopsAnnotation), - helpText = "Disable combinational loop checking" ) ) + helpText = "Disable combinational loop checking" + ) + ) private def getExprDeps(deps: MutableConnMap, v: LogicNode, info: Info)(e: Expression): Unit = e match { - case r: WRef => deps.addEdgeIfValid(v, LogicNode(r), info) + case r: WRef => deps.addEdgeIfValid(v, LogicNode(r), info) case s: WSubField => deps.addEdgeIfValid(v, LogicNode(s), info) case _ => e.foreach(getExprDeps(deps, v, info)) } private def getStmtDeps( simplifiedModules: mutable.Map[String, AbstractConnMap], - deps: MutableConnMap)(s: Statement): Unit = s match { + deps: MutableConnMap + )(s: Statement + ): Unit = s match { case Connect(info, loc, expr) => val lhs = LogicNode(loc) if (deps.contains(lhs)) { @@ -152,9 +152,9 @@ class CheckCombLoops extends Transform case i: WDefInstance => val iGraph = simplifiedModules(i.module).transformNodes(n => n.copy(inst = Some(i.name))) iGraph.getVertices.foreach(deps.addVertex(_)) - iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v,_) } }) + iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v, _) } }) case _ => - s.foreach(getStmtDeps(simplifiedModules,deps)) + s.foreach(getStmtDeps(simplifiedModules, deps)) } // Pretty-print a LogicNode with a prepended hierarchical path @@ -169,24 +169,26 @@ class CheckCombLoops extends Transform * recovered. */ private def expandInstancePaths( - m: String, + m: String, moduleGraphs: mutable.Map[String, ConnMap], - moduleDeps: Map[String, Map[String, String]], - hierPrefix: Seq[String], - path: Seq[LogicNode]): Seq[String] = { + moduleDeps: Map[String, Map[String, String]], + hierPrefix: Seq[String], + path: Seq[LogicNode] + ): Seq[String] = { // Recover info from edge data, add to error string def info(u: LogicNode, v: LogicNode): String = moduleGraphs(m).getEdgeData(u, v).map(_.toString).mkString("\t", "", "") // lhs comes after rhs - val pathNodes = (path zip path.tail) map { case (rhs, lhs) => - if (lhs.inst.isDefined && !lhs.memport.isDefined && lhs.inst == rhs.inst) { - val child = moduleDeps(m)(lhs.inst.get) - val newHierPrefix = hierPrefix :+ lhs.inst.get - val subpath = moduleGraphs(child).path(lhs.copy(inst=None),rhs.copy(inst=None)).reverse - expandInstancePaths(child, moduleGraphs, moduleDeps, newHierPrefix, subpath) - } else { - Seq(prettyPrintAbsoluteRef(hierPrefix, lhs) ++ info(lhs, rhs)) - } + val pathNodes = (path.zip(path.tail)).map { + case (rhs, lhs) => + if (lhs.inst.isDefined && !lhs.memport.isDefined && lhs.inst == rhs.inst) { + val child = moduleDeps(m)(lhs.inst.get) + val newHierPrefix = hierPrefix :+ lhs.inst.get + val subpath = moduleGraphs(child).path(lhs.copy(inst = None), rhs.copy(inst = None)).reverse + expandInstancePaths(child, moduleGraphs, moduleDeps, newHierPrefix, subpath) + } else { + Seq(prettyPrintAbsoluteRef(hierPrefix, lhs) ++ info(lhs, rhs)) + } } pathNodes.flatten } @@ -238,12 +240,13 @@ class CheckCombLoops extends Transform val errors = new Errors() val extModulePaths = state.annotations.groupBy { case ann: ExtModulePathAnnotation => ModuleTarget(c.main, ann.source.module) - case ann: Annotation => CircuitTarget(c.main) + case ann: Annotation => CircuitTarget(c.main) } - val moduleMap = c.modules.map({m => (m.name,m) }).toMap + val moduleMap = c.modules.map({ m => (m.name, m) }).toMap val iGraph = InstanceKeyGraph(c).graph - val moduleDeps = iGraph.getEdgeMap.map({ case (k,v) => (k.module, (v map { i => (i.name, i.module) }).toMap) }).toMap - val topoSortedModules = iGraph.transformNodes(_.module).linearize.reverse map { moduleMap(_) } + val moduleDeps = + iGraph.getEdgeMap.map({ case (k, v) => (k.module, (v.map { i => (i.name, i.module) }).toMap) }).toMap + val topoSortedModules = iGraph.transformNodes(_.module).linearize.reverse.map { moduleMap(_) } val moduleGraphs = new mutable.HashMap[String, ConnMap] val simplifiedModuleGraphs = new mutable.HashMap[String, AbstractConnMap] topoSortedModules.foreach { @@ -252,7 +255,8 @@ class CheckCombLoops extends Transform val extModuleDeps = new MutableDiGraph[LogicNode] with MutableEdgeData[LogicNode, Info] portSet.foreach(extModuleDeps.addVertex(_)) extModulePaths.getOrElse(ModuleTarget(c.main, em.name), Nil).collect { - case a: ExtModulePathAnnotation => extModuleDeps.addPairWithEdge(LogicNode(a.sink.ref), LogicNode(a.source.ref)) + case a: ExtModulePathAnnotation => + extModuleDeps.addPairWithEdge(LogicNode(a.sink.ref), LogicNode(a.source.ref)) } moduleGraphs(em.name) = extModuleDeps simplifiedModuleGraphs(em.name) = extModuleDeps.simplify(portSet) @@ -270,7 +274,7 @@ class CheckCombLoops extends Transform for (scc <- internalDeps.findSCCs.filter(_.length > 1)) { val sccSubgraph = internalDeps.subgraph(scc.toSet) val cycle = findCycleInSCC(sccSubgraph) - (cycle zip cycle.tail).foreach({ case (a,b) => require(internalDeps.getEdges(a).contains(b)) }) + (cycle.zip(cycle.tail)).foreach({ case (a, b) => require(internalDeps.getEdges(a).contains(b)) }) // Reverse to make sure LHS comes after RHS, print repeated vertex at start for legibility val intuitiveCycle = cycle.reverse val repeatedInitial = prettyPrintAbsoluteRef(Seq(m.name), intuitiveCycle.head) @@ -280,10 +284,11 @@ class CheckCombLoops extends Transform case m => throwInternalError(s"Module ${m.name} has unrecognized type") } val mt = ModuleTarget(c.main, c.main) - val annos = simplifiedModuleGraphs(c.main).getEdgeMap.collect { case (from, tos) if tos.nonEmpty => - val sink = mt.ref(from.name) - val sources = tos.map(to => mt.ref(to.name)) - CombinationalPath(sink, sources.toSeq) + val annos = simplifiedModuleGraphs(c.main).getEdgeMap.collect { + case (from, tos) if tos.nonEmpty => + val sink = mt.ref(from.name) + val sources = tos.map(to => mt.ref(to.name)) + CombinationalPath(sink, sources.toSeq) } (state.copy(annotations = state.annotations ++ annos), errors, simplifiedModuleGraphs, moduleGraphs) } @@ -291,7 +296,7 @@ class CheckCombLoops extends Transform /** * Returns a Map from Module name to port connectivity */ - def analyze(state: CircuitState): collection.Map[String,DiGraph[String]] = { + def analyze(state: CircuitState): collection.Map[String, DiGraph[String]] = { val (result, errors, connectivity, _) = run(state) connectivity.map { case (k, v) => (k, v.transformNodes(ln => ln.name)) @@ -301,7 +306,7 @@ class CheckCombLoops extends Transform /** * Returns a Map from Module name to complete netlist connectivity */ - def analyzeFull(state: CircuitState): collection.Map[String,DiGraph[LogicNode]] = { + def analyzeFull(state: CircuitState): collection.Map[String, DiGraph[LogicNode]] = { run(state)._4 } diff --git a/src/main/scala/firrtl/transforms/CombineCats.scala b/src/main/scala/firrtl/transforms/CombineCats.scala index 7fa01e46..3014d0e3 100644 --- a/src/main/scala/firrtl/transforms/CombineCats.scala +++ b/src/main/scala/firrtl/transforms/CombineCats.scala @@ -1,4 +1,3 @@ - package firrtl package transforms @@ -14,26 +13,30 @@ import scala.collection.mutable case class MaxCatLenAnnotation(maxCatLen: Int) extends NoTargetAnnotation object CombineCats { + /** Mapping from references to the [[firrtl.ir.Expression Expression]]s that drive them paired with their Cat length */ type Netlist = mutable.HashMap[WrappedExpression, (Int, Expression)] def expandCatArgs(maxCatLen: Int, netlist: Netlist)(expr: Expression): (Int, Expression) = expr match { - case cat@DoPrim(Cat, args, _, _) => + case cat @ DoPrim(Cat, args, _, _) => val (a0Len, a0Expanded) = expandCatArgs(maxCatLen - 1, netlist)(args.head) val (a1Len, a1Expanded) = expandCatArgs(maxCatLen - a0Len, netlist)(args(1)) (a0Len + a1Len, cat.copy(args = Seq(a0Expanded, a1Expanded)).asInstanceOf[Expression]) case other => - netlist.get(we(expr)).collect { - case (len, cat@DoPrim(Cat, _, _, _)) if maxCatLen >= len => expandCatArgs(maxCatLen, netlist)(cat) - }.getOrElse((1, other)) + netlist + .get(we(expr)) + .collect { + case (len, cat @ DoPrim(Cat, _, _, _)) if maxCatLen >= len => expandCatArgs(maxCatLen, netlist)(cat) + } + .getOrElse((1, other)) } def onStmt(maxCatLen: Int, netlist: Netlist)(stmt: Statement): Statement = { stmt.map(onStmt(maxCatLen, netlist)) match { - case node@DefNode(_, name, value) => + case node @ DefNode(_, name, value) => val catLenAndVal = value match { - case cat@DoPrim(Cat, _, _, _) => expandCatArgs(maxCatLen, netlist)(cat) - case other => (1, other) + case cat @ DoPrim(Cat, _, _, _) => expandCatArgs(maxCatLen, netlist)(cat) + case other => (1, other) } netlist(we(WRef(name))) = catLenAndVal node.copy(value = catLenAndVal._2) @@ -55,16 +58,16 @@ object CombineCats { class CombineCats extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.LowForm ++ - Seq( Dependency(passes.RemoveValidIf), - Dependency[firrtl.transforms.ConstantPropagation], - Dependency(firrtl.passes.memlib.VerilogMemDelays), - Dependency(firrtl.passes.SplitExpressions) ) + Seq( + Dependency(passes.RemoveValidIf), + Dependency[firrtl.transforms.ConstantPropagation], + Dependency(firrtl.passes.memlib.VerilogMemDelays), + Dependency(firrtl.passes.SplitExpressions) + ) override def optionalPrerequisites = Seq.empty - override def optionalPrerequisiteOf = Seq( - Dependency[SystemVerilogEmitter], - Dependency[VerilogEmitter] ) + override def optionalPrerequisiteOf = Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) override def invalidates(a: Transform) = false diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index ce36dd72..dc9b2bbe 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -28,7 +28,7 @@ object ConstantPropagation { /** Pads e to the width of t */ def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match { - case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) + case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) case (we, wt) if we == wt => e } @@ -44,38 +44,40 @@ object ConstantPropagation { case lit: Literal => require(hi >= lo) UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe)) - case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match { - case t: UIntType => x - case _ => asUInt(x, e.tpe) - } + case x if bitWidth(e.tpe) == bitWidth(x.tpe) => + x.tpe match { + case t: UIntType => x + case _ => asUInt(x, e.tpe) + } case _ => e } } def foldShiftRight(e: DoPrim) = e.consts.head.toInt match { case 0 => e.args.head - case x => e.args.head match { - // TODO when amount >= x.width, return a zero-width wire - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1)) - // take sign bit if shift amount is larger than arg width - case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1)) - case _ => e - } + case x => + e.args.head match { + // TODO when amount >= x.width, return a zero-width wire + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x).max(1))) + // take sign bit if shift amount is larger than arg width + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x).max(1))) + case _ => e + } } - - /********************************************** - * REGISTER CONSTANT PROPAGATION HELPER TYPES * - **********************************************/ + /** ******************************************** + * REGISTER CONSTANT PROPAGATION HELPER TYPES * + * ******************************************** + */ // A utility class that is somewhat like an Option but with two variants containing Nothing. // for register constant propagation (register or literal). private abstract class ConstPropBinding[+T] { def resolve[V >: T](that: ConstPropBinding[V]): ConstPropBinding[V] = (this, that) match { - case (x, y) if (x == y) => x + case (x, y) if (x == y) => x case (x, UnboundConstant) => x case (UnboundConstant, y) => y - case _ => NonConstant + case _ => NonConstant } } @@ -103,21 +105,23 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res override def prerequisites = ((new mutable.LinkedHashSet()) - ++ firrtl.stage.Forms.LowForm - - Dependency(firrtl.passes.Legalize) - + Dependency(firrtl.passes.RemoveValidIf)).toSeq + ++ firrtl.stage.Forms.LowForm + - Dependency(firrtl.passes.Legalize) + + Dependency(firrtl.passes.RemoveValidIf)).toSeq override def optionalPrerequisites = Seq.empty override def optionalPrerequisiteOf = - Seq( Dependency(firrtl.passes.memlib.VerilogMemDelays), - Dependency(firrtl.passes.SplitExpressions), - Dependency[SystemVerilogEmitter], - Dependency[VerilogEmitter] ) + Seq( + Dependency(firrtl.passes.memlib.VerilogMemDelays), + Dependency(firrtl.passes.SplitExpressions), + Dependency[SystemVerilogEmitter], + Dependency[VerilogEmitter] + ) override def invalidates(a: Transform): Boolean = a match { case firrtl.passes.Legalize => true - case _ => false + case _ => false } override val annotationClasses: Traversable[Class[_]] = Seq(classOf[DontTouchAnnotation]) @@ -130,7 +134,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res } sealed trait FoldCommutativeOp extends SimplifyBinaryOp { - def fold(c1: Literal, c2: Literal): Expression + def fold(c1: Literal, c2: Literal): Expression def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression override def apply(e: DoPrim): Expression = (e.args.head, e.args(1)) match { @@ -138,7 +142,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res case (lhs: Literal, rhs) => pad(simplify(e, lhs, rhs), e.tpe) case (lhs, rhs: Literal) => pad(simplify(e, rhs, lhs), e.tpe) case (lhs, rhs) if (lhs == rhs) => matchingArgsValue(e, lhs) - case _ => e + case _ => e } } @@ -177,20 +181,20 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res */ def apply(prim: DoPrim): Expression = prim.args.head match { case a: Literal => simplifyLiteral(a) - case _ => prim + case _ => prim } } object FoldADD extends FoldCommutativeOp { def fold(c1: Literal, c2: Literal) = ((c1, c2): @unchecked) match { - case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1)) - case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1)) + case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width.max(c2.width)) + IntWidth(1)) + case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width.max(c2.width)) + IntWidth(1)) } def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, w) if v == BigInt(0) => rhs case SIntLiteral(v, w) if v == BigInt(0) => rhs - case _ => e + case _ => e } def matchingArgsValue(e: DoPrim, arg: Expression) = e } @@ -209,77 +213,81 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res object FoldAND extends FoldCommutativeOp { def fold(c1: Literal, c2: Literal) = { - val width = (c1.width max c2.width).asInstanceOf[IntWidth] + val width = (c1.width.max(c2.width)).asInstanceOf[IntWidth] UIntLiteral.masked(c1.value & c2.value, width) } def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) - case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) + case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) + case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => rhs - case _ => e + case _ => e } def matchingArgsValue(e: DoPrim, arg: Expression) = asUInt(arg, e.tpe) } object FoldOR extends FoldCommutativeOp { def fold(c1: Literal, c2: Literal) = { - val width = (c1.width max c2.width).asInstanceOf[IntWidth] + val width = (c1.width.max(c2.width)).asInstanceOf[IntWidth] UIntLiteral.masked((c1.value | c2.value), width) } def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, _) if v == BigInt(0) => rhs - case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) + case UIntLiteral(v, _) if v == BigInt(0) => rhs + case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => lhs - case _ => e + case _ => e } def matchingArgsValue(e: DoPrim, arg: Expression) = asUInt(arg, e.tpe) } object FoldXOR extends FoldCommutativeOp { def fold(c1: Literal, c2: Literal) = { - val width = (c1.width max c2.width).asInstanceOf[IntWidth] + val width = (c1.width.max(c2.width)).asInstanceOf[IntWidth] UIntLiteral.masked((c1.value ^ c2.value), width) } def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, _) if v == BigInt(0) => rhs case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) - case _ => e + case _ => e } def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(0, getWidth(arg.tpe)) } object FoldEqual extends FoldCommutativeOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1)) + def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1)) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs - case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => DoPrim(Not, Seq(rhs), Nil, e.tpe) + case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => + DoPrim(Not, Seq(rhs), Nil, e.tpe) case _ => e } def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(1) } object FoldNotEqual extends FoldCommutativeOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1)) + def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1)) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs - case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => DoPrim(Not, Seq(rhs), Nil, e.tpe) + case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => + DoPrim(Not, Seq(rhs), Nil, e.tpe) case _ => e } def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(0) } private def foldConcat(e: DoPrim) = (e.args.head, e.args(1)) match { - case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) => UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw)) + case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) => + UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw)) case _ => e } private def foldShiftLeft(e: DoPrim) = e.consts.head.toInt match { case 0 => e.args.head - case x => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x)) - case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x)) - case _ => e - } + case x => + e.args.head match { + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x)) + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x)) + case _ => e + } } private def foldDynamicShiftLeft(e: DoPrim) = e.args.last match { @@ -296,53 +304,55 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res case _ => e } - private def foldComparison(e: DoPrim) = { def foldIfZeroedArg(x: Expression): Expression = { def isUInt(e: Expression): Boolean = e.tpe match { case UIntType(_) => true - case _ => false + case _ => false } def isZero(e: Expression) = e match { - case UIntLiteral(value, _) => value == BigInt(0) - case SIntLiteral(value, _) => value == BigInt(0) - case _ => false - } + case UIntLiteral(value, _) => value == BigInt(0) + case SIntLiteral(value, _) => value == BigInt(0) + case _ => false + } x match { - case DoPrim(Lt, Seq(a,b),_,_) if isUInt(a) && isZero(b) => zero - case DoPrim(Leq, Seq(a,b),_,_) if isZero(a) && isUInt(b) => one - case DoPrim(Gt, Seq(a,b),_,_) if isZero(a) && isUInt(b) => zero - case DoPrim(Geq, Seq(a,b),_,_) if isUInt(a) && isZero(b) => one - case ex => ex + case DoPrim(Lt, Seq(a, b), _, _) if isUInt(a) && isZero(b) => zero + case DoPrim(Leq, Seq(a, b), _, _) if isZero(a) && isUInt(b) => one + case DoPrim(Gt, Seq(a, b), _, _) if isZero(a) && isUInt(b) => zero + case DoPrim(Geq, Seq(a, b), _, _) if isUInt(a) && isZero(b) => one + case ex => ex } } def foldIfOutsideRange(x: Expression): Expression = { //Note, only abides by a partial ordering case class Range(min: BigInt, max: BigInt) { - def === (that: Range) = + def ===(that: Range) = Seq(this.min, this.max, that.min, that.max) - .sliding(2,1) + .sliding(2, 1) .map(x => x.head == x(1)) .reduce(_ && _) - def > (that: Range) = this.min > that.max - def >= (that: Range) = this.min >= that.max - def < (that: Range) = this.max < that.min - def <= (that: Range) = this.max <= that.min + def >(that: Range) = this.min > that.max + def >=(that: Range) = this.min >= that.max + def <(that: Range) = this.max < that.min + def <=(that: Range) = this.max <= that.min } def range(e: Expression): Range = e match { case UIntLiteral(value, _) => Range(value, value) case SIntLiteral(value, _) => Range(value, value) - case _ => e.tpe match { - case SIntType(IntWidth(width)) => Range( - min = BigInt(0) - BigInt(2).pow(width.toInt - 1), - max = BigInt(2).pow(width.toInt - 1) - BigInt(1) - ) - case UIntType(IntWidth(width)) => Range( - min = BigInt(0), - max = BigInt(2).pow(width.toInt) - BigInt(1) - ) - } + case _ => + e.tpe match { + case SIntType(IntWidth(width)) => + Range( + min = BigInt(0) - BigInt(2).pow(width.toInt - 1), + max = BigInt(2).pow(width.toInt - 1) - BigInt(1) + ) + case UIntType(IntWidth(width)) => + Range( + min = BigInt(0), + max = BigInt(2).pow(width.toInt) - BigInt(1) + ) + } } // Calculates an expression's range of values x match { @@ -351,27 +361,28 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res def r1 = range(ex.args(1)) ex.op match { // Always true - case Lt if r0 < r1 => one + case Lt if r0 < r1 => one case Leq if r0 <= r1 => one - case Gt if r0 > r1 => one + case Gt if r0 > r1 => one case Geq if r0 >= r1 => one // Always false - case Lt if r0 >= r1 => zero + case Lt if r0 >= r1 => zero case Leq if r0 > r1 => zero - case Gt if r0 <= r1 => zero + case Gt if r0 <= r1 => zero case Geq if r0 < r1 => zero - case _ => ex + case _ => ex } case ex => ex } } def foldIfMatchingArgs(x: Expression) = x match { - case DoPrim(op, Seq(a, b), _, _) if (a == b) => op match { - case (Lt | Gt) => zero - case (Leq | Geq) => one - case _ => x - } + case DoPrim(op, Seq(a, b), _, _) if (a == b) => + op match { + case (Lt | Gt) => zero + case (Leq | Geq) => one + case _ => x + } case _ => x } foldIfZeroedArg(foldIfOutsideRange(foldIfMatchingArgs(e))) @@ -393,43 +404,47 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res } private def constPropPrim(e: DoPrim): Expression = e.op match { - case Shl => foldShiftLeft(e) - case Dshl => foldDynamicShiftLeft(e) - case Shr => foldShiftRight(e) - case Dshr => foldDynamicShiftRight(e) - case Cat => foldConcat(e) - case Add => FoldADD(e) - case Sub => SimplifySUB(e) - case Div => SimplifyDIV(e) - case Rem => SimplifyREM(e) - case And => FoldAND(e) - case Or => FoldOR(e) - case Xor => FoldXOR(e) - case Eq => FoldEqual(e) - case Neq => FoldNotEqual(e) - case Andr => FoldANDR(e) - case Orr => FoldORR(e) - case Xorr => FoldXORR(e) + case Shl => foldShiftLeft(e) + case Dshl => foldDynamicShiftLeft(e) + case Shr => foldShiftRight(e) + case Dshr => foldDynamicShiftRight(e) + case Cat => foldConcat(e) + case Add => FoldADD(e) + case Sub => SimplifySUB(e) + case Div => SimplifyDIV(e) + case Rem => SimplifyREM(e) + case And => FoldAND(e) + case Or => FoldOR(e) + case Xor => FoldXOR(e) + case Eq => FoldEqual(e) + case Neq => FoldNotEqual(e) + case Andr => FoldANDR(e) + case Orr => FoldORR(e) + case Xorr => FoldXORR(e) case (Lt | Leq | Gt | Geq) => foldComparison(e) - case Not => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) - case _ => e - } + case Not => + e.args.head match { + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) + case _ => e + } case AsUInt => e.args.head match { case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w)) - case arg => arg.tpe match { - case _: UIntType => arg - case _ => e - } + case arg => + arg.tpe match { + case _: UIntType => arg + case _ => e + } } - case AsSInt => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt-1)) << w.toInt), IntWidth(w)) - case arg => arg.tpe match { - case _: SIntType => arg - case _ => e + case AsSInt => + e.args.head match { + case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt - 1)) << w.toInt), IntWidth(w)) + case arg => + arg.tpe match { + case _: SIntType => arg + case _ => e + } } - } case AsClock => val arg = e.args.head arg.tpe match { @@ -442,25 +457,27 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res case AsyncResetType => arg case _ => e } - case Pad => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head max w)) - case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head max w)) - case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head - case _ => e - } + case Pad => + e.args.head match { + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head.max(w))) + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head.max(w))) + case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head + case _ => e + } case (Bits | Head | Tail) => constPropBitExtract(e) - case _ => e + case _ => e } private def constPropMuxCond(m: Mux) = m.cond match { case UIntLiteral(c, _) => pad(if (c == BigInt(1)) m.tval else m.fval, m.tpe) - case _ => m + case _ => m } private def constPropMux(m: Mux): Expression = (m.tval, m.fval) match { case _ if m.tval == m.fval => m.tval case (t: UIntLiteral, f: UIntLiteral) - if t.value == BigInt(1) && f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1) => m.cond + if t.value == BigInt(1) && f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1) => + m.cond case (t: UIntLiteral, _) if t.value == BigInt(1) && bitWidth(m.tpe) == BigInt(1) => DoPrim(Or, Seq(m.cond, m.fval), Nil, m.tpe) case (_, f: UIntLiteral) if f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1) => @@ -479,15 +496,22 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // Is "a" a "better name" than "b"? private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_') - def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e) - def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e) - - private def constPropExpression(nodeMap: NodeMap, instMap: collection.Map[Instance, OfModule], constSubOutputs: Map[OfModule, Map[String, Literal]])(e: Expression): Expression = { - val old = e map constPropExpression(nodeMap, instMap, constSubOutputs) + def optimize(e: Expression): Expression = + constPropExpression(new NodeMap(), Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e) + def optimize(e: Expression, nodeMap: NodeMap): Expression = + constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e) + + private def constPropExpression( + nodeMap: NodeMap, + instMap: collection.Map[Instance, OfModule], + constSubOutputs: Map[OfModule, Map[String, Literal]] + )(e: Expression + ): Expression = { + val old = e.map(constPropExpression(nodeMap, instMap, constSubOutputs)) val propagated = old match { case p: DoPrim => constPropPrim(p) - case m: Mux => constPropMux(m) - case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) => + case m: Mux => constPropMux(m) + case ref @ WRef(rname, _, _, SourceFlow) if nodeMap.contains(rname) => constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2) case ref @ WSubField(WRef(inst, _, InstanceKind, _), pname, _, SourceFlow) => val module = instMap(inst.Instance) @@ -506,17 +530,17 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res * @todo generalize source locator propagation across Expressions and delete this method * @todo is the `orElse` the way we want to do propagation here? */ - private def propagateDirectConnectionInfoOnly(nodeMap: NodeMap, dontTouch: Set[String]) - (stmt: Statement): Statement = stmt match { - // We check rname because inlining it would cause the original declaration to go away - case node @ DefNode(info0, name, WRef(rname, _, NodeKind, _)) if !dontTouch(rname) => - val (info1, _) = InfoExpr.unwrap(nodeMap(rname)) - node.copy(info = InfoExpr.orElse(info1, info0)) - case con @ Connect(info0, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouch(rname) => - val (info1, _) = InfoExpr.unwrap(nodeMap(rname)) - con.copy(info = InfoExpr.orElse(info1, info0)) - case other => other - } + private def propagateDirectConnectionInfoOnly(nodeMap: NodeMap, dontTouch: Set[String])(stmt: Statement): Statement = + stmt match { + // We check rname because inlining it would cause the original declaration to go away + case node @ DefNode(info0, name, WRef(rname, _, NodeKind, _)) if !dontTouch(rname) => + val (info1, _) = InfoExpr.unwrap(nodeMap(rname)) + node.copy(info = InfoExpr.orElse(info1, info0)) + case con @ Connect(info0, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouch(rname) => + val (info1, _) = InfoExpr.unwrap(nodeMap(rname)) + con.copy(info = InfoExpr.orElse(info1, info0)) + case other => other + } /* Constant propagate a Module * @@ -538,12 +562,12 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res */ @tailrec private def constPropModule( - m: Module, - dontTouches: Set[String], - instMap: collection.Map[Instance, OfModule], - constInputs: Map[String, Literal], - constSubOutputs: Map[OfModule, Map[String, Literal]] - ): (Module, Map[String, Literal], Map[OfModule, Map[String, Seq[Literal]]]) = { + m: Module, + dontTouches: Set[String], + instMap: collection.Map[Instance, OfModule], + constInputs: Map[String, Literal], + constSubOutputs: Map[OfModule, Map[String, Literal]] + ): (Module, Map[String, Literal], Map[OfModule, Map[String, Seq[Literal]]]) = { var nPropagated = 0L val nodeMap = new NodeMap() @@ -571,13 +595,13 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // to constant wires, we don't need to worry about propagating primops or muxes since we'll do // that on the next iteration if necessary def backPropExpr(expr: Expression): Expression = { - val old = expr map backPropExpr + val old = expr.map(backPropExpr) val propagated = old match { // When swapping, we swap both rhs and lhs - case ref @ WRef(rname, _,_,_) if swapMap.contains(rname) => + case ref @ WRef(rname, _, _, _) if swapMap.contains(rname) => ref.copy(name = swapMap(rname)) // Only const prop on the rhs - case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) => + case ref @ WRef(rname, _, _, SourceFlow) if nodeMap.contains(rname) => constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2) case x => x } @@ -590,27 +614,29 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res def backPropStmt(stmt: Statement): Statement = stmt match { case reg: DefRegister if (WrappedExpression.weq(reg.init, WRef(reg))) => // Self-init reset is an idiom for "no reset," and must be handled separately - swapMap.get(reg.name) - .map(newName => reg.copy(name = newName, init = WRef(reg).copy(name = newName))) - .getOrElse(reg) - case s => s map backPropExpr match { - case decl: IsDeclaration if swapMap.contains(decl.name) => - val newName = swapMap(decl.name) - nPropagated += 1 - decl match { - case node: DefNode => node.copy(name = newName) - case wire: DefWire => wire.copy(name = newName) - case reg: DefRegister => reg.copy(name = newName) - case other => throwInternalError() - } - case other => other map backPropStmt - } + swapMap + .get(reg.name) + .map(newName => reg.copy(name = newName, init = WRef(reg).copy(name = newName))) + .getOrElse(reg) + case s => + s.map(backPropExpr) match { + case decl: IsDeclaration if swapMap.contains(decl.name) => + val newName = swapMap(decl.name) + nPropagated += 1 + decl match { + case node: DefNode => node.copy(name = newName) + case wire: DefWire => wire.copy(name = newName) + case reg: DefRegister => reg.copy(name = newName) + case other => throwInternalError() + } + case other => other.map(backPropStmt) + } } // When propagating a reference, check if we want to keep the name that would be deleted def propagateRef(lname: String, value: Expression, info: Info): Unit = { value match { - case WRef(rname,_,kind,_) if betterName(lname, rname) && !swapMap.contains(rname) && kind != PortKind => + case WRef(rname, _, kind, _) if betterName(lname, rname) && !swapMap.contains(rname) && kind != PortKind => assert(!swapMap.contains(lname)) // <- Shouldn't be possible because lname is either a // node declaration or the single connection to a wire or register swapMap += (lname -> rname, rname -> lname) @@ -639,25 +665,24 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // Const prop registers that are driven by a mux tree containing only instances of one constant or self-assigns // This requires that reset has been made explicit case Connect(_, lref @ WRef(lname, ltpe, RegKind, _), rhs) if !dontTouches(lname) => - - /* Checks if an RHS expression e of a register assignment is convertible to a constant assignment. - * Here, this means that e must be 1) a literal, 2) a self-connect, or 3) a mux tree of - * cases (1) and (2). In case (3), it also recursively checks that the two mux cases can - * be resolved: each side is allowed one candidate register and one candidate literal to - * appear in their source trees, referring to the potential constant propagation case that - * they could allow. If the two are compatible (no different bound sources of either of - * the two types), they can be resolved by combining sources. Otherwise, they propagate - * NonConstant values. When encountering a node reference, it expands the node by to its - * RHS assignment and recurses. - * - * @note Some optimization of Mux trees turn 1-bit mux operators into boolean operators. This - * can stifle register constant propagations, which looks at drivers through value-preserving - * Muxes and Connects only. By speculatively expanding some 1-bit Or and And operations into - * muxes, we can obtain the best possible insight on the value of the mux with a simple peephole - * de-optimization that does not actually appear in the output code. - * - * @return a RegCPEntry describing the constant prop-compatible sources driving this expression - */ + /* Checks if an RHS expression e of a register assignment is convertible to a constant assignment. + * Here, this means that e must be 1) a literal, 2) a self-connect, or 3) a mux tree of + * cases (1) and (2). In case (3), it also recursively checks that the two mux cases can + * be resolved: each side is allowed one candidate register and one candidate literal to + * appear in their source trees, referring to the potential constant propagation case that + * they could allow. If the two are compatible (no different bound sources of either of + * the two types), they can be resolved by combining sources. Otherwise, they propagate + * NonConstant values. When encountering a node reference, it expands the node by to its + * RHS assignment and recurses. + * + * @note Some optimization of Mux trees turn 1-bit mux operators into boolean operators. This + * can stifle register constant propagations, which looks at drivers through value-preserving + * Muxes and Connects only. By speculatively expanding some 1-bit Or and And operations into + * muxes, we can obtain the best possible insight on the value of the mux with a simple peephole + * de-optimization that does not actually appear in the output code. + * + * @return a RegCPEntry describing the constant prop-compatible sources driving this expression + */ val unbound = RegCPEntry(UnboundConstant, UnboundConstant) val selfBound = RegCPEntry(BoundConstant(lname), UnboundConstant) @@ -684,11 +709,11 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // Updates nodeMap after analyzing the returned value from regConstant def updateNodeMapIfConstant(e: Expression): Unit = regConstant(e, selfBound) match { - case RegCPEntry(UnboundConstant, UnboundConstant) => nodeMap(lname) = padCPExp(zero) + case RegCPEntry(UnboundConstant, UnboundConstant) => nodeMap(lname) = padCPExp(zero) case RegCPEntry(BoundConstant(_), UnboundConstant) => nodeMap(lname) = padCPExp(zero) - case RegCPEntry(UnboundConstant, BoundConstant(lit)) => nodeMap(lname) = padCPExp(lit) + case RegCPEntry(UnboundConstant, BoundConstant(lit)) => nodeMap(lname) = padCPExp(lit) case RegCPEntry(BoundConstant(_), BoundConstant(lit)) => nodeMap(lname) = padCPExp(lit) - case _ => + case _ => } def padCPExp(e: Expression) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(e, ltpe)) @@ -733,11 +758,11 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // Unify two maps using f to combine values of duplicate keys private def unify[K, V](a: Map[K, V], b: Map[K, V])(f: (V, V) => V): Map[K, V] = - b.foldLeft(a) { case (acc, (k, v)) => - acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v)) + b.foldLeft(a) { + case (acc, (k, v)) => + acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v)) } - private def run(c: Circuit, dontTouchMap: Map[OfModule, Set[String]]): Circuit = { val iGraph = InstanceKeyGraph(c) val moduleDeps = iGraph.getChildInstanceMap @@ -754,9 +779,11 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // are driven with the same constant value. Then, if we find a Module input where each instance // is driven with the same constant (and not seen in a previous iteration), we iterate again @tailrec - def iterate(toVisit: Set[OfModule], - modules: Map[OfModule, Module], - constInputs: Map[OfModule, Map[String, Literal]]): Map[OfModule, DefModule] = { + def iterate( + toVisit: Set[OfModule], + modules: Map[OfModule, Module], + constInputs: Map[OfModule, Map[String, Literal]] + ): Map[OfModule, DefModule] = { if (toVisit.isEmpty) modules else { // Order from leaf modules to root so that any module driving an output @@ -767,31 +794,36 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // Aggreagte Module outputs that are driven constant for use by instaniating Modules // Aggregate submodule inputs driven constant for checking later val (modulesx, _, constInputsx) = - order.foldLeft((modules, - Map[OfModule, Map[String, Literal]](), - Map[OfModule, Map[String, Seq[Literal]]]())) { + order.foldLeft((modules, Map[OfModule, Map[String, Literal]](), Map[OfModule, Map[String, Seq[Literal]]]())) { case ((mmap, constOutputs, constInputsAcc), mname) => val dontTouches = dontTouchMap.getOrElse(mname, Set.empty) - val (mx, mco, mci) = constPropModule(modules(mname), dontTouches, moduleDeps(mname), - constInputs.getOrElse(mname, Map.empty), constOutputs) + val (mx, mco, mci) = constPropModule( + modules(mname), + dontTouches, + moduleDeps(mname), + constInputs.getOrElse(mname, Map.empty), + constOutputs + ) // Accumulate all Literals used to drive a particular Module port val constInputsx = unify(constInputsAcc, mci)((a, b) => unify(a, b)((c, d) => c ++ d)) (mmap + (mname -> mx), constOutputs + (mname -> mco), constInputsx) } // Determine which module inputs have all of the same, new constants driving them - val newProppedInputs = constInputsx.flatMap { case (mname, ports) => - val portsx = ports.flatMap { case (pname, lits) => - val newPort = !constInputs.get(mname).map(_.contains(pname)).getOrElse(false) - val isModule = modules.contains(mname) // ExtModules are not contained in modules - val allSameConst = lits.size == instCount(mname) && lits.toSet.size == 1 - if (isModule && newPort && allSameConst) Some(pname -> lits.head) - else None - } - if (portsx.nonEmpty) Some(mname -> portsx) else None + val newProppedInputs = constInputsx.flatMap { + case (mname, ports) => + val portsx = ports.flatMap { + case (pname, lits) => + val newPort = !constInputs.get(mname).map(_.contains(pname)).getOrElse(false) + val isModule = modules.contains(mname) // ExtModules are not contained in modules + val allSameConst = lits.size == instCount(mname) && lits.toSet.size == 1 + if (isModule && newPort && allSameConst) Some(pname -> lits.head) + else None + } + if (portsx.nonEmpty) Some(mname -> portsx) else None } val modsWithConstInputs = newProppedInputs.keySet val newToVisit = modsWithConstInputs ++ - modsWithConstInputs.flatMap(parentGraph.reachableFrom) + modsWithConstInputs.flatMap(parentGraph.reachableFrom) // Combine const inputs (there can't be duplicate values in the inner maps) val nextConstInputs = unify(constInputs, newProppedInputs)((a, b) => a ++ b) iterate(newToVisit.toSet, modulesx, nextConstInputs) @@ -805,7 +837,6 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res c.modules.map(m => mmap.getOrElse(m.OfModule, m)) } - Circuit(c.info, modulesx, c.main) } diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala index c883bdfb..fb1bd1f6 100644 --- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala +++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala @@ -1,4 +1,3 @@ - package firrtl.transforms import firrtl._ @@ -8,7 +7,7 @@ import firrtl.annotations._ import firrtl.graph._ import firrtl.analyses.InstanceKeyGraph import firrtl.Mappers._ -import firrtl.Utils.{throwInternalError, kind} +import firrtl.Utils.{kind, throwInternalError} import firrtl.MemoizedHash._ import firrtl.options.{Dependency, RegisteredTransform, ShellOption} @@ -29,29 +28,34 @@ import collection.mutable * circumstances of their instantiation in their parent module, they will still not be removed. To * remove such modules, use the [[NoDedupAnnotation]] to prevent deduplication. */ -class DeadCodeElimination extends Transform +class DeadCodeElimination + extends Transform with ResolvedAnnotationPaths with RegisteredTransform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.LowForm ++ - Seq( Dependency(firrtl.passes.RemoveValidIf), - Dependency[firrtl.transforms.ConstantPropagation], - Dependency(firrtl.passes.memlib.VerilogMemDelays), - Dependency(firrtl.passes.SplitExpressions), - Dependency[firrtl.transforms.CombineCats], - Dependency(passes.CommonSubexpressionElimination) ) + Seq( + Dependency(firrtl.passes.RemoveValidIf), + Dependency[firrtl.transforms.ConstantPropagation], + Dependency(firrtl.passes.memlib.VerilogMemDelays), + Dependency(firrtl.passes.SplitExpressions), + Dependency[firrtl.transforms.CombineCats], + Dependency(passes.CommonSubexpressionElimination) + ) override def optionalPrerequisites = Seq.empty override def optionalPrerequisiteOf = - Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper], - Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], - Dependency[firrtl.transforms.FlattenRegUpdate], - Dependency(passes.VerilogModulusCleanup), - Dependency[firrtl.transforms.VerilogRename], - Dependency(passes.VerilogPrep), - Dependency[firrtl.AddDescriptionNodes] ) + Seq( + Dependency[firrtl.transforms.BlackBoxSourceHelper], + Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], + Dependency[firrtl.transforms.FlattenRegUpdate], + Dependency(passes.VerilogModulusCleanup), + Dependency[firrtl.transforms.VerilogRename], + Dependency(passes.VerilogPrep), + Dependency[firrtl.AddDescriptionNodes] + ) override def invalidates(a: Transform) = false @@ -59,7 +63,9 @@ class DeadCodeElimination extends Transform new ShellOption[Unit]( longOption = "no-dce", toAnnotationSeq = (_: Unit) => Seq(NoDCEAnnotation), - helpText = "Disable dead code elimination" ) ) + helpText = "Disable dead code elimination" + ) + ) /** Based on LogicNode ins CheckCombLoops, currently kind of faking it */ private type LogicNode = MemoizedHash[WrappedExpression] @@ -72,6 +78,7 @@ class DeadCodeElimination extends Transform val loweredName = LowerTypes.loweredName(component.name.split('.')) apply(component.module.name, WRef(loweredName)) } + /** External Modules are representated as a single node driven by all inputs and driving all * outputs */ @@ -87,7 +94,7 @@ class DeadCodeElimination extends Transform def rec(e: Expression): Expression = { e match { case ref @ (_: WRef | _: WSubField) => refs += ref - case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested map rec + case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested.map(rec) case ignore @ (_: Literal) => // Do nothing case unexpected => throwInternalError() } @@ -98,9 +105,7 @@ class DeadCodeElimination extends Transform } // Gets all dependencies and constructs LogicNodes from them - private def getDepsImpl(mname: String, - instMap: collection.Map[String, String]) - (expr: Expression): Seq[LogicNode] = + private def getDepsImpl(mname: String, instMap: collection.Map[String, String])(expr: Expression): Seq[LogicNode] = extractRefs(expr).map { e => if (kind(e) == InstanceKind) { val (inst, tail) = Utils.splitRef(e) @@ -110,11 +115,12 @@ class DeadCodeElimination extends Transform } } - /** Construct the dependency graph within this module */ - private def setupDepGraph(depGraph: MutableDiGraph[LogicNode], - instMap: collection.Map[String, String]) - (mod: Module): Unit = { + private def setupDepGraph( + depGraph: MutableDiGraph[LogicNode], + instMap: collection.Map[String, String] + )(mod: Module + ): Unit = { def getDeps(expr: Expression): Seq[LogicNode] = getDepsImpl(mod.name, instMap)(expr) def onStmt(stmt: Statement): Unit = stmt match { @@ -150,7 +156,7 @@ class DeadCodeElimination extends Transform val node = getDeps(loc) match { case Seq(elt) => elt } getDeps(expr).foreach(ref => depGraph.addPairWithEdge(node, ref)) // Simulation constructs are treated as top-level outputs - case Stop(_,_, clk, en) => + case Stop(_, _, clk, en) => Seq(clk, en).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(circuitSink, ref)) case Print(_, _, args, clk, en) => (args :+ clk :+ en).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(circuitSink, ref)) @@ -172,12 +178,14 @@ class DeadCodeElimination extends Transform } // TODO Make immutable? - private def createDependencyGraph(instMaps: collection.Map[String, collection.Map[String, String]], - doTouchExtMods: Set[String], - c: Circuit): MutableDiGraph[LogicNode] = { + private def createDependencyGraph( + instMaps: collection.Map[String, collection.Map[String, String]], + doTouchExtMods: Set[String], + c: Circuit + ): MutableDiGraph[LogicNode] = { val depGraph = new MutableDiGraph[LogicNode] c.modules.foreach { - case mod: Module => setupDepGraph(depGraph, instMaps(mod.name))(mod) + case mod: Module => setupDepGraph(depGraph, instMaps(mod.name))(mod) case ext: ExtModule => // Connect all inputs to all outputs val node = LogicNode(ext) @@ -205,23 +213,25 @@ class DeadCodeElimination extends Transform depGraph } - private def deleteDeadCode(instMap: collection.Map[String, String], - deadNodes: collection.Set[LogicNode], - moduleMap: collection.Map[String, DefModule], - renames: RenameMap, - topName: String, - doTouchExtMods: Set[String]) - (mod: DefModule): Option[DefModule] = { + private def deleteDeadCode( + instMap: collection.Map[String, String], + deadNodes: collection.Set[LogicNode], + moduleMap: collection.Map[String, DefModule], + renames: RenameMap, + topName: String, + doTouchExtMods: Set[String] + )(mod: DefModule + ): Option[DefModule] = { // For log-level debug def deleteMsg(decl: IsDeclaration): String = { val tpe = decl match { - case _: DefNode => "node" + case _: DefNode => "node" case _: DefRegister => "reg" - case _: DefWire => "wire" - case _: Port => "port" - case _: DefMemory => "mem" + case _: DefWire => "wire" + case _: Port => "port" + case _: DefMemory => "mem" case (_: DefInstance | _: WDefInstance) => "inst" - case _: Module => "module" + case _: Module => "module" case _: ExtModule => "extmodule" } val ref = decl match { @@ -237,7 +247,7 @@ class DeadCodeElimination extends Transform def deleteIfNotEnabled(stmt: Statement, en: Expression): Statement = en match { case UIntLiteral(v, _) if v == BigInt(0) => EmptyStmt - case _ => stmt + case _ => stmt } def onStmt(stmt: Statement): Statement = { @@ -256,12 +266,11 @@ class DeadCodeElimination extends Transform logger.debug(deleteMsg(decl)) renames.delete(decl.name) EmptyStmt - } - else decl - case print: Print => deleteIfNotEnabled(print, print.en) - case stop: Stop => deleteIfNotEnabled(stop, stop.en) + } else decl + case print: Print => deleteIfNotEnabled(print, print.en) + case stop: Stop => deleteIfNotEnabled(stop, stop.en) case formal: Verification => deleteIfNotEnabled(formal, formal.en) - case con: Connect => + case con: Connect => val node = getDeps(con.loc) match { case Seq(elt) => elt } if (deadNodes.contains(node)) EmptyStmt else con case Attach(info, exprs) => // If any exprs are dead then all are @@ -270,7 +279,7 @@ class DeadCodeElimination extends Transform case IsInvalid(info, expr) => val node = getDeps(expr) match { case Seq(elt) => elt } if (deadNodes.contains(node)) EmptyStmt else IsInvalid(info, expr) - case block: Block => block map onStmt + case block: Block => block.map(onStmt) case other => other } stmtx match { // Check if module empty @@ -300,8 +309,7 @@ class DeadCodeElimination extends Transform if (portsx.isEmpty && doTouchExtMods.contains(ext.name)) { logger.debug(deleteMsg(mod)) None - } - else { + } else { if (ext.ports != portsx) throwInternalError() // Sanity check Some(ext.copy(ports = portsx)) } @@ -309,14 +317,13 @@ class DeadCodeElimination extends Transform } - def run(state: CircuitState, - dontTouches: Seq[LogicNode], - doTouchExtMods: Set[String]): CircuitState = { + def run(state: CircuitState, dontTouches: Seq[LogicNode], doTouchExtMods: Set[String]): CircuitState = { val c = state.circuit val moduleMap = c.modules.map(m => m.name -> m).toMap val iGraph = InstanceKeyGraph(c) - val moduleDeps = iGraph.graph.getEdgeMap.map({ case (k,v) => - k.module -> v.map(i => i.name -> i.module).toMap + val moduleDeps = iGraph.graph.getEdgeMap.map({ + case (k, v) => + k.module -> v.map(i => i.name -> i.module).toMap }) val topoSortedModules = iGraph.graph.transformNodes(_.module).linearize.reverse.map(moduleMap(_)) @@ -347,11 +354,12 @@ class DeadCodeElimination extends Transform // themselves. We iterate over the modules in a topological order from leaves to the top. The // current status of the modulesxMap is used to either delete instances or update their types val modulesxMap = mutable.HashMap.empty[String, DefModule] - topoSortedModules.foreach { case mod => - deleteDeadCode(moduleDeps(mod.name), deadNodes, modulesxMap, renames, c.main, doTouchExtMods)(mod) match { - case Some(m) => modulesxMap += m.name -> m - case None => renames.delete(ModuleName(mod.name, CircuitName(c.main))) - } + topoSortedModules.foreach { + case mod => + deleteDeadCode(moduleDeps(mod.name), deadNodes, modulesxMap, renames, c.main, doTouchExtMods)(mod) match { + case Some(m) => modulesxMap += m.name -> m + case None => renames.delete(ModuleName(mod.name, CircuitName(c.main))) + } } // Preserve original module order diff --git a/src/main/scala/firrtl/transforms/Dedup.scala b/src/main/scala/firrtl/transforms/Dedup.scala index 627af11f..18e32cbc 100644 --- a/src/main/scala/firrtl/transforms/Dedup.scala +++ b/src/main/scala/firrtl/transforms/Dedup.scala @@ -20,7 +20,6 @@ import scala.annotation.tailrec // Datastructures import scala.collection.mutable - /** A component, e.g. register etc. Must be declared only once under the TopAnnotation */ case class NoDedupAnnotation(target: ModuleTarget) extends SingleTargetAnnotation[ModuleTarget] { def duplicate(n: ModuleTarget): NoDedupAnnotation = NoDedupAnnotation(n) @@ -36,7 +35,9 @@ case object NoCircuitDedupAnnotation extends NoTargetAnnotation with HasShellOpt new ShellOption[Unit]( longOption = "no-dedup", toAnnotationSeq = _ => Seq(NoCircuitDedupAnnotation), - helpText = "Do NOT dedup modules" ) ) + helpText = "Do NOT dedup modules" + ) + ) } @@ -46,12 +47,13 @@ case object NoCircuitDedupAnnotation extends NoTargetAnnotation with HasShellOpt * @param original Original module * @param index the normalized position of the original module in the original module list, fraction between 0 and 1 */ -case class DedupedResult(original: ModuleTarget, duplicate: Option[IsModule], index: Double) extends MultiTargetAnnotation { +case class DedupedResult(original: ModuleTarget, duplicate: Option[IsModule], index: Double) + extends MultiTargetAnnotation { override val targets: Seq[Seq[Target]] = Seq(Seq(original), duplicate.toList) override def duplicate(n: Seq[Seq[Target]]): Annotation = { n.toList match { case Seq(_, List(dup: IsModule)) => DedupedResult(original, Some(dup), index) - case _ => DedupedResult(original, None, -1) + case _ => DedupedResult(original, None, -1) } } } @@ -96,7 +98,7 @@ class DedupModules extends Transform with DependencyAPIMigration { val noDedups = state.circuit.main +: state.annotations.collect { case NoDedupAnnotation(ModuleTarget(_, m)) => m } val (remainingAnnotations, dupResults) = state.annotations.partition { case _: DupedResult => false - case _ => true + case _ => true } val previouslyDupedMap = dupResults.flatMap { case DupedResult(newModules, original) => @@ -114,9 +116,11 @@ class DedupModules extends Transform with DependencyAPIMigration { * @param noDedups Modules not to dedup * @return Deduped Circuit and corresponding RenameMap */ - def run(c: Circuit, - noDedups: Seq[String], - previouslyDupedMap: Map[String, String]): (Circuit, RenameMap, AnnotationSeq) = { + def run( + c: Circuit, + noDedups: Seq[String], + previouslyDupedMap: Map[String, String] + ): (Circuit, RenameMap, AnnotationSeq) = { // RenameMap val componentRenameMap = RenameMap() @@ -124,13 +128,16 @@ class DedupModules extends Transform with DependencyAPIMigration { // Maps module name to corresponding dedup module val dedupMap = DedupModules.deduplicate(c, noDedups.toSet, previouslyDupedMap, componentRenameMap) - val dedupCliques = dedupMap.foldLeft(Map.empty[String, Set[String]]) { - case (dedupCliqueMap, (orig: String, dupMod: DefModule)) => - val set = dedupCliqueMap.getOrElse(dupMod.name, Set.empty[String]) + dupMod.name + orig - dedupCliqueMap + (dupMod.name -> set) - }.flatMap { case (dedupName, set) => - set.map { _ -> set } - } + val dedupCliques = dedupMap + .foldLeft(Map.empty[String, Set[String]]) { + case (dedupCliqueMap, (orig: String, dupMod: DefModule)) => + val set = dedupCliqueMap.getOrElse(dupMod.name, Set.empty[String]) + dupMod.name + orig + dedupCliqueMap + (dupMod.name -> set) + } + .flatMap { + case (dedupName, set) => + set.map { _ -> set } + } // Use old module list to preserve ordering // Lookup what a module deduped to, if its a duplicate, remove it @@ -149,9 +156,10 @@ class DedupModules extends Transform with DependencyAPIMigration { val ct = CircuitTarget(c.main) - val map = dedupMap.map { case (from, to) => - logger.debug(s"[Dedup] $from -> ${to.name}") - ct.module(from).asInstanceOf[CompleteTarget] -> Seq(ct.module(to.name)) + val map = dedupMap.map { + case (from, to) => + logger.debug(s"[Dedup] $from -> ${to.name}") + ct.module(from).asInstanceOf[CompleteTarget] -> Seq(ct.module(to.name)) } val moduleRenameMap = RenameMap() moduleRenameMap.recordAll(map) @@ -159,15 +167,19 @@ class DedupModules extends Transform with DependencyAPIMigration { // Build instanceify renaming map val instanceGraph = InstanceKeyGraph(c) val instanceify = RenameMap() - val moduleName2Index = c.modules.map(_.name).zipWithIndex.map { case (n, i) => - { - c.modules.size match { - case 0 => (n, 0.0) - case 1 => (n, 1.0) - case d => (n, i.toDouble / (d - 1)) + val moduleName2Index = c.modules + .map(_.name) + .zipWithIndex + .map { + case (n, i) => { + c.modules.size match { + case 0 => (n, 0.0) + case 1 => (n, 1.0) + case d => (n, i.toDouble / (d - 1)) + } } } - }.toMap + .toMap // get the ordered set of instances a module, includes new Deduped modules val getChildrenInstances = { @@ -182,56 +194,62 @@ class DedupModules extends Transform with DependencyAPIMigration { } val instanceNameMap: Map[OfModule, Map[Instance, Instance]] = { - dedupMap.map { case (oldName, dedupedMod) => - val key = OfModule(oldName) - val value = getChildrenInstances(oldName).zip(getChildrenInstances(dedupedMod.name)).map { - case (oldInst, newInst) => Instance(oldInst.name) -> Instance(newInst.name) - }.toMap - key -> value + dedupMap.map { + case (oldName, dedupedMod) => + val key = OfModule(oldName) + val value = getChildrenInstances(oldName) + .zip(getChildrenInstances(dedupedMod.name)) + .map { + case (oldInst, newInst) => Instance(oldInst.name) -> Instance(newInst.name) + } + .toMap + key -> value }.toMap } - val dedupAnnotations = c.modules.map(_.name).map(ct.module).flatMap { case mt@ModuleTarget(c, m) if dedupCliques(m).size > 1 => - dedupMap.get(m) match { - case None => Nil - case Some(module: DefModule) => - val paths = instanceGraph.findInstancesInHierarchy(m) - // If dedupedAnnos is exactly annos, contains is because dedupedAnnos is type Option - val newTargets = paths.map { path => - val root: IsModule = ct.module(c) - path.foldLeft(root -> root) { case ((oldRelPath, newRelPath), InstanceKeyGraph.InstanceKey(name, mod)) => - if(mod == c) { - val mod = CircuitTarget(c).module(c) - mod -> mod - } else { - val enclosingMod = oldRelPath match { - case i: InstanceTarget => i.ofModule - case m: ModuleTarget => m.module - } - val instMap = instanceNameMap(OfModule(enclosingMod)) - val newInstName = instMap(Instance(name)).value - val old = oldRelPath.instOf(name, mod) - old -> newRelPath.instOf(newInstName, mod) + val dedupAnnotations = c.modules.map(_.name).map(ct.module).flatMap { + case mt @ ModuleTarget(c, m) if dedupCliques(m).size > 1 => + dedupMap.get(m) match { + case None => Nil + case Some(module: DefModule) => + val paths = instanceGraph.findInstancesInHierarchy(m) + // If dedupedAnnos is exactly annos, contains is because dedupedAnnos is type Option + val newTargets = paths.map { path => + val root: IsModule = ct.module(c) + path.foldLeft(root -> root) { + case ((oldRelPath, newRelPath), InstanceKeyGraph.InstanceKey(name, mod)) => + if (mod == c) { + val mod = CircuitTarget(c).module(c) + mod -> mod + } else { + val enclosingMod = oldRelPath match { + case i: InstanceTarget => i.ofModule + case m: ModuleTarget => m.module + } + val instMap = instanceNameMap(OfModule(enclosingMod)) + val newInstName = instMap(Instance(name)).value + val old = oldRelPath.instOf(name, mod) + old -> newRelPath.instOf(newInstName, mod) + } } } - } - // Add all relative paths to referredModule to map to new instances - def addRecord(from: IsMember, to: IsMember): Unit = from match { - case x: ModuleTarget => - instanceify.record(x, to) - case x: IsComponent => - instanceify.record(x, to) - addRecord(x.stripHierarchy(1), to) - } - // Instanceify deduped Modules! - if (dedupCliques(module.name).size > 1) { - newTargets.foreach { case (from, to) => addRecord(from, to) } - } - // Return Deduped Results - if (newTargets.size == 1) { - Seq(DedupedResult(mt, newTargets.headOption.map(_._1), moduleName2Index(m))) - } else Nil - } + // Add all relative paths to referredModule to map to new instances + def addRecord(from: IsMember, to: IsMember): Unit = from match { + case x: ModuleTarget => + instanceify.record(x, to) + case x: IsComponent => + instanceify.record(x, to) + addRecord(x.stripHierarchy(1), to) + } + // Instanceify deduped Modules! + if (dedupCliques(module.name).size > 1) { + newTargets.foreach { case (from, to) => addRecord(from, to) } + } + // Return Deduped Results + if (newTargets.size == 1) { + Seq(DedupedResult(mt, newTargets.headOption.map(_._1), moduleName2Index(m))) + } else Nil + } case noDedups => Nil } @@ -242,6 +260,7 @@ class DedupModules extends Transform with DependencyAPIMigration { /** Utility functions for [[DedupModules]] */ object DedupModules extends LazyLogging { + /** Change's a module's internal signal names, types, infos, and modules. * @param rename Function to rename a signal. Called on declaration and references. * @param retype Function to retype a signal. Called on declaration, references, and subfields @@ -250,14 +269,16 @@ object DedupModules extends LazyLogging { * @param module Module to change internals * @return Changed Module */ - def changeInternals(rename: String=>String, - retype: String=>Type=>Type, - reinfo: Info=>Info, - renameOfModule: (String, String)=>String, - renameExps: Boolean = true - )(module: DefModule): DefModule = { + def changeInternals( + rename: String => String, + retype: String => Type => Type, + reinfo: Info => Info, + renameOfModule: (String, String) => String, + renameExps: Boolean = true + )(module: DefModule + ): DefModule = { def onPort(p: Port): Port = Port(reinfo(p.info), rename(p.name), p.direction, retype(p.name)(p.tpe)) - def onExp(e: Expression): Expression = e match { + def onExp(e: Expression): Expression = e match { case WRef(n, t, k, g) => WRef(rename(n), retype(n)(t), k, g) case WSubField(expr, n, tpe, kind) => val fieldIndex = expr.tpe.asInstanceOf[BundleType].fields.indexWhere(f => f.name == n) @@ -266,12 +287,12 @@ object DedupModules extends LazyLogging { val finalExpr = WSubField(newExpr, newField.name, newField.tpe, kind) //TODO: renameMap.rename(e.serialize, finalExpr.serialize) finalExpr - case other => other map onExp + case other => other.map(onExp) } def onStmt(s: Statement): Statement = s match { case DefNode(info, name, value) => retype(name)(value.tpe) - if(renameExps) DefNode(reinfo(info), rename(name), onExp(value)) + if (renameExps) DefNode(reinfo(info), rename(name), onExp(value)) else DefNode(reinfo(info), rename(name), value) case WDefInstance(i, n, m, t) => val newmod = renameOfModule(n, m) @@ -283,12 +304,18 @@ object DedupModules extends LazyLogging { val oldType = MemPortUtils.memType(d) val newType = retype(d.name)(oldType) val index = oldType - .asInstanceOf[BundleType].fields.headOption - .map(_.tpe.asInstanceOf[BundleType].fields.indexWhere( - { - case Field("data" | "wdata" | "rdata", _, _) => true - case _ => false - })) + .asInstanceOf[BundleType] + .fields + .headOption + .map( + _.tpe + .asInstanceOf[BundleType] + .fields + .indexWhere({ + case Field("data" | "wdata" | "rdata", _, _) => true + case _ => false + }) + ) val newDataType = index match { case Some(i) => //If index nonempty, then there exists a port @@ -299,15 +326,15 @@ object DedupModules extends LazyLogging { // associate it with the type of the memory (as the memory type is different than the datatype) retype(d.name + ";&*^$")(d.dataType) } - d.copy(dataType = newDataType) map rename map reinfo + d.copy(dataType = newDataType).map(rename).map(reinfo) case h: IsDeclaration => - val temp = h map rename map retype(h.name) map reinfo - if(renameExps) temp map onExp else temp + val temp = h.map(rename).map(retype(h.name)).map(reinfo) + if (renameExps) temp.map(onExp) else temp case other => - val temp = other map reinfo map onStmt - if(renameExps) temp map onExp else temp + val temp = other.map(reinfo).map(onStmt) + if (renameExps) temp.map(onExp) else temp } - module map onPort map onStmt + module.map(onPort).map(onStmt) } /** Dedup a module's instances based on dedup map @@ -321,11 +348,13 @@ object DedupModules extends LazyLogging { * @param renameMap Will be modified to keep track of renames in this function * @return fixed up module deduped instances */ - def dedupInstances(top: CircuitTarget, - originalModule: String, - moduleMap: Map[String, DefModule], - name2name: Map[String, String], - renameMap: RenameMap): DefModule = { + def dedupInstances( + top: CircuitTarget, + originalModule: String, + moduleMap: Map[String, DefModule], + name2name: Map[String, String], + renameMap: RenameMap + ): DefModule = { val module = moduleMap(originalModule) // If black box, return it (it has no instances) @@ -340,7 +369,8 @@ object DedupModules extends LazyLogging { } val typeMap = mutable.HashMap[String, Type]() def retype(name: String)(tpe: Type): Type = { - if (typeMap.contains(name)) typeMap(name) else { + if (typeMap.contains(name)) typeMap(name) + else { if (instanceModuleMap.contains(name)) { val newType = Utils.module_type(getNewModule(instanceModuleMap(name))) typeMap(name) = newType @@ -360,7 +390,7 @@ object DedupModules extends LazyLogging { def renameOfModule(instance: String, ofModule: String): String = { name2name(ofModule) } - changeInternals({n => n}, retype, {i => i}, renameOfModule)(module) + changeInternals({ n => n }, retype, { i => i }, renameOfModule)(module) } @tailrec @@ -415,10 +445,11 @@ object DedupModules extends LazyLogging { * @return A map from tag to names of modules with the same structure and * a RenameMap which maps Module names to their Tag. */ - def buildRTLTags(top: CircuitTarget, - moduleLinearization: Seq[DefModule], - noDedups: Set[String] - ): (collection.Map[String, collection.Set[String]], RenameMap) = { + def buildRTLTags( + top: CircuitTarget, + moduleLinearization: Seq[DefModule], + noDedups: Set[String] + ): (collection.Map[String, collection.Set[String]], RenameMap) = { // maps hash code to human readable tag val hashToTag = mutable.HashMap[ir.HashCode, String]() @@ -449,9 +480,9 @@ object DedupModules extends LazyLogging { moduleNameToTag(originalModule.name) = hashToTag(hash) } - val tag2all = hashToNames.map{ case (hash, names) => hashToTag(hash) -> names.toSet } + val tag2all = hashToNames.map { case (hash, names) => hashToTag(hash) -> names.toSet } val tagMap = RenameMap() - moduleNameToTag.foreach{ case (name, tag) => tagMap.record(top.module(name), top.module(tag)) } + moduleNameToTag.foreach { case (name, tag) => tagMap.record(top.module(name), top.module(tag)) } (tag2all, tagMap) } @@ -461,10 +492,12 @@ object DedupModules extends LazyLogging { * @param renameMap rename map to populate when deduping * @return Map of original Module name -> Deduped Module */ - def deduplicate(circuit: Circuit, - noDedups: Set[String], - previousDupResults: Map[String, String], - renameMap: RenameMap): Map[String, DefModule] = { + def deduplicate( + circuit: Circuit, + noDedups: Set[String], + previousDupResults: Map[String, String], + renameMap: RenameMap + ): Map[String, DefModule] = { val (moduleMap, moduleLinearization) = { val iGraph = InstanceKeyGraph(circuit) @@ -479,13 +512,14 @@ object DedupModules extends LazyLogging { val (tag2all, tagMap) = buildRTLTags(top, moduleLinearization, noDedups) // Set tag2name to be the best dedup module name - val moduleIndex = circuit.modules.zipWithIndex.map{case (m, i) => m.name -> i}.toMap + val moduleIndex = circuit.modules.zipWithIndex.map { case (m, i) => m.name -> i }.toMap // returns the module matching the circuit name or the module with lower index otherwise def order(l: String, r: String): String = { if (l == main) l else if (r == main) r - else if (moduleIndex(l) < moduleIndex(r)) l else r + else if (moduleIndex(l) < moduleIndex(r)) l + else r } // Maps a module's tag to its deduplicated module @@ -499,7 +533,7 @@ object DedupModules extends LazyLogging { tag2name(tag) = dedupName val dedupModule = moduleMap(dedupWithoutOldName) match { case e: ExtModule => e.copy(name = dedupName) - case e: Module => e.copy(name = dedupName) + case e: Module => e.copy(name = dedupName) } dedupName -> dedupModule }.toMap @@ -508,32 +542,32 @@ object DedupModules extends LazyLogging { val name2name = moduleMap.keysIterator.map { originalModule => tagMap.get(top.module(originalModule)) match { case Some(Seq(Target(_, Some(tag), Nil))) => originalModule -> tag2name(tag) - case None => originalModule -> originalModule - case other => throwInternalError(other.toString) + case None => originalModule -> originalModule + case other => throwInternalError(other.toString) } }.toMap // Build Remap for modules with deduped module references val dedupedName2module = tag2name.map { - case (tag, name) => name -> DedupModules.dedupInstances( - top, name, moduleMapWithOldNames, name2name, renameMap) + case (tag, name) => name -> DedupModules.dedupInstances(top, name, moduleMapWithOldNames, name2name, renameMap) } // Build map from original name to corresponding deduped module // It is important to flatMap before looking up the DefModules so that they aren't hashed val name2module: Map[String, DefModule] = tag2all.flatMap { case (tag, names) => names.map(_ -> tag) } - .mapValues(tag => dedupedName2module(tag2name(tag))) - .toMap + .mapValues(tag => dedupedName2module(tag2name(tag))) + .toMap // Build renameMap val indexedTargets = mutable.HashMap[String, IndexedSeq[ReferenceTarget]]() - name2module.foreach { case (originalName, depModule) => - if(originalName != depModule.name) { - val toSeq = indexedTargets.getOrElseUpdate(depModule.name, computeIndexedNames(circuit.main, depModule)) - val fromSeq = computeIndexedNames(circuit.main, moduleMap(originalName)) - computeRenameMap(fromSeq, toSeq, renameMap) - } + name2module.foreach { + case (originalName, depModule) => + if (originalName != depModule.name) { + val toSeq = indexedTargets.getOrElseUpdate(depModule.name, computeIndexedNames(circuit.main, depModule)) + val fromSeq = computeIndexedNames(circuit.main, moduleMap(originalName)) + computeRenameMap(fromSeq, toSeq, renameMap) + } } name2module @@ -549,18 +583,21 @@ object DedupModules extends LazyLogging { tpe } - changeInternals(rename, retype, {i => i}, {(x, y) => x}, renameExps = false)(m) + changeInternals(rename, retype, { i => i }, { (x, y) => x }, renameExps = false)(m) refs.toIndexedSeq } - def computeRenameMap(originalNames: IndexedSeq[ReferenceTarget], - dedupedNames: IndexedSeq[ReferenceTarget], - renameMap: RenameMap): Unit = { + def computeRenameMap( + originalNames: IndexedSeq[ReferenceTarget], + dedupedNames: IndexedSeq[ReferenceTarget], + renameMap: RenameMap + ): Unit = { originalNames.zip(dedupedNames).foreach { - case (o, d) => if (o.component != d.component || o.ref != d.ref) { - renameMap.record(o, d.copy(module = o.module)) - } + case (o, d) => + if (o.component != d.component || o.ref != d.ref) { + renameMap.record(o, d.copy(module = o.module)) + } } } diff --git a/src/main/scala/firrtl/transforms/FixAddingNegativeLiteralsTransform.scala b/src/main/scala/firrtl/transforms/FixAddingNegativeLiteralsTransform.scala index a1e49d62..bfab31bf 100644 --- a/src/main/scala/firrtl/transforms/FixAddingNegativeLiteralsTransform.scala +++ b/src/main/scala/firrtl/transforms/FixAddingNegativeLiteralsTransform.scala @@ -33,7 +33,7 @@ object FixAddingNegativeLiterals { */ def fixupModule(m: DefModule): DefModule = { val namespace = Namespace(m) - m map fixupStatement(namespace) + m.map(fixupStatement(namespace)) } /** Returns a statement with fixed additions of negative literals @@ -43,8 +43,8 @@ object FixAddingNegativeLiterals { */ def fixupStatement(namespace: Namespace)(s: Statement): Statement = { val stmtBuffer = mutable.ListBuffer[Statement]() - val ret = s map fixupStatement(namespace) map fixupOnExpr(Utils.get_info(s), namespace, stmtBuffer) - if(stmtBuffer.isEmpty) { + val ret = s.map(fixupStatement(namespace)).map(fixupOnExpr(Utils.get_info(s), namespace, stmtBuffer)) + if (stmtBuffer.isEmpty) { ret } else { stmtBuffer += ret @@ -58,8 +58,7 @@ object FixAddingNegativeLiterals { * @param e expression to fixup * @return generated statements and the fixed expression */ - def fixupExpression(info: Info, namespace: Namespace) - (e: Expression): (Seq[Statement], Expression) = { + def fixupExpression(info: Info, namespace: Namespace)(e: Expression): (Seq[Statement], Expression) = { val stmtBuffer = mutable.ListBuffer[Statement]() val retExpr = fixupOnExpr(info, namespace, stmtBuffer)(e) (stmtBuffer.toList, retExpr) @@ -72,12 +71,16 @@ object FixAddingNegativeLiterals { * @param e expression to fixup * @return fixed expression */ - private def fixupOnExpr(info: Info, namespace: Namespace, stmtBuffer: mutable.ListBuffer[Statement]) - (e: Expression): Expression = { + private def fixupOnExpr( + info: Info, + namespace: Namespace, + stmtBuffer: mutable.ListBuffer[Statement] + )(e: Expression + ): Expression = { // Helper function to create the subtraction expression def fixupAdd(expr: Expression, litValue: BigInt, litWidth: BigInt): DoPrim = { - if(litValue == minNegValue(litWidth)) { + if (litValue == minNegValue(litWidth)) { val posLiteral = SIntLiteral(-litValue) assert(posLiteral.width.asInstanceOf[IntWidth].width - 1 == litWidth) val sub = DefNode(info, namespace.newTemp, setType(DoPrim(Sub, Seq(expr, posLiteral), Nil, UnknownType))) @@ -91,10 +94,10 @@ object FixAddingNegativeLiterals { } } - e map fixupOnExpr(info, namespace, stmtBuffer) match { - case DoPrim(Add, Seq(arg, lit@SIntLiteral(value, w@IntWidth(width))), Nil, t: SIntType) if value < 0 => + e.map(fixupOnExpr(info, namespace, stmtBuffer)) match { + case DoPrim(Add, Seq(arg, lit @ SIntLiteral(value, w @ IntWidth(width))), Nil, t: SIntType) if value < 0 => fixupAdd(arg, value, width) - case DoPrim(Add, Seq(lit@SIntLiteral(value, w@IntWidth(width)), arg), Nil, t: SIntType) if value < 0 => + case DoPrim(Add, Seq(lit @ SIntLiteral(value, w @ IntWidth(width)), arg), Nil, t: SIntType) if value < 0 => fixupAdd(arg, value, width) case other => other } diff --git a/src/main/scala/firrtl/transforms/Flatten.scala b/src/main/scala/firrtl/transforms/Flatten.scala index cc5b3504..36e71470 100644 --- a/src/main/scala/firrtl/transforms/Flatten.scala +++ b/src/main/scala/firrtl/transforms/Flatten.scala @@ -7,7 +7,7 @@ import firrtl.ir._ import firrtl.Mappers._ import firrtl.annotations._ import scala.collection.mutable -import firrtl.passes.{InlineInstances,PassException} +import firrtl.passes.{InlineInstances, PassException} import firrtl.stage.Forms /** Tags an annotation to be consumed by this transform */ @@ -25,101 +25,114 @@ case class FlattenAnnotation(target: Named) extends SingleTargetAnnotation[Named */ class Flatten extends Transform with DependencyAPIMigration { - override def prerequisites = Forms.LowForm - override def optionalPrerequisites = Seq.empty - override def optionalPrerequisiteOf = Forms.LowEmitters + override def prerequisites = Forms.LowForm + override def optionalPrerequisites = Seq.empty + override def optionalPrerequisiteOf = Forms.LowEmitters override def invalidates(a: Transform) = false - val inlineTransform = new InlineInstances - - private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) = - anns.foldLeft( (Set.empty[ModuleName], Set.empty[ComponentName]) ) { - case ((modNames, instNames), ann) => ann match { - case FlattenAnnotation(CircuitName(c)) => - (circuit.modules.collect { - case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c)) - }.toSet, instNames) - case FlattenAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames) - case FlattenAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod)) - case _ => throw new PassException("Annotation must be a FlattenAnnotation") - } - } - - /** + val inlineTransform = new InlineInstances + + private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) = + anns.foldLeft((Set.empty[ModuleName], Set.empty[ComponentName])) { + case ((modNames, instNames), ann) => + ann match { + case FlattenAnnotation(CircuitName(c)) => + ( + circuit.modules.collect { + case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c)) + }.toSet, + instNames + ) + case FlattenAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames) + case FlattenAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod)) + case _ => throw new PassException("Annotation must be a FlattenAnnotation") + } + } + + /** * Modifies the circuit by replicating the hierarchy under the annotated objects (mods and insts) and * by rewriting the original circuit to refer to the new modules that will be inlined later. * @return modified circuit and ModuleNames to inline */ - def duplicateSubCircuitsFromAnno(c: Circuit, mods: Set[ModuleName], insts: Set[ComponentName]): (Circuit, Set[ModuleName]) = { - val modMap = c.modules.map(m => m.name->m).toMap - val seedMods = mutable.Map.empty[String, String] - val newModDefs = mutable.Set.empty[DefModule] - val nsp = Namespace(c) - - /** + def duplicateSubCircuitsFromAnno( + c: Circuit, + mods: Set[ModuleName], + insts: Set[ComponentName] + ): (Circuit, Set[ModuleName]) = { + val modMap = c.modules.map(m => m.name -> m).toMap + val seedMods = mutable.Map.empty[String, String] + val newModDefs = mutable.Set.empty[DefModule] + val nsp = Namespace(c) + + /** * We start with rewriting DefInstances in the modules with annotations to refer to replicated modules to be created later. * It populates seedMods where we capture the mapping between the original module name of the instances came from annotation * to a new module name that we will create as a replica of the original one. * Note: We replace old modules with it replicas so that other instances of the same module can be left unchanged. */ - def rewriteMod(parent: DefModule)(x: Statement): Statement = x match { - case _: Block => x map rewriteMod(parent) - case WDefInstance(info, instName, moduleName, instTpe) => - if (insts.contains(ComponentName(instName, ModuleName(parent.name, CircuitName(c.main)))) - || mods.contains(ModuleName(parent.name, CircuitName(c.main)))) { - val newModName = if (seedMods.contains(moduleName)) seedMods(moduleName) else nsp.newName(moduleName+"_TO_FLATTEN") - seedMods += moduleName -> newModName - WDefInstance(info, instName, newModName, instTpe) - } else x - case _ => x - } - - val modifMods = c.modules map { m => m map rewriteMod(m) } - - /** + def rewriteMod(parent: DefModule)(x: Statement): Statement = x match { + case _: Block => x.map(rewriteMod(parent)) + case WDefInstance(info, instName, moduleName, instTpe) => + if ( + insts.contains(ComponentName(instName, ModuleName(parent.name, CircuitName(c.main)))) + || mods.contains(ModuleName(parent.name, CircuitName(c.main))) + ) { + val newModName = + if (seedMods.contains(moduleName)) seedMods(moduleName) else nsp.newName(moduleName + "_TO_FLATTEN") + seedMods += moduleName -> newModName + WDefInstance(info, instName, newModName, instTpe) + } else x + case _ => x + } + + val modifMods = c.modules.map { m => m.map(rewriteMod(m)) } + + /** * Recursively rewrites modules in the hierarchy starting with modules in seedMods (originally annotations). * Populates newModDefs, which are replicated modules used in the subcircuit that we create * by recursively traversing modules captured inside seedMods and replicating them */ - def recDupMods(mods: Map[String, String]): Unit = { - val replMods = mutable.Map.empty[String, String] - - def dupMod(x: Statement): Statement = x match { - case _: Block => x map dupMod - case WDefInstance(info, instName, moduleName, instTpe) => modMap(moduleName) match { - case m: Module => - val newModName = if (replMods.contains(moduleName)) replMods(moduleName) else nsp.newName(moduleName+"_TO_FLATTEN") - replMods += moduleName -> newModName - WDefInstance(info, instName, newModName, instTpe) - case _ => x // Ignore extmodules - } - case _ => x - } - - def dupName(name: String): String = mods(name) - val newMods = mods map { case (origName, newName) => modMap(origName) map dupMod map dupName } - - newModDefs ++= newMods - - if(replMods.size > 0) recDupMods(replMods.toMap) - - } - recDupMods(seedMods.toMap) - - //convert newly created modules to ModuleName for inlining next (outside this function) - val modsToInline = newModDefs map { m => ModuleName(m.name, CircuitName(c.main)) } - (c.copy(modules = modifMods ++ newModDefs), modsToInline.toSet) - } - - override def execute(state: CircuitState): CircuitState = { - val annos = state.annotations.collect { case a @ FlattenAnnotation(_) => a } - annos match { - case Nil => state - case myAnnotations => - val (modNames, instNames) = collectAnns(state.circuit, myAnnotations) - // take incoming annotation and produce annotations for InlineInstances, i.e. traverse circuit down to find all instances to inline - val (newc, modsToInline) = duplicateSubCircuitsFromAnno(state.circuit, modNames, instNames) - inlineTransform.run(newc, modsToInline.toSet, Set.empty[ComponentName], state.annotations) - } - } + def recDupMods(mods: Map[String, String]): Unit = { + val replMods = mutable.Map.empty[String, String] + + def dupMod(x: Statement): Statement = x match { + case _: Block => x.map(dupMod) + case WDefInstance(info, instName, moduleName, instTpe) => + modMap(moduleName) match { + case m: Module => + val newModName = + if (replMods.contains(moduleName)) replMods(moduleName) else nsp.newName(moduleName + "_TO_FLATTEN") + replMods += moduleName -> newModName + WDefInstance(info, instName, newModName, instTpe) + case _ => x // Ignore extmodules + } + case _ => x + } + + def dupName(name: String): String = mods(name) + val newMods = mods.map { case (origName, newName) => modMap(origName).map(dupMod).map(dupName) } + + newModDefs ++= newMods + + if (replMods.size > 0) recDupMods(replMods.toMap) + + } + recDupMods(seedMods.toMap) + + //convert newly created modules to ModuleName for inlining next (outside this function) + val modsToInline = newModDefs.map { m => ModuleName(m.name, CircuitName(c.main)) } + (c.copy(modules = modifMods ++ newModDefs), modsToInline.toSet) + } + + override def execute(state: CircuitState): CircuitState = { + val annos = state.annotations.collect { case a @ FlattenAnnotation(_) => a } + annos match { + case Nil => state + case myAnnotations => + val (modNames, instNames) = collectAnns(state.circuit, myAnnotations) + // take incoming annotation and produce annotations for InlineInstances, i.e. traverse circuit down to find all instances to inline + val (newc, modsToInline) = duplicateSubCircuitsFromAnno(state.circuit, modNames, instNames) + inlineTransform.run(newc, modsToInline.toSet, Set.empty[ComponentName], state.annotations) + } + } } diff --git a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala index a2399b5a..b582fe2a 100644 --- a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala +++ b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala @@ -119,7 +119,7 @@ object FlattenRegUpdate { def rec(e: Expression): (Info, Expression) = { val (info, expr) = kind(e) match { case NodeKind | WireKind if !endpoints(e) => unwrap(netlist.getOrElse(e, e)) - case _ => unwrap(e) + case _ => unwrap(e) } expr match { case Mux(cond, tval, fval, tpe) => @@ -128,16 +128,18 @@ object FlattenRegUpdate { val infox = combineInfos(info, tinfo, finfo) (infox, Mux(cond, tvalx, fvalx, tpe)) // Return the original expression to end flattening - case _ => unwrap(e) + case _ => unwrap(e) } } rec(start) } def onStmt(stmt: Statement): Statement = stmt.map(onStmt) match { - case reg @ DefRegister(_, rname, _,_, resetCond, _) => - assert(resetCond.tpe == AsyncResetType || resetCond == Utils.zero, - "Synchronous reset should have already been made explicit!") + case reg @ DefRegister(_, rname, _, _, resetCond, _) => + assert( + resetCond.tpe == AsyncResetType || resetCond == Utils.zero, + "Synchronous reset should have already been made explicit!" + ) val ref = WRef(reg) val (info, rhs) = constructRegUpdate(netlist.getOrElse(ref, ref)) val update = Connect(info, ref, rhs) @@ -145,7 +147,7 @@ object FlattenRegUpdate { reg // Remove connections to Registers so we preserve LowFirrtl single-connection semantics case Connect(_, lhs, _) if kind(lhs) == RegKind => EmptyStmt - case other => other + case other => other } val bodyx = onStmt(mod.body) @@ -163,12 +165,14 @@ object FlattenRegUpdate { class FlattenRegUpdate extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ - Seq( Dependency[BlackBoxSourceHelper], - Dependency[FixAddingNegativeLiterals], - Dependency[ReplaceTruncatingArithmetic], - Dependency[InlineBitExtractionsTransform], - Dependency[InlineCastsTransform], - Dependency[LegalizeClocksTransform] ) + Seq( + Dependency[BlackBoxSourceHelper], + Dependency[FixAddingNegativeLiterals], + Dependency[ReplaceTruncatingArithmetic], + Dependency[InlineBitExtractionsTransform], + Dependency[InlineCastsTransform], + Dependency[LegalizeClocksTransform] + ) override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized @@ -181,7 +185,7 @@ class FlattenRegUpdate extends Transform with DependencyAPIMigration { def execute(state: CircuitState): CircuitState = { val modulesx = state.circuit.modules.map { - case mod: Module => FlattenRegUpdate.flattenReg(mod) + case mod: Module => FlattenRegUpdate.flattenReg(mod) case ext: ExtModule => ext } state.copy(circuit = state.circuit.copy(modules = modulesx)) diff --git a/src/main/scala/firrtl/transforms/GroupComponents.scala b/src/main/scala/firrtl/transforms/GroupComponents.scala index 166feba0..0db67f1e 100644 --- a/src/main/scala/firrtl/transforms/GroupComponents.scala +++ b/src/main/scala/firrtl/transforms/GroupComponents.scala @@ -10,7 +10,6 @@ import firrtl.stage.Forms import scala.collection.mutable - /** * Specifies a group of components, within a module, to pull out into their own module * Components that are only connected to a group's components will also be included @@ -21,8 +20,14 @@ import scala.collection.mutable * @param outputSuffix suggested suffix of any output ports of the new module * @param inputSuffix suggested suffix of any input ports of the new module */ -case class GroupAnnotation(components: Seq[ComponentName], newModule: String, newInstance: String, outputSuffix: Option[String] = None, inputSuffix: Option[String] = None) extends Annotation { - if(components.nonEmpty) { +case class GroupAnnotation( + components: Seq[ComponentName], + newModule: String, + newInstance: String, + outputSuffix: Option[String] = None, + inputSuffix: Option[String] = None) + extends Annotation { + if (components.nonEmpty) { require(components.forall(_.module == components.head.module), "All components must be in the same module.") require(components.forall(!_.name.contains('.')), "No components can be a subcomponent.") } @@ -35,7 +40,7 @@ case class GroupAnnotation(components: Seq[ComponentName], newModule: String, ne /* Only keeps components renamed to components */ def update(renames: RenameMap): Seq[Annotation] = { - val newComponents = components.flatMap{c => renames.get(c).getOrElse(Seq(c))}.collect { + val newComponents = components.flatMap { c => renames.get(c).getOrElse(Seq(c)) }.collect { case c: ComponentName => c } Seq(GroupAnnotation(newComponents, newModule, newInstance, outputSuffix, inputSuffix)) @@ -58,7 +63,7 @@ class GroupComponents extends Transform with DependencyAPIMigration { } override def execute(state: CircuitState): CircuitState = { - val groups = state.annotations.collect {case g: GroupAnnotation => g} + val groups = state.annotations.collect { case g: GroupAnnotation => g } val module2group = groups.groupBy(_.currentModule) val mnamespace = Namespace(state.circuit) val newModules = state.circuit.modules.flatMap { @@ -74,13 +79,12 @@ class GroupComponents extends Transform with DependencyAPIMigration { val namespace = Namespace(m) val groupRoots = groups.map(_.components.map(_.name)) val totalSum = groupRoots.map(_.size).sum - val union = groupRoots.foldLeft(Set.empty[String]){(all, set) => all.union(set.toSet)} + val union = groupRoots.foldLeft(Set.empty[String]) { (all, set) => all.union(set.toSet) } - require(groupRoots.forall{_.forall{namespace.contains}}, "All names should be in this module") + require(groupRoots.forall { _.forall { namespace.contains } }, "All names should be in this module") require(totalSum == union.size, "No name can be in more than one group") require(groupRoots.forall(_.nonEmpty), "All groupRoots must by non-empty") - // Order of groups, according to their label. The label is the first root in the group val labelOrder = groups.collect({ case g: GroupAnnotation => g.components.head.name }) @@ -90,8 +94,8 @@ class GroupComponents extends Transform with DependencyAPIMigration { // Group roots, by label // The label "" indicates the original module, and components belonging to that group will remain // in the original module (not get moved into a new module) - val label2group: Map[String, MSet[String]] = groups.collect{ - case GroupAnnotation(set, module, instance, _, _) => set.head.name -> mutable.Set(set.map(_.name):_*) + val label2group: Map[String, MSet[String]] = groups.collect { + case GroupAnnotation(set, module, instance, _, _) => set.head.name -> mutable.Set(set.map(_.name): _*) }.toMap + ("" -> mutable.Set("")) // Name of new module containing each group, by label @@ -105,7 +109,6 @@ class GroupComponents extends Transform with DependencyAPIMigration { // Build set of components not in set val notSet = label2group.map { case (key, value) => key -> union.diff(value) } - // Get all dependencies between components val deps = getComponentConnectivity(m) @@ -114,13 +117,14 @@ class GroupComponents extends Transform with DependencyAPIMigration { // For each group (by label), add connectivity between nodes in set // Populate reachableNodes with reachability, where blacklist is their notSet - label2group.foreach { case (label, set) => - set.foreach { x => - deps.addPairWithEdge(label, x) - } - deps.reachableFrom(label, notSet(label)) foreach { node => - reachableNodes.getOrElseUpdate(node, mutable.Set.empty[String]) += label - } + label2group.foreach { + case (label, set) => + set.foreach { x => + deps.addPairWithEdge(label, x) + } + deps.reachableFrom(label, notSet(label)).foreach { node => + reachableNodes.getOrElseUpdate(node, mutable.Set.empty[String]) += label + } } // Unused nodes are not reachable from any group nor the root--add them to root group @@ -129,12 +133,13 @@ class GroupComponents extends Transform with DependencyAPIMigration { } // Add nodes who are reached by a single group, to that group - reachableNodes.foreach { case (node, membership) => - if(membership.size == 1) { - label2group(membership.head) += node - } else { - label2group("") += node - } + reachableNodes.foreach { + case (node, membership) => + if (membership.size == 1) { + label2group(membership.head) += node + } else { + label2group("") += node + } } applyGrouping(m, labelOrder, label2group, label2module, label2instance, label2annotation) @@ -150,19 +155,21 @@ class GroupComponents extends Transform with DependencyAPIMigration { * @param label2annotation annotation specifying the group, by label * @return new modules, including each group's module and the new split module */ - def applyGrouping( m: Module, - labelOrder: Seq[String], - label2group: Map[String, MSet[String]], - label2module: Map[String, String], - label2instance: Map[String, String], - label2annotation: Map[String, GroupAnnotation] - ): Seq[Module] = { + def applyGrouping( + m: Module, + labelOrder: Seq[String], + label2group: Map[String, MSet[String]], + label2module: Map[String, String], + label2instance: Map[String, String], + label2annotation: Map[String, GroupAnnotation] + ): Seq[Module] = { // Maps node to group val byNode = mutable.HashMap[String, String]() - label2group.foreach { case (group, nodes) => - nodes.foreach { node => - byNode(node) = group - } + label2group.foreach { + case (group, nodes) => + nodes.foreach { node => + byNode(node) = group + } } val groupNamespace = label2group.map { case (head, set) => head -> Namespace(set.toSeq) } @@ -180,7 +187,7 @@ class GroupComponents extends Transform with DependencyAPIMigration { val portNames = groupPortNames(group) val suffix = d match { case Output => label2annotation(group).outputSuffix.getOrElse("") - case Input => label2annotation(group).inputSuffix.getOrElse("") + case Input => label2annotation(group).inputSuffix.getOrElse("") } val newName = groupNamespace(group).newName(source + suffix) val portName = portNames.getOrElseUpdate(source, newName) @@ -192,7 +199,7 @@ class GroupComponents extends Transform with DependencyAPIMigration { val portName = addPort(group, exp, Output) val connectStatement = exp.tpe match { case AnalogType(_) => Attach(NoInfo, Seq(WRef(portName), exp)) - case _ => Connect(NoInfo, WRef(portName), exp) + case _ => Connect(NoInfo, WRef(portName), exp) } groupStatements(group) += connectStatement portName @@ -201,7 +208,7 @@ class GroupComponents extends Transform with DependencyAPIMigration { // Given the sink is in a group, tidy up source references def inGroupFixExps(group: String, added: mutable.ArrayBuffer[Statement])(e: Expression): Expression = e match { case _: Literal => e - case _: DoPrim | _: Mux | _: ValidIf => e map inGroupFixExps(group, added) + case _: DoPrim | _: Mux | _: ValidIf => e.map(inGroupFixExps(group, added)) case otherExp: Expression => val wref = getWRef(otherExp) val source = wref.name @@ -238,10 +245,10 @@ class GroupComponents extends Transform with DependencyAPIMigration { // Given the sink is in the parent module, tidy up source references belonging to groups def inTopFixExps(e: Expression): Expression = e match { - case _: DoPrim | _: Mux | _: ValidIf => e map inTopFixExps + case _: DoPrim | _: Mux | _: ValidIf => e.map(inTopFixExps) case otherExp: Expression => val wref = getWRef(otherExp) - if(byNode(wref.name) != "") { + if (byNode(wref.name) != "") { // Get the name of source's group val otherGroup = byNode(wref.name) @@ -260,7 +267,7 @@ class GroupComponents extends Transform with DependencyAPIMigration { case r: IsDeclaration if byNode(r.name) != "" => val topStmts = mutable.ArrayBuffer[Statement]() val group = byNode(r.name) - groupStatements(group) += r mapExpr inGroupFixExps(group, topStmts) + groupStatements(group) += r.mapExpr(inGroupFixExps(group, topStmts)) Block(topStmts.toSeq) case c: Connect if byNode(getWRef(c.loc).name) != "" => // Sink is in a group @@ -276,20 +283,26 @@ class GroupComponents extends Transform with DependencyAPIMigration { // TODO Attach if all are in a group? case _: IsDeclaration | _: Connect | _: Attach => // Sink is in Top - val ret = s mapExpr inTopFixExps + val ret = s.mapExpr(inTopFixExps) ret - case other => other map onStmt + case other => other.map(onStmt) } } - // Build datastructures - val newTopBody = Block(labelOrder.map(g => WDefInstance(NoInfo, label2instance(g), label2module(g), UnknownType)) ++ Seq(onStmt(m.body))) + val newTopBody = Block( + labelOrder.map(g => WDefInstance(NoInfo, label2instance(g), label2module(g), UnknownType)) ++ Seq(onStmt(m.body)) + ) val finalTopBody = Block(Utils.squashEmpty(newTopBody).asInstanceOf[Block].stmts.distinct) // For all group labels (not including the original module label), return a new Module. - val newModules = labelOrder.filter(_ != "") map { group => - Module(NoInfo, label2module(group), groupPorts(group).distinct.toSeq, Block(groupStatements(group).distinct.toSeq)) + val newModules = labelOrder.filter(_ != "").map { group => + Module( + NoInfo, + label2module(group), + groupPorts(group).distinct.toSeq, + Block(groupStatements(group).distinct.toSeq) + ) } Seq(m.copy(body = finalTopBody)) ++ newModules } @@ -298,7 +311,7 @@ class GroupComponents extends Transform with DependencyAPIMigration { case w: WRef => w case other => var w = WRef("") - other mapExpr { e => w = getWRef(e); e} + other.mapExpr { e => w = getWRef(e); e } w } @@ -317,25 +330,25 @@ class GroupComponents extends Transform with DependencyAPIMigration { bidirGraph.addPairWithEdge(sink.name, name) bidirGraph.addPairWithEdge(name, sink.name) w - case other => other map onExpr(sink) + case other => other.map(onExpr(sink)) } def onStmt(stmt: Statement): Unit = stmt match { case w: WDefInstance => case h: IsDeclaration => bidirGraph.addVertex(h.name) - h map onExpr(WRef(h.name)) + h.map(onExpr(WRef(h.name))) case Attach(_, exprs) => // Add edge between each expression - exprs.tail map onExpr(getWRef(exprs.head)) + exprs.tail.map(onExpr(getWRef(exprs.head))) case Connect(_, loc, expr) => onExpr(getWRef(loc))(expr) - case q @ Stop(_,_, clk, en) => + case q @ Stop(_, _, clk, en) => val simName = simNamespace.newTemp simulations(simName) = q - Seq(clk, en) map onExpr(WRef(simName)) + Seq(clk, en).map(onExpr(WRef(simName))) case q @ Print(_, _, args, clk, en) => val simName = simNamespace.newTemp simulations(simName) = q - (args :+ clk :+ en) map onExpr(WRef(simName)) + (args :+ clk :+ en).map(onExpr(WRef(simName))) case Block(stmts) => stmts.foreach(onStmt) case ignore @ (_: IsInvalid | EmptyStmt) => // do nothing case other => throw new Exception(s"Unexpected Statement $other") @@ -358,7 +371,7 @@ class GroupAndDedup extends GroupComponents { override def invalidates(a: Transform): Boolean = a match { case _: DedupModules => true - case _ => super.invalidates(a) + case _ => super.invalidates(a) } } diff --git a/src/main/scala/firrtl/transforms/InferResets.scala b/src/main/scala/firrtl/transforms/InferResets.scala index dd073001..376382cc 100644 --- a/src/main/scala/firrtl/transforms/InferResets.scala +++ b/src/main/scala/firrtl/transforms/InferResets.scala @@ -7,9 +7,9 @@ import firrtl.ir._ import firrtl.Mappers._ import firrtl.traversals.Foreachers._ import firrtl.annotations.{ReferenceTarget, TargetToken} -import firrtl.Utils.{toTarget, throwInternalError} +import firrtl.Utils.{throwInternalError, toTarget} import firrtl.options.Dependency -import firrtl.passes.{Pass, PassException, InferTypes} +import firrtl.passes.{InferTypes, Pass, PassException} import firrtl.graph.MutableDiGraph import scala.collection.mutable @@ -83,14 +83,13 @@ object InferResets { // Vectors must all have the same type, so we only process Index 0 // If the subtype is an aggregate, there can be multiple of each index val ts = tokens.collect { case (TargetToken.Index(0) +: tail, tpe) => (tail, tpe) } - VectorTree(fromTokens(ts:_*)) + VectorTree(fromTokens(ts: _*)) // BundleTree case (TargetToken.Field(_) +: _, _) +: _ => val fields = - tokens.groupBy { case (TargetToken.Field(n) +: t, _) => n } - .mapValues { ts => - fromTokens(ts.map { case (_ +: t, tpe) => (t, tpe) }:_*) - }.toMap + tokens.groupBy { case (TargetToken.Field(n) +: t, _) => n }.mapValues { ts => + fromTokens(ts.map { case (_ +: t, tpe) => (t, tpe) }: _*) + }.toMap BundleTree(fields) } } @@ -113,14 +112,16 @@ object InferResets { class InferResets extends Transform with DependencyAPIMigration { override def prerequisites = - Seq( Dependency(passes.ResolveKinds), - Dependency(passes.InferTypes), - Dependency(passes.ResolveFlows), - Dependency[passes.InferWidths] ) ++ stage.Forms.WorkingIR + Seq( + Dependency(passes.ResolveKinds), + Dependency(passes.InferTypes), + Dependency(passes.ResolveFlows), + Dependency[passes.InferWidths] + ) ++ stage.Forms.WorkingIR override def invalidates(a: Transform): Boolean = a match { case _: checks.CheckResets | passes.CheckTypes => true - case _ => false + case _ => false } import InferResets._ @@ -138,7 +139,7 @@ class InferResets extends Transform with DependencyAPIMigration { val mod = instMap(target.ref) val port = target.component.head match { case TargetToken.Field(name) => name - case bad => Utils.throwInternalError(s"Unexpected token $bad") + case bad => Utils.throwInternalError(s"Unexpected token $bad") } target.copy(module = mod, ref = port, component = target.component.tail) case _ => target @@ -148,17 +149,18 @@ class InferResets extends Transform with DependencyAPIMigration { // Mark driver of a ResetType leaf def markResetDriver(lhs: Expression, rhs: Expression): Unit = { val con = Utils.flow(lhs) match { - case SinkFlow if lhs.tpe == ResetType => Some((lhs, rhs)) + case SinkFlow if lhs.tpe == ResetType => Some((lhs, rhs)) case SourceFlow if rhs.tpe == ResetType => Some((rhs, lhs)) // If sink is not ResetType, do nothing - case _ => None + case _ => None } - con.foreach { case (loc, exp) => - val driver = exp.tpe match { - case ResetType => TargetDriver(makeTarget(exp)) - case tpe => TypeDriver(tpe, () => makeTarget(exp)) - } - map.getOrElseUpdate(makeTarget(loc), mutable.ListBuffer()) += driver + con.foreach { + case (loc, exp) => + val driver = exp.tpe match { + case ResetType => TargetDriver(makeTarget(exp)) + case tpe => TypeDriver(tpe, () => makeTarget(exp)) + } + map.getOrElseUpdate(makeTarget(loc), mutable.ListBuffer()) += driver } } stmt match { @@ -227,7 +229,7 @@ class InferResets extends Transform with DependencyAPIMigration { private def resolve(map: Map[ReferenceTarget, List[ResetDriver]]): Try[Map[ReferenceTarget, Type]] = { val graph = new MutableDiGraph[Node] val asyncNode = Typ(AsyncResetType) - val syncNode = Typ(Utils.BoolType) + val syncNode = Typ(Utils.BoolType) for ((target, drivers) <- map) { val v = Var(target) drivers.foreach { @@ -247,7 +249,7 @@ class InferResets extends Transform with DependencyAPIMigration { // do the actual inference, the check is simply if syncNode is reachable from asyncNode graph.addPairWithEdge(v, u) case InvalidDriver => - graph.addVertex(v) // Must be in the graph or won't be inferred + graph.addVertex(v) // Must be in the graph or won't be inferred } } val async = graph.reachableFrom(asyncNode) @@ -257,7 +259,7 @@ class InferResets extends Transform with DependencyAPIMigration { case (a, _) if a.contains(syncNode) => throw InferResetsException(graph.path(asyncNode, syncNode)) case (a, s) => (a.view.collect { case Var(t) => t -> asyncNode.tpe } ++ - s.view.collect { case Var(t) => t -> syncNode.tpe }).toMap + s.view.collect { case Var(t) => t -> syncNode.tpe }).toMap } } } @@ -265,34 +267,40 @@ class InferResets extends Transform with DependencyAPIMigration { private def fixupType(tpe: Type, tree: TypeTree): Type = (tpe, tree) match { case (BundleType(fields), BundleTree(map)) => val fieldsx = - fields.map(f => map.get(f.name) match { - case Some(t) => f.copy(tpe = fixupType(f.tpe, t)) - case None => f - }) + fields.map(f => + map.get(f.name) match { + case Some(t) => f.copy(tpe = fixupType(f.tpe, t)) + case None => f + } + ) BundleType(fieldsx) case (VectorType(vtpe, size), VectorTree(t)) => VectorType(fixupType(vtpe, t), size) case (_, GroundTree(t)) => t - case x => throw new Exception(s"Error! Unexpected pair $x") + case x => throw new Exception(s"Error! Unexpected pair $x") } // Assumes all ReferenceTargets are in the same module private def makeDeclMap(map: Map[ReferenceTarget, Type]): Map[String, TypeTree] = - map.groupBy(_._1.ref).mapValues { ts => - TypeTree.fromTokens(ts.toSeq.map { case (target, tpe) => (target.component, tpe) }:_*) - }.toMap + map + .groupBy(_._1.ref) + .mapValues { ts => + TypeTree.fromTokens(ts.toSeq.map { case (target, tpe) => (target.component, tpe) }: _*) + } + .toMap private def implPort(map: Map[String, TypeTree])(port: Port): Port = - map.get(port.name) - .map(tree => port.copy(tpe = fixupType(port.tpe, tree))) - .getOrElse(port) + map + .get(port.name) + .map(tree => port.copy(tpe = fixupType(port.tpe, tree))) + .getOrElse(port) private def implStmt(map: Map[String, TypeTree])(stmt: Statement): Statement = stmt.map(implStmt(map)) match { case decl: IsDeclaration if map.contains(decl.name) => val tree = map(decl.name) decl match { - case reg: DefRegister => reg.copy(tpe = fixupType(reg.tpe, tree)) - case wire: DefWire => wire.copy(tpe = fixupType(wire.tpe, tree)) + case reg: DefRegister => reg.copy(tpe = fixupType(reg.tpe, tree)) + case wire: DefWire => wire.copy(tpe = fixupType(wire.tpe, tree)) // TODO Can this really happen? case mem: DefMemory => mem.copy(dataType = fixupType(mem.dataType, tree)) case other => other @@ -303,10 +311,13 @@ class InferResets extends Transform with DependencyAPIMigration { private def implement(c: Circuit, map: Map[ReferenceTarget, Type]): Circuit = { val modMaps = map.groupBy(_._1.module) def onMod(mod: DefModule): DefModule = { - modMaps.get(mod.name).map { tmap => - val declMap = makeDeclMap(tmap) - mod.map(implPort(declMap)).map(implStmt(declMap)) - }.getOrElse(mod) + modMaps + .get(mod.name) + .map { tmap => + val declMap = makeDeclMap(tmap) + mod.map(implPort(declMap)).map(implStmt(declMap)) + } + .getOrElse(mod) } c.map(onMod) } diff --git a/src/main/scala/firrtl/transforms/InlineBitExtractions.scala b/src/main/scala/firrtl/transforms/InlineBitExtractions.scala index 515bf407..100b598f 100644 --- a/src/main/scala/firrtl/transforms/InlineBitExtractions.scala +++ b/src/main/scala/firrtl/transforms/InlineBitExtractions.scala @@ -6,7 +6,7 @@ package transforms import firrtl.ir._ import firrtl.Mappers._ import firrtl.options.Dependency -import firrtl.PrimOps.{Bits, Head, Tail, Shr} +import firrtl.PrimOps.{Bits, Head, Shr, Tail} import firrtl.Utils.{isBitExtract, isTemp} import firrtl.WrappedExpression._ @@ -19,8 +19,8 @@ object InlineBitExtractionsTransform { // Note that this can have false negatives but MUST NOT have false positives. private def isSimpleExpr(expr: Expression): Boolean = expr match { case _: WRef | _: Literal | _: WSubField => true - case DoPrim(op, args, _,_) if isBitExtract(op) => args.forall(isSimpleExpr) - case _ => false + case DoPrim(op, args, _, _) if isBitExtract(op) => args.forall(isSimpleExpr) + case _ => false } // replace Head/Tail/Shr with Bits for easier back-to-back Bits Extractions @@ -28,12 +28,12 @@ object InlineBitExtractionsTransform { case DoPrim(Head, rhs, c, tpe) if isSimpleExpr(expr) => val msb = bitWidth(rhs.head.tpe) - 1 val lsb = bitWidth(rhs.head.tpe) - c.head - DoPrim(Bits, rhs, Seq(msb,lsb), tpe) + DoPrim(Bits, rhs, Seq(msb, lsb), tpe) case DoPrim(Tail, rhs, c, tpe) if isSimpleExpr(expr) => val msb = bitWidth(rhs.head.tpe) - c.head - 1 - DoPrim(Bits, rhs, Seq(msb,0), tpe) + DoPrim(Bits, rhs, Seq(msb, 0), tpe) case DoPrim(Shr, rhs, c, tpe) if isSimpleExpr(expr) => - DoPrim(Bits, rhs, Seq(bitWidth(rhs.head.tpe)-1, c.head), tpe) + DoPrim(Bits, rhs, Seq(bitWidth(rhs.head.tpe) - 1, c.head), tpe) case _ => expr // Not a candidate } @@ -49,26 +49,28 @@ object InlineBitExtractionsTransform { */ def onExpr(netlist: Netlist)(expr: Expression): Expression = { expr.map(onExpr(netlist)) match { - case e @ WRef(name, _,_,_) => - netlist.get(we(e)) - .filter(isBitExtract) - .getOrElse(e) + case e @ WRef(name, _, _, _) => + netlist + .get(we(e)) + .filter(isBitExtract) + .getOrElse(e) // replace back-to-back Bits Extractions case lhs @ DoPrim(lop, ival, lc, ltpe) if isSimpleExpr(lhs) => ival.head match { case of @ DoPrim(rop, rhs, rc, rtpe) if isSimpleExpr(of) => (lop, rop) match { - case (Head, Head) => DoPrim(Head, rhs, Seq(lc.head min rc.head), ltpe) + case (Head, Head) => DoPrim(Head, rhs, Seq(lc.head.min(rc.head)), ltpe) case (Tail, Tail) => DoPrim(Tail, rhs, Seq(lc.head + rc.head), ltpe) - case (Shr, Shr) => DoPrim(Shr, rhs, Seq(lc.head + rc.head), ltpe) - case (_,_) => (lowerToDoPrimOpBits(lhs), lowerToDoPrimOpBits(of)) match { - case (DoPrim(Bits, _, Seq(lmsb, llsb), _), DoPrim(Bits, _, Seq(rmsb, rlsb), _)) => - DoPrim(Bits, rhs, Seq(lmsb+rlsb,llsb+rlsb), ltpe) - case (_,_) => lhs // Not a candidate - } + case (Shr, Shr) => DoPrim(Shr, rhs, Seq(lc.head + rc.head), ltpe) + case (_, _) => + (lowerToDoPrimOpBits(lhs), lowerToDoPrimOpBits(of)) match { + case (DoPrim(Bits, _, Seq(lmsb, llsb), _), DoPrim(Bits, _, Seq(rmsb, rlsb), _)) => + DoPrim(Bits, rhs, Seq(lmsb + rlsb, llsb + rlsb), ltpe) + case (_, _) => lhs // Not a candidate + } } - case _ => lhs // Not a candidate - } + case _ => lhs // Not a candidate + } case other => other // Not a candidate } } @@ -97,9 +99,11 @@ object InlineBitExtractionsTransform { class InlineBitExtractionsTransform extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ - Seq( Dependency[BlackBoxSourceHelper], - Dependency[FixAddingNegativeLiterals], - Dependency[ReplaceTruncatingArithmetic] ) + Seq( + Dependency[BlackBoxSourceHelper], + Dependency[FixAddingNegativeLiterals], + Dependency[ReplaceTruncatingArithmetic] + ) override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized diff --git a/src/main/scala/firrtl/transforms/InlineCasts.scala b/src/main/scala/firrtl/transforms/InlineCasts.scala index 3dac938e..0efc0727 100644 --- a/src/main/scala/firrtl/transforms/InlineCasts.scala +++ b/src/main/scala/firrtl/transforms/InlineCasts.scala @@ -8,7 +8,7 @@ import firrtl.Mappers._ import firrtl.PrimOps.Pad import firrtl.options.Dependency -import firrtl.Utils.{isCast, isBitExtract, NodeMap} +import firrtl.Utils.{isBitExtract, isCast, NodeMap} object InlineCastsTransform { @@ -17,8 +17,8 @@ object InlineCastsTransform { // Note that this can have false negatives but MUST NOT have false positives private def isSimpleCast(castSeen: Boolean)(expr: Expression): Boolean = expr match { case _: WRef | _: Literal | _: WSubField => castSeen - case DoPrim(op, args, _,_) if isCast(op) => args.forall(isSimpleCast(true)) - case _ => false + case DoPrim(op, args, _, _) if isCast(op) => args.forall(isSimpleCast(true)) + case _ => false } /** Recursively replace [[WRef]]s with new [[firrtl.ir.Expression Expression]]s @@ -31,17 +31,20 @@ object InlineCastsTransform { def onExpr(replace: NodeMap)(expr: Expression): Expression = expr match { // Anything that may generate a part-select should not be inlined! case DoPrim(op, _, _, _) if (isBitExtract(op) || op == Pad) => expr - case e => e.map(onExpr(replace)) match { - case e @ WRef(name, _,_,_) => - replace.get(name) - .filter(isSimpleCast(castSeen=false)) - .getOrElse(e) - case e @ DoPrim(op, Seq(WRef(name, _,_,_)), _,_) if isCast(op) => - replace.get(name) - .map(value => e.copy(args = Seq(value))) - .getOrElse(e) - case other => other // Not a candidate - } + case e => + e.map(onExpr(replace)) match { + case e @ WRef(name, _, _, _) => + replace + .get(name) + .filter(isSimpleCast(castSeen = false)) + .getOrElse(e) + case e @ DoPrim(op, Seq(WRef(name, _, _, _)), _, _) if isCast(op) => + replace + .get(name) + .map(value => e.copy(args = Seq(value))) + .getOrElse(e) + case other => other // Not a candidate + } } /** Inline casts in a Statement @@ -69,11 +72,13 @@ object InlineCastsTransform { class InlineCastsTransform extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ - Seq( Dependency[BlackBoxSourceHelper], - Dependency[FixAddingNegativeLiterals], - Dependency[ReplaceTruncatingArithmetic], - Dependency[InlineBitExtractionsTransform], - Dependency[PropagatePresetAnnotations] ) + Seq( + Dependency[BlackBoxSourceHelper], + Dependency[FixAddingNegativeLiterals], + Dependency[ReplaceTruncatingArithmetic], + Dependency[InlineBitExtractionsTransform], + Dependency[PropagatePresetAnnotations] + ) override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized diff --git a/src/main/scala/firrtl/transforms/LegalizeClocks.scala b/src/main/scala/firrtl/transforms/LegalizeClocks.scala index f439fdc9..248775d9 100644 --- a/src/main/scala/firrtl/transforms/LegalizeClocks.scala +++ b/src/main/scala/firrtl/transforms/LegalizeClocks.scala @@ -18,8 +18,8 @@ object LegalizeClocksTransform { // Currently only looks for literals nested within casts private def illegalClockExpr(expr: Expression): Boolean = expr match { case _: Literal => true - case DoPrim(op, args, _,_) if isCast(op) => args.exists(illegalClockExpr) - case _ => false + case DoPrim(op, args, _, _) if isCast(op) => args.exists(illegalClockExpr) + case _ => false } /** Legalize Clocks in a Statement @@ -66,11 +66,13 @@ object LegalizeClocksTransform { class LegalizeClocksTransform extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ - Seq( Dependency[BlackBoxSourceHelper], - Dependency[FixAddingNegativeLiterals], - Dependency[ReplaceTruncatingArithmetic], - Dependency[InlineBitExtractionsTransform], - Dependency[InlineCastsTransform] ) + Seq( + Dependency[BlackBoxSourceHelper], + Dependency[FixAddingNegativeLiterals], + Dependency[ReplaceTruncatingArithmetic], + Dependency[InlineBitExtractionsTransform], + Dependency[InlineCastsTransform] + ) override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized diff --git a/src/main/scala/firrtl/transforms/LegalizeReductions.scala b/src/main/scala/firrtl/transforms/LegalizeReductions.scala index 2e60aae7..33a10349 100644 --- a/src/main/scala/firrtl/transforms/LegalizeReductions.scala +++ b/src/main/scala/firrtl/transforms/LegalizeReductions.scala @@ -6,17 +6,16 @@ import firrtl.Mappers._ import firrtl.options.Dependency import firrtl.Utils.BoolType - object LegalizeAndReductionsTransform { private def allOnesOfType(tpe: Type): Literal = tpe match { case UIntType(width @ IntWidth(x)) => UIntLiteral((BigInt(1) << x.toInt) - 1, width) - case SIntType(width) => SIntLiteral(-1, width) + case SIntType(width) => SIntLiteral(-1, width) } def onExpr(expr: Expression): Expression = expr.map(onExpr) match { - case DoPrim(PrimOps.Andr, Seq(arg), _,_) if bitWidth(arg.tpe) > 64 => + case DoPrim(PrimOps.Andr, Seq(arg), _, _) if bitWidth(arg.tpe) > 64 => DoPrim(PrimOps.Eq, Seq(arg, allOnesOfType(arg.tpe)), Seq(), BoolType) case other => other } @@ -35,8 +34,7 @@ class LegalizeAndReductionsTransform extends Transform with DependencyAPIMigrati override def prerequisites = firrtl.stage.Forms.WorkingIR ++ - Seq( Dependency(passes.CheckTypes), - Dependency(passes.CheckWidths)) + Seq(Dependency(passes.CheckTypes), Dependency(passes.CheckWidths)) override def optionalPrerequisites = Nil diff --git a/src/main/scala/firrtl/transforms/ManipulateNames.scala b/src/main/scala/firrtl/transforms/ManipulateNames.scala index f15b546f..d0b12e66 100644 --- a/src/main/scala/firrtl/transforms/ManipulateNames.scala +++ b/src/main/scala/firrtl/transforms/ManipulateNames.scala @@ -57,8 +57,9 @@ sealed trait ManipulateNamesListAnnotation[A <: ManipulateNames[_]] extends Mult * @note $noteLocalTargets */ case class ManipulateNamesBlocklistAnnotation[A <: ManipulateNames[_]]( - targets: Seq[Seq[Target]], - transform: Dependency[A]) extends ManipulateNamesListAnnotation[A] { + targets: Seq[Seq[Target]], + transform: Dependency[A]) + extends ManipulateNamesListAnnotation[A] { override def duplicate(a: Seq[Seq[Target]]) = this.copy(targets = a) @@ -77,8 +78,9 @@ case class ManipulateNamesBlocklistAnnotation[A <: ManipulateNames[_]]( * @note $noteLocalTargets */ case class ManipulateNamesAllowlistAnnotation[A <: ManipulateNames[_]]( - targets: Seq[Seq[Target]], - transform: Dependency[A]) extends ManipulateNamesListAnnotation[A] { + targets: Seq[Seq[Target]], + transform: Dependency[A]) + extends ManipulateNamesListAnnotation[A] { override def duplicate(a: Seq[Seq[Target]]) = this.copy(targets = a) @@ -94,19 +96,21 @@ case class ManipulateNamesAllowlistAnnotation[A <: ManipulateNames[_]]( * @param oldTargets the old targets */ case class ManipulateNamesAllowlistResultAnnotation[A <: ManipulateNames[_]]( - targets: Seq[Seq[Target]], - transform: Dependency[A], - oldTargets: Seq[Seq[Target]]) extends MultiTargetAnnotation { + targets: Seq[Seq[Target]], + transform: Dependency[A], + oldTargets: Seq[Seq[Target]]) + extends MultiTargetAnnotation { override def duplicate(a: Seq[Seq[Target]]) = this.copy(targets = a) override def update(renames: RenameMap) = { val (targetsx, oldTargetsx) = targets.zip(oldTargets).foldLeft((Seq.empty[Seq[Target]], Seq.empty[Seq[Target]])) { - case ((accT, accO), (t, o)) => t.flatMap(renames(_)) match { - /* If the target was deleted, delete the old target */ - case tx if tx.isEmpty => (accT, accO) - case tx => (Seq(tx) ++ accT, Seq(o) ++ accO) - } + case ((accT, accO), (t, o)) => + t.flatMap(renames(_)) match { + /* If the target was deleted, delete the old target */ + case tx if tx.isEmpty => (accT, accO) + case tx => (Seq(tx) ++ accT, Seq(o) ++ accO) + } } targetsx match { /* If all targets were deleted, delete the annotation */ @@ -117,9 +121,13 @@ case class ManipulateNamesAllowlistResultAnnotation[A <: ManipulateNames[_]]( /** Return [[firrtl.RenameMap RenameMap]] from old targets to new targets */ def toRenameMap: RenameMap = { - val m = oldTargets.zip(targets).flatMap { - case (a, b) => a.map(_ -> b) - }.toMap.asInstanceOf[Map[CompleteTarget, Seq[CompleteTarget]]] + val m = oldTargets + .zip(targets) + .flatMap { + case (a, b) => a.map(_ -> b) + } + .toMap + .asInstanceOf[Map[CompleteTarget, Seq[CompleteTarget]]] RenameMap.create(m) } @@ -132,25 +140,28 @@ case class ManipulateNamesAllowlistResultAnnotation[A <: ManipulateNames[_]]( * @param allow a function that returns true if a [[firrtl.annotations.Target Target]] should be renamed */ private class RenameDataStructure( - circuit: ir.Circuit, + circuit: ir.Circuit, val renames: RenameMap, - val block: Target => Boolean, - val allow: Target => Boolean) { + val block: Target => Boolean, + val allow: Target => Boolean) { /** A mapping of targets to associated namespaces */ val namespaces: mutable.HashMap[CompleteTarget, Namespace] = mutable.HashMap(CircuitTarget(circuit.main) -> Namespace(circuit)) - /** Wraps a HashMap to provide better error messages when accessing a non-existing element */ + /** Wraps a HashMap to provide better error messages when accessing a non-existing element */ class InstanceHashMap { type Key = ReferenceTarget type Value = Either[ReferenceTarget, InstanceTarget] private val m = mutable.HashMap[Key, Value]() - def apply(key: ReferenceTarget): Value = m.getOrElse(key, { - throw new FirrtlUserException( - s"""|Reference target '${key.serialize}' did not exist in mapping of reference targets to insts/mems. - | This is indicative of a circuit that has not been run through LowerTypes.""".stripMargin) - }) + def apply(key: ReferenceTarget): Value = m.getOrElse( + key, { + throw new FirrtlUserException( + s"""|Reference target '${key.serialize}' did not exist in mapping of reference targets to insts/mems. + | This is indicative of a circuit that has not been run through LowerTypes.""".stripMargin + ) + } + ) def update(key: Key, value: Value): Unit = m.update(key, value) } @@ -165,17 +176,17 @@ private class RenameDataStructure( /** Transform for manipulate all the names in a FIRRTL circuit. * @tparam A the type of the child transform */ -abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Transform with DependencyAPIMigration { +abstract class ManipulateNames[A <: ManipulateNames[_]: ClassTag] extends Transform with DependencyAPIMigration { /** A function used to manipulate a name in a FIRRTL circuit */ def manipulate: (String, Namespace) => Option[String] - override def prerequisites: Seq[TransformDependency] = Seq(Dependency(firrtl.passes.LowerTypes)) - override def optionalPrerequisites: Seq[TransformDependency] = Seq.empty + override def prerequisites: Seq[TransformDependency] = Seq(Dependency(firrtl.passes.LowerTypes)) + override def optionalPrerequisites: Seq[TransformDependency] = Seq.empty override def optionalPrerequisiteOf: Seq[TransformDependency] = Forms.LowEmitters override def invalidates(a: Transform) = a match { case _: analyses.GetNamespace => true - case _ => false + case _ => false } /** Compute a new name for some target and record the rename if the new name differs. If the top module or the circuit @@ -192,27 +203,31 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans case a if r.skip(a) => (name, None) /* Circuit renaming */ - case a@ CircuitTarget(b) => manipulate(b, r.namespaces(a)) match { - case Some(str) => (str, Some(a.copy(circuit = str))) - case None => (b, None) - } + case a @ CircuitTarget(b) => + manipulate(b, r.namespaces(a)) match { + case Some(str) => (str, Some(a.copy(circuit = str))) + case None => (b, None) + } /* Module renaming for non-top modules */ - case a@ ModuleTarget(_, b) => manipulate(b, r.namespaces(a.circuitTarget)) match { - case Some(str) => (str, Some(a.copy(module = str))) - case None => (b, None) - } + case a @ ModuleTarget(_, b) => + manipulate(b, r.namespaces(a.circuitTarget)) match { + case Some(str) => (str, Some(a.copy(module = str))) + case None => (b, None) + } /* Instance renaming */ - case a@ InstanceTarget(_, _, Nil, b, c) => manipulate(b, r.namespaces(a.moduleTarget)) match { - case Some(str) => (str, Some(a.copy(instance = str))) - case None => (b, None) - } + case a @ InstanceTarget(_, _, Nil, b, c) => + manipulate(b, r.namespaces(a.moduleTarget)) match { + case Some(str) => (str, Some(a.copy(instance = str))) + case None => (b, None) + } /* Rename either a module component or a memory */ - case a@ ReferenceTarget(_, _, _, b, Nil) => manipulate(b, r.namespaces(a.moduleTarget)) match { - case Some(str) => (str, Some(a.copy(ref = str))) - case None => (b, None) - } + case a @ ReferenceTarget(_, _, _, b, Nil) => + manipulate(b, r.namespaces(a.moduleTarget)) match { + case Some(str) => (str, Some(a.copy(ref = str))) + case None => (b, None) + } /* Rename an instance port or a memory reader/writer/readwriter */ - case a@ ReferenceTarget(_, _, _, b, (token@ TargetToken.Field(c)) :: Nil) => + case a @ ReferenceTarget(_, _, _, b, (token @ TargetToken.Field(c)) :: Nil) => val ref = r.instanceMap(a.moduleTarget.ref(b)) match { case Right(inst) => inst.ofModuleTarget case Left(mem) => mem @@ -224,8 +239,8 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans } /* Record the optional rename. If the circuit was renamed, also rename the top module. If the top module was * renamed, also rename the circuit. */ - ax.foreach( - axx => target match { + ax.foreach(axx => + target match { case c: CircuitTarget => r.renames.rename(target, r.renames(axx)) r.renames.rename(c.module(c.circuit), CircuitTarget(namex).module(namex)) @@ -252,21 +267,26 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans r.renames.underlying.get(t) match { case Some(ax) if ax.size == 1 => ax match { - case Seq(foo: CircuitTarget) => foo.name - case Seq(foo: ModuleTarget) => foo.module - case Seq(foo: InstanceTarget) => foo.instance - case Seq(foo: ReferenceTarget) => foo.tokens.last match { - case TargetToken.Ref(value) => value - case TargetToken.Field(value) => value - case _ => Utils.throwInternalError( - s"""|Reference target '${t.serialize}'must end in 'Ref' or 'Field' + case Seq(foo: CircuitTarget) => foo.name + case Seq(foo: ModuleTarget) => foo.module + case Seq(foo: InstanceTarget) => foo.instance + case Seq(foo: ReferenceTarget) => + foo.tokens.last match { + case TargetToken.Ref(value) => value + case TargetToken.Field(value) => value + case _ => + Utils.throwInternalError( + s"""|Reference target '${t.serialize}'must end in 'Ref' or 'Field' | This is indicative of a circuit that has not been run through LowerTypes.""", - Some(new MatchError(foo.serialize))) - } + Some(new MatchError(foo.serialize)) + ) + } } - case s@ Some(ax) => Utils.throwInternalError( - s"""Found multiple renames '${t}' -> [${ax.map(_.serialize).mkString(",")}]. This should be impossible.""", - Some(new MatchError(s))) + case s @ Some(ax) => + Utils.throwInternalError( + s"""Found multiple renames '${t}' -> [${ax.map(_.serialize).mkString(",")}]. This should be impossible.""", + Some(new MatchError(s)) + ) case None => name } @@ -280,27 +300,34 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans /* A reference to something inside this module */ case w: WRef => w.copy(name = maybeRename(w.name, r, Target.asTarget(t)(w))) /* This is either the subfield of an instance or a subfield of a memory reader/writer/readwriter */ - case w@ WSubField(expr, ref, _, _) => expr match { - /* This is an instance */ - case we@ WRef(inst, _, _, _) => - val tx = Target.asTarget(t)(we) - val (rTarget: ReferenceTarget, iTarget: InstanceTarget) = r.instanceMap(tx) match { - case Right(a) => (a.ofModuleTarget.ref(ref), a) - case a@ Left(ref) => throw new FirrtlUserException( - s"""|Unexpected '${ref.serialize}' in instanceMap for key '${tx.serialize}' on expression '${w.serialize}'. - | This is indicative of a circuit that has not been run through LowerTypes.""", new MatchError(a)) - } - w.copy(we.copy(name=maybeRename(inst, r, iTarget)), name=maybeRename(ref, r, rTarget)) - /* This is a reader/writer/readwriter */ - case ws@ WSubField(expr, port, _, _) => expr match { - /* This is the memory. */ - case wr@ WRef(mem, _, _, _) => - w.copy( - expr=ws.copy( - expr=wr.copy(name=maybeRename(mem, r, t.ref(mem))), - name=maybeRename(port, r, t.ref(mem).field(port)))) + case w @ WSubField(expr, ref, _, _) => + expr match { + /* This is an instance */ + case we @ WRef(inst, _, _, _) => + val tx = Target.asTarget(t)(we) + val (rTarget: ReferenceTarget, iTarget: InstanceTarget) = r.instanceMap(tx) match { + case Right(a) => (a.ofModuleTarget.ref(ref), a) + case a @ Left(ref) => + throw new FirrtlUserException( + s"""|Unexpected '${ref.serialize}' in instanceMap for key '${tx.serialize}' on expression '${w.serialize}'. + | This is indicative of a circuit that has not been run through LowerTypes.""", + new MatchError(a) + ) + } + w.copy(we.copy(name = maybeRename(inst, r, iTarget)), name = maybeRename(ref, r, rTarget)) + /* This is a reader/writer/readwriter */ + case ws @ WSubField(expr, port, _, _) => + expr match { + /* This is the memory. */ + case wr @ WRef(mem, _, _, _) => + w.copy( + expr = ws.copy( + expr = wr.copy(name = maybeRename(mem, r, t.ref(mem))), + name = maybeRename(port, r, t.ref(mem).field(port)) + ) + ) + } } - } case e => e.map(onExpression(_: ir.Expression, r, t)) } @@ -310,30 +337,31 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans * and readwriters. */ private def onStatement(s: ir.Statement, r: RenameDataStructure, t: ModuleTarget): ir.Statement = s match { - case decl: ir.IsDeclaration => decl match { - case decl@ WDefInstance(_, inst, mod, _) => - val modx = maybeRename(mod, r, t.circuitTarget.module(mod)) - val instx = doRename(inst, r, t.instOf(inst, mod)) - r.instanceMap(t.ref(inst)) = Right(t.instOf(inst, mod)) - decl.copy(name = instx, module = modx) - case decl: ir.DefMemory => - val namex = doRename(decl.name, r, t.ref(decl.name)) - val tx = t.ref(decl.name) - r.namespaces(tx) = Namespace(decl.readers ++ decl.writers ++ decl.readwriters) - r.instanceMap(tx) = Left(tx) - decl - .copy( - name = namex, - readers = decl.readers.map(_r => doRename(_r, r, tx.field(_r))), - writers = decl.writers.map(_w => doRename(_w, r, tx.field(_w))), - readwriters = decl.readwriters.map(_rw => doRename(_rw, r, tx.field(_rw))) - ) - .map(onExpression(_: ir.Expression, r, t)) - case decl => - decl - .map(doRename(_: String, r, t.ref(decl.name))) - .map(onExpression(_: ir.Expression, r, t)) - } + case decl: ir.IsDeclaration => + decl match { + case decl @ WDefInstance(_, inst, mod, _) => + val modx = maybeRename(mod, r, t.circuitTarget.module(mod)) + val instx = doRename(inst, r, t.instOf(inst, mod)) + r.instanceMap(t.ref(inst)) = Right(t.instOf(inst, mod)) + decl.copy(name = instx, module = modx) + case decl: ir.DefMemory => + val namex = doRename(decl.name, r, t.ref(decl.name)) + val tx = t.ref(decl.name) + r.namespaces(tx) = Namespace(decl.readers ++ decl.writers ++ decl.readwriters) + r.instanceMap(tx) = Left(tx) + decl + .copy( + name = namex, + readers = decl.readers.map(_r => doRename(_r, r, tx.field(_r))), + writers = decl.writers.map(_w => doRename(_w, r, tx.field(_w))), + readwriters = decl.readwriters.map(_rw => doRename(_rw, r, tx.field(_rw))) + ) + .map(onExpression(_: ir.Expression, r, t)) + case decl => + decl + .map(doRename(_: String, r, t.ref(decl.name))) + .map(onExpression(_: ir.Expression, r, t)) + } case s => s .map(onStatement(_: ir.Statement, r, t)) @@ -362,7 +390,7 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans */ val onName: String => String = t.circuit match { case `main` => maybeRename(_, r, moduleTarget) - case _ => doRename(_, r, moduleTarget) + case _ => doRename(_, r, moduleTarget) } m @@ -380,11 +408,11 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans * @return the circuit with manipulated names */ def run( - c: ir.Circuit, + c: ir.Circuit, renames: RenameMap, - block: Target => Boolean, - allow: Target => Boolean) - : ir.Circuit = { + block: Target => Boolean, + allow: Target => Boolean + ): ir.Circuit = { val t = CircuitTarget(c.main) /* If the circuit is a skip, return the original circuit. Otherwise, walk all the modules and rename them. Rename the @@ -427,8 +455,7 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans .toMap /* Replace the old modules making sure that they are still in the same order */ - c.copy(modules = c.modules.map(m => modulesx(t.module(m.name))), - main = mainx) + c.copy(modules = c.modules.map(m => modulesx(t.module(m.name))), main = mainx) } } @@ -436,18 +463,20 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans def execute(state: CircuitState): CircuitState = { val block = state.annotations.collect { - case ManipulateNamesBlocklistAnnotation(targetSeq, t) => t.getObject match { - case _: A => targetSeq - case _ => Nil - } + case ManipulateNamesBlocklistAnnotation(targetSeq, t) => + t.getObject match { + case _: A => targetSeq + case _ => Nil + } }.flatten.flatten.toSet val allow = { val allowx = state.annotations.collect { - case ManipulateNamesAllowlistAnnotation(targetSeq, t) => t.getObject match { - case _: A => targetSeq - case _ => Nil - } + case ManipulateNamesAllowlistAnnotation(targetSeq, t) => + t.getObject match { + case _: A => targetSeq + case _ => Nil + } }.flatten.flatten allowx match { @@ -461,17 +490,19 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans val annotationsx = state.annotations.flatMap { /* Consume blocklist annotations */ - case foo@ ManipulateNamesBlocklistAnnotation(_, t) => t.getObject match { - case _: A => None - case _ => Some(foo) - } + case foo @ ManipulateNamesBlocklistAnnotation(_, t) => + t.getObject match { + case _: A => None + case _ => Some(foo) + } /* Convert allowlist annotations to result annotations */ - case foo@ ManipulateNamesAllowlistAnnotation(a, t) => + case foo @ ManipulateNamesAllowlistAnnotation(a, t) => t.getObject match { - case _: A => (a, a.map(_.map(renames(_)).flatten)) match { - case (a, b) => Some(ManipulateNamesAllowlistResultAnnotation(b, t, a)) - } - case _ => Some(foo) + case _: A => + (a, a.map(_.map(renames(_)).flatten)) match { + case (a, b) => Some(ManipulateNamesAllowlistResultAnnotation(b, t, a)) + } + case _ => Some(foo) } case a => Some(a) } diff --git a/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala b/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala index ff44afec..5532d0f0 100644 --- a/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala +++ b/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala @@ -1,4 +1,3 @@ - package firrtl package transforms @@ -34,17 +33,19 @@ trait DontTouchAllTargets extends HasDontTouches { self: Annotation => * DCE treats the component as a top-level sink of the circuit */ case class DontTouchAnnotation(target: ReferenceTarget) - extends SingleTargetAnnotation[ReferenceTarget] with DontTouchAllTargets { + extends SingleTargetAnnotation[ReferenceTarget] + with DontTouchAllTargets { def targets = Seq(target) def duplicate(n: ReferenceTarget) = this.copy(n) } object DontTouchAnnotation { - class DontTouchNotFoundException(module: String, component: String) extends PassException( - s"""|Target marked dontTouch ($module.$component) not found! - |It was probably accidentally deleted. Please check that your custom transforms are not responsible and then - |file an issue on GitHub: https://github.com/freechipsproject/firrtl/issues/new""".stripMargin - ) + class DontTouchNotFoundException(module: String, component: String) + extends PassException( + s"""|Target marked dontTouch ($module.$component) not found! + |It was probably accidentally deleted. Please check that your custom transforms are not responsible and then + |file an issue on GitHub: https://github.com/freechipsproject/firrtl/issues/new""".stripMargin + ) def errorNotFound(module: String, component: String) = throw new DontTouchNotFoundException(module, component) @@ -58,7 +59,6 @@ object DontTouchAnnotation { * * @note Unlike [[DontTouchAnnotation]], we don't care if the annotation is deleted */ -case class OptimizableExtModuleAnnotation(target: ModuleName) extends - SingleTargetAnnotation[ModuleName] { +case class OptimizableExtModuleAnnotation(target: ModuleName) extends SingleTargetAnnotation[ModuleName] { def duplicate(n: ModuleName) = this.copy(n) } diff --git a/src/main/scala/firrtl/transforms/PropagatePresetAnnotations.scala b/src/main/scala/firrtl/transforms/PropagatePresetAnnotations.scala index da803837..97db0219 100644 --- a/src/main/scala/firrtl/transforms/PropagatePresetAnnotations.scala +++ b/src/main/scala/firrtl/transforms/PropagatePresetAnnotations.scala @@ -11,8 +11,10 @@ import firrtl.options.Dependency import scala.collection.mutable object PropagatePresetAnnotations { - val advice = "Please Note that a Preset-annotated AsyncReset shall NOT be casted to other types with any of the following functions: asInterval, asUInt, asSInt, asClock, asFixedPoint, asAsyncReset." - case class TreeCleanUpOrphanException(message: String) extends FirrtlUserException(s"Node left an orphan during tree cleanup: $message $advice") + val advice = + "Please Note that a Preset-annotated AsyncReset shall NOT be casted to other types with any of the following functions: asInterval, asUInt, asSInt, asClock, asFixedPoint, asAsyncReset." + case class TreeCleanUpOrphanException(message: String) + extends FirrtlUserException(s"Node left an orphan during tree cleanup: $message $advice") } /** Propagate PresetAnnotations to all children of targeted AsyncResets @@ -39,9 +41,11 @@ object PropagatePresetAnnotations { class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ - Seq( Dependency[BlackBoxSourceHelper], - Dependency[FixAddingNegativeLiterals], - Dependency[ReplaceTruncatingArithmetic]) + Seq( + Dependency[BlackBoxSourceHelper], + Dependency[FixAddingNegativeLiterals], + Dependency[ReplaceTruncatingArithmetic] + ) override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized @@ -52,7 +56,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { import PropagatePresetAnnotations._ private type TargetSet = mutable.HashSet[ReferenceTarget] - private type TargetMap = mutable.HashMap[ReferenceTarget,String] + private type TargetMap = mutable.HashMap[ReferenceTarget, String] private type TargetSetMap = mutable.HashMap[ReferenceTarget, TargetSet] private val toCleanUp = new TargetSet() @@ -71,7 +75,11 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { * @param presetAnnos all the annotations * @return updated annotations */ - private def propagate(cs: CircuitState, presetAnnos: Seq[PresetAnnotation], otherAnnos: Seq[Annotation]): AnnotationSeq = { + private def propagate( + cs: CircuitState, + presetAnnos: Seq[PresetAnnotation], + otherAnnos: Seq[Annotation] + ): AnnotationSeq = { val presets = presetAnnos.groupBy(_.target) // store all annotated asyncreset references val asyncToAnnotate = new TargetSet() @@ -85,34 +93,34 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { val circuitTarget = CircuitTarget(cs.circuit.main) /* - * WALK I PHASE 1 FUNCTIONS - */ + * WALK I PHASE 1 FUNCTIONS + */ /* Walk current module - * - process ports - * - store connections & entry points for PHASE 2 - * - process statements - * - Instances => record local instances for cross module AsyncReset Tree Buidling - * - Registers => store AsyncReset bound registers for PHASE 2 - * - Wire => store AsyncReset Connections & entry points for PHASE 2 - * - Connect => store AsyncReset Connections & entry points for PHASE 2 - * - * @param m module - */ + * - process ports + * - store connections & entry points for PHASE 2 + * - process statements + * - Instances => record local instances for cross module AsyncReset Tree Buidling + * - Registers => store AsyncReset bound registers for PHASE 2 + * - Wire => store AsyncReset Connections & entry points for PHASE 2 + * - Connect => store AsyncReset Connections & entry points for PHASE 2 + * + * @param m module + */ def processModule(m: DefModule): Unit = { val moduleTarget = circuitTarget.module(m.name) val localInstances = new TargetMap() /* Recursively process a given type - * Recursive on Bundle and Vector Type only - * Store Register and Connections for AsyncResetType - * @param tpe [[Type]] to be processed - * @param target [[ReferenceTarget]] associated to the tpe - * @param all Boolean indicating whether all subelements of the current - * tpe should also be stored as Annotated AsyncReset entry points - */ + * Recursive on Bundle and Vector Type only + * Store Register and Connections for AsyncResetType + * @param tpe [[Type]] to be processed + * @param target [[ReferenceTarget]] associated to the tpe + * @param all Boolean indicating whether all subelements of the current + * tpe should also be stored as Annotated AsyncReset entry points + */ def processType(tpe: Type, target: ReferenceTarget, all: Boolean): Unit = { - if(tpe == AsyncResetType){ + if (tpe == AsyncResetType) { asyncRegMap(target) = new TargetSet() asyncCoMap(target) = new TargetSet() if (presets.contains(target) || all) { @@ -121,14 +129,13 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { } else { tpe match { case b: BundleType => - b.fields.foreach{ - (x: Field) => - val tar = target.field(x.name) - processType(x.tpe, tar, (presets.contains(tar) || all)) + b.fields.foreach { (x: Field) => + val tar = target.field(x.name) + processType(x.tpe, tar, (presets.contains(tar) || all)) } case v: VectorType => - for(i <- 0 until v.size) { + for (i <- 0 until v.size) { val tar = target.index(i) processType(v.tpe, tar, (presets.contains(tar) || all)) } @@ -143,19 +150,19 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { } /* Recursively search for the ReferenceTarget of a given Expression - * @param e Targeted Expression - * @param ta Local ReferenceTarget of the Targeted Expression - * @return a ReferenceTarget in case of success, a GenericTarget otherwise - * @throws [[InternalError]] on unexpected recursive path return results - */ - def getRef(e: Expression, ta: ReferenceTarget, annoCo: Boolean = false) : Target = { + * @param e Targeted Expression + * @param ta Local ReferenceTarget of the Targeted Expression + * @return a ReferenceTarget in case of success, a GenericTarget otherwise + * @throws [[InternalError]] on unexpected recursive path return results + */ + def getRef(e: Expression, ta: ReferenceTarget, annoCo: Boolean = false): Target = { e match { case w: WRef => moduleTarget.ref(w.name) case w: WSubField => getRef(w.expr, ta, annoCo) match { case rt: ReferenceTarget => - if(localInstances.contains(rt)){ - val remote_ref = circuitTarget.module(localInstances(rt)) + if (localInstances.contains(rt)) { + val remote_ref = circuitTarget.module(localInstances(rt)) if (annoCo) asyncCoMap(ta) += rt.field(w.name) remote_ref.ref(w.name) @@ -163,7 +170,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { rt.field(w.name) } case remote_target => remote_target - } + } case w: WSubIndex => getRef(w.expr, ta, annoCo) match { case remote_target: ReferenceTarget => @@ -179,7 +186,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { def processRegister(r: DefRegister): Unit = { getRef(r.reset, moduleTarget.ref(r.name), false) match { - case rt : ReferenceTarget => + case rt: ReferenceTarget => if (asyncRegMap.contains(rt)) { asyncRegMap(rt) += moduleTarget.ref(r.name) } @@ -189,12 +196,12 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { } def processConnect(c: Connect): Unit = { - getRef(c.expr, ReferenceTarget("","", Seq.empty, "", Seq.empty)) match { + getRef(c.expr, ReferenceTarget("", "", Seq.empty, "", Seq.empty)) match { case rhs: ReferenceTarget => if (presets.contains(rhs) || asyncRegMap.contains(rhs)) { getRef(c.loc, rhs, true) match { - case lhs : ReferenceTarget => - if(asyncRegMap.contains(rhs)){ + case lhs: ReferenceTarget => + if (asyncRegMap.contains(rhs)) { asyncRegMap(rhs) += lhs } else { asyncToAnnotate += lhs @@ -211,10 +218,10 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { val target = moduleTarget.ref(n.name) processType(n.value.tpe, target, presets.contains(target)) - getRef(n.value, ReferenceTarget("","", Seq.empty, "", Seq.empty)) match { + getRef(n.value, ReferenceTarget("", "", Seq.empty, "", Seq.empty)) match { case rhs: ReferenceTarget => if (presets.contains(rhs) || asyncRegMap.contains(rhs)) { - if(asyncRegMap.contains(rhs)){ + if (asyncRegMap.contains(rhs)) { asyncRegMap(rhs) += target } else { asyncToAnnotate += target @@ -227,18 +234,18 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { def processStatements(statement: Statement): Unit = { statement match { - case i : WDefInstance => + case i: WDefInstance => localInstances(moduleTarget.ref(i.name)) = i.module - case r : DefRegister => processRegister(r) - case w : DefWire => processWire(w) - case n : DefNode => processNode(n) - case c : Connect => processConnect(c) - case s => s.foreachStmt(processStatements) + case r: DefRegister => processRegister(r) + case w: DefWire => processWire(w) + case n: DefNode => processNode(n) + case c: Connect => processConnect(c) + case s => s.foreachStmt(processStatements) } } def processPorts(port: Port): Unit = { - if(port.tpe == AsyncResetType){ + if (port.tpe == AsyncResetType) { val target = moduleTarget.ref(port.name) asyncRegMap(target) = new TargetSet() asyncCoMap(target) = new TargetSet() @@ -263,17 +270,17 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { /** Annotate a given target and all its children according to the asyncCoMap */ def annotateCo(ta: ReferenceTarget): Unit = { - if (asyncCoMap.contains(ta)){ + if (asyncCoMap.contains(ta)) { toCleanUp += ta - asyncCoMap(ta) foreach( (t: ReferenceTarget) => { + asyncCoMap(ta).foreach((t: ReferenceTarget) => { toCleanUp += t }) } } /** Annotate all registers somehow connected to the orignal annotated async reset */ - def annotateRegSet(set: TargetSet) : Unit = { - set foreach ( (ta: ReferenceTarget) => { + def annotateRegSet(set: TargetSet): Unit = { + set.foreach((ta: ReferenceTarget) => { annotateCo(ta) if (asyncRegMap.contains(ta)) { annotateRegSet(asyncRegMap(ta)) @@ -287,8 +294,8 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { * Walk AsyncReset Trees with all Annotated AsyncReset as entry points * Annotate all leaf registers and intermediate wires, nodes, connectors along the way */ - def annotateAsyncSet(set: TargetSet) : Unit = { - set foreach ((t: ReferenceTarget) => { + def annotateAsyncSet(set: TargetSet): Unit = { + set.foreach((t: ReferenceTarget) => { annotateCo(t) if (asyncRegMap.contains(t)) annotateRegSet(asyncRegMap(t)) @@ -300,7 +307,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { */ cs.circuit.foreachModule(processModule) // PHASE 1 : Initialize - annotateAsyncSet(asyncToAnnotate) // PHASE 2 : Annotate + annotateAsyncSet(asyncToAnnotate) // PHASE 2 : Annotate otherAnnos ++ newAnnos } @@ -312,21 +319,21 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { * Clean-up useless reset tree (not relying on DCE) * Disconnect preset registers from their reset tree */ - private def cleanUpPresetTree(circuit: Circuit, annos: AnnotationSeq) : Circuit = { - val presetRegs = annos.collect {case a : PresetRegAnnotation => a}.groupBy(_.target) + private def cleanUpPresetTree(circuit: Circuit, annos: AnnotationSeq): Circuit = { + val presetRegs = annos.collect { case a: PresetRegAnnotation => a }.groupBy(_.target) val circuitTarget = CircuitTarget(circuit.main) def processModule(m: DefModule): DefModule = { val moduleTarget = circuitTarget.module(m.name) val localInstances = new TargetMap() - def getRef(e: Expression) : Target = { + def getRef(e: Expression): Target = { e match { case w: WRef => moduleTarget.ref(w.name) case w: WSubField => getRef(w.expr) match { case rt: ReferenceTarget => - if(localInstances.contains(rt)){ + if (localInstances.contains(rt)) { circuitTarget.module(localInstances(rt)).ref(w.name) } else { rt.field(w.name) @@ -341,14 +348,13 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { case DoPrim(op, args, _, _) => op match { case AsInterval | AsUInt | AsSInt | AsClock | AsFixedPoint | AsAsyncReset => getRef(args.head) - case _ => Target(None, None, Seq.empty) + case _ => Target(None, None, Seq.empty) } case _ => Target(None, None, Seq.empty) } } - - def processRegister(r: DefRegister) : DefRegister = { + def processRegister(r: DefRegister): DefRegister = { if (presetRegs.contains(moduleTarget.ref(r.name))) { r.copy(reset = UIntLiteral(0)) } else { @@ -356,7 +362,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { } } - def processWire(w: DefWire) : Statement = { + def processWire(w: DefWire): Statement = { if (toCleanUp.contains(moduleTarget.ref(w.name))) { EmptyStmt } else { @@ -364,12 +370,12 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { } } - def processNode(n: DefNode) : Statement = { + def processNode(n: DefNode): Statement = { if (toCleanUp.contains(moduleTarget.ref(n.name))) { EmptyStmt } else { getRef(n.value) match { - case rt : ReferenceTarget if(toCleanUp.contains(rt)) => + case rt: ReferenceTarget if (toCleanUp.contains(rt)) => throw TreeCleanUpOrphanException(s"Orphan (${moduleTarget.ref(n.name)}) the way.") case _ => n } @@ -380,7 +386,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { getRef(c.expr) match { case rhs: ReferenceTarget if (toCleanUp.contains(rhs)) => getRef(c.loc) match { - case lhs : ReferenceTarget if(!toCleanUp.contains(lhs)) => + case lhs: ReferenceTarget if (!toCleanUp.contains(lhs)) => throw TreeCleanUpOrphanException(s"Orphan ${lhs} connected deleted node $rhs.") case _ => EmptyStmt } @@ -388,7 +394,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { } } - def processInstance(i: WDefInstance) : WDefInstance = { + def processInstance(i: WDefInstance): WDefInstance = { localInstances(moduleTarget.ref(i.name)) = i.module val tpe = i.tpe match { case b: BundleType => @@ -401,12 +407,12 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { def processStatements(statement: Statement): Statement = { statement match { - case i : WDefInstance => processInstance(i) - case r : DefRegister => processRegister(r) - case w : DefWire => processWire(w) - case n : DefNode => processNode(n) - case c : Connect => processConnect(c) - case s => s.mapStmt(processStatements) + case i: WDefInstance => processInstance(i) + case r: DefRegister => processRegister(r) + case w: DefWire => processWire(w) + case n: DefNode => processNode(n) + case c: Connect => processConnect(c) + case s => s.mapStmt(processStatements) } } @@ -422,10 +428,10 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { def execute(state: CircuitState): CircuitState = { // Collect all user-defined PresetAnnotation - val (presets, otherAnnos) = state.annotations.partition { case _: PresetAnnotation => true ; case _ => false } + val (presets, otherAnnos) = state.annotations.partition { case _: PresetAnnotation => true; case _ => false } // No PresetAnnotation => no need to walk the IR - if (presets.isEmpty){ + if (presets.isEmpty) { state } else { // PHASE I - Propagate diff --git a/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala b/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala index 840a3d99..ae3bc693 100644 --- a/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala +++ b/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala @@ -21,10 +21,11 @@ class RemoveKeywordCollisions(keywords: Set[String]) extends ManipulateNames { * @return Some name if a rename occurred, None otherwise * @note prefix uniqueness is not respected */ - override def manipulate = (n: String, ns: Namespace) => keywords.contains(n) match { - case true => Some(Uniquify.findValidPrefix(n + inlineDelim, Seq(""), ns.cloneUnderlying ++ keywords)) - case false => None - } + override def manipulate = (n: String, ns: Namespace) => + keywords.contains(n) match { + case true => Some(Uniquify.findValidPrefix(n + inlineDelim, Seq(""), ns.cloneUnderlying ++ keywords)) + case false => None + } } @@ -32,14 +33,16 @@ class RemoveKeywordCollisions(keywords: Set[String]) extends ManipulateNames { class VerilogRename extends RemoveKeywordCollisions(v_keywords) { override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ - Seq( Dependency[BlackBoxSourceHelper], - Dependency[FixAddingNegativeLiterals], - Dependency[ReplaceTruncatingArithmetic], - Dependency[InlineBitExtractionsTransform], - Dependency[InlineCastsTransform], - Dependency[LegalizeClocksTransform], - Dependency[FlattenRegUpdate], - Dependency(passes.VerilogModulusCleanup) ) + Seq( + Dependency[BlackBoxSourceHelper], + Dependency[FixAddingNegativeLiterals], + Dependency[ReplaceTruncatingArithmetic], + Dependency[InlineBitExtractionsTransform], + Dependency[InlineCastsTransform], + Dependency[LegalizeClocksTransform], + Dependency[FlattenRegUpdate], + Dependency(passes.VerilogModulusCleanup) + ) override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized diff --git a/src/main/scala/firrtl/transforms/RemoveReset.scala b/src/main/scala/firrtl/transforms/RemoveReset.scala index 6b3a9d07..8736e21b 100644 --- a/src/main/scala/firrtl/transforms/RemoveReset.scala +++ b/src/main/scala/firrtl/transforms/RemoveReset.scala @@ -18,8 +18,7 @@ import scala.collection.{immutable, mutable} object RemoveReset extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.MidForm ++ - Seq( Dependency(passes.LowerTypes), - Dependency(passes.Legalize) ) + Seq(Dependency(passes.LowerTypes), Dependency(passes.Legalize)) override def optionalPrerequisites = Seq.empty @@ -58,7 +57,7 @@ object RemoveReset extends Transform with DependencyAPIMigration { reg.copy(reset = Utils.zero, init = WRef(reg)) case reg @ DefRegister(_, rname, _, _, Utils.zero, _) => reg.copy(init = WRef(reg)) // canonicalize - case reg @ DefRegister(info , rname, _, _, reset, init) if reset.tpe != AsyncResetType => + case reg @ DefRegister(info, rname, _, _, reset, init) if reset.tpe != AsyncResetType => // Add register reset to map resets(rname) = Reset(reset, init, info) reg.copy(reset = Utils.zero, init = WRef(reg)) @@ -68,7 +67,7 @@ object RemoveReset extends Transform with DependencyAPIMigration { // Use reg source locator for mux enable and true value since that's where they're defined val infox = MultiInfo(reset.info, reset.info, info) Connect(infox, ref, Mux(reset.cond, reset.value, expr, muxType)) - case other => other map onStmt + case other => other.map(onStmt) } } m.map(onStmt) diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala index f692e513..31fa3b6f 100644 --- a/src/main/scala/firrtl/transforms/RemoveWires.scala +++ b/src/main/scala/firrtl/transforms/RemoveWires.scala @@ -8,11 +8,11 @@ import firrtl.Utils._ import firrtl.Mappers._ import firrtl.traversals.Foreachers._ import firrtl.WrappedExpression._ -import firrtl.graph.{MutableDiGraph, CyclicException} +import firrtl.graph.{CyclicException, MutableDiGraph} import firrtl.options.Dependency import scala.collection.mutable -import scala.util.{Try, Success, Failure} +import scala.util.{Failure, Success, Try} /** Replace wires with nodes in a legal, flow-forward order * @@ -23,11 +23,13 @@ import scala.util.{Try, Success, Failure} class RemoveWires extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.MidForm ++ - Seq( Dependency(passes.LowerTypes), - Dependency(passes.Legalize), - Dependency(passes.ResolveKinds), - Dependency(transforms.RemoveReset), - Dependency[transforms.CheckCombLoops] ) + Seq( + Dependency(passes.LowerTypes), + Dependency(passes.Legalize), + Dependency(passes.ResolveKinds), + Dependency(transforms.RemoveReset), + Dependency[transforms.CheckCombLoops] + ) override def optionalPrerequisites = Seq(Dependency[checks.CheckResets]) @@ -35,7 +37,7 @@ class RemoveWires extends Transform with DependencyAPIMigration { override def invalidates(a: Transform) = a match { case passes.ResolveKinds => true - case _ => false + case _ => false } // Extract all expressions that are references to a Node, Wire, or Reg @@ -44,7 +46,7 @@ class RemoveWires extends Transform with DependencyAPIMigration { val refs = mutable.ArrayBuffer.empty[WRef] def rec(e: Expression): Expression = { e match { - case ref @ WRef(_,_, WireKind | NodeKind | RegKind, _) => refs += ref + case ref @ WRef(_, _, WireKind | NodeKind | RegKind, _) => refs += ref case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested.foreach(rec) case _ => // Do nothing } @@ -57,7 +59,8 @@ class RemoveWires extends Transform with DependencyAPIMigration { // Transform netlist into DefNodes private def getOrderedNodes( netlist: mutable.LinkedHashMap[WrappedExpression, (Seq[Expression], Info)], - regInfo: mutable.Map[WrappedExpression, DefRegister]): Try[Seq[Statement]] = { + regInfo: mutable.Map[WrappedExpression, DefRegister] + ): Try[Seq[Statement]] = { val digraph = new MutableDiGraph[WrappedExpression] for ((sink, (exprs, _)) <- netlist) { digraph.addVertex(sink) @@ -106,21 +109,22 @@ class RemoveWires extends Transform with DependencyAPIMigration { case reg: DefRegister => val resetDep = reg.reset.tpe match { case AsyncResetType => Some(reg.reset) - case _ => None + case _ => None } val initDep = Some(reg.init).filter(we(WRef(reg)) != we(_)) // Dependency exists IF reg doesn't init itself regInfo(we(WRef(reg))) = reg netlist(we(WRef(reg))) = (Seq(reg.clock) ++ resetDep ++ initDep, reg.info) case decl: IsDeclaration => // Keep all declarations except for nodes and non-Analog wires decls += decl - case con @ Connect(cinfo, lhs, rhs) => kind(lhs) match { - case WireKind => - // Be sure to pad the rhs since nodes get their type from the rhs - val paddedRhs = ConstantPropagation.pad(rhs, lhs.tpe) - val dinfo = wireInfo(lhs) - netlist(we(lhs)) = (Seq(paddedRhs), MultiInfo(dinfo, cinfo)) - case _ => otherStmts += con // Other connections just pass through - } + case con @ Connect(cinfo, lhs, rhs) => + kind(lhs) match { + case WireKind => + // Be sure to pad the rhs since nodes get their type from the rhs + val paddedRhs = ConstantPropagation.pad(rhs, lhs.tpe) + val dinfo = wireInfo(lhs) + netlist(we(lhs)) = (Seq(paddedRhs), MultiInfo(dinfo, cinfo)) + case _ => otherStmts += con // Other connections just pass through + } case invalid @ IsInvalid(info, expr) => kind(expr) match { case WireKind => @@ -146,8 +150,10 @@ class RemoveWires extends Transform with DependencyAPIMigration { // If we hit a CyclicException, just abort removing wires case Failure(c: CyclicException) => val problematicNode = c.node - logger.warn(s"Cycle found in module $name, " + - s"wires will not be removed which can prevent optimizations! Problem node: $problematicNode") + logger.warn( + s"Cycle found in module $name, " + + s"wires will not be removed which can prevent optimizations! Problem node: $problematicNode" + ) mod case Failure(other) => throw other } @@ -155,7 +161,6 @@ class RemoveWires extends Transform with DependencyAPIMigration { } } - def execute(state: CircuitState): CircuitState = state.copy(circuit = state.circuit.map(onModule)) } diff --git a/src/main/scala/firrtl/transforms/RenameModules.scala b/src/main/scala/firrtl/transforms/RenameModules.scala index d37f8c39..16fd655a 100644 --- a/src/main/scala/firrtl/transforms/RenameModules.scala +++ b/src/main/scala/firrtl/transforms/RenameModules.scala @@ -44,7 +44,7 @@ class RenameModules extends Transform with DependencyAPIMigration { moduleOrder.foreach(collectNameMapping(namespace.get, nameMappings)) val modulesx = state.circuit.modules.map { - case mod: Module => mod.mapStmt(onStmt(nameMappings)).mapString(nameMappings) + case mod: Module => mod.mapStmt(onStmt(nameMappings)).mapString(nameMappings) case ext: ExtModule => ext } diff --git a/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala b/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala index a93087b9..14c84b91 100644 --- a/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala +++ b/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala @@ -80,8 +80,7 @@ object ReplaceTruncatingArithmetic { class ReplaceTruncatingArithmetic extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ - Seq( Dependency[BlackBoxSourceHelper], - Dependency[FixAddingNegativeLiterals] ) + Seq(Dependency[BlackBoxSourceHelper], Dependency[FixAddingNegativeLiterals]) override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized diff --git a/src/main/scala/firrtl/transforms/SimplifyMems.scala b/src/main/scala/firrtl/transforms/SimplifyMems.scala index a056c7da..7790d060 100644 --- a/src/main/scala/firrtl/transforms/SimplifyMems.scala +++ b/src/main/scala/firrtl/transforms/SimplifyMems.scala @@ -33,12 +33,13 @@ class SimplifyMems extends Transform with DependencyAPIMigration { def onExpr(e: Expression): Expression = e.map(onExpr) match { case wr @ WRef(name, _, MemKind, _) if memAdapters.contains(name) => wr.copy(kind = WireKind) - case e => e + case e => e } def simplifyMem(mem: DefMemory): Statement = { val adapterDecl = DefWire(mem.info, mem.name, memType(mem)) - val simpleMemDecl = mem.copy(name = moduleNS.newName(s"${mem.name}_flattened"), dataType = flattenType(mem.dataType)) + val simpleMemDecl = + mem.copy(name = moduleNS.newName(s"${mem.name}_flattened"), dataType = flattenType(mem.dataType)) val oldRT = mTarget.ref(mem.name) val adapterConnects = memType(simpleMemDecl).fields.flatMap { case Field(pName, Flip, pType: BundleType) => @@ -63,8 +64,10 @@ class SimplifyMems extends Transform with DependencyAPIMigration { def canSimplify(mem: DefMemory) = mem.dataType match { case at: AggregateType => - val wMasks = mem.writers.map(w => getMaskBits(connects, memPortField(mem, w, "en"), memPortField(mem, w, "mask"))) - val rwMasks = mem.readwriters.map(w => getMaskBits(connects, memPortField(mem, w, "wmode"), memPortField(mem, w, "wmask"))) + val wMasks = + mem.writers.map(w => getMaskBits(connects, memPortField(mem, w, "en"), memPortField(mem, w, "mask"))) + val rwMasks = + mem.readwriters.map(w => getMaskBits(connects, memPortField(mem, w, "wmode"), memPortField(mem, w, "wmask"))) (wMasks ++ rwMasks).flatten.isEmpty case _ => false } diff --git a/src/main/scala/firrtl/transforms/TopWiring.scala b/src/main/scala/firrtl/transforms/TopWiring.scala index f5a5e2a3..b35fed22 100644 --- a/src/main/scala/firrtl/transforms/TopWiring.scala +++ b/src/main/scala/firrtl/transforms/TopWiring.scala @@ -4,7 +4,7 @@ package TopWiring import firrtl._ import firrtl.ir._ -import firrtl.passes.{InferTypes, LowerTypes, ResolveKinds, ResolveFlows, ExpandConnects} +import firrtl.passes.{ExpandConnects, InferTypes, LowerTypes, ResolveFlows, ResolveKinds} import firrtl.annotations._ import firrtl.Mappers._ import firrtl.analyses.InstanceKeyGraph @@ -13,22 +13,21 @@ import firrtl.options.Dependency import collection.mutable -/** Annotation for optional output files, and what directory to put those files in (absolute path) **/ -case class TopWiringOutputFilesAnnotation(dirName: String, - outputFunction: (String,Seq[((ComponentName, Type, Boolean, - Seq[String],String), Int)], - CircuitState) => CircuitState) extends NoTargetAnnotation +/** Annotation for optional output files, and what directory to put those files in (absolute path) * */ +case class TopWiringOutputFilesAnnotation( + dirName: String, + outputFunction: (String, Seq[((ComponentName, Type, Boolean, Seq[String], String), Int)], + CircuitState) => CircuitState) + extends NoTargetAnnotation /** Annotation for indicating component to be wired, and what prefix to add to the ports that are generated */ -case class TopWiringAnnotation(target: ComponentName, prefix: String) extends - SingleTargetAnnotation[ComponentName] { +case class TopWiringAnnotation(target: ComponentName, prefix: String) extends SingleTargetAnnotation[ComponentName] { def duplicate(n: ComponentName) = this.copy(target = n) } - /** Punch out annotated ports out to the toplevel of the circuit. - This also has an option to pass a function as a parmeter to generate - custom output files as a result of the additional ports + * This also has an option to pass a function as a parmeter to generate + * custom output files as a result of the additional ports * @note This *does* work for deduped modules */ class TopWiringTransform extends Transform with DependencyAPIMigration { @@ -39,116 +38,133 @@ class TopWiringTransform extends Transform with DependencyAPIMigration { override def invalidates(a: Transform): Boolean = a match { case InferTypes | ResolveKinds | ResolveFlows | ExpandConnects => true - case _ => false + case _ => false } type InstPath = Seq[String] /** Get the names of the targets that need to be wired */ private def getSourceNames(state: CircuitState): Map[ComponentName, String] = { - state.annotations.collect { case TopWiringAnnotation(srcname,prefix) => - (srcname -> prefix) }.toMap.withDefaultValue("") + state.annotations.collect { + case TopWiringAnnotation(srcname, prefix) => + (srcname -> prefix) + }.toMap.withDefaultValue("") } - /** Get the names of the modules which include the targets that need to be wired */ private def getSourceModNames(state: CircuitState): Seq[String] = { - state.annotations.collect { case TopWiringAnnotation(ComponentName(_,ModuleName(srcmodname, _)),_) => srcmodname } + state.annotations.collect { case TopWiringAnnotation(ComponentName(_, ModuleName(srcmodname, _)), _) => srcmodname } } - - /** Get the Type of each wire to be connected * * Find the definition of each wire in sourceList, and get the type and whether or not it's a port * Update the results in sourceMap */ - private def getSourceTypes(sourceList: Map[ComponentName, String], - sourceMap: mutable.Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]], - currentmodule: ModuleName, state: CircuitState)(s: Statement): Statement = s match { + private def getSourceTypes( + sourceList: Map[ComponentName, String], + sourceMap: mutable.Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]], + currentmodule: ModuleName, + state: CircuitState + )(s: Statement + ): Statement = s match { // If target wire, add name and size to to sourceMap case w: IsDeclaration => if (sourceList.keys.toSeq.contains(ComponentName(w.name, currentmodule))) { - val (isport, tpe, prefix) = w match { - case d: DefWire => (false, d.tpe, sourceList(ComponentName(w.name,currentmodule))) - case d: DefNode => (false, d.value.tpe, sourceList(ComponentName(w.name,currentmodule))) - case d: DefRegister => (false, d.tpe, sourceList(ComponentName(w.name,currentmodule))) - case d: Port => (true, d.tpe, sourceList(ComponentName(w.name,currentmodule))) - case _ => throw new Exception(s"Cannot wire this type of declaration! ${w.serialize}") - } - sourceMap.get(currentmodule.name) match { - case Some(xs:Seq[(ComponentName, Type, Boolean, InstPath, String)]) => - sourceMap.update(currentmodule.name, xs :+( - (ComponentName(w.name,currentmodule), tpe, isport ,Seq[String](w.name), prefix) )) - case None => - sourceMap(currentmodule.name) = Seq((ComponentName(w.name,currentmodule), - tpe, isport ,Seq[String](w.name), prefix)) - } + val (isport, tpe, prefix) = w match { + case d: DefWire => (false, d.tpe, sourceList(ComponentName(w.name, currentmodule))) + case d: DefNode => (false, d.value.tpe, sourceList(ComponentName(w.name, currentmodule))) + case d: DefRegister => (false, d.tpe, sourceList(ComponentName(w.name, currentmodule))) + case d: Port => (true, d.tpe, sourceList(ComponentName(w.name, currentmodule))) + case _ => throw new Exception(s"Cannot wire this type of declaration! ${w.serialize}") + } + sourceMap.get(currentmodule.name) match { + case Some(xs: Seq[(ComponentName, Type, Boolean, InstPath, String)]) => + sourceMap.update( + currentmodule.name, + xs :+ ((ComponentName(w.name, currentmodule), tpe, isport, Seq[String](w.name), prefix)) + ) + case None => + sourceMap(currentmodule.name) = Seq( + (ComponentName(w.name, currentmodule), tpe, isport, Seq[String](w.name), prefix) + ) + } } w // Return argument unchanged (ok because DefWire has no Statement children) // If not, apply to all children Statement - case _ => s map getSourceTypes(sourceList, sourceMap, currentmodule, state) + case _ => s.map(getSourceTypes(sourceList, sourceMap, currentmodule, state)) } - - /** Get the Type of each port to be connected * * Similar to getSourceTypes, but specifically for ports since they are not found in statements. * Find the definition of each port in sourceList, and get the type and whether or not it's a port * Update the results in sourceMap */ - private def getSourceTypesPorts(sourceList: Map[ComponentName, String], sourceMap: mutable.Map[String, - Seq[(ComponentName, Type, Boolean, InstPath, String)]], - currentmodule: ModuleName, state: CircuitState)(s: Port): CircuitState = s match { + private def getSourceTypesPorts( + sourceList: Map[ComponentName, String], + sourceMap: mutable.Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]], + currentmodule: ModuleName, + state: CircuitState + )(s: Port + ): CircuitState = s match { // If target port, add name and size to to sourceMap case w: IsDeclaration => if (sourceList.keys.toSeq.contains(ComponentName(w.name, currentmodule))) { - val (isport, tpe, prefix) = w match { - case d: Port => (true, d.tpe, sourceList(ComponentName(w.name,currentmodule))) - case _ => throw new Exception(s"Cannot wire this type of declaration! ${w.serialize}") - } - sourceMap.get(currentmodule.name) match { - case Some(xs:Seq[(ComponentName, Type, Boolean, InstPath, String)]) => - sourceMap.update(currentmodule.name, xs :+( - (ComponentName(w.name,currentmodule), tpe, isport ,Seq[String](w.name), prefix) )) - case None => - sourceMap(currentmodule.name) = Seq((ComponentName(w.name,currentmodule), - tpe, isport ,Seq[String](w.name), prefix)) - } + val (isport, tpe, prefix) = w match { + case d: Port => (true, d.tpe, sourceList(ComponentName(w.name, currentmodule))) + case _ => throw new Exception(s"Cannot wire this type of declaration! ${w.serialize}") + } + sourceMap.get(currentmodule.name) match { + case Some(xs: Seq[(ComponentName, Type, Boolean, InstPath, String)]) => + sourceMap.update( + currentmodule.name, + xs :+ ((ComponentName(w.name, currentmodule), tpe, isport, Seq[String](w.name), prefix)) + ) + case None => + sourceMap(currentmodule.name) = Seq( + (ComponentName(w.name, currentmodule), tpe, isport, Seq[String](w.name), prefix) + ) + } } state // Return argument unchanged (ok because DefWire has no Statement children) // If not, apply to all children Statement case _ => state } - /** Create a map of Module name to target wires under this module * * These paths are relative but cross module (they refer down through instance hierarchy) */ - private def getSourcesMap(state: CircuitState): Map[String,Seq[(ComponentName, Type, Boolean, InstPath, String)]] = { + private def getSourcesMap(state: CircuitState): Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]] = { val sSourcesModNames = getSourceModNames(state) val sSourcesNames = getSourceNames(state) val instGraph = firrtl.analyses.InstanceKeyGraph(state.circuit) - val cMap = instGraph.getChildInstances.map{ case (m, wdis) => - (m -> wdis.map{ case wdi => (wdi.name, wdi.module) }.toSeq) }.toMap + val cMap = instGraph.getChildInstances.map { + case (m, wdis) => + (m -> wdis.map { case wdi => (wdi.name, wdi.module) }.toSeq) + }.toMap val topSort = instGraph.moduleOrder.reverse // Map of component name to relative instance paths that result in a debug wire val sourcemods: mutable.Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]] = mutable.Map(sSourcesModNames.map(_ -> Seq()): _*) - state.circuit.modules.foreach { m => m map - getSourceTypes(sSourcesNames, sourcemods, ModuleName(m.name, CircuitName(state.circuit.main)) , state) } - state.circuit.modules.foreach { m => m.ports.foreach { - p => Seq(p) map - getSourceTypesPorts(sSourcesNames, sourcemods, ModuleName(m.name, CircuitName(state.circuit.main)) , state) }} + state.circuit.modules.foreach { m => + m.map(getSourceTypes(sSourcesNames, sourcemods, ModuleName(m.name, CircuitName(state.circuit.main)), state)) + } + state.circuit.modules.foreach { m => + m.ports.foreach { p => + Seq(p).map( + getSourceTypesPorts(sSourcesNames, sourcemods, ModuleName(m.name, CircuitName(state.circuit.main)), state) + ) + } + } for (mod <- topSort) { - val seqChildren: Seq[(ComponentName,Type,Boolean,InstPath,String)] = cMap(mod.name).flatMap { + val seqChildren: Seq[(ComponentName, Type, Boolean, InstPath, String)] = cMap(mod.name).flatMap { case (inst, module) => - sourcemods.get(module).map( _.map { case (a,b,c,path,p) => (a,b,c, inst +: path, p)}) + sourcemods.get(module).map(_.map { case (a, b, c, path, p) => (a, b, c, inst +: path, p) }) }.flatten if (seqChildren.nonEmpty) { sourcemods(mod.name) = sourcemods.getOrElse(mod.name, Seq()) ++ seqChildren @@ -158,108 +174,113 @@ class TopWiringTransform extends Transform with DependencyAPIMigration { sourcemods.toMap } - - /** Process a given DefModule * * For Modules that contain or are in the parent hierarchy to modules containing target wires * 1. Add ports for each target wire this module is parent to * 2. Connect these ports to ports of instances that are parents to some number of target wires */ - private def onModule(sources: Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]], - portnamesmap : mutable.Map[String,String], - instgraph : firrtl.analyses.InstanceKeyGraph, - namespacemap : Map[String, Namespace]) - (module: DefModule): DefModule = { + private def onModule( + sources: Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]], + portnamesmap: mutable.Map[String, String], + instgraph: firrtl.analyses.InstanceKeyGraph, + namespacemap: Map[String, Namespace] + )(module: DefModule + ): DefModule = { val namespace = namespacemap(module.name) sources.get(module.name) match { case Some(p) => - val newPorts = p.map{ case (ComponentName(cname,_), tpe, _ , path, prefix) => { - val newportname = portnamesmap.get(prefix + path.mkString("_")) match { - case Some(pn) => pn - case None => { - val npn = namespace.newName(prefix + path.mkString("_")) - portnamesmap(prefix + path.mkString("_")) = npn - npn - } + val newPorts = p.map { + case (ComponentName(cname, _), tpe, _, path, prefix) => { + val newportname = portnamesmap.get(prefix + path.mkString("_")) match { + case Some(pn) => pn + case None => { + val npn = namespace.newName(prefix + path.mkString("_")) + portnamesmap(prefix + path.mkString("_")) = npn + npn } - Port(NoInfo, newportname, Output, tpe) - } } + } + Port(NoInfo, newportname, Output, tpe) + } + } // Add connections to Module val childInstances = instgraph.getChildInstances.toMap module match { case m: Module => - val connections: Seq[Connect] = p.map { case (ComponentName(cname,_), _, _ , path, prefix) => + val connections: Seq[Connect] = p.map { + case (ComponentName(cname, _), _, _, path, prefix) => val modRef = portnamesmap.get(prefix + path.mkString("_")) match { - case Some(pn) => WRef(pn) - case None => { - portnamesmap(prefix + path.mkString("_")) = namespace.newName(prefix + path.mkString("_")) - WRef(portnamesmap(prefix + path.mkString("_"))) - } + case Some(pn) => WRef(pn) + case None => { + portnamesmap(prefix + path.mkString("_")) = namespace.newName(prefix + path.mkString("_")) + WRef(portnamesmap(prefix + path.mkString("_"))) + } } path.size match { - case 1 => { - val leafRef = WRef(path.head.mkString("")) - Connect(NoInfo, modRef, leafRef) - } - case _ => { - val instportname = portnamesmap.get(prefix + path.tail.mkString("_")) match { - case Some(ipn) => ipn - case None => { - val instmod = childInstances(module.name).collectFirst { - case wdi if wdi.name == path.head => wdi.module}.get - val instnamespace = namespacemap(instmod) - portnamesmap(prefix + path.tail.mkString("_")) = - instnamespace.newName(prefix + path.tail.mkString("_")) - portnamesmap(prefix + path.tail.mkString("_")) - } - } - val instRef = WSubField(WRef(path.head), instportname) - Connect(NoInfo, modRef, instRef) + case 1 => { + val leafRef = WRef(path.head.mkString("")) + Connect(NoInfo, modRef, leafRef) + } + case _ => { + val instportname = portnamesmap.get(prefix + path.tail.mkString("_")) match { + case Some(ipn) => ipn + case None => { + val instmod = childInstances(module.name).collectFirst { + case wdi if wdi.name == path.head => wdi.module + }.get + val instnamespace = namespacemap(instmod) + portnamesmap(prefix + path.tail.mkString("_")) = + instnamespace.newName(prefix + path.tail.mkString("_")) + portnamesmap(prefix + path.tail.mkString("_")) + } + } + val instRef = WSubField(WRef(path.head), instportname) + Connect(NoInfo, modRef, instRef) } } } - m.copy(ports = m.ports ++ newPorts, body = Block(Seq(m.body) ++ connections )) + m.copy(ports = m.ports ++ newPorts, body = Block(Seq(m.body) ++ connections)) case e: ExtModule => e.copy(ports = e.ports ++ newPorts) - } + } case None => module // unchanged if no paths } } - /** Dummy function that is currently unused. Can be used to fill an outputFunction requirment in the future */ - def topWiringDummyOutputFilesFunction(dir: String, - mapping: Seq[((ComponentName, Type, Boolean, InstPath, String), Int)], - state: CircuitState): CircuitState = { - state + /** Dummy function that is currently unused. Can be used to fill an outputFunction requirment in the future */ + def topWiringDummyOutputFilesFunction( + dir: String, + mapping: Seq[((ComponentName, Type, Boolean, InstPath, String), Int)], + state: CircuitState + ): CircuitState = { + state } - def execute(state: CircuitState): CircuitState = { - val outputTuples: Seq[(String, - (String,Seq[((ComponentName, Type, Boolean, InstPath, String), Int)], - CircuitState) => CircuitState)] = state.annotations.collect { - case TopWiringOutputFilesAnnotation(td,of) => (td, of) } + val outputTuples: Seq[ + (String, (String, Seq[((ComponentName, Type, Boolean, InstPath, String), Int)], CircuitState) => CircuitState) + ] = state.annotations.collect { + case TopWiringOutputFilesAnnotation(td, of) => (td, of) + } // Do actual work of this transform val sources = getSourcesMap(state) val (nstate, nmappings) = if (sources.nonEmpty) { - val portnamesmap: mutable.Map[String,String] = mutable.Map() + val portnamesmap: mutable.Map[String, String] = mutable.Map() val instgraph = InstanceKeyGraph(state.circuit) - val namespacemap = state.circuit.modules.map{ case m => (m.name -> Namespace(m)) }.toMap - val modulesx = state.circuit.modules map onModule(sources, portnamesmap, instgraph, namespacemap) + val namespacemap = state.circuit.modules.map { case m => (m.name -> Namespace(m)) }.toMap + val modulesx = state.circuit.modules.map(onModule(sources, portnamesmap, instgraph, namespacemap)) val newCircuit = state.circuit.copy(modules = modulesx) val mappings = sources(state.circuit.main).zipWithIndex val annosx = state.annotations.filter { case _: TopWiringAnnotation => false - case _ => true + case _ => true } (state.copy(circuit = newCircuit, annotations = annosx), mappings) - } - else { (state, List.empty) } + } else { (state, List.empty) } //Generate output files based on the mapping. outputTuples.map { case (dir, outputfunction) => outputfunction(dir, nmappings, nstate) } nstate diff --git a/src/main/scala/firrtl/transforms/formal/AssertSubmoduleAssumptions.scala b/src/main/scala/firrtl/transforms/formal/AssertSubmoduleAssumptions.scala index 7370fcfb..cdbee495 100644 --- a/src/main/scala/firrtl/transforms/formal/AssertSubmoduleAssumptions.scala +++ b/src/main/scala/firrtl/transforms/formal/AssertSubmoduleAssumptions.scala @@ -1,4 +1,3 @@ - package firrtl.transforms.formal import firrtl.ir.{Circuit, Formal, Statement, Verification} @@ -7,7 +6,6 @@ import firrtl.{CircuitState, DependencyAPIMigration, Transform} import firrtl.annotations.NoTargetAnnotation import firrtl.options.{PreservesAll, RegisteredTransform, ShellOption} - /** * Assert Submodule Assumptions * @@ -16,12 +14,13 @@ import firrtl.options.{PreservesAll, RegisteredTransform, ShellOption} * overly restrictive assume in a child module can prevent the model checker * from searching valid inputs and states in the parent module. */ -class AssertSubmoduleAssumptions extends Transform - with RegisteredTransform - with DependencyAPIMigration - with PreservesAll[Transform] { +class AssertSubmoduleAssumptions + extends Transform + with RegisteredTransform + with DependencyAPIMigration + with PreservesAll[Transform] { - override def prerequisites: Seq[TransformDependency] = Seq.empty + override def prerequisites: Seq[TransformDependency] = Seq.empty override def optionalPrerequisites: Seq[TransformDependency] = Seq.empty override def optionalPrerequisiteOf: Seq[TransformDependency] = firrtl.stage.Forms.MidEmitters @@ -29,9 +28,10 @@ class AssertSubmoduleAssumptions extends Transform val options = Seq( new ShellOption[Unit]( longOption = "no-asa", - toAnnotationSeq = (_: Unit) => Seq( - DontAssertSubmoduleAssumptionsAnnotation), - helpText = "Disable assert submodule assumptions" ) ) + toAnnotationSeq = (_: Unit) => Seq(DontAssertSubmoduleAssumptionsAnnotation), + helpText = "Disable assert submodule assumptions" + ) + ) def assertAssumption(s: Statement): Statement = s match { case Verification(Formal.Assume, info, clk, cond, en, msg) => @@ -50,8 +50,7 @@ class AssertSubmoduleAssumptions extends Transform } def execute(state: CircuitState): CircuitState = { - val noASA = state.annotations.contains( - DontAssertSubmoduleAssumptionsAnnotation) + val noASA = state.annotations.contains(DontAssertSubmoduleAssumptionsAnnotation) if (noASA) { logger.info("Skipping assert submodule assumptions") state diff --git a/src/main/scala/firrtl/transforms/formal/ConvertAsserts.scala b/src/main/scala/firrtl/transforms/formal/ConvertAsserts.scala index ddead331..5928c79c 100644 --- a/src/main/scala/firrtl/transforms/formal/ConvertAsserts.scala +++ b/src/main/scala/firrtl/transforms/formal/ConvertAsserts.scala @@ -14,10 +14,8 @@ import firrtl.options.Dependency object ConvertAsserts extends Transform with DependencyAPIMigration { override def prerequisites = Nil override def optionalPrerequisites = Nil - override def optionalPrerequisiteOf = Seq( - Dependency[VerilogEmitter], - Dependency[MinimumVerilogEmitter], - Dependency[RemoveVerificationStatements]) + override def optionalPrerequisiteOf = + Seq(Dependency[VerilogEmitter], Dependency[MinimumVerilogEmitter], Dependency[RemoveVerificationStatements]) override def invalidates(a: Transform): Boolean = false @@ -28,7 +26,7 @@ object ConvertAsserts extends Transform with DependencyAPIMigration { val stop = Stop(i, 1, clk, gatedNPred) msg match { case StringLit("") => stop - case _ => Block(Print(i, msg, Nil, clk, gatedNPred), stop) + case _ => Block(Print(i, msg, Nil, clk, gatedNPred), stop) } case s => s.mapStmt(convertAsserts) } diff --git a/src/main/scala/firrtl/transforms/formal/RemoveVerificationStatements.scala b/src/main/scala/firrtl/transforms/formal/RemoveVerificationStatements.scala index 72890c07..1e6d2c72 100644 --- a/src/main/scala/firrtl/transforms/formal/RemoveVerificationStatements.scala +++ b/src/main/scala/firrtl/transforms/formal/RemoveVerificationStatements.scala @@ -1,4 +1,3 @@ - package firrtl.transforms.formal import firrtl.ir.{Circuit, EmptyStmt, Statement, Verification} @@ -6,7 +5,6 @@ import firrtl.{CircuitState, DependencyAPIMigration, MinimumVerilogEmitter, Tran import firrtl.options.{Dependency, PreservesAll, StageUtils} import firrtl.stage.TransformManager.TransformDependency - /** * Remove Verification Statements * @@ -14,15 +12,12 @@ import firrtl.stage.TransformManager.TransformDependency * This is intended to be required by the Verilog emitter to ensure compatibility * with the Verilog 2001 standard. */ -class RemoveVerificationStatements extends Transform - with DependencyAPIMigration - with PreservesAll[Transform] { +class RemoveVerificationStatements extends Transform with DependencyAPIMigration with PreservesAll[Transform] { - override def prerequisites: Seq[TransformDependency] = Seq.empty + override def prerequisites: Seq[TransformDependency] = Seq.empty override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency(ConvertAsserts)) override def optionalPrerequisiteOf: Seq[TransformDependency] = - Seq( Dependency[VerilogEmitter], - Dependency[MinimumVerilogEmitter]) + Seq(Dependency[VerilogEmitter], Dependency[MinimumVerilogEmitter]) private var removedCounter = 0 @@ -43,11 +38,13 @@ class RemoveVerificationStatements extends Transform def execute(state: CircuitState): CircuitState = { val newState = state.copy(circuit = run(state.circuit)) if (removedCounter > 0) { - StageUtils.dramaticWarning(s"$removedCounter verification statements " + - "(assert, assume or cover) " + - "were removed when compiling to Verilog because the basic Verilog " + - "standard does not support them. If this was not intended, compile " + - "to System Verilog instead using the `-X sverilog` compiler flag.") + StageUtils.dramaticWarning( + s"$removedCounter verification statements " + + "(assert, assume or cover) " + + "were removed when compiling to Verilog because the basic Verilog " + + "standard does not support them. If this was not intended, compile " + + "to System Verilog instead using the `-X sverilog` compiler flag." + ) } newState } |
