diff options
| author | David Biancolin | 2019-01-21 18:50:51 -0500 |
|---|---|---|
| committer | GitHub | 2019-01-21 18:50:51 -0500 |
| commit | 10586d6a141859b843057ec9979011e26ad207f1 (patch) | |
| tree | ff23c30013159cdd1879b1e5c3dd5baca5bf4867 /src | |
| parent | 73ae6257fce586ac145b6ab348ce1b47634e7a46 (diff) | |
| parent | df3a34f01d227ff9ad0e63a41ff10001ac01c01d (diff) | |
Merge branch 'master' into top-wiring-aggregates
Diffstat (limited to 'src')
18 files changed, 517 insertions, 79 deletions
diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala index 80ba42c4..87662800 100644 --- a/src/main/scala/firrtl/Compiler.scala +++ b/src/main/scala/firrtl/Compiler.scala @@ -373,6 +373,10 @@ trait Compiler extends LazyLogging { */ def transforms: Seq[Transform] + require(transforms.size >= 1, + s"Compiler transforms for '${this.getClass.getName}' must have at least ONE Transform! " + + "Use IdentityTransform if you need an identity/no-op transform.") + // Similar to (input|output)Form on [[Transform]] but derived from this Compiler's transforms def inputForm: CircuitForm = transforms.head.inputForm def outputForm: CircuitForm = transforms.last.outputForm diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index 350fd433..eab928c2 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -2,6 +2,8 @@ package firrtl +import firrtl.transforms.IdentityTransform + sealed abstract class CoreTransform extends SeqTransform /** This transforms "CHIRRTL", the chisel3 IR, to "Firrtl". Note the resulting @@ -132,7 +134,7 @@ import firrtl.transforms.BlackBoxSourceHelper */ class NoneCompiler extends Compiler { def emitter = new ChirrtlEmitter - def transforms: Seq[Transform] = Seq.empty + def transforms: Seq[Transform] = Seq(new IdentityTransform(ChirrtlForm)) } /** Emits input circuit diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index b96cd253..a5d2571d 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -40,6 +40,11 @@ object WRef { def apply(reg: DefRegister): WRef = new WRef(reg.name, reg.tpe, RegKind, UNKNOWNGENDER) /** Creates a WRef from a Node */ def apply(node: DefNode): WRef = new WRef(node.name, node.value.tpe, NodeKind, MALE) + /** Creates a WRef from a Port */ + def apply(port: Port): WRef = new WRef(port.name, port.tpe, PortKind, UNKNOWNGENDER) + /** Creates a WRef from a WDefInstance */ + def apply(wi: WDefInstance): WRef = new WRef(wi.name, wi.tpe, InstanceKind, UNKNOWNGENDER) + /** Creates a WRef from an arbitrary string name */ def apply(n: String, t: Type = UnknownType, k: Kind = ExpKind): WRef = new WRef(n, t, k, UNKNOWNGENDER) } case class WSubField(expr: Expression, name: String, tpe: Type, gender: Gender) extends Expression { diff --git a/src/main/scala/firrtl/annotations/AnnotationUtils.scala b/src/main/scala/firrtl/annotations/AnnotationUtils.scala index ba9220f7..72765ab7 100644 --- a/src/main/scala/firrtl/annotations/AnnotationUtils.scala +++ b/src/main/scala/firrtl/annotations/AnnotationUtils.scala @@ -51,8 +51,8 @@ object AnnotationUtils { case Some(_) => val i = s.indexWhere(c => "[].".contains(c)) s.slice(0, i) match { - case "" => Seq(s(i).toString) ++ tokenize(s.drop(i + 1)) - case x => Seq(x, s(i).toString) ++ tokenize(s.drop(i + 1)) + case "" => s(i).toString +: tokenize(s.drop(i + 1)) + case x => x +: s(i).toString +: tokenize(s.drop(i + 1)) } case None if s == "" => Nil case None => Seq(s) diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index cdf8e194..19ee56ca 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -607,7 +607,7 @@ case object UnknownType extends Type { } /** [[Port]] Direction */ -abstract class Direction extends FirrtlNode +sealed abstract class Direction extends FirrtlNode case object Input extends Direction { def serialize: String = "input" } diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala index cc69be6f..5a9a60f8 100644 --- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala +++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala @@ -103,7 +103,11 @@ object RemoveCHIRRTL extends Transform { rds map (_.name), wrs map (_.name), rws map (_.name)) Block(mem +: stmts) case sx: CDefMPort => - types(sx.name) = types(sx.mem) + types.get(sx.mem) match { + case Some(mem) => types(sx.name) = mem + case None => + throw new PassException(s"Undefined memory ${sx.mem} referenced by mport ${sx.name}") + } val addrs = ArrayBuffer[String]() val clks = ArrayBuffer[String]() val ens = ArrayBuffer[String]() diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala index 2b6fa55d..c13fa261 100644 --- a/src/main/scala/firrtl/passes/Uniquify.scala +++ b/src/main/scala/firrtl/passes/Uniquify.scala @@ -3,14 +3,16 @@ package firrtl.passes import com.typesafe.scalalogging.LazyLogging -import scala.annotation.tailrec +import scala.annotation.tailrec import firrtl._ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ import MemPortUtils.memType +import scala.collection.mutable + /** Resolve name collisions that would occur in [[LowerTypes]] * * @note Must be run after [[InferTypes]] because [[ir.DefNode]]s need type @@ -78,36 +80,43 @@ object Uniquify extends Transform { t: BundleType, namespace: collection.mutable.HashSet[String]) (implicit sinfo: Info, mname: String): BundleType = { - def recUniquifyNames(t: Type, namespace: collection.mutable.HashSet[String]): Type = t match { + def recUniquifyNames(t: Type, namespace: collection.mutable.HashSet[String]): (Type, Seq[String]) = t match { case tx: BundleType => // First add everything - val newFields = tx.fields map { f => + val newFieldsAndElts = tx.fields map { f => val newName = findValidPrefix(f.name, Seq(""), namespace) namespace += newName Field(newName, f.flip, f.tpe) } map { f => f.tpe match { - case _: GroundType => f + case _: GroundType => (f, Seq[String](f.name)) case _ => - val tpe = recUniquifyNames(f.tpe, collection.mutable.HashSet()) - val elts = enumerateNames(tpe) + val (tpe, eltsx) = recUniquifyNames(f.tpe, collection.mutable.HashSet()) // Need leading _ for findValidPrefix, it doesn't add _ for checks - val eltsNames = elts map (e => "_" + LowerTypes.loweredName(e)) + val eltsNames: Seq[String] = eltsx map (e => "_" + e) val prefix = findValidPrefix(f.name, eltsNames, namespace) // We added f.name in previous map, delete if we change it if (prefix != f.name) { namespace -= f.name namespace += prefix } - namespace ++= (elts map (e => LowerTypes.loweredName(prefix +: e))) - Field(prefix, f.flip, tpe) + val newElts: Seq[String] = eltsx map (e => LowerTypes.loweredName(prefix +: Seq(e))) + namespace ++= newElts + (Field(prefix, f.flip, tpe), prefix +: newElts) } } - BundleType(newFields) + val (newFields, elts) = newFieldsAndElts.unzip + (BundleType(newFields), elts.flatten) case tx: VectorType => - VectorType(recUniquifyNames(tx.tpe, namespace), tx.size) - case tx => tx + val (tpe, elts) = recUniquifyNames(tx.tpe, namespace) + val newElts = ((0 until tx.size) map (i => i.toString)) ++ + ((0 until tx.size) flatMap { i => + elts map (e => LowerTypes.loweredName(Seq(i.toString, e))) + }) + (VectorType(tpe, tx.size), newElts) + case tx => (tx, Nil) } - recUniquifyNames(t, namespace) match { + val (tpe, _) = recUniquifyNames(t, namespace) + tpe match { case tx: BundleType => tx case tx => throwInternalError(s"uniquifyNames: shouldn't be here - $tx") } diff --git a/src/main/scala/firrtl/transforms/CheckCombLoops.scala b/src/main/scala/firrtl/transforms/CheckCombLoops.scala index 7afce210..9016dca4 100644 --- a/src/main/scala/firrtl/transforms/CheckCombLoops.scala +++ b/src/main/scala/firrtl/transforms/CheckCombLoops.scala @@ -7,6 +7,8 @@ import scala.collection.immutable.HashSet import scala.collection.immutable.HashMap import annotation.tailrec +import Function.tupled + import firrtl._ import firrtl.ir._ import firrtl.passes.{Errors, PassException} @@ -26,6 +28,23 @@ object CheckCombLoops { case object DontCheckCombLoopsAnnotation extends NoTargetAnnotation +case class ExtModulePathAnnotation(source: ReferenceTarget, sink: ReferenceTarget) extends Annotation { + if (!source.isLocal || !sink.isLocal || source.module != sink.module) { + throwInternalError(s"ExtModulePathAnnotation must connect two local targets from the same module") + } + + override def getTargets: Seq[ReferenceTarget] = Seq(source, sink) + + override def update(renames: RenameMap): Seq[Annotation] = { + val sources = renames.get(source).getOrElse(Seq(source)) + val sinks = renames.get(sink).getOrElse(Seq(sink)) + val paths = sources flatMap { s => sinks.map((s, _)) } + paths.collect { + case (source: ReferenceTarget, sink: ReferenceTarget) => ExtModulePathAnnotation(source, sink) + } + } +} + case class CombinationalPath(sink: ComponentName, sources: Seq[ComponentName]) extends Annotation { override def update(renames: RenameMap): Seq[Annotation] = { val newSources: Seq[IsComponent] = sources.flatMap { s => renames.get(s).getOrElse(Seq(s.toTarget)) }.collect {case x: IsComponent if x.isLocal => x} @@ -36,12 +55,12 @@ case class CombinationalPath(sink: ComponentName, sources: Seq[ComponentName]) e /** Finds and detects combinational logic loops in a circuit, if any * exist. Returns the input circuit with no modifications. - * + * * @throws CombLoopException if a loop is found * @note Input form: Low FIRRTL * @note Output form: Low FIRRTL (identity transform) * @note The pass looks for loops through combinational-read memories - * @note The pass cannot find loops that pass through ExtModules + * @note The pass relies on ExtModulePathAnnotations to find loops through ExtModules * @note The pass will throw exceptions on "false paths" */ class CheckCombLoops extends Transform with RegisteredTransform { @@ -69,6 +88,8 @@ class CheckCombLoops extends Transform with RegisteredTransform { private case class LogicNode(name: String, inst: Option[String] = None, memport: Option[String] = None) private def toLogicNode(e: Expression): LogicNode = e match { + case idx: WSubIndex => + toLogicNode(idx.expr) case r: WRef => LogicNode(r.name) case s: WSubField => @@ -95,29 +116,29 @@ class CheckCombLoops extends Transform with RegisteredTransform { private def getStmtDeps( simplifiedModules: mutable.Map[String,DiGraph[LogicNode]], deps: MutableDiGraph[LogicNode])(s: Statement): Unit = s match { - case Connect(_,loc,expr) => - val lhs = toLogicNode(loc) - if (deps.contains(lhs)) { - getExprDeps(deps, lhs)(expr) - } - case w: DefWire => - deps.addVertex(LogicNode(w.name)) - case n: DefNode => - val lhs = LogicNode(n.name) - deps.addVertex(lhs) - getExprDeps(deps, lhs)(n.value) - case m: DefMemory if (m.readLatency == 0) => - for (rp <- m.readers) { - val dataNode = deps.addVertex(LogicNode("data",Some(m.name),Some(rp))) - deps.addEdge(dataNode, deps.addVertex(LogicNode("addr",Some(m.name),Some(rp)))) - deps.addEdge(dataNode, deps.addVertex(LogicNode("en",Some(m.name),Some(rp)))) - } - case i: WDefInstance => - val iGraph = simplifiedModules(i.module).transformNodes(n => n.copy(inst = Some(i.name))) - iGraph.getVertices.foreach(deps.addVertex(_)) - iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v,_) } }) - case _ => - s.foreach(getStmtDeps(simplifiedModules,deps)) + case Connect(_,loc,expr) => + val lhs = toLogicNode(loc) + if (deps.contains(lhs)) { + getExprDeps(deps, lhs)(expr) + } + case w: DefWire => + deps.addVertex(LogicNode(w.name)) + case n: DefNode => + val lhs = LogicNode(n.name) + deps.addVertex(lhs) + getExprDeps(deps, lhs)(n.value) + case m: DefMemory if (m.readLatency == 0) => + for (rp <- m.readers) { + val dataNode = deps.addVertex(LogicNode("data",Some(m.name),Some(rp))) + deps.addEdge(dataNode, deps.addVertex(LogicNode("addr",Some(m.name),Some(rp)))) + deps.addEdge(dataNode, deps.addVertex(LogicNode("en",Some(m.name),Some(rp)))) + } + case i: WDefInstance => + val iGraph = simplifiedModules(i.module).transformNodes(n => n.copy(inst = Some(i.name))) + iGraph.getVertices.foreach(deps.addVertex(_)) + iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v,_) } }) + case _ => + s.foreach(getStmtDeps(simplifiedModules,deps)) } /* @@ -129,7 +150,7 @@ class CheckCombLoops extends Transform with RegisteredTransform { private def expandInstancePaths( m: String, moduleGraphs: mutable.Map[String,DiGraph[LogicNode]], - moduleDeps: Map[String, Map[String,String]], + moduleDeps: Map[String, Map[String,String]], prefix: Seq[String], path: Seq[LogicNode]): Seq[String] = { def absNodeName(prefix: Seq[String], n: LogicNode) = @@ -173,51 +194,65 @@ class CheckCombLoops extends Transform with RegisteredTransform { * module is converted to a netlist and analyzed locally, with its * subinstances represented by trivial, simplified subgraphs. The * overall outline of the process is: - * + * * 1. Create a graph of module instance dependances * 2. Linearize this acyclic graph - * + * * 3. Generate a local netlist; replace any instances with * simplified subgraphs representing connectivity of their IOs - * + * * 4. Check for nontrivial strongly connected components - * + * * 5. Create a reduced representation of the netlist with only the * module IOs as nodes, where output X (which must be a ground type, * as only low FIRRTL is supported) will have an edge to input Y if * and only if it combinationally depends on input Y. Associate this * reduced graph with the module for future use. */ - private def run(c: Circuit): (Circuit, Seq[Annotation]) = { + private def run(state: CircuitState) = { + val c = state.circuit val errors = new Errors() - /* TODO(magyar): deal with exmodules! No pass warnings currently - * exist. Maybe warn when iterating through modules. - */ + val extModulePaths = state.annotations.groupBy { + case ann: ExtModulePathAnnotation => ModuleTarget(c.main, ann.source.module) + case ann: Annotation => CircuitTarget(c.main) + } val moduleMap = c.modules.map({m => (m.name,m) }).toMap val iGraph = new InstanceGraph(c).graph val moduleDeps = iGraph.getEdgeMap.map({ case (k,v) => (k.module, (v map { i => (i.name, i.module) }).toMap) }).toMap val topoSortedModules = iGraph.transformNodes(_.module).linearize.reverse map { moduleMap(_) } val moduleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]] val simplifiedModuleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]] - for (m <- topoSortedModules) { - val internalDeps = new MutableDiGraph[LogicNode] - m.ports.foreach({ p => internalDeps.addVertex(LogicNode(p.name)) }) - m.foreach(getStmtDeps(simplifiedModuleGraphs, internalDeps)) - val moduleGraph = DiGraph(internalDeps) - moduleGraphs(m.name) = moduleGraph - simplifiedModuleGraphs(m.name) = moduleGraphs(m.name).simplify((m.ports map { p => LogicNode(p.name) }).toSet) - // Find combinational nodes with self-edges; this is *NOT* the same as length-1 SCCs! - for (unitLoopNode <- moduleGraph.getVertices.filter(v => moduleGraph.getEdges(v).contains(v))) { - errors.append(new CombLoopException(m.info, m.name, Seq(unitLoopNode.name))) - } - for (scc <- moduleGraph.findSCCs.filter(_.length > 1)) { - val sccSubgraph = moduleGraph.subgraph(scc.toSet) - val cycle = findCycleInSCC(sccSubgraph) - (cycle zip cycle.tail).foreach({ case (a,b) => require(moduleGraph.getEdges(a).contains(b)) }) - val expandedCycle = expandInstancePaths(m.name, moduleGraphs, moduleDeps, Seq(m.name), cycle.reverse) - errors.append(new CombLoopException(m.info, m.name, expandedCycle)) - } + topoSortedModules.foreach { + case em: ExtModule => + val portSet = em.ports.map(p => LogicNode(p.name)).toSet + val extModuleDeps = new MutableDiGraph[LogicNode] + portSet.foreach(extModuleDeps.addVertex(_)) + extModulePaths.getOrElse(ModuleTarget(c.main, em.name), Nil).collect { + case a: ExtModulePathAnnotation => extModuleDeps.addPairWithEdge(LogicNode(a.sink.ref), LogicNode(a.source.ref)) + } + moduleGraphs(em.name) = DiGraph(extModuleDeps).simplify(portSet) + simplifiedModuleGraphs(em.name) = moduleGraphs(em.name) + case m: Module => + val portSet = m.ports.map(p => LogicNode(p.name)).toSet + val internalDeps = new MutableDiGraph[LogicNode] + portSet.foreach(internalDeps.addVertex(_)) + m.foreach(getStmtDeps(simplifiedModuleGraphs, internalDeps)) + val moduleGraph = DiGraph(internalDeps) + moduleGraphs(m.name) = moduleGraph + simplifiedModuleGraphs(m.name) = moduleGraphs(m.name).simplify(portSet) + // Find combinational nodes with self-edges; this is *NOT* the same as length-1 SCCs! + for (unitLoopNode <- moduleGraph.getVertices.filter(v => moduleGraph.getEdges(v).contains(v))) { + errors.append(new CombLoopException(m.info, m.name, Seq(unitLoopNode.name))) + } + for (scc <- moduleGraph.findSCCs.filter(_.length > 1)) { + val sccSubgraph = moduleGraph.subgraph(scc.toSet) + val cycle = findCycleInSCC(sccSubgraph) + (cycle zip cycle.tail).foreach({ case (a,b) => require(moduleGraph.getEdges(a).contains(b)) }) + val expandedCycle = expandInstancePaths(m.name, moduleGraphs, moduleDeps, Seq(m.name), cycle.reverse) + errors.append(new CombLoopException(m.info, m.name, expandedCycle)) + } + case m => throwInternalError(s"Module ${m.name} has unrecognized type") } val mn = ModuleName(c.main, CircuitName(c.main)) val annos = simplifiedModuleGraphs(c.main).getEdgeMap.collect { case (from, tos) if tos.nonEmpty => @@ -225,8 +260,17 @@ class CheckCombLoops extends Transform with RegisteredTransform { val sources = tos.map(x => ComponentName(x.name, mn)) CombinationalPath(sink, sources.toSeq) } - errors.trigger() - (c, annos.toSeq) + (state.copy(annotations = state.annotations ++ annos), errors, simplifiedModuleGraphs) + } + + /** + * Returns a Map from Module name to port connectivity + */ + def analyze(state: CircuitState): collection.Map[String,DiGraph[String]] = { + val (result, errors, connectivity) = run(state) + connectivity.map { + case (k, v) => (k, v.transformNodes(ln => ln.name)) + } } def execute(state: CircuitState): CircuitState = { @@ -235,8 +279,9 @@ class CheckCombLoops extends Transform with RegisteredTransform { logger.warn("Skipping Combinational Loop Detection") state } else { - val (result, annos) = run(state.circuit) - CircuitState(result, outputForm, state.annotations ++ annos, state.renames) + val (result, errors, connectivity) = run(state) + errors.trigger() + result } } } diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index da7f1a46..8a273476 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -67,7 +67,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { } object FoldADD extends FoldCommutativeOp { - def fold(c1: Literal, c2: Literal) = (c1, c2) match { + def fold(c1: Literal, c2: Literal) = ((c1, c2): @unchecked) match { case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1)) case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1)) } @@ -137,6 +137,13 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { } } + private def foldDynamicShiftLeft(e: DoPrim) = e.args.last match { + case UIntLiteral(v, IntWidth(w)) => + val shl = DoPrim(Shl, Seq(e.args.head), Seq(v), UnknownType) + pad(PrimOps.set_primop_type(shl), e.tpe) + case _ => e + } + private def foldShiftRight(e: DoPrim) = e.consts.head.toInt match { case 0 => e.args.head case x => e.args.head match { @@ -148,6 +155,14 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { } } + private def foldDynamicShiftRight(e: DoPrim) = e.args.last match { + case UIntLiteral(v, IntWidth(w)) => + val shr = DoPrim(Shr, Seq(e.args.head), Seq(v), UnknownType) + pad(PrimOps.set_primop_type(shr), e.tpe) + case _ => e + } + + private def foldComparison(e: DoPrim) = { def foldIfZeroedArg(x: Expression): Expression = { def isUInt(e: Expression): Boolean = e.tpe match { @@ -221,7 +236,9 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { private def constPropPrim(e: DoPrim): Expression = e.op match { case Shl => foldShiftLeft(e) + case Dshl => foldDynamicShiftLeft(e) case Shr => foldShiftRight(e) + case Dshr => foldDynamicShiftRight(e) case Cat => foldConcat(e) case Add => FoldADD(e) case And => FoldAND(e) @@ -277,6 +294,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e) def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e) + private def constPropExpression(nodeMap: NodeMap, instMap: Map[String, String], constSubOutputs: Map[String, Map[String, Literal]])(e: Expression): Expression = { val old = e map constPropExpression(nodeMap, instMap, constSubOutputs) val propagated = old match { @@ -290,7 +308,9 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths { constSubOutputs.get(module).flatMap(_.get(pname)).getOrElse(ref) case x => x } - propagated + // We're done when the Expression no longer changes + if (propagated eq old) propagated + else constPropExpression(nodeMap, instMap, constSubOutputs)(propagated) } /** Constant propagate a Module diff --git a/src/main/scala/firrtl/transforms/GroupComponents.scala b/src/main/scala/firrtl/transforms/GroupComponents.scala index 55828e0a..8c36bb6d 100644 --- a/src/main/scala/firrtl/transforms/GroupComponents.scala +++ b/src/main/scala/firrtl/transforms/GroupComponents.scala @@ -4,7 +4,7 @@ import firrtl._ import firrtl.Mappers._ import firrtl.ir._ import firrtl.annotations.{Annotation, ComponentName} -import firrtl.passes.{InferTypes, LowerTypes, MemPortUtils} +import firrtl.passes.{InferTypes, LowerTypes, MemPortUtils, ResolveKinds} import firrtl.Utils.kind import firrtl.graph.{DiGraph, MutableDiGraph} @@ -62,7 +62,7 @@ class GroupComponents extends firrtl.Transform { case other => Seq(other) } val cs = state.copy(circuit = state.circuit.copy(modules = newModules)) - val csx = InferTypes.execute(cs) + val csx = ResolveKinds.execute(InferTypes.execute(cs)) csx } @@ -119,6 +119,11 @@ class GroupComponents extends firrtl.Transform { } } + // Unused nodes are not reachable from any group nor the root--add them to root group + for ((v, _) <- deps.getEdgeMap) { + reachableNodes.getOrElseUpdate(v, mutable.Set("")) + } + // Add nodes who are reached by a single group, to that group reachableNodes.foreach { case (node, membership) => if(membership.size == 1) { @@ -307,7 +312,9 @@ class GroupComponents extends firrtl.Transform { } def onStmt(stmt: Statement): Unit = stmt match { case w: WDefInstance => - case h: IsDeclaration => h map onExpr(WRef(h.name)) + case h: IsDeclaration => + bidirGraph.addVertex(h.name) + h map onExpr(WRef(h.name)) case Attach(_, exprs) => // Add edge between each expression exprs.tail map onExpr(getWRef(exprs.head)) case Connect(_, loc, expr) => diff --git a/src/main/scala/firrtl/transforms/IdentityTransform.scala b/src/main/scala/firrtl/transforms/IdentityTransform.scala new file mode 100644 index 00000000..a39ca4b7 --- /dev/null +++ b/src/main/scala/firrtl/transforms/IdentityTransform.scala @@ -0,0 +1,17 @@ +// See LICENSE for license details. + +package firrtl.transforms + +import firrtl.{CircuitForm, CircuitState, Transform} + +/** Transform that applies an identity function. This returns an unmodified [[CircuitState]]. + * @param form the input and output [[CircuitForm]] + */ +class IdentityTransform(form: CircuitForm) extends Transform { + + final override def inputForm: CircuitForm = form + final override def outputForm: CircuitForm = form + + final def execute(state: CircuitState): CircuitState = state + +} diff --git a/src/main/scala/firrtl/util/ClassUtils.scala b/src/main/scala/firrtl/util/ClassUtils.scala new file mode 100644 index 00000000..1b388035 --- /dev/null +++ b/src/main/scala/firrtl/util/ClassUtils.scala @@ -0,0 +1,19 @@ +package firrtl.util + +object ClassUtils { + /** Determine if a named class is loaded. + * + * @param name - name of the class: "foo.bar" or "org.foo.bar" + * @return true if the class has been loaded (is accessible), false otherwise. + */ + def isClassLoaded(name: String): Boolean = { + val found = try { + Class.forName(name, false, getClass.getClassLoader) != null + } catch { + case e: ClassNotFoundException => false + case x: Throwable => throw x + } +// println(s"isClassLoaded: %s $name".format(if (found) "found" else "didn't find")) + found + } +} diff --git a/src/main/scala/firrtl/util/TestOptions.scala b/src/main/scala/firrtl/util/TestOptions.scala new file mode 100644 index 00000000..9ee99f8c --- /dev/null +++ b/src/main/scala/firrtl/util/TestOptions.scala @@ -0,0 +1,13 @@ +package firrtl.util + +import ClassUtils.isClassLoaded + +object TestOptions { + // Our timing is inaccurate if we're running tests under coverage. + // If any of the classes known to be associated with evaluating coverage are loaded, + // assume we're running tests under coverage. + // NOTE: We assume we need only ask the class loader that loaded us. + // If it was loaded by another class loader (outside of our hierarchy), it wouldn't be available to us. + val coverageClasses = List("scoverage.Platform", "com.intellij.rt.coverage.instrumentation.TouchCounter") + val accurateTiming = !coverageClasses.exists(isClassLoaded(_)) +} diff --git a/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala b/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala index 8fc7dda9..98472f14 100644 --- a/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala +++ b/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala @@ -165,6 +165,92 @@ class CheckCombLoopsSpec extends SimpleTransformSpec { } } + "Combinational loop through an annotated ExtModule" should "throw an exception" in { + val input = """circuit hasloops : + | extmodule blackbox : + | input in : UInt<1> + | output out : UInt<1> + | module hasloops : + | input clk : Clock + | input a : UInt<1> + | input b : UInt<1> + | output c : UInt<1> + | output d : UInt<1> + | wire y : UInt<1> + | wire z : UInt<1> + | c <= b + | inst inner of blackbox + | inner.in <= y + | z <= inner.out + | y <= z + | d <= z + |""".stripMargin + + val mt = ModuleTarget("hasloops", "blackbox") + val annos = AnnotationSeq(Seq(ExtModulePathAnnotation(mt.ref("in"), mt.ref("out")))) + val writer = new java.io.StringWriter + intercept[CheckCombLoops.CombLoopException] { + compile(CircuitState(parse(input), ChirrtlForm, annos), writer) + } + } + + "Loop-free circuit with ExtModulePathAnnotations" should "not throw an exception" in { + val input = """circuit hasnoloops : + | extmodule blackbox : + | input in1 : UInt<1> + | input in2 : UInt<1> + | output out1 : UInt<1> + | output out2 : UInt<1> + | module hasnoloops : + | input clk : Clock + | input a : UInt<1> + | output b : UInt<1> + | wire x : UInt<1> + | inst inner of blackbox + | inner.in1 <= a + | x <= inner.out1 + | inner.in2 <= x + | b <= inner.out2 + |""".stripMargin + + val mt = ModuleTarget("hasnoloops", "blackbox") + val annos = AnnotationSeq(Seq( + ExtModulePathAnnotation(mt.ref("in1"), mt.ref("out1")), + ExtModulePathAnnotation(mt.ref("in2"), mt.ref("out2")))) + val writer = new java.io.StringWriter + compile(CircuitState(parse(input), ChirrtlForm, annos), writer) + } + + "Combinational loop through an output RHS reference" should "throw an exception" in { + val input = """circuit hasloops : + | module thru : + | input in : UInt<1> + | output tmp : UInt<1> + | output out : UInt<1> + | tmp <= in + | out <= tmp + | module hasloops : + | input clk : Clock + | input a : UInt<1> + | input b : UInt<1> + | output c : UInt<1> + | output d : UInt<1> + | wire y : UInt<1> + | wire z : UInt<1> + | c <= b + | inst inner of thru + | inner.in <= y + | z <= inner.out + | y <= z + | d <= z + |""".stripMargin + + val writer = new java.io.StringWriter + intercept[CheckCombLoops.CombLoopException] { + compile(CircuitState(parse(input), ChirrtlForm), writer) + } + } + "Multiple simple loops in one SCC" should "throw an exception" in { val input = """circuit hasloops : | module hasloops : diff --git a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala index 74d39286..a4473fe7 100644 --- a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala +++ b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala @@ -108,6 +108,23 @@ circuit foo : parse(res.getEmittedCircuit.value) } + "An mport that refers to an undefined memory" should "have a helpful error message" in { + val input = + """circuit testTestModule : + | module testTestModule : + | input clock : Clock + | input reset : UInt<1> + | output io : {flip in : UInt<10>, out : UInt<10>} + | + | node _T_10 = bits(io.in, 1, 0) + | read mport _T_11 = m[_T_10], clock + | io.out <= _T_11""".stripMargin + + intercept[PassException]{ + (new LowFirrtlCompiler).compile(CircuitState(parse(input), ChirrtlForm), Seq()).circuit + }.getMessage should startWith ("Undefined memory m referenced by mport _T_11") + } + ignore should "Memories should not have validif on port clocks when declared in a when" in { val input = """;buildInfoPackage: chisel3, version: 3.0-SNAPSHOT, scalaVersion: 2.11.11, sbtVersion: 0.13.16, builtAtString: 2017-10-06 20:55:20.367, builtAtMillis: 1507323320367 diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index 603ddc25..8a69fcaa 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -734,6 +734,24 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { """.stripMargin (parse(exec(input))) should be(parse(check)) } + + // Optimizing this mux gives: z <= pad(UInt<2>(0), 4) + // Thus this checks that we then optimize that pad + "ConstProp" should "optimize nested Expressions" in { + val input = + """circuit Top : + | module Top : + | output z : UInt<4> + | z <= mux(UInt(1), UInt<2>(0), UInt<4>(0)) + """.stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<4> + | z <= UInt<4>("h0") + """.stripMargin + (parse(exec(input))) should be(parse(check)) + } } // More sophisticated tests of the full compiler @@ -1104,6 +1122,77 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { | z <= _T_61""".stripMargin execute(input, check, Seq.empty) } + + behavior of "ConstProp" + + it should "optimize shl of constants" in { + val input = + """circuit Top : + | module Top : + | output z : UInt<7> + | z <= shl(UInt(5), 4) + """.stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<7> + | z <= UInt<7>("h50") + """.stripMargin + execute(input, check, Seq.empty) + } + + it should "optimize shr of constants" in { + val input = + """circuit Top : + | module Top : + | output z : UInt<1> + | z <= shr(UInt(5), 2) + """.stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<1> + | z <= UInt<1>("h1") + """.stripMargin + execute(input, check, Seq.empty) + } + + // Due to #866, we need dshl optimized away or it'll become a dshlw and error in parsing + // Include cat to verify width is correct + it should "optimize dshl of constant" in { + val input = + """circuit Top : + | module Top : + | output z : UInt<8> + | node n = dshl(UInt<1>(0), UInt<2>(0)) + | z <= cat(UInt<4>("hf"), n) + """.stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<8> + | z <= UInt<8>("hf0") + """.stripMargin + execute(input, check, Seq.empty) + } + + // Include cat and constants to verify width is correct + it should "optimize dshr of constant" in { + val input = + """circuit Top : + | module Top : + | output z : UInt<8> + | node n = dshr(UInt<4>(0), UInt<2>(2)) + | z <= cat(UInt<4>("hf"), n) + """.stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<8> + | z <= UInt<8>("hf0") + """.stripMargin + execute(input, check, Seq.empty) + } } diff --git a/src/test/scala/firrtlTests/UniquifySpec.scala b/src/test/scala/firrtlTests/UniquifySpec.scala index 43d1e733..561f0a84 100644 --- a/src/test/scala/firrtlTests/UniquifySpec.scala +++ b/src/test/scala/firrtlTests/UniquifySpec.scala @@ -12,6 +12,7 @@ import firrtl._ import firrtl.annotations._ import firrtl.annotations.TargetToken._ import firrtl.transforms.DontTouchAnnotation +import firrtl.util.TestOptions class UniquifySpec extends FirrtlFlatSpec { @@ -283,4 +284,31 @@ class UniquifySpec extends FirrtlFlatSpec { executeTest(input, expected) } + + it should "quickly rename deep bundles" in { + // We use a fixed time to determine if this test passed or failed. + // This test would pass under normal conditions, but would fail during coverage tests. + // Since executions times vary significantly under coverage testing, we check a global + // to see if timing measurements are accurate enough to enforce the timing checks. + val maxMs = 8000.0 + + def mkType(i: Int): String = { + if(i == 0) "UInt<8>" else s"{x: ${mkType(i - 1)}}" + } + + val depth = 500 + + val input = + s"""circuit Test: + | module Test : + | input in: ${mkType(depth)} + | output out: ${mkType(depth)} + | out <= in + |""".stripMargin + + val (renameMs, _) = Utils.time(compileToVerilog(input)) + + if (TestOptions.accurateTiming) + renameMs shouldBe < (maxMs) + } } diff --git a/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala b/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala index 3a32ec71..c54e02e3 100644 --- a/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala +++ b/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala @@ -3,6 +3,10 @@ package transforms import firrtl.annotations.{CircuitName, ComponentName, ModuleName} import firrtl.transforms.{GroupAnnotation, GroupComponents} +import firrtl._ +import firrtl.ir._ + +import FirrtlCheckers._ class GroupComponentsSpec extends LowTransformSpec { def transform = new GroupComponents() @@ -42,6 +46,43 @@ class GroupComponentsSpec extends LowTransformSpec { """.stripMargin execute(input, check, groups) } + "Grouping" should "work even when there are unused nodes" in { + val input = + s"""circuit $top : + | module $top : + | input in: UInt<16> + | output out: UInt<16> + | node n = UInt<16>("h0") + | wire w : UInt<16> + | wire a : UInt<16> + | wire b : UInt<16> + | a <= UInt<16>("h0") + | b <= a + | w <= in + | out <= w + """.stripMargin + val groups = Seq( + GroupAnnotation(Seq(topComp("w")), "Child", "inst", Some("_OUT"), Some("_IN")) + ) + val check = + s"""circuit Top : + | module $top : + | input in: UInt<16> + | output out: UInt<16> + | inst inst of Child + | node n = UInt<16>("h0") + | inst.in_IN <= in + | node a = UInt<16>("h0") + | node b = a + | out <= inst.w_OUT + | module Child : + | input in_IN : UInt<16> + | output w_OUT : UInt<16> + | node w = in_IN + | w_OUT <= w + """.stripMargin + execute(input, check, groups) + } "The two sets of instances" should "be grouped" in { val input = @@ -288,3 +329,35 @@ class GroupComponentsSpec extends LowTransformSpec { execute(input, check, groups) } } + +class GroupComponentsIntegrationSpec extends FirrtlFlatSpec { + def topComp(name: String): ComponentName = ComponentName(name, ModuleName("Top", CircuitName("Top"))) + "Grouping" should "properly set kinds" in { + val input = + """circuit Top : + | module Top : + | input clk: Clock + | input data: UInt<16> + | output out: UInt<16> + | reg r: UInt<16>, clk + | r <= data + | out <= r + """.stripMargin + val groups = Seq( + GroupAnnotation(Seq(topComp("r")), "MyModule", "inst", Some("_OUT"), Some("_IN")) + ) + val result = (new VerilogCompiler).compileAndEmit( + CircuitState(parse(input), ChirrtlForm, groups), + Seq(new GroupComponents) + ) + result should containTree { + case Connect(_, WSubField(WRef("inst",_, InstanceKind,_), "data_IN", _,_), WRef("data",_,_,_)) => true + } + result should containTree { + case Connect(_, WSubField(WRef("inst",_, InstanceKind,_), "clk_IN", _,_), WRef("clk",_,_,_)) => true + } + result should containTree { + case Connect(_, WRef("out",_,_,_), WSubField(WRef("inst",_, InstanceKind,_), "r_OUT", _,_)) => true + } + } +} |
