diff options
| author | chick | 2020-08-14 19:47:53 -0700 |
|---|---|---|
| committer | Jack Koenig | 2020-08-14 19:47:53 -0700 |
| commit | 6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch) | |
| tree | 2ed103ee80b0fba613c88a66af854ae9952610ce /src/main/scala/firrtl/passes | |
| parent | b516293f703c4de86397862fee1897aded2ae140 (diff) | |
All of src/ formatted with scalafmt
Diffstat (limited to 'src/main/scala/firrtl/passes')
58 files changed, 2592 insertions, 2120 deletions
diff --git a/src/main/scala/firrtl/passes/CInferMDir.scala b/src/main/scala/firrtl/passes/CInferMDir.scala index b4819751..1fe8d57c 100644 --- a/src/main/scala/firrtl/passes/CInferMDir.scala +++ b/src/main/scala/firrtl/passes/CInferMDir.scala @@ -18,60 +18,61 @@ object CInferMDir extends Pass { def infer_mdir_e(mports: MPortDirMap, dir: MPortDir)(e: Expression): Expression = e match { case e: Reference => - mports get e.name match { + mports.get(e.name) match { case None => - case Some(p) => mports(e.name) = (p, dir) match { - case (MInfer, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir") - case (MInfer, MWrite) => MWrite - case (MInfer, MRead) => MRead - case (MInfer, MReadWrite) => MReadWrite - case (MWrite, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir") - case (MWrite, MWrite) => MWrite - case (MWrite, MRead) => MReadWrite - case (MWrite, MReadWrite) => MReadWrite - case (MRead, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir") - case (MRead, MWrite) => MReadWrite - case (MRead, MRead) => MRead - case (MRead, MReadWrite) => MReadWrite - case (MReadWrite, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir") - case (MReadWrite, MWrite) => MReadWrite - case (MReadWrite, MRead) => MReadWrite - case (MReadWrite, MReadWrite) => MReadWrite - } + case Some(p) => + mports(e.name) = (p, dir) match { + case (MInfer, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir") + case (MInfer, MWrite) => MWrite + case (MInfer, MRead) => MRead + case (MInfer, MReadWrite) => MReadWrite + case (MWrite, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir") + case (MWrite, MWrite) => MWrite + case (MWrite, MRead) => MReadWrite + case (MWrite, MReadWrite) => MReadWrite + case (MRead, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir") + case (MRead, MWrite) => MReadWrite + case (MRead, MRead) => MRead + case (MRead, MReadWrite) => MReadWrite + case (MReadWrite, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir") + case (MReadWrite, MWrite) => MReadWrite + case (MReadWrite, MRead) => MReadWrite + case (MReadWrite, MReadWrite) => MReadWrite + } } e case e: SubAccess => infer_mdir_e(mports, dir)(e.expr) infer_mdir_e(mports, MRead)(e.index) // index can't be a write port e - case e => e map infer_mdir_e(mports, dir) + case e => e.map(infer_mdir_e(mports, dir)) } def infer_mdir_s(mports: MPortDirMap)(s: Statement): Statement = s match { case sx: CDefMPort => - mports(sx.name) = sx.direction - sx map infer_mdir_e(mports, MRead) + mports(sx.name) = sx.direction + sx.map(infer_mdir_e(mports, MRead)) case sx: Connect => - infer_mdir_e(mports, MRead)(sx.expr) - infer_mdir_e(mports, MWrite)(sx.loc) - sx + infer_mdir_e(mports, MRead)(sx.expr) + infer_mdir_e(mports, MWrite)(sx.loc) + sx case sx: PartialConnect => - infer_mdir_e(mports, MRead)(sx.expr) - infer_mdir_e(mports, MWrite)(sx.loc) - sx - case sx => sx map infer_mdir_s(mports) map infer_mdir_e(mports, MRead) + infer_mdir_e(mports, MRead)(sx.expr) + infer_mdir_e(mports, MWrite)(sx.loc) + sx + case sx => sx.map(infer_mdir_s(mports)).map(infer_mdir_e(mports, MRead)) } def set_mdir_s(mports: MPortDirMap)(s: Statement): Statement = s match { - case sx: CDefMPort => sx copy (direction = mports(sx.name)) - case sx => sx map set_mdir_s(mports) + case sx: CDefMPort => sx.copy(direction = mports(sx.name)) + case sx => sx.map(set_mdir_s(mports)) } def infer_mdir(m: DefModule): DefModule = { val mports = new MPortDirMap - m map infer_mdir_s(mports) map set_mdir_s(mports) + m.map(infer_mdir_s(mports)).map(set_mdir_s(mports)) } def run(c: Circuit): Circuit = - c copy (modules = c.modules map infer_mdir) + c.copy(modules = c.modules.map(infer_mdir)) } diff --git a/src/main/scala/firrtl/passes/CheckChirrtl.scala b/src/main/scala/firrtl/passes/CheckChirrtl.scala index 9903f445..97d614c1 100644 --- a/src/main/scala/firrtl/passes/CheckChirrtl.scala +++ b/src/main/scala/firrtl/passes/CheckChirrtl.scala @@ -12,9 +12,7 @@ object CheckChirrtl extends Pass with CheckHighFormLike { override def prerequisites = Dependency[CheckScalaVersion] :: Nil override val optionalPrerequisiteOf = firrtl.stage.Forms.ChirrtlForm ++ - Seq( Dependency(CInferTypes), - Dependency(CInferMDir), - Dependency(RemoveCHIRRTL) ) + Seq(Dependency(CInferTypes), Dependency(CInferMDir), Dependency(RemoveCHIRRTL)) override def invalidates(a: Transform) = false diff --git a/src/main/scala/firrtl/passes/CheckFlows.scala b/src/main/scala/firrtl/passes/CheckFlows.scala index 3a9cc212..bc455a20 100644 --- a/src/main/scala/firrtl/passes/CheckFlows.scala +++ b/src/main/scala/firrtl/passes/CheckFlows.scala @@ -13,79 +13,87 @@ object CheckFlows extends Pass { override def prerequisites = Dependency(passes.ResolveFlows) +: firrtl.stage.Forms.WorkingIR override def optionalPrerequisiteOf = - Seq( Dependency[passes.InferBinaryPoints], - Dependency[passes.TrimIntervals], - Dependency[passes.InferWidths], - Dependency[transforms.InferResets] ) + Seq( + Dependency[passes.InferBinaryPoints], + Dependency[passes.TrimIntervals], + Dependency[passes.InferWidths], + Dependency[transforms.InferResets] + ) override def invalidates(a: Transform) = false type FlowMap = collection.mutable.HashMap[String, Flow] implicit def toStr(g: Flow): String = g match { - case SourceFlow => "source" - case SinkFlow => "sink" + case SourceFlow => "source" + case SinkFlow => "sink" case UnknownFlow => "unknown" - case DuplexFlow => "duplex" + case DuplexFlow => "duplex" } - class WrongFlow(info:Info, mname: String, expr: String, wrong: Flow, right: Flow) extends PassException( - s"$info: [module $mname] Expression $expr is used as a $wrong but can only be used as a $right.") + class WrongFlow(info: Info, mname: String, expr: String, wrong: Flow, right: Flow) + extends PassException( + s"$info: [module $mname] Expression $expr is used as a $wrong but can only be used as a $right." + ) - def run (c:Circuit): Circuit = { + def run(c: Circuit): Circuit = { val errors = new Errors() def get_flow(e: Expression, flows: FlowMap): Flow = e match { - case (e: WRef) => flows(e.name) + case (e: WRef) => flows(e.name) case (e: WSubIndex) => get_flow(e.expr, flows) case (e: WSubAccess) => get_flow(e.expr, flows) - case (e: WSubField) => e.expr.tpe match {case t: BundleType => - val f = (t.fields find (_.name == e.name)).get - times(get_flow(e.expr, flows), f.flip) - } + case (e: WSubField) => + e.expr.tpe match { + case t: BundleType => + val f = (t.fields.find(_.name == e.name)).get + times(get_flow(e.expr, flows), f.flip) + } case _ => SourceFlow } def flip_q(t: Type): Boolean = { def flip_rec(t: Type, f: Orientation): Boolean = t match { - case tx:BundleType => tx.fields exists ( - field => flip_rec(field.tpe, times(f, field.flip)) - ) + case tx: BundleType => tx.fields.exists(field => flip_rec(field.tpe, times(f, field.flip))) case tx: VectorType => flip_rec(tx.tpe, f) case tx => f == Flip } flip_rec(t, Default) } - def check_flow(info:Info, mname: String, flows: FlowMap, desired: Flow)(e:Expression): Unit = { - val flow = get_flow(e,flows) + def check_flow(info: Info, mname: String, flows: FlowMap, desired: Flow)(e: Expression): Unit = { + val flow = get_flow(e, flows) (flow, desired) match { case (SourceFlow, SinkFlow) => errors.append(new WrongFlow(info, mname, e.serialize, desired, flow)) - case (SinkFlow, SourceFlow) => kind(e) match { - case PortKind | InstanceKind if !flip_q(e.tpe) => // OK! - case _ => - errors.append(new WrongFlow(info, mname, e.serialize, desired, flow)) - } + case (SinkFlow, SourceFlow) => + kind(e) match { + case PortKind | InstanceKind if !flip_q(e.tpe) => // OK! + case _ => + errors.append(new WrongFlow(info, mname, e.serialize, desired, flow)) + } case _ => } - } + } - def check_flows_e (info:Info, mname: String, flows: FlowMap)(e:Expression): Unit = { + def check_flows_e(info: Info, mname: String, flows: FlowMap)(e: Expression): Unit = { e match { - case e: Mux => e foreach check_flow(info, mname, flows, SourceFlow) - case e: DoPrim => e.args foreach check_flow(info, mname, flows, SourceFlow) + case e: Mux => e.foreach(check_flow(info, mname, flows, SourceFlow)) + case e: DoPrim => e.args.foreach(check_flow(info, mname, flows, SourceFlow)) case _ => } - e foreach check_flows_e(info, mname, flows) + e.foreach(check_flows_e(info, mname, flows)) } def check_flows_s(minfo: Info, mname: String, flows: FlowMap)(s: Statement): Unit = { - val info = get_info(s) match { case NoInfo => minfo case x => x } + val info = get_info(s) match { + case NoInfo => minfo + case x => x + } s match { - case (s: DefWire) => flows(s.name) = DuplexFlow + case (s: DefWire) => flows(s.name) = DuplexFlow case (s: DefRegister) => flows(s.name) = DuplexFlow - case (s: DefMemory) => flows(s.name) = SourceFlow + case (s: DefMemory) => flows(s.name) = SourceFlow case (s: WDefInstance) => flows(s.name) = SourceFlow case (s: DefNode) => check_flow(info, mname, flows, SourceFlow)(s.value) @@ -94,7 +102,7 @@ object CheckFlows extends Pass { check_flow(info, mname, flows, SinkFlow)(s.loc) check_flow(info, mname, flows, SourceFlow)(s.expr) case (s: Print) => - s.args foreach check_flow(info, mname, flows, SourceFlow) + s.args.foreach(check_flow(info, mname, flows, SourceFlow)) check_flow(info, mname, flows, SourceFlow)(s.en) check_flow(info, mname, flows, SourceFlow)(s.clk) case (s: PartialConnect) => @@ -111,14 +119,14 @@ object CheckFlows extends Pass { check_flow(info, mname, flows, SourceFlow)(s.en) case _ => } - s foreach check_flows_e(info, mname, flows) - s foreach check_flows_s(minfo, mname, flows) + s.foreach(check_flows_e(info, mname, flows)) + s.foreach(check_flows_s(minfo, mname, flows)) } for (m <- c.modules) { val flows = new FlowMap - flows ++= (m.ports map (p => p.name -> to_flow(p.direction))) - m foreach check_flows_s(m.info, m.name, flows) + flows ++= (m.ports.map(p => p.name -> to_flow(p.direction))) + m.foreach(check_flows_s(m.info, m.name, flows)) } errors.trigger() c diff --git a/src/main/scala/firrtl/passes/CheckHighForm.scala b/src/main/scala/firrtl/passes/CheckHighForm.scala index 2f706d35..559c9060 100644 --- a/src/main/scala/firrtl/passes/CheckHighForm.scala +++ b/src/main/scala/firrtl/passes/CheckHighForm.scala @@ -27,66 +27,71 @@ trait CheckHighFormLike { this: Pass => scopes.find(_.contains(port.mem)).getOrElse(scopes.head) += port.name } def legalDecl(name: String): Boolean = !moduleNS.contains(name) - def legalRef(name: String): Boolean = scopes.exists(_.contains(name)) + def legalRef(name: String): Boolean = scopes.exists(_.contains(name)) def childScope(): ScopeView = new ScopeView(moduleNS, new NameSet +: scopes) } // Custom Exceptions - class NotUniqueException(info: Info, mname: String, name: String) extends PassException( - s"$info: [module $mname] Reference $name does not have a unique name.") - class InvalidLOCException(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Invalid connect to an expression that is not a reference or a WritePort.") - class NegUIntException(info: Info, mname: String) extends PassException( - s"$info: [module $mname] UIntLiteral cannot be negative.") - class UndeclaredReferenceException(info: Info, mname: String, name: String) extends PassException( - s"$info: [module $mname] Reference $name is not declared.") - class PoisonWithFlipException(info: Info, mname: String, name: String) extends PassException( - s"$info: [module $mname] Poison $name cannot be a bundle type with flips.") - class MemWithFlipException(info: Info, mname: String, name: String) extends PassException( - s"$info: [module $mname] Memory $name cannot be a bundle type with flips.") - class IllegalMemLatencyException(info: Info, mname: String, name: String) extends PassException( - s"$info: [module $mname] Memory $name must have non-negative read latency and positive write latency.") - class RegWithFlipException(info: Info, mname: String, name: String) extends PassException( - s"$info: [module $mname] Register $name cannot be a bundle type with flips.") - class InvalidAccessException(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Invalid access to non-reference.") - class ModuleNameNotUniqueException(info: Info, mname: String) extends PassException( - s"$info: Repeat definition of module $mname") - class DefnameConflictException(info: Info, mname: String, defname: String) extends PassException( - s"$info: defname $defname of extmodule $mname conflicts with an existing module") - class DefnameDifferentPortsException(info: Info, mname: String, defname: String) extends PassException( - s"""$info: ports of extmodule $mname with defname $defname are different for an extmodule with the same defname""") - class ModuleNotDefinedException(info: Info, mname: String, name: String) extends PassException( - s"$info: Module $name is not defined.") - class IncorrectNumArgsException(info: Info, mname: String, op: String, n: Int) extends PassException( - s"$info: [module $mname] Primop $op requires $n expression arguments.") - class IncorrectNumConstsException(info: Info, mname: String, op: String, n: Int) extends PassException( - s"$info: [module $mname] Primop $op requires $n integer arguments.") - class NegWidthException(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Width cannot be negative.") - class NegVecSizeException(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Vector type size cannot be negative.") - class NegMemSizeException(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Memory size cannot be negative or zero.") - class BadPrintfException(info: Info, mname: String, x: Char) extends PassException( - s"$info: [module $mname] Bad printf format: " + "\"%" + x + "\"") - class BadPrintfTrailingException(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Bad printf format: trailing " + "\"%\"") - class BadPrintfIncorrectNumException(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Bad printf format: incorrect number of arguments") - class InstanceLoop(info: Info, mname: String, loop: String) extends PassException( - s"$info: [module $mname] Has instance loop $loop") - class NoTopModuleException(info: Info, name: String) extends PassException( - s"$info: A single module must be named $name.") - class NegArgException(info: Info, mname: String, op: String, value: BigInt) extends PassException( - s"$info: [module $mname] Primop $op argument $value < 0.") - class LsbLargerThanMsbException(info: Info, mname: String, op: String, lsb: BigInt, msb: BigInt) extends PassException( - s"$info: [module $mname] Primop $op lsb $lsb > $msb.") - class ResetInputException(info: Info, mname: String, expr: Expression) extends PassException( - s"$info: [module $mname] Abstract Reset not allowed as top-level input: ${expr.serialize}") - class ResetExtModuleOutputException(info: Info, mname: String, expr: Expression) extends PassException( - s"$info: [module $mname] Abstract Reset not allowed as ExtModule output: ${expr.serialize}") - + class NotUniqueException(info: Info, mname: String, name: String) + extends PassException(s"$info: [module $mname] Reference $name does not have a unique name.") + class InvalidLOCException(info: Info, mname: String) + extends PassException( + s"$info: [module $mname] Invalid connect to an expression that is not a reference or a WritePort." + ) + class NegUIntException(info: Info, mname: String) + extends PassException(s"$info: [module $mname] UIntLiteral cannot be negative.") + class UndeclaredReferenceException(info: Info, mname: String, name: String) + extends PassException(s"$info: [module $mname] Reference $name is not declared.") + class PoisonWithFlipException(info: Info, mname: String, name: String) + extends PassException(s"$info: [module $mname] Poison $name cannot be a bundle type with flips.") + class MemWithFlipException(info: Info, mname: String, name: String) + extends PassException(s"$info: [module $mname] Memory $name cannot be a bundle type with flips.") + class IllegalMemLatencyException(info: Info, mname: String, name: String) + extends PassException( + s"$info: [module $mname] Memory $name must have non-negative read latency and positive write latency." + ) + class RegWithFlipException(info: Info, mname: String, name: String) + extends PassException(s"$info: [module $mname] Register $name cannot be a bundle type with flips.") + class InvalidAccessException(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Invalid access to non-reference.") + class ModuleNameNotUniqueException(info: Info, mname: String) + extends PassException(s"$info: Repeat definition of module $mname") + class DefnameConflictException(info: Info, mname: String, defname: String) + extends PassException(s"$info: defname $defname of extmodule $mname conflicts with an existing module") + class DefnameDifferentPortsException(info: Info, mname: String, defname: String) + extends PassException( + s"""$info: ports of extmodule $mname with defname $defname are different for an extmodule with the same defname""" + ) + class ModuleNotDefinedException(info: Info, mname: String, name: String) + extends PassException(s"$info: Module $name is not defined.") + class IncorrectNumArgsException(info: Info, mname: String, op: String, n: Int) + extends PassException(s"$info: [module $mname] Primop $op requires $n expression arguments.") + class IncorrectNumConstsException(info: Info, mname: String, op: String, n: Int) + extends PassException(s"$info: [module $mname] Primop $op requires $n integer arguments.") + class NegWidthException(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Width cannot be negative.") + class NegVecSizeException(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Vector type size cannot be negative.") + class NegMemSizeException(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Memory size cannot be negative or zero.") + class BadPrintfException(info: Info, mname: String, x: Char) + extends PassException(s"$info: [module $mname] Bad printf format: " + "\"%" + x + "\"") + class BadPrintfTrailingException(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Bad printf format: trailing " + "\"%\"") + class BadPrintfIncorrectNumException(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Bad printf format: incorrect number of arguments") + class InstanceLoop(info: Info, mname: String, loop: String) + extends PassException(s"$info: [module $mname] Has instance loop $loop") + class NoTopModuleException(info: Info, name: String) + extends PassException(s"$info: A single module must be named $name.") + class NegArgException(info: Info, mname: String, op: String, value: BigInt) + extends PassException(s"$info: [module $mname] Primop $op argument $value < 0.") + class LsbLargerThanMsbException(info: Info, mname: String, op: String, lsb: BigInt, msb: BigInt) + extends PassException(s"$info: [module $mname] Primop $op lsb $lsb > $msb.") + class ResetInputException(info: Info, mname: String, expr: Expression) + extends PassException(s"$info: [module $mname] Abstract Reset not allowed as top-level input: ${expr.serialize}") + class ResetExtModuleOutputException(info: Info, mname: String, expr: Expression) + extends PassException(s"$info: [module $mname] Abstract Reset not allowed as ExtModule output: ${expr.serialize}") // Is Chirrtl allowed for this check? If not, return an error def errorOnChirrtl(info: Info, mname: String, s: Statement): Option[PassException] @@ -94,12 +99,12 @@ trait CheckHighFormLike { this: Pass => def run(c: Circuit): Circuit = { val errors = new Errors() val moduleGraph = new ModuleGraph - val moduleNames = (c.modules map (_.name)).toSet + val moduleNames = (c.modules.map(_.name)).toSet val intModuleNames = c.modules.view.collect({ case m: Module => m.name }).toSet - c.modules.groupBy(_.name).filter(_._2.length > 1).flatMap(_._2).foreach { - m => errors.append(new ModuleNameNotUniqueException(m.info, m.name)) + c.modules.groupBy(_.name).filter(_._2.length > 1).flatMap(_._2).foreach { m => + errors.append(new ModuleNameNotUniqueException(m.info, m.name)) } /** Strip all widths from types */ @@ -110,16 +115,18 @@ trait CheckHighFormLike { this: Pass => val extmoduleCollidingPorts = c.modules.collect { case a: ExtModule => a - }.groupBy(a => (a.defname, a.params.nonEmpty)).map { - /* There are no parameters, so all ports must match exactly. */ - case (k@ (_, false), a) => - k -> a.map(_.copy(info=NoInfo)).map(_.ports.map(_.copy(info=NoInfo))).toSet - /* If there are parameters, then only port names must match because parameters could parameterize widths. - * This means that this check cannot produce false positives, but can have false negatives. - */ - case (k@ (_, true), a) => - k -> a.map(_.copy(info=NoInfo)).map(_.ports.map(_.copy(info=NoInfo).mapType(stripWidth))).toSet - }.filter(_._2.size > 1) + }.groupBy(a => (a.defname, a.params.nonEmpty)) + .map { + /* There are no parameters, so all ports must match exactly. */ + case (k @ (_, false), a) => + k -> a.map(_.copy(info = NoInfo)).map(_.ports.map(_.copy(info = NoInfo))).toSet + /* If there are parameters, then only port names must match because parameters could parameterize widths. + * This means that this check cannot produce false positives, but can have false negatives. + */ + case (k @ (_, true), a) => + k -> a.map(_.copy(info = NoInfo)).map(_.ports.map(_.copy(info = NoInfo).mapType(stripWidth))).toSet + } + .filter(_._2.size > 1) c.modules.collect { case a: ExtModule => @@ -129,7 +136,8 @@ trait CheckHighFormLike { this: Pass => case _ => } a match { - case ExtModule(info, name, _, defname, params) if extmoduleCollidingPorts.contains((defname, params.nonEmpty)) => + case ExtModule(info, name, _, defname, params) + if extmoduleCollidingPorts.contains((defname, params.nonEmpty)) => errors.append(new DefnameDifferentPortsException(info, name, defname)) case _ => } @@ -147,14 +155,14 @@ trait CheckHighFormLike { this: Pass => } def nonNegativeConsts(): Unit = { - e.consts.filter(_ < 0).foreach { - negC => errors.append(new NegArgException(info, mname, e.op.toString, negC)) + e.consts.filter(_ < 0).foreach { negC => + errors.append(new NegArgException(info, mname, e.op.toString, negC)) } } e.op match { - case Add | Sub | Mul | Div | Rem | Lt | Leq | Gt | Geq | - Eq | Neq | Dshl | Dshr | And | Or | Xor | Cat | Dshlw | Clip | Wrap | Squeeze => + case Add | Sub | Mul | Div | Rem | Lt | Leq | Gt | Geq | Eq | Neq | Dshl | Dshr | And | Or | Xor | Cat | Dshlw | + Clip | Wrap | Squeeze => correctNum(Option(2), 0) case AsUInt | AsSInt | AsClock | AsAsyncReset | Cvt | Neq | Not => correctNum(Option(1), 0) @@ -175,7 +183,7 @@ trait CheckHighFormLike { this: Pass => case AsInterval => correctNum(Option(1), 3) case Andr | Orr | Xorr | Neg => - correctNum(None,0) + correctNum(None, 0) } } @@ -208,12 +216,12 @@ trait CheckHighFormLike { this: Pass => } def checkHighFormT(info: Info, mname: => String)(t: Type): Unit = { - t foreach checkHighFormT(info, mname) + t.foreach(checkHighFormT(info, mname)) t match { case tx: VectorType if tx.size < 0 => errors.append(new NegVecSizeException(info, mname)) case _: IntervalType => - case _ => t foreach checkHighFormW(info, mname) + case _ => t.foreach(checkHighFormW(info, mname)) } } @@ -235,12 +243,12 @@ trait CheckHighFormLike { this: Pass => errors.append(new NegUIntException(info, mname)) case ex: DoPrim => checkHighFormPrimop(info, mname, ex) case _: Reference | _: WRef | _: UIntLiteral | _: Mux | _: ValidIf => - case ex: SubAccess => validSubexp(info, mname)(ex.expr) + case ex: SubAccess => validSubexp(info, mname)(ex.expr) case ex: WSubAccess => validSubexp(info, mname)(ex.expr) - case ex => ex foreach validSubexp(info, mname) + case ex => ex.foreach(validSubexp(info, mname)) } - e foreach checkHighFormW(info, mname + "/" + e.serialize) - e foreach checkHighFormE(info, mname, names) + e.foreach(checkHighFormW(info, mname + "/" + e.serialize)) + e.foreach(checkHighFormE(info, mname, names)) } def checkName(info: Info, mname: String, names: ScopeView)(name: String): Unit = { @@ -253,14 +261,17 @@ trait CheckHighFormLike { this: Pass => if (!moduleNames(child)) errors.append(new ModuleNotDefinedException(info, parent, child)) // Check to see if a recursive module instantiation has occured - val childToParent = moduleGraph add (parent, child) + val childToParent = moduleGraph.add(parent, child) if (childToParent.nonEmpty) - errors.append(new InstanceLoop(info, parent, childToParent mkString "->")) + errors.append(new InstanceLoop(info, parent, childToParent.mkString("->"))) } def checkHighFormS(minfo: Info, mname: String, names: ScopeView)(s: Statement): Unit = { - val info = get_info(s) match {case NoInfo => minfo case x => x} - s foreach checkName(info, mname, names) + val info = get_info(s) match { + case NoInfo => minfo + case x => x + } + s.foreach(checkName(info, mname, names)) s match { case DefRegister(info, name, tpe, _, reset, init) => if (hasFlip(tpe)) @@ -272,24 +283,24 @@ trait CheckHighFormLike { this: Pass => errors.append(new MemWithFlipException(info, mname, sx.name)) if (sx.depth <= 0) errors.append(new NegMemSizeException(info, mname)) - case sx: DefInstance => checkInstance(info, mname, sx.module) - case sx: WDefInstance => checkInstance(info, mname, sx.module) - case sx: Connect => checkValidLoc(info, mname, sx.loc) - case sx: PartialConnect => checkValidLoc(info, mname, sx.loc) - case sx: Print => checkFstring(info, mname, sx.string, sx.args.length) - case _: CDefMemory => errorOnChirrtl(info, mname, s).foreach { e => errors.append(e) } + case sx: DefInstance => checkInstance(info, mname, sx.module) + case sx: WDefInstance => checkInstance(info, mname, sx.module) + case sx: Connect => checkValidLoc(info, mname, sx.loc) + case sx: PartialConnect => checkValidLoc(info, mname, sx.loc) + case sx: Print => checkFstring(info, mname, sx.string, sx.args.length) + case _: CDefMemory => errorOnChirrtl(info, mname, s).foreach { e => errors.append(e) } case mport: CDefMPort => errorOnChirrtl(info, mname, s).foreach { e => errors.append(e) } names.expandMPortVisibility(mport) case sx => // Do Nothing } - s foreach checkHighFormT(info, mname) - s foreach checkHighFormE(info, mname, names) + s.foreach(checkHighFormT(info, mname)) + s.foreach(checkHighFormE(info, mname, names)) s match { - case Conditionally(_,_, conseq, alt) => + case Conditionally(_, _, conseq, alt) => checkHighFormS(minfo, mname, names.childScope())(conseq) checkHighFormS(minfo, mname, names.childScope())(alt) - case _ => s foreach checkHighFormS(minfo, mname, names) + case _ => s.foreach(checkHighFormS(minfo, mname, names)) } } @@ -313,10 +324,10 @@ trait CheckHighFormLike { this: Pass => def checkHighFormM(m: DefModule): Unit = { val names = ScopeView() - m foreach checkHighFormP(m.name, names) - m foreach checkHighFormS(m.info, m.name, names) + m.foreach(checkHighFormP(m.name, names)) + m.foreach(checkHighFormS(m.info, m.name, names)) m match { - case _: Module => + case _: Module => case ext: ExtModule => for ((port, expr) <- findBadResetTypePorts(ext, Output)) { errors.append(new ResetExtModuleOutputException(port.info, ext.name, expr)) @@ -324,7 +335,7 @@ trait CheckHighFormLike { this: Pass => } } - c.modules foreach checkHighFormM + c.modules.foreach(checkHighFormM) c.modules.filter(_.name == c.main) match { case Seq(topMod) => for ((port, expr) <- findBadResetTypePorts(topMod, Input)) { @@ -342,21 +353,23 @@ object CheckHighForm extends Pass with CheckHighFormLike { override def prerequisites = firrtl.stage.Forms.WorkingIR override def optionalPrerequisiteOf = - Seq( Dependency(passes.ResolveKinds), - Dependency(passes.InferTypes), - Dependency(passes.ResolveFlows), - Dependency[passes.InferWidths], - Dependency[transforms.InferResets] ) + Seq( + Dependency(passes.ResolveKinds), + Dependency(passes.InferTypes), + Dependency(passes.ResolveFlows), + Dependency[passes.InferWidths], + Dependency[transforms.InferResets] + ) override def invalidates(a: Transform) = false - class IllegalChirrtlMemException(info: Info, mname: String, name: String) extends PassException( - s"$info: [module $mname] Memory $name has not been properly lowered from Chirrtl IR.") + class IllegalChirrtlMemException(info: Info, mname: String, name: String) + extends PassException(s"$info: [module $mname] Memory $name has not been properly lowered from Chirrtl IR.") def errorOnChirrtl(info: Info, mname: String, s: Statement): Option[PassException] = { val memName = s match { case cm: CDefMemory => cm.name - case cp: CDefMPort => cp.mem + case cp: CDefMPort => cp.mem } Some(new IllegalChirrtlMemException(info, mname, memName)) } diff --git a/src/main/scala/firrtl/passes/CheckInitialization.scala b/src/main/scala/firrtl/passes/CheckInitialization.scala index 4a5577f9..96057831 100644 --- a/src/main/scala/firrtl/passes/CheckInitialization.scala +++ b/src/main/scala/firrtl/passes/CheckInitialization.scala @@ -22,10 +22,11 @@ object CheckInitialization extends Pass { private case class VoidExpr(stmt: Statement, voidDeps: Seq[Expression]) - class RefNotInitializedException(info: Info, mname: String, name: String, trace: Seq[Statement]) extends PassException( - s"$info : [module $mname] Reference $name is not fully initialized.\n" + - trace.map(s => s" ${get_info(s)} : ${s.serialize}").mkString("\n") - ) + class RefNotInitializedException(info: Info, mname: String, name: String, trace: Seq[Statement]) + extends PassException( + s"$info : [module $mname] Reference $name is not fully initialized.\n" + + trace.map(s => s" ${get_info(s)} : ${s.serialize}").mkString("\n") + ) private def getTrace(expr: WrappedExpression, voidExprs: Map[WrappedExpression, VoidExpr]): Seq[Statement] = { @tailrec @@ -81,7 +82,7 @@ object CheckInitialization extends Pass { case node: DefNode => // Ignore nodes case decl: IsDeclaration => val trace = getTrace(expr, voidExprs.toMap) - errors append new RefNotInitializedException(decl.info, m.name, decl.name, trace) + errors.append(new RefNotInitializedException(decl.info, m.name, decl.name, trace)) } } } diff --git a/src/main/scala/firrtl/passes/CheckTypes.scala b/src/main/scala/firrtl/passes/CheckTypes.scala index c94928a1..956c1134 100644 --- a/src/main/scala/firrtl/passes/CheckTypes.scala +++ b/src/main/scala/firrtl/passes/CheckTypes.scala @@ -16,92 +16,105 @@ object CheckTypes extends Pass { override def prerequisites = Dependency(InferTypes) +: firrtl.stage.Forms.WorkingIR override def optionalPrerequisiteOf = - Seq( Dependency(passes.ResolveFlows), - Dependency(passes.CheckFlows), - Dependency[passes.InferWidths], - Dependency(passes.CheckWidths) ) + Seq( + Dependency(passes.ResolveFlows), + Dependency(passes.CheckFlows), + Dependency[passes.InferWidths], + Dependency(passes.CheckWidths) + ) override def invalidates(a: Transform) = false // Custom Exceptions - class SubfieldNotInBundle(info: Info, mname: String, name: String) extends PassException( - s"$info: [module $mname ] Subfield $name is not in bundle.") - class SubfieldOnNonBundle(info: Info, mname: String, name: String) extends PassException( - s"$info: [module $mname] Subfield $name is accessed on a non-bundle.") - class IndexTooLarge(info: Info, mname: String, value: Int) extends PassException( - s"$info: [module $mname] Index with value $value is too large.") - class IndexOnNonVector(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Index illegal on non-vector type.") - class AccessIndexNotUInt(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Access index must be a UInt type.") - class IndexNotUInt(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Index is not of UIntType.") - class EnableNotUInt(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Enable is not of UIntType.") + class SubfieldNotInBundle(info: Info, mname: String, name: String) + extends PassException(s"$info: [module $mname ] Subfield $name is not in bundle.") + class SubfieldOnNonBundle(info: Info, mname: String, name: String) + extends PassException(s"$info: [module $mname] Subfield $name is accessed on a non-bundle.") + class IndexTooLarge(info: Info, mname: String, value: Int) + extends PassException(s"$info: [module $mname] Index with value $value is too large.") + class IndexOnNonVector(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Index illegal on non-vector type.") + class AccessIndexNotUInt(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Access index must be a UInt type.") + class IndexNotUInt(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Index is not of UIntType.") + class EnableNotUInt(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Enable is not of UIntType.") class InvalidConnect(info: Info, mname: String, con: String, lhs: Expression, rhs: Expression) extends PassException({ - val ltpe = s" ${lhs.serialize}: ${lhs.tpe.serialize}" - val rtpe = s" ${rhs.serialize}: ${rhs.tpe.serialize}" - s"$info: [module $mname] Type mismatch in '$con'.\n$ltpe\n$rtpe" - }) - class InvalidRegInit(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Type of init must match type of DefRegister.") - class PrintfArgNotGround(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Printf arguments must be either UIntType or SIntType.") - class ReqClk(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Requires a clock typed signal.") - class RegReqClk(info: Info, mname: String, name: String) extends PassException( - s"$info: [module $mname] Register $name requires a clock typed signal.") - class EnNotUInt(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Enable must be a UIntType typed signal.") - class PredNotUInt(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Predicate not a UIntType.") - class OpNotGround(info: Info, mname: String, op: String) extends PassException( - s"$info: [module $mname] Primop $op cannot operate on non-ground types.") - class OpNotUInt(info: Info, mname: String, op: String, e: String) extends PassException( - s"$info: [module $mname] Primop $op requires argument $e to be a UInt type.") - class OpNotAllUInt(info: Info, mname: String, op: String) extends PassException( - s"$info: [module $mname] Primop $op requires all arguments to be UInt type.") - class OpNotAllSameType(info: Info, mname: String, op: String) extends PassException( - s"$info: [module $mname] Primop $op requires all operands to have the same type.") - class OpNoMixFix(info:Info, mname: String, op: String) extends PassException(s"${info}: [module ${mname}] Primop ${op} cannot operate on args of some, but not all, fixed type.") - class OpNotCorrectType(info:Info, mname: String, op: String, tpes: Seq[String]) extends PassException(s"${info}: [module ${mname}] Primop ${op} does not have correct arg types: $tpes.") - class OpNotAnalog(info: Info, mname: String, exp: String) extends PassException( - s"$info: [module $mname] Attach requires all arguments to be Analog type: $exp.") - class NodePassiveType(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Node must be a passive type.") - class MuxSameType(info: Info, mname: String, t1: String, t2: String) extends PassException( - s"$info: [module $mname] Must mux between equivalent types: $t1 != $t2.") - class MuxPassiveTypes(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Must mux between passive types.") - class MuxCondUInt(info: Info, mname: String) extends PassException( - s"$info: [module $mname] A mux condition must be of type UInt.") - class MuxClock(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Firrtl does not support muxing clocks.") - class ValidIfPassiveTypes(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Must validif a passive type.") - class ValidIfCondUInt(info: Info, mname: String) extends PassException( - s"$info: [module $mname] A validif condition must be of type UInt.") - class IllegalAnalogDeclaration(info: Info, mname: String, decName: String) extends PassException( - s"$info: [module $mname] Cannot declare a reg, node, or memory with an Analog type: $decName.") - class IllegalAttachExp(info: Info, mname: String, expName: String) extends PassException( - s"$info: [module $mname] Attach expression must be an port, wire, or port of instance: $expName.") - class IllegalResetType(info: Info, mname: String, exp: String) extends PassException( - s"$info: [module $mname] Register resets must have type Reset, AsyncReset, or UInt<1>: $exp.") - class IllegalUnknownType(info: Info, mname: String, exp: String) extends PassException( - s"$info: [module $mname] Uninferred type: $exp." - ) + val ltpe = s" ${lhs.serialize}: ${lhs.tpe.serialize}" + val rtpe = s" ${rhs.serialize}: ${rhs.tpe.serialize}" + s"$info: [module $mname] Type mismatch in '$con'.\n$ltpe\n$rtpe" + }) + class InvalidRegInit(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Type of init must match type of DefRegister.") + class PrintfArgNotGround(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Printf arguments must be either UIntType or SIntType.") + class ReqClk(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Requires a clock typed signal.") + class RegReqClk(info: Info, mname: String, name: String) + extends PassException(s"$info: [module $mname] Register $name requires a clock typed signal.") + class EnNotUInt(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Enable must be a UIntType typed signal.") + class PredNotUInt(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Predicate not a UIntType.") + class OpNotGround(info: Info, mname: String, op: String) + extends PassException(s"$info: [module $mname] Primop $op cannot operate on non-ground types.") + class OpNotUInt(info: Info, mname: String, op: String, e: String) + extends PassException(s"$info: [module $mname] Primop $op requires argument $e to be a UInt type.") + class OpNotAllUInt(info: Info, mname: String, op: String) + extends PassException(s"$info: [module $mname] Primop $op requires all arguments to be UInt type.") + class OpNotAllSameType(info: Info, mname: String, op: String) + extends PassException(s"$info: [module $mname] Primop $op requires all operands to have the same type.") + class OpNoMixFix(info: Info, mname: String, op: String) + extends PassException( + s"${info}: [module ${mname}] Primop ${op} cannot operate on args of some, but not all, fixed type." + ) + class OpNotCorrectType(info: Info, mname: String, op: String, tpes: Seq[String]) + extends PassException(s"${info}: [module ${mname}] Primop ${op} does not have correct arg types: $tpes.") + class OpNotAnalog(info: Info, mname: String, exp: String) + extends PassException(s"$info: [module $mname] Attach requires all arguments to be Analog type: $exp.") + class NodePassiveType(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Node must be a passive type.") + class MuxSameType(info: Info, mname: String, t1: String, t2: String) + extends PassException(s"$info: [module $mname] Must mux between equivalent types: $t1 != $t2.") + class MuxPassiveTypes(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Must mux between passive types.") + class MuxCondUInt(info: Info, mname: String) + extends PassException(s"$info: [module $mname] A mux condition must be of type UInt.") + class MuxClock(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Firrtl does not support muxing clocks.") + class ValidIfPassiveTypes(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Must validif a passive type.") + class ValidIfCondUInt(info: Info, mname: String) + extends PassException(s"$info: [module $mname] A validif condition must be of type UInt.") + class IllegalAnalogDeclaration(info: Info, mname: String, decName: String) + extends PassException( + s"$info: [module $mname] Cannot declare a reg, node, or memory with an Analog type: $decName." + ) + class IllegalAttachExp(info: Info, mname: String, expName: String) + extends PassException( + s"$info: [module $mname] Attach expression must be an port, wire, or port of instance: $expName." + ) + class IllegalResetType(info: Info, mname: String, exp: String) + extends PassException( + s"$info: [module $mname] Register resets must have type Reset, AsyncReset, or UInt<1>: $exp." + ) + class IllegalUnknownType(info: Info, mname: String, exp: String) + extends PassException( + s"$info: [module $mname] Uninferred type: $exp." + ) def fits(bigger: Constraint, smaller: Constraint): Boolean = (bigger, smaller) match { case (IsKnown(v1), IsKnown(v2)) if v1 < v2 => false - case _ => true + case _ => true } def legalResetType(tpe: Type): Boolean = tpe match { case UIntType(IntWidth(w)) if w == 1 => true - case AsyncResetType => true - case ResetType => true - case UIntType(UnknownWidth) => + case AsyncResetType => true + case ResetType => true + case UIntType(UnknownWidth) => // cannot catch here, though width may ultimately be wrong true case _ => false @@ -118,13 +131,13 @@ object CheckTypes extends Pass { fits(i2.lower, i1.lower) && fits(i1.upper, i2.upper) && fits(i1.point, i2.point) case (_: AnalogType, _: AnalogType) => true case (AsyncResetType, AsyncResetType) => flip1 == flip2 - case (ResetType, tpe) => legalResetType(tpe) && flip1 == flip2 - case (tpe, ResetType) => legalResetType(tpe) && flip1 == flip2 + case (ResetType, tpe) => legalResetType(tpe) && flip1 == flip2 + case (tpe, ResetType) => legalResetType(tpe) && flip1 == flip2 case (t1: BundleType, t2: BundleType) => - val t1_fields = (t1.fields foldLeft Map[String, (Type, Orientation)]())( - (map, f1) => map + (f1.name ->( (f1.tpe, f1.flip) ))) - t2.fields forall (f2 => - t1_fields get f2.name match { + val t1_fields = + (t1.fields.foldLeft(Map[String, (Type, Orientation)]()))((map, f1) => map + (f1.name -> ((f1.tpe, f1.flip)))) + t2.fields.forall(f2 => + t1_fields.get(f2.name) match { case None => true case Some((f1_tpe, f1_flip)) => bulk_equals(f1_tpe, f2.tpe, times(flip1, f1_flip), times(flip2, f2.flip)) @@ -155,79 +168,155 @@ object CheckTypes extends Pass { def ut: UIntType = UIntType(UnknownWidth) def st: SIntType = SIntType(UnknownWidth) - def run (c:Circuit) : Circuit = { + def run(c: Circuit): Circuit = { val errors = new Errors() def passive(t: Type): Boolean = t match { - case _: UIntType |_: SIntType => true + case _: UIntType | _: SIntType => true case tx: VectorType => passive(tx.tpe) - case tx: BundleType => tx.fields forall (x => x.flip == Default && passive(x.tpe)) + case tx: BundleType => tx.fields.forall(x => x.flip == Default && passive(x.tpe)) case tx => true } def check_types_primop(info: Info, mname: String, e: DoPrim): Unit = { - def checkAllTypes(exprs: Seq[Expression], okUInt: Boolean, okSInt: Boolean, okClock: Boolean, okFix: Boolean, okAsync: Boolean, okInterval: Boolean): Unit = { + def checkAllTypes( + exprs: Seq[Expression], + okUInt: Boolean, + okSInt: Boolean, + okClock: Boolean, + okFix: Boolean, + okAsync: Boolean, + okInterval: Boolean + ): Unit = { exprs.foldLeft((false, false, false, false, false, false)) { - case ((isUInt, isSInt, isClock, isFix, isAsync, isInterval), expr) => expr.tpe match { - case u: UIntType => (true, isSInt, isClock, isFix, isAsync, isInterval) - case s: SIntType => (isUInt, true, isClock, isFix, isAsync, isInterval) - case ClockType => (isUInt, isSInt, true, isFix, isAsync, isInterval) - case f: FixedType => (isUInt, isSInt, isClock, true, isAsync, isInterval) - case AsyncResetType => (isUInt, isSInt, isClock, isFix, true, isInterval) - case i:IntervalType => (isUInt, isSInt, isClock, isFix, isAsync, true) - case UnknownType => - errors.append(new IllegalUnknownType(info, mname, e.serialize)) - (isUInt, isSInt, isClock, isFix, isAsync, isInterval) - case other => throwInternalError(s"Illegal Type: ${other.serialize}") - } + case ((isUInt, isSInt, isClock, isFix, isAsync, isInterval), expr) => + expr.tpe match { + case u: UIntType => (true, isSInt, isClock, isFix, isAsync, isInterval) + case s: SIntType => (isUInt, true, isClock, isFix, isAsync, isInterval) + case ClockType => (isUInt, isSInt, true, isFix, isAsync, isInterval) + case f: FixedType => (isUInt, isSInt, isClock, true, isAsync, isInterval) + case AsyncResetType => (isUInt, isSInt, isClock, isFix, true, isInterval) + case i: IntervalType => (isUInt, isSInt, isClock, isFix, isAsync, true) + case UnknownType => + errors.append(new IllegalUnknownType(info, mname, e.serialize)) + (isUInt, isSInt, isClock, isFix, isAsync, isInterval) + case other => throwInternalError(s"Illegal Type: ${other.serialize}") + } } match { // (UInt, SInt, Clock, Fixed, Async, Interval) - case (isAll, false, false, false, false, false) if isAll == okUInt => - case (false, isAll, false, false, false, false) if isAll == okSInt => - case (false, false, isAll, false, false, false) if isAll == okClock => - case (false, false, false, isAll, false, false) if isAll == okFix => - case (false, false, false, false, isAll, false) if isAll == okAsync => + case (isAll, false, false, false, false, false) if isAll == okUInt => + case (false, isAll, false, false, false, false) if isAll == okSInt => + case (false, false, isAll, false, false, false) if isAll == okClock => + case (false, false, false, isAll, false, false) if isAll == okFix => + case (false, false, false, false, isAll, false) if isAll == okAsync => case (false, false, false, false, false, isAll) if isAll == okInterval => - case x => errors.append(new OpNotCorrectType(info, mname, e.op.serialize, exprs.map(_.tpe.serialize))) + case x => errors.append(new OpNotCorrectType(info, mname, e.op.serialize, exprs.map(_.tpe.serialize))) } } e.op match { case AsUInt | AsSInt | AsClock | AsFixedPoint | AsAsyncReset | AsInterval => - // All types are ok + // All types are ok case Dshl | Dshr => - checkAllTypes(Seq(e.args.head), okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true) - checkAllTypes(Seq(e.args(1)), okUInt=true, okSInt=false, okClock=false, okFix=false, okAsync=false, okInterval=false) + checkAllTypes( + Seq(e.args.head), + okUInt = true, + okSInt = true, + okClock = false, + okFix = true, + okAsync = false, + okInterval = true + ) + checkAllTypes( + Seq(e.args(1)), + okUInt = true, + okSInt = false, + okClock = false, + okFix = false, + okAsync = false, + okInterval = false + ) case Add | Sub | Mul | Lt | Leq | Gt | Geq | Eq | Neq => - checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true) + checkAllTypes( + e.args, + okUInt = true, + okSInt = true, + okClock = false, + okFix = true, + okAsync = false, + okInterval = true + ) case Pad | Bits | Head | Tail => - checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=false) + checkAllTypes( + e.args, + okUInt = true, + okSInt = true, + okClock = false, + okFix = true, + okAsync = false, + okInterval = false + ) case Shl | Shr | Cat => - checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true) + checkAllTypes( + e.args, + okUInt = true, + okSInt = true, + okClock = false, + okFix = true, + okAsync = false, + okInterval = true + ) case IncP | DecP | SetP => - checkAllTypes(e.args, okUInt=false, okSInt=false, okClock=false, okFix=true, okAsync=false, okInterval=true) + checkAllTypes( + e.args, + okUInt = false, + okSInt = false, + okClock = false, + okFix = true, + okAsync = false, + okInterval = true + ) case Wrap | Clip | Squeeze => - checkAllTypes(e.args, okUInt = false, okSInt = false, okClock = false, okFix = false, okAsync=false, okInterval = true) + checkAllTypes( + e.args, + okUInt = false, + okSInt = false, + okClock = false, + okFix = false, + okAsync = false, + okInterval = true + ) case _ => - checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=false, okAsync=false, okInterval=false) + checkAllTypes( + e.args, + okUInt = true, + okSInt = true, + okClock = false, + okFix = false, + okAsync = false, + okInterval = false + ) } } - def check_types_e(info:Info, mname: String)(e: Expression): Unit = { + def check_types_e(info: Info, mname: String)(e: Expression): Unit = { e match { - case (e: WSubField) => e.expr.tpe match { - case (t: BundleType) => t.fields find (_.name == e.name) match { - case Some(_) => - case None => errors.append(new SubfieldNotInBundle(info, mname, e.name)) + case (e: WSubField) => + e.expr.tpe match { + case (t: BundleType) => + t.fields.find(_.name == e.name) match { + case Some(_) => + case None => errors.append(new SubfieldNotInBundle(info, mname, e.name)) + } + case _ => errors.append(new SubfieldOnNonBundle(info, mname, e.name)) + } + case (e: WSubIndex) => + e.expr.tpe match { + case (t: VectorType) if e.value < t.size => + case (t: VectorType) => + errors.append(new IndexTooLarge(info, mname, e.value)) + case _ => + errors.append(new IndexOnNonVector(info, mname)) } - case _ => errors.append(new SubfieldOnNonBundle(info, mname, e.name)) - } - case (e: WSubIndex) => e.expr.tpe match { - case (t: VectorType) if e.value < t.size => - case (t: VectorType) => - errors.append(new IndexTooLarge(info, mname, e.value)) - case _ => - errors.append(new IndexOnNonVector(info, mname)) - } case (e: WSubAccess) => e.expr.tpe match { case _: VectorType => @@ -256,11 +345,14 @@ object CheckTypes extends Pass { } case _ => } - e foreach check_types_e(info, mname) + e.foreach(check_types_e(info, mname)) } def check_types_s(minfo: Info, mname: String)(s: Statement): Unit = { - val info = get_info(s) match { case NoInfo => minfo case x => x } + val info = get_info(s) match { + case NoInfo => minfo + case x => x + } s match { case sx: Connect if !validConnect(sx) => val conMsg = sx.copy(info = NoInfo).serialize @@ -270,7 +362,7 @@ object CheckTypes extends Pass { errors.append(new InvalidConnect(info, mname, conMsg, sx.loc, sx.expr)) case sx: DefRegister => sx.tpe match { - case AnalogType(_) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name)) + case AnalogType(_) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name)) case t if wt(sx.tpe) != wt(sx.init.tpe) => errors.append(new InvalidRegInit(info, mname)) case t if !validConnect(sx.tpe, sx.init.tpe) => val conMsg = sx.copy(info = NoInfo).serialize @@ -285,11 +377,12 @@ object CheckTypes extends Pass { } case sx: Conditionally if wt(sx.pred.tpe) != wt(ut) => errors.append(new PredNotUInt(info, mname)) - case sx: DefNode => sx.value.tpe match { - case AnalogType(w) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name)) - case t if !passive(sx.value.tpe) => errors.append(new NodePassiveType(info, mname)) - case t => - } + case sx: DefNode => + sx.value.tpe match { + case AnalogType(w) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name)) + case t if !passive(sx.value.tpe) => errors.append(new NodePassiveType(info, mname)) + case t => + } case sx: Attach => for (e <- sx.exprs) { e.tpe match { @@ -298,14 +391,14 @@ object CheckTypes extends Pass { } kind(e) match { case (InstanceKind | PortKind | WireKind) => - case _ => errors.append(new IllegalAttachExp(info, mname, e.serialize)) + case _ => errors.append(new IllegalAttachExp(info, mname, e.serialize)) } } case sx: Stop => if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname)) if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname)) case sx: Print => - if (sx.args exists (x => wt(x.tpe) != wt(ut) && wt(x.tpe) != wt(st))) + if (sx.args.exists(x => wt(x.tpe) != wt(ut) && wt(x.tpe) != wt(st))) errors.append(new PrintfArgNotGround(info, mname)) if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname)) if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname)) @@ -313,17 +406,18 @@ object CheckTypes extends Pass { if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname)) if (wt(sx.pred.tpe) != wt(ut)) errors.append(new PredNotUInt(info, mname)) if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname)) - case sx: DefMemory => sx.dataType match { - case AnalogType(w) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name)) - case t => - } + case sx: DefMemory => + sx.dataType match { + case AnalogType(w) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name)) + case t => + } case _ => } - s foreach check_types_e(info, mname) - s foreach check_types_s(info, mname) + s.foreach(check_types_e(info, mname)) + s.foreach(check_types_s(info, mname)) } - c.modules foreach (m => m foreach check_types_s(m.info, m.name)) + c.modules.foreach(m => m.foreach(check_types_s(m.info, m.name))) errors.trigger() c } diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala index a7729ef8..f7fefa87 100644 --- a/src/main/scala/firrtl/passes/CheckWidths.scala +++ b/src/main/scala/firrtl/passes/CheckWidths.scala @@ -22,43 +22,49 @@ object CheckWidths extends Pass { /** The maximum allowed width for any circuit element */ val MaxWidth = 1000000 val DshlMaxWidth = getUIntWidth(MaxWidth) - class UninferredWidth (info: Info, target: String) extends PassException( - s"""|$info : Uninferred width for target below.serialize}. (Did you forget to assign to it?) - |$target""".stripMargin) - class UninferredBound (info: Info, target: String, bound: String) extends PassException( - s"""|$info : Uninferred $bound bound for target. (Did you forget to assign to it?) - |$target""".stripMargin) - class InvalidRange (info: Info, target: String, i: IntervalType) extends PassException( - s"""|$info : Invalid range ${i.serialize} for target below. (Are the bounds valid?) - |$target""".stripMargin) - class WidthTooSmall(info: Info, mname: String, b: BigInt) extends PassException( - s"$info : [target $mname] Width too small for constant $b.") - class WidthTooBig(info: Info, mname: String, b: BigInt) extends PassException( - s"$info : [target $mname] Width $b greater than max allowed width of $MaxWidth bits") - class DshlTooBig(info: Info, mname: String) extends PassException( - s"$info : [target $mname] Width of dshl shift amount must be less than $DshlMaxWidth bits.") - class MultiBitAsClock(info: Info, mname: String) extends PassException( - s"$info : [target $mname] Cannot cast a multi-bit signal to a Clock.") - class MultiBitAsAsyncReset(info: Info, mname: String) extends PassException( - s"$info : [target $mname] Cannot cast a multi-bit signal to an AsyncReset.") - class NegWidthException(info:Info, mname: String) extends PassException( - s"$info: [target $mname] Width cannot be negative or zero.") - class BitsWidthException(info: Info, mname: String, hi: BigInt, width: BigInt, exp: String) extends PassException( - s"$info: [target $mname] High bit $hi in bits operator is larger than input width $width in $exp.") - class HeadWidthException(info: Info, mname: String, n: BigInt, width: BigInt) extends PassException( - s"$info: [target $mname] Parameter $n in head operator is larger than input width $width.") - class TailWidthException(info: Info, mname: String, n: BigInt, width: BigInt) extends PassException( - s"$info: [target $mname] Parameter $n in tail operator is larger than input width $width.") - class AttachWidthsNotEqual(info: Info, mname: String, eName: String, source: String) extends PassException( - s"$info: [target $mname] Attach source $source and expression $eName must have identical widths.") + class UninferredWidth(info: Info, target: String) + extends PassException(s"""|$info : Uninferred width for target below.serialize}. (Did you forget to assign to it?) + |$target""".stripMargin) + class UninferredBound(info: Info, target: String, bound: String) + extends PassException(s"""|$info : Uninferred $bound bound for target. (Did you forget to assign to it?) + |$target""".stripMargin) + class InvalidRange(info: Info, target: String, i: IntervalType) + extends PassException(s"""|$info : Invalid range ${i.serialize} for target below. (Are the bounds valid?) + |$target""".stripMargin) + class WidthTooSmall(info: Info, mname: String, b: BigInt) + extends PassException(s"$info : [target $mname] Width too small for constant $b.") + class WidthTooBig(info: Info, mname: String, b: BigInt) + extends PassException(s"$info : [target $mname] Width $b greater than max allowed width of $MaxWidth bits") + class DshlTooBig(info: Info, mname: String) + extends PassException( + s"$info : [target $mname] Width of dshl shift amount must be less than $DshlMaxWidth bits." + ) + class MultiBitAsClock(info: Info, mname: String) + extends PassException(s"$info : [target $mname] Cannot cast a multi-bit signal to a Clock.") + class MultiBitAsAsyncReset(info: Info, mname: String) + extends PassException(s"$info : [target $mname] Cannot cast a multi-bit signal to an AsyncReset.") + class NegWidthException(info: Info, mname: String) + extends PassException(s"$info: [target $mname] Width cannot be negative or zero.") + class BitsWidthException(info: Info, mname: String, hi: BigInt, width: BigInt, exp: String) + extends PassException( + s"$info: [target $mname] High bit $hi in bits operator is larger than input width $width in $exp." + ) + class HeadWidthException(info: Info, mname: String, n: BigInt, width: BigInt) + extends PassException(s"$info: [target $mname] Parameter $n in head operator is larger than input width $width.") + class TailWidthException(info: Info, mname: String, n: BigInt, width: BigInt) + extends PassException(s"$info: [target $mname] Parameter $n in tail operator is larger than input width $width.") + class AttachWidthsNotEqual(info: Info, mname: String, eName: String, source: String) + extends PassException( + s"$info: [target $mname] Attach source $source and expression $eName must have identical widths." + ) class DisjointSqueeze(info: Info, mname: String, squeeze: DoPrim) - extends PassException({ - val toSqz = squeeze.args.head.serialize - val toSqzTpe = squeeze.args.head.tpe.serialize - val sqzTo = squeeze.args(1).serialize - val sqzToTpe = squeeze.args(1).tpe.serialize - s"$info: [module $mname] Disjoint squz currently unsupported: $toSqz:$toSqzTpe cannot be squeezed with $sqzTo's type $sqzToTpe" - }) + extends PassException({ + val toSqz = squeeze.args.head.serialize + val toSqzTpe = squeeze.args.head.tpe.serialize + val sqzTo = squeeze.args(1).serialize + val sqzToTpe = squeeze.args(1).tpe.serialize + s"$info: [module $mname] Disjoint squz currently unsupported: $toSqz:$toSqzTpe cannot be squeezed with $sqzTo's type $sqzToTpe" + }) def run(c: Circuit): Circuit = { val errors = new Errors() @@ -77,35 +83,35 @@ object CheckWidths extends Pass { def hasWidth(tpe: Type): Boolean = tpe match { case GroundType(IntWidth(w)) => true - case GroundType(_) => false - case _ => throwInternalError(s"hasWidth - $tpe") + case GroundType(_) => false + case _ => throwInternalError(s"hasWidth - $tpe") } def check_width_t(info: Info, target: Target)(t: Type): Unit = { t match { case tt: BundleType => tt.fields.foreach(check_width_f(info, target)) //Supports when l = u (if closed) - case i@IntervalType(Closed(l), Closed(u), IntWidth(_)) if l <= u => i - case i:IntervalType if i.range == Some(Nil) => + case i @ IntervalType(Closed(l), Closed(u), IntWidth(_)) if l <= u => i + case i: IntervalType if i.range == Some(Nil) => errors.append(new InvalidRange(info, target.prettyPrint(" "), i)) i - case i@IntervalType(KnownBound(l), KnownBound(u), IntWidth(p)) if l >= u => + case i @ IntervalType(KnownBound(l), KnownBound(u), IntWidth(p)) if l >= u => errors.append(new InvalidRange(info, target.prettyPrint(" "), i)) i - case i@IntervalType(KnownBound(_), KnownBound(_), IntWidth(_)) => i - case i@IntervalType(_: IsKnown, _, _) => + case i @ IntervalType(KnownBound(_), KnownBound(_), IntWidth(_)) => i + case i @ IntervalType(_: IsKnown, _, _) => errors.append(new UninferredBound(info, target.prettyPrint(" "), "upper")) i - case i@IntervalType(_, _: IsKnown, _) => + case i @ IntervalType(_, _: IsKnown, _) => errors.append(new UninferredBound(info, target.prettyPrint(" "), "lower")) i - case i@IntervalType(_, _, _) => + case i @ IntervalType(_, _, _) => errors.append(new UninferredBound(info, target.prettyPrint(" "), "lower")) errors.append(new UninferredBound(info, target.prettyPrint(" "), "upper")) i - case tt => tt foreach check_width_t(info, target) + case tt => tt.foreach(check_width_t(info, target)) } - t foreach check_width_w(info, target, t) + t.foreach(check_width_w(info, target, t)) } def check_width_f(info: Info, target: Target)(f: Field): Unit = @@ -120,7 +126,8 @@ object CheckWidths extends Pass { errors.append(new WidthTooSmall(info, target.serialize, v)) case e @ DoPrim(op, Seq(a, b), _, tpe) => (op, a.tpe, b.tpe) match { - case (Squeeze, IntervalType(Closed(la), Closed(ua), _), IntervalType(Closed(lb), Closed(ub), _)) if (ua < lb) || (ub < la) => + case (Squeeze, IntervalType(Closed(la), Closed(ua), _), IntervalType(Closed(lb), Closed(ub), _)) + if (ua < lb) || (ub < la) => errors.append(new DisjointSqueeze(info, target.serialize, e)) case (Dshl, at, bt) if (hasWidth(at) && bitWidth(bt) >= DshlMaxWidth) => errors.append(new DshlTooBig(info, target.serialize)) @@ -159,7 +166,6 @@ object CheckWidths extends Pass { } } - def check_width_e_dfs(info: Info, target: Target, expr: Expression): Unit = { val stack = collection.mutable.ArrayStack(expr) def push(e: Expression): Unit = stack.push(e) @@ -171,25 +177,31 @@ object CheckWidths extends Pass { } def check_width_s(minfo: Info, target: ModuleTarget)(s: Statement): Unit = { - val info = get_info(s) match { case NoInfo => minfo case x => x } - val subRef = s match { case sx: HasName => target.ref(sx.name) case _ => target } - s foreach check_width_e(info, target, 4) - s foreach check_width_s(info, target) - s foreach check_width_t(info, subRef) + val info = get_info(s) match { + case NoInfo => minfo + case x => x + } + val subRef = s match { + case sx: HasName => target.ref(sx.name) + case _ => target + } + s.foreach(check_width_e(info, target, 4)) + s.foreach(check_width_s(info, target)) + s.foreach(check_width_t(info, subRef)) s match { case Attach(infox, exprs) => - exprs.tail.foreach ( e => + exprs.tail.foreach(e => if (bitWidth(e.tpe) != bitWidth(exprs.head.tpe)) errors.append(new AttachWidthsNotEqual(infox, target.serialize, e.serialize, exprs.head.serialize)) ) case sx: DefRegister => sx.reset.tpe match { case UIntType(IntWidth(w)) if w == 1 => - case AsyncResetType => - case ResetType => - case _ => errors.append(new CheckTypes.IllegalResetType(info, target.serialize, sx.name)) + case AsyncResetType => + case ResetType => + case _ => errors.append(new CheckTypes.IllegalResetType(info, target.serialize, sx.name)) } - if(!CheckTypes.validConnect(sx.tpe, sx.init.tpe)) { + if (!CheckTypes.validConnect(sx.tpe, sx.init.tpe)) { val conMsg = sx.copy(info = NoInfo).serialize errors.append(new CheckTypes.InvalidConnect(info, target.module, conMsg, WRef(sx), sx.init)) } @@ -197,14 +209,15 @@ object CheckWidths extends Pass { } } - def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Unit = check_width_t(p.info, target.ref(p.name))(p.tpe) + def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Unit = + check_width_t(p.info, target.ref(p.name))(p.tpe) def check_width_m(circuit: CircuitTarget)(m: DefModule): Unit = { - m foreach check_width_p(m.info, circuit.module(m.name)) - m foreach check_width_s(m.info, circuit.module(m.name)) + m.foreach(check_width_p(m.info, circuit.module(m.name))) + m.foreach(check_width_s(m.info, circuit.module(m.name))) } - c.modules foreach check_width_m(CircuitTarget(c.main)) + c.modules.foreach(check_width_m(CircuitTarget(c.main))) errors.trigger() c } diff --git a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala index 544f90a6..55a9c53a 100644 --- a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala +++ b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala @@ -10,15 +10,16 @@ import firrtl.options.Dependency object CommonSubexpressionElimination extends Pass { override def prerequisites = firrtl.stage.Forms.LowForm ++ - Seq( Dependency(firrtl.passes.RemoveValidIf), - Dependency[firrtl.transforms.ConstantPropagation], - Dependency(firrtl.passes.memlib.VerilogMemDelays), - Dependency(firrtl.passes.SplitExpressions), - Dependency[firrtl.transforms.CombineCats] ) + Seq( + Dependency(firrtl.passes.RemoveValidIf), + Dependency[firrtl.transforms.ConstantPropagation], + Dependency(firrtl.passes.memlib.VerilogMemDelays), + Dependency(firrtl.passes.SplitExpressions), + Dependency[firrtl.transforms.CombineCats] + ) override def optionalPrerequisiteOf = - Seq( Dependency[SystemVerilogEmitter], - Dependency[VerilogEmitter] ) + Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) override def invalidates(a: Transform) = false @@ -27,24 +28,26 @@ object CommonSubexpressionElimination extends Pass { val nodes = collection.mutable.HashMap[String, Expression]() def eliminateNodeRef(e: Expression): Expression = e match { - case WRef(name, tpe, kind, flow) => nodes get name match { - case Some(expression) => expressions get expression match { - case Some(cseName) if cseName != name => - WRef(cseName, tpe, kind, flow) + case WRef(name, tpe, kind, flow) => + nodes.get(name) match { + case Some(expression) => + expressions.get(expression) match { + case Some(cseName) if cseName != name => + WRef(cseName, tpe, kind, flow) + case _ => e + } case _ => e } - case _ => e - } - case _ => e map eliminateNodeRef + case _ => e.map(eliminateNodeRef) } def eliminateNodeRefs(s: Statement): Statement = { - s map eliminateNodeRef match { + s.map(eliminateNodeRef) match { case x: DefNode => nodes(x.name) = x.value expressions.getOrElseUpdate(x.value, x.name) x - case other => other map eliminateNodeRefs + case other => other.map(eliminateNodeRefs) } } @@ -54,7 +57,7 @@ object CommonSubexpressionElimination extends Pass { def run(c: Circuit): Circuit = { val modulesx = c.modules.map { case m: ExtModule => m - case m: Module => Module(m.info, m.name, m.ports, cse(m.body)) + case m: Module => Module(m.info, m.name, m.ports, cse(m.body)) } Circuit(c.info, modulesx, c.main) } diff --git a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala index 4a426209..baf7d4d5 100644 --- a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala +++ b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala @@ -7,7 +7,7 @@ import firrtl.PrimOps._ import firrtl.ir._ import firrtl._ import firrtl.Mappers._ -import firrtl.Utils.{sub_type, module_type, field_type, max, throwInternalError} +import firrtl.Utils.{field_type, max, module_type, sub_type, throwInternalError} import firrtl.options.Dependency /** Replaces FixedType with SIntType, and correctly aligns all binary points @@ -15,71 +15,74 @@ import firrtl.options.Dependency object ConvertFixedToSInt extends Pass { override def prerequisites = - Seq( Dependency(PullMuxes), - Dependency(ReplaceAccesses), - Dependency(ExpandConnects), - Dependency(RemoveAccesses), - Dependency[ExpandWhensAndCheck], - Dependency[RemoveIntervals] ) ++ firrtl.stage.Forms.Deduped + Seq( + Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects), + Dependency(RemoveAccesses), + Dependency[ExpandWhensAndCheck], + Dependency[RemoveIntervals] + ) ++ firrtl.stage.Forms.Deduped override def invalidates(a: Transform) = false def alignArg(e: Expression, point: BigInt): Expression = e.tpe match { case FixedType(IntWidth(w), IntWidth(p)) => // assert(point >= p) - if((point - p) > 0) { + if ((point - p) > 0) { DoPrim(Shl, Seq(e), Seq(point - p), UnknownType) } else if (point - p < 0) { DoPrim(Shr, Seq(e), Seq(p - point), UnknownType) } else e case FixedType(w, p) => throwInternalError(s"alignArg: shouldn't be here - $e") - case _ => e + case _ => e } def calcPoint(es: Seq[Expression]): BigInt = es.map(_.tpe match { case FixedType(IntWidth(w), IntWidth(p)) => p - case _ => BigInt(0) + case _ => BigInt(0) }).reduce(max(_, _)) def toSIntType(t: Type): Type = t match { case FixedType(IntWidth(w), IntWidth(p)) => SIntType(IntWidth(w)) - case FixedType(w, p) => throwInternalError(s"toSIntType: shouldn't be here - $t") - case _ => t map toSIntType + case FixedType(w, p) => throwInternalError(s"toSIntType: shouldn't be here - $t") + case _ => t.map(toSIntType) } def run(c: Circuit): Circuit = { - val moduleTypes = mutable.HashMap[String,Type]() - def onModule(m:DefModule) : DefModule = { - val types = mutable.HashMap[String,Type]() - def updateExpType(e:Expression): Expression = e match { - case DoPrim(Mul, args, consts, tpe) => e map updateExpType - case DoPrim(AsFixedPoint, args, consts, tpe) => DoPrim(AsSInt, args, Seq.empty, tpe) map updateExpType - case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe) map updateExpType - case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe) map updateExpType - case DoPrim(SetP, args, consts, FixedType(w, IntWidth(p))) => alignArg(args.head, p) map updateExpType + val moduleTypes = mutable.HashMap[String, Type]() + def onModule(m: DefModule): DefModule = { + val types = mutable.HashMap[String, Type]() + def updateExpType(e: Expression): Expression = e match { + case DoPrim(Mul, args, consts, tpe) => e.map(updateExpType) + case DoPrim(AsFixedPoint, args, consts, tpe) => DoPrim(AsSInt, args, Seq.empty, tpe).map(updateExpType) + case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe).map(updateExpType) + case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe).map(updateExpType) + case DoPrim(SetP, args, consts, FixedType(w, IntWidth(p))) => alignArg(args.head, p).map(updateExpType) case DoPrim(op, args, consts, tpe) => val point = calcPoint(args) val newExp = DoPrim(op, args.map(x => alignArg(x, point)), consts, UnknownType) - newExp map updateExpType match { + newExp.map(updateExpType) match { case DoPrim(AsFixedPoint, args, consts, tpe) => DoPrim(AsSInt, args, Seq.empty, tpe) - case e => e + case e => e } case Mux(cond, tval, fval, tpe) => val point = calcPoint(Seq(tval, fval)) val newExp = Mux(cond, alignArg(tval, point), alignArg(fval, point), UnknownType) - newExp map updateExpType + newExp.map(updateExpType) case e: UIntLiteral => e case e: SIntLiteral => e - case _ => e map updateExpType match { - case ValidIf(cond, value, tpe) => ValidIf(cond, value, value.tpe) - case WRef(name, tpe, k, g) => WRef(name, types(name), k, g) - case WSubField(exp, name, tpe, g) => WSubField(exp, name, field_type(exp.tpe, name), g) - case WSubIndex(exp, value, tpe, g) => WSubIndex(exp, value, sub_type(exp.tpe), g) - case WSubAccess(exp, index, tpe, g) => WSubAccess(exp, index, sub_type(exp.tpe), g) - } + case _ => + e.map(updateExpType) match { + case ValidIf(cond, value, tpe) => ValidIf(cond, value, value.tpe) + case WRef(name, tpe, k, g) => WRef(name, types(name), k, g) + case WSubField(exp, name, tpe, g) => WSubField(exp, name, field_type(exp.tpe, name), g) + case WSubIndex(exp, value, tpe, g) => WSubIndex(exp, value, sub_type(exp.tpe), g) + case WSubAccess(exp, index, tpe, g) => WSubAccess(exp, index, sub_type(exp.tpe), g) + } } def updateStmtType(s: Statement): Statement = s match { case DefRegister(info, name, tpe, clock, reset, init) => val newType = toSIntType(tpe) types(name) = newType - DefRegister(info, name, newType, clock, reset, init) map updateExpType + DefRegister(info, name, newType, clock, reset, init).map(updateExpType) case DefWire(info, name, tpe) => val newType = toSIntType(tpe) types(name) = newType @@ -101,37 +104,34 @@ object ConvertFixedToSInt extends Pass { case Connect(info, loc, exp) => val point = calcPoint(Seq(loc)) val newExp = alignArg(exp, point) - Connect(info, loc, newExp) map updateExpType + Connect(info, loc, newExp).map(updateExpType) case PartialConnect(info, loc, exp) => val point = calcPoint(Seq(loc)) val newExp = alignArg(exp, point) - PartialConnect(info, loc, newExp) map updateExpType + PartialConnect(info, loc, newExp).map(updateExpType) // check Connect case, need to shl - case s => (s map updateStmtType) map updateExpType + case s => (s.map(updateStmtType)).map(updateExpType) } m.ports.foreach(p => types(p.name) = p.tpe) m match { - case Module(info, name, ports, body) => Module(info,name,ports,updateStmtType(body)) - case m:ExtModule => m + case Module(info, name, ports, body) => Module(info, name, ports, updateStmtType(body)) + case m: ExtModule => m } } - val newModules = for(m <- c.modules) yield { - val newPorts = m.ports.map(p => Port(p.info,p.name,p.direction,toSIntType(p.tpe))) + val newModules = for (m <- c.modules) yield { + val newPorts = m.ports.map(p => Port(p.info, p.name, p.direction, toSIntType(p.tpe))) m match { - case Module(info, name, ports, body) => Module(info,name,newPorts,body) - case ext: ExtModule => ext.copy(ports = newPorts) + case Module(info, name, ports, body) => Module(info, name, newPorts, body) + case ext: ExtModule => ext.copy(ports = newPorts) } } newModules.foreach(m => moduleTypes(m.name) = module_type(m)) /* @todo This should be moved outside */ - (firrtl.passes.InferTypes).run(Circuit(c.info, newModules.map(onModule(_)), c.main )) + (firrtl.passes.InferTypes).run(Circuit(c.info, newModules.map(onModule(_)), c.main)) } } - - - // vim: set ts=4 sw=4 et: diff --git a/src/main/scala/firrtl/passes/ExpandConnects.scala b/src/main/scala/firrtl/passes/ExpandConnects.scala index d28e6399..4f849c5a 100644 --- a/src/main/scala/firrtl/passes/ExpandConnects.scala +++ b/src/main/scala/firrtl/passes/ExpandConnects.scala @@ -9,8 +9,7 @@ import firrtl.Mappers._ object ExpandConnects extends Pass { override def prerequisites = - Seq( Dependency(PullMuxes), - Dependency(ReplaceAccesses) ) ++ firrtl.stage.Forms.Deduped + Seq(Dependency(PullMuxes), Dependency(ReplaceAccesses)) ++ firrtl.stage.Forms.Deduped override def invalidates(a: Transform) = a match { case ResolveFlows => true @@ -19,62 +18,65 @@ object ExpandConnects extends Pass { def run(c: Circuit): Circuit = { def expand_connects(m: Module): Module = { - val flows = collection.mutable.LinkedHashMap[String,Flow]() + val flows = collection.mutable.LinkedHashMap[String, Flow]() def expand_s(s: Statement): Statement = { - def set_flow(e: Expression): Expression = e map set_flow match { + def set_flow(e: Expression): Expression = e.map(set_flow) match { case ex: WRef => WRef(ex.name, ex.tpe, ex.kind, flows(ex.name)) case ex: WSubField => val f = get_field(ex.expr.tpe, ex.name) val flowx = times(flow(ex.expr), f.flip) WSubField(ex.expr, ex.name, ex.tpe, flowx) - case ex: WSubIndex => WSubIndex(ex.expr, ex.value, ex.tpe, flow(ex.expr)) + case ex: WSubIndex => WSubIndex(ex.expr, ex.value, ex.tpe, flow(ex.expr)) case ex: WSubAccess => WSubAccess(ex.expr, ex.index, ex.tpe, flow(ex.expr)) case ex => ex } s match { - case sx: DefWire => flows(sx.name) = DuplexFlow; sx - case sx: DefRegister => flows(sx.name) = DuplexFlow; sx + case sx: DefWire => flows(sx.name) = DuplexFlow; sx + case sx: DefRegister => flows(sx.name) = DuplexFlow; sx case sx: WDefInstance => flows(sx.name) = SourceFlow; sx - case sx: DefMemory => flows(sx.name) = SourceFlow; sx + case sx: DefMemory => flows(sx.name) = SourceFlow; sx case sx: DefNode => flows(sx.name) = SourceFlow; sx case sx: IsInvalid => - val invalids = create_exps(sx.expr).flatMap { case expx => - flow(set_flow(expx)) match { + val invalids = create_exps(sx.expr).flatMap { + case expx => + flow(set_flow(expx)) match { case DuplexFlow => Some(IsInvalid(sx.info, expx)) - case SinkFlow => Some(IsInvalid(sx.info, expx)) - case _ => None - } + case SinkFlow => Some(IsInvalid(sx.info, expx)) + case _ => None + } } invalids.size match { - case 0 => EmptyStmt - case 1 => invalids.head - case _ => Block(invalids) + case 0 => EmptyStmt + case 1 => invalids.head + case _ => Block(invalids) } case sx: Connect => val locs = create_exps(sx.loc) val exps = create_exps(sx.expr) - Block(locs.zip(exps).map { case (locx, expx) => - to_flip(flow(locx)) match { + Block(locs.zip(exps).map { + case (locx, expx) => + to_flip(flow(locx)) match { case Default => Connect(sx.info, locx, expx) - case Flip => Connect(sx.info, expx, locx) - } + case Flip => Connect(sx.info, expx, locx) + } }) case sx: PartialConnect => val ls = get_valid_points(sx.loc.tpe, sx.expr.tpe, Default, Default) val locs = create_exps(sx.loc) val exps = create_exps(sx.expr) - val stmts = ls map { case (x, y) => - locs(x).tpe match { - case AnalogType(_) => Attach(sx.info, Seq(locs(x), exps(y))) - case _ => - to_flip(flow(locs(x))) match { - case Default => Connect(sx.info, locs(x), exps(y)) - case Flip => Connect(sx.info, exps(y), locs(x)) - } - } + val stmts = ls.map { + case (x, y) => + locs(x).tpe match { + case AnalogType(_) => Attach(sx.info, Seq(locs(x), exps(y))) + case _ => + to_flip(flow(locs(x))) match { + case Default => Connect(sx.info, locs(x), exps(y)) + case Flip => Connect(sx.info, exps(y), locs(x)) + } + } } Block(stmts) - case sx => sx map expand_s + case sx => sx.map(expand_s) } } @@ -83,8 +85,8 @@ object ExpandConnects extends Pass { } val modulesx = c.modules.map { - case (m: ExtModule) => m - case (m: Module) => expand_connects(m) + case (m: ExtModule) => m + case (m: Module) => expand_connects(m) } Circuit(c.info, modulesx, c.main) } diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index ab7f02db..14d5d3ef 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -28,21 +28,23 @@ import collection.mutable object ExpandWhens extends Pass { override def prerequisites = - Seq( Dependency(PullMuxes), - Dependency(ReplaceAccesses), - Dependency(ExpandConnects), - Dependency(RemoveAccesses) ) ++ firrtl.stage.Forms.Resolved + Seq( + Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects), + Dependency(RemoveAccesses) + ) ++ firrtl.stage.Forms.Resolved override def invalidates(a: Transform): Boolean = a match { case CheckInitialization | ResolveKinds | InferTypes => true - case _ => false + case _ => false } /** Returns circuit with when and last connection semantics resolved */ def run(c: Circuit): Circuit = { - val modulesx = c.modules map { + val modulesx = c.modules.map { case m: ExtModule => m - case m: Module => onModule(m) + case m: Module => onModule(m) } Circuit(c.info, modulesx, c.main) } @@ -74,13 +76,12 @@ object ExpandWhens extends Pass { // Does an expression contain WVoid inserted in this pass? def containsVoid(e: Expression): Boolean = e match { - case WVoid => true + case WVoid => true case ValidIf(_, value, _) => memoizedVoid(value) - case Mux(_, tv, fv, _) => memoizedVoid(tv) || memoizedVoid(fv) - case _ => false + case Mux(_, tv, fv, _) => memoizedVoid(tv) || memoizedVoid(fv) + case _ => false } - // Memoizes the node that holds a particular expression, if any val nodes = new NodeLookup @@ -95,18 +96,15 @@ object ExpandWhens extends Pass { * @param p predicate so far, used to update simulation constructs * @param s statement to expand */ - def expandWhens(netlist: Netlist, - defaults: Defaults, - p: Expression) - (s: Statement): Statement = s match { + def expandWhens(netlist: Netlist, defaults: Defaults, p: Expression)(s: Statement): Statement = s match { // For each non-register declaration, update netlist with value WVoid for each sink reference // Return self, unchanged case stmt @ (_: DefNode | EmptyStmt) => stmt case w: DefWire => - netlist ++= (getSinkRefs(w.name, w.tpe, DuplexFlow) map (ref => we(ref) -> WVoid)) + netlist ++= (getSinkRefs(w.name, w.tpe, DuplexFlow).map(ref => we(ref) -> WVoid)) w case w: DefMemory => - netlist ++= (getSinkRefs(w.name, MemPortUtils.memType(w), SourceFlow) map (ref => we(ref) -> WVoid)) + netlist ++= (getSinkRefs(w.name, MemPortUtils.memType(w), SourceFlow).map(ref => we(ref) -> WVoid)) w case w: WDefInstance => netlist ++= (getSinkRefs(w.name, w.tpe, SourceFlow).map(ref => we(ref) -> WVoid)) @@ -151,82 +149,88 @@ object ExpandWhens extends Pass { // Process combined maps because we only want to create 1 mux for each node // present in the conseq and/or alt - val memos = (conseqNetlist ++ altNetlist) map { case (lvalue, _) => - // Defaults in netlist get priority over those in defaults - val default = netlist get lvalue match { - case Some(v) => Some(v) - case None => getDefault(lvalue, defaults) - } - // info0 and info1 correspond to Mux infos, use info0 only if ValidIf - val (res, info0, info1) = default match { - case Some(defaultValue) => - val (tinfo, trueValue) = unwrap(conseqNetlist.getOrElse(lvalue, defaultValue)) - val (finfo, falseValue) = unwrap(altNetlist.getOrElse(lvalue, defaultValue)) - (trueValue, falseValue) match { - case (WInvalid, WInvalid) => (WInvalid, NoInfo, NoInfo) - case (WInvalid, fv) => (ValidIf(NOT(sx.pred), fv, fv.tpe), finfo, NoInfo) - case (tv, WInvalid) => (ValidIf(sx.pred, tv, tv.tpe), tinfo, NoInfo) - case (tv, fv) => (Mux(sx.pred, tv, fv, mux_type_and_widths(tv, fv)), tinfo, finfo) - } - case None => - // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt - (conseqNetlist.getOrElse(lvalue, altNetlist(lvalue)), NoInfo, NoInfo) - } + val memos = (conseqNetlist ++ altNetlist).map { + case (lvalue, _) => + // Defaults in netlist get priority over those in defaults + val default = netlist.get(lvalue) match { + case Some(v) => Some(v) + case None => getDefault(lvalue, defaults) + } + // info0 and info1 correspond to Mux infos, use info0 only if ValidIf + val (res, info0, info1) = default match { + case Some(defaultValue) => + val (tinfo, trueValue) = unwrap(conseqNetlist.getOrElse(lvalue, defaultValue)) + val (finfo, falseValue) = unwrap(altNetlist.getOrElse(lvalue, defaultValue)) + (trueValue, falseValue) match { + case (WInvalid, WInvalid) => (WInvalid, NoInfo, NoInfo) + case (WInvalid, fv) => (ValidIf(NOT(sx.pred), fv, fv.tpe), finfo, NoInfo) + case (tv, WInvalid) => (ValidIf(sx.pred, tv, tv.tpe), tinfo, NoInfo) + case (tv, fv) => (Mux(sx.pred, tv, fv, mux_type_and_widths(tv, fv)), tinfo, finfo) + } + case None => + // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt + (conseqNetlist.getOrElse(lvalue, altNetlist(lvalue)), NoInfo, NoInfo) + } - res match { - // Don't create a node to hold mux trees with void values - // "Idiomatic" emission of these muxes isn't a concern because they represent bad code (latches) - case e if containsVoid(e) => - netlist(lvalue) = e - memoizedVoid += e // remember that this was void - EmptyStmt - case _: ValidIf | _: Mux | _: DoPrim => nodes get res match { - case Some(name) => - netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow) + res match { + // Don't create a node to hold mux trees with void values + // "Idiomatic" emission of these muxes isn't a concern because they represent bad code (latches) + case e if containsVoid(e) => + netlist(lvalue) = e + memoizedVoid += e // remember that this was void + EmptyStmt + case _: ValidIf | _: Mux | _: DoPrim => + nodes.get(res) match { + case Some(name) => + netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow) + EmptyStmt + case None => + val name = namespace.newTemp + nodes(res) = name + netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow) + // Use MultiInfo constructor to preserve NoInfos + val info = new MultiInfo(List(sx.info, info0, info1)) + DefNode(info, name, res) + } + case _ => + netlist(lvalue) = res EmptyStmt - case None => - val name = namespace.newTemp - nodes(res) = name - netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow) - // Use MultiInfo constructor to preserve NoInfos - val info = new MultiInfo(List(sx.info, info0, info1)) - DefNode(info, name, res) } - case _ => - netlist(lvalue) = res - EmptyStmt - } } Block(Seq(conseqStmt, altStmt) ++ memos) - case block: Block => block map expandWhens(netlist, defaults, p) + case block: Block => block.map(expandWhens(netlist, defaults, p)) case _ => throwInternalError() } val netlist = new Netlist // Add ports to netlist - netlist ++= (m.ports flatMap { case Port(_, name, dir, tpe) => - getSinkRefs(name, tpe, to_flow(dir)) map (ref => we(ref) -> WVoid) + netlist ++= (m.ports.flatMap { + case Port(_, name, dir, tpe) => + getSinkRefs(name, tpe, to_flow(dir)).map(ref => we(ref) -> WVoid) }) // Do traversal and construct mutable datastructures val bodyx = expandWhens(netlist, Seq(netlist), one)(m.body) val attachedAnalogs = attaches.flatMap(_.exprs.map(we)).toSet - val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist, attachedAnalogs) ++ - combineAttaches(attaches.toSeq) ++ simlist) + val newBody = Block( + Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist, attachedAnalogs) ++ + combineAttaches(attaches.toSeq) ++ simlist + ) Module(m.info, m.name, m.ports, newBody) } - /** Returns all references to all sink leaf subcomponents of a reference */ private def getSinkRefs(n: String, t: Type, g: Flow): Seq[Expression] = { val exps = create_exps(WRef(n, t, ExpKind, g)) - exps.flatMap { case exp => - exp.tpe match { - case AnalogType(w) => None - case _ => flow(exp) match { - case (DuplexFlow | SinkFlow) => Some(exp) - case _ => None + exps.flatMap { + case exp => + exp.tpe match { + case AnalogType(w) => None + case _ => + flow(exp) match { + case (DuplexFlow | SinkFlow) => Some(exp) + case _ => None + } } - } } } @@ -238,7 +242,7 @@ object ExpandWhens extends Pass { def handleInvalid(k: WrappedExpression, info: Info): Statement = if (attached.contains(k)) EmptyStmt else IsInvalid(info, k.e1) netlist.map { - case (k, WInvalid) => handleInvalid(k, NoInfo) + case (k, WInvalid) => handleInvalid(k, NoInfo) case (k, InfoExpr(info, WInvalid)) => handleInvalid(k, info) case (k, v) => val (info, expr) = unwrap(v) @@ -261,7 +265,7 @@ object ExpandWhens extends Pass { case Seq() => // None of these expressions is present in the attachMap AttachAcc(exprs, attachMap.size) case accs => // At least one expression present in the attachMap - val sorted = accs sortBy (_.idx) + val sorted = accs.sortBy(_.idx) AttachAcc((sorted.map(_.exprs) :+ exprs).flatten.distinct, sorted.head.idx) } attachMap ++= acc.exprs.map(_ -> acc) @@ -274,10 +278,11 @@ object ExpandWhens extends Pass { private def getDefault(lvalue: WrappedExpression, defaults: Defaults): Option[Expression] = { defaults match { case Nil => None - case head :: tail => head get lvalue match { - case Some(p) => Some(p) - case None => getDefault(lvalue, tail) - } + case head :: tail => + head.get(lvalue) match { + case Some(p) => Some(p) + case None => getDefault(lvalue, tail) + } } } @@ -290,10 +295,12 @@ object ExpandWhens extends Pass { class ExpandWhensAndCheck extends Transform with DependencyAPIMigration { override def prerequisites = - Seq( Dependency(PullMuxes), - Dependency(ReplaceAccesses), - Dependency(ExpandConnects), - Dependency(RemoveAccesses) ) ++ firrtl.stage.Forms.Deduped + Seq( + Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects), + Dependency(RemoveAccesses) + ) ++ firrtl.stage.Forms.Deduped override def invalidates(a: Transform): Boolean = a match { case ResolveKinds | InferTypes | ResolveFlows | _: InferWidths => true @@ -301,6 +308,6 @@ class ExpandWhensAndCheck extends Transform with DependencyAPIMigration { } override def execute(a: CircuitState): CircuitState = - Seq(ExpandWhens, CheckInitialization).foldLeft(a){ case (acc, tx) => tx.transform(acc) } + Seq(ExpandWhens, CheckInitialization).foldLeft(a) { case (acc, tx) => tx.transform(acc) } } diff --git a/src/main/scala/firrtl/passes/InferBinaryPoints.scala b/src/main/scala/firrtl/passes/InferBinaryPoints.scala index a16205a7..f393d8a5 100644 --- a/src/main/scala/firrtl/passes/InferBinaryPoints.scala +++ b/src/main/scala/firrtl/passes/InferBinaryPoints.scala @@ -13,9 +13,7 @@ import firrtl.options.Dependency class InferBinaryPoints extends Pass { override def prerequisites = - Seq( Dependency(ResolveKinds), - Dependency(InferTypes), - Dependency(ResolveFlows) ) + Seq(Dependency(ResolveKinds), Dependency(InferTypes), Dependency(ResolveFlows)) override def optionalPrerequisiteOf = Seq.empty @@ -23,12 +21,12 @@ class InferBinaryPoints extends Pass { private val constraintSolver = new ConstraintSolver() - private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1,t2) match { - case (UIntType(w1), UIntType(w2)) => - case (SIntType(w1), SIntType(w2)) => - case (ClockType, ClockType) => - case (ResetType, _) => - case (_, ResetType) => + private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1, t2) match { + case (UIntType(w1), UIntType(w2)) => + case (SIntType(w1), SIntType(w2)) => + case (ClockType, ClockType) => + case (ResetType, _) => + case (_, ResetType) => case (AsyncResetType, AsyncResetType) => case (FixedType(w1, p1), FixedType(w2, p2)) => constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint("")) @@ -36,78 +34,86 @@ class InferBinaryPoints extends Pass { constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint("")) case (AnalogType(w1), AnalogType(w2)) => case (t1: BundleType, t2: BundleType) => - (t1.fields zip t2.fields) foreach { case (f1, f2) => - (f1.flip, f2.flip) match { - case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe) - case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe) - case _ => sys.error("Shouldn't be here") - } + (t1.fields.zip(t2.fields)).foreach { + case (f1, f2) => + (f1.flip, f2.flip) match { + case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe) + case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe) + case _ => sys.error("Shouldn't be here") + } } case (t1: VectorType, t2: VectorType) => addTypeConstraints(r1.index(0), r2.index(0))(t1.tpe, t2.tpe) case other => throwInternalError(s"Illegal compiler state: cannot constraint different types - $other") } - private def addDecConstraints(t: Type): Type = t map addDecConstraints - private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s map addDecConstraints match { + private def addDecConstraints(t: Type): Type = t.map(addDecConstraints) + private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s.map(addDecConstraints) match { case c: Connect => val n = get_size(c.loc.tpe) val locs = create_exps(c.loc) val exps = create_exps(c.expr) - (locs zip exps) foreach { case (loc, exp) => - to_flip(flow(loc)) match { - case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) - case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) - } + (locs.zip(exps)).foreach { + case (loc, exp) => + to_flip(flow(loc)) match { + case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) + case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) + } } c case pc: PartialConnect => val ls = get_valid_points(pc.loc.tpe, pc.expr.tpe, Default, Default) val locs = create_exps(pc.loc) val exps = create_exps(pc.expr) - ls foreach { case (x, y) => - val loc = locs(x) - val exp = exps(y) - to_flip(flow(loc)) match { - case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) - case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) - } + ls.foreach { + case (x, y) => + val loc = locs(x) + val exp = exps(y) + to_flip(flow(loc)) match { + case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) + case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) + } } pc case r: DefRegister => - addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe) + addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe) r - case x => x map addStmtConstraints(mt) + case x => x.map(addStmtConstraints(mt)) } private def fixWidth(w: Width): Width = constraintSolver.get(w) match { case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) - case None => w - case _ => sys.error("Shouldn't be here") + case None => w + case _ => sys.error("Shouldn't be here") } - private def fixType(t: Type): Type = t map fixType map fixWidth match { + private def fixType(t: Type): Type = t.map(fixType).map(fixWidth) match { case IntervalType(l, u, p) => val px = constraintSolver.get(p) match { case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) - case None => p - case _ => sys.error("Shouldn't be here") + case None => p + case _ => sys.error("Shouldn't be here") } IntervalType(l, u, px) case FixedType(w, p) => val px = constraintSolver.get(p) match { case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) - case None => p - case _ => sys.error("Shouldn't be here") + case None => p + case _ => sys.error("Shouldn't be here") } FixedType(w, px) case x => x } - private def fixStmt(s: Statement): Statement = s map fixStmt map fixType - private def fixPort(p: Port): Port = Port(p.info, p.name, p.direction, fixType(p.tpe)) - def run (c: Circuit): Circuit = { + private def fixStmt(s: Statement): Statement = s.map(fixStmt).map(fixType) + private def fixPort(p: Port): Port = Port(p.info, p.name, p.direction, fixType(p.tpe)) + def run(c: Circuit): Circuit = { val ct = CircuitTarget(c.main) - c.modules foreach (m => m map addStmtConstraints(ct.module(m.name))) - c.modules foreach (_.ports foreach {p => addDecConstraints(p.tpe)}) + c.modules.foreach(m => m.map(addStmtConstraints(ct.module(m.name)))) + c.modules.foreach(_.ports.foreach { p => addDecConstraints(p.tpe) }) constraintSolver.solve() - InferTypes.run(c.copy(modules = c.modules map (_ - map fixPort - map fixStmt))) + InferTypes.run( + c.copy(modules = + c.modules.map( + _.map(fixPort) + .map(fixStmt) + ) + ) + ) } } diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala index 6cc9f2b9..4d14e7ff 100644 --- a/src/main/scala/firrtl/passes/InferTypes.scala +++ b/src/main/scala/firrtl/passes/InferTypes.scala @@ -23,16 +23,16 @@ object InferTypes extends Pass { def remove_unknowns_b(b: Bound): Bound = b match { case UnknownBound => VarBound(namespace.newName("b")) - case k => k + case k => k } def remove_unknowns_w(w: Width): Width = w match { case UnknownWidth => VarWidth(namespace.newName("w")) - case wx => wx + case wx => wx } def remove_unknowns(t: Type): Type = { - t map remove_unknowns map remove_unknowns_w match { + t.map(remove_unknowns).map(remove_unknowns_w) match { case IntervalType(l, u, p) => IntervalType(remove_unknowns_b(l), remove_unknowns_b(u), p) case x => x @@ -41,18 +41,18 @@ object InferTypes extends Pass { // we first need to remove the unknown widths and bounds from all ports, // as their type will determine the module types - val portsKnown = c.modules.map(_.map{ p: Port => p.copy(tpe = remove_unknowns(p.tpe)) }) + val portsKnown = c.modules.map(_.map { p: Port => p.copy(tpe = remove_unknowns(p.tpe)) }) val mtypes = portsKnown.map(m => m.name -> module_type(m)).toMap def infer_types_e(types: TypeLookup)(e: Expression): Expression = - e map infer_types_e(types) match { - case e: WRef => e copy (tpe = types(e.name)) - case e: WSubField => e copy (tpe = field_type(e.expr.tpe, e.name)) - case e: WSubIndex => e copy (tpe = sub_type(e.expr.tpe)) - case e: WSubAccess => e copy (tpe = sub_type(e.expr.tpe)) - case e: DoPrim => PrimOps.set_primop_type(e) - case e: Mux => e copy (tpe = mux_type_and_widths(e.tval, e.fval)) - case e: ValidIf => e copy (tpe = e.value.tpe) + e.map(infer_types_e(types)) match { + case e: WRef => e.copy(tpe = types(e.name)) + case e: WSubField => e.copy(tpe = field_type(e.expr.tpe, e.name)) + case e: WSubIndex => e.copy(tpe = sub_type(e.expr.tpe)) + case e: WSubAccess => e.copy(tpe = sub_type(e.expr.tpe)) + case e: DoPrim => PrimOps.set_primop_type(e) + case e: Mux => e.copy(tpe = mux_type_and_widths(e.tval, e.fval)) + case e: ValidIf => e.copy(tpe = e.value.tpe) case e @ (_: UIntLiteral | _: SIntLiteral) => e } @@ -60,37 +60,37 @@ object InferTypes extends Pass { case sx: WDefInstance => val t = mtypes(sx.module) types(sx.name) = t - sx copy (tpe = t) + sx.copy(tpe = t) case sx: DefWire => val t = remove_unknowns(sx.tpe) types(sx.name) = t - sx copy (tpe = t) + sx.copy(tpe = t) case sx: DefNode => - val sxx = (sx map infer_types_e(types)).asInstanceOf[DefNode] + val sxx = (sx.map(infer_types_e(types))).asInstanceOf[DefNode] val t = remove_unknowns(sxx.value.tpe) types(sx.name) = t sxx case sx: DefRegister => val t = remove_unknowns(sx.tpe) types(sx.name) = t - sx copy (tpe = t) map infer_types_e(types) + sx.copy(tpe = t).map(infer_types_e(types)) case sx: DefMemory => // we need to remove the unknowns from the data type so that all ports get the same VarWidth val knownDataType = sx.copy(dataType = remove_unknowns(sx.dataType)) types(sx.name) = MemPortUtils.memType(knownDataType) knownDataType - case sx => sx map infer_types_s(types) map infer_types_e(types) + case sx => sx.map(infer_types_s(types)).map(infer_types_e(types)) } def infer_types_p(types: TypeLookup)(p: Port): Port = { val t = remove_unknowns(p.tpe) types(p.name) = t - p copy (tpe = t) + p.copy(tpe = t) } def infer_types(m: DefModule): DefModule = { val types = new TypeLookup - m map infer_types_p(types) map infer_types_s(types) + m.map(infer_types_p(types)).map(infer_types_s(types)) } c.copy(modules = portsKnown.map(infer_types)) @@ -108,45 +108,45 @@ object CInferTypes extends Pass { private type TypeLookup = collection.mutable.HashMap[String, Type] def run(c: Circuit): Circuit = { - val mtypes = (c.modules map (m => m.name -> module_type(m))).toMap - - def infer_types_e(types: TypeLookup)(e: Expression) : Expression = - e map infer_types_e(types) match { - case (e: Reference) => e copy (tpe = types.getOrElse(e.name, UnknownType)) - case (e: SubField) => e copy (tpe = field_type(e.expr.tpe, e.name)) - case (e: SubIndex) => e copy (tpe = sub_type(e.expr.tpe)) - case (e: SubAccess) => e copy (tpe = sub_type(e.expr.tpe)) - case (e: DoPrim) => PrimOps.set_primop_type(e) - case (e: Mux) => e copy (tpe = mux_type(e.tval, e.fval)) - case (e: ValidIf) => e copy (tpe = e.value.tpe) - case e @ (_: UIntLiteral | _: SIntLiteral) => e + val mtypes = (c.modules.map(m => m.name -> module_type(m))).toMap + + def infer_types_e(types: TypeLookup)(e: Expression): Expression = + e.map(infer_types_e(types)) match { + case (e: Reference) => e.copy(tpe = types.getOrElse(e.name, UnknownType)) + case (e: SubField) => e.copy(tpe = field_type(e.expr.tpe, e.name)) + case (e: SubIndex) => e.copy(tpe = sub_type(e.expr.tpe)) + case (e: SubAccess) => e.copy(tpe = sub_type(e.expr.tpe)) + case (e: DoPrim) => PrimOps.set_primop_type(e) + case (e: Mux) => e.copy(tpe = mux_type(e.tval, e.fval)) + case (e: ValidIf) => e.copy(tpe = e.value.tpe) + case e @ (_: UIntLiteral | _: SIntLiteral) => e } def infer_types_s(types: TypeLookup)(s: Statement): Statement = s match { case sx: DefRegister => types(sx.name) = sx.tpe - sx map infer_types_e(types) + sx.map(infer_types_e(types)) case sx: DefWire => types(sx.name) = sx.tpe sx case sx: DefNode => - val sxx = (sx map infer_types_e(types)).asInstanceOf[DefNode] + val sxx = (sx.map(infer_types_e(types))).asInstanceOf[DefNode] types(sxx.name) = sxx.value.tpe sxx case sx: DefMemory => types(sx.name) = MemPortUtils.memType(sx) sx case sx: CDefMPort => - val t = types getOrElse(sx.mem, UnknownType) + val t = types.getOrElse(sx.mem, UnknownType) types(sx.name) = t - sx copy (tpe = t) + sx.copy(tpe = t) case sx: CDefMemory => types(sx.name) = sx.tpe sx case sx: DefInstance => types(sx.name) = mtypes(sx.module) sx - case sx => sx map infer_types_s(types) map infer_types_e(types) + case sx => sx.map(infer_types_s(types)).map(infer_types_e(types)) } def infer_types_p(types: TypeLookup)(p: Port): Port = { @@ -156,9 +156,9 @@ object CInferTypes extends Pass { def infer_types(m: DefModule): DefModule = { val types = new TypeLookup - m map infer_types_p(types) map infer_types_s(types) + m.map(infer_types_p(types)).map(infer_types_s(types)) } - c copy (modules = c.modules map infer_types) + c.copy(modules = c.modules.map(infer_types)) } } diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index 3720523b..eae9690f 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -14,7 +14,7 @@ import firrtl.options.Dependency object InferWidths { def apply(): InferWidths = new InferWidths() - def run(c: Circuit): Circuit = new InferWidths().run(c)(new ConstraintSolver) + def run(c: Circuit): Circuit = new InferWidths().run(c)(new ConstraintSolver) def execute(state: CircuitState): CircuitState = new InferWidths().execute(state) } @@ -22,12 +22,14 @@ case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarg def update(renameMap: RenameMap): Seq[WidthGeqConstraintAnnotation] = { val newLoc :: newExp :: Nil = Seq(loc, exp).map { target => renameMap.get(target) match { - case None => Some(target) - case Some(Seq()) => None + case None => Some(target) + case Some(Seq()) => None case Some(Seq(one)) => Some(one) case Some(many) => - throw new Exception(s"Target below is an AggregateType, which " + - "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint()) + throw new Exception( + s"Target below is an AggregateType, which " + + "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint() + ) } } @@ -60,28 +62,31 @@ case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarg * * Uses firrtl.constraint package to infer widths */ -class InferWidths extends Transform - with ResolvedAnnotationPaths - with DependencyAPIMigration { +class InferWidths extends Transform with ResolvedAnnotationPaths with DependencyAPIMigration { override def prerequisites = - Seq( Dependency(passes.ResolveKinds), - Dependency(passes.InferTypes), - Dependency(passes.ResolveFlows), - Dependency[passes.InferBinaryPoints], - Dependency[passes.TrimIntervals] ) ++ firrtl.stage.Forms.WorkingIR + Seq( + Dependency(passes.ResolveKinds), + Dependency(passes.InferTypes), + Dependency(passes.ResolveFlows), + Dependency[passes.InferBinaryPoints], + Dependency[passes.TrimIntervals] + ) ++ firrtl.stage.Forms.WorkingIR override def invalidates(a: Transform) = false val annotationClasses = Seq(classOf[WidthGeqConstraintAnnotation]) - private def addTypeConstraints - (r1: ReferenceTarget, r2: ReferenceTarget) - (t1: Type, t2: Type) - (implicit constraintSolver: ConstraintSolver) - : Unit = (t1,t2) match { + private def addTypeConstraints( + r1: ReferenceTarget, + r2: ReferenceTarget + )(t1: Type, + t2: Type + )( + implicit constraintSolver: ConstraintSolver + ): Unit = (t1, t2) match { case (UIntType(w1), UIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint("")) case (SIntType(w1), SIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint("")) - case (ClockType, ClockType) => + case (ClockType, ClockType) => case (FixedType(w1, p1), FixedType(w2, p2)) => constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint("")) constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint("")) @@ -93,101 +98,119 @@ class InferWidths extends Transform constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint("")) constraintSolver.addGeq(w2, w1, r1.prettyPrint(""), r2.prettyPrint("")) case (t1: BundleType, t2: BundleType) => - (t1.fields zip t2.fields) foreach { case (f1, f2) => - (f1.flip, f2.flip) match { - case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe) - case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe) - case _ => sys.error("Shouldn't be here") - } + (t1.fields.zip(t2.fields)).foreach { + case (f1, f2) => + (f1.flip, f2.flip) match { + case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe) + case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe) + case _ => sys.error("Shouldn't be here") + } } case (t1: VectorType, t2: VectorType) => addTypeConstraints(r1.index(0), r2.index(0))(t1.tpe, t2.tpe) case (AsyncResetType, AsyncResetType) => Nil - case (ResetType, _) => Nil - case (_, ResetType) => Nil + case (ResetType, _) => Nil + case (_, ResetType) => Nil } - private def addExpConstraints(e: Expression)(implicit constraintSolver: ConstraintSolver) - : Expression = e map addExpConstraints match { - case m@Mux(p, tVal, fVal, t) => - constraintSolver.addGeq(getWidth(p), Closed(1), "mux predicate", "1.W") - m - case other => other - } + private def addExpConstraints(e: Expression)(implicit constraintSolver: ConstraintSolver): Expression = + e.map(addExpConstraints) match { + case m @ Mux(p, tVal, fVal, t) => + constraintSolver.addGeq(getWidth(p), Closed(1), "mux predicate", "1.W") + m + case other => other + } - private def addStmtConstraints(mt: ModuleTarget)(s: Statement)(implicit constraintSolver: ConstraintSolver) - : Statement = s map addExpConstraints match { + private def addStmtConstraints( + mt: ModuleTarget + )(s: Statement + )( + implicit constraintSolver: ConstraintSolver + ): Statement = s.map(addExpConstraints) match { case c: Connect => val n = get_size(c.loc.tpe) val locs = create_exps(c.loc) val exps = create_exps(c.expr) - (locs zip exps).foreach { case (loc, exp) => - to_flip(flow(loc)) match { - case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) - case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) - } - } + (locs.zip(exps)).foreach { + case (loc, exp) => + to_flip(flow(loc)) match { + case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) + case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) + } + } c case pc: PartialConnect => val ls = get_valid_points(pc.loc.tpe, pc.expr.tpe, Default, Default) val locs = create_exps(pc.loc) val exps = create_exps(pc.expr) - ls foreach { case (x, y) => - val loc = locs(x) - val exp = exps(y) - to_flip(flow(loc)) match { - case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) - case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) - } + ls.foreach { + case (x, y) => + val loc = locs(x) + val exp = exps(y) + to_flip(flow(loc)) match { + case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) + case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) + } } pc case r: DefRegister => - if (r.reset.tpe != AsyncResetType ) { + if (r.reset.tpe != AsyncResetType) { addTypeConstraints(Target.asTarget(mt)(r.reset), mt.ref("1"))(r.reset.tpe, UIntType(IntWidth(1))) } addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe) r - case a@Attach(_, exprs) => - val widths = exprs map (e => (e, getWidth(e.tpe))) + case a @ Attach(_, exprs) => + val widths = exprs.map(e => (e, getWidth(e.tpe))) val maxWidth = IsMax(widths.map(x => width2constraint(x._2))) - widths.foreach { case (e, w) => - constraintSolver.addGeq(w, CalcWidth(maxWidth), Target.asTarget(mt)(e).prettyPrint(""), mt.ref(a.serialize).prettyPrint("")) + widths.foreach { + case (e, w) => + constraintSolver.addGeq( + w, + CalcWidth(maxWidth), + Target.asTarget(mt)(e).prettyPrint(""), + mt.ref(a.serialize).prettyPrint("") + ) } a case c: Conditionally => addTypeConstraints(Target.asTarget(mt)(c.pred), mt.ref("1.W"))(c.pred.tpe, UIntType(IntWidth(1))) - c map addStmtConstraints(mt) - case x => x map addStmtConstraints(mt) + c.map(addStmtConstraints(mt)) + case x => x.map(addStmtConstraints(mt)) } private def fixWidth(w: Width)(implicit constraintSolver: ConstraintSolver): Width = constraintSolver.get(w) match { case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) - case None => w - case _ => sys.error("Shouldn't be here") + case None => w + case _ => sys.error("Shouldn't be here") } - private def fixType(t: Type)(implicit constraintSolver: ConstraintSolver): Type = t map fixType map fixWidth match { + private def fixType(t: Type)(implicit constraintSolver: ConstraintSolver): Type = t.map(fixType).map(fixWidth) match { case IntervalType(l, u, p) => val (lx, ux) = (constraintSolver.get(l), constraintSolver.get(u)) match { case (Some(x: Bound), Some(y: Bound)) => (x, y) case (None, None) => (l, u) - case x => sys.error(s"Shouldn't be here: $x") - + case x => sys.error(s"Shouldn't be here: $x") } IntervalType(lx, ux, fixWidth(p)) case FixedType(w, p) => FixedType(w, fixWidth(p)) - case x => x + case x => x } - private def fixStmt(s: Statement)(implicit constraintSolver: ConstraintSolver): Statement = s map fixStmt map fixType + private def fixStmt(s: Statement)(implicit constraintSolver: ConstraintSolver): Statement = + s.map(fixStmt).map(fixType) private def fixPort(p: Port)(implicit constraintSolver: ConstraintSolver): Port = { Port(p.info, p.name, p.direction, fixType(p.tpe)) } - def run (c: Circuit)(implicit constraintSolver: ConstraintSolver): Circuit = { + def run(c: Circuit)(implicit constraintSolver: ConstraintSolver): Circuit = { val ct = CircuitTarget(c.main) - c.modules foreach ( m => m map addStmtConstraints(ct.module(m.name))) + c.modules.foreach(m => m.map(addStmtConstraints(ct.module(m.name)))) constraintSolver.solve() - val ret = InferTypes.run(c.copy(modules = c.modules map (_ - map fixPort - map fixStmt))) + val ret = InferTypes.run( + c.copy(modules = + c.modules.map( + _.map(fixPort) + .map(fixStmt) + ) + ) + ) constraintSolver.clear() ret } @@ -200,15 +223,16 @@ class InferWidths extends Transform def getDeclTypes(modName: String)(stmt: Statement): Unit = { val pairOpt = stmt match { - case w: DefWire => Some(w.name -> w.tpe) - case r: DefRegister => Some(r.name -> r.tpe) - case n: DefNode => Some(n.name -> n.value.tpe) + case w: DefWire => Some(w.name -> w.tpe) + case r: DefRegister => Some(r.name -> r.tpe) + case n: DefNode => Some(n.name -> n.value.tpe) case i: WDefInstance => Some(i.name -> i.tpe) - case m: DefMemory => Some(m.name -> MemPortUtils.memType(m)) + case m: DefMemory => Some(m.name -> MemPortUtils.memType(m)) case other => None } - pairOpt.foreach { case (ref, tpe) => - typeMap += (ReferenceTarget(circuitName, modName, Nil, ref, Nil) -> tpe) + pairOpt.foreach { + case (ref, tpe) => + typeMap += (ReferenceTarget(circuitName, modName, Nil, ref, Nil) -> tpe) } stmt.foreachStmt(getDeclTypes(modName)) } @@ -223,14 +247,20 @@ class InferWidths extends Transform } state.annotations.foreach { - case anno: WidthGeqConstraintAnnotation if anno.loc.isLocal && anno.exp.isLocal => - val locType :: expType :: Nil = Seq(anno.loc, anno.exp) map { target => - val baseType = typeMap.getOrElse(target.copy(component = Seq.empty), - throw new Exception(s"Target below from WidthGeqConstraintAnnotation was not found\n" + target.prettyPrint())) + case anno: WidthGeqConstraintAnnotation if anno.loc.isLocal && anno.exp.isLocal => + val locType :: expType :: Nil = Seq(anno.loc, anno.exp).map { target => + val baseType = typeMap.getOrElse( + target.copy(component = Seq.empty), + throw new Exception( + s"Target below from WidthGeqConstraintAnnotation was not found\n" + target.prettyPrint() + ) + ) val leafType = target.componentType(baseType) if (leafType.isInstanceOf[AggregateType]) { - throw new Exception(s"Target below is an AggregateType, which " + - "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint()) + throw new Exception( + s"Target below is an AggregateType, which " + + "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint() + ) } leafType diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala index ad963b19..316878fb 100644 --- a/src/main/scala/firrtl/passes/Inline.scala +++ b/src/main/scala/firrtl/passes/Inline.scala @@ -32,89 +32,100 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe override def invalidates(a: Transform): Boolean = a == ResolveKinds - private [firrtl] val inlineDelim: String = "_" + private[firrtl] val inlineDelim: String = "_" val options = Seq( new ShellOption[Seq[String]]( longOption = "inline", - toAnnotationSeq = (a: Seq[String]) => a.map { value => - value.split('.') match { - case Array(circuit) => - InlineAnnotation(CircuitName(circuit)) - case Array(circuit, module) => - InlineAnnotation(ModuleName(module, CircuitName(circuit))) - case Array(circuit, module, inst) => - InlineAnnotation(ComponentName(inst, ModuleName(module, CircuitName(circuit)))) - } - } :+ RunFirrtlTransformAnnotation(new InlineInstances), + toAnnotationSeq = (a: Seq[String]) => + a.map { value => + value.split('.') match { + case Array(circuit) => + InlineAnnotation(CircuitName(circuit)) + case Array(circuit, module) => + InlineAnnotation(ModuleName(module, CircuitName(circuit))) + case Array(circuit, module, inst) => + InlineAnnotation(ComponentName(inst, ModuleName(module, CircuitName(circuit)))) + } + } :+ RunFirrtlTransformAnnotation(new InlineInstances), helpText = "Inline selected modules", shortOption = Some("fil"), - helpValueName = Some("<circuit>[.<module>[.<instance>]][,...]") ) ) - - private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) = - anns.foldLeft( (Set.empty[ModuleName], Set.empty[ComponentName]) ) { - case ((modNames, instNames), ann) => ann match { - case InlineAnnotation(CircuitName(c)) => - (circuit.modules.collect { - case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c)) - }.toSet, instNames) - case InlineAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames) - case InlineAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod)) - case _ => (modNames, instNames) - } - } - - def execute(state: CircuitState): CircuitState = { - // TODO Add error check for more than one annotation for inlining - val (modNames, instNames) = collectAnns(state.circuit, state.annotations) - if (modNames.nonEmpty || instNames.nonEmpty) { - run(state.circuit, modNames, instNames, state.annotations) - } else { - state - } - } - - // Checks the following properties: - // 1) All annotated modules exist - // 2) All annotated modules are InModules (can be inlined) - // 3) All annotated instances exist, and their modules can be inline - def check(c: Circuit, moduleNames: Set[ModuleName], instanceNames: Set[ComponentName]): Unit = { - val errors = mutable.ArrayBuffer[PassException]() - val moduleMap = InstanceKeyGraph(c).moduleMap - def checkExists(name: String): Unit = - if (!moduleMap.contains(name)) - errors += new PassException(s"Annotated module does not exist: $name") - def checkExternal(name: String): Unit = moduleMap(name) match { - case m: ExtModule => errors += new PassException(s"Annotated module cannot be an external module: $name") - case _ => - } - def checkInstance(cn: ComponentName): Unit = { - var containsCN = false - def onStmt(name: String)(s: Statement): Statement = { - s match { - case WDefInstance(_, inst_name, module_name, tpe) => - if (name == inst_name) { - containsCN = true - checkExternal(module_name) - } - case _ => + helpValueName = Some("<circuit>[.<module>[.<instance>]][,...]") + ) + ) + + private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) = + anns.foldLeft((Set.empty[ModuleName], Set.empty[ComponentName])) { + case ((modNames, instNames), ann) => + ann match { + case InlineAnnotation(CircuitName(c)) => + ( + circuit.modules.collect { + case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c)) + }.toSet, + instNames + ) + case InlineAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames) + case InlineAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod)) + case _ => (modNames, instNames) + } + } + + def execute(state: CircuitState): CircuitState = { + // TODO Add error check for more than one annotation for inlining + val (modNames, instNames) = collectAnns(state.circuit, state.annotations) + if (modNames.nonEmpty || instNames.nonEmpty) { + run(state.circuit, modNames, instNames, state.annotations) + } else { + state + } + } + + // Checks the following properties: + // 1) All annotated modules exist + // 2) All annotated modules are InModules (can be inlined) + // 3) All annotated instances exist, and their modules can be inline + def check(c: Circuit, moduleNames: Set[ModuleName], instanceNames: Set[ComponentName]): Unit = { + val errors = mutable.ArrayBuffer[PassException]() + val moduleMap = InstanceKeyGraph(c).moduleMap + def checkExists(name: String): Unit = + if (!moduleMap.contains(name)) + errors += new PassException(s"Annotated module does not exist: $name") + def checkExternal(name: String): Unit = moduleMap(name) match { + case m: ExtModule => errors += new PassException(s"Annotated module cannot be an external module: $name") + case _ => + } + def checkInstance(cn: ComponentName): Unit = { + var containsCN = false + def onStmt(name: String)(s: Statement): Statement = { + s match { + case WDefInstance(_, inst_name, module_name, tpe) => + if (name == inst_name) { + containsCN = true + checkExternal(module_name) } - s map onStmt(name) - } - onStmt(cn.name)(moduleMap(cn.module.name).asInstanceOf[Module].body) - if (!containsCN) errors += new PassException(s"Annotated instance does not exist: ${cn.module.name}.${cn.name}") + case _ => + } + s.map(onStmt(name)) } + onStmt(cn.name)(moduleMap(cn.module.name).asInstanceOf[Module].body) + if (!containsCN) errors += new PassException(s"Annotated instance does not exist: ${cn.module.name}.${cn.name}") + } - moduleNames.foreach{mn => checkExists(mn.name)} - if (errors.nonEmpty) throw new PassExceptions(errors.toSeq) - moduleNames.foreach{mn => checkExternal(mn.name)} - if (errors.nonEmpty) throw new PassExceptions(errors.toSeq) - instanceNames.foreach{cn => checkInstance(cn)} - if (errors.nonEmpty) throw new PassExceptions(errors.toSeq) - } - + moduleNames.foreach { mn => checkExists(mn.name) } + if (errors.nonEmpty) throw new PassExceptions(errors.toSeq) + moduleNames.foreach { mn => checkExternal(mn.name) } + if (errors.nonEmpty) throw new PassExceptions(errors.toSeq) + instanceNames.foreach { cn => checkInstance(cn) } + if (errors.nonEmpty) throw new PassExceptions(errors.toSeq) + } - def run(c: Circuit, modsToInline: Set[ModuleName], instsToInline: Set[ComponentName], annos: AnnotationSeq): CircuitState = { + def run( + c: Circuit, + modsToInline: Set[ModuleName], + instsToInline: Set[ComponentName], + annos: AnnotationSeq + ): CircuitState = { def getInstancesOf(c: Circuit, modules: Set[String]): Set[(OfModule, Instance)] = c.modules.foldLeft(Set[(OfModule, Instance)]()) { (set, d) => d match { @@ -125,7 +136,7 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe case WDefInstance(info, instName, moduleName, instTpe) if modules.contains(moduleName) => instances += (OfModule(m.name) -> Instance(instName)) s - case sx => sx map findInstances + case sx => sx.map(findInstances) } findInstances(m.body) instances.toSet ++ set @@ -135,7 +146,8 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe // Check annotations and circuit match up check(c, modsToInline, instsToInline) val flatModules = modsToInline.map(m => m.name) - val flatInstances: Set[(OfModule, Instance)] = instsToInline.map(i => OfModule(i.module.name) -> Instance(i.name)) ++ getInstancesOf(c, flatModules) + val flatInstances: Set[(OfModule, Instance)] = + instsToInline.map(i => OfModule(i.module.name) -> Instance(i.name)) ++ getInstancesOf(c, flatModules) val iGraph = InstanceKeyGraph(c) val namespaceMap = collection.mutable.Map[String, Namespace]() // Map of Module name to Map of instance name to Module name @@ -144,11 +156,13 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe /** Add a prefix to all declarations updating a [[Namespace]] and appending to a [[RenameMap]] */ def appendNamePrefix( currentModule: IsModule, - nextModule: IsModule, - prefix: String, - ns: Namespace, - renames: mutable.HashMap[String, String], - renameMap: RenameMap)(s: Statement): Statement = { + nextModule: IsModule, + prefix: String, + ns: Namespace, + renames: mutable.HashMap[String, String], + renameMap: RenameMap + )(s: Statement + ): Statement = { def onName(ofModuleOpt: Option[String])(name: String) = { if (prefix.nonEmpty && !ns.tryName(prefix + name)) { throw new Exception(s"Inlining failed. Inlined name '${prefix + name}' already exists") @@ -164,25 +178,29 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe } s match { - case s: WDefInstance => s.map(onName(Some(s.module))).map(appendNamePrefix(currentModule, nextModule, prefix, ns, renames, renameMap)) - case other => s.map(onName(None)).map(appendNamePrefix(currentModule, nextModule, prefix, ns, renames, renameMap)) + case s: WDefInstance => + s.map(onName(Some(s.module))).map(appendNamePrefix(currentModule, nextModule, prefix, ns, renames, renameMap)) + case other => + s.map(onName(None)).map(appendNamePrefix(currentModule, nextModule, prefix, ns, renames, renameMap)) } } /** Modify all references */ def appendRefPrefix( currentModule: IsModule, - renames: mutable.HashMap[String, String])(s: Statement): Statement = { - def onExpr(e: Expression): Expression = e match { - case wr@ WRef(name, _, _, _) => - renames.get(name) match { - case Some(prefixedName) => wr.copy(name = prefixedName) - case None => wr - } - case ex => ex.map(onExpr) - } - s.map(onExpr).map(appendRefPrefix(currentModule, renames)) + renames: mutable.HashMap[String, String] + )(s: Statement + ): Statement = { + def onExpr(e: Expression): Expression = e match { + case wr @ WRef(name, _, _, _) => + renames.get(name) match { + case Some(prefixedName) => wr.copy(name = prefixedName) + case None => wr + } + case ex => ex.map(onExpr) } + s.map(onExpr).map(appendRefPrefix(currentModule, renames)) + } val cache = mutable.HashMap.empty[ModuleTarget, Statement] @@ -194,16 +212,19 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe val (renamesMap, renamesSeq) = { val mutableDiGraph = new MutableDiGraph[(OfModule, Instance)] // compute instance graph - instMaps.foreach { case (grandParentOfMod, parents) => - parents.foreach { case (parentInst, parentOfMod) => - val from = grandParentOfMod -> parentInst - mutableDiGraph.addVertex(from) - instMaps(parentOfMod).foreach { case (childInst, _) => - val to = parentOfMod -> childInst - mutableDiGraph.addVertex(to) - mutableDiGraph.addEdge(from, to) + instMaps.foreach { + case (grandParentOfMod, parents) => + parents.foreach { + case (parentInst, parentOfMod) => + val from = grandParentOfMod -> parentInst + mutableDiGraph.addVertex(from) + instMaps(parentOfMod).foreach { + case (childInst, _) => + val to = parentOfMod -> childInst + mutableDiGraph.addVertex(to) + mutableDiGraph.addEdge(from, to) + } } - } } val diGraph = DiGraph(mutableDiGraph) @@ -226,10 +247,12 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe } def fixupRefs( - instMap: collection.Map[Instance, OfModule], - currentModule: IsModule)(e: Expression): Expression = { + instMap: collection.Map[Instance, OfModule], + currentModule: IsModule + )(e: Expression + ): Expression = { e match { - case wsf@ WSubField(wr@ WRef(ref, _, InstanceKind, _), field, tpe, gen) => + case wsf @ WSubField(wr @ WRef(ref, _, InstanceKind, _), field, tpe, gen) => val inst = currentModule.instOf(ref, instMap(Instance(ref)).value) val renamesOpt = renamesMap.get(OfModule(currentModule.module) -> Instance(inst.instance)) val port = inst.ref(field) @@ -242,12 +265,12 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe } case None => wsf } - case wr@ WRef(name, _, InstanceKind, _) => + case wr @ WRef(name, _, InstanceKind, _) => val inst = currentModule.instOf(name, instMap(Instance(name)).value) val renamesOpt = renamesMap.get(OfModule(currentModule.module) -> Instance(inst.instance)) val comp = currentModule.ref(name) renamesOpt.flatMap(_.get(comp)).getOrElse(Seq(comp)) match { - case Seq(car: ReferenceTarget) => wr.copy(name=car.ref) + case Seq(car: ReferenceTarget) => wr.copy(name = car.ref) } case ex => ex.map(fixupRefs(instMap, currentModule)) } @@ -258,7 +281,8 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe val ns = namespaceMap.getOrElseUpdate(currentModuleName, Namespace(iGraph.moduleMap(currentModuleName))) val instMap = instMaps(OfModule(currentModuleName)) s match { - case wDef@ WDefInstance(_, instName, modName, _) if flatInstances.contains(OfModule(currentModuleName) -> Instance(instName)) => + case wDef @ WDefInstance(_, instName, modName, _) + if flatInstances.contains(OfModule(currentModuleName) -> Instance(instName)) => val renames = renamesMap(OfModule(currentModuleName) -> Instance(instName)) val toInline = iGraph.moduleMap(modName) match { case m: ExtModule => throw new PassException(s"Cannot inline external module ${m.name}") @@ -269,7 +293,7 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe val bodyx = { val module = currentModule.copy(module = modName) - cache.getOrElseUpdate(module, Block(ports :+ toInline.body) map onStmt(module)) + cache.getOrElseUpdate(module, Block(ports :+ toInline.body).map(onStmt(module))) } val names = "" +: Uniquify @@ -294,14 +318,14 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe renamedBody case sx => sx - .map(fixupRefs(instMap, currentModule)) - .map(onStmt(currentModule)) + .map(fixupRefs(instMap, currentModule)) + .map(onStmt(currentModule)) } } val flatCircuit = c.copy(modules = c.modules.flatMap { case m if flatModules.contains(m.name) => None - case m => + case m => Some(m.map(onStmt(ModuleName(m.name, CircuitName(c.main))))) }) diff --git a/src/main/scala/firrtl/passes/Legalize.scala b/src/main/scala/firrtl/passes/Legalize.scala index 8b7b733a..5d59e075 100644 --- a/src/main/scala/firrtl/passes/Legalize.scala +++ b/src/main/scala/firrtl/passes/Legalize.scala @@ -1,11 +1,11 @@ package firrtl.passes import firrtl.PrimOps._ -import firrtl.Utils.{BoolType, error, zero} +import firrtl.Utils.{error, zero, BoolType} import firrtl.ir._ import firrtl.options.Dependency import firrtl.transforms.ConstantPropagation -import firrtl.{Transform, bitWidth} +import firrtl.{bitWidth, Transform} import firrtl.Mappers._ // Replace shr by amount >= arg width with 0 for UInts and MSB for SInts @@ -62,30 +62,31 @@ object Legalize extends Pass { } else { val bits = DoPrim(Bits, Seq(c.expr), Seq(w - 1, 0), UIntType(IntWidth(w))) val expr = t match { - case UIntType(_) => bits - case SIntType(_) => DoPrim(AsSInt, Seq(bits), Seq(), SIntType(IntWidth(w))) + case UIntType(_) => bits + case SIntType(_) => DoPrim(AsSInt, Seq(bits), Seq(), SIntType(IntWidth(w))) case FixedType(_, IntWidth(p)) => DoPrim(AsFixedPoint, Seq(bits), Seq(p), t) } Connect(c.info, c.loc, expr) } } - def run (c: Circuit): Circuit = { - def legalizeE(expr: Expression): Expression = expr map legalizeE match { - case prim: DoPrim => prim.op match { - case Shr => legalizeShiftRight(prim) - case Pad => legalizePad(prim) - case Bits | Head | Tail => legalizeBitExtract(prim) - case _ => prim - } + def run(c: Circuit): Circuit = { + def legalizeE(expr: Expression): Expression = expr.map(legalizeE) match { + case prim: DoPrim => + prim.op match { + case Shr => legalizeShiftRight(prim) + case Pad => legalizePad(prim) + case Bits | Head | Tail => legalizeBitExtract(prim) + case _ => prim + } case e => e // respect pre-order traversal } - def legalizeS (s: Statement): Statement = { + def legalizeS(s: Statement): Statement = { val legalizedStmt = s match { case c: Connect => legalizeConnect(c) case _ => s } - legalizedStmt map legalizeS map legalizeE + legalizedStmt.map(legalizeS).map(legalizeE) } - c copy (modules = c.modules map (_ map legalizeS)) + c.copy(modules = c.modules.map(_.map(legalizeS))) } } diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index ace4f3e8..ad608cec 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -3,8 +3,26 @@ package firrtl.passes import firrtl.analyses.{InstanceKeyGraph, SymbolTable} -import firrtl.annotations.{CircuitTarget, MemoryInitAnnotation, MemoryRandomInitAnnotation, ModuleTarget, ReferenceTarget} -import firrtl.{CircuitForm, CircuitState, DependencyAPIMigration, InstanceKind, Kind, MemKind, PortKind, RenameMap, Transform, UnknownForm, Utils} +import firrtl.annotations.{ + CircuitTarget, + MemoryInitAnnotation, + MemoryRandomInitAnnotation, + ModuleTarget, + ReferenceTarget +} +import firrtl.{ + CircuitForm, + CircuitState, + DependencyAPIMigration, + InstanceKind, + Kind, + MemKind, + PortKind, + RenameMap, + Transform, + UnknownForm, + Utils +} import firrtl.ir._ import firrtl.options.Dependency import firrtl.stage.TransformManager.TransformDependency @@ -20,18 +38,19 @@ import scala.collection.mutable object LowerTypes extends Transform with DependencyAPIMigration { override def prerequisites: Seq[TransformDependency] = Seq( Dependency(RemoveAccesses), // we require all SubAccess nodes to have been removed - Dependency(CheckTypes), // we require all types to be correct - Dependency(InferTypes), // we require instance types to be resolved (i.e., DefInstance.tpe != UnknownType) - Dependency(ExpandConnects) // we require all PartialConnect nodes to have been expanded + Dependency(CheckTypes), // we require all types to be correct + Dependency(InferTypes), // we require instance types to be resolved (i.e., DefInstance.tpe != UnknownType) + Dependency(ExpandConnects) // we require all PartialConnect nodes to have been expanded ) - override def optionalPrerequisiteOf: Seq[TransformDependency] = Seq.empty + override def optionalPrerequisiteOf: Seq[TransformDependency] = Seq.empty override def invalidates(a: Transform): Boolean = a match { case ResolveFlows => true // we generate UnknownFlow for now (could be fixed) - case _ => false + case _ => false } /** Delimiter used in lowering names */ val delim = "_" + /** Expands a chain of referential [[firrtl.ir.Expression]]s into the equivalent lowered name * @param e [[firrtl.ir.Expression]] made up of _only_ [[firrtl.WRef]], [[firrtl.WSubField]], and [[firrtl.WSubIndex]] * @return Lowered name of e @@ -39,8 +58,8 @@ object LowerTypes extends Transform with DependencyAPIMigration { */ def loweredName(e: Expression): String = e match { case e: Reference => e.name - case e: SubField => s"${loweredName(e.expr)}$delim${e.name}" - case e: SubIndex => s"${loweredName(e.expr)}$delim${e.value}" + case e: SubField => s"${loweredName(e.expr)}$delim${e.name}" + case e: SubIndex => s"${loweredName(e.expr)}$delim${e.value}" } def loweredName(s: Seq[String]): String = s.mkString(delim) @@ -48,7 +67,7 @@ object LowerTypes extends Transform with DependencyAPIMigration { // When memories are lowered to ground type, we have to fix the init annotation or error on it. val (memInitAnnos, otherAnnos) = state.annotations.partition { case _: MemoryRandomInitAnnotation => false - case _: MemoryInitAnnotation => true + case _: MemoryInitAnnotation => true case _ => false } val memInitByModule = memInitAnnos.map(_.asInstanceOf[MemoryInitAnnotation]).groupBy(_.target.encapsulatingModule) @@ -61,14 +80,18 @@ object LowerTypes extends Transform with DependencyAPIMigration { val newAnnos = otherAnnos ++ resultAndRenames.flatMap(_._3) // chain module renames in topological order - val moduleRenames = resultAndRenames.map{ case(m,r, _) => m.name -> r }.toMap + val moduleRenames = resultAndRenames.map { case (m, r, _) => m.name -> r }.toMap val moduleOrderBottomUp = InstanceKeyGraph(result).moduleOrder.reverseIterator - val renames = moduleOrderBottomUp.map(m => moduleRenames(m.name)).reduce((a,b) => a.andThen(b)) + val renames = moduleOrderBottomUp.map(m => moduleRenames(m.name)).reduce((a, b) => a.andThen(b)) state.copy(circuit = result, renames = Some(renames), annotations = newAnnos) } - private def onModule(c: CircuitTarget, m: DefModule, memoryInit: Seq[MemoryInitAnnotation]): (DefModule, RenameMap, Seq[MemoryInitAnnotation]) = { + private def onModule( + c: CircuitTarget, + m: DefModule, + memoryInit: Seq[MemoryInitAnnotation] + ): (DefModule, RenameMap, Seq[MemoryInitAnnotation]) = { val renameMap = RenameMap() val ref = c.module(m.name) @@ -86,26 +109,36 @@ object LowerTypes extends Transform with DependencyAPIMigration { } // We lower ports in a separate pass in order to ensure that statements inside the module do not influence port names. - private def lowerPorts(ref: ModuleTarget, m: DefModule, renameMap: RenameMap): - (DefModule, Seq[(String, Seq[Reference])]) = { + private def lowerPorts( + ref: ModuleTarget, + m: DefModule, + renameMap: RenameMap + ): (DefModule, Seq[(String, Seq[Reference])]) = { val namespace = mutable.HashSet[String]() ++ m.ports.map(_.name) val loweredPortsAndRefs = m.ports.flatMap { p => - val fieldsAndRefs = DestructTypes.destruct(ref, Field(p.name, Utils.to_flip(p.direction), p.tpe), namespace, renameMap, Set()) - fieldsAndRefs.map { case (f, ref) => - (Port(p.info, f.name, Utils.to_dir(f.flip), f.tpe), ref -> Seq(Reference(f.name, f.tpe, PortKind))) + val fieldsAndRefs = + DestructTypes.destruct(ref, Field(p.name, Utils.to_flip(p.direction), p.tpe), namespace, renameMap, Set()) + fieldsAndRefs.map { + case (f, ref) => + (Port(p.info, f.name, Utils.to_dir(f.flip), f.tpe), ref -> Seq(Reference(f.name, f.tpe, PortKind))) } } val newM = m match { - case e : ExtModule => e.copy(ports = loweredPortsAndRefs.map(_._1)) - case mod: Module => mod.copy(ports = loweredPortsAndRefs.map(_._1)) + case e: ExtModule => e.copy(ports = loweredPortsAndRefs.map(_._1)) + case mod: Module => mod.copy(ports = loweredPortsAndRefs.map(_._1)) } (newM, loweredPortsAndRefs.map(_._2)) } - private def onStatement(s: Statement)(implicit symbols: LoweringTable, memInit: Seq[MemoryInitAnnotation]): Statement = s match { + private def onStatement( + s: Statement + )( + implicit symbols: LoweringTable, + memInit: Seq[MemoryInitAnnotation] + ): Statement = s match { // declarations - case d : DefWire => - Block(symbols.lower(d.name, d.tpe, firrtl.WireKind).map { case (name, tpe, _) => d.copy(name=name, tpe=tpe) }) + case d: DefWire => + Block(symbols.lower(d.name, d.tpe, firrtl.WireKind).map { case (name, tpe, _) => d.copy(name = name, tpe = tpe) }) case d @ DefRegister(info, _, _, clock, reset, _) => // clock and reset are always of ground type val loweredClock = onExpression(clock) @@ -113,41 +146,41 @@ object LowerTypes extends Transform with DependencyAPIMigration { // It is important to first lower the declaration, because the reset can refer to the register itself! val loweredRegs = symbols.lower(d.name, d.tpe, firrtl.RegKind) val inits = Utils.create_exps(d.init).map(onExpression) - Block( - loweredRegs.zip(inits).map { case ((name, tpe, _), init) => + Block(loweredRegs.zip(inits).map { + case ((name, tpe, _), init) => DefRegister(info, name, tpe, loweredClock, loweredReset, init) }) - case d : DefNode => + case d: DefNode => val values = Utils.create_exps(d.value).map(onExpression) - Block( - symbols.lower(d.name, d.value.tpe, firrtl.NodeKind).zip(values).map{ case((name, tpe, _), value) => + Block(symbols.lower(d.name, d.value.tpe, firrtl.NodeKind).zip(values).map { + case ((name, tpe, _), value) => assert(tpe == value.tpe) DefNode(d.info, name, value) }) - case d : DefMemory => + case d: DefMemory => // TODO: as an optimization, we could just skip ground type memories here. // This would require that we don't error in getReferences() but instead return the old reference. val mems = symbols.lower(d) - if(mems.length > 1 && memInit.exists(_.target.ref == d.name)) { + if (mems.length > 1 && memInit.exists(_.target.ref == d.name)) { val mod = memInit.find(_.target.ref == d.name).get.target.encapsulatingModule val msg = s"[module $mod] Cannot initialize memory ${d.name} of non ground type ${d.dataType.serialize}" throw new RuntimeException(msg) } Block(mems) - case d : DefInstance => symbols.lower(d) + case d: DefInstance => symbols.lower(d) // connections case Connect(info, loc, expr) => - if(!expr.tpe.isInstanceOf[GroundType]) { + if (!expr.tpe.isInstanceOf[GroundType]) { throw new RuntimeException(s"LowerTypes expects Connects to have been expanded! ${expr.tpe.serialize}") } val rhs = onExpression(expr) // We can get multiple refs on the lhs because of ground-type memory ports like "clk" which can get duplicated. val lhs = symbols.getReferences(loc.asInstanceOf[RefLikeExpression]) Block(lhs.map(loc => Connect(info, loc, rhs))) - case p : PartialConnect => + case p: PartialConnect => throw new RuntimeException(s"LowerTypes expects PartialConnects to be resolved! $p") case IsInvalid(info, expr) => - if(!expr.tpe.isInstanceOf[GroundType]) { + if (!expr.tpe.isInstanceOf[GroundType]) { throw new RuntimeException(s"LowerTypes expects IsInvalids to have been expanded! ${expr.tpe.serialize}") } // We can get multiple refs on the lhs because of ground-type memory ports like "clk" which can get duplicated. @@ -172,15 +205,18 @@ object LowerTypes extends Transform with DependencyAPIMigration { // Holds the first level of the module-level namespace. // (i.e. everything that can be addressed directly by a Reference node) private class LoweringSymbolTable extends SymbolTable { - def declare(name: String, tpe: Type, kind: Kind): Unit = symbols.append(name) + def declare(name: String, tpe: Type, kind: Kind): Unit = symbols.append(name) def declareInstance(name: String, module: String): Unit = symbols.append(name) private val symbols = mutable.ArrayBuffer[String]() def getSymbolNames: Iterable[String] = symbols } // Lowers types and keeps track of references to lowered types. -private class LoweringTable(table: LoweringSymbolTable, renameMap: RenameMap, m: ModuleTarget, - portNameToExprs: Seq[(String, Seq[Reference])]) { +private class LoweringTable( + table: LoweringSymbolTable, + renameMap: RenameMap, + m: ModuleTarget, + portNameToExprs: Seq[(String, Seq[Reference])]) { private val portNames: Set[String] = portNameToExprs.map(_._2.head.name).toSet private val namespace = mutable.HashSet[String]() ++ table.getSymbolNames // Serialized old access string to new ground type reference. @@ -196,10 +232,11 @@ private class LoweringTable(table: LoweringSymbolTable, renameMap: RenameMap, m: nameToExprs ++= refs.map { case (name, r) => name -> List(r) } newInst } + /** used to lower nodes, registers and wires */ def lower(name: String, tpe: Type, kind: Kind, flip: Orientation = Default): Seq[(String, Type, Orientation)] = { val fieldsAndRefs = DestructTypes.destruct(m, Field(name, flip, tpe), namespace, renameMap, portNames) - nameToExprs ++= fieldsAndRefs.map{ case (f, ref) => ref -> List(Reference(f.name, f.tpe, kind)) } + nameToExprs ++= fieldsAndRefs.map { case (f, ref) => ref -> List(Reference(f.name, f.tpe, kind)) } fieldsAndRefs.map { case (f, _) => (f.name, f.tpe, f.flip) } } def lower(p: Port): Seq[Port] = { @@ -211,10 +248,10 @@ private class LoweringTable(table: LoweringSymbolTable, renameMap: RenameMap, m: // We could just use FirrtlNode.serialize here, but we want to make sure there are not SubAccess nodes left. private def serialize(expr: RefLikeExpression): String = expr match { - case Reference(name, _, _, _) => name - case SubField(expr, name, _, _) => serialize(expr.asInstanceOf[RefLikeExpression]) + "." + name + case Reference(name, _, _, _) => name + case SubField(expr, name, _, _) => serialize(expr.asInstanceOf[RefLikeExpression]) + "." + name case SubIndex(expr, index, _, _) => serialize(expr.asInstanceOf[RefLikeExpression]) + "[" + index.toString + "]" - case a : SubAccess => + case a: SubAccess => throw new RuntimeException(s"LowerTypes expects all SubAccesses to have been expanded! ${a.serialize}") } } @@ -230,13 +267,18 @@ private object DestructTypes { * - generates a list of all old reference name that now refer to the particular ground type field * - updates namespace with all possibly conflicting names */ - def destruct(m: ModuleTarget, ref: Field, namespace: Namespace, renameMap: RenameMap, reserved: Set[String]): - Seq[(Field, String)] = { + def destruct( + m: ModuleTarget, + ref: Field, + namespace: Namespace, + renameMap: RenameMap, + reserved: Set[String] + ): Seq[(Field, String)] = { // field renames (uniquify) are computed bottom up val (rename, _) = uniquify(ref, namespace, reserved) // early exit for ground types that do not need renaming - if(ref.tpe.isInstanceOf[GroundType] && rename.isEmpty) { + if (ref.tpe.isInstanceOf[GroundType] && rename.isEmpty) { return List((ref, ref.name)) } @@ -253,8 +295,13 @@ private object DestructTypes { * Note that the list of fields is only of the child fields, and needs a SubField node * instead of a flat Reference when turning them into access expressions. */ - def destructInstance(m: ModuleTarget, instance: DefInstance, namespace: Namespace, renameMap: RenameMap, - reserved: Set[String]): (DefInstance, Seq[(String, SubField)]) = { + def destructInstance( + m: ModuleTarget, + instance: DefInstance, + namespace: Namespace, + renameMap: RenameMap, + reserved: Set[String] + ): (DefInstance, Seq[(String, SubField)]) = { val (rename, _) = uniquify(Field(instance.name, Default, instance.tpe), namespace, reserved) val newName = rename.map(_.name).getOrElse(instance.name) @@ -266,14 +313,14 @@ private object DestructTypes { } // rename all references to the instance if necessary - if(newName != instance.name) { + if (newName != instance.name) { renameMap.record(m.instOf(instance.name, instance.module), m.instOf(newName, instance.module)) } // The ports do not need to be explicitly renamed here. They are renamed when the module ports are lowered. val newInstance = instance.copy(name = newName, tpe = BundleType(children.map(_._1))) val instanceRef = Reference(newName, newInstance.tpe, InstanceKind) - val refs = children.map{ case(c,r) => extractGroundTypeRefString(r) -> SubField(instanceRef, c.name, c.tpe) } + val refs = children.map { case (c, r) => extractGroundTypeRefString(r) -> SubField(instanceRef, c.name, c.tpe) } (newInstance, refs) } @@ -285,8 +332,13 @@ private object DestructTypes { * e.g. ("mem_a.r.clk", "mem.r.clk") and ("mem_b.r.clk", "mem.r.clk") * Thus it is appropriate to groupBy old reference string instead of just inserting into a hash table. */ - def destructMemory(m: ModuleTarget, mem: DefMemory, namespace: Namespace, renameMap: RenameMap, - reserved: Set[String]): (Seq[DefMemory], Seq[(String, SubField)]) = { + def destructMemory( + m: ModuleTarget, + mem: DefMemory, + namespace: Namespace, + renameMap: RenameMap, + reserved: Set[String] + ): (Seq[DefMemory], Seq[(String, SubField)]) = { // Uniquify the lowered memory names: When memories get split up into ground types, the access order is changes. // E.g. `mem.r.data.x` becomes `mem_x.r.data`. // This is why we need to create the new bundle structure before we can resolve any name clashes. @@ -301,48 +353,50 @@ private object DestructTypes { // the "old dummy field" is used as a template for the new memory port types val oldDummyField = Field("dummy", Default, MemPortUtils.memType(mem.copy(dataType = BoolType))) - val newMemAndSubFields = res.map { case (field, refs) => - val newMem = mem.copy(name = field.name, dataType = field.tpe) - val newMemRef = m.ref(field.name) - val memWasRenamed = field.name != mem.name // false iff the dataType was a GroundType - if(memWasRenamed) { renameMap.record(oldMemRef, newMemRef) } - - val newMemReference = Reference(field.name, MemPortUtils.memType(newMem), MemKind) - val refSuffixes = refs.map(_.component).filterNot(_.isEmpty) - - val subFields = oldDummyField.tpe.asInstanceOf[BundleType].fields.flatMap { port => - val oldPortRef = oldMemRef.field(port.name) - val newPortRef = newMemRef.field(port.name) - - val newPortType = newMemReference.tpe.asInstanceOf[BundleType].fields.find(_.name == port.name).get.tpe - val newPortAccess = SubField(newMemReference, port.name, newPortType) - - port.tpe.asInstanceOf[BundleType].fields.map { portField => - val isDataField = portField.name == "data" || portField.name == "wdata" || portField.name == "rdata" - val isMaskField = portField.name == "mask" || portField.name == "wmask" - val isDataOrMaskField = isDataField || isMaskField - val oldFieldRefs = if(memWasRenamed && isDataOrMaskField) { - // there might have been multiple different fields which now alias to the same lowered field. - val oldPortFieldBaseRef = oldPortRef.field(portField.name) - refSuffixes.map(s => oldPortFieldBaseRef.copy(component = oldPortFieldBaseRef.component ++ s)) - } else { - List(oldPortRef.field(portField.name)) + val newMemAndSubFields = res.map { + case (field, refs) => + val newMem = mem.copy(name = field.name, dataType = field.tpe) + val newMemRef = m.ref(field.name) + val memWasRenamed = field.name != mem.name // false iff the dataType was a GroundType + if (memWasRenamed) { renameMap.record(oldMemRef, newMemRef) } + + val newMemReference = Reference(field.name, MemPortUtils.memType(newMem), MemKind) + val refSuffixes = refs.map(_.component).filterNot(_.isEmpty) + + val subFields = oldDummyField.tpe.asInstanceOf[BundleType].fields.flatMap { port => + val oldPortRef = oldMemRef.field(port.name) + val newPortRef = newMemRef.field(port.name) + + val newPortType = newMemReference.tpe.asInstanceOf[BundleType].fields.find(_.name == port.name).get.tpe + val newPortAccess = SubField(newMemReference, port.name, newPortType) + + port.tpe.asInstanceOf[BundleType].fields.map { portField => + val isDataField = portField.name == "data" || portField.name == "wdata" || portField.name == "rdata" + val isMaskField = portField.name == "mask" || portField.name == "wmask" + val isDataOrMaskField = isDataField || isMaskField + val oldFieldRefs = if (memWasRenamed && isDataOrMaskField) { + // there might have been multiple different fields which now alias to the same lowered field. + val oldPortFieldBaseRef = oldPortRef.field(portField.name) + refSuffixes.map(s => oldPortFieldBaseRef.copy(component = oldPortFieldBaseRef.component ++ s)) + } else { + List(oldPortRef.field(portField.name)) + } + + val newPortType = if (isDataField) { newMem.dataType } + else { portField.tpe } + val newPortFieldAccess = SubField(newPortAccess, portField.name, newPortType) + + // record renames only for the data field which is the only port field of non-ground type + val newPortFieldRef = newPortRef.field(portField.name) + if (memWasRenamed && isDataOrMaskField) { + oldFieldRefs.foreach { o => renameMap.record(o, newPortFieldRef) } + } + + val oldFieldStringRef = extractGroundTypeRefString(oldFieldRefs) + (oldFieldStringRef, newPortFieldAccess) } - - val newPortType = if(isDataField) { newMem.dataType } else { portField.tpe } - val newPortFieldAccess = SubField(newPortAccess, portField.name, newPortType) - - // record renames only for the data field which is the only port field of non-ground type - val newPortFieldRef = newPortRef.field(portField.name) - if(memWasRenamed && isDataOrMaskField) { - oldFieldRefs.foreach { o => renameMap.record(o, newPortFieldRef) } - } - - val oldFieldStringRef = extractGroundTypeRefString(oldFieldRefs) - (oldFieldStringRef, newPortFieldAccess) } - } - (newMem, subFields) + (newMem, subFields) } (newMemAndSubFields.map(_._1), newMemAndSubFields.flatMap(_._2)) @@ -356,22 +410,30 @@ private object DestructTypes { Field(mem.name, Default, BundleType(fields)) } - private def recordRenames(fieldToRefs: Seq[(Field, Seq[ReferenceTarget])], renameMap: RenameMap, parent: ParentRef): - Unit = { + private def recordRenames( + fieldToRefs: Seq[(Field, Seq[ReferenceTarget])], + renameMap: RenameMap, + parent: ParentRef + ): Unit = { // TODO: if we group by ReferenceTarget, we could reduce the number of calls to `record`. Is it worth it? - fieldToRefs.foreach { case(field, refs) => - val fieldRef = parent.ref(field.name) - refs.foreach{ r => renameMap.record(r, fieldRef) } + fieldToRefs.foreach { + case (field, refs) => + val fieldRef = parent.ref(field.name) + refs.foreach { r => renameMap.record(r, fieldRef) } } } private def extractGroundTypeRefString(refs: Seq[ReferenceTarget]): String = { - if (refs.isEmpty) { "" } else { + if (refs.isEmpty) { "" } + else { // Since we depend on ExpandConnects any reference we encounter will be of ground type // and thus the one with the longest access path. - refs.reduceLeft((x, y) => if (x.component.length > y.component.length) x else y) + refs + .reduceLeft((x, y) => if (x.component.length > y.component.length) x else y) // convert references to strings relative to the module - .serialize.dropWhile(_ != '>').tail + .serialize + .dropWhile(_ != '>') + .tail } } @@ -385,14 +447,19 @@ private object DestructTypes { * @return a sequence of ground type fields with new names and, for each field, * a sequence of old references that should to be renamed to point to the particular field */ - private def destruct(prefix: String, oldParent: ParentRef, oldField: Field, - isVecField: Boolean, rename: Option[RenameNode]): Seq[(Field, Seq[ReferenceTarget])] = { + private def destruct( + prefix: String, + oldParent: ParentRef, + oldField: Field, + isVecField: Boolean, + rename: Option[RenameNode] + ): Seq[(Field, Seq[ReferenceTarget])] = { val newName = rename.map(_.name).getOrElse(oldField.name) val oldRef = oldParent.ref(oldField.name, isVecField) oldField.tpe match { - case _ : GroundType => List((oldField.copy(name = prefix + newName), List(oldRef))) - case _ : BundleType | _ : VectorType => + case _: GroundType => List((oldField.copy(name = prefix + newName), List(oldRef))) + case _: BundleType | _: VectorType => val newPrefix = prefix + newName + LowerTypes.delim val isVecField = oldField.tpe.isInstanceOf[VectorType] val fields = getFields(oldField.tpe) @@ -401,7 +468,7 @@ private object DestructTypes { destruct(newPrefix, RefParentRef(oldRef), f, isVecField, rename.flatMap(_.children.get(f.name))) } // the bundle/vec reference refers to all children - children.map{ case(c, r) => (c, r :+ oldRef) } + children.map { case (c, r) => (c, r :+ oldRef) } } } @@ -409,7 +476,8 @@ private object DestructTypes { /** Implements the core functionality of the old Uniquify pass: rename bundle fields and top-level references * where necessary in order to avoid name clashes when lowering aggregate type with the `_` delimiter. - * We don't actually do the rename here but just calculate a rename tree. */ + * We don't actually do the rename here but just calculate a rename tree. + */ private def uniquify(ref: Field, namespace: Namespace, reserved: Set[String]): (Option[RenameNode], Seq[String]) = { // ensure that there are no name clashes with the list of reserved (port) names val newRefName = findValidPrefix(ref.name, reserved.contains) @@ -426,23 +494,23 @@ private object DestructTypes { // We added f.name in previous map, delete if we change it val renamed = prefix != ref.name if (renamed) { - if(!reserved.contains(ref.name)) namespace -= ref.name + if (!reserved.contains(ref.name)) namespace -= ref.name namespace += prefix } val suffixes = renamedFieldNames.map(f => prefix + LowerTypes.delim + f) val anyChildRenamed = renamedFields.exists(_._1.isDefined) - val rename = if(renamed || anyChildRenamed){ - val children = renamedFields.map(_._1).zip(fields).collect{ case (Some(r), f) => f.name -> r }.toMap + val rename = if (renamed || anyChildRenamed) { + val children = renamedFields.map(_._1).zip(fields).collect { case (Some(r), f) => f.name -> r }.toMap Some(RenameNode(prefix, children)) } else { None } (rename, suffixes :+ prefix) - case v : VectorType=> + case v: VectorType => // if Vecs are to be lowered, we can just treat them like a bundle uniquify(ref.copy(tpe = vecToBundle(v)), namespace, reserved) - case _ : GroundType => - if(newRefName == ref.name) { + case _: GroundType => + if (newRefName == ref.name) { (None, List(ref.name)) } else { (Some(RenameNode(newRefName, Map())), List(newRefName)) @@ -452,22 +520,23 @@ private object DestructTypes { } /** Appends delim to prefix until no collisions of prefix + elts in names We don't add an _ in the collision check - * because elts could be Seq("") In this case, we're just really checking if prefix itself collides */ + * because elts could be Seq("") In this case, we're just really checking if prefix itself collides + */ @tailrec private def findValidPrefix(prefix: String, inNamespace: String => Boolean, elts: Seq[String] = List("")): String = { elts.find(elt => inNamespace(prefix + elt)) match { case Some(_) => findValidPrefix(prefix + "_", inNamespace, elts) - case None => prefix + case None => prefix } } private def getFields(tpe: Type): Seq[Field] = tpe match { case BundleType(fields) => fields - case v : VectorType => vecToBundle(v).fields + case v: VectorType => vecToBundle(v).fields } private def vecToBundle(v: VectorType): BundleType = { - BundleType(( 0 until v.size).map(i => Field(i.toString, Default, v.tpe))) + BundleType((0 until v.size).map(i => Field(i.toString, Default, v.tpe))) } /** Used to abstract over module and reference parents. @@ -480,6 +549,7 @@ private object DestructTypes { } private case class RefParentRef(r: ReferenceTarget) extends ParentRef { override def ref(name: String, asVecField: Boolean): ReferenceTarget = - if(asVecField) { r.index(name.toInt) } else { r.field(name) } + if (asVecField) { r.index(name.toInt) } + else { r.field(name) } } } diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala index ca5c2544..79560605 100644 --- a/src/main/scala/firrtl/passes/PadWidths.scala +++ b/src/main/scala/firrtl/passes/PadWidths.scala @@ -15,23 +15,21 @@ object PadWidths extends Pass { override def prerequisites = ((new mutable.LinkedHashSet()) - ++ firrtl.stage.Forms.LowForm - - Dependency(firrtl.passes.Legalize) - + Dependency(firrtl.passes.RemoveValidIf)).toSeq + ++ firrtl.stage.Forms.LowForm + - Dependency(firrtl.passes.Legalize) + + Dependency(firrtl.passes.RemoveValidIf)).toSeq override def optionalPrerequisites = Seq(Dependency[firrtl.transforms.ConstantPropagation]) override def optionalPrerequisiteOf = - Seq( Dependency(firrtl.passes.memlib.VerilogMemDelays), - Dependency[SystemVerilogEmitter], - Dependency[VerilogEmitter] ) + Seq(Dependency(firrtl.passes.memlib.VerilogMemDelays), Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) override def invalidates(a: Transform): Boolean = a match { case _: firrtl.transforms.ConstantPropagation | Legalize => true case _ => false } - private def width(t: Type): Int = bitWidth(t).toInt + private def width(t: Type): Int = bitWidth(t).toInt private def width(e: Expression): Int = width(e.tpe) // Returns an expression with the correct integer width private def fixup(i: Int)(e: Expression) = { @@ -54,31 +52,31 @@ object PadWidths extends Pass { } // Recursive, updates expression so children exp's have correct widths - private def onExp(e: Expression): Expression = e map onExp match { + private def onExp(e: Expression): Expression = e.map(onExp) match { case Mux(cond, tval, fval, tpe) => Mux(cond, fixup(width(tpe))(tval), fixup(width(tpe))(fval), tpe) - case ex: ValidIf => ex copy (value = fixup(width(ex.tpe))(ex.value)) - case ex: DoPrim => ex.op match { - case Lt | Leq | Gt | Geq | Eq | Neq | Not | And | Or | Xor | - Add | Sub | Mul | Div | Rem | Shr => - // sensitive ops - ex map fixup((ex.args map width foldLeft 0)(math.max)) - case Dshl => - // special case as args aren't all same width - ex copy (op = Dshlw, args = Seq(fixup(width(ex.tpe))(ex.args.head), ex.args(1))) - case _ => ex - } + case ex: ValidIf => ex.copy(value = fixup(width(ex.tpe))(ex.value)) + case ex: DoPrim => + ex.op match { + case Lt | Leq | Gt | Geq | Eq | Neq | Not | And | Or | Xor | Add | Sub | Mul | Div | Rem | Shr => + // sensitive ops + ex.map(fixup((ex.args.map(width).foldLeft(0))(math.max))) + case Dshl => + // special case as args aren't all same width + ex.copy(op = Dshlw, args = Seq(fixup(width(ex.tpe))(ex.args.head), ex.args(1))) + case _ => ex + } case ex => ex } // Recursive. Fixes assignments and register initialization widths - private def onStmt(s: Statement): Statement = s map onExp match { + private def onStmt(s: Statement): Statement = s.map(onExp) match { case sx: Connect => - sx copy (expr = fixup(width(sx.loc))(sx.expr)) + sx.copy(expr = fixup(width(sx.loc))(sx.expr)) case sx: DefRegister => - sx copy (init = fixup(width(sx.tpe))(sx.init)) - case sx => sx map onStmt + sx.copy(init = fixup(width(sx.tpe))(sx.init)) + case sx => sx.map(onStmt) } - def run(c: Circuit): Circuit = c copy (modules = c.modules map (_ map onStmt)) + def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(_.map(onStmt))) } diff --git a/src/main/scala/firrtl/passes/Pass.scala b/src/main/scala/firrtl/passes/Pass.scala index 036bd06a..b5eac4ed 100644 --- a/src/main/scala/firrtl/passes/Pass.scala +++ b/src/main/scala/firrtl/passes/Pass.scala @@ -8,7 +8,7 @@ import firrtl.{CircuitState, FirrtlUserException, Transform} * Has an [[UnknownForm]], because larger [[Transform]] should specify form */ trait Pass extends Transform with DependencyAPIMigration { - def run(c: Circuit): Circuit + def run(c: Circuit): Circuit def execute(state: CircuitState): CircuitState = state.copy(circuit = run(state.circuit)) } diff --git a/src/main/scala/firrtl/passes/PullMuxes.scala b/src/main/scala/firrtl/passes/PullMuxes.scala index b805b5fc..27543d63 100644 --- a/src/main/scala/firrtl/passes/PullMuxes.scala +++ b/src/main/scala/firrtl/passes/PullMuxes.scala @@ -11,38 +11,50 @@ object PullMuxes extends Pass { override def invalidates(a: Transform) = false def run(c: Circuit): Circuit = { - def pull_muxes_e(e: Expression): Expression = e map pull_muxes_e match { - case ex: WSubField => ex.expr match { - case exx: Mux => Mux(exx.cond, - WSubField(exx.tval, ex.name, ex.tpe, ex.flow), - WSubField(exx.fval, ex.name, ex.tpe, ex.flow), ex.tpe) - case exx: ValidIf => ValidIf(exx.cond, - WSubField(exx.value, ex.name, ex.tpe, ex.flow), ex.tpe) - case _ => ex // case exx => exx causes failed tests - } - case ex: WSubIndex => ex.expr match { - case exx: Mux => Mux(exx.cond, - WSubIndex(exx.tval, ex.value, ex.tpe, ex.flow), - WSubIndex(exx.fval, ex.value, ex.tpe, ex.flow), ex.tpe) - case exx: ValidIf => ValidIf(exx.cond, - WSubIndex(exx.value, ex.value, ex.tpe, ex.flow), ex.tpe) - case _ => ex // case exx => exx causes failed tests - } - case ex: WSubAccess => ex.expr match { - case exx: Mux => Mux(exx.cond, - WSubAccess(exx.tval, ex.index, ex.tpe, ex.flow), - WSubAccess(exx.fval, ex.index, ex.tpe, ex.flow), ex.tpe) - case exx: ValidIf => ValidIf(exx.cond, - WSubAccess(exx.value, ex.index, ex.tpe, ex.flow), ex.tpe) - case _ => ex // case exx => exx causes failed tests - } - case ex => ex - } - def pull_muxes(s: Statement): Statement = s map pull_muxes map pull_muxes_e - val modulesx = c.modules.map { - case (m:Module) => Module(m.info, m.name, m.ports, pull_muxes(m.body)) - case (m:ExtModule) => m - } - Circuit(c.info, modulesx, c.main) - } + def pull_muxes_e(e: Expression): Expression = e.map(pull_muxes_e) match { + case ex: WSubField => + ex.expr match { + case exx: Mux => + Mux( + exx.cond, + WSubField(exx.tval, ex.name, ex.tpe, ex.flow), + WSubField(exx.fval, ex.name, ex.tpe, ex.flow), + ex.tpe + ) + case exx: ValidIf => ValidIf(exx.cond, WSubField(exx.value, ex.name, ex.tpe, ex.flow), ex.tpe) + case _ => ex // case exx => exx causes failed tests + } + case ex: WSubIndex => + ex.expr match { + case exx: Mux => + Mux( + exx.cond, + WSubIndex(exx.tval, ex.value, ex.tpe, ex.flow), + WSubIndex(exx.fval, ex.value, ex.tpe, ex.flow), + ex.tpe + ) + case exx: ValidIf => ValidIf(exx.cond, WSubIndex(exx.value, ex.value, ex.tpe, ex.flow), ex.tpe) + case _ => ex // case exx => exx causes failed tests + } + case ex: WSubAccess => + ex.expr match { + case exx: Mux => + Mux( + exx.cond, + WSubAccess(exx.tval, ex.index, ex.tpe, ex.flow), + WSubAccess(exx.fval, ex.index, ex.tpe, ex.flow), + ex.tpe + ) + case exx: ValidIf => ValidIf(exx.cond, WSubAccess(exx.value, ex.index, ex.tpe, ex.flow), ex.tpe) + case _ => ex // case exx => exx causes failed tests + } + case ex => ex + } + def pull_muxes(s: Statement): Statement = s.map(pull_muxes).map(pull_muxes_e) + val modulesx = c.modules.map { + case (m: Module) => Module(m.info, m.name, m.ports, pull_muxes(m.body)) + case (m: ExtModule) => m + } + Circuit(c.info, modulesx, c.main) + } } diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala index 18db5939..015346ff 100644 --- a/src/main/scala/firrtl/passes/RemoveAccesses.scala +++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala @@ -2,7 +2,7 @@ package firrtl.passes -import firrtl.{Namespace, Transform, WRef, WSubAccess, WSubIndex, WSubField} +import firrtl.{Namespace, Transform, WRef, WSubAccess, WSubField, WSubIndex} import firrtl.PrimOps.{And, Eq} import firrtl.ir._ import firrtl.Mappers._ @@ -17,10 +17,12 @@ import scala.collection.mutable object RemoveAccesses extends Pass { override def prerequisites = - Seq( Dependency(PullMuxes), - Dependency(ZeroLengthVecs), - Dependency(ReplaceAccesses), - Dependency(ExpandConnects) ) ++ firrtl.stage.Forms.Deduped + Seq( + Dependency(PullMuxes), + Dependency(ZeroLengthVecs), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects) + ) ++ firrtl.stage.Forms.Deduped override def invalidates(a: Transform): Boolean = a match { case Uniquify | ResolveKinds | ResolveFlows => true @@ -28,8 +30,8 @@ object RemoveAccesses extends Pass { } private def AND(e1: Expression, e2: Expression) = - if(e1 == one) e2 - else if(e2 == one) e1 + if (e1 == one) e2 + else if (e2 == one) e1 else DoPrim(And, Seq(e1, e2), Nil, BoolType) private def EQV(e1: Expression, e2: Expression): Expression = @@ -45,30 +47,35 @@ object RemoveAccesses extends Pass { * Seq(Location(a[0], UIntLiteral(0)), Location(a[1], UIntLiteral(1))) */ private def getLocations(e: Expression): Seq[Location] = e match { - case e: WRef => create_exps(e).map(Location(_,one)) + case e: WRef => create_exps(e).map(Location(_, one)) case e: WSubIndex => val ls = getLocations(e.expr) val start = get_point(e) val end = start + get_size(e.tpe) val stride = get_size(e.expr.tpe) - for ((l, i) <- ls.zipWithIndex - if ((i % stride) >= start) & ((i % stride) < end)) yield l + for ( + (l, i) <- ls.zipWithIndex + if ((i % stride) >= start) & ((i % stride) < end) + ) yield l case e: WSubField => val ls = getLocations(e.expr) val start = get_point(e) val end = start + get_size(e.tpe) val stride = get_size(e.expr.tpe) - for ((l, i) <- ls.zipWithIndex - if ((i % stride) >= start) & ((i % stride) < end)) yield l + for ( + (l, i) <- ls.zipWithIndex + if ((i % stride) >= start) & ((i % stride) < end) + ) yield l case e: WSubAccess => val ls = getLocations(e.expr) val stride = get_size(e.tpe) val wrap = e.expr.tpe.asInstanceOf[VectorType].size - ls.zipWithIndex map {case (l, i) => - val c = (i / stride) % wrap - val basex = l.base - val guardx = AND(l.guard,EQV(UIntLiteral(c),e.index)) - Location(basex,guardx) + ls.zipWithIndex.map { + case (l, i) => + val c = (i / stride) % wrap + val basex = l.base + val guardx = AND(l.guard, EQV(UIntLiteral(c), e.index)) + Location(basex, guardx) } } @@ -78,10 +85,10 @@ object RemoveAccesses extends Pass { var ret: Boolean = false def rec_has_access(e: Expression): Expression = { e match { - case _ : WSubAccess => ret = true + case _: WSubAccess => ret = true case _ => } - e map rec_has_access + e.map(rec_has_access) } rec_has_access(e) ret @@ -90,7 +97,7 @@ object RemoveAccesses extends Pass { // This improves the performance of this pass private val createExpsCache = mutable.HashMap[Expression, Seq[Expression]]() private def create_exps(e: Expression) = - createExpsCache getOrElseUpdate (e, firrtl.Utils.create_exps(e)) + createExpsCache.getOrElseUpdate(e, firrtl.Utils.create_exps(e)) def run(c: Circuit): Circuit = { def remove_m(m: Module): Module = { @@ -105,21 +112,21 @@ object RemoveAccesses extends Pass { */ val stmts = mutable.ArrayBuffer[Statement]() def removeSource(e: Expression): Expression = e match { - case (_:WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if hasAccess(e) => + case (_: WSubAccess | _: WSubField | _: WSubIndex | _: WRef) if hasAccess(e) => val rs = getLocations(e) - rs find (x => x.guard != one) match { + rs.find(x => x.guard != one) match { case None => throwInternalError(s"removeSource: shouldn't be here - $e") case Some(_) => val (wire, temp) = create_temp(e) val temps = create_exps(temp) def getTemp(i: Int) = temps(i % temps.size) stmts += wire - rs.zipWithIndex foreach { + rs.zipWithIndex.foreach { case (x, i) if i < temps.size => - stmts += IsInvalid(get_info(s),getTemp(i)) - stmts += Conditionally(get_info(s),x.guard,Connect(get_info(s),getTemp(i),x.base),EmptyStmt) + stmts += IsInvalid(get_info(s), getTemp(i)) + stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt) case (x, i) => - stmts += Conditionally(get_info(s),x.guard,Connect(get_info(s),getTemp(i),x.base),EmptyStmt) + stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt) } temp } @@ -129,14 +136,16 @@ object RemoveAccesses extends Pass { /** Replaces a subaccess in a given sink expression */ def removeSink(info: Info, loc: Expression): Expression = loc match { - case (_: WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if hasAccess(loc) => + case (_: WSubAccess | _: WSubField | _: WSubIndex | _: WRef) if hasAccess(loc) => val ls = getLocations(loc) - if (ls.size == 1 & weq(ls.head.guard,one)) loc + if (ls.size == 1 & weq(ls.head.guard, one)) loc else { val (wire, temp) = create_temp(loc) stmts += wire - ls foreach (x => stmts += - Conditionally(info,x.guard,Connect(info,x.base,temp),EmptyStmt)) + ls.foreach(x => + stmts += + Conditionally(info, x.guard, Connect(info, x.base, temp), EmptyStmt) + ) temp } case _ => loc @@ -150,7 +159,7 @@ object RemoveAccesses extends Pass { case w: WSubAccess => removeSource(WSubAccess(w.expr, fixSource(w.index), w.tpe, w.flow)) //case w: WSubIndex => removeSource(w) //case w: WSubField => removeSource(w) - case x => x map fixSource + case x => x.map(fixSource) } /** Recursively walks a sink expression and fixes all subaccesses @@ -159,13 +168,13 @@ object RemoveAccesses extends Pass { */ def fixSink(e: Expression): Expression = e match { case w: WSubAccess => WSubAccess(fixSink(w.expr), fixSource(w.index), w.tpe, w.flow) - case x => x map fixSink + case x => x.map(fixSink) } val sx = s match { case Connect(info, loc, exp) => Connect(info, removeSink(info, fixSink(loc)), fixSource(exp)) - case sxx => sxx map fixSource map onStmt + case sxx => sxx.map(fixSource).map(onStmt) } stmts += sx if (stmts.size != 1) Block(stmts.toSeq) else stmts(0) @@ -173,9 +182,9 @@ object RemoveAccesses extends Pass { Module(m.info, m.name, m.ports, squashEmpty(onStmt(m.body))) } - c copy (modules = c.modules map { + c.copy(modules = c.modules.map { case m: ExtModule => m - case m: Module => remove_m(m) + case m: Module => remove_m(m) }) } } diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala index 61fd6258..624138ab 100644 --- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala +++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala @@ -17,8 +17,7 @@ case class DataRef(exp: Expression, source: String, sink: String, mask: String, object RemoveCHIRRTL extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.ChirrtlForm ++ - Seq( Dependency(passes.CInferTypes), - Dependency(passes.CInferMDir) ) + Seq(Dependency(passes.CInferTypes), Dependency(passes.CInferMDir)) override def invalidates(a: Transform) = false @@ -31,10 +30,14 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration { def create_all_exps(ex: Expression): Seq[Expression] = ex.tpe match { case _: GroundType => Seq(ex) - case t: BundleType => (t.fields foldLeft Seq[Expression]())((exps, f) => - exps ++ create_all_exps(SubField(ex, f.name, f.tpe))) ++ Seq(ex) - case t: VectorType => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) => - exps ++ create_all_exps(SubIndex(ex, i, t.tpe))) ++ Seq(ex) + case t: BundleType => + (t.fields.foldLeft(Seq[Expression]()))((exps, f) => exps ++ create_all_exps(SubField(ex, f.name, f.tpe))) ++ Seq( + ex + ) + case t: VectorType => + ((0 until t.size).foldLeft(Seq[Expression]()))((exps, i) => + exps ++ create_all_exps(SubIndex(ex, i, t.tpe)) + ) ++ Seq(ex) case UnknownType => Seq(ex) } @@ -42,17 +45,18 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration { case ex: Mux => val e1s = create_exps(ex.tval) val e2s = create_exps(ex.fval) - (e1s zip e2s) map { case (e1, e2) => Mux(ex.cond, e1, e2, mux_type(e1, e2)) } + (e1s.zip(e2s)).map { case (e1, e2) => Mux(ex.cond, e1, e2, mux_type(e1, e2)) } case ex: ValidIf => - create_exps(ex.value) map (e1 => ValidIf(ex.cond, e1, e1.tpe)) - case ex => ex.tpe match { - case _: GroundType => Seq(ex) - case t: BundleType => (t.fields foldLeft Seq[Expression]())((exps, f) => - exps ++ create_exps(SubField(ex, f.name, f.tpe))) - case t: VectorType => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) => - exps ++ create_exps(SubIndex(ex, i, t.tpe))) - case UnknownType => Seq(ex) - } + create_exps(ex.value).map(e1 => ValidIf(ex.cond, e1, e1.tpe)) + case ex => + ex.tpe match { + case _: GroundType => Seq(ex) + case t: BundleType => + (t.fields.foldLeft(Seq[Expression]()))((exps, f) => exps ++ create_exps(SubField(ex, f.name, f.tpe))) + case t: VectorType => + ((0 until t.size).foldLeft(Seq[Expression]()))((exps, i) => exps ++ create_exps(SubIndex(ex, i, t.tpe))) + case UnknownType => Seq(ex) + } } private def EMPs: MPorts = MPorts(ArrayBuffer[MPort](), ArrayBuffer[MPort](), ArrayBuffer[MPort]()) @@ -61,40 +65,48 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration { s match { case sx: CDefMemory if sx.seq => smems += sx.name case sx: CDefMPort => - val p = mports getOrElse (sx.mem, EMPs) + val p = mports.getOrElse(sx.mem, EMPs) sx.direction match { - case MRead => p.readers += MPort(sx.name, sx.exps(1)) - case MWrite => p.writers += MPort(sx.name, sx.exps(1)) + case MRead => p.readers += MPort(sx.name, sx.exps(1)) + case MWrite => p.writers += MPort(sx.name, sx.exps(1)) case MReadWrite => p.readwriters += MPort(sx.name, sx.exps(1)) - case MInfer => // direction may not be inferred if it's not being used + case MInfer => // direction may not be inferred if it's not being used } mports(sx.mem) = p case _ => } - s map collect_smems_and_mports(mports, smems) + s.map(collect_smems_and_mports(mports, smems)) } - def collect_refs(mports: MPortMap, smems: SeqMemSet, types: MPortTypeMap, - refs: DataRefMap, raddrs: AddrMap, renames: RenameMap)(s: Statement): Statement = s match { + def collect_refs( + mports: MPortMap, + smems: SeqMemSet, + types: MPortTypeMap, + refs: DataRefMap, + raddrs: AddrMap, + renames: RenameMap + )(s: Statement + ): Statement = s match { case sx: CDefMemory => types(sx.name) = sx.tpe - val taddr = UIntType(IntWidth(1 max getUIntWidth(sx.size - 1))) + val taddr = UIntType(IntWidth(1.max(getUIntWidth(sx.size - 1)))) val tdata = sx.tpe - def set_poison(vec: scala.collection.Seq[MPort]) = vec.toSeq.flatMap (r => Seq( - IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "addr", taddr)), - IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "clk", ClockType)) - )) - def set_enable(vec: scala.collection.Seq[MPort], en: String) = vec.toSeq.map (r => - Connect(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), en, BoolType), zero) + def set_poison(vec: scala.collection.Seq[MPort]) = vec.toSeq.flatMap(r => + Seq( + IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "addr", taddr)), + IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "clk", ClockType)) + ) ) + def set_enable(vec: scala.collection.Seq[MPort], en: String) = + vec.toSeq.map(r => Connect(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), en, BoolType), zero)) def set_write(vec: scala.collection.Seq[MPort], data: String, mask: String) = vec.toSeq.flatMap { r => val tmask = createMask(sx.tpe) val portRef = SubField(Reference(sx.name, ut), r.name, ut) Seq(IsInvalid(sx.info, SubField(portRef, data, tdata)), IsInvalid(sx.info, SubField(portRef, mask, tmask))) } - val rds = (mports getOrElse (sx.name, EMPs)).readers - val wrs = (mports getOrElse (sx.name, EMPs)).writers - val rws = (mports getOrElse (sx.name, EMPs)).readwriters + val rds = (mports.getOrElse(sx.name, EMPs)).readers + val wrs = (mports.getOrElse(sx.name, EMPs)).writers + val rws = (mports.getOrElse(sx.name, EMPs)).readwriters val stmts = set_poison(rds) ++ set_enable(rds, "en") ++ set_poison(wrs) ++ @@ -104,8 +116,18 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration { set_enable(rws, "wmode") ++ set_enable(rws, "en") ++ set_write(rws, "wdata", "wmask") - val mem = DefMemory(sx.info, sx.name, sx.tpe, sx.size, 1, if (sx.seq) 1 else 0, - rds.map(_.name).toSeq, wrs.map(_.name).toSeq, rws.map(_.name).toSeq, sx.readUnderWrite) + val mem = DefMemory( + sx.info, + sx.name, + sx.tpe, + sx.size, + 1, + if (sx.seq) 1 else 0, + rds.map(_.name).toSeq, + wrs.map(_.name).toSeq, + rws.map(_.name).toSeq, + sx.readUnderWrite + ) Block(mem +: stmts) case sx: CDefMPort => types.get(sx.mem) match { @@ -130,8 +152,8 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration { val es = create_all_exps(WRef(sx.name, sx.tpe)) val rs = create_all_exps(WRef(s"${sx.mem}.${sx.name}.rdata", sx.tpe)) val ws = create_all_exps(WRef(s"${sx.mem}.${sx.name}.wdata", sx.tpe)) - ((es zip rs) zip ws) map { - case ((e, r), w) => renames.rename(e.serialize, Seq(r.serialize, w.serialize)) + ((es.zip(rs)).zip(ws)).map { + case ((e, r), w) => renames.rename(e.serialize, Seq(r.serialize, w.serialize)) } case MWrite => refs(sx.name) = DataRef(portRef, "data", "data", "mask", rdwrite = false) @@ -142,7 +164,7 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration { renames.rename(sx.name, s"${sx.mem}.${sx.name}.data") val es = create_all_exps(WRef(sx.name, sx.tpe)) val ws = create_all_exps(WRef(s"${sx.mem}.${sx.name}.data", sx.tpe)) - (es zip ws) map { + (es.zip(ws)).map { case (e, w) => renames.rename(e.serialize, w.serialize) } case MRead => @@ -157,63 +179,69 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration { renames.rename(sx.name, s"${sx.mem}.${sx.name}.data") val es = create_all_exps(WRef(sx.name, sx.tpe)) val rs = create_all_exps(WRef(s"${sx.mem}.${sx.name}.data", sx.tpe)) - (es zip rs) map { + (es.zip(rs)).map { case (e, r) => renames.rename(e.serialize, r.serialize) } case MInfer => // do nothing if it's not being used } - Block(List() ++ - (addrs.map (x => Connect(sx.info, SubField(portRef, x, ut), sx.exps.head))) ++ - (clks map (x => Connect(sx.info, SubField(portRef, x, ut), sx.exps(1)))) ++ - (ens map (x => Connect(sx.info,SubField(portRef, x, ut), one))) ++ - masks.map(lhs => Connect(sx.info, lhs, zero)) + Block( + List() ++ + (addrs.map(x => Connect(sx.info, SubField(portRef, x, ut), sx.exps.head))) ++ + (clks.map(x => Connect(sx.info, SubField(portRef, x, ut), sx.exps(1)))) ++ + (ens.map(x => Connect(sx.info, SubField(portRef, x, ut), one))) ++ + masks.map(lhs => Connect(sx.info, lhs, zero)) ) - case sx => sx map collect_refs(mports, smems, types, refs, raddrs, renames) + case sx => sx.map(collect_refs(mports, smems, types, refs, raddrs, renames)) } def get_mask(refs: DataRefMap)(e: Expression): Expression = - e map get_mask(refs) match { - case ex: Reference => refs get ex.name match { - case None => ex - case Some(p) => SubField(p.exp, p.mask, createMask(ex.tpe)) - } + e.map(get_mask(refs)) match { + case ex: Reference => + refs.get(ex.name) match { + case None => ex + case Some(p) => SubField(p.exp, p.mask, createMask(ex.tpe)) + } case ex => ex } def remove_chirrtl_s(refs: DataRefMap, raddrs: AddrMap)(s: Statement): Statement = { var has_write_mport = false var has_readwrite_mport: Option[Expression] = None - var has_read_mport: Option[Expression] = None + var has_read_mport: Option[Expression] = None def remove_chirrtl_e(g: Flow)(e: Expression): Expression = e match { - case Reference(name, tpe, _, _) => refs get name match { - case Some(p) => g match { - case SinkFlow => - has_write_mport = true - if (p.rdwrite) has_readwrite_mport = Some(SubField(p.exp, "wmode", BoolType)) - SubField(p.exp, p.sink, tpe) - case SourceFlow => - SubField(p.exp, p.source, tpe) - } - case None => g match { - case SinkFlow => raddrs get name match { - case Some(en) => has_read_mport = Some(en) ; e - case None => e - } - case SourceFlow => e + case Reference(name, tpe, _, _) => + refs.get(name) match { + case Some(p) => + g match { + case SinkFlow => + has_write_mport = true + if (p.rdwrite) has_readwrite_mport = Some(SubField(p.exp, "wmode", BoolType)) + SubField(p.exp, p.sink, tpe) + case SourceFlow => + SubField(p.exp, p.source, tpe) + } + case None => + g match { + case SinkFlow => + raddrs.get(name) match { + case Some(en) => has_read_mport = Some(en); e + case None => e + } + case SourceFlow => e + } } - } - case SubAccess(expr, index, tpe, _) => SubAccess( - remove_chirrtl_e(g)(expr), remove_chirrtl_e(SourceFlow)(index), tpe) - case ex => ex map remove_chirrtl_e(g) - } - s match { + case SubAccess(expr, index, tpe, _) => + SubAccess(remove_chirrtl_e(g)(expr), remove_chirrtl_e(SourceFlow)(index), tpe) + case ex => ex.map(remove_chirrtl_e(g)) + } + s match { case DefNode(info, name, value) => val valuex = remove_chirrtl_e(SourceFlow)(value) val sx = DefNode(info, name, valuex) // Check node is used for read port address remove_chirrtl_e(SinkFlow)(Reference(name, value.tpe)) has_read_mport match { - case None => sx + case None => sx case Some(en) => Block(sx, Connect(info, en, one)) } case Connect(info, loc, expr) => @@ -222,14 +250,14 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration { val sx = Connect(info, locx, rocx) val stmts = ArrayBuffer[Statement]() has_read_mport match { - case None => + case None => case Some(en) => stmts += Connect(info, en, one) } if (has_write_mport) { val locs = create_exps(get_mask(refs)(loc)) - stmts ++= (locs map (x => Connect(info, x, one))) + stmts ++= (locs.map(x => Connect(info, x, one))) has_readwrite_mport match { - case None => + case None => case Some(wmode) => stmts += Connect(info, wmode, one) } } @@ -240,20 +268,20 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration { val sx = PartialConnect(info, locx, rocx) val stmts = ArrayBuffer[Statement]() has_read_mport match { - case None => + case None => case Some(en) => stmts += Connect(info, en, one) } if (has_write_mport) { val ls = get_valid_points(loc.tpe, expr.tpe, Default, Default) val locs = create_exps(get_mask(refs)(loc)) - stmts ++= (ls map { case (x, _) => Connect(info, locs(x), one) }) + stmts ++= (ls.map { case (x, _) => Connect(info, locs(x), one) }) has_readwrite_mport match { - case None => + case None => case Some(wmode) => stmts += Connect(info, wmode, one) } } if (stmts.isEmpty) sx else Block(sx +: stmts.toSeq) - case sx => sx map remove_chirrtl_s(refs, raddrs) map remove_chirrtl_e(SourceFlow) + case sx => sx.map(remove_chirrtl_s(refs, raddrs)).map(remove_chirrtl_e(SourceFlow)) } } @@ -264,16 +292,16 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration { val refs = new DataRefMap val raddrs = new AddrMap renames.setModule(m.name) - (m map collect_smems_and_mports(mports, smems) - map collect_refs(mports, smems, types, refs, raddrs, renames) - map remove_chirrtl_s(refs, raddrs)) + (m.map(collect_smems_and_mports(mports, smems)) + .map(collect_refs(mports, smems, types, refs, raddrs, renames)) + .map(remove_chirrtl_s(refs, raddrs))) } def execute(state: CircuitState): CircuitState = { val c = state.circuit val renames = RenameMap() renames.setCircuit(c.main) - val result = c copy (modules = c.modules map remove_chirrtl_m(renames)) + val result = c.copy(modules = c.modules.map(remove_chirrtl_m(renames))) state.copy(circuit = result, renames = Some(renames)) } } diff --git a/src/main/scala/firrtl/passes/RemoveEmpty.scala b/src/main/scala/firrtl/passes/RemoveEmpty.scala index eabf667c..eb25dcc4 100644 --- a/src/main/scala/firrtl/passes/RemoveEmpty.scala +++ b/src/main/scala/firrtl/passes/RemoveEmpty.scala @@ -15,7 +15,7 @@ object RemoveEmpty extends Pass with DependencyAPIMigration { private def onModule(m: DefModule): DefModule = { m match { - case m: Module => Module(m.info, m.name, m.ports, Utils.squashEmpty(m.body)) + case m: Module => Module(m.info, m.name, m.ports, Utils.squashEmpty(m.body)) case m: ExtModule => m } } diff --git a/src/main/scala/firrtl/passes/RemoveIntervals.scala b/src/main/scala/firrtl/passes/RemoveIntervals.scala index 7059526c..657b4356 100644 --- a/src/main/scala/firrtl/passes/RemoveIntervals.scala +++ b/src/main/scala/firrtl/passes/RemoveIntervals.scala @@ -13,14 +13,13 @@ import firrtl.options.Dependency import scala.math.BigDecimal.RoundingMode._ class WrapWithRemainder(info: Info, mname: String, wrap: DoPrim) - extends PassException({ - val toWrap = wrap.args.head.serialize - val toWrapTpe = wrap.args.head.tpe.serialize - val wrapTo = wrap.args(1).serialize - val wrapToTpe = wrap.args(1).tpe.serialize - s"$info: [module $mname] Wraps with remainder currently unsupported: $toWrap:$toWrapTpe cannot be wrapped to $wrapTo's type $wrapToTpe" - }) - + extends PassException({ + val toWrap = wrap.args.head.serialize + val toWrapTpe = wrap.args.head.tpe.serialize + val wrapTo = wrap.args(1).serialize + val wrapToTpe = wrap.args(1).tpe.serialize + s"$info: [module $mname] Wraps with remainder currently unsupported: $toWrap:$toWrapTpe cannot be wrapped to $wrapTo's type $wrapToTpe" + }) /** Replaces IntervalType with SIntType, three AST walks: * 1) Align binary points @@ -39,48 +38,50 @@ class WrapWithRemainder(info: Info, mname: String, wrap: DoPrim) class RemoveIntervals extends Pass { override def prerequisites: Seq[Dependency[Transform]] = - Seq( Dependency(PullMuxes), - Dependency(ReplaceAccesses), - Dependency(ExpandConnects), - Dependency(RemoveAccesses), - Dependency[ExpandWhensAndCheck] ) ++ firrtl.stage.Forms.Deduped + Seq( + Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects), + Dependency(RemoveAccesses), + Dependency[ExpandWhensAndCheck] + ) ++ firrtl.stage.Forms.Deduped override def invalidates(transform: Transform): Boolean = { transform match { case InferTypes | ResolveKinds => true - case _ => false + case _ => false } } def run(c: Circuit): Circuit = { val alignedCircuit = c val errors = new Errors() - val wiredCircuit = alignedCircuit map makeWireModule - val replacedCircuit = wiredCircuit map replaceModuleInterval(errors) + val wiredCircuit = alignedCircuit.map(makeWireModule) + val replacedCircuit = wiredCircuit.map(replaceModuleInterval(errors)) errors.trigger() replacedCircuit } /* Replace interval types */ private def replaceModuleInterval(errors: Errors)(m: DefModule): DefModule = - m map replaceStmtInterval(errors, m.name) map replacePortInterval + m.map(replaceStmtInterval(errors, m.name)).map(replacePortInterval) private def replaceStmtInterval(errors: Errors, mname: String)(s: Statement): Statement = { val info = s match { case h: HasInfo => h.info case _ => NoInfo } - s map replaceTypeInterval map replaceStmtInterval(errors, mname) map replaceExprInterval(errors, info, mname) + s.map(replaceTypeInterval).map(replaceStmtInterval(errors, mname)).map(replaceExprInterval(errors, info, mname)) } private def replaceExprInterval(errors: Errors, info: Info, mname: String)(e: Expression): Expression = e match { case _: WRef | _: WSubIndex | _: WSubField => e case o => - o map replaceExprInterval(errors, info, mname) match { + o.map(replaceExprInterval(errors, info, mname)) match { case DoPrim(AsInterval, Seq(a1), _, tpe) => DoPrim(AsSInt, Seq(a1), Seq.empty, tpe) - case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe) - case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe) + case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe) + case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe) case DoPrim(Clip, Seq(a1, _), Nil, tpe: IntervalType) => // Output interval (pre-calculated) val clipLo = tpe.minAdjusted.get @@ -94,13 +95,13 @@ class RemoveIntervals extends Pass { val ltOpt = clipLo <= inLow (gtOpt, ltOpt) match { // input range within output range -> no optimization - case (true, true) => a1 + case (true, true) => a1 case (true, false) => Mux(Lt(a1, clipLo.S), clipLo.S, a1) case (false, true) => Mux(Gt(a1, clipHi.S), clipHi.S, a1) - case _ => Mux(Gt(a1, clipHi.S), clipHi.S, Mux(Lt(a1, clipLo.S), clipLo.S, a1)) + case _ => Mux(Gt(a1, clipHi.S), clipHi.S, Mux(Lt(a1, clipLo.S), clipLo.S, a1)) } - case sqz@DoPrim(Squeeze, Seq(a1, a2), Nil, tpe: IntervalType) => + case sqz @ DoPrim(Squeeze, Seq(a1, a2), Nil, tpe: IntervalType) => // Using (conditional) reassign interval w/o adding mux val a1tpe = a1.tpe.asInstanceOf[IntervalType] val a2tpe = a2.tpe.asInstanceOf[IntervalType] @@ -117,54 +118,55 @@ class RemoveIntervals extends Pass { val bits = DoPrim(Bits, Seq(a1), Seq(w2 - 1, 0), UIntType(IntWidth(w2))) DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(w2))) } - case w@DoPrim(Wrap, Seq(a1, a2), Nil, tpe: IntervalType) => a2.tpe match { - // If a2 type is Interval wrap around range. If UInt, wrap around width - case t: IntervalType => - // Need to match binary points before getting *adjusted! - val (wrapLo, wrapHi) = t.copy(point = tpe.point) match { - case t: IntervalType => (t.minAdjusted.get, t.maxAdjusted.get) - case _ => Utils.throwInternalError(s"Illegal AST state: cannot have $e not have an IntervalType") - } - val (inLo, inHi) = a1.tpe match { - case t2: IntervalType => (t2.minAdjusted.get, t2.maxAdjusted.get) - case _ => sys.error("Shouldn't be here") - } - // If (max input) - (max wrap) + (min wrap) is less then (maxwrap), we can optimize when (max input > max wrap) - val range = wrapHi - wrapLo - val ltOpt = Add(a1, (range + 1).S) - val gtOpt = Sub(a1, (range + 1).S) - // [Angie]: This is dangerous. Would rather throw compilation error right now than allow "Rem" without the user explicitly including it. - // If x < wl - // output: wh - (wl - x) + 1 AKA x + r + 1 - // worst case: wh - (wl - xl) + 1 = wl - // -> xl + wr + 1 = wl - // If x > wh - // output: wl + (x - wh) - 1 AKA x - r - 1 - // worst case: wl + (xh - wh) - 1 = wh - // -> xh - wr - 1 = wh - val default = Add(Rem(Sub(a1, wrapLo.S), Sub(wrapHi.S, wrapLo.S)), wrapLo.S) - (wrapHi >= inHi, wrapLo <= inLo, (inHi - range - 1) <= wrapHi, (inLo + range + 1) >= wrapLo) match { - case (true, true, _, _) => a1 - case (true, _, _, true) => Mux(Lt(a1, wrapLo.S), ltOpt, a1) - case (_, true, true, _) => Mux(Gt(a1, wrapHi.S), gtOpt, a1) - // Note: inHi - range - 1 = wrapHi can't be true when inLo + range + 1 = wrapLo (i.e. simultaneous extreme cases don't work) - case (_, _, true, true) => Mux(Gt(a1, wrapHi.S), gtOpt, Mux(Lt(a1, wrapLo.S), ltOpt, a1)) - case _ => - errors.append(new WrapWithRemainder(info, mname, w)) - default - } - case _ => sys.error("Shouldn't be here") - } + case w @ DoPrim(Wrap, Seq(a1, a2), Nil, tpe: IntervalType) => + a2.tpe match { + // If a2 type is Interval wrap around range. If UInt, wrap around width + case t: IntervalType => + // Need to match binary points before getting *adjusted! + val (wrapLo, wrapHi) = t.copy(point = tpe.point) match { + case t: IntervalType => (t.minAdjusted.get, t.maxAdjusted.get) + case _ => Utils.throwInternalError(s"Illegal AST state: cannot have $e not have an IntervalType") + } + val (inLo, inHi) = a1.tpe match { + case t2: IntervalType => (t2.minAdjusted.get, t2.maxAdjusted.get) + case _ => sys.error("Shouldn't be here") + } + // If (max input) - (max wrap) + (min wrap) is less then (maxwrap), we can optimize when (max input > max wrap) + val range = wrapHi - wrapLo + val ltOpt = Add(a1, (range + 1).S) + val gtOpt = Sub(a1, (range + 1).S) + // [Angie]: This is dangerous. Would rather throw compilation error right now than allow "Rem" without the user explicitly including it. + // If x < wl + // output: wh - (wl - x) + 1 AKA x + r + 1 + // worst case: wh - (wl - xl) + 1 = wl + // -> xl + wr + 1 = wl + // If x > wh + // output: wl + (x - wh) - 1 AKA x - r - 1 + // worst case: wl + (xh - wh) - 1 = wh + // -> xh - wr - 1 = wh + val default = Add(Rem(Sub(a1, wrapLo.S), Sub(wrapHi.S, wrapLo.S)), wrapLo.S) + (wrapHi >= inHi, wrapLo <= inLo, (inHi - range - 1) <= wrapHi, (inLo + range + 1) >= wrapLo) match { + case (true, true, _, _) => a1 + case (true, _, _, true) => Mux(Lt(a1, wrapLo.S), ltOpt, a1) + case (_, true, true, _) => Mux(Gt(a1, wrapHi.S), gtOpt, a1) + // Note: inHi - range - 1 = wrapHi can't be true when inLo + range + 1 = wrapLo (i.e. simultaneous extreme cases don't work) + case (_, _, true, true) => Mux(Gt(a1, wrapHi.S), gtOpt, Mux(Lt(a1, wrapLo.S), ltOpt, a1)) + case _ => + errors.append(new WrapWithRemainder(info, mname, w)) + default + } + case _ => sys.error("Shouldn't be here") + } case other => other } } - private def replacePortInterval(p: Port): Port = p map replaceTypeInterval + private def replacePortInterval(p: Port): Port = p.map(replaceTypeInterval) private def replaceTypeInterval(t: Type): Type = t match { - case i@IntervalType(l: IsKnown, u: IsKnown, p: IntWidth) => SIntType(i.width) + case i @ IntervalType(l: IsKnown, u: IsKnown, p: IntWidth) => SIntType(i.width) case i: IntervalType => sys.error(s"Shouldn't be here: $i") - case v => v map replaceTypeInterval + case v => v.map(replaceTypeInterval) } /** Replace Interval Nodes with Interval Wires @@ -174,15 +176,16 @@ class RemoveIntervals extends Pass { * @param m module to replace nodes with wire + connection * @return */ - private def makeWireModule(m: DefModule): DefModule = m map makeWireStmt + private def makeWireModule(m: DefModule): DefModule = m.map(makeWireStmt) private def makeWireStmt(s: Statement): Statement = s match { - case DefNode(info, name, value) => value.tpe match { - case IntervalType(l, u, p) => - val newType = IntervalType(l, u, p) - Block(Seq(DefWire(info, name, newType), Connect(info, WRef(name, newType, WireKind, SinkFlow), value))) - case other => s - } - case other => other map makeWireStmt + case DefNode(info, name, value) => + value.tpe match { + case IntervalType(l, u, p) => + val newType = IntervalType(l, u, p) + Block(Seq(DefWire(info, name, newType), Connect(info, WRef(name, newType, WireKind, SinkFlow), value))) + case other => s + } + case other => other.map(makeWireStmt) } } diff --git a/src/main/scala/firrtl/passes/RemoveValidIf.scala b/src/main/scala/firrtl/passes/RemoveValidIf.scala index 895cb10f..7e82b37b 100644 --- a/src/main/scala/firrtl/passes/RemoveValidIf.scala +++ b/src/main/scala/firrtl/passes/RemoveValidIf.scala @@ -26,14 +26,13 @@ object RemoveValidIf extends Pass { case ClockType => ClockZero case _: FixedType => FixedZero case AsyncResetType => AsyncZero - case other => throwInternalError(s"Unexpected type $other") + case other => throwInternalError(s"Unexpected type $other") } override def prerequisites = firrtl.stage.Forms.LowForm override def optionalPrerequisiteOf = - Seq( Dependency[SystemVerilogEmitter], - Dependency[VerilogEmitter] ) + Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) override def invalidates(a: Transform): Boolean = a match { case Legalize | _: firrtl.transforms.ConstantPropagation => true @@ -42,24 +41,25 @@ object RemoveValidIf extends Pass { // Recursive. Removes ValidIfs private def onExp(e: Expression): Expression = { - e map onExp match { + e.map(onExp) match { case ValidIf(_, value, _) => value - case x => x + case x => x } } // Recursive. Replaces IsInvalid with connecting zero - private def onStmt(s: Statement): Statement = s map onStmt map onExp match { - case invalid @ IsInvalid(info, loc) => loc.tpe match { - case _: AnalogType => EmptyStmt - case tpe => Connect(info, loc, getGroundZero(tpe)) - } + private def onStmt(s: Statement): Statement = s.map(onStmt).map(onExp) match { + case invalid @ IsInvalid(info, loc) => + loc.tpe match { + case _: AnalogType => EmptyStmt + case tpe => Connect(info, loc, getGroundZero(tpe)) + } case other => other } private def onModule(m: DefModule): DefModule = { m match { - case m: Module => Module(m.info, m.name, m.ports, onStmt(m.body)) + case m: Module => Module(m.info, m.name, m.ports, onStmt(m.body)) case m: ExtModule => m } } diff --git a/src/main/scala/firrtl/passes/ReplaceAccesses.scala b/src/main/scala/firrtl/passes/ReplaceAccesses.scala index e31d9410..4a3cd697 100644 --- a/src/main/scala/firrtl/passes/ReplaceAccesses.scala +++ b/src/main/scala/firrtl/passes/ReplaceAccesses.scala @@ -18,15 +18,16 @@ object ReplaceAccesses extends Pass { override def invalidates(a: Transform) = false def run(c: Circuit): Circuit = { - def onStmt(s: Statement): Statement = s map onStmt map onExp - def onExp(e: Expression): Expression = e match { - case WSubAccess(ex, UIntLiteral(value, _), t, g) => ex.tpe match { - case VectorType(_, len) if (value < len) => WSubIndex(onExp(ex), value.toInt, t, g) - case _ => e map onExp - } - case _ => e map onExp + def onStmt(s: Statement): Statement = s.map(onStmt).map(onExp) + def onExp(e: Expression): Expression = e match { + case WSubAccess(ex, UIntLiteral(value, _), t, g) => + ex.tpe match { + case VectorType(_, len) if (value < len) => WSubIndex(onExp(ex), value.toInt, t, g) + case _ => e.map(onExp) + } + case _ => e.map(onExp) } - c copy (modules = c.modules map (_ map onStmt)) + c.copy(modules = c.modules.map(_.map(onStmt))) } } diff --git a/src/main/scala/firrtl/passes/ResolveFlows.scala b/src/main/scala/firrtl/passes/ResolveFlows.scala index 85a0a26f..48b9479c 100644 --- a/src/main/scala/firrtl/passes/ResolveFlows.scala +++ b/src/main/scala/firrtl/passes/ResolveFlows.scala @@ -14,17 +14,22 @@ object ResolveFlows extends Pass { override def invalidates(a: Transform) = false def resolve_e(g: Flow)(e: Expression): Expression = e match { - case ex: WRef => ex copy (flow = g) - case WSubField(exp, name, tpe, _) => WSubField( - Utils.field_flip(exp.tpe, name) match { - case Default => resolve_e(g)(exp) - case Flip => resolve_e(Utils.swap(g))(exp) - }, name, tpe, g) + case ex: WRef => ex.copy(flow = g) + case WSubField(exp, name, tpe, _) => + WSubField( + Utils.field_flip(exp.tpe, name) match { + case Default => resolve_e(g)(exp) + case Flip => resolve_e(Utils.swap(g))(exp) + }, + name, + tpe, + g + ) case WSubIndex(exp, value, tpe, _) => WSubIndex(resolve_e(g)(exp), value, tpe, g) case WSubAccess(exp, index, tpe, _) => WSubAccess(resolve_e(g)(exp), resolve_e(SourceFlow)(index), tpe, g) - case _ => e map resolve_e(g) + case _ => e.map(resolve_e(g)) } def resolve_s(s: Statement): Statement = s match { @@ -35,11 +40,11 @@ object ResolveFlows extends Pass { Connect(info, resolve_e(SinkFlow)(loc), resolve_e(SourceFlow)(expr)) case PartialConnect(info, loc, expr) => PartialConnect(info, resolve_e(SinkFlow)(loc), resolve_e(SourceFlow)(expr)) - case sx => sx map resolve_e(SourceFlow) map resolve_s + case sx => sx.map(resolve_e(SourceFlow)).map(resolve_s) } - def resolve_flow(m: DefModule): DefModule = m map resolve_s + def resolve_flow(m: DefModule): DefModule = m.map(resolve_s) def run(c: Circuit): Circuit = - c copy (modules = c.modules map resolve_flow) + c.copy(modules = c.modules.map(resolve_flow)) } diff --git a/src/main/scala/firrtl/passes/ResolveKinds.scala b/src/main/scala/firrtl/passes/ResolveKinds.scala index 67360b74..fcbac163 100644 --- a/src/main/scala/firrtl/passes/ResolveKinds.scala +++ b/src/main/scala/firrtl/passes/ResolveKinds.scala @@ -20,21 +20,21 @@ object ResolveKinds extends Pass { } def resolve_expr(kinds: KindMap)(e: Expression): Expression = e match { - case ex: WRef => ex copy (kind = kinds(ex.name)) - case _ => e map resolve_expr(kinds) + case ex: WRef => ex.copy(kind = kinds(ex.name)) + case _ => e.map(resolve_expr(kinds)) } def resolve_stmt(kinds: KindMap)(s: Statement): Statement = { s match { - case sx: DefWire => kinds(sx.name) = WireKind - case sx: DefNode => kinds(sx.name) = NodeKind - case sx: DefRegister => kinds(sx.name) = RegKind + case sx: DefWire => kinds(sx.name) = WireKind + case sx: DefNode => kinds(sx.name) = NodeKind + case sx: DefRegister => kinds(sx.name) = RegKind case sx: WDefInstance => kinds(sx.name) = InstanceKind - case sx: DefMemory => kinds(sx.name) = MemKind + case sx: DefMemory => kinds(sx.name) = MemKind case _ => } s.map(resolve_stmt(kinds)) - .map(resolve_expr(kinds)) + .map(resolve_expr(kinds)) } def resolve_kinds(m: DefModule): DefModule = { @@ -44,5 +44,5 @@ object ResolveKinds extends Pass { } def run(c: Circuit): Circuit = - c copy (modules = c.modules map resolve_kinds) + c.copy(modules = c.modules.map(resolve_kinds)) } diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala index c536cd5d..a65f8921 100644 --- a/src/main/scala/firrtl/passes/SplitExpressions.scala +++ b/src/main/scala/firrtl/passes/SplitExpressions.scala @@ -7,7 +7,7 @@ import firrtl.{SystemVerilogEmitter, Transform, VerilogEmitter} import firrtl.ir._ import firrtl.options.Dependency import firrtl.Mappers._ -import firrtl.Utils.{kind, flow, get_info} +import firrtl.Utils.{flow, get_info, kind} // Datastructures import scala.collection.mutable @@ -17,65 +17,63 @@ import scala.collection.mutable object SplitExpressions extends Pass { override def prerequisites = firrtl.stage.Forms.LowForm ++ - Seq( Dependency(firrtl.passes.RemoveValidIf), - Dependency(firrtl.passes.memlib.VerilogMemDelays) ) + Seq(Dependency(firrtl.passes.RemoveValidIf), Dependency(firrtl.passes.memlib.VerilogMemDelays)) override def optionalPrerequisiteOf = - Seq( Dependency[SystemVerilogEmitter], - Dependency[VerilogEmitter] ) + Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) override def invalidates(a: Transform) = a match { case ResolveKinds => true case _ => false } - private def onModule(m: Module): Module = { - val namespace = Namespace(m) - def onStmt(s: Statement): Statement = { - val v = mutable.ArrayBuffer[Statement]() - // Splits current expression if needed - // Adds named temporaries to v - def split(e: Expression): Expression = e match { - case e: DoPrim => - val name = namespace.newTemp - v += DefNode(get_info(s), name, e) - WRef(name, e.tpe, kind(e), flow(e)) - case e: Mux => - val name = namespace.newTemp - v += DefNode(get_info(s), name, e) - WRef(name, e.tpe, kind(e), flow(e)) - case e: ValidIf => - val name = namespace.newTemp - v += DefNode(get_info(s), name, e) - WRef(name, e.tpe, kind(e), flow(e)) - case _ => e - } - - // Recursive. Splits compound nodes - def onExp(e: Expression): Expression = - e map onExp match { - case ex: DoPrim => ex map split - case ex => ex - } + private def onModule(m: Module): Module = { + val namespace = Namespace(m) + def onStmt(s: Statement): Statement = { + val v = mutable.ArrayBuffer[Statement]() + // Splits current expression if needed + // Adds named temporaries to v + def split(e: Expression): Expression = e match { + case e: DoPrim => + val name = namespace.newTemp + v += DefNode(get_info(s), name, e) + WRef(name, e.tpe, kind(e), flow(e)) + case e: Mux => + val name = namespace.newTemp + v += DefNode(get_info(s), name, e) + WRef(name, e.tpe, kind(e), flow(e)) + case e: ValidIf => + val name = namespace.newTemp + v += DefNode(get_info(s), name, e) + WRef(name, e.tpe, kind(e), flow(e)) + case _ => e + } - s map onExp match { - case x: Block => x map onStmt - case EmptyStmt => EmptyStmt - case x => - v += x - v.size match { - case 1 => v.head - case _ => Block(v.toSeq) - } + // Recursive. Splits compound nodes + def onExp(e: Expression): Expression = + e.map(onExp) match { + case ex: DoPrim => ex.map(split) + case ex => ex } + + s.map(onExp) match { + case x: Block => x.map(onStmt) + case EmptyStmt => EmptyStmt + case x => + v += x + v.size match { + case 1 => v.head + case _ => Block(v.toSeq) + } } - Module(m.info, m.name, m.ports, onStmt(m.body)) - } - def run(c: Circuit): Circuit = { - val modulesx = c.modules map { - case m: Module => onModule(m) - case m: ExtModule => m - } - Circuit(c.info, modulesx, c.main) - } + } + Module(m.info, m.name, m.ports, onStmt(m.body)) + } + def run(c: Circuit): Circuit = { + val modulesx = c.modules.map { + case m: Module => onModule(m) + case m: ExtModule => m + } + Circuit(c.info, modulesx, c.main) + } } diff --git a/src/main/scala/firrtl/passes/ToWorkingIR.scala b/src/main/scala/firrtl/passes/ToWorkingIR.scala index c271302a..03faaf3c 100644 --- a/src/main/scala/firrtl/passes/ToWorkingIR.scala +++ b/src/main/scala/firrtl/passes/ToWorkingIR.scala @@ -6,5 +6,5 @@ import firrtl.Transform object ToWorkingIR extends Pass { override def prerequisites = firrtl.stage.Forms.MinimalHighForm override def invalidates(a: Transform) = false - def run(c:Circuit): Circuit = c + def run(c: Circuit): Circuit = c } diff --git a/src/main/scala/firrtl/passes/TrimIntervals.scala b/src/main/scala/firrtl/passes/TrimIntervals.scala index 822a8125..0a05bd4e 100644 --- a/src/main/scala/firrtl/passes/TrimIntervals.scala +++ b/src/main/scala/firrtl/passes/TrimIntervals.scala @@ -23,10 +23,7 @@ import firrtl.Transform class TrimIntervals extends Pass { override def prerequisites = - Seq( Dependency(ResolveKinds), - Dependency(InferTypes), - Dependency(ResolveFlows), - Dependency[InferBinaryPoints] ) + Seq(Dependency(ResolveKinds), Dependency(InferTypes), Dependency(ResolveFlows), Dependency[InferBinaryPoints]) override def optionalPrerequisiteOf = Seq.empty @@ -34,48 +31,51 @@ class TrimIntervals extends Pass { def run(c: Circuit): Circuit = { // Open -> closed - val firstPass = InferTypes.run(c map replaceModuleInterval) + val firstPass = InferTypes.run(c.map(replaceModuleInterval)) // Align binary points and adjust range accordingly (loss of precision changes range) - firstPass map alignModuleBP + firstPass.map(alignModuleBP) } /* Replace interval types */ - private def replaceModuleInterval(m: DefModule): DefModule = m map replaceStmtInterval map replacePortInterval + private def replaceModuleInterval(m: DefModule): DefModule = m.map(replaceStmtInterval).map(replacePortInterval) - private def replaceStmtInterval(s: Statement): Statement = s map replaceTypeInterval map replaceStmtInterval + private def replaceStmtInterval(s: Statement): Statement = s.map(replaceTypeInterval).map(replaceStmtInterval) - private def replacePortInterval(p: Port): Port = p map replaceTypeInterval + private def replacePortInterval(p: Port): Port = p.map(replaceTypeInterval) private def replaceTypeInterval(t: Type): Type = t match { - case i@IntervalType(l: IsKnown, u: IsKnown, IntWidth(p)) => + case i @ IntervalType(l: IsKnown, u: IsKnown, IntWidth(p)) => IntervalType(Closed(i.min.get), Closed(i.max.get), IntWidth(p)) case i: IntervalType => i - case v => v map replaceTypeInterval + case v => v.map(replaceTypeInterval) } /* Align interval binary points -- BINARY POINT ALIGNMENT AFFECTS RANGE INFERENCE! */ - private def alignModuleBP(m: DefModule): DefModule = m map alignStmtBP - - private def alignStmtBP(s: Statement): Statement = s map alignExpBP match { - case c@Connect(info, loc, expr) => loc.tpe match { - case IntervalType(_, _, p) => Connect(info, loc, fixBP(p)(expr)) - case _ => c - } - case c@PartialConnect(info, loc, expr) => loc.tpe match { - case IntervalType(_, _, p) => PartialConnect(info, loc, fixBP(p)(expr)) - case _ => c - } - case other => other map alignStmtBP + private def alignModuleBP(m: DefModule): DefModule = m.map(alignStmtBP) + + private def alignStmtBP(s: Statement): Statement = s.map(alignExpBP) match { + case c @ Connect(info, loc, expr) => + loc.tpe match { + case IntervalType(_, _, p) => Connect(info, loc, fixBP(p)(expr)) + case _ => c + } + case c @ PartialConnect(info, loc, expr) => + loc.tpe match { + case IntervalType(_, _, p) => PartialConnect(info, loc, fixBP(p)(expr)) + case _ => c + } + case other => other.map(alignStmtBP) } // Note - wrap/clip/squeeze ignore the binary point of the second argument, thus not needed to be aligned // Note - Mul does not need its binary points aligned, because multiplication is cool like that - private val opsToFix = Seq(Add, Sub, Lt, Leq, Gt, Geq, Eq, Neq/*, Wrap, Clip, Squeeze*/) + private val opsToFix = Seq(Add, Sub, Lt, Leq, Gt, Geq, Eq, Neq /*, Wrap, Clip, Squeeze*/ ) - private def alignExpBP(e: Expression): Expression = e map alignExpBP match { + private def alignExpBP(e: Expression): Expression = e.map(alignExpBP) match { case DoPrim(SetP, Seq(arg), Seq(const), tpe: IntervalType) => fixBP(IntWidth(const))(arg) - case DoPrim(o, args, consts, t) if opsToFix.contains(o) && - (args.map(_.tpe).collect { case x: IntervalType => x }).size == args.size => + case DoPrim(o, args, consts, t) + if opsToFix.contains(o) && + (args.map(_.tpe).collect { case x: IntervalType => x }).size == args.size => val maxBP = args.map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ max _) DoPrim(o, args.map { a => fixBP(maxBP)(a) }, consts, t) case Mux(cond, tval, fval, t: IntervalType) => @@ -85,9 +85,9 @@ class TrimIntervals extends Pass { } private def fixBP(p: Width)(e: Expression): Expression = (p, e.tpe) match { case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired == current => e - case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired > current => + case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired > current => DoPrim(IncP, Seq(e), Seq(desired - current), IntervalType(l, u, IntWidth(desired))) - case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired < current => + case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired < current => val shiftAmt = current - desired val shiftGain = BigDecimal(BigInt(1) << shiftAmt.toInt) val shiftMul = Closed(BigDecimal(1) / shiftGain) diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala index b9cd32fa..10198b33 100644 --- a/src/main/scala/firrtl/passes/Uniquify.scala +++ b/src/main/scala/firrtl/passes/Uniquify.scala @@ -2,7 +2,6 @@ package firrtl.passes - import scala.annotation.tailrec import firrtl._ import firrtl.ir._ @@ -35,12 +34,11 @@ import MemPortUtils.memType object Uniquify extends Transform with DependencyAPIMigration { override def prerequisites = - Seq( Dependency(ResolveKinds), - Dependency(InferTypes) ) ++ firrtl.stage.Forms.WorkingIR + Seq(Dependency(ResolveKinds), Dependency(InferTypes)) ++ firrtl.stage.Forms.WorkingIR override def invalidates(a: Transform): Boolean = a match { case ResolveKinds | InferTypes => true - case _ => false + case _ => false } private case class UniquifyException(msg: String) extends FirrtlInternalException(msg) @@ -55,12 +53,13 @@ object Uniquify extends Transform with DependencyAPIMigration { */ @tailrec def findValidPrefix( - prefix: String, - elts: Seq[String], - namespace: collection.mutable.HashSet[String]): String = { - elts find (elt => namespace.contains(prefix + elt)) match { + prefix: String, + elts: Seq[String], + namespace: collection.mutable.HashSet[String] + ): String = { + elts.find(elt => namespace.contains(prefix + elt)) match { case Some(_) => findValidPrefix(prefix + "_", elts, namespace) - case None => prefix + case None => prefix } } @@ -70,16 +69,16 @@ object Uniquify extends Transform with DependencyAPIMigration { * => foo, foo bar, foo bar 0, foo bar 1, foo bar 0 a, foo bar 0 b, foo bar 1 a, foo bar 1 b, foo c * }}} */ - private [firrtl] def enumerateNames(tpe: Type): Seq[Seq[String]] = tpe match { + private[firrtl] def enumerateNames(tpe: Type): Seq[Seq[String]] = tpe match { case t: BundleType => - t.fields flatMap { f => - (enumerateNames(f.tpe) map (f.name +: _)) ++ Seq(Seq(f.name)) + t.fields.flatMap { f => + (enumerateNames(f.tpe).map(f.name +: _)) ++ Seq(Seq(f.name)) } case t: VectorType => - ((0 until t.size) map (i => Seq(i.toString))) ++ - ((0 until t.size) flatMap { i => - enumerateNames(t.tpe) map (i.toString +: _) - }) + ((0 until t.size).map(i => Seq(i.toString))) ++ + ((0 until t.size).flatMap { i => + enumerateNames(t.tpe).map(i.toString +: _) + }) case _ => Seq() } @@ -87,27 +86,38 @@ object Uniquify extends Transform with DependencyAPIMigration { def stmtToType(s: Statement)(implicit sinfo: Info, mname: String): BundleType = { // Recursive helper def recStmtToType(s: Statement): Seq[Field] = s match { - case sx: DefWire => Seq(Field(sx.name, Default, sx.tpe)) + case sx: DefWire => Seq(Field(sx.name, Default, sx.tpe)) case sx: DefRegister => Seq(Field(sx.name, Default, sx.tpe)) case sx: WDefInstance => Seq(Field(sx.name, Default, sx.tpe)) - case sx: DefMemory => sx.dataType match { - case (_: UIntType | _: SIntType | _: FixedType) => - Seq(Field(sx.name, Default, memType(sx))) - case tpe: BundleType => - val newFields = tpe.fields map ( f => - DefMemory(sx.info, f.name, f.tpe, sx.depth, sx.writeLatency, - sx.readLatency, sx.readers, sx.writers, sx.readwriters) - ) flatMap recStmtToType - Seq(Field(sx.name, Default, BundleType(newFields))) - case tpe: VectorType => - val newFields = (0 until tpe.size) map ( i => - sx.copy(name = i.toString, dataType = tpe.tpe) - ) flatMap recStmtToType - Seq(Field(sx.name, Default, BundleType(newFields))) - } - case sx: DefNode => Seq(Field(sx.name, Default, sx.value.tpe)) + case sx: DefMemory => + sx.dataType match { + case (_: UIntType | _: SIntType | _: FixedType) => + Seq(Field(sx.name, Default, memType(sx))) + case tpe: BundleType => + val newFields = tpe.fields + .map(f => + DefMemory( + sx.info, + f.name, + f.tpe, + sx.depth, + sx.writeLatency, + sx.readLatency, + sx.readers, + sx.writers, + sx.readwriters + ) + ) + .flatMap(recStmtToType) + Seq(Field(sx.name, Default, BundleType(newFields))) + case tpe: VectorType => + val newFields = + (0 until tpe.size).map(i => sx.copy(name = i.toString, dataType = tpe.tpe)).flatMap(recStmtToType) + Seq(Field(sx.name, Default, BundleType(newFields))) + } + case sx: DefNode => Seq(Field(sx.name, Default, sx.value.tpe)) case sx: Conditionally => recStmtToType(sx.conseq) ++ recStmtToType(sx.alt) - case sx: Block => (sx.stmts map recStmtToType).flatten + case sx: Block => (sx.stmts.map(recStmtToType)).flatten case sx => Seq() } BundleType(recStmtToType(s)) @@ -116,40 +126,44 @@ object Uniquify extends Transform with DependencyAPIMigration { // Accepts a Type and an initial namespace // Returns new Type with uniquified names private def uniquifyNames( - t: BundleType, - namespace: collection.mutable.HashSet[String]) - (implicit sinfo: Info, mname: String): BundleType = { + t: BundleType, + namespace: collection.mutable.HashSet[String] + )( + implicit sinfo: Info, + mname: String + ): BundleType = { def recUniquifyNames(t: Type, namespace: collection.mutable.HashSet[String]): (Type, Seq[String]) = t match { case tx: BundleType => // First add everything - val newFieldsAndElts = tx.fields map { f => + val newFieldsAndElts = tx.fields.map { f => val newName = findValidPrefix(f.name, Seq(""), namespace) namespace += newName Field(newName, f.flip, f.tpe) - } map { f => f.tpe match { - case _: GroundType => (f, Seq[String](f.name)) - case _ => - val (tpe, eltsx) = recUniquifyNames(f.tpe, collection.mutable.HashSet()) - // Need leading _ for findValidPrefix, it doesn't add _ for checks - val eltsNames: Seq[String] = eltsx map (e => "_" + e) - val prefix = findValidPrefix(f.name, eltsNames, namespace) - // We added f.name in previous map, delete if we change it - if (prefix != f.name) { - namespace -= f.name - namespace += prefix - } - val newElts: Seq[String] = eltsx map (e => LowerTypes.loweredName(prefix +: Seq(e))) - namespace ++= newElts - (Field(prefix, f.flip, tpe), prefix +: newElts) + }.map { f => + f.tpe match { + case _: GroundType => (f, Seq[String](f.name)) + case _ => + val (tpe, eltsx) = recUniquifyNames(f.tpe, collection.mutable.HashSet()) + // Need leading _ for findValidPrefix, it doesn't add _ for checks + val eltsNames: Seq[String] = eltsx.map(e => "_" + e) + val prefix = findValidPrefix(f.name, eltsNames, namespace) + // We added f.name in previous map, delete if we change it + if (prefix != f.name) { + namespace -= f.name + namespace += prefix + } + val newElts: Seq[String] = eltsx.map(e => LowerTypes.loweredName(prefix +: Seq(e))) + namespace ++= newElts + (Field(prefix, f.flip, tpe), prefix +: newElts) } } val (newFields, elts) = newFieldsAndElts.unzip (BundleType(newFields), elts.flatten) case tx: VectorType => val (tpe, elts) = recUniquifyNames(tx.tpe, namespace) - val newElts = ((0 until tx.size) map (i => i.toString)) ++ - ((0 until tx.size) flatMap { i => - elts map (e => LowerTypes.loweredName(Seq(i.toString, e))) + val newElts = ((0 until tx.size).map(i => i.toString)) ++ + ((0 until tx.size).flatMap { i => + elts.map(e => LowerTypes.loweredName(Seq(i.toString, e))) }) (VectorType(tpe, tx.size), newElts) case tx => (tx, Nil) @@ -164,19 +178,26 @@ object Uniquify extends Transform with DependencyAPIMigration { // Creates a mapping from flattened references to members of $from -> // flattened references to members of $to private def createNameMapping( - from: Type, - to: Type) - (implicit sinfo: Info, mname: String): Map[String, NameMapNode] = { + from: Type, + to: Type + )( + implicit sinfo: Info, + mname: String + ): Map[String, NameMapNode] = { (from, to) match { case (fromx: BundleType, tox: BundleType) => - (fromx.fields zip tox.fields flatMap { case (f, t) => - val eltsMap = createNameMapping(f.tpe, t.tpe) - if ((f.name != t.name) || eltsMap.nonEmpty) { - Map(f.name -> NameMapNode(t.name, eltsMap)) - } else { - Map[String, NameMapNode]() - } - }).toMap + (fromx.fields + .zip(tox.fields) + .flatMap { + case (f, t) => + val eltsMap = createNameMapping(f.tpe, t.tpe) + if ((f.name != t.name) || eltsMap.nonEmpty) { + Map(f.name -> NameMapNode(t.name, eltsMap)) + } else { + Map[String, NameMapNode]() + } + }) + .toMap case (fromx: VectorType, tox: VectorType) => createNameMapping(fromx.tpe, tox.tpe) case (fromx, tox) => @@ -187,18 +208,19 @@ object Uniquify extends Transform with DependencyAPIMigration { // Maps names in expression to new uniquified names private def uniquifyNamesExp( - exp: Expression, - map: Map[String, NameMapNode]) - (implicit sinfo: Info, mname: String): Expression = { + exp: Expression, + map: Map[String, NameMapNode] + )( + implicit sinfo: Info, + mname: String + ): Expression = { // Recursive Helper - def rec(exp: Expression, m: Map[String, NameMapNode]): - (Expression, Map[String, NameMapNode]) = exp match { + def rec(exp: Expression, m: Map[String, NameMapNode]): (Expression, Map[String, NameMapNode]) = exp match { case e: WRef => if (m.contains(e.name)) { val node = m(e.name) (WRef(node.name, e.tpe, e.kind, e.flow), node.elts) - } - else (e, Map()) + } else (e, Map()) case e: WSubField => val (subExp, subMap) = rec(e.expr, m) val (retName, retMap) = @@ -218,18 +240,21 @@ object Uniquify extends Transform with DependencyAPIMigration { (WSubAccess(subExp, index, e.tpe, e.flow), subMap) case (_: UIntLiteral | _: SIntLiteral) => (exp, m) case (_: Mux | _: ValidIf | _: DoPrim) => - (exp map ((e: Expression) => uniquifyNamesExp(e, map)), m) + (exp.map((e: Expression) => uniquifyNamesExp(e, map)), m) } rec(exp, map)._1 } // Uses map to recursively rename fields of tpe private def uniquifyNamesType( - tpe: Type, - map: Map[String, NameMapNode]) - (implicit sinfo: Info, mname: String): Type = tpe match { + tpe: Type, + map: Map[String, NameMapNode] + )( + implicit sinfo: Info, + mname: String + ): Type = tpe match { case t: BundleType => - val newFields = t.fields map { f => + val newFields = t.fields.map { f => if (map.contains(f.name)) { val node = map(f.name) Field(node.name, f.flip, uniquifyNamesType(f.tpe, node.elts)) @@ -244,8 +269,11 @@ object Uniquify extends Transform with DependencyAPIMigration { } // Everything wrapped in run so that it's thread safe - @deprecated("The functionality of Uniquify is now part of LowerTypes." + - "Please file an issue with firrtl if you use Uniquify outside of the context of LowerTypes.", "Firrtl 1.4") + @deprecated( + "The functionality of Uniquify is now part of LowerTypes." + + "Please file an issue with firrtl if you use Uniquify outside of the context of LowerTypes.", + "Firrtl 1.4" + ) def execute(state: CircuitState): CircuitState = { val c = state.circuit val renames = RenameMap() @@ -263,22 +291,22 @@ object Uniquify extends Transform with DependencyAPIMigration { val nameMap = collection.mutable.HashMap[String, NameMapNode]() def uniquifyExp(e: Expression): Expression = e match { - case (_: WRef | _: WSubField | _: WSubIndex | _: WSubAccess ) => + case (_: WRef | _: WSubField | _: WSubIndex | _: WSubAccess) => uniquifyNamesExp(e, nameMap.toMap) - case e: Mux => e map uniquifyExp - case e: ValidIf => e map uniquifyExp + case e: Mux => e.map(uniquifyExp) + case e: ValidIf => e.map(uniquifyExp) case (_: UIntLiteral | _: SIntLiteral) => e - case e: DoPrim => e map uniquifyExp + case e: DoPrim => e.map(uniquifyExp) } def uniquifyStmt(s: Statement): Statement = { - s map uniquifyStmt map uniquifyExp match { + s.map(uniquifyStmt).map(uniquifyExp) match { case sx: DefWire => sinfo = sx.info if (nameMap.contains(sx.name)) { val node = nameMap(sx.name) val newType = uniquifyNamesType(sx.tpe, node.elts) - (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach { + (Utils.create_exps(sx.name, sx.tpe).zip(Utils.create_exps(node.name, newType))).foreach { case (from, to) => renames.rename(from.serialize, to.serialize) } DefWire(sx.info, node.name, newType) @@ -290,7 +318,7 @@ object Uniquify extends Transform with DependencyAPIMigration { if (nameMap.contains(sx.name)) { val node = nameMap(sx.name) val newType = uniquifyNamesType(sx.tpe, node.elts) - (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach { + (Utils.create_exps(sx.name, sx.tpe).zip(Utils.create_exps(node.name, newType))).foreach { case (from, to) => renames.rename(from.serialize, to.serialize) } DefRegister(sx.info, node.name, newType, sx.clock, sx.reset, sx.init) @@ -302,7 +330,7 @@ object Uniquify extends Transform with DependencyAPIMigration { if (nameMap.contains(sx.name)) { val node = nameMap(sx.name) val newType = portTypeMap(m.name) - (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach { + (Utils.create_exps(sx.name, sx.tpe).zip(Utils.create_exps(node.name, newType))).foreach { case (from, to) => renames.rename(from.serialize, to.serialize) } WDefInstance(sx.info, node.name, sx.module, newType) @@ -317,7 +345,7 @@ object Uniquify extends Transform with DependencyAPIMigration { val mem = sx.copy(name = node.name, dataType = dataType) // Create new mapping to handle references to memory data fields val uniqueMemMap = createNameMapping(memType(sx), memType(mem)) - (Utils.create_exps(sx.name, memType(sx)) zip Utils.create_exps(node.name, memType(mem))) foreach { + (Utils.create_exps(sx.name, memType(sx)).zip(Utils.create_exps(node.name, memType(mem)))).foreach { case (from, to) => renames.rename(from.serialize, to.serialize) } nameMap(sx.name) = NameMapNode(node.name, node.elts ++ uniqueMemMap) @@ -329,9 +357,12 @@ object Uniquify extends Transform with DependencyAPIMigration { sinfo = sx.info if (nameMap.contains(sx.name)) { val node = nameMap(sx.name) - (Utils.create_exps(sx.name, s.asInstanceOf[DefNode].value.tpe) zip Utils.create_exps(node.name, sx.value.tpe)) foreach { - case (from, to) => renames.rename(from.serialize, to.serialize) - } + (Utils + .create_exps(sx.name, s.asInstanceOf[DefNode].value.tpe) + .zip(Utils.create_exps(node.name, sx.value.tpe))) + .foreach { + case (from, to) => renames.rename(from.serialize, to.serialize) + } DefNode(sx.info, node.name, sx.value) } else { sx @@ -354,19 +385,18 @@ object Uniquify extends Transform with DependencyAPIMigration { mname = m.name m match { case m: ExtModule => m - case m: Module => + case m: Module => // Adds port names to namespace and namemap nameMap ++= portNameMap(m.name) - namespace ++= create_exps("", portTypeMap(m.name)) map - LowerTypes.loweredName map (_.tail) - m.copy(body = uniquifyBody(m.body) ) + namespace ++= create_exps("", portTypeMap(m.name)).map(LowerTypes.loweredName).map(_.tail) + m.copy(body = uniquifyBody(m.body)) } } def uniquifyPorts(renames: RenameMap)(m: DefModule): DefModule = { renames.setModule(m.name) def uniquifyPorts(ports: Seq[Port]): Seq[Port] = { - val portsType = BundleType(ports map { + val portsType = BundleType(ports.map { case Port(_, name, dir, tpe) => Field(name, to_flip(dir), tpe) }) val uniquePortsType = uniquifyNames(portsType, collection.mutable.HashSet()) @@ -374,11 +404,12 @@ object Uniquify extends Transform with DependencyAPIMigration { portNameMap += (m.name -> localMap) portTypeMap += (m.name -> uniquePortsType) - ports zip uniquePortsType.fields map { case (p, f) => - (Utils.create_exps(p.name, p.tpe) zip Utils.create_exps(f.name, f.tpe)) foreach { - case (from, to) => renames.rename(from.serialize, to.serialize) - } - Port(p.info, f.name, p.direction, f.tpe) + ports.zip(uniquePortsType.fields).map { + case (p, f) => + (Utils.create_exps(p.name, p.tpe).zip(Utils.create_exps(f.name, f.tpe))).foreach { + case (from, to) => renames.rename(from.serialize, to.serialize) + } + Port(p.info, f.name, p.direction, f.tpe) } } @@ -386,12 +417,12 @@ object Uniquify extends Transform with DependencyAPIMigration { mname = m.name m match { case m: ExtModule => m.copy(ports = uniquifyPorts(m.ports)) - case m: Module => m.copy(ports = uniquifyPorts(m.ports)) + case m: Module => m.copy(ports = uniquifyPorts(m.ports)) } } sinfo = c.info - val result = Circuit(c.info, c.modules map uniquifyPorts(renames) map uniquifyModule(renames), c.main) + val result = Circuit(c.info, c.modules.map(uniquifyPorts(renames)).map(uniquifyModule(renames)), c.main) state.copy(circuit = result, renames = Some(renames)) } } diff --git a/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala b/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala index 36eff379..0b046a5f 100644 --- a/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala +++ b/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala @@ -12,28 +12,30 @@ import firrtl.options.Dependency import scala.collection.mutable /** - * Verilog has the width of (a % b) = Max(W(a), W(b)) - * FIRRTL has the width of (a % b) = Min(W(a), W(b)), which makes more sense, - * but nevertheless is a problem when emitting verilog - * - * This pass finds every instance of (a % b) and: - * 1) adds a temporary node equal to (a % b) with width Max(W(a), W(b)) - * 2) replaces the reference to (a % b) with a bitslice of the temporary node - * to get back down to width Min(W(a), W(b)) - * - * This is technically incorrect firrtl, but allows the verilog emitter - * to emit correct verilog without needing to add temporary nodes - */ + * Verilog has the width of (a % b) = Max(W(a), W(b)) + * FIRRTL has the width of (a % b) = Min(W(a), W(b)), which makes more sense, + * but nevertheless is a problem when emitting verilog + * + * This pass finds every instance of (a % b) and: + * 1) adds a temporary node equal to (a % b) with width Max(W(a), W(b)) + * 2) replaces the reference to (a % b) with a bitslice of the temporary node + * to get back down to width Min(W(a), W(b)) + * + * This is technically incorrect firrtl, but allows the verilog emitter + * to emit correct verilog without needing to add temporary nodes + */ object VerilogModulusCleanup extends Pass { override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ - Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper], - Dependency[firrtl.transforms.FixAddingNegativeLiterals], - Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], - Dependency[firrtl.transforms.InlineBitExtractionsTransform], - Dependency[firrtl.transforms.InlineCastsTransform], - Dependency[firrtl.transforms.LegalizeClocksTransform], - Dependency[firrtl.transforms.FlattenRegUpdate] ) + Seq( + Dependency[firrtl.transforms.BlackBoxSourceHelper], + Dependency[firrtl.transforms.FixAddingNegativeLiterals], + Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], + Dependency[firrtl.transforms.InlineBitExtractionsTransform], + Dependency[firrtl.transforms.InlineCastsTransform], + Dependency[firrtl.transforms.LegalizeClocksTransform], + Dependency[firrtl.transforms.FlattenRegUpdate] + ) override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized @@ -51,32 +53,35 @@ object VerilogModulusCleanup extends Pass { case t => UnknownWidth } - def maxWidth(ws: Seq[Width]): Width = ws reduceLeft { (x,y) => (x,y) match { - case (IntWidth(x), IntWidth(y)) => IntWidth(x max y) - case (x, y) => UnknownWidth - }} + def maxWidth(ws: Seq[Width]): Width = ws.reduceLeft { (x, y) => + (x, y) match { + case (IntWidth(x), IntWidth(y)) => IntWidth(x.max(y)) + case (x, y) => UnknownWidth + } + } def verilogRemWidth(e: DoPrim)(tpe: Type): Type = { val newWidth = maxWidth(e.args.map(exp => getWidth(exp))) - tpe mapWidth (w => newWidth) + tpe.mapWidth(w => newWidth) } def removeRem(e: Expression): Expression = e match { - case e: DoPrim => e.op match { - case Rem => - val name = namespace.newTemp - val newType = e mapType verilogRemWidth(e) - v += DefNode(get_info(s), name, e mapType verilogRemWidth(e)) - val remRef = WRef(name, newType.tpe, kind(e), flow(e)) - val remWidth = bitWidth(e.tpe) - DoPrim(Bits, Seq(remRef), Seq(remWidth - 1, BigInt(0)), e.tpe) - case _ => e - } + case e: DoPrim => + e.op match { + case Rem => + val name = namespace.newTemp + val newType = e.mapType(verilogRemWidth(e)) + v += DefNode(get_info(s), name, e.mapType(verilogRemWidth(e))) + val remRef = WRef(name, newType.tpe, kind(e), flow(e)) + val remWidth = bitWidth(e.tpe) + DoPrim(Bits, Seq(remRef), Seq(remWidth - 1, BigInt(0)), e.tpe) + case _ => e + } case _ => e } - s map removeRem match { - case x: Block => x map onStmt + s.map(removeRem) match { + case x: Block => x.map(onStmt) case EmptyStmt => EmptyStmt case x => v += x @@ -90,8 +95,8 @@ object VerilogModulusCleanup extends Pass { } def run(c: Circuit): Circuit = { - val modules = c.modules map { - case m: Module => onModule(m) + val modules = c.modules.map { + case m: Module => onModule(m) case m: ExtModule => m } Circuit(c.info, modules, c.main) diff --git a/src/main/scala/firrtl/passes/VerilogPrep.scala b/src/main/scala/firrtl/passes/VerilogPrep.scala index 03d47cfc..eeb34fa9 100644 --- a/src/main/scala/firrtl/passes/VerilogPrep.scala +++ b/src/main/scala/firrtl/passes/VerilogPrep.scala @@ -21,15 +21,17 @@ import scala.collection.mutable object VerilogPrep extends Pass { override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ - Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper], - Dependency[firrtl.transforms.FixAddingNegativeLiterals], - Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], - Dependency[firrtl.transforms.InlineBitExtractionsTransform], - Dependency[firrtl.transforms.InlineCastsTransform], - Dependency[firrtl.transforms.LegalizeClocksTransform], - Dependency[firrtl.transforms.FlattenRegUpdate], - Dependency(passes.VerilogModulusCleanup), - Dependency[firrtl.transforms.VerilogRename] ) + Seq( + Dependency[firrtl.transforms.BlackBoxSourceHelper], + Dependency[firrtl.transforms.FixAddingNegativeLiterals], + Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], + Dependency[firrtl.transforms.InlineBitExtractionsTransform], + Dependency[firrtl.transforms.InlineCastsTransform], + Dependency[firrtl.transforms.LegalizeClocksTransform], + Dependency[firrtl.transforms.FlattenRegUpdate], + Dependency(passes.VerilogModulusCleanup), + Dependency[firrtl.transforms.VerilogRename] + ) override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized @@ -46,9 +48,9 @@ object VerilogPrep extends Pass { val sourceMap = mutable.HashMap.empty[WrappedExpression, Expression] lazy val namespace = Namespace(m) - def onStmt(stmt: Statement): Statement = stmt map onStmt match { + def onStmt(stmt: Statement): Statement = stmt.map(onStmt) match { case attach: Attach => - val wires = attach.exprs groupBy kind + val wires = attach.exprs.groupBy(kind) val sources = wires.getOrElse(PortKind, Seq.empty) ++ wires.getOrElse(WireKind, Seq.empty) val instPorts = wires.getOrElse(InstanceKind, Seq.empty) // Sanity check (Should be caught by CheckTypes) @@ -71,14 +73,14 @@ object VerilogPrep extends Pass { case s => s } - (m map onStmt, sourceMap.toMap) + (m.map(onStmt), sourceMap.toMap) } def run(c: Circuit): Circuit = { def lowerE(e: Expression): Expression = e match { case (_: WRef | _: WSubField) if kind(e) == InstanceKind => WRef(LowerTypes.loweredName(e), e.tpe, kind(e), flow(e)) - case _ => e map lowerE + case _ => e.map(lowerE) } def lowerS(attachMap: AttachSourceMap)(s: Statement): Statement = s match { @@ -96,12 +98,12 @@ object VerilogPrep extends Pass { }.unzip val newInst = WDefInstanceConnector(info, name, module, tpe, portCons) Block(wires.flatten :+ newInst) - case other => other map lowerS(attachMap) map lowerE + case other => other.map(lowerS(attachMap)).map(lowerE) } - val modulesx = c.modules map { mod => + val modulesx = c.modules.map { mod => val (modx, attachMap) = collectAndRemoveAttach(mod) - modx map lowerS(attachMap) + modx.map(lowerS(attachMap)) } c.copy(modules = modulesx) } diff --git a/src/main/scala/firrtl/passes/ZeroLengthVecs.scala b/src/main/scala/firrtl/passes/ZeroLengthVecs.scala index 39c127de..e61780a4 100644 --- a/src/main/scala/firrtl/passes/ZeroLengthVecs.scala +++ b/src/main/scala/firrtl/passes/ZeroLengthVecs.scala @@ -17,10 +17,7 @@ import firrtl.options.Dependency */ object ZeroLengthVecs extends Pass { override def prerequisites = - Seq( Dependency(PullMuxes), - Dependency(ResolveKinds), - Dependency(InferTypes), - Dependency(ExpandConnects) ) + Seq(Dependency(PullMuxes), Dependency(ResolveKinds), Dependency(InferTypes), Dependency(ExpandConnects)) override def invalidates(a: Transform) = false @@ -28,8 +25,8 @@ object ZeroLengthVecs extends Pass { // interval type with the type alone unless you declare a component private def replaceWithDontCare(toReplace: Expression): Expression = { val default = toReplace.tpe match { - case UIntType(w) => UIntLiteral(0, w) - case SIntType(w) => SIntLiteral(0, w) + case UIntType(w) => UIntLiteral(0, w) + case SIntType(w) => SIntLiteral(0, w) case FixedType(w, p) => FixedLiteral(0, w, p) case it: IntervalType => val zeroType = IntervalType(Closed(0), Closed(0), IntWidth(0)) @@ -40,11 +37,11 @@ object ZeroLengthVecs extends Pass { } private def zeroLenDerivedRefLike(expr: Expression): Boolean = (expr, expr.tpe) match { - case (_, VectorType(_, 0)) => true - case (WSubIndex(e, _, _, _), _) => zeroLenDerivedRefLike(e) + case (_, VectorType(_, 0)) => true + case (WSubIndex(e, _, _, _), _) => zeroLenDerivedRefLike(e) case (WSubAccess(e, _, _, _), _) => zeroLenDerivedRefLike(e) - case (WSubField(e, _, _, _), _) => zeroLenDerivedRefLike(e) - case _ => false + case (WSubField(e, _, _, _), _) => zeroLenDerivedRefLike(e) + case _ => false } // The connects have all been lowered, so all aggregate-typed expressions are "grounded" by WSubField/WSubAccess/WSubIndex @@ -52,13 +49,13 @@ object ZeroLengthVecs extends Pass { private def dropZeroLenSubAccesses(expr: Expression): Expression = expr match { case _: WSubIndex | _: WSubAccess | _: WSubField => if (zeroLenDerivedRefLike(expr)) replaceWithDontCare(expr) else expr - case e => e map dropZeroLenSubAccesses + case e => e.map(dropZeroLenSubAccesses) } // Attach semantics: drop all zero-length-derived members of attach group, drop stmt if trivial private def onStmt(stmt: Statement): Statement = stmt match { case Connect(_, sink, _) if zeroLenDerivedRefLike(sink) => EmptyStmt - case IsInvalid(_, sink) if zeroLenDerivedRefLike(sink) => EmptyStmt + case IsInvalid(_, sink) if zeroLenDerivedRefLike(sink) => EmptyStmt case Attach(info, sinks) => val filtered = Attach(info, sinks.filterNot(zeroLenDerivedRefLike)) if (filtered.exprs.length < 2) EmptyStmt else filtered diff --git a/src/main/scala/firrtl/passes/ZeroWidth.scala b/src/main/scala/firrtl/passes/ZeroWidth.scala index 56d66ef0..82321f95 100644 --- a/src/main/scala/firrtl/passes/ZeroWidth.scala +++ b/src/main/scala/firrtl/passes/ZeroWidth.scala @@ -11,12 +11,14 @@ import firrtl.options.Dependency object ZeroWidth extends Transform with DependencyAPIMigration { override def prerequisites = - Seq( Dependency(PullMuxes), - Dependency(ReplaceAccesses), - Dependency(ExpandConnects), - Dependency(RemoveAccesses), - Dependency[ExpandWhensAndCheck], - Dependency(ConvertFixedToSInt) ) ++ firrtl.stage.Forms.Deduped + Seq( + Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects), + Dependency(RemoveAccesses), + Dependency[ExpandWhensAndCheck], + Dependency(ConvertFixedToSInt) + ) ++ firrtl.stage.Forms.Deduped override def invalidates(a: Transform): Boolean = a match { case InferTypes => true @@ -24,30 +26,41 @@ object ZeroWidth extends Transform with DependencyAPIMigration { } private def makeEmptyMemBundle(name: String): Field = - Field(name, Flip, BundleType(Seq( - Field("addr", Default, UIntType(IntWidth(0))), - Field("en", Default, UIntType(IntWidth(0))), - Field("clk", Default, UIntType(IntWidth(0))), - Field("data", Flip, UIntType(IntWidth(0))) - ))) + Field( + name, + Flip, + BundleType( + Seq( + Field("addr", Default, UIntType(IntWidth(0))), + Field("en", Default, UIntType(IntWidth(0))), + Field("clk", Default, UIntType(IntWidth(0))), + Field("data", Flip, UIntType(IntWidth(0))) + ) + ) + ) private def onEmptyMemStmt(s: Statement): Statement = s match { - case d @ DefMemory(info, name, tpe, _, _, _, rs, ws, rws, _) => removeZero(tpe) match { - case None => - DefWire(info, name, BundleType( - rs.map(r => makeEmptyMemBundle(r)) ++ - ws.map(w => makeEmptyMemBundle(w)) ++ - rws.map(rw => makeEmptyMemBundle(rw)) - )) - case Some(_) => d - } - case sx => sx map onEmptyMemStmt + case d @ DefMemory(info, name, tpe, _, _, _, rs, ws, rws, _) => + removeZero(tpe) match { + case None => + DefWire( + info, + name, + BundleType( + rs.map(r => makeEmptyMemBundle(r)) ++ + ws.map(w => makeEmptyMemBundle(w)) ++ + rws.map(rw => makeEmptyMemBundle(rw)) + ) + ) + case Some(_) => d + } + case sx => sx.map(onEmptyMemStmt) } private def onModuleEmptyMemStmt(m: DefModule): DefModule = { m match { case ext: ExtModule => ext - case in: Module => in.copy(body = onEmptyMemStmt(in.body)) + case in: Module => in.copy(body = onEmptyMemStmt(in.body)) } } @@ -59,20 +72,20 @@ object ZeroWidth extends Transform with DependencyAPIMigration { * This replaces memories with a DefWire() bundle that contains the address, en, * clk, and data fields implemented as zero width wires. Running the rest of the ZeroWidth * transform will remove these dangling references properly. - * */ def executeEmptyMemStmt(state: CircuitState): CircuitState = { val c = state.circuit - val result = c.copy(modules = c.modules map onModuleEmptyMemStmt) + val result = c.copy(modules = c.modules.map(onModuleEmptyMemStmt)) state.copy(circuit = result) } // This is slightly different and specialized version of create_exps, TODO unify? private def findRemovable(expr: => Expression, tpe: Type): Seq[Expression] = tpe match { - case GroundType(width) => width match { - case IntWidth(ZERO) => List(expr) - case _ => List.empty - } + case GroundType(width) => + width match { + case IntWidth(ZERO) => List(expr) + case _ => List.empty + } case BundleType(fields) => if (fields.isEmpty) List(expr) else fields.flatMap(f => findRemovable(WSubField(expr, f.name, f.tpe, SourceFlow), f.tpe)) @@ -95,7 +108,7 @@ object ZeroWidth extends Transform with DependencyAPIMigration { t } x match { - case s: Statement => s map onType(s.name) + case s: Statement => s.map(onType(s.name)) case Port(_, name, _, t) => onType(name)(t) } removedNames @@ -103,14 +116,14 @@ object ZeroWidth extends Transform with DependencyAPIMigration { private[passes] def removeZero(t: Type): Option[Type] = t match { case GroundType(IntWidth(ZERO)) => None case BundleType(fields) => - fields map (f => (f, removeZero(f.tpe))) collect { + fields.map(f => (f, removeZero(f.tpe))).collect { case (Field(name, flip, _), Some(t)) => Field(name, flip, t) } match { case Nil => None case seq => Some(BundleType(seq)) } - case VectorType(t, size) => removeZero(t) map (VectorType(_, size)) - case x => Some(x) + case VectorType(t, size) => removeZero(t).map(VectorType(_, size)) + case x => Some(x) } private def onExp(e: Expression): Expression = e match { case DoPrim(Cat, args, consts, tpe) => @@ -118,26 +131,27 @@ object ZeroWidth extends Transform with DependencyAPIMigration { x.tpe match { case UIntType(IntWidth(ZERO)) => Seq.empty[Expression] case SIntType(IntWidth(ZERO)) => Seq.empty[Expression] - case other => Seq(x) + case other => Seq(x) } } nonZeros match { - case Nil => UIntLiteral(ZERO, IntWidth(BigInt(1))) + case Nil => UIntLiteral(ZERO, IntWidth(BigInt(1))) case Seq(x) => x - case seq => DoPrim(Cat, seq, consts, tpe) map onExp + case seq => DoPrim(Cat, seq, consts, tpe).map(onExp) } case DoPrim(Andr, Seq(x), _, _) if (bitWidth(x.tpe) == 0) => UIntLiteral(1) // nothing false - case other => other.tpe match { - case UIntType(IntWidth(ZERO)) => UIntLiteral(ZERO, IntWidth(BigInt(1))) - case SIntType(IntWidth(ZERO)) => SIntLiteral(ZERO, IntWidth(BigInt(1))) - case _ => e map onExp - } + case other => + other.tpe match { + case UIntType(IntWidth(ZERO)) => UIntLiteral(ZERO, IntWidth(BigInt(1))) + case SIntType(IntWidth(ZERO)) => SIntLiteral(ZERO, IntWidth(BigInt(1))) + case _ => e.map(onExp) + } } private def onStmt(renames: RenameMap)(s: Statement): Statement = s match { case d @ DefWire(info, name, tpe) => renames.delete(getRemoved(d)) removeZero(tpe) match { - case None => EmptyStmt + case None => EmptyStmt case Some(t) => DefWire(info, name, t) } case d @ DefRegister(info, name, tpe, clock, reset, init) => @@ -145,7 +159,7 @@ object ZeroWidth extends Transform with DependencyAPIMigration { removeZero(tpe) match { case None => EmptyStmt case Some(t) => - DefRegister(info, name, t, onExp(clock), onExp(reset), onExp(init)) + DefRegister(info, name, t, onExp(clock), onExp(reset), onExp(init)) } case d: DefMemory => renames.delete(getRemoved(d)) @@ -154,25 +168,28 @@ object ZeroWidth extends Transform with DependencyAPIMigration { Utils.throwInternalError(s"private pass ZeroWidthMemRemove should have removed this memory: $d") case Some(t) => d.copy(dataType = t) } - case Connect(info, loc, exp) => removeZero(loc.tpe) match { - case None => EmptyStmt - case Some(t) => Connect(info, loc, onExp(exp)) - } - case IsInvalid(info, exp) => removeZero(exp.tpe) match { - case None => EmptyStmt - case Some(t) => IsInvalid(info, onExp(exp)) - } - case DefNode(info, name, value) => removeZero(value.tpe) match { - case None => EmptyStmt - case Some(t) => DefNode(info, name, onExp(value)) - } - case sx => sx map onStmt(renames) map onExp + case Connect(info, loc, exp) => + removeZero(loc.tpe) match { + case None => EmptyStmt + case Some(t) => Connect(info, loc, onExp(exp)) + } + case IsInvalid(info, exp) => + removeZero(exp.tpe) match { + case None => EmptyStmt + case Some(t) => IsInvalid(info, onExp(exp)) + } + case DefNode(info, name, value) => + removeZero(value.tpe) match { + case None => EmptyStmt + case Some(t) => DefNode(info, name, onExp(value)) + } + case sx => sx.map(onStmt(renames)).map(onExp) } private def onModule(renames: RenameMap)(m: DefModule): DefModule = { renames.setModule(m.name) // For each port, record deleted subcomponents - m.ports.foreach{p => renames.delete(getRemoved(p))} - val ports = m.ports map (p => (p, removeZero(p.tpe))) flatMap { + m.ports.foreach { p => renames.delete(getRemoved(p)) } + val ports = m.ports.map(p => (p, removeZero(p.tpe))).flatMap { case (Port(info, name, dir, _), Some(t)) => Seq(Port(info, name, dir, t)) case (Port(_, name, _, _), None) => renames.delete(name) @@ -180,7 +197,7 @@ object ZeroWidth extends Transform with DependencyAPIMigration { } m match { case ext: ExtModule => ext.copy(ports = ports) - case in: Module => in.copy(ports = ports, body = onStmt(renames)(in.body)) + case in: Module => in.copy(ports = ports, body = onStmt(renames)(in.body)) } } def execute(state: CircuitState): CircuitState = { @@ -189,7 +206,7 @@ object ZeroWidth extends Transform with DependencyAPIMigration { val c = InferTypes.run(executeEmptyMemStmt(state).circuit) val renames = RenameMap() renames.setCircuit(c.main) - val result = c.copy(modules = c.modules map onModule(renames)) + val result = c.copy(modules = c.modules.map(onModule(renames))) CircuitState(result, outputForm, state.annotations, Some(renames)) } } diff --git a/src/main/scala/firrtl/passes/clocklist/ClockList.scala b/src/main/scala/firrtl/passes/clocklist/ClockList.scala index c2323d4c..bfc03b51 100644 --- a/src/main/scala/firrtl/passes/clocklist/ClockList.scala +++ b/src/main/scala/firrtl/passes/clocklist/ClockList.scala @@ -13,8 +13,8 @@ import Utils._ import memlib.AnalysisUtils._ /** Starting with a top module, determine the clock origins of each child instance. - * Write the result to writer. - */ + * Write the result to writer. + */ class ClockList(top: String, writer: Writer) extends Pass { def run(c: Circuit): Circuit = { // Build useful datastructures @@ -29,7 +29,7 @@ class ClockList(top: String, writer: Writer) extends Pass { // Clock sources must be blackbox outputs and top's clock val partialSourceList = getSourceList(moduleMap)(lineages) - val sourceList = partialSourceList ++ moduleMap(top).ports.collect{ case Port(i, n, Input, ClockType) => n } + val sourceList = partialSourceList ++ moduleMap(top).ports.collect { case Port(i, n, Input, ClockType) => n } writer.append(s"Sourcelist: $sourceList \n") // Remove everything from the circuit, unless it has a clock type @@ -37,8 +37,9 @@ class ClockList(top: String, writer: Writer) extends Pass { val onlyClockCircuit = RemoveAllButClocks.run(c) // Inline the clock-only circuit up to the specified top module - val modulesToInline = (c.modules.collect { case Module(_, n, _, _) if n != top => ModuleName(n, CircuitName(c.main)) }).toSet - val inlineTransform = new InlineInstances{ override val inlineDelim = "$" } + val modulesToInline = + (c.modules.collect { case Module(_, n, _, _) if n != top => ModuleName(n, CircuitName(c.main)) }).toSet + val inlineTransform = new InlineInstances { override val inlineDelim = "$" } val inlinedCircuit = inlineTransform.run(onlyClockCircuit, modulesToInline, Set(), Seq()).circuit val topModule = inlinedCircuit.modules.find(_.name == top).getOrElse(throwInternalError("no top module")) @@ -49,13 +50,14 @@ class ClockList(top: String, writer: Writer) extends Pass { val origins = getOrigins(connects, "", moduleMap)(lineages) // If the clock origin is contained in the source list, label good (otherwise bad) - origins.foreach { case (instance, origin) => - val sep = if(instance == "") "" else "." - if(!sourceList.contains(origin.replace('.','$'))){ - outputBuffer.append(s"Bad Origin of $instance${sep}clock is $origin\n") - } else { - outputBuffer.append(s"Good Origin of $instance${sep}clock is $origin\n") - } + origins.foreach { + case (instance, origin) => + val sep = if (instance == "") "" else "." + if (!sourceList.contains(origin.replace('.', '$'))) { + outputBuffer.append(s"Bad Origin of $instance${sep}clock is $origin\n") + } else { + outputBuffer.append(s"Good Origin of $instance${sep}clock is $origin\n") + } } // Write to output file diff --git a/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala b/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala index e6617857..468ba905 100644 --- a/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala +++ b/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala @@ -12,8 +12,7 @@ import memlib._ import firrtl.options.{RegisteredTransform, ShellOption} import firrtl.stage.{Forms, RunFirrtlTransformAnnotation} -case class ClockListAnnotation(target: ModuleName, outputConfig: String) extends - SingleTargetAnnotation[ModuleName] { +case class ClockListAnnotation(target: ModuleName, outputConfig: String) extends SingleTargetAnnotation[ModuleName] { def duplicate(n: ModuleName) = ClockListAnnotation(n, outputConfig) } @@ -44,7 +43,7 @@ Usage: ) passOptions.get(InputConfigFileName) match { case Some(x) => error("Unneeded input config file name!" + usage) - case None => + case None => } val target = ModuleName(passModule, CircuitName(passCircuit)) ClockListAnnotation(target, outputConfig) @@ -53,18 +52,20 @@ Usage: class ClockListTransform extends Transform with DependencyAPIMigration with RegisteredTransform { - override def prerequisites = Forms.LowForm - override def optionalPrerequisites = Seq.empty - override def optionalPrerequisiteOf = Forms.LowEmitters + override def prerequisites = Forms.LowForm + override def optionalPrerequisites = Seq.empty + override def optionalPrerequisiteOf = Forms.LowEmitters val options = Seq( new ShellOption[String]( longOption = "list-clocks", - toAnnotationSeq = (a: String) => Seq( passes.clocklist.ClockListAnnotation.parse(a), - RunFirrtlTransformAnnotation(new ClockListTransform) ), + toAnnotationSeq = (a: String) => + Seq(passes.clocklist.ClockListAnnotation.parse(a), RunFirrtlTransformAnnotation(new ClockListTransform)), helpText = "List which signal drives each clock of every descendent of specified modules", shortOption = Some("clks"), - helpValueName = Some("-c:<circuit>:-m:<module>:-o:<filename>") ) ) + helpValueName = Some("-c:<circuit>:-m:<module>:-o:<filename>") + ) + ) def passSeq(top: String, writer: Writer): Seq[Pass] = Seq(new ClockList(top, writer)) diff --git a/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala b/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala index b77629fc..00e07588 100644 --- a/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala +++ b/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala @@ -10,45 +10,56 @@ import Utils._ import memlib.AnalysisUtils._ object ClockListUtils { + /** Returns a list of clock outputs from instances of external modules - */ + */ def getSourceList(moduleMap: Map[String, DefModule])(lin: Lineage): Seq[String] = { - val s = lin.foldLeft(Seq[String]()){case (sL, (i, l)) => - val sLx = getSourceList(moduleMap)(l) - val sLxx = sLx map (i + "$" + _) - sL ++ sLxx + val s = lin.foldLeft(Seq[String]()) { + case (sL, (i, l)) => + val sLx = getSourceList(moduleMap)(l) + val sLxx = sLx.map(i + "$" + _) + sL ++ sLxx } val sourceList = moduleMap(lin.name) match { case ExtModule(i, n, ports, dn, p) => - val portExps = ports.flatMap{p => create_exps(WRef(p.name, p.tpe, PortKind, to_flow(p.direction)))} + val portExps = ports.flatMap { p => create_exps(WRef(p.name, p.tpe, PortKind, to_flow(p.direction))) } portExps.filter(e => (e.tpe == ClockType) && (flow(e) == SinkFlow)).map(_.serialize) case _ => Nil } val sx = sourceList ++ s sx } + /** Returns a map from instance name to its clock origin. - * Child instances are not included if they share the same clock as their parent - */ - def getOrigins(connects: Connects, me: String, moduleMap: Map[String, DefModule])(lin: Lineage): Map[String, String] = { - val sep = if(me == "") "" else "$" + * Child instances are not included if they share the same clock as their parent + */ + def getOrigins( + connects: Connects, + me: String, + moduleMap: Map[String, DefModule] + )(lin: Lineage + ): Map[String, String] = { + val sep = if (me == "") "" else "$" // Get origins from all children - val childrenOrigins = lin.foldLeft(Map[String, String]()){case (o, (i, l)) => - o ++ getOrigins(connects, me + sep + i, moduleMap)(l) + val childrenOrigins = lin.foldLeft(Map[String, String]()) { + case (o, (i, l)) => + o ++ getOrigins(connects, me + sep + i, moduleMap)(l) } // If I have a clock, get it val clockOpt = moduleMap(lin.name) match { - case Module(i, n, ports, b) => ports.collectFirst { case p if p.name == "clock" => me + sep + "clock" } + case Module(i, n, ports, b) => ports.collectFirst { case p if p.name == "clock" => me + sep + "clock" } case ExtModule(i, n, ports, dn, p) => None } // Return new origins with direct children removed, if they match my clock clockOpt match { case Some(clock) => val myOrigin = getOrigin(connects, clock).serialize - childrenOrigins.foldLeft(Map(me -> myOrigin)) { case (o, (childInstance, childOrigin)) => - val childrenInstances = lin.children.map { case (instance, _) => me + sep + instance } - // If direct child shares my origin, omit it - if(childOrigin == myOrigin && childrenInstances.contains(childInstance)) o else o + (childInstance -> childOrigin) + childrenOrigins.foldLeft(Map(me -> myOrigin)) { + case (o, (childInstance, childOrigin)) => + val childrenInstances = lin.children.map { case (instance, _) => me + sep + instance } + // If direct child shares my origin, omit it + if (childOrigin == myOrigin && childrenInstances.contains(childInstance)) o + else o + (childInstance -> childOrigin) } case None => childrenOrigins } diff --git a/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala b/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala index 6eb8c138..d72bc293 100644 --- a/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala +++ b/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala @@ -9,22 +9,22 @@ import Utils._ import Mappers._ /** Remove all statements and ports (except instances/whens/blocks) whose - * expressions do not relate to ground types. - */ + * expressions do not relate to ground types. + */ object RemoveAllButClocks extends Pass { - def onStmt(s: Statement): Statement = (s map onStmt) match { - case DefWire(i, n, ClockType) => s + def onStmt(s: Statement): Statement = (s.map(onStmt)) match { + case DefWire(i, n, ClockType) => s case DefNode(i, n, value) if value.tpe == ClockType => s - case Connect(i, l, r) if l.tpe == ClockType => s - case sx: WDefInstance => sx - case sx: DefInstance => sx - case sx: Block => sx + case Connect(i, l, r) if l.tpe == ClockType => s + case sx: WDefInstance => sx + case sx: DefInstance => sx + case sx: Block => sx case sx: Conditionally => sx case _ => EmptyStmt } def onModule(m: DefModule): DefModule = m match { - case Module(i, n, ps, b) => Module(i, n, ps.filter(_.tpe == ClockType), squashEmpty(onStmt(b))) + case Module(i, n, ps, b) => Module(i, n, ps.filter(_.tpe == ClockType), squashEmpty(onStmt(b))) case ExtModule(i, n, ps, dn, p) => ExtModule(i, n, ps.filter(_.tpe == ClockType), dn, p) } - def run(c: Circuit): Circuit = c.copy(modules = c.modules map onModule) + def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(onModule)) } diff --git a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala index 14bd9e44..d237c36a 100644 --- a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala +++ b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala @@ -19,8 +19,9 @@ class CreateMemoryAnnotations(reader: Option[YamlFileReader]) extends Transform import CustomYAMLProtocol._ val configs = r.parse[Config] val oldAnnos = state.annotations - val (as, pins) = configs.foldLeft((oldAnnos, Seq.empty[String])) { case ((annos, pins), config) => - (annos, pins :+ config.pin.name) + val (as, pins) = configs.foldLeft((oldAnnos, Seq.empty[String])) { + case ((annos, pins), config) => + (annos, pins :+ config.pin.name) } state.copy(annotations = PinAnnotation(pins.toSeq) +: as) } diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala index 4847a698..e290633e 100644 --- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala +++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala @@ -10,12 +10,11 @@ import firrtl.PrimOps._ import firrtl.Utils.{one, zero, BoolType} import firrtl.options.{HasShellOptions, ShellOption} import MemPortUtils.memPortField -import firrtl.passes.memlib.AnalysisUtils.{Connects, getConnects, getOrigin} +import firrtl.passes.memlib.AnalysisUtils.{getConnects, getOrigin, Connects} import WrappedExpression.weq import annotations._ import firrtl.stage.{Forms, RunFirrtlTransformAnnotation} - case object InferReadWriteAnnotation extends NoTargetAnnotation // This pass examine the enable signals of the read & write ports of memories @@ -40,12 +39,13 @@ object InferReadWritePass extends Pass { getProductTerms(connects)(cond) ++ getProductTerms(connects)(tval) // Visit each term of AND operation case DoPrim(op, args, consts, tpe) if op == And => - e +: (args flatMap getProductTerms(connects)) + e +: (args.flatMap(getProductTerms(connects))) // Visit connected nodes to references - case _: WRef | _: WSubField | _: WSubIndex => connects get e match { - case None => Seq(e) - case Some(ex) => e +: getProductTerms(connects)(ex) - } + case _: WRef | _: WSubField | _: WSubIndex => + connects.get(e) match { + case None => Seq(e) + case Some(ex) => e +: getProductTerms(connects)(ex) + } // Otherwise just return itself case _ => Seq(e) } @@ -58,96 +58,103 @@ object InferReadWritePass extends Pass { // b ?= Eq(a, 0) or b ?= Eq(0, a) case (_, DoPrim(Eq, args, _, _)) => weq(args.head, a) && weq(args(1), zero) || - weq(args(1), a) && weq(args.head, zero) + weq(args(1), a) && weq(args.head, zero) // a ?= Eq(b, 0) or b ?= Eq(0, a) case (DoPrim(Eq, args, _, _), _) => weq(args.head, b) && weq(args(1), zero) || - weq(args(1), b) && weq(args.head, zero) + weq(args(1), b) && weq(args.head, zero) case _ => false } - def replaceExp(repl: Netlist)(e: Expression): Expression = - e map replaceExp(repl) match { - case ex: WSubField => repl getOrElse (ex.serialize, ex) + e.map(replaceExp(repl)) match { + case ex: WSubField => repl.getOrElse(ex.serialize, ex) case ex => ex } def replaceStmt(repl: Netlist)(s: Statement): Statement = - s map replaceStmt(repl) map replaceExp(repl) match { + s.map(replaceStmt(repl)).map(replaceExp(repl)) match { case Connect(_, EmptyExpression, _) => EmptyStmt - case sx => sx + case sx => sx } - def inferReadWriteStmt(connects: Connects, - repl: Netlist, - stmts: Statements) - (s: Statement): Statement = s match { + def inferReadWriteStmt(connects: Connects, repl: Netlist, stmts: Statements)(s: Statement): Statement = s match { // infer readwrite ports only for non combinational memories case mem: DefMemory if mem.readLatency > 0 => val readers = new PortSet val writers = new PortSet val readwriters = collection.mutable.ArrayBuffer[String]() val namespace = Namespace(mem.readers ++ mem.writers ++ mem.readwriters) - for (w <- mem.writers ; r <- mem.readers) { + for { + w <- mem.writers + r <- mem.readers + } { val wenProductTerms = getProductTerms(connects)(memPortField(mem, w, "en")) val renProductTerms = getProductTerms(connects)(memPortField(mem, r, "en")) - val proofOfMutualExclusion = wenProductTerms.find(a => renProductTerms exists (b => checkComplement(a, b))) + val proofOfMutualExclusion = wenProductTerms.find(a => renProductTerms.exists(b => checkComplement(a, b))) val wclk = getOrigin(connects)(memPortField(mem, w, "clk")) val rclk = getOrigin(connects)(memPortField(mem, r, "clk")) if (weq(wclk, rclk) && proofOfMutualExclusion.nonEmpty) { - val rw = namespace newName "rw" + val rw = namespace.newName("rw") val rwExp = WSubField(WRef(mem.name), rw) readwriters += rw readers += r writers += w - repl(memPortField(mem, r, "clk")) = EmptyExpression - repl(memPortField(mem, r, "en")) = EmptyExpression + repl(memPortField(mem, r, "clk")) = EmptyExpression + repl(memPortField(mem, r, "en")) = EmptyExpression repl(memPortField(mem, r, "addr")) = EmptyExpression repl(memPortField(mem, r, "data")) = WSubField(rwExp, "rdata") - repl(memPortField(mem, w, "clk")) = EmptyExpression - repl(memPortField(mem, w, "en")) = EmptyExpression + repl(memPortField(mem, w, "clk")) = EmptyExpression + repl(memPortField(mem, w, "en")) = EmptyExpression repl(memPortField(mem, w, "addr")) = EmptyExpression repl(memPortField(mem, w, "data")) = WSubField(rwExp, "wdata") repl(memPortField(mem, w, "mask")) = WSubField(rwExp, "wmask") stmts += Connect(NoInfo, WSubField(rwExp, "wmode"), proofOfMutualExclusion.get) stmts += Connect(NoInfo, WSubField(rwExp, "clk"), wclk) - stmts += Connect(NoInfo, WSubField(rwExp, "en"), - DoPrim(Or, Seq(connects(memPortField(mem, r, "en")), - connects(memPortField(mem, w, "en"))), Nil, BoolType)) - stmts += Connect(NoInfo, WSubField(rwExp, "addr"), - Mux(connects(memPortField(mem, w, "en")), - connects(memPortField(mem, w, "addr")), - connects(memPortField(mem, r, "addr")), UnknownType)) + stmts += Connect( + NoInfo, + WSubField(rwExp, "en"), + DoPrim(Or, Seq(connects(memPortField(mem, r, "en")), connects(memPortField(mem, w, "en"))), Nil, BoolType) + ) + stmts += Connect( + NoInfo, + WSubField(rwExp, "addr"), + Mux( + connects(memPortField(mem, w, "en")), + connects(memPortField(mem, w, "addr")), + connects(memPortField(mem, r, "addr")), + UnknownType + ) + ) } } - if (readwriters.isEmpty) mem else mem copy ( - readers = mem.readers filterNot readers, - writers = mem.writers filterNot writers, - readwriters = mem.readwriters ++ readwriters) - case sx => sx map inferReadWriteStmt(connects, repl, stmts) + if (readwriters.isEmpty) mem + else + mem.copy( + readers = mem.readers.filterNot(readers), + writers = mem.writers.filterNot(writers), + readwriters = mem.readwriters ++ readwriters + ) + case sx => sx.map(inferReadWriteStmt(connects, repl, stmts)) } def inferReadWrite(m: DefModule) = { val connects = getConnects(m) val repl = new Netlist val stmts = new Statements - (m map inferReadWriteStmt(connects, repl, stmts) - map replaceStmt(repl)) match { + (m.map(inferReadWriteStmt(connects, repl, stmts)) + .map(replaceStmt(repl))) match { case m: ExtModule => m - case m: Module => m copy (body = Block(m.body +: stmts.toSeq)) + case m: Module => m.copy(body = Block(m.body +: stmts.toSeq)) } } - def run(c: Circuit) = c copy (modules = c.modules map inferReadWrite) + def run(c: Circuit) = c.copy(modules = c.modules.map(inferReadWrite)) } // Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl" // To use this transform, circuit name should be annotated with its TransId. -class InferReadWrite extends Transform - with DependencyAPIMigration - with SeqTransformBased - with HasShellOptions { +class InferReadWrite extends Transform with DependencyAPIMigration with SeqTransformBased with HasShellOptions { override def prerequisites = Forms.MidForm override def optionalPrerequisites = Seq.empty @@ -159,7 +166,9 @@ class InferReadWrite extends Transform longOption = "infer-rw", toAnnotationSeq = (_: Unit) => Seq(InferReadWriteAnnotation, RunFirrtlTransformAnnotation(new InferReadWrite)), helpText = "Enable read/write port inference for memories", - shortOption = Some("firw") ) ) + shortOption = Some("firw") + ) + ) def transforms = Seq( InferReadWritePass, diff --git a/src/main/scala/firrtl/passes/memlib/MemConf.scala b/src/main/scala/firrtl/passes/memlib/MemConf.scala index 3809c47c..871a1093 100644 --- a/src/main/scala/firrtl/passes/memlib/MemConf.scala +++ b/src/main/scala/firrtl/passes/memlib/MemConf.scala @@ -3,7 +3,6 @@ package firrtl.passes package memlib - sealed abstract class MemPort(val name: String) { override def toString = name } case object ReadPort extends MemPort("read") @@ -19,22 +18,27 @@ object MemPort { def apply(s: String): Option[MemPort] = MemPort.all.find(_.name == s) def fromString(s: String): Map[MemPort, Int] = { - s.split(",").toSeq.map(MemPort.apply).map(_ match { - case Some(x) => x - case _ => throw new Exception(s"Error parsing MemPort string : ${s}") - }).groupBy(identity).mapValues(_.size).toMap + s.split(",") + .toSeq + .map(MemPort.apply) + .map(_ match { + case Some(x) => x + case _ => throw new Exception(s"Error parsing MemPort string : ${s}") + }) + .groupBy(identity) + .mapValues(_.size) + .toMap } } case class MemConf( - name: String, - depth: BigInt, - width: Int, - ports: Map[MemPort, Int], - maskGranularity: Option[Int] -) { + name: String, + depth: BigInt, + width: Int, + ports: Map[MemPort, Int], + maskGranularity: Option[Int]) { - private def portsStr = ports.map { case (port, num) => Seq.fill(num)(port.name).mkString(",") } mkString (",") + private def portsStr = ports.map { case (port, num) => Seq.fill(num)(port.name).mkString(",") }.mkString(",") private def maskGranStr = maskGranularity.map((p) => s"mask_gran $p").getOrElse("") // Assert that all of the entries in the port map are greater than zero to make it easier to compare two of these case classes @@ -49,21 +53,34 @@ object MemConf { val regex = raw"\s*name\s+(\w+)\s+depth\s+(\d+)\s+width\s+(\d+)\s+ports\s+([^\s]+)\s+(?:mask_gran\s+(\d+))?\s*".r def fromString(s: String): Seq[MemConf] = { - s.split("\n").toSeq.map(_ match { - case MemConf.regex(name, depth, width, ports, maskGran) => Some(MemConf(name, BigInt(depth), width.toInt, MemPort.fromString(ports), Option(maskGran).map(_.toInt))) - case "" => None - case _ => throw new Exception(s"Error parsing MemConf string : ${s}") - }).flatten + s.split("\n") + .toSeq + .map(_ match { + case MemConf.regex(name, depth, width, ports, maskGran) => + Some(MemConf(name, BigInt(depth), width.toInt, MemPort.fromString(ports), Option(maskGran).map(_.toInt))) + case "" => None + case _ => throw new Exception(s"Error parsing MemConf string : ${s}") + }) + .flatten } - def apply(name: String, depth: BigInt, width: Int, readPorts: Int, writePorts: Int, readWritePorts: Int, maskGranularity: Option[Int]): MemConf = { + def apply( + name: String, + depth: BigInt, + width: Int, + readPorts: Int, + writePorts: Int, + readWritePorts: Int, + maskGranularity: Option[Int] + ): MemConf = { val ports: Seq[(MemPort, Int)] = (if (maskGranularity.isEmpty) { - (if (writePorts == 0) Seq() else Seq(WritePort -> writePorts)) ++ - (if (readWritePorts == 0) Seq() else Seq(ReadWritePort -> readWritePorts)) - } else { - (if (writePorts == 0) Seq() else Seq(MaskedWritePort -> writePorts)) ++ - (if (readWritePorts == 0) Seq() else Seq(MaskedReadWritePort -> readWritePorts)) - }) ++ (if (readPorts == 0) Seq() else Seq(ReadPort -> readPorts)) + (if (writePorts == 0) Seq() else Seq(WritePort -> writePorts)) ++ + (if (readWritePorts == 0) Seq() else Seq(ReadWritePort -> readWritePorts)) + } else { + (if (writePorts == 0) Seq() else Seq(MaskedWritePort -> writePorts)) ++ + (if (readWritePorts == 0) Seq() + else Seq(MaskedReadWritePort -> readWritePorts)) + }) ++ (if (readPorts == 0) Seq() else Seq(ReadPort -> readPorts)) new MemConf(name, depth, width, ports.toMap, maskGranularity) } } diff --git a/src/main/scala/firrtl/passes/memlib/MemIR.scala b/src/main/scala/firrtl/passes/memlib/MemIR.scala index 3731ea86..c8cd3e8d 100644 --- a/src/main/scala/firrtl/passes/memlib/MemIR.scala +++ b/src/main/scala/firrtl/passes/memlib/MemIR.scala @@ -19,38 +19,38 @@ object DefAnnotatedMemory { m.readwriters, m.readUnderWrite, None, // mask granularity annotation - None // No reference yet to another memory + None // No reference yet to another memory ) } } case class DefAnnotatedMemory( - info: Info, - name: String, - dataType: Type, - depth: BigInt, - writeLatency: Int, - readLatency: Int, - readers: Seq[String], - writers: Seq[String], - readwriters: Seq[String], - readUnderWrite: ReadUnderWrite.Value, - maskGran: Option[BigInt], - memRef: Option[(String, String)] /* (Module, Mem) */ - //pins: Seq[Pin], - ) extends Statement with IsDeclaration { + info: Info, + name: String, + dataType: Type, + depth: BigInt, + writeLatency: Int, + readLatency: Int, + readers: Seq[String], + writers: Seq[String], + readwriters: Seq[String], + readUnderWrite: ReadUnderWrite.Value, + maskGran: Option[BigInt], + memRef: Option[(String, String)] /* (Module, Mem) */ + //pins: Seq[Pin], +) extends Statement + with IsDeclaration { override def serialize: String = this.toMem.serialize - def mapStmt(f: Statement => Statement): Statement = this - def mapExpr(f: Expression => Expression): Statement = this - def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType)) - def mapString(f: String => String): Statement = this.copy(name = f(name)) - def toMem = DefMemory(info, name, dataType, depth, - writeLatency, readLatency, readers, writers, - readwriters, readUnderWrite) - def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) - def foreachStmt(f: Statement => Unit): Unit = () - def foreachExpr(f: Expression => Unit): Unit = () - def foreachType(f: Type => Unit): Unit = f(dataType) - def foreachString(f: String => Unit): Unit = f(name) - def foreachInfo(f: Info => Unit): Unit = f(info) + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = this + def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType)) + def mapString(f: String => String): Statement = this.copy(name = f(name)) + def toMem = + DefMemory(info, name, dataType, depth, writeLatency, readLatency, readers, writers, readwriters, readUnderWrite) + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = () + def foreachExpr(f: Expression => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = f(dataType) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } diff --git a/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala b/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala index f0c9ebf4..1db132f7 100644 --- a/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala +++ b/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala @@ -7,8 +7,7 @@ import firrtl.options.{RegisteredLibrary, ShellOption} class MemLibOptions extends RegisteredLibrary { val name: String = "MemLib Options" - val options: Seq[ShellOption[_]] = Seq( new InferReadWrite, - new ReplSeqMem ) + val options: Seq[ShellOption[_]] = Seq(new InferReadWrite, new ReplSeqMem) .flatMap(_.options) } diff --git a/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala b/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala index b6a9a23d..f153fa2b 100644 --- a/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala +++ b/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala @@ -11,12 +11,12 @@ import MemPortUtils.{MemPortMap} object MemTransformUtils { /** Replaces references to old memory port names with new memory port names - */ + */ def updateStmtRefs(repl: MemPortMap)(s: Statement): Statement = { //TODO(izraelevitz): check speed def updateRef(e: Expression): Expression = { - val ex = e map updateRef - repl getOrElse (ex.serialize, ex) + val ex = e.map(updateRef) + repl.getOrElse(ex.serialize, ex) } def hasEmptyExpr(stmt: Statement): Boolean = { @@ -24,16 +24,16 @@ object MemTransformUtils { def testEmptyExpr(e: Expression): Expression = { e match { case EmptyExpression => foundEmpty = true - case _ => + case _ => } - e map testEmptyExpr // map must return; no foreach + e.map(testEmptyExpr) // map must return; no foreach } - stmt map testEmptyExpr + stmt.map(testEmptyExpr) foundEmpty } def updateStmtRefs(s: Statement): Statement = - s map updateStmtRefs map updateRef match { + s.map(updateStmtRefs).map(updateRef) match { case c: Connect if hasEmptyExpr(c) => EmptyStmt case s => s } @@ -42,6 +42,6 @@ object MemTransformUtils { } def defaultPortSeq(mem: DefAnnotatedMemory): Seq[Field] = MemPortUtils.defaultPortSeq(mem.toMem) - def memPortField(s: DefAnnotatedMemory, p: String, f: String): WSubField = + def memPortField(s: DefAnnotatedMemory, p: String, f: String): WSubField = MemPortUtils.memPortField(s.toMem, p, f) } diff --git a/src/main/scala/firrtl/passes/memlib/MemUtils.scala b/src/main/scala/firrtl/passes/memlib/MemUtils.scala index 69c6b284..f325c0ba 100644 --- a/src/main/scala/firrtl/passes/memlib/MemUtils.scala +++ b/src/main/scala/firrtl/passes/memlib/MemUtils.scala @@ -7,19 +7,19 @@ import firrtl.ir._ import firrtl.Utils._ /** Given a mask, return a bitmask corresponding to the desired datatype. - * Requirements: - * - The mask type and datatype must be equivalent, except any ground type in - * datatype must be matched by a 1-bit wide UIntType. - * - The mask must be a reference, subfield, or subindex - * The bitmask is a series of concatenations of the single mask bit over the - * length of the corresponding ground type, e.g.: - *{{{ - * wire mask: {x: UInt<1>, y: UInt<1>} - * wire data: {x: UInt<2>, y: SInt<2>} - * // this would return: - * cat(cat(mask.x, mask.x), cat(mask.y, mask.y)) - * }}} - */ + * Requirements: + * - The mask type and datatype must be equivalent, except any ground type in + * datatype must be matched by a 1-bit wide UIntType. + * - The mask must be a reference, subfield, or subindex + * The bitmask is a series of concatenations of the single mask bit over the + * length of the corresponding ground type, e.g.: + * {{{ + * wire mask: {x: UInt<1>, y: UInt<1>} + * wire data: {x: UInt<2>, y: SInt<2>} + * // this would return: + * cat(cat(mask.x, mask.x), cat(mask.y, mask.y)) + * }}} + */ object toBitMask { def apply(mask: Expression, dataType: Type): Expression = mask match { case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiermask(ex, dataType) @@ -28,12 +28,13 @@ object toBitMask { private def hiermask(mask: Expression, dataType: Type): Expression = (mask.tpe, dataType) match { case (mt: VectorType, dt: VectorType) => - seqCat((0 until mt.size).reverse map { i => + seqCat((0 until mt.size).reverse.map { i => hiermask(WSubIndex(mask, i, mt.tpe, UnknownFlow), dt.tpe) }) case (mt: BundleType, dt: BundleType) => - seqCat((mt.fields zip dt.fields) map { case (mf, df) => - hiermask(WSubField(mask, mf.name, mf.tpe, UnknownFlow), df.tpe) + seqCat((mt.fields.zip(dt.fields)).map { + case (mf, df) => + hiermask(WSubField(mask, mf.name, mf.tpe, UnknownFlow), df.tpe) }) case (UIntType(width), dt: GroundType) if width == IntWidth(BigInt(1)) => seqCat(List.fill(bitWidth(dt).intValue)(mask)) @@ -44,7 +45,7 @@ object toBitMask { object createMask { def apply(dt: Type): Type = dt match { case t: VectorType => VectorType(apply(t.tpe), t.size) - case t: BundleType => BundleType(t.fields map (f => f copy (tpe=apply(f.tpe)))) + case t: BundleType => BundleType(t.fields.map(f => f.copy(tpe = apply(f.tpe)))) case GroundType(w) if w == IntWidth(0) => UIntType(IntWidth(0)) case t: GroundType => BoolType } @@ -56,27 +57,33 @@ object MemPortUtils { type Modules = collection.mutable.ArrayBuffer[DefModule] def defaultPortSeq(mem: DefMemory): Seq[Field] = Seq( - Field("addr", Default, UIntType(IntWidth(getUIntWidth(mem.depth - 1) max 1))), + Field("addr", Default, UIntType(IntWidth(getUIntWidth(mem.depth - 1).max(1)))), Field("en", Default, BoolType), Field("clk", Default, ClockType) ) // Todo: merge it with memToBundle def memType(mem: DefMemory): BundleType = { - val rType = BundleType(defaultPortSeq(mem) :+ - Field("data", Flip, mem.dataType)) - val wType = BundleType(defaultPortSeq(mem) ++ Seq( - Field("data", Default, mem.dataType), - Field("mask", Default, createMask(mem.dataType)))) - val rwType = BundleType(defaultPortSeq(mem) ++ Seq( - Field("rdata", Flip, mem.dataType), - Field("wmode", Default, BoolType), - Field("wdata", Default, mem.dataType), - Field("wmask", Default, createMask(mem.dataType)))) + val rType = BundleType( + defaultPortSeq(mem) :+ + Field("data", Flip, mem.dataType) + ) + val wType = BundleType( + defaultPortSeq(mem) ++ Seq(Field("data", Default, mem.dataType), Field("mask", Default, createMask(mem.dataType))) + ) + val rwType = BundleType( + defaultPortSeq(mem) ++ Seq( + Field("rdata", Flip, mem.dataType), + Field("wmode", Default, BoolType), + Field("wdata", Default, mem.dataType), + Field("wmask", Default, createMask(mem.dataType)) + ) + ) BundleType( - (mem.readers map (Field(_, Flip, rType))) ++ - (mem.writers map (Field(_, Flip, wType))) ++ - (mem.readwriters map (Field(_, Flip, rwType)))) + (mem.readers.map(Field(_, Flip, rType))) ++ + (mem.writers.map(Field(_, Flip, wType))) ++ + (mem.readwriters.map(Field(_, Flip, rwType))) + ) } def memPortField(s: DefMemory, p: String, f: String): WSubField = { diff --git a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala index c51a0adc..30529119 100644 --- a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala +++ b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala @@ -9,27 +9,27 @@ import firrtl.Mappers._ import MemPortUtils._ import MemTransformUtils._ - /** Changes memory port names to standard port names (i.e. RW0 instead T_408) - */ + */ object RenameAnnotatedMemoryPorts extends Pass { + /** Renames memory ports to a standard naming scheme: - * - R0, R1, ... for each read port - * - W0, W1, ... for each write port - * - RW0, RW1, ... for each readwrite port - */ + * - R0, R1, ... for each read port + * - W0, W1, ... for each write port + * - RW0, RW1, ... for each readwrite port + */ def createMemProto(m: DefAnnotatedMemory): DefAnnotatedMemory = { - val rports = m.readers.indices map (i => s"R$i") - val wports = m.writers.indices map (i => s"W$i") - val rwports = m.readwriters.indices map (i => s"RW$i") - m copy (readers = rports, writers = wports, readwriters = rwports) + val rports = m.readers.indices.map(i => s"R$i") + val wports = m.writers.indices.map(i => s"W$i") + val rwports = m.readwriters.indices.map(i => s"RW$i") + m.copy(readers = rports, writers = wports, readwriters = rwports) } /** Maps the serialized form of all memory port field names to the - * corresponding new memory port field Expression. - * E.g.: - * - ("m.read.addr") becomes (m.R0.addr) - */ + * corresponding new memory port field Expression. + * E.g.: + * - ("m.read.addr") becomes (m.R0.addr) + */ def getMemPortMap(m: DefAnnotatedMemory, memPortMap: MemPortMap): Unit = { val defaultFields = Seq("addr", "en", "clk") val rFields = defaultFields :+ "data" @@ -37,7 +37,10 @@ object RenameAnnotatedMemoryPorts extends Pass { val rwFields = defaultFields ++ Seq("wmode", "wdata", "rdata", "wmask") def updateMemPortMap(ports: Seq[String], fields: Seq[String], newPortKind: String): Unit = - for ((p, i) <- ports.zipWithIndex; f <- fields) { + for { + (p, i) <- ports.zipWithIndex + f <- fields + } { val newPort = WSubField(WRef(m.name), newPortKind + i) val field = WSubField(newPort, f) memPortMap(s"${m.name}.$p.$f") = field @@ -55,16 +58,16 @@ object RenameAnnotatedMemoryPorts extends Pass { val updatedMem = createMemProto(m) getMemPortMap(m, memPortMap) updatedMem - case s => s map updateMemStmts(memPortMap) + case s => s.map(updateMemStmts(memPortMap)) } /** Replaces candidate memories and their references with standard port names - */ + */ def updateMemMods(m: DefModule) = { val memPortMap = new MemPortMap - (m map updateMemStmts(memPortMap) - map updateStmtRefs(memPortMap)) + (m.map(updateMemStmts(memPortMap)) + .map(updateStmtRefs(memPortMap))) } - def run(c: Circuit) = c copy (modules = c.modules map updateMemMods) + def run(c: Circuit) = c.copy(modules = c.modules.map(updateMemMods)) } diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala index bfbc163a..fc381e88 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala @@ -13,7 +13,6 @@ import firrtl.annotations._ import firrtl.stage.Forms import wiring._ - /** Annotates the name of the pins to add for WiringTransform */ case class PinAnnotation(pins: Seq[String]) extends NoTargetAnnotation @@ -35,14 +34,16 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM /** Return true if mask granularity is per bit, false if per byte or unspecified */ private def getFillWMask(mem: DefAnnotatedMemory) = mem.maskGran match { - case None => false + case None => false case Some(v) => v == 1 } private def rPortToBundle(mem: DefAnnotatedMemory) = BundleType( - defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType)) + defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType) + ) private def rPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType( - defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType))) + defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType)) + ) /** Catch incorrect memory instantiations when there are masked memories with unsupported aggregate types. * @@ -82,7 +83,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM ) private def wPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType( (defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType))) ++ (mem.maskGran match { - case None => Nil + case None => Nil case Some(_) if getFillWMask(mem) => Seq(Field("mask", Default, flattenType(mem.dataType))) case Some(_) => { checkMaskDatatype(mem) @@ -111,7 +112,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM Field("wdata", Default, flattenType(mem.dataType)), Field("rdata", Flip, flattenType(mem.dataType)) ) ++ (mem.maskGran match { - case None => Nil + case None => Nil case Some(_) if (getFillWMask(mem)) => Seq(Field("wmask", Default, flattenType(mem.dataType))) case Some(_) => { checkMaskDatatype(mem) @@ -122,32 +123,34 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM def memToBundle(s: DefAnnotatedMemory) = BundleType( s.readers.map(Field(_, Flip, rPortToBundle(s))) ++ - s.writers.map(Field(_, Flip, wPortToBundle(s))) ++ - s.readwriters.map(Field(_, Flip, rwPortToBundle(s)))) + s.writers.map(Field(_, Flip, wPortToBundle(s))) ++ + s.readwriters.map(Field(_, Flip, rwPortToBundle(s))) + ) def memToFlattenBundle(s: DefAnnotatedMemory) = BundleType( s.readers.map(Field(_, Flip, rPortToFlattenBundle(s))) ++ - s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++ - s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s)))) + s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++ + s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s))) + ) /** Creates a wrapper module and external module to replace a candidate memory - * The wrapper module has the same type as the memory it replaces - * The external module - */ + * The wrapper module has the same type as the memory it replaces + * The external module + */ def createMemModule(m: DefAnnotatedMemory, wrapperName: String): Seq[DefModule] = { assert(m.dataType != UnknownType) val wrapperIoType = memToBundle(m) - val wrapperIoPorts = wrapperIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe)) + val wrapperIoPorts = wrapperIoType.fields.map(f => Port(NoInfo, f.name, Input, f.tpe)) // Creates a type with the write/readwrite masks omitted if necessary val bbIoType = memToFlattenBundle(m) - val bbIoPorts = bbIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe)) + val bbIoPorts = bbIoType.fields.map(f => Port(NoInfo, f.name, Input, f.tpe)) val bbRef = WRef(m.name, bbIoType) val hasMask = m.maskGran.isDefined val fillMask = getFillWMask(m) def portRef(p: String) = WRef(p, field_type(wrapperIoType, p)) val stmts = Seq(WDefInstance(NoInfo, m.name, m.name, UnknownType)) ++ - (m.readers flatMap (r => adaptReader(portRef(r), WSubField(bbRef, r)))) ++ - (m.writers flatMap (w => adaptWriter(portRef(w), WSubField(bbRef, w), hasMask, fillMask))) ++ - (m.readwriters flatMap (rw => adaptReadWriter(portRef(rw), WSubField(bbRef, rw), hasMask, fillMask))) + (m.readers.flatMap(r => adaptReader(portRef(r), WSubField(bbRef, r)))) ++ + (m.writers.flatMap(w => adaptWriter(portRef(w), WSubField(bbRef, w), hasMask, fillMask))) ++ + (m.readwriters.flatMap(rw => adaptReadWriter(portRef(rw), WSubField(bbRef, rw), hasMask, fillMask))) val wrapper = Module(NoInfo, wrapperName, wrapperIoPorts, Block(stmts)) val bb = ExtModule(NoInfo, m.name, bbIoPorts, m.name, Seq.empty) // TODO: Annotate? -- use actual annotation map @@ -160,16 +163,16 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM // TODO(shunshou): get rid of copy pasta // Connects the clk, en, and addr fields from the wrapperPort to the bbPort def defaultConnects(wrapperPort: WRef, bbPort: WSubField): Seq[Connect] = - Seq("clk", "en", "addr") map (f => connectFields(bbPort, f, wrapperPort, f)) + Seq("clk", "en", "addr").map(f => connectFields(bbPort, f, wrapperPort, f)) // Generates mask bits (concatenates an aggregate to ground type) // depending on mask granularity (# bits = data width / mask granularity) def maskBits(mask: WSubField, dataType: Type, fillMask: Boolean): Expression = if (fillMask) toBitMask(mask, dataType) else toBits(mask) - def adaptReader(wrapperPort: WRef, bbPort: WSubField): Seq[Statement] = + def adaptReader(wrapperPort: WRef, bbPort: WSubField): Seq[Statement] = defaultConnects(wrapperPort, bbPort) :+ - fromBits(WSubField(wrapperPort, "data"), WSubField(bbPort, "data")) + fromBits(WSubField(wrapperPort, "data"), WSubField(bbPort, "data")) def adaptWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean): Seq[Statement] = { val wrapperData = WSubField(wrapperPort, "data") @@ -177,11 +180,12 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM Connect(NoInfo, WSubField(bbPort, "data"), toBits(wrapperData)) hasMask match { case false => defaultSeq - case true => defaultSeq :+ Connect( - NoInfo, - WSubField(bbPort, "mask"), - maskBits(WSubField(wrapperPort, "mask"), wrapperData.tpe, fillMask) - ) + case true => + defaultSeq :+ Connect( + NoInfo, + WSubField(bbPort, "mask"), + maskBits(WSubField(wrapperPort, "mask"), wrapperData.tpe, fillMask) + ) } } @@ -190,61 +194,67 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM val defaultSeq = defaultConnects(wrapperPort, bbPort) ++ Seq( fromBits(WSubField(wrapperPort, "rdata"), WSubField(bbPort, "rdata")), connectFields(bbPort, "wmode", wrapperPort, "wmode"), - Connect(NoInfo, WSubField(bbPort, "wdata"), toBits(wrapperWData))) + Connect(NoInfo, WSubField(bbPort, "wdata"), toBits(wrapperWData)) + ) hasMask match { case false => defaultSeq - case true => defaultSeq :+ Connect( - NoInfo, - WSubField(bbPort, "wmask"), - maskBits(WSubField(wrapperPort, "wmask"), wrapperWData.tpe, fillMask) - ) + case true => + defaultSeq :+ Connect( + NoInfo, + WSubField(bbPort, "wmask"), + maskBits(WSubField(wrapperPort, "wmask"), wrapperWData.tpe, fillMask) + ) } } /** Mapping from (module, memory name) pairs to blackbox names */ private type NameMap = collection.mutable.HashMap[(String, String), String] + /** Construct NameMap by assigning unique names for each memory blackbox */ def constructNameMap(namespace: Namespace, nameMap: NameMap, mname: String)(s: Statement): Statement = { s match { - case m: DefAnnotatedMemory => m.memRef match { - case None => nameMap(mname -> m.name) = namespace newName m.name - case Some(_) => - } + case m: DefAnnotatedMemory => + m.memRef match { + case None => nameMap(mname -> m.name) = namespace.newName(m.name) + case Some(_) => + } case _ => } - s map constructNameMap(namespace, nameMap, mname) + s.map(constructNameMap(namespace, nameMap, mname)) } - def updateMemStmts(namespace: Namespace, - nameMap: NameMap, - mname: String, - memPortMap: MemPortMap, - memMods: Modules) - (s: Statement): Statement = s match { + def updateMemStmts( + namespace: Namespace, + nameMap: NameMap, + mname: String, + memPortMap: MemPortMap, + memMods: Modules + )(s: Statement + ): Statement = s match { case m: DefAnnotatedMemory => if (m.maskGran.isEmpty) { - m.writers foreach { w => memPortMap(s"${m.name}.$w.mask") = EmptyExpression } - m.readwriters foreach { w => memPortMap(s"${m.name}.$w.wmask") = EmptyExpression } + m.writers.foreach { w => memPortMap(s"${m.name}.$w.mask") = EmptyExpression } + m.readwriters.foreach { w => memPortMap(s"${m.name}.$w.wmask") = EmptyExpression } } m.memRef match { case None => // prototype mem val newWrapperName = nameMap(mname -> m.name) - val newMemBBName = namespace newName s"${newWrapperName}_ext" - val newMem = m copy (name = newMemBBName) + val newMemBBName = namespace.newName(s"${newWrapperName}_ext") + val newMem = m.copy(name = newMemBBName) memMods ++= createMemModule(newMem, newWrapperName) WDefInstance(m.info, m.name, newWrapperName, UnknownType) case Some((module, mem)) => WDefInstance(m.info, m.name, nameMap(module -> mem), UnknownType) } - case sx => sx map updateMemStmts(namespace, nameMap, mname, memPortMap, memMods) + case sx => sx.map(updateMemStmts(namespace, nameMap, mname, memPortMap, memMods)) } def updateMemMods(namespace: Namespace, nameMap: NameMap, memMods: Modules)(m: DefModule) = { val memPortMap = new MemPortMap - (m map updateMemStmts(namespace, nameMap, m.name, memPortMap, memMods) - map updateStmtRefs(memPortMap)) + (m.map(updateMemStmts(namespace, nameMap, m.name, memPortMap, memMods)) + .map(updateStmtRefs(memPortMap))) } def execute(state: CircuitState): CircuitState = { @@ -252,15 +262,15 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM val namespace = Namespace(c) val memMods = new Modules val nameMap = new NameMap - c.modules map (m => m map constructNameMap(namespace, nameMap, m.name)) - val modules = c.modules map updateMemMods(namespace, nameMap, memMods) + c.modules.map(m => m.map(constructNameMap(namespace, nameMap, m.name))) + val modules = c.modules.map(updateMemMods(namespace, nameMap, memMods)) // print conf writer.serialize() val pannos = state.annotations.collect { case a: PinAnnotation => a } val pins = pannos match { - case Seq() => Nil + case Seq() => Nil case Seq(PinAnnotation(pins)) => pins - case _ => throwInternalError("Something went wrong") + case _ => throwInternalError("Something went wrong") } val annos = pins.foldLeft(Seq[Annotation]()) { (seq, pin) => seq ++ memMods.collect { diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala index 87321ea0..79e07640 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala @@ -7,7 +7,7 @@ import firrtl._ import firrtl.annotations._ import firrtl.options.{HasShellOptions, ShellOption} import Utils.error -import java.io.{File, CharArrayWriter, PrintWriter} +import java.io.{CharArrayWriter, File, PrintWriter} import wiring._ import firrtl.stage.{Forms, RunFirrtlTransformAnnotation} @@ -50,7 +50,15 @@ class ConfWriter(filename: String) { // assert that we don't overflow going from BigInt to Int conversion require(bitWidth(m.dataType) <= Int.MaxValue) m.maskGran.foreach { case x => require(x <= Int.MaxValue) } - val conf = MemConf(m.name, m.depth, bitWidth(m.dataType).toInt, m.readers.length, m.writers.length, m.readwriters.length, m.maskGran.map(_.toInt)) + val conf = MemConf( + m.name, + m.depth, + bitWidth(m.dataType).toInt, + m.readers.length, + m.writers.length, + m.readwriters.length, + m.maskGran.map(_.toInt) + ) outputBuffer.append(conf.toString) } def serialize() = { @@ -113,27 +121,31 @@ class ReplSeqMem extends Transform with HasShellOptions with DependencyAPIMigrat val options = Seq( new ShellOption[String]( longOption = "repl-seq-mem", - toAnnotationSeq = (a: String) => Seq( passes.memlib.ReplSeqMemAnnotation.parse(a), - RunFirrtlTransformAnnotation(new ReplSeqMem) ), + toAnnotationSeq = + (a: String) => Seq(passes.memlib.ReplSeqMemAnnotation.parse(a), RunFirrtlTransformAnnotation(new ReplSeqMem)), helpText = "Blackbox and emit a configuration file for each sequential memory", shortOption = Some("frsq"), - helpValueName = Some("-c:<circuit>:-i:<file>:-o:<file>") ) ) + helpValueName = Some("-c:<circuit>:-i:<file>:-o:<file>") + ) + ) def transforms(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter): Seq[Transform] = - Seq(new SimpleMidTransform(Legalize), - new SimpleMidTransform(ToMemIR), - new SimpleMidTransform(ResolveMaskGranularity), - new SimpleMidTransform(RenameAnnotatedMemoryPorts), - new ResolveMemoryReference, - new CreateMemoryAnnotations(inConfigFile), - new ReplaceMemMacros(outConfigFile), - new WiringTransform, - new SimpleMidTransform(RemoveEmpty), - new SimpleMidTransform(CheckInitialization), - new SimpleMidTransform(InferTypes), - Uniquify, - new SimpleMidTransform(ResolveKinds), - new SimpleMidTransform(ResolveFlows)) + Seq( + new SimpleMidTransform(Legalize), + new SimpleMidTransform(ToMemIR), + new SimpleMidTransform(ResolveMaskGranularity), + new SimpleMidTransform(RenameAnnotatedMemoryPorts), + new ResolveMemoryReference, + new CreateMemoryAnnotations(inConfigFile), + new ReplaceMemMacros(outConfigFile), + new WiringTransform, + new SimpleMidTransform(RemoveEmpty), + new SimpleMidTransform(CheckInitialization), + new SimpleMidTransform(InferTypes), + Uniquify, + new SimpleMidTransform(ResolveKinds), + new SimpleMidTransform(ResolveFlows) + ) def execute(state: CircuitState): CircuitState = { val annos = state.annotations.collect { case a: ReplSeqMemAnnotation => a } diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala index 41c47dce..434c7602 100644 --- a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala +++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala @@ -28,10 +28,10 @@ object AnalysisUtils { connects(value.serialize) = WInvalid case _ => // do nothing } - s map getConnects(connects) + s.map(getConnects(connects)) } val connects = new Connects - m map getConnects(connects) + m.map(getConnects(connects)) connects } @@ -56,8 +56,8 @@ object AnalysisUtils { else if (weq(tvOrigin, fvOrigin)) tvOrigin else if (weq(fvOrigin, zero) && weq(condOrigin, tvOrigin)) condOrigin else e - case DoPrim(PrimOps.Or, args, consts, tpe) if args exists (weq(_, one)) => one - case DoPrim(PrimOps.And, args, consts, tpe) if args exists (weq(_, zero)) => zero + case DoPrim(PrimOps.Or, args, consts, tpe) if args.exists(weq(_, one)) => one + case DoPrim(PrimOps.And, args, consts, tpe) if args.exists(weq(_, zero)) => zero case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) => val extractionWidth = (msb - lsb) + 1 val nodeWidth = bitWidth(args.head.tpe) @@ -69,10 +69,10 @@ object AnalysisUtils { case ValidIf(cond, value, _) => getOrigin(connects)(value) // note: this should stop on a reg, but will stack overflow for combinational loops (not allowed) case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind => - connects get e.serialize match { - case Some(ex) => getOrigin(connects)(ex) - case None => e - } + connects.get(e.serialize) match { + case Some(ex) => getOrigin(connects)(ex) + case None => e + } case _ => e } } @@ -90,10 +90,9 @@ object ResolveMaskGranularity extends Pass { */ def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = { val wenOrigin = getOrigin(connects)(wen) - val wmaskOrigin = connects.keys filter - (_ startsWith wmask.serialize) map {s: String => getOrigin(connects, s)} + val wmaskOrigin = connects.keys.filter(_.startsWith(wmask.serialize)).map { s: String => getOrigin(connects, s) } // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking) - val redundantMask = wmaskOrigin forall (x => weq(x, wenOrigin) || weq(x, one)) + val redundantMask = wmaskOrigin.forall(x => weq(x, wenOrigin) || weq(x, one)) if (redundantMask) None else Some(wmaskOrigin.size) } @@ -103,18 +102,17 @@ object ResolveMaskGranularity extends Pass { def updateStmts(connects: Connects)(s: Statement): Statement = s match { case m: DefAnnotatedMemory => val dataBits = bitWidth(m.dataType) - val rwMasks = m.readwriters map (rw => - getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask"))) - val wMasks = m.writers map (w => - getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask"))) + val rwMasks = + m.readwriters.map(rw => getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask"))) + val wMasks = m.writers.map(w => getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask"))) val maskGran = (rwMasks ++ wMasks).head match { - case None => None + case None => None case Some(maskBits) => Some(dataBits / maskBits) } m.copy(maskGran = maskGran) - case sx => sx map updateStmts(connects) + case sx => sx.map(updateStmts(connects)) } - def annotateModMems(m: DefModule): DefModule = m map updateStmts(getConnects(m)) - def run(c: Circuit): Circuit = c copy (modules = c.modules map annotateModMems) + def annotateModMems(m: DefModule): DefModule = m.map(updateStmts(getConnects(m))) + def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(annotateModMems)) } diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala index b5ff10c6..e80e0c4a 100644 --- a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala +++ b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala @@ -14,7 +14,7 @@ case class NoDedupMemAnnotation(target: ComponentName) extends SingleTargetAnnot } /** Resolves annotation ref to memories that exactly match (except name) another memory - */ + */ class ResolveMemoryReference extends Transform with DependencyAPIMigration { override def prerequisites = Forms.MidForm @@ -45,10 +45,12 @@ class ResolveMemoryReference extends Transform with DependencyAPIMigration { /** If a candidate memory is identical except for name to another, add an * annotation that references the name of the other memory. */ - def updateMemStmts(mname: String, - existingMems: AnnotatedMemories, - noDedupMap: Map[String, Set[String]]) - (s: Statement): Statement = s match { + def updateMemStmts( + mname: String, + existingMems: AnnotatedMemories, + noDedupMap: Map[String, Set[String]] + )(s: Statement + ): Statement = s match { // If not dedupable, no need to add to existing (since nothing can dedup with it) // We just return the DefAnnotatedMemory as is in the default case below case m: DefAnnotatedMemory if dedupable(noDedupMap, mname, m.name) => diff --git a/src/main/scala/firrtl/passes/memlib/ToMemIR.scala b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala index 554a3572..9fe7f852 100644 --- a/src/main/scala/firrtl/passes/memlib/ToMemIR.scala +++ b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala @@ -14,16 +14,17 @@ import firrtl.ir._ * - undefined read-under-write behavior */ object ToMemIR extends Pass { + /** Only annotate memories that are candidates for memory macro replacements * i.e. rw, w + r (read, write 1 cycle delay) and read-under-write "undefined." */ import ReadUnderWrite._ def updateStmts(s: Statement): Statement = s match { - case m @ DefMemory(_,_,_,_,1,1,r,w,rw,Undefined) if (w.length + rw.length) == 1 && r.length <= 1 => + case m @ DefMemory(_, _, _, _, 1, 1, r, w, rw, Undefined) if (w.length + rw.length) == 1 && r.length <= 1 => DefAnnotatedMemory(m) - case sx => sx map updateStmts + case sx => sx.map(updateStmts) } - def annotateModMems(m: DefModule) = m map updateStmts - def run(c: Circuit) = c copy (modules = c.modules map annotateModMems) + def annotateModMems(m: DefModule) = m.map(updateStmts) + def run(c: Circuit) = c.copy(modules = c.modules.map(annotateModMems)) } diff --git a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala index dd644323..a2b14343 100644 --- a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala +++ b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala @@ -24,19 +24,19 @@ object MemDelayAndReadwriteTransformer { case class SplitStatements(decls: Seq[Statement], conns: Seq[Connect]) // Utilities for generating hardware - def NOT(e: Expression) = DoPrim(PrimOps.Not, Seq(e), Nil, BoolType) - def AND(e1: Expression, e2: Expression) = DoPrim(PrimOps.And, Seq(e1, e2), Nil, BoolType) - def connect(l: Expression, r: Expression): Connect = Connect(NoInfo, l, r) - def condConnect(c: Expression)(l: Expression, r: Expression): Connect = connect(l, Mux(c, r, l, l.tpe)) + def NOT(e: Expression) = DoPrim(PrimOps.Not, Seq(e), Nil, BoolType) + def AND(e1: Expression, e2: Expression) = DoPrim(PrimOps.And, Seq(e1, e2), Nil, BoolType) + def connect(l: Expression, r: Expression): Connect = Connect(NoInfo, l, r) + def condConnect(c: Expression)(l: Expression, r: Expression): Connect = connect(l, Mux(c, r, l, l.tpe)) // Utilities for working with WithValid groups def connect(l: WithValid, r: WithValid): Seq[Connect] = { - val paired = (l.valid +: l.payload) zip (r.valid +: r.payload) + val paired = (l.valid +: l.payload).zip(r.valid +: r.payload) paired.map { case (le, re) => connect(le, re) } } def condConnect(l: WithValid, r: WithValid): Seq[Connect] = { - connect(l.valid, r.valid) +: (l.payload zip r.payload).map { case (le, re) => condConnect(r.valid)(le, re) } + connect(l.valid, r.valid) +: (l.payload.zip(r.payload)).map { case (le, re) => condConnect(r.valid)(le, re) } } // Internal representation of a pipeline stage with an associated valid signal @@ -47,20 +47,23 @@ object MemDelayAndReadwriteTransformer { private def flatName(e: Expression) = metaChars.replaceAllIn(e.serialize, "_") // Pipeline a group of signals with an associated valid signal. Gate registers when possible. - def pipelineWithValid(ns: Namespace)( - clock: Expression, - depth: Int, - src: WithValid, - nameTemplate: Option[WithValid] = None): (WithValid, Seq[Statement], Seq[Connect]) = { + def pipelineWithValid( + ns: Namespace + )(clock: Expression, + depth: Int, + src: WithValid, + nameTemplate: Option[WithValid] = None + ): (WithValid, Seq[Statement], Seq[Connect]) = { def asReg(e: Expression) = DefRegister(NoInfo, e.serialize, e.tpe, clock, zero, e) val template = nameTemplate.getOrElse(src) - val stages = Seq.iterate(PipeStageWithValid(0, src), depth + 1) { case prev => - def pipeRegRef(e: Expression) = WRef(ns.newName(s"${flatName(e)}_pipe_${prev.idx}"), e.tpe, RegKind) - val ref = WithValid(pipeRegRef(template.valid), template.payload.map(pipeRegRef)) - val regs = (ref.valid +: ref.payload).map(asReg) - PipeStageWithValid(prev.idx + 1, ref, SplitStatements(regs, condConnect(ref, prev.ref))) + val stages = Seq.iterate(PipeStageWithValid(0, src), depth + 1) { + case prev => + def pipeRegRef(e: Expression) = WRef(ns.newName(s"${flatName(e)}_pipe_${prev.idx}"), e.tpe, RegKind) + val ref = WithValid(pipeRegRef(template.valid), template.payload.map(pipeRegRef)) + val regs = (ref.valid +: ref.payload).map(asReg) + PipeStageWithValid(prev.idx + 1, ref, SplitStatements(regs, condConnect(ref, prev.ref))) } (stages.last.ref, stages.flatMap(_.stmts.decls), stages.flatMap(_.stmts.conns)) } @@ -84,10 +87,10 @@ class MemDelayAndReadwriteTransformer(m: DefModule) { private def findMemConns(s: Statement): Unit = s match { case Connect(_, loc, expr) if (kind(loc) == MemKind) => netlist(we(loc)) = expr - case _ => s.foreach(findMemConns) + case _ => s.foreach(findMemConns) } - private def swapMemRefs(e: Expression): Expression = e map swapMemRefs match { + private def swapMemRefs(e: Expression): Expression = e.map(swapMemRefs) match { case sf: WSubField => exprReplacements.getOrElse(we(sf), sf) case ex => ex } @@ -105,51 +108,57 @@ class MemDelayAndReadwriteTransformer(m: DefModule) { val rRespDelay = if (mem.readUnderWrite == ReadUnderWrite.Old) mem.readLatency else 0 val wCmdDelay = mem.writeLatency - 1 - val readStmts = (mem.readers ++ mem.readwriters).map { case r => - def oldDriver(f: String) = netlist(we(memPortField(mem, r, f))) - def newField(f: String) = memPortField(newMem, rMap.getOrElse(r, r), f) - val clk = oldDriver("clk") - - // Pack sources of read command inputs into WithValid object -> different for readwriter - val enSrc = if (rMap.contains(r)) AND(oldDriver("en"), NOT(oldDriver("wmode"))) else oldDriver("en") - val cmdSrc = WithValid(enSrc, Seq(oldDriver("addr"))) - val cmdSink = WithValid(newField("en"), Seq(newField("addr"))) - val (cmdPiped, cmdDecls, cmdConns) = pipelineWithValid(ns)(clk, rCmdDelay, cmdSrc, nameTemplate = Some(cmdSink)) - val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk) - - // Pipeline read response using *last* command pipe stage enable as the valid signal - val resp = WithValid(cmdPiped.valid, Seq(newField("data"))) - val respPipeNameTemplate = Some(resp.copy(valid = cmdSink.valid)) // base pipeline register names off field names - val (respPiped, respDecls, respConns) = pipelineWithValid(ns)(clk, rRespDelay, resp, nameTemplate = respPipeNameTemplate) - - // Make sure references to the read data get appropriately substituted - val oldRDataName = if (rMap.contains(r)) "rdata" else "data" - exprReplacements(we(memPortField(mem, r, oldRDataName))) = respPiped.payload.head - - // Return all statements; they're separated so connects can go after all declarations - SplitStatements(cmdDecls ++ respDecls, cmdConns ++ cmdPortConns ++ respConns) + val readStmts = (mem.readers ++ mem.readwriters).map { + case r => + def oldDriver(f: String) = netlist(we(memPortField(mem, r, f))) + def newField(f: String) = memPortField(newMem, rMap.getOrElse(r, r), f) + val clk = oldDriver("clk") + + // Pack sources of read command inputs into WithValid object -> different for readwriter + val enSrc = if (rMap.contains(r)) AND(oldDriver("en"), NOT(oldDriver("wmode"))) else oldDriver("en") + val cmdSrc = WithValid(enSrc, Seq(oldDriver("addr"))) + val cmdSink = WithValid(newField("en"), Seq(newField("addr"))) + val (cmdPiped, cmdDecls, cmdConns) = + pipelineWithValid(ns)(clk, rCmdDelay, cmdSrc, nameTemplate = Some(cmdSink)) + val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk) + + // Pipeline read response using *last* command pipe stage enable as the valid signal + val resp = WithValid(cmdPiped.valid, Seq(newField("data"))) + val respPipeNameTemplate = + Some(resp.copy(valid = cmdSink.valid)) // base pipeline register names off field names + val (respPiped, respDecls, respConns) = + pipelineWithValid(ns)(clk, rRespDelay, resp, nameTemplate = respPipeNameTemplate) + + // Make sure references to the read data get appropriately substituted + val oldRDataName = if (rMap.contains(r)) "rdata" else "data" + exprReplacements(we(memPortField(mem, r, oldRDataName))) = respPiped.payload.head + + // Return all statements; they're separated so connects can go after all declarations + SplitStatements(cmdDecls ++ respDecls, cmdConns ++ cmdPortConns ++ respConns) } - val writeStmts = (mem.writers ++ mem.readwriters).map { case w => - def oldDriver(f: String) = netlist(we(memPortField(mem, w, f))) - def newField(f: String) = memPortField(newMem, wMap.getOrElse(w, w), f) - val clk = oldDriver("clk") - - // Pack sources of write command inputs into WithValid object -> different for readwriter - val cmdSrc = if (wMap.contains(w)) { - val en = AND(oldDriver("en"), oldDriver("wmode")) - WithValid(en, Seq(oldDriver("addr"), oldDriver("wmask"), oldDriver("wdata"))) - } else { - WithValid(oldDriver("en"), Seq(oldDriver("addr"), oldDriver("mask"), oldDriver("data"))) - } - - // Pipeline write command, connect to memory - val cmdSink = WithValid(newField("en"), Seq(newField("addr"), newField("mask"), newField("data"))) - val (cmdPiped, cmdDecls, cmdConns) = pipelineWithValid(ns)(clk, wCmdDelay, cmdSrc, nameTemplate = Some(cmdSink)) - val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk) - - // Return all statements; they're separated so connects can go after all declarations - SplitStatements(cmdDecls, cmdConns ++ cmdPortConns) + val writeStmts = (mem.writers ++ mem.readwriters).map { + case w => + def oldDriver(f: String) = netlist(we(memPortField(mem, w, f))) + def newField(f: String) = memPortField(newMem, wMap.getOrElse(w, w), f) + val clk = oldDriver("clk") + + // Pack sources of write command inputs into WithValid object -> different for readwriter + val cmdSrc = if (wMap.contains(w)) { + val en = AND(oldDriver("en"), oldDriver("wmode")) + WithValid(en, Seq(oldDriver("addr"), oldDriver("wmask"), oldDriver("wdata"))) + } else { + WithValid(oldDriver("en"), Seq(oldDriver("addr"), oldDriver("mask"), oldDriver("data"))) + } + + // Pipeline write command, connect to memory + val cmdSink = WithValid(newField("en"), Seq(newField("addr"), newField("mask"), newField("data"))) + val (cmdPiped, cmdDecls, cmdConns) = + pipelineWithValid(ns)(clk, wCmdDelay, cmdSrc, nameTemplate = Some(cmdSink)) + val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk) + + // Return all statements; they're separated so connects can go after all declarations + SplitStatements(cmdDecls, cmdConns ++ cmdPortConns) } newConns ++= (readStmts ++ writeStmts).flatMap(_.conns) @@ -171,8 +180,7 @@ object VerilogMemDelays extends Pass { override def prerequisites = firrtl.stage.Forms.LowForm :+ Dependency(firrtl.passes.RemoveValidIf) override val optionalPrerequisiteOf = - Seq( Dependency[VerilogEmitter], - Dependency[SystemVerilogEmitter] ) + Seq(Dependency[VerilogEmitter], Dependency[SystemVerilogEmitter]) override def invalidates(a: Transform): Boolean = a match { case _: transforms.ConstantPropagation | ResolveFlows => true @@ -180,5 +188,5 @@ object VerilogMemDelays extends Pass { } def transform(m: DefModule): DefModule = (new MemDelayAndReadwriteTransformer(m)).transformed - def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(transform)) + def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(transform)) } diff --git a/src/main/scala/firrtl/passes/memlib/YamlUtils.scala b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala index a43adfe2..b5f91e7b 100644 --- a/src/main/scala/firrtl/passes/memlib/YamlUtils.scala +++ b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala @@ -6,7 +6,6 @@ import net.jcazevedo.moultingyaml._ import java.io.{CharArrayWriter, File, PrintWriter} import firrtl.FileUtils - object CustomYAMLProtocol extends DefaultYamlProtocol { // bottom depends on top implicit val _pin = yamlFormat1(Pin) @@ -20,17 +19,15 @@ case class Source(name: String, module: String) case class Top(name: String) case class Config(pin: Pin, source: Source, top: Top) - class YamlFileReader(file: String) { - def parse[A](implicit reader: YamlReader[A]) : Seq[A] = { + def parse[A](implicit reader: YamlReader[A]): Seq[A] = { if (new File(file).exists) { val yamlString = FileUtils.getText(file) - yamlString.parseYamls flatMap (x => - try Some(reader read x) + yamlString.parseYamls.flatMap(x => + try Some(reader.read(x)) catch { case e: Exception => None } ) - } - else sys.error("Yaml file doesn't exist!") + } else sys.error("Yaml file doesn't exist!") } } @@ -38,11 +35,11 @@ class YamlFileWriter(file: String) { val outputBuffer = new CharArrayWriter val separator = "--- \n" def append(in: YamlValue): Unit = { - outputBuffer append s"$separator${in.prettyPrint}" + outputBuffer.append(s"$separator${in.prettyPrint}") } def dump(): Unit = { val outputFile = new PrintWriter(file) - outputFile write outputBuffer.toString + outputFile.write(outputBuffer.toString) outputFile.close() } } diff --git a/src/main/scala/firrtl/passes/wiring/Wiring.scala b/src/main/scala/firrtl/passes/wiring/Wiring.scala index 3f74e5d2..a69b7797 100644 --- a/src/main/scala/firrtl/passes/wiring/Wiring.scala +++ b/src/main/scala/firrtl/passes/wiring/Wiring.scala @@ -18,8 +18,7 @@ import firrtl.graph.EulerTour case class WiringInfo(source: ComponentName, sinks: Seq[Named], pin: String) /** A data store of wiring names */ -case class WiringNames(compName: String, source: String, sinks: Seq[Named], - pin: String) +case class WiringNames(compName: String, source: String, sinks: Seq[Named], pin: String) /** Pass that computes and applies a sequence of wiring modifications * @@ -28,31 +27,39 @@ case class WiringNames(compName: String, source: String, sinks: Seq[Named], */ class Wiring(wiSeq: Seq[WiringInfo]) extends Pass { def run(c: Circuit): Circuit = analyze(c) - .foldLeft(c){ - case (cx, (tpe, modsMap)) => cx.copy( - modules = cx.modules map onModule(tpe, modsMap)) } + .foldLeft(c) { + case (cx, (tpe, modsMap)) => cx.copy(modules = cx.modules.map(onModule(tpe, modsMap))) + } /** Converts multiple units of wiring information to module modifications */ private def analyze(c: Circuit): Seq[(Type, Map[String, Modifications])] = { val names = wiSeq - .map ( wi => (wi.source, wi.sinks, wi.pin) match { - case (ComponentName(comp, ModuleName(source,_)), sinks, pin) => - WiringNames(comp, source, sinks, pin) }) + .map(wi => + (wi.source, wi.sinks, wi.pin) match { + case (ComponentName(comp, ModuleName(source, _)), sinks, pin) => + WiringNames(comp, source, sinks, pin) + } + ) val portNames = mutable.Seq.fill(names.size)(Map[String, String]()) - c.modules.foreach{ m => + c.modules.foreach { m => val ns = Namespace(m) - names.zipWithIndex.foreach{ case (WiringNames(c, so, si, p), i) => - portNames(i) = portNames(i) + - ( m.name -> { - if (si.exists(getModuleName(_) == m.name)) ns.newName(p) - else ns.newName(tokenize(c) filterNot ("[]." contains _) mkString "_") - })}} + names.zipWithIndex.foreach { + case (WiringNames(c, so, si, p), i) => + portNames(i) = portNames(i) + + (m.name -> { + if (si.exists(getModuleName(_) == m.name)) ns.newName(p) + else ns.newName(tokenize(c).filterNot("[]." contains _).mkString("_")) + }) + } + } val iGraph = InstanceKeyGraph(c) - names.zip(portNames).map{ case(WiringNames(comp, so, si, _), pn) => - computeModifications(c, iGraph, comp, so, si, pn) } + names.zip(portNames).map { + case (WiringNames(comp, so, si, _), pn) => + computeModifications(c, iGraph, comp, so, si, pn) + } } /** Converts a single unit of wiring information to module modifications @@ -69,19 +76,20 @@ class Wiring(wiSeq: Seq[WiringInfo]) extends Pass { * @return a tuple of the component type and a map of module names * to pending modifications */ - private def computeModifications(c: Circuit, - iGraph: InstanceKeyGraph, - compName: String, - source: String, - sinks: Seq[Named], - portNames: Map[String, String]): - (Type, Map[String, Modifications]) = { + private def computeModifications( + c: Circuit, + iGraph: InstanceKeyGraph, + compName: String, + source: String, + sinks: Seq[Named], + portNames: Map[String, String] + ): (Type, Map[String, Modifications]) = { val sourceComponentType = getType(c, source, compName) - val sinkComponents: Map[String, Seq[String]] = sinks - .collect{ case ComponentName(c, ModuleName(m, _)) => (c, m) } - .foldLeft(new scala.collection.immutable.HashMap[String, Seq[String]]){ - case (a, (c, m)) => a ++ Map(m -> (Seq(c) ++ a.getOrElse(m, Nil)) ) } + val sinkComponents: Map[String, Seq[String]] = sinks.collect { case ComponentName(c, ModuleName(m, _)) => (c, m) } + .foldLeft(new scala.collection.immutable.HashMap[String, Seq[String]]) { + case (a, (c, m)) => a ++ Map(m -> (Seq(c) ++ a.getOrElse(m, Nil))) + } // Determine "ownership" of sources to sinks via minimum distance val owners = sinksToSourcesSeq(sinks, source, iGraph) @@ -95,86 +103,88 @@ class Wiring(wiSeq: Seq[WiringInfo]) extends Pass { def makeWire(m: Modifications, portName: String): Modifications = m.copy(addPortOrWire = Some(m.addPortOrWire.getOrElse((portName, DecWire)))) def makeWireC(m: Modifications, portName: String, c: (String, String)): Modifications = - m.copy(addPortOrWire = Some(m.addPortOrWire.getOrElse((portName, DecWire))), cons = (m.cons :+ c).distinct ) + m.copy(addPortOrWire = Some(m.addPortOrWire.getOrElse((portName, DecWire))), cons = (m.cons :+ c).distinct) val tour = EulerTour(iGraph.graph, iGraph.top) // Finds the lowest common ancestor instances for two module names in a design def lowestCommonAncestor(moduleA: Seq[InstanceKey], moduleB: Seq[InstanceKey]): Seq[InstanceKey] = tour.rmq(moduleA, moduleB) - owners.foreach { case (sink, source) => - val lca = lowestCommonAncestor(sink, source) - - // Compute metadata along Sink to LCA paths. - sink.drop(lca.size - 1).sliding(2).toList.reverse.foreach { - case Seq(InstanceKey(_,pm), InstanceKey(ci,cm)) => - val to = s"$ci.${portNames(cm)}" - val from = s"${portNames(pm)}" - meta(pm) = makeWireC(meta(pm), portNames(pm), (to, from)) - meta(cm) = meta(cm).copy( - addPortOrWire = Some((portNames(cm), DecInput)) - ) - // Case where the sink is the LCA - case Seq(InstanceKey(_,pm)) => - // Case where the source is also the LCA - if (source.drop(lca.size).isEmpty) { - meta(pm) = makeWire(meta(pm), portNames(pm)) - } else { - val InstanceKey(ci,cm) = source.drop(lca.size).head - val to = s"${portNames(pm)}" - val from = s"$ci.${portNames(cm)}" - meta(pm) = makeWireC(meta(pm), portNames(pm), (to, from)) - } - } + owners.foreach { + case (sink, source) => + val lca = lowestCommonAncestor(sink, source) - // Compute metadata for the Sink - sink.last match { case InstanceKey( _, m) => - if (sinkComponents.contains(m)) { - val from = s"${portNames(m)}" - sinkComponents(m).foreach( to => - meta(m) = meta(m).copy( - cons = (meta(m).cons :+( (to, from) )).distinct + // Compute metadata along Sink to LCA paths. + sink.drop(lca.size - 1).sliding(2).toList.reverse.foreach { + case Seq(InstanceKey(_, pm), InstanceKey(ci, cm)) => + val to = s"$ci.${portNames(cm)}" + val from = s"${portNames(pm)}" + meta(pm) = makeWireC(meta(pm), portNames(pm), (to, from)) + meta(cm) = meta(cm).copy( + addPortOrWire = Some((portNames(cm), DecInput)) ) - ) + // Case where the sink is the LCA + case Seq(InstanceKey(_, pm)) => + // Case where the source is also the LCA + if (source.drop(lca.size).isEmpty) { + meta(pm) = makeWire(meta(pm), portNames(pm)) + } else { + val InstanceKey(ci, cm) = source.drop(lca.size).head + val to = s"${portNames(pm)}" + val from = s"$ci.${portNames(cm)}" + meta(pm) = makeWireC(meta(pm), portNames(pm), (to, from)) + } } - } - // Compute metadata for the Source - source.last match { case InstanceKey( _, m) => - val to = s"${portNames(m)}" - val from = compName - meta(m) = meta(m).copy( - cons = (meta(m).cons :+( (to, from) )).distinct - ) - } + // Compute metadata for the Sink + sink.last match { + case InstanceKey(_, m) => + if (sinkComponents.contains(m)) { + val from = s"${portNames(m)}" + sinkComponents(m).foreach(to => + meta(m) = meta(m).copy( + cons = (meta(m).cons :+ ((to, from))).distinct + ) + ) + } + } - // Compute metadata along Source to LCA path - source.drop(lca.size - 1).sliding(2).toList.reverse.map { - case Seq(InstanceKey(_,pm), InstanceKey(ci,cm)) => { - val to = s"${portNames(pm)}" - val from = s"$ci.${portNames(cm)}" - meta(pm) = meta(pm).copy( - cons = (meta(pm).cons :+( (to, from) )).distinct - ) - meta(cm) = meta(cm).copy( - addPortOrWire = Some((portNames(cm), DecOutput)) - ) + // Compute metadata for the Source + source.last match { + case InstanceKey(_, m) => + val to = s"${portNames(m)}" + val from = compName + meta(m) = meta(m).copy( + cons = (meta(m).cons :+ ((to, from))).distinct + ) } - // Case where the source is the LCA - case Seq(InstanceKey(_,pm)) => { - // Case where the sink is also the LCA. We do nothing here, - // as we've created the connecting wire above - if (sink.drop(lca.size).isEmpty) { - } else { - val InstanceKey(ci,cm) = sink.drop(lca.size).head - val to = s"$ci.${portNames(cm)}" - val from = s"${portNames(pm)}" + + // Compute metadata along Source to LCA path + source.drop(lca.size - 1).sliding(2).toList.reverse.map { + case Seq(InstanceKey(_, pm), InstanceKey(ci, cm)) => { + val to = s"${portNames(pm)}" + val from = s"$ci.${portNames(cm)}" meta(pm) = meta(pm).copy( - cons = (meta(pm).cons :+( (to, from) )).distinct + cons = (meta(pm).cons :+ ((to, from))).distinct ) + meta(cm) = meta(cm).copy( + addPortOrWire = Some((portNames(cm), DecOutput)) + ) + } + // Case where the source is the LCA + case Seq(InstanceKey(_, pm)) => { + // Case where the sink is also the LCA. We do nothing here, + // as we've created the connecting wire above + if (sink.drop(lca.size).isEmpty) {} else { + val InstanceKey(ci, cm) = sink.drop(lca.size).head + val to = s"$ci.${portNames(cm)}" + val from = s"${portNames(pm)}" + meta(pm) = meta(pm).copy( + cons = (meta(pm).cons :+ ((to, from))).distinct + ) + } } } - } } (sourceComponentType, meta.toMap) } @@ -189,20 +199,22 @@ class Wiring(wiSeq: Seq[WiringInfo]) extends Pass { val ports = mutable.ArrayBuffer[Port]() l.addPortOrWire match { case None => - case Some((s, dt)) => dt match { - case DecInput => ports += Port(NoInfo, s, Input, t) - case DecOutput => ports += Port(NoInfo, s, Output, t) - case DecWire => defines += DefWire(NoInfo, s, t) - } + case Some((s, dt)) => + dt match { + case DecInput => ports += Port(NoInfo, s, Input, t) + case DecOutput => ports += Port(NoInfo, s, Output, t) + case DecWire => defines += DefWire(NoInfo, s, t) + } } - connects ++= (l.cons map { case ((l, r)) => - Connect(NoInfo, toExp(l), toExp(r)) + connects ++= (l.cons.map { + case ((l, r)) => + Connect(NoInfo, toExp(l), toExp(r)) }) m match { case Module(i, n, ps, body) => val stmts = body match { case Block(sx) => sx - case s => Seq(s) + case s => Seq(s) } Module(i, n, ps ++ ports, Block(List() ++ defines ++ stmts ++ connects)) case ExtModule(i, n, ps, dn, p) => ExtModule(i, n, ps ++ ports, dn, p) diff --git a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala index 20fb1215..d6658f16 100644 --- a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala +++ b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala @@ -14,14 +14,12 @@ import firrtl.stage.Forms case class WiringException(msg: String) extends PassException(msg) /** A component, e.g. register etc. Must be declared only once under the TopAnnotation */ -case class SourceAnnotation(target: ComponentName, pin: String) extends - SingleTargetAnnotation[ComponentName] { +case class SourceAnnotation(target: ComponentName, pin: String) extends SingleTargetAnnotation[ComponentName] { def duplicate(n: ComponentName) = this.copy(target = n) } /** A module, e.g. ExtModule etc., that should add the input pin */ -case class SinkAnnotation(target: Named, pin: String) extends - SingleTargetAnnotation[Named] { +case class SinkAnnotation(target: Named, pin: String) extends SingleTargetAnnotation[Named] { def duplicate(n: Named) = this.copy(target = n) } @@ -76,8 +74,9 @@ class WiringTransform extends Transform with DependencyAPIMigration { (sources.size, sinks.size) match { case (0, p) => state case (s, p) if (p > 0) => - val wis = sources.foldLeft(Seq[WiringInfo]()) { case (seq, (pin, source)) => - seq :+ WiringInfo(source, sinks(pin), pin) + val wis = sources.foldLeft(Seq[WiringInfo]()) { + case (seq, (pin, source)) => + seq :+ WiringInfo(source, sinks(pin), pin) } val annosx = state.annotations.filterNot(annos.toSet.contains) transforms(wis) diff --git a/src/main/scala/firrtl/passes/wiring/WiringUtils.scala b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala index c220692a..5e8f8616 100644 --- a/src/main/scala/firrtl/passes/wiring/WiringUtils.scala +++ b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala @@ -25,54 +25,54 @@ case object DecWire extends DecKind /** Store of pending wiring information for a Module */ case class Modifications( addPortOrWire: Option[(String, DecKind)] = None, - cons: Seq[(String, String)] = Seq.empty) { + cons: Seq[(String, String)] = Seq.empty) { override def toString: String = serialize("") def serialize(tab: String): String = s""" - |$tab addPortOrWire: $addPortOrWire - |$tab cons: $cons - |""".stripMargin + |$tab addPortOrWire: $addPortOrWire + |$tab cons: $cons + |""".stripMargin } /** A lineage tree representing the instance hierarchy in a design */ @deprecated("Use DiGraph/InstanceGraph", "1.1.1") case class Lineage( - name: String, - children: Seq[(String, Lineage)] = Seq.empty, - source: Boolean = false, - sink: Boolean = false, - sourceParent: Boolean = false, - sinkParent: Boolean = false, - sharedParent: Boolean = false, - addPort: Option[(String, DecKind)] = None, - cons: Seq[(String, String)] = Seq.empty) { + name: String, + children: Seq[(String, Lineage)] = Seq.empty, + source: Boolean = false, + sink: Boolean = false, + sourceParent: Boolean = false, + sinkParent: Boolean = false, + sharedParent: Boolean = false, + addPort: Option[(String, DecKind)] = None, + cons: Seq[(String, String)] = Seq.empty) { def map(f: Lineage => Lineage): Lineage = - this.copy(children = children.map{ case (i, m) => (i, f(m)) }) + this.copy(children = children.map { case (i, m) => (i, f(m)) }) override def toString: String = shortSerialize("") def shortSerialize(tab: String): String = s""" - |$tab name: $name, - |$tab children: ${children.map(c => tab + " " + c._2.shortSerialize(tab + " "))} - |""".stripMargin + |$tab name: $name, + |$tab children: ${children.map(c => tab + " " + c._2.shortSerialize(tab + " "))} + |""".stripMargin def foldLeft[B](z: B)(op: (B, (String, Lineage)) => B): B = this.children.foldLeft(z)(op) def serialize(tab: String): String = s""" - |$tab name: $name, - |$tab source: $source, - |$tab sink: $sink, - |$tab sourceParent: $sourceParent, - |$tab sinkParent: $sinkParent, - |$tab sharedParent: $sharedParent, - |$tab addPort: $addPort - |$tab cons: $cons - |$tab children: ${children.map(c => tab + " " + c._2.serialize(tab + " "))} - |""".stripMargin + |$tab name: $name, + |$tab source: $source, + |$tab sink: $sink, + |$tab sourceParent: $sourceParent, + |$tab sinkParent: $sinkParent, + |$tab sharedParent: $sharedParent, + |$tab addPort: $addPort + |$tab cons: $cons + |$tab children: ${children.map(c => tab + " " + c._2.serialize(tab + " "))} + |""".stripMargin } object WiringUtils { @@ -87,12 +87,12 @@ object WiringUtils { val childrenMap = new ChildrenMap() def getChildren(mname: String)(s: Statement): Unit = s match { case s: WDefInstance => - childrenMap(mname) = childrenMap(mname) :+( (s.name, s.module) ) + childrenMap(mname) = childrenMap(mname) :+ ((s.name, s.module)) case s: DefInstance => - childrenMap(mname) = childrenMap(mname) :+( (s.name, s.module) ) + childrenMap(mname) = childrenMap(mname) :+ ((s.name, s.module)) case s => s.foreach(getChildren(mname)) } - c.modules.foreach{ m => + c.modules.foreach { m => childrenMap(m.name) = Nil m.foreach(getChildren(m.name)) } @@ -103,7 +103,7 @@ object WiringUtils { */ @deprecated("Use DiGraph/InstanceGraph", "1.1.1") def getLineage(childrenMap: ChildrenMap, module: String): Lineage = - Lineage(module, childrenMap(module) map { case (i, m) => (i, getLineage(childrenMap, m)) } ) + Lineage(module, childrenMap(module).map { case (i, m) => (i, getLineage(childrenMap, m)) }) /** Return a map of sink instances to source instances that minimizes * distance @@ -114,22 +114,25 @@ object WiringUtils { * @return a map of sink instance names to source instance names * @throws WiringException if a sink is equidistant to two sources */ - @deprecated("This method can lead to non-determinism in your compiler pass and exposes internal details." + - " Please file an issue with firrtl if you have a use case!", "Firrtl 1.4") + @deprecated( + "This method can lead to non-determinism in your compiler pass and exposes internal details." + + " Please file an issue with firrtl if you have a use case!", + "Firrtl 1.4" + ) def sinksToSources(sinks: Seq[Named], source: String, i: InstanceGraph): Map[Seq[WDefInstance], Seq[WDefInstance]] = { // The order of owners influences the order of the results, it thus needs to be deterministic with a LinkedHashMap. val owners = new mutable.LinkedHashMap[Seq[WDefInstance], Vector[Seq[WDefInstance]]] val queue = new mutable.Queue[Seq[WDefInstance]] val visited = new mutable.HashMap[Seq[WDefInstance], Boolean].withDefaultValue(false) - val sourcePaths = i.fullHierarchy.collect { case (k,v) if k.module == source => v } + val sourcePaths = i.fullHierarchy.collect { case (k, v) if k.module == source => v } sourcePaths.flatten.foreach { l => queue.enqueue(l) owners(l) = Vector(l) } val sinkModuleNames = sinks.map(getModuleName).toSet - val sinkPaths = i.fullHierarchy.collect { case (k,v) if sinkModuleNames.contains(k.module) => v } + val sinkPaths = i.fullHierarchy.collect { case (k, v) if sinkModuleNames.contains(k.module) => v } // sinkInsts needs to have unique entries but is also iterated over which is why we use a LinkedHashSet val sinkInsts = mutable.LinkedHashSet() ++ sinkPaths.flatten @@ -156,8 +159,8 @@ object WiringUtils { // [todo] This is the critical section edges - .filter( e => !visited(e) && e.nonEmpty ) - .foreach{ v => + .filter(e => !visited(e) && e.nonEmpty) + .foreach { v => owners(v) = owners.getOrElse(v, Vector()) ++ owners(u) queue.enqueue(v) } @@ -167,8 +170,8 @@ object WiringUtils { // this should fail is if a sink is equidistant to two sources. sinkInsts.foreach { s => if (!owners.contains(s) || owners(s).size > 1) { - throw new WiringException( - s"Unable to determine source mapping for sink '${s.map(_.name)}'") } + throw new WiringException(s"Unable to determine source mapping for sink '${s.map(_.name)}'") + } } } @@ -184,21 +187,24 @@ object WiringUtils { * @return a map of sink instance names to source instance names * @throws WiringException if a sink is equidistant to two sources */ - private[firrtl] def sinksToSourcesSeq(sinks: Seq[Named], source: String, i: InstanceKeyGraph): - Seq[(Seq[InstanceKey], Seq[InstanceKey])] = { + private[firrtl] def sinksToSourcesSeq( + sinks: Seq[Named], + source: String, + i: InstanceKeyGraph + ): Seq[(Seq[InstanceKey], Seq[InstanceKey])] = { // The order of owners influences the order of the results, it thus needs to be deterministic with a LinkedHashMap. val owners = new mutable.LinkedHashMap[Seq[InstanceKey], Vector[Seq[InstanceKey]]] val queue = new mutable.Queue[Seq[InstanceKey]] val visited = new mutable.HashMap[Seq[InstanceKey], Boolean].withDefaultValue(false) - val sourcePaths = i.fullHierarchy.collect { case (k,v) if k.module == source => v } + val sourcePaths = i.fullHierarchy.collect { case (k, v) if k.module == source => v } sourcePaths.flatten.foreach { l => queue.enqueue(l) owners(l) = Vector(l) } val sinkModuleNames = sinks.map(getModuleName).toSet - val sinkPaths = i.fullHierarchy.collect { case (k,v) if sinkModuleNames.contains(k.module) => v } + val sinkPaths = i.fullHierarchy.collect { case (k, v) if sinkModuleNames.contains(k.module) => v } // sinkInsts needs to have unique entries but is also iterated over which is why we use a LinkedHashSet val sinkInsts = mutable.LinkedHashSet() ++ sinkPaths.flatten @@ -225,8 +231,8 @@ object WiringUtils { // [todo] This is the critical section edges - .filter( e => !visited(e) && e.nonEmpty ) - .foreach{ v => + .filter(e => !visited(e) && e.nonEmpty) + .foreach { v => owners(v) = owners.getOrElse(v, Vector()) ++ owners(u) queue.enqueue(v) } @@ -236,8 +242,8 @@ object WiringUtils { // this should fail is if a sink is equidistant to two sources. sinkInsts.foreach { s => if (!owners.contains(s) || owners(s).size > 1) { - throw new WiringException( - s"Unable to determine source mapping for sink '${s.map(_.name)}'") } + throw new WiringException(s"Unable to determine source mapping for sink '${s.map(_.name)}'") + } } } @@ -249,8 +255,7 @@ object WiringUtils { n match { case ModuleName(m, _) => m case ComponentName(_, ModuleName(m, _)) => m - case _ => throw new WiringException( - "Only Components or Modules have an associated Module name") + case _ => throw new WiringException("Only Components or Modules have an associated Module name") } } @@ -266,9 +271,9 @@ object WiringUtils { def getType(c: Circuit, module: String, comp: String): Type = { def getRoot(e: Expression): String = e match { case r: Reference => r.name - case i: SubIndex => getRoot(i.expr) + case i: SubIndex => getRoot(i.expr) case a: SubAccess => getRoot(a.expr) - case f: SubField => getRoot(f.expr) + case f: SubField => getRoot(f.expr) } val eComp = toExp(comp) val root = getRoot(eComp) @@ -289,11 +294,12 @@ object WiringUtils { case sx: DefMemory if sx.name == root => tpe = Some(MemPortUtils.memType(sx)) sx - case sx => sx map getType + case sx => sx.map(getType) + } + val m = c.modules.find(_.name == module).getOrElse { + throw new WiringException(s"Must have a module named $module") } - val m = c.modules find (_.name == module) getOrElse { - throw new WiringException(s"Must have a module named $module") } - tpe = m.ports find (_.name == root) map (_.tpe) + tpe = m.ports.find(_.name == root).map(_.tpe) m match { case Module(i, n, ps, b) => getType(b) case e: ExtModule => @@ -301,10 +307,10 @@ object WiringUtils { tpe match { case None => throw new WiringException(s"Didn't find $comp in $module!") case Some(t) => - def setType(e: Expression): Expression = e map setType match { + def setType(e: Expression): Expression = e.map(setType) match { case ex: Reference => ex.copy(tpe = t) - case ex: SubField => ex.copy(tpe = field_type(ex.expr.tpe, ex.name)) - case ex: SubIndex => ex.copy(tpe = sub_type(ex.expr.tpe)) + case ex: SubField => ex.copy(tpe = field_type(ex.expr.tpe, ex.name)) + case ex: SubIndex => ex.copy(tpe = sub_type(ex.expr.tpe)) case ex: SubAccess => ex.copy(tpe = sub_type(ex.expr.tpe)) } setType(eComp).tpe |
