aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes
diff options
context:
space:
mode:
authorchick2020-08-14 19:47:53 -0700
committerJack Koenig2020-08-14 19:47:53 -0700
commit6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch)
tree2ed103ee80b0fba613c88a66af854ae9952610ce /src/main/scala/firrtl/passes
parentb516293f703c4de86397862fee1897aded2ae140 (diff)
All of src/ formatted with scalafmt
Diffstat (limited to 'src/main/scala/firrtl/passes')
-rw-r--r--src/main/scala/firrtl/passes/CInferMDir.scala67
-rw-r--r--src/main/scala/firrtl/passes/CheckChirrtl.scala4
-rw-r--r--src/main/scala/firrtl/passes/CheckFlows.scala84
-rw-r--r--src/main/scala/firrtl/passes/CheckHighForm.scala227
-rw-r--r--src/main/scala/firrtl/passes/CheckInitialization.scala11
-rw-r--r--src/main/scala/firrtl/passes/CheckTypes.scala376
-rw-r--r--src/main/scala/firrtl/passes/CheckWidths.scala139
-rw-r--r--src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala37
-rw-r--r--src/main/scala/firrtl/passes/ConvertFixedToSInt.scala90
-rw-r--r--src/main/scala/firrtl/passes/ExpandConnects.scala66
-rw-r--r--src/main/scala/firrtl/passes/ExpandWhens.scala173
-rw-r--r--src/main/scala/firrtl/passes/InferBinaryPoints.scala98
-rw-r--r--src/main/scala/firrtl/passes/InferTypes.scala76
-rw-r--r--src/main/scala/firrtl/passes/InferWidths.scala190
-rw-r--r--src/main/scala/firrtl/passes/Inline.scala246
-rw-r--r--src/main/scala/firrtl/passes/Legalize.scala31
-rw-r--r--src/main/scala/firrtl/passes/LowerTypes.scala300
-rw-r--r--src/main/scala/firrtl/passes/PadWidths.scala46
-rw-r--r--src/main/scala/firrtl/passes/Pass.scala2
-rw-r--r--src/main/scala/firrtl/passes/PullMuxes.scala80
-rw-r--r--src/main/scala/firrtl/passes/RemoveAccesses.scala79
-rw-r--r--src/main/scala/firrtl/passes/RemoveCHIRRTL.scala196
-rw-r--r--src/main/scala/firrtl/passes/RemoveEmpty.scala2
-rw-r--r--src/main/scala/firrtl/passes/RemoveIntervals.scala149
-rw-r--r--src/main/scala/firrtl/passes/RemoveValidIf.scala22
-rw-r--r--src/main/scala/firrtl/passes/ReplaceAccesses.scala17
-rw-r--r--src/main/scala/firrtl/passes/ResolveFlows.scala25
-rw-r--r--src/main/scala/firrtl/passes/ResolveKinds.scala16
-rw-r--r--src/main/scala/firrtl/passes/SplitExpressions.scala100
-rw-r--r--src/main/scala/firrtl/passes/ToWorkingIR.scala2
-rw-r--r--src/main/scala/firrtl/passes/TrimIntervals.scala58
-rw-r--r--src/main/scala/firrtl/passes/Uniquify.scala241
-rw-r--r--src/main/scala/firrtl/passes/VerilogModulusCleanup.scala81
-rw-r--r--src/main/scala/firrtl/passes/VerilogPrep.scala34
-rw-r--r--src/main/scala/firrtl/passes/ZeroLengthVecs.scala21
-rw-r--r--src/main/scala/firrtl/passes/ZeroWidth.scala137
-rw-r--r--src/main/scala/firrtl/passes/clocklist/ClockList.scala26
-rw-r--r--src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala19
-rw-r--r--src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala45
-rw-r--r--src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala20
-rw-r--r--src/main/scala/firrtl/passes/memlib/DecorateMems.scala5
-rw-r--r--src/main/scala/firrtl/passes/memlib/InferReadWrite.scala101
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemConf.scala65
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemIR.scala56
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemLibOptions.scala3
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala16
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemUtils.scala69
-rw-r--r--src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala43
-rw-r--r--src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala118
-rw-r--r--src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala50
-rw-r--r--src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala36
-rw-r--r--src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala12
-rw-r--r--src/main/scala/firrtl/passes/memlib/ToMemIR.scala9
-rw-r--r--src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala136
-rw-r--r--src/main/scala/firrtl/passes/memlib/YamlUtils.scala15
-rw-r--r--src/main/scala/firrtl/passes/wiring/Wiring.scala212
-rw-r--r--src/main/scala/firrtl/passes/wiring/WiringTransform.scala11
-rw-r--r--src/main/scala/firrtl/passes/wiring/WiringUtils.scala122
58 files changed, 2592 insertions, 2120 deletions
diff --git a/src/main/scala/firrtl/passes/CInferMDir.scala b/src/main/scala/firrtl/passes/CInferMDir.scala
index b4819751..1fe8d57c 100644
--- a/src/main/scala/firrtl/passes/CInferMDir.scala
+++ b/src/main/scala/firrtl/passes/CInferMDir.scala
@@ -18,60 +18,61 @@ object CInferMDir extends Pass {
def infer_mdir_e(mports: MPortDirMap, dir: MPortDir)(e: Expression): Expression = e match {
case e: Reference =>
- mports get e.name match {
+ mports.get(e.name) match {
case None =>
- case Some(p) => mports(e.name) = (p, dir) match {
- case (MInfer, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir")
- case (MInfer, MWrite) => MWrite
- case (MInfer, MRead) => MRead
- case (MInfer, MReadWrite) => MReadWrite
- case (MWrite, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir")
- case (MWrite, MWrite) => MWrite
- case (MWrite, MRead) => MReadWrite
- case (MWrite, MReadWrite) => MReadWrite
- case (MRead, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir")
- case (MRead, MWrite) => MReadWrite
- case (MRead, MRead) => MRead
- case (MRead, MReadWrite) => MReadWrite
- case (MReadWrite, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir")
- case (MReadWrite, MWrite) => MReadWrite
- case (MReadWrite, MRead) => MReadWrite
- case (MReadWrite, MReadWrite) => MReadWrite
- }
+ case Some(p) =>
+ mports(e.name) = (p, dir) match {
+ case (MInfer, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir")
+ case (MInfer, MWrite) => MWrite
+ case (MInfer, MRead) => MRead
+ case (MInfer, MReadWrite) => MReadWrite
+ case (MWrite, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir")
+ case (MWrite, MWrite) => MWrite
+ case (MWrite, MRead) => MReadWrite
+ case (MWrite, MReadWrite) => MReadWrite
+ case (MRead, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir")
+ case (MRead, MWrite) => MReadWrite
+ case (MRead, MRead) => MRead
+ case (MRead, MReadWrite) => MReadWrite
+ case (MReadWrite, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir")
+ case (MReadWrite, MWrite) => MReadWrite
+ case (MReadWrite, MRead) => MReadWrite
+ case (MReadWrite, MReadWrite) => MReadWrite
+ }
}
e
case e: SubAccess =>
infer_mdir_e(mports, dir)(e.expr)
infer_mdir_e(mports, MRead)(e.index) // index can't be a write port
e
- case e => e map infer_mdir_e(mports, dir)
+ case e => e.map(infer_mdir_e(mports, dir))
}
def infer_mdir_s(mports: MPortDirMap)(s: Statement): Statement = s match {
case sx: CDefMPort =>
- mports(sx.name) = sx.direction
- sx map infer_mdir_e(mports, MRead)
+ mports(sx.name) = sx.direction
+ sx.map(infer_mdir_e(mports, MRead))
case sx: Connect =>
- infer_mdir_e(mports, MRead)(sx.expr)
- infer_mdir_e(mports, MWrite)(sx.loc)
- sx
+ infer_mdir_e(mports, MRead)(sx.expr)
+ infer_mdir_e(mports, MWrite)(sx.loc)
+ sx
case sx: PartialConnect =>
- infer_mdir_e(mports, MRead)(sx.expr)
- infer_mdir_e(mports, MWrite)(sx.loc)
- sx
- case sx => sx map infer_mdir_s(mports) map infer_mdir_e(mports, MRead)
+ infer_mdir_e(mports, MRead)(sx.expr)
+ infer_mdir_e(mports, MWrite)(sx.loc)
+ sx
+ case sx => sx.map(infer_mdir_s(mports)).map(infer_mdir_e(mports, MRead))
}
def set_mdir_s(mports: MPortDirMap)(s: Statement): Statement = s match {
- case sx: CDefMPort => sx copy (direction = mports(sx.name))
- case sx => sx map set_mdir_s(mports)
+ case sx: CDefMPort => sx.copy(direction = mports(sx.name))
+ case sx => sx.map(set_mdir_s(mports))
}
def infer_mdir(m: DefModule): DefModule = {
val mports = new MPortDirMap
- m map infer_mdir_s(mports) map set_mdir_s(mports)
+ m.map(infer_mdir_s(mports)).map(set_mdir_s(mports))
}
def run(c: Circuit): Circuit =
- c copy (modules = c.modules map infer_mdir)
+ c.copy(modules = c.modules.map(infer_mdir))
}
diff --git a/src/main/scala/firrtl/passes/CheckChirrtl.scala b/src/main/scala/firrtl/passes/CheckChirrtl.scala
index 9903f445..97d614c1 100644
--- a/src/main/scala/firrtl/passes/CheckChirrtl.scala
+++ b/src/main/scala/firrtl/passes/CheckChirrtl.scala
@@ -12,9 +12,7 @@ object CheckChirrtl extends Pass with CheckHighFormLike {
override def prerequisites = Dependency[CheckScalaVersion] :: Nil
override val optionalPrerequisiteOf = firrtl.stage.Forms.ChirrtlForm ++
- Seq( Dependency(CInferTypes),
- Dependency(CInferMDir),
- Dependency(RemoveCHIRRTL) )
+ Seq(Dependency(CInferTypes), Dependency(CInferMDir), Dependency(RemoveCHIRRTL))
override def invalidates(a: Transform) = false
diff --git a/src/main/scala/firrtl/passes/CheckFlows.scala b/src/main/scala/firrtl/passes/CheckFlows.scala
index 3a9cc212..bc455a20 100644
--- a/src/main/scala/firrtl/passes/CheckFlows.scala
+++ b/src/main/scala/firrtl/passes/CheckFlows.scala
@@ -13,79 +13,87 @@ object CheckFlows extends Pass {
override def prerequisites = Dependency(passes.ResolveFlows) +: firrtl.stage.Forms.WorkingIR
override def optionalPrerequisiteOf =
- Seq( Dependency[passes.InferBinaryPoints],
- Dependency[passes.TrimIntervals],
- Dependency[passes.InferWidths],
- Dependency[transforms.InferResets] )
+ Seq(
+ Dependency[passes.InferBinaryPoints],
+ Dependency[passes.TrimIntervals],
+ Dependency[passes.InferWidths],
+ Dependency[transforms.InferResets]
+ )
override def invalidates(a: Transform) = false
type FlowMap = collection.mutable.HashMap[String, Flow]
implicit def toStr(g: Flow): String = g match {
- case SourceFlow => "source"
- case SinkFlow => "sink"
+ case SourceFlow => "source"
+ case SinkFlow => "sink"
case UnknownFlow => "unknown"
- case DuplexFlow => "duplex"
+ case DuplexFlow => "duplex"
}
- class WrongFlow(info:Info, mname: String, expr: String, wrong: Flow, right: Flow) extends PassException(
- s"$info: [module $mname] Expression $expr is used as a $wrong but can only be used as a $right.")
+ class WrongFlow(info: Info, mname: String, expr: String, wrong: Flow, right: Flow)
+ extends PassException(
+ s"$info: [module $mname] Expression $expr is used as a $wrong but can only be used as a $right."
+ )
- def run (c:Circuit): Circuit = {
+ def run(c: Circuit): Circuit = {
val errors = new Errors()
def get_flow(e: Expression, flows: FlowMap): Flow = e match {
- case (e: WRef) => flows(e.name)
+ case (e: WRef) => flows(e.name)
case (e: WSubIndex) => get_flow(e.expr, flows)
case (e: WSubAccess) => get_flow(e.expr, flows)
- case (e: WSubField) => e.expr.tpe match {case t: BundleType =>
- val f = (t.fields find (_.name == e.name)).get
- times(get_flow(e.expr, flows), f.flip)
- }
+ case (e: WSubField) =>
+ e.expr.tpe match {
+ case t: BundleType =>
+ val f = (t.fields.find(_.name == e.name)).get
+ times(get_flow(e.expr, flows), f.flip)
+ }
case _ => SourceFlow
}
def flip_q(t: Type): Boolean = {
def flip_rec(t: Type, f: Orientation): Boolean = t match {
- case tx:BundleType => tx.fields exists (
- field => flip_rec(field.tpe, times(f, field.flip))
- )
+ case tx: BundleType => tx.fields.exists(field => flip_rec(field.tpe, times(f, field.flip)))
case tx: VectorType => flip_rec(tx.tpe, f)
case tx => f == Flip
}
flip_rec(t, Default)
}
- def check_flow(info:Info, mname: String, flows: FlowMap, desired: Flow)(e:Expression): Unit = {
- val flow = get_flow(e,flows)
+ def check_flow(info: Info, mname: String, flows: FlowMap, desired: Flow)(e: Expression): Unit = {
+ val flow = get_flow(e, flows)
(flow, desired) match {
case (SourceFlow, SinkFlow) =>
errors.append(new WrongFlow(info, mname, e.serialize, desired, flow))
- case (SinkFlow, SourceFlow) => kind(e) match {
- case PortKind | InstanceKind if !flip_q(e.tpe) => // OK!
- case _ =>
- errors.append(new WrongFlow(info, mname, e.serialize, desired, flow))
- }
+ case (SinkFlow, SourceFlow) =>
+ kind(e) match {
+ case PortKind | InstanceKind if !flip_q(e.tpe) => // OK!
+ case _ =>
+ errors.append(new WrongFlow(info, mname, e.serialize, desired, flow))
+ }
case _ =>
}
- }
+ }
- def check_flows_e (info:Info, mname: String, flows: FlowMap)(e:Expression): Unit = {
+ def check_flows_e(info: Info, mname: String, flows: FlowMap)(e: Expression): Unit = {
e match {
- case e: Mux => e foreach check_flow(info, mname, flows, SourceFlow)
- case e: DoPrim => e.args foreach check_flow(info, mname, flows, SourceFlow)
+ case e: Mux => e.foreach(check_flow(info, mname, flows, SourceFlow))
+ case e: DoPrim => e.args.foreach(check_flow(info, mname, flows, SourceFlow))
case _ =>
}
- e foreach check_flows_e(info, mname, flows)
+ e.foreach(check_flows_e(info, mname, flows))
}
def check_flows_s(minfo: Info, mname: String, flows: FlowMap)(s: Statement): Unit = {
- val info = get_info(s) match { case NoInfo => minfo case x => x }
+ val info = get_info(s) match {
+ case NoInfo => minfo
+ case x => x
+ }
s match {
- case (s: DefWire) => flows(s.name) = DuplexFlow
+ case (s: DefWire) => flows(s.name) = DuplexFlow
case (s: DefRegister) => flows(s.name) = DuplexFlow
- case (s: DefMemory) => flows(s.name) = SourceFlow
+ case (s: DefMemory) => flows(s.name) = SourceFlow
case (s: WDefInstance) => flows(s.name) = SourceFlow
case (s: DefNode) =>
check_flow(info, mname, flows, SourceFlow)(s.value)
@@ -94,7 +102,7 @@ object CheckFlows extends Pass {
check_flow(info, mname, flows, SinkFlow)(s.loc)
check_flow(info, mname, flows, SourceFlow)(s.expr)
case (s: Print) =>
- s.args foreach check_flow(info, mname, flows, SourceFlow)
+ s.args.foreach(check_flow(info, mname, flows, SourceFlow))
check_flow(info, mname, flows, SourceFlow)(s.en)
check_flow(info, mname, flows, SourceFlow)(s.clk)
case (s: PartialConnect) =>
@@ -111,14 +119,14 @@ object CheckFlows extends Pass {
check_flow(info, mname, flows, SourceFlow)(s.en)
case _ =>
}
- s foreach check_flows_e(info, mname, flows)
- s foreach check_flows_s(minfo, mname, flows)
+ s.foreach(check_flows_e(info, mname, flows))
+ s.foreach(check_flows_s(minfo, mname, flows))
}
for (m <- c.modules) {
val flows = new FlowMap
- flows ++= (m.ports map (p => p.name -> to_flow(p.direction)))
- m foreach check_flows_s(m.info, m.name, flows)
+ flows ++= (m.ports.map(p => p.name -> to_flow(p.direction)))
+ m.foreach(check_flows_s(m.info, m.name, flows))
}
errors.trigger()
c
diff --git a/src/main/scala/firrtl/passes/CheckHighForm.scala b/src/main/scala/firrtl/passes/CheckHighForm.scala
index 2f706d35..559c9060 100644
--- a/src/main/scala/firrtl/passes/CheckHighForm.scala
+++ b/src/main/scala/firrtl/passes/CheckHighForm.scala
@@ -27,66 +27,71 @@ trait CheckHighFormLike { this: Pass =>
scopes.find(_.contains(port.mem)).getOrElse(scopes.head) += port.name
}
def legalDecl(name: String): Boolean = !moduleNS.contains(name)
- def legalRef(name: String): Boolean = scopes.exists(_.contains(name))
+ def legalRef(name: String): Boolean = scopes.exists(_.contains(name))
def childScope(): ScopeView = new ScopeView(moduleNS, new NameSet +: scopes)
}
// Custom Exceptions
- class NotUniqueException(info: Info, mname: String, name: String) extends PassException(
- s"$info: [module $mname] Reference $name does not have a unique name.")
- class InvalidLOCException(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Invalid connect to an expression that is not a reference or a WritePort.")
- class NegUIntException(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] UIntLiteral cannot be negative.")
- class UndeclaredReferenceException(info: Info, mname: String, name: String) extends PassException(
- s"$info: [module $mname] Reference $name is not declared.")
- class PoisonWithFlipException(info: Info, mname: String, name: String) extends PassException(
- s"$info: [module $mname] Poison $name cannot be a bundle type with flips.")
- class MemWithFlipException(info: Info, mname: String, name: String) extends PassException(
- s"$info: [module $mname] Memory $name cannot be a bundle type with flips.")
- class IllegalMemLatencyException(info: Info, mname: String, name: String) extends PassException(
- s"$info: [module $mname] Memory $name must have non-negative read latency and positive write latency.")
- class RegWithFlipException(info: Info, mname: String, name: String) extends PassException(
- s"$info: [module $mname] Register $name cannot be a bundle type with flips.")
- class InvalidAccessException(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Invalid access to non-reference.")
- class ModuleNameNotUniqueException(info: Info, mname: String) extends PassException(
- s"$info: Repeat definition of module $mname")
- class DefnameConflictException(info: Info, mname: String, defname: String) extends PassException(
- s"$info: defname $defname of extmodule $mname conflicts with an existing module")
- class DefnameDifferentPortsException(info: Info, mname: String, defname: String) extends PassException(
- s"""$info: ports of extmodule $mname with defname $defname are different for an extmodule with the same defname""")
- class ModuleNotDefinedException(info: Info, mname: String, name: String) extends PassException(
- s"$info: Module $name is not defined.")
- class IncorrectNumArgsException(info: Info, mname: String, op: String, n: Int) extends PassException(
- s"$info: [module $mname] Primop $op requires $n expression arguments.")
- class IncorrectNumConstsException(info: Info, mname: String, op: String, n: Int) extends PassException(
- s"$info: [module $mname] Primop $op requires $n integer arguments.")
- class NegWidthException(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Width cannot be negative.")
- class NegVecSizeException(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Vector type size cannot be negative.")
- class NegMemSizeException(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Memory size cannot be negative or zero.")
- class BadPrintfException(info: Info, mname: String, x: Char) extends PassException(
- s"$info: [module $mname] Bad printf format: " + "\"%" + x + "\"")
- class BadPrintfTrailingException(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Bad printf format: trailing " + "\"%\"")
- class BadPrintfIncorrectNumException(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Bad printf format: incorrect number of arguments")
- class InstanceLoop(info: Info, mname: String, loop: String) extends PassException(
- s"$info: [module $mname] Has instance loop $loop")
- class NoTopModuleException(info: Info, name: String) extends PassException(
- s"$info: A single module must be named $name.")
- class NegArgException(info: Info, mname: String, op: String, value: BigInt) extends PassException(
- s"$info: [module $mname] Primop $op argument $value < 0.")
- class LsbLargerThanMsbException(info: Info, mname: String, op: String, lsb: BigInt, msb: BigInt) extends PassException(
- s"$info: [module $mname] Primop $op lsb $lsb > $msb.")
- class ResetInputException(info: Info, mname: String, expr: Expression) extends PassException(
- s"$info: [module $mname] Abstract Reset not allowed as top-level input: ${expr.serialize}")
- class ResetExtModuleOutputException(info: Info, mname: String, expr: Expression) extends PassException(
- s"$info: [module $mname] Abstract Reset not allowed as ExtModule output: ${expr.serialize}")
-
+ class NotUniqueException(info: Info, mname: String, name: String)
+ extends PassException(s"$info: [module $mname] Reference $name does not have a unique name.")
+ class InvalidLOCException(info: Info, mname: String)
+ extends PassException(
+ s"$info: [module $mname] Invalid connect to an expression that is not a reference or a WritePort."
+ )
+ class NegUIntException(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] UIntLiteral cannot be negative.")
+ class UndeclaredReferenceException(info: Info, mname: String, name: String)
+ extends PassException(s"$info: [module $mname] Reference $name is not declared.")
+ class PoisonWithFlipException(info: Info, mname: String, name: String)
+ extends PassException(s"$info: [module $mname] Poison $name cannot be a bundle type with flips.")
+ class MemWithFlipException(info: Info, mname: String, name: String)
+ extends PassException(s"$info: [module $mname] Memory $name cannot be a bundle type with flips.")
+ class IllegalMemLatencyException(info: Info, mname: String, name: String)
+ extends PassException(
+ s"$info: [module $mname] Memory $name must have non-negative read latency and positive write latency."
+ )
+ class RegWithFlipException(info: Info, mname: String, name: String)
+ extends PassException(s"$info: [module $mname] Register $name cannot be a bundle type with flips.")
+ class InvalidAccessException(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Invalid access to non-reference.")
+ class ModuleNameNotUniqueException(info: Info, mname: String)
+ extends PassException(s"$info: Repeat definition of module $mname")
+ class DefnameConflictException(info: Info, mname: String, defname: String)
+ extends PassException(s"$info: defname $defname of extmodule $mname conflicts with an existing module")
+ class DefnameDifferentPortsException(info: Info, mname: String, defname: String)
+ extends PassException(
+ s"""$info: ports of extmodule $mname with defname $defname are different for an extmodule with the same defname"""
+ )
+ class ModuleNotDefinedException(info: Info, mname: String, name: String)
+ extends PassException(s"$info: Module $name is not defined.")
+ class IncorrectNumArgsException(info: Info, mname: String, op: String, n: Int)
+ extends PassException(s"$info: [module $mname] Primop $op requires $n expression arguments.")
+ class IncorrectNumConstsException(info: Info, mname: String, op: String, n: Int)
+ extends PassException(s"$info: [module $mname] Primop $op requires $n integer arguments.")
+ class NegWidthException(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Width cannot be negative.")
+ class NegVecSizeException(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Vector type size cannot be negative.")
+ class NegMemSizeException(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Memory size cannot be negative or zero.")
+ class BadPrintfException(info: Info, mname: String, x: Char)
+ extends PassException(s"$info: [module $mname] Bad printf format: " + "\"%" + x + "\"")
+ class BadPrintfTrailingException(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Bad printf format: trailing " + "\"%\"")
+ class BadPrintfIncorrectNumException(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Bad printf format: incorrect number of arguments")
+ class InstanceLoop(info: Info, mname: String, loop: String)
+ extends PassException(s"$info: [module $mname] Has instance loop $loop")
+ class NoTopModuleException(info: Info, name: String)
+ extends PassException(s"$info: A single module must be named $name.")
+ class NegArgException(info: Info, mname: String, op: String, value: BigInt)
+ extends PassException(s"$info: [module $mname] Primop $op argument $value < 0.")
+ class LsbLargerThanMsbException(info: Info, mname: String, op: String, lsb: BigInt, msb: BigInt)
+ extends PassException(s"$info: [module $mname] Primop $op lsb $lsb > $msb.")
+ class ResetInputException(info: Info, mname: String, expr: Expression)
+ extends PassException(s"$info: [module $mname] Abstract Reset not allowed as top-level input: ${expr.serialize}")
+ class ResetExtModuleOutputException(info: Info, mname: String, expr: Expression)
+ extends PassException(s"$info: [module $mname] Abstract Reset not allowed as ExtModule output: ${expr.serialize}")
// Is Chirrtl allowed for this check? If not, return an error
def errorOnChirrtl(info: Info, mname: String, s: Statement): Option[PassException]
@@ -94,12 +99,12 @@ trait CheckHighFormLike { this: Pass =>
def run(c: Circuit): Circuit = {
val errors = new Errors()
val moduleGraph = new ModuleGraph
- val moduleNames = (c.modules map (_.name)).toSet
+ val moduleNames = (c.modules.map(_.name)).toSet
val intModuleNames = c.modules.view.collect({ case m: Module => m.name }).toSet
- c.modules.groupBy(_.name).filter(_._2.length > 1).flatMap(_._2).foreach {
- m => errors.append(new ModuleNameNotUniqueException(m.info, m.name))
+ c.modules.groupBy(_.name).filter(_._2.length > 1).flatMap(_._2).foreach { m =>
+ errors.append(new ModuleNameNotUniqueException(m.info, m.name))
}
/** Strip all widths from types */
@@ -110,16 +115,18 @@ trait CheckHighFormLike { this: Pass =>
val extmoduleCollidingPorts = c.modules.collect {
case a: ExtModule => a
- }.groupBy(a => (a.defname, a.params.nonEmpty)).map {
- /* There are no parameters, so all ports must match exactly. */
- case (k@ (_, false), a) =>
- k -> a.map(_.copy(info=NoInfo)).map(_.ports.map(_.copy(info=NoInfo))).toSet
- /* If there are parameters, then only port names must match because parameters could parameterize widths.
- * This means that this check cannot produce false positives, but can have false negatives.
- */
- case (k@ (_, true), a) =>
- k -> a.map(_.copy(info=NoInfo)).map(_.ports.map(_.copy(info=NoInfo).mapType(stripWidth))).toSet
- }.filter(_._2.size > 1)
+ }.groupBy(a => (a.defname, a.params.nonEmpty))
+ .map {
+ /* There are no parameters, so all ports must match exactly. */
+ case (k @ (_, false), a) =>
+ k -> a.map(_.copy(info = NoInfo)).map(_.ports.map(_.copy(info = NoInfo))).toSet
+ /* If there are parameters, then only port names must match because parameters could parameterize widths.
+ * This means that this check cannot produce false positives, but can have false negatives.
+ */
+ case (k @ (_, true), a) =>
+ k -> a.map(_.copy(info = NoInfo)).map(_.ports.map(_.copy(info = NoInfo).mapType(stripWidth))).toSet
+ }
+ .filter(_._2.size > 1)
c.modules.collect {
case a: ExtModule =>
@@ -129,7 +136,8 @@ trait CheckHighFormLike { this: Pass =>
case _ =>
}
a match {
- case ExtModule(info, name, _, defname, params) if extmoduleCollidingPorts.contains((defname, params.nonEmpty)) =>
+ case ExtModule(info, name, _, defname, params)
+ if extmoduleCollidingPorts.contains((defname, params.nonEmpty)) =>
errors.append(new DefnameDifferentPortsException(info, name, defname))
case _ =>
}
@@ -147,14 +155,14 @@ trait CheckHighFormLike { this: Pass =>
}
def nonNegativeConsts(): Unit = {
- e.consts.filter(_ < 0).foreach {
- negC => errors.append(new NegArgException(info, mname, e.op.toString, negC))
+ e.consts.filter(_ < 0).foreach { negC =>
+ errors.append(new NegArgException(info, mname, e.op.toString, negC))
}
}
e.op match {
- case Add | Sub | Mul | Div | Rem | Lt | Leq | Gt | Geq |
- Eq | Neq | Dshl | Dshr | And | Or | Xor | Cat | Dshlw | Clip | Wrap | Squeeze =>
+ case Add | Sub | Mul | Div | Rem | Lt | Leq | Gt | Geq | Eq | Neq | Dshl | Dshr | And | Or | Xor | Cat | Dshlw |
+ Clip | Wrap | Squeeze =>
correctNum(Option(2), 0)
case AsUInt | AsSInt | AsClock | AsAsyncReset | Cvt | Neq | Not =>
correctNum(Option(1), 0)
@@ -175,7 +183,7 @@ trait CheckHighFormLike { this: Pass =>
case AsInterval =>
correctNum(Option(1), 3)
case Andr | Orr | Xorr | Neg =>
- correctNum(None,0)
+ correctNum(None, 0)
}
}
@@ -208,12 +216,12 @@ trait CheckHighFormLike { this: Pass =>
}
def checkHighFormT(info: Info, mname: => String)(t: Type): Unit = {
- t foreach checkHighFormT(info, mname)
+ t.foreach(checkHighFormT(info, mname))
t match {
case tx: VectorType if tx.size < 0 =>
errors.append(new NegVecSizeException(info, mname))
case _: IntervalType =>
- case _ => t foreach checkHighFormW(info, mname)
+ case _ => t.foreach(checkHighFormW(info, mname))
}
}
@@ -235,12 +243,12 @@ trait CheckHighFormLike { this: Pass =>
errors.append(new NegUIntException(info, mname))
case ex: DoPrim => checkHighFormPrimop(info, mname, ex)
case _: Reference | _: WRef | _: UIntLiteral | _: Mux | _: ValidIf =>
- case ex: SubAccess => validSubexp(info, mname)(ex.expr)
+ case ex: SubAccess => validSubexp(info, mname)(ex.expr)
case ex: WSubAccess => validSubexp(info, mname)(ex.expr)
- case ex => ex foreach validSubexp(info, mname)
+ case ex => ex.foreach(validSubexp(info, mname))
}
- e foreach checkHighFormW(info, mname + "/" + e.serialize)
- e foreach checkHighFormE(info, mname, names)
+ e.foreach(checkHighFormW(info, mname + "/" + e.serialize))
+ e.foreach(checkHighFormE(info, mname, names))
}
def checkName(info: Info, mname: String, names: ScopeView)(name: String): Unit = {
@@ -253,14 +261,17 @@ trait CheckHighFormLike { this: Pass =>
if (!moduleNames(child))
errors.append(new ModuleNotDefinedException(info, parent, child))
// Check to see if a recursive module instantiation has occured
- val childToParent = moduleGraph add (parent, child)
+ val childToParent = moduleGraph.add(parent, child)
if (childToParent.nonEmpty)
- errors.append(new InstanceLoop(info, parent, childToParent mkString "->"))
+ errors.append(new InstanceLoop(info, parent, childToParent.mkString("->")))
}
def checkHighFormS(minfo: Info, mname: String, names: ScopeView)(s: Statement): Unit = {
- val info = get_info(s) match {case NoInfo => minfo case x => x}
- s foreach checkName(info, mname, names)
+ val info = get_info(s) match {
+ case NoInfo => minfo
+ case x => x
+ }
+ s.foreach(checkName(info, mname, names))
s match {
case DefRegister(info, name, tpe, _, reset, init) =>
if (hasFlip(tpe))
@@ -272,24 +283,24 @@ trait CheckHighFormLike { this: Pass =>
errors.append(new MemWithFlipException(info, mname, sx.name))
if (sx.depth <= 0)
errors.append(new NegMemSizeException(info, mname))
- case sx: DefInstance => checkInstance(info, mname, sx.module)
- case sx: WDefInstance => checkInstance(info, mname, sx.module)
- case sx: Connect => checkValidLoc(info, mname, sx.loc)
- case sx: PartialConnect => checkValidLoc(info, mname, sx.loc)
- case sx: Print => checkFstring(info, mname, sx.string, sx.args.length)
- case _: CDefMemory => errorOnChirrtl(info, mname, s).foreach { e => errors.append(e) }
+ case sx: DefInstance => checkInstance(info, mname, sx.module)
+ case sx: WDefInstance => checkInstance(info, mname, sx.module)
+ case sx: Connect => checkValidLoc(info, mname, sx.loc)
+ case sx: PartialConnect => checkValidLoc(info, mname, sx.loc)
+ case sx: Print => checkFstring(info, mname, sx.string, sx.args.length)
+ case _: CDefMemory => errorOnChirrtl(info, mname, s).foreach { e => errors.append(e) }
case mport: CDefMPort =>
errorOnChirrtl(info, mname, s).foreach { e => errors.append(e) }
names.expandMPortVisibility(mport)
case sx => // Do Nothing
}
- s foreach checkHighFormT(info, mname)
- s foreach checkHighFormE(info, mname, names)
+ s.foreach(checkHighFormT(info, mname))
+ s.foreach(checkHighFormE(info, mname, names))
s match {
- case Conditionally(_,_, conseq, alt) =>
+ case Conditionally(_, _, conseq, alt) =>
checkHighFormS(minfo, mname, names.childScope())(conseq)
checkHighFormS(minfo, mname, names.childScope())(alt)
- case _ => s foreach checkHighFormS(minfo, mname, names)
+ case _ => s.foreach(checkHighFormS(minfo, mname, names))
}
}
@@ -313,10 +324,10 @@ trait CheckHighFormLike { this: Pass =>
def checkHighFormM(m: DefModule): Unit = {
val names = ScopeView()
- m foreach checkHighFormP(m.name, names)
- m foreach checkHighFormS(m.info, m.name, names)
+ m.foreach(checkHighFormP(m.name, names))
+ m.foreach(checkHighFormS(m.info, m.name, names))
m match {
- case _: Module =>
+ case _: Module =>
case ext: ExtModule =>
for ((port, expr) <- findBadResetTypePorts(ext, Output)) {
errors.append(new ResetExtModuleOutputException(port.info, ext.name, expr))
@@ -324,7 +335,7 @@ trait CheckHighFormLike { this: Pass =>
}
}
- c.modules foreach checkHighFormM
+ c.modules.foreach(checkHighFormM)
c.modules.filter(_.name == c.main) match {
case Seq(topMod) =>
for ((port, expr) <- findBadResetTypePorts(topMod, Input)) {
@@ -342,21 +353,23 @@ object CheckHighForm extends Pass with CheckHighFormLike {
override def prerequisites = firrtl.stage.Forms.WorkingIR
override def optionalPrerequisiteOf =
- Seq( Dependency(passes.ResolveKinds),
- Dependency(passes.InferTypes),
- Dependency(passes.ResolveFlows),
- Dependency[passes.InferWidths],
- Dependency[transforms.InferResets] )
+ Seq(
+ Dependency(passes.ResolveKinds),
+ Dependency(passes.InferTypes),
+ Dependency(passes.ResolveFlows),
+ Dependency[passes.InferWidths],
+ Dependency[transforms.InferResets]
+ )
override def invalidates(a: Transform) = false
- class IllegalChirrtlMemException(info: Info, mname: String, name: String) extends PassException(
- s"$info: [module $mname] Memory $name has not been properly lowered from Chirrtl IR.")
+ class IllegalChirrtlMemException(info: Info, mname: String, name: String)
+ extends PassException(s"$info: [module $mname] Memory $name has not been properly lowered from Chirrtl IR.")
def errorOnChirrtl(info: Info, mname: String, s: Statement): Option[PassException] = {
val memName = s match {
case cm: CDefMemory => cm.name
- case cp: CDefMPort => cp.mem
+ case cp: CDefMPort => cp.mem
}
Some(new IllegalChirrtlMemException(info, mname, memName))
}
diff --git a/src/main/scala/firrtl/passes/CheckInitialization.scala b/src/main/scala/firrtl/passes/CheckInitialization.scala
index 4a5577f9..96057831 100644
--- a/src/main/scala/firrtl/passes/CheckInitialization.scala
+++ b/src/main/scala/firrtl/passes/CheckInitialization.scala
@@ -22,10 +22,11 @@ object CheckInitialization extends Pass {
private case class VoidExpr(stmt: Statement, voidDeps: Seq[Expression])
- class RefNotInitializedException(info: Info, mname: String, name: String, trace: Seq[Statement]) extends PassException(
- s"$info : [module $mname] Reference $name is not fully initialized.\n" +
- trace.map(s => s" ${get_info(s)} : ${s.serialize}").mkString("\n")
- )
+ class RefNotInitializedException(info: Info, mname: String, name: String, trace: Seq[Statement])
+ extends PassException(
+ s"$info : [module $mname] Reference $name is not fully initialized.\n" +
+ trace.map(s => s" ${get_info(s)} : ${s.serialize}").mkString("\n")
+ )
private def getTrace(expr: WrappedExpression, voidExprs: Map[WrappedExpression, VoidExpr]): Seq[Statement] = {
@tailrec
@@ -81,7 +82,7 @@ object CheckInitialization extends Pass {
case node: DefNode => // Ignore nodes
case decl: IsDeclaration =>
val trace = getTrace(expr, voidExprs.toMap)
- errors append new RefNotInitializedException(decl.info, m.name, decl.name, trace)
+ errors.append(new RefNotInitializedException(decl.info, m.name, decl.name, trace))
}
}
}
diff --git a/src/main/scala/firrtl/passes/CheckTypes.scala b/src/main/scala/firrtl/passes/CheckTypes.scala
index c94928a1..956c1134 100644
--- a/src/main/scala/firrtl/passes/CheckTypes.scala
+++ b/src/main/scala/firrtl/passes/CheckTypes.scala
@@ -16,92 +16,105 @@ object CheckTypes extends Pass {
override def prerequisites = Dependency(InferTypes) +: firrtl.stage.Forms.WorkingIR
override def optionalPrerequisiteOf =
- Seq( Dependency(passes.ResolveFlows),
- Dependency(passes.CheckFlows),
- Dependency[passes.InferWidths],
- Dependency(passes.CheckWidths) )
+ Seq(
+ Dependency(passes.ResolveFlows),
+ Dependency(passes.CheckFlows),
+ Dependency[passes.InferWidths],
+ Dependency(passes.CheckWidths)
+ )
override def invalidates(a: Transform) = false
// Custom Exceptions
- class SubfieldNotInBundle(info: Info, mname: String, name: String) extends PassException(
- s"$info: [module $mname ] Subfield $name is not in bundle.")
- class SubfieldOnNonBundle(info: Info, mname: String, name: String) extends PassException(
- s"$info: [module $mname] Subfield $name is accessed on a non-bundle.")
- class IndexTooLarge(info: Info, mname: String, value: Int) extends PassException(
- s"$info: [module $mname] Index with value $value is too large.")
- class IndexOnNonVector(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Index illegal on non-vector type.")
- class AccessIndexNotUInt(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Access index must be a UInt type.")
- class IndexNotUInt(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Index is not of UIntType.")
- class EnableNotUInt(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Enable is not of UIntType.")
+ class SubfieldNotInBundle(info: Info, mname: String, name: String)
+ extends PassException(s"$info: [module $mname ] Subfield $name is not in bundle.")
+ class SubfieldOnNonBundle(info: Info, mname: String, name: String)
+ extends PassException(s"$info: [module $mname] Subfield $name is accessed on a non-bundle.")
+ class IndexTooLarge(info: Info, mname: String, value: Int)
+ extends PassException(s"$info: [module $mname] Index with value $value is too large.")
+ class IndexOnNonVector(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Index illegal on non-vector type.")
+ class AccessIndexNotUInt(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Access index must be a UInt type.")
+ class IndexNotUInt(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Index is not of UIntType.")
+ class EnableNotUInt(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Enable is not of UIntType.")
class InvalidConnect(info: Info, mname: String, con: String, lhs: Expression, rhs: Expression)
extends PassException({
- val ltpe = s" ${lhs.serialize}: ${lhs.tpe.serialize}"
- val rtpe = s" ${rhs.serialize}: ${rhs.tpe.serialize}"
- s"$info: [module $mname] Type mismatch in '$con'.\n$ltpe\n$rtpe"
- })
- class InvalidRegInit(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Type of init must match type of DefRegister.")
- class PrintfArgNotGround(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Printf arguments must be either UIntType or SIntType.")
- class ReqClk(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Requires a clock typed signal.")
- class RegReqClk(info: Info, mname: String, name: String) extends PassException(
- s"$info: [module $mname] Register $name requires a clock typed signal.")
- class EnNotUInt(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Enable must be a UIntType typed signal.")
- class PredNotUInt(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Predicate not a UIntType.")
- class OpNotGround(info: Info, mname: String, op: String) extends PassException(
- s"$info: [module $mname] Primop $op cannot operate on non-ground types.")
- class OpNotUInt(info: Info, mname: String, op: String, e: String) extends PassException(
- s"$info: [module $mname] Primop $op requires argument $e to be a UInt type.")
- class OpNotAllUInt(info: Info, mname: String, op: String) extends PassException(
- s"$info: [module $mname] Primop $op requires all arguments to be UInt type.")
- class OpNotAllSameType(info: Info, mname: String, op: String) extends PassException(
- s"$info: [module $mname] Primop $op requires all operands to have the same type.")
- class OpNoMixFix(info:Info, mname: String, op: String) extends PassException(s"${info}: [module ${mname}] Primop ${op} cannot operate on args of some, but not all, fixed type.")
- class OpNotCorrectType(info:Info, mname: String, op: String, tpes: Seq[String]) extends PassException(s"${info}: [module ${mname}] Primop ${op} does not have correct arg types: $tpes.")
- class OpNotAnalog(info: Info, mname: String, exp: String) extends PassException(
- s"$info: [module $mname] Attach requires all arguments to be Analog type: $exp.")
- class NodePassiveType(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Node must be a passive type.")
- class MuxSameType(info: Info, mname: String, t1: String, t2: String) extends PassException(
- s"$info: [module $mname] Must mux between equivalent types: $t1 != $t2.")
- class MuxPassiveTypes(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Must mux between passive types.")
- class MuxCondUInt(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] A mux condition must be of type UInt.")
- class MuxClock(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Firrtl does not support muxing clocks.")
- class ValidIfPassiveTypes(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Must validif a passive type.")
- class ValidIfCondUInt(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] A validif condition must be of type UInt.")
- class IllegalAnalogDeclaration(info: Info, mname: String, decName: String) extends PassException(
- s"$info: [module $mname] Cannot declare a reg, node, or memory with an Analog type: $decName.")
- class IllegalAttachExp(info: Info, mname: String, expName: String) extends PassException(
- s"$info: [module $mname] Attach expression must be an port, wire, or port of instance: $expName.")
- class IllegalResetType(info: Info, mname: String, exp: String) extends PassException(
- s"$info: [module $mname] Register resets must have type Reset, AsyncReset, or UInt<1>: $exp.")
- class IllegalUnknownType(info: Info, mname: String, exp: String) extends PassException(
- s"$info: [module $mname] Uninferred type: $exp."
- )
+ val ltpe = s" ${lhs.serialize}: ${lhs.tpe.serialize}"
+ val rtpe = s" ${rhs.serialize}: ${rhs.tpe.serialize}"
+ s"$info: [module $mname] Type mismatch in '$con'.\n$ltpe\n$rtpe"
+ })
+ class InvalidRegInit(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Type of init must match type of DefRegister.")
+ class PrintfArgNotGround(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Printf arguments must be either UIntType or SIntType.")
+ class ReqClk(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Requires a clock typed signal.")
+ class RegReqClk(info: Info, mname: String, name: String)
+ extends PassException(s"$info: [module $mname] Register $name requires a clock typed signal.")
+ class EnNotUInt(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Enable must be a UIntType typed signal.")
+ class PredNotUInt(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Predicate not a UIntType.")
+ class OpNotGround(info: Info, mname: String, op: String)
+ extends PassException(s"$info: [module $mname] Primop $op cannot operate on non-ground types.")
+ class OpNotUInt(info: Info, mname: String, op: String, e: String)
+ extends PassException(s"$info: [module $mname] Primop $op requires argument $e to be a UInt type.")
+ class OpNotAllUInt(info: Info, mname: String, op: String)
+ extends PassException(s"$info: [module $mname] Primop $op requires all arguments to be UInt type.")
+ class OpNotAllSameType(info: Info, mname: String, op: String)
+ extends PassException(s"$info: [module $mname] Primop $op requires all operands to have the same type.")
+ class OpNoMixFix(info: Info, mname: String, op: String)
+ extends PassException(
+ s"${info}: [module ${mname}] Primop ${op} cannot operate on args of some, but not all, fixed type."
+ )
+ class OpNotCorrectType(info: Info, mname: String, op: String, tpes: Seq[String])
+ extends PassException(s"${info}: [module ${mname}] Primop ${op} does not have correct arg types: $tpes.")
+ class OpNotAnalog(info: Info, mname: String, exp: String)
+ extends PassException(s"$info: [module $mname] Attach requires all arguments to be Analog type: $exp.")
+ class NodePassiveType(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Node must be a passive type.")
+ class MuxSameType(info: Info, mname: String, t1: String, t2: String)
+ extends PassException(s"$info: [module $mname] Must mux between equivalent types: $t1 != $t2.")
+ class MuxPassiveTypes(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Must mux between passive types.")
+ class MuxCondUInt(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] A mux condition must be of type UInt.")
+ class MuxClock(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Firrtl does not support muxing clocks.")
+ class ValidIfPassiveTypes(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Must validif a passive type.")
+ class ValidIfCondUInt(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] A validif condition must be of type UInt.")
+ class IllegalAnalogDeclaration(info: Info, mname: String, decName: String)
+ extends PassException(
+ s"$info: [module $mname] Cannot declare a reg, node, or memory with an Analog type: $decName."
+ )
+ class IllegalAttachExp(info: Info, mname: String, expName: String)
+ extends PassException(
+ s"$info: [module $mname] Attach expression must be an port, wire, or port of instance: $expName."
+ )
+ class IllegalResetType(info: Info, mname: String, exp: String)
+ extends PassException(
+ s"$info: [module $mname] Register resets must have type Reset, AsyncReset, or UInt<1>: $exp."
+ )
+ class IllegalUnknownType(info: Info, mname: String, exp: String)
+ extends PassException(
+ s"$info: [module $mname] Uninferred type: $exp."
+ )
def fits(bigger: Constraint, smaller: Constraint): Boolean = (bigger, smaller) match {
case (IsKnown(v1), IsKnown(v2)) if v1 < v2 => false
- case _ => true
+ case _ => true
}
def legalResetType(tpe: Type): Boolean = tpe match {
case UIntType(IntWidth(w)) if w == 1 => true
- case AsyncResetType => true
- case ResetType => true
- case UIntType(UnknownWidth) =>
+ case AsyncResetType => true
+ case ResetType => true
+ case UIntType(UnknownWidth) =>
// cannot catch here, though width may ultimately be wrong
true
case _ => false
@@ -118,13 +131,13 @@ object CheckTypes extends Pass {
fits(i2.lower, i1.lower) && fits(i1.upper, i2.upper) && fits(i1.point, i2.point)
case (_: AnalogType, _: AnalogType) => true
case (AsyncResetType, AsyncResetType) => flip1 == flip2
- case (ResetType, tpe) => legalResetType(tpe) && flip1 == flip2
- case (tpe, ResetType) => legalResetType(tpe) && flip1 == flip2
+ case (ResetType, tpe) => legalResetType(tpe) && flip1 == flip2
+ case (tpe, ResetType) => legalResetType(tpe) && flip1 == flip2
case (t1: BundleType, t2: BundleType) =>
- val t1_fields = (t1.fields foldLeft Map[String, (Type, Orientation)]())(
- (map, f1) => map + (f1.name ->( (f1.tpe, f1.flip) )))
- t2.fields forall (f2 =>
- t1_fields get f2.name match {
+ val t1_fields =
+ (t1.fields.foldLeft(Map[String, (Type, Orientation)]()))((map, f1) => map + (f1.name -> ((f1.tpe, f1.flip))))
+ t2.fields.forall(f2 =>
+ t1_fields.get(f2.name) match {
case None => true
case Some((f1_tpe, f1_flip)) =>
bulk_equals(f1_tpe, f2.tpe, times(flip1, f1_flip), times(flip2, f2.flip))
@@ -155,79 +168,155 @@ object CheckTypes extends Pass {
def ut: UIntType = UIntType(UnknownWidth)
def st: SIntType = SIntType(UnknownWidth)
- def run (c:Circuit) : Circuit = {
+ def run(c: Circuit): Circuit = {
val errors = new Errors()
def passive(t: Type): Boolean = t match {
- case _: UIntType |_: SIntType => true
+ case _: UIntType | _: SIntType => true
case tx: VectorType => passive(tx.tpe)
- case tx: BundleType => tx.fields forall (x => x.flip == Default && passive(x.tpe))
+ case tx: BundleType => tx.fields.forall(x => x.flip == Default && passive(x.tpe))
case tx => true
}
def check_types_primop(info: Info, mname: String, e: DoPrim): Unit = {
- def checkAllTypes(exprs: Seq[Expression], okUInt: Boolean, okSInt: Boolean, okClock: Boolean, okFix: Boolean, okAsync: Boolean, okInterval: Boolean): Unit = {
+ def checkAllTypes(
+ exprs: Seq[Expression],
+ okUInt: Boolean,
+ okSInt: Boolean,
+ okClock: Boolean,
+ okFix: Boolean,
+ okAsync: Boolean,
+ okInterval: Boolean
+ ): Unit = {
exprs.foldLeft((false, false, false, false, false, false)) {
- case ((isUInt, isSInt, isClock, isFix, isAsync, isInterval), expr) => expr.tpe match {
- case u: UIntType => (true, isSInt, isClock, isFix, isAsync, isInterval)
- case s: SIntType => (isUInt, true, isClock, isFix, isAsync, isInterval)
- case ClockType => (isUInt, isSInt, true, isFix, isAsync, isInterval)
- case f: FixedType => (isUInt, isSInt, isClock, true, isAsync, isInterval)
- case AsyncResetType => (isUInt, isSInt, isClock, isFix, true, isInterval)
- case i:IntervalType => (isUInt, isSInt, isClock, isFix, isAsync, true)
- case UnknownType =>
- errors.append(new IllegalUnknownType(info, mname, e.serialize))
- (isUInt, isSInt, isClock, isFix, isAsync, isInterval)
- case other => throwInternalError(s"Illegal Type: ${other.serialize}")
- }
+ case ((isUInt, isSInt, isClock, isFix, isAsync, isInterval), expr) =>
+ expr.tpe match {
+ case u: UIntType => (true, isSInt, isClock, isFix, isAsync, isInterval)
+ case s: SIntType => (isUInt, true, isClock, isFix, isAsync, isInterval)
+ case ClockType => (isUInt, isSInt, true, isFix, isAsync, isInterval)
+ case f: FixedType => (isUInt, isSInt, isClock, true, isAsync, isInterval)
+ case AsyncResetType => (isUInt, isSInt, isClock, isFix, true, isInterval)
+ case i: IntervalType => (isUInt, isSInt, isClock, isFix, isAsync, true)
+ case UnknownType =>
+ errors.append(new IllegalUnknownType(info, mname, e.serialize))
+ (isUInt, isSInt, isClock, isFix, isAsync, isInterval)
+ case other => throwInternalError(s"Illegal Type: ${other.serialize}")
+ }
} match {
// (UInt, SInt, Clock, Fixed, Async, Interval)
- case (isAll, false, false, false, false, false) if isAll == okUInt =>
- case (false, isAll, false, false, false, false) if isAll == okSInt =>
- case (false, false, isAll, false, false, false) if isAll == okClock =>
- case (false, false, false, isAll, false, false) if isAll == okFix =>
- case (false, false, false, false, isAll, false) if isAll == okAsync =>
+ case (isAll, false, false, false, false, false) if isAll == okUInt =>
+ case (false, isAll, false, false, false, false) if isAll == okSInt =>
+ case (false, false, isAll, false, false, false) if isAll == okClock =>
+ case (false, false, false, isAll, false, false) if isAll == okFix =>
+ case (false, false, false, false, isAll, false) if isAll == okAsync =>
case (false, false, false, false, false, isAll) if isAll == okInterval =>
- case x => errors.append(new OpNotCorrectType(info, mname, e.op.serialize, exprs.map(_.tpe.serialize)))
+ case x => errors.append(new OpNotCorrectType(info, mname, e.op.serialize, exprs.map(_.tpe.serialize)))
}
}
e.op match {
case AsUInt | AsSInt | AsClock | AsFixedPoint | AsAsyncReset | AsInterval =>
- // All types are ok
+ // All types are ok
case Dshl | Dshr =>
- checkAllTypes(Seq(e.args.head), okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true)
- checkAllTypes(Seq(e.args(1)), okUInt=true, okSInt=false, okClock=false, okFix=false, okAsync=false, okInterval=false)
+ checkAllTypes(
+ Seq(e.args.head),
+ okUInt = true,
+ okSInt = true,
+ okClock = false,
+ okFix = true,
+ okAsync = false,
+ okInterval = true
+ )
+ checkAllTypes(
+ Seq(e.args(1)),
+ okUInt = true,
+ okSInt = false,
+ okClock = false,
+ okFix = false,
+ okAsync = false,
+ okInterval = false
+ )
case Add | Sub | Mul | Lt | Leq | Gt | Geq | Eq | Neq =>
- checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true)
+ checkAllTypes(
+ e.args,
+ okUInt = true,
+ okSInt = true,
+ okClock = false,
+ okFix = true,
+ okAsync = false,
+ okInterval = true
+ )
case Pad | Bits | Head | Tail =>
- checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=false)
+ checkAllTypes(
+ e.args,
+ okUInt = true,
+ okSInt = true,
+ okClock = false,
+ okFix = true,
+ okAsync = false,
+ okInterval = false
+ )
case Shl | Shr | Cat =>
- checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true)
+ checkAllTypes(
+ e.args,
+ okUInt = true,
+ okSInt = true,
+ okClock = false,
+ okFix = true,
+ okAsync = false,
+ okInterval = true
+ )
case IncP | DecP | SetP =>
- checkAllTypes(e.args, okUInt=false, okSInt=false, okClock=false, okFix=true, okAsync=false, okInterval=true)
+ checkAllTypes(
+ e.args,
+ okUInt = false,
+ okSInt = false,
+ okClock = false,
+ okFix = true,
+ okAsync = false,
+ okInterval = true
+ )
case Wrap | Clip | Squeeze =>
- checkAllTypes(e.args, okUInt = false, okSInt = false, okClock = false, okFix = false, okAsync=false, okInterval = true)
+ checkAllTypes(
+ e.args,
+ okUInt = false,
+ okSInt = false,
+ okClock = false,
+ okFix = false,
+ okAsync = false,
+ okInterval = true
+ )
case _ =>
- checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=false, okAsync=false, okInterval=false)
+ checkAllTypes(
+ e.args,
+ okUInt = true,
+ okSInt = true,
+ okClock = false,
+ okFix = false,
+ okAsync = false,
+ okInterval = false
+ )
}
}
- def check_types_e(info:Info, mname: String)(e: Expression): Unit = {
+ def check_types_e(info: Info, mname: String)(e: Expression): Unit = {
e match {
- case (e: WSubField) => e.expr.tpe match {
- case (t: BundleType) => t.fields find (_.name == e.name) match {
- case Some(_) =>
- case None => errors.append(new SubfieldNotInBundle(info, mname, e.name))
+ case (e: WSubField) =>
+ e.expr.tpe match {
+ case (t: BundleType) =>
+ t.fields.find(_.name == e.name) match {
+ case Some(_) =>
+ case None => errors.append(new SubfieldNotInBundle(info, mname, e.name))
+ }
+ case _ => errors.append(new SubfieldOnNonBundle(info, mname, e.name))
+ }
+ case (e: WSubIndex) =>
+ e.expr.tpe match {
+ case (t: VectorType) if e.value < t.size =>
+ case (t: VectorType) =>
+ errors.append(new IndexTooLarge(info, mname, e.value))
+ case _ =>
+ errors.append(new IndexOnNonVector(info, mname))
}
- case _ => errors.append(new SubfieldOnNonBundle(info, mname, e.name))
- }
- case (e: WSubIndex) => e.expr.tpe match {
- case (t: VectorType) if e.value < t.size =>
- case (t: VectorType) =>
- errors.append(new IndexTooLarge(info, mname, e.value))
- case _ =>
- errors.append(new IndexOnNonVector(info, mname))
- }
case (e: WSubAccess) =>
e.expr.tpe match {
case _: VectorType =>
@@ -256,11 +345,14 @@ object CheckTypes extends Pass {
}
case _ =>
}
- e foreach check_types_e(info, mname)
+ e.foreach(check_types_e(info, mname))
}
def check_types_s(minfo: Info, mname: String)(s: Statement): Unit = {
- val info = get_info(s) match { case NoInfo => minfo case x => x }
+ val info = get_info(s) match {
+ case NoInfo => minfo
+ case x => x
+ }
s match {
case sx: Connect if !validConnect(sx) =>
val conMsg = sx.copy(info = NoInfo).serialize
@@ -270,7 +362,7 @@ object CheckTypes extends Pass {
errors.append(new InvalidConnect(info, mname, conMsg, sx.loc, sx.expr))
case sx: DefRegister =>
sx.tpe match {
- case AnalogType(_) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name))
+ case AnalogType(_) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name))
case t if wt(sx.tpe) != wt(sx.init.tpe) => errors.append(new InvalidRegInit(info, mname))
case t if !validConnect(sx.tpe, sx.init.tpe) =>
val conMsg = sx.copy(info = NoInfo).serialize
@@ -285,11 +377,12 @@ object CheckTypes extends Pass {
}
case sx: Conditionally if wt(sx.pred.tpe) != wt(ut) =>
errors.append(new PredNotUInt(info, mname))
- case sx: DefNode => sx.value.tpe match {
- case AnalogType(w) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name))
- case t if !passive(sx.value.tpe) => errors.append(new NodePassiveType(info, mname))
- case t =>
- }
+ case sx: DefNode =>
+ sx.value.tpe match {
+ case AnalogType(w) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name))
+ case t if !passive(sx.value.tpe) => errors.append(new NodePassiveType(info, mname))
+ case t =>
+ }
case sx: Attach =>
for (e <- sx.exprs) {
e.tpe match {
@@ -298,14 +391,14 @@ object CheckTypes extends Pass {
}
kind(e) match {
case (InstanceKind | PortKind | WireKind) =>
- case _ => errors.append(new IllegalAttachExp(info, mname, e.serialize))
+ case _ => errors.append(new IllegalAttachExp(info, mname, e.serialize))
}
}
case sx: Stop =>
if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname))
if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname))
case sx: Print =>
- if (sx.args exists (x => wt(x.tpe) != wt(ut) && wt(x.tpe) != wt(st)))
+ if (sx.args.exists(x => wt(x.tpe) != wt(ut) && wt(x.tpe) != wt(st)))
errors.append(new PrintfArgNotGround(info, mname))
if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname))
if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname))
@@ -313,17 +406,18 @@ object CheckTypes extends Pass {
if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname))
if (wt(sx.pred.tpe) != wt(ut)) errors.append(new PredNotUInt(info, mname))
if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname))
- case sx: DefMemory => sx.dataType match {
- case AnalogType(w) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name))
- case t =>
- }
+ case sx: DefMemory =>
+ sx.dataType match {
+ case AnalogType(w) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name))
+ case t =>
+ }
case _ =>
}
- s foreach check_types_e(info, mname)
- s foreach check_types_s(info, mname)
+ s.foreach(check_types_e(info, mname))
+ s.foreach(check_types_s(info, mname))
}
- c.modules foreach (m => m foreach check_types_s(m.info, m.name))
+ c.modules.foreach(m => m.foreach(check_types_s(m.info, m.name)))
errors.trigger()
c
}
diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala
index a7729ef8..f7fefa87 100644
--- a/src/main/scala/firrtl/passes/CheckWidths.scala
+++ b/src/main/scala/firrtl/passes/CheckWidths.scala
@@ -22,43 +22,49 @@ object CheckWidths extends Pass {
/** The maximum allowed width for any circuit element */
val MaxWidth = 1000000
val DshlMaxWidth = getUIntWidth(MaxWidth)
- class UninferredWidth (info: Info, target: String) extends PassException(
- s"""|$info : Uninferred width for target below.serialize}. (Did you forget to assign to it?)
- |$target""".stripMargin)
- class UninferredBound (info: Info, target: String, bound: String) extends PassException(
- s"""|$info : Uninferred $bound bound for target. (Did you forget to assign to it?)
- |$target""".stripMargin)
- class InvalidRange (info: Info, target: String, i: IntervalType) extends PassException(
- s"""|$info : Invalid range ${i.serialize} for target below. (Are the bounds valid?)
- |$target""".stripMargin)
- class WidthTooSmall(info: Info, mname: String, b: BigInt) extends PassException(
- s"$info : [target $mname] Width too small for constant $b.")
- class WidthTooBig(info: Info, mname: String, b: BigInt) extends PassException(
- s"$info : [target $mname] Width $b greater than max allowed width of $MaxWidth bits")
- class DshlTooBig(info: Info, mname: String) extends PassException(
- s"$info : [target $mname] Width of dshl shift amount must be less than $DshlMaxWidth bits.")
- class MultiBitAsClock(info: Info, mname: String) extends PassException(
- s"$info : [target $mname] Cannot cast a multi-bit signal to a Clock.")
- class MultiBitAsAsyncReset(info: Info, mname: String) extends PassException(
- s"$info : [target $mname] Cannot cast a multi-bit signal to an AsyncReset.")
- class NegWidthException(info:Info, mname: String) extends PassException(
- s"$info: [target $mname] Width cannot be negative or zero.")
- class BitsWidthException(info: Info, mname: String, hi: BigInt, width: BigInt, exp: String) extends PassException(
- s"$info: [target $mname] High bit $hi in bits operator is larger than input width $width in $exp.")
- class HeadWidthException(info: Info, mname: String, n: BigInt, width: BigInt) extends PassException(
- s"$info: [target $mname] Parameter $n in head operator is larger than input width $width.")
- class TailWidthException(info: Info, mname: String, n: BigInt, width: BigInt) extends PassException(
- s"$info: [target $mname] Parameter $n in tail operator is larger than input width $width.")
- class AttachWidthsNotEqual(info: Info, mname: String, eName: String, source: String) extends PassException(
- s"$info: [target $mname] Attach source $source and expression $eName must have identical widths.")
+ class UninferredWidth(info: Info, target: String)
+ extends PassException(s"""|$info : Uninferred width for target below.serialize}. (Did you forget to assign to it?)
+ |$target""".stripMargin)
+ class UninferredBound(info: Info, target: String, bound: String)
+ extends PassException(s"""|$info : Uninferred $bound bound for target. (Did you forget to assign to it?)
+ |$target""".stripMargin)
+ class InvalidRange(info: Info, target: String, i: IntervalType)
+ extends PassException(s"""|$info : Invalid range ${i.serialize} for target below. (Are the bounds valid?)
+ |$target""".stripMargin)
+ class WidthTooSmall(info: Info, mname: String, b: BigInt)
+ extends PassException(s"$info : [target $mname] Width too small for constant $b.")
+ class WidthTooBig(info: Info, mname: String, b: BigInt)
+ extends PassException(s"$info : [target $mname] Width $b greater than max allowed width of $MaxWidth bits")
+ class DshlTooBig(info: Info, mname: String)
+ extends PassException(
+ s"$info : [target $mname] Width of dshl shift amount must be less than $DshlMaxWidth bits."
+ )
+ class MultiBitAsClock(info: Info, mname: String)
+ extends PassException(s"$info : [target $mname] Cannot cast a multi-bit signal to a Clock.")
+ class MultiBitAsAsyncReset(info: Info, mname: String)
+ extends PassException(s"$info : [target $mname] Cannot cast a multi-bit signal to an AsyncReset.")
+ class NegWidthException(info: Info, mname: String)
+ extends PassException(s"$info: [target $mname] Width cannot be negative or zero.")
+ class BitsWidthException(info: Info, mname: String, hi: BigInt, width: BigInt, exp: String)
+ extends PassException(
+ s"$info: [target $mname] High bit $hi in bits operator is larger than input width $width in $exp."
+ )
+ class HeadWidthException(info: Info, mname: String, n: BigInt, width: BigInt)
+ extends PassException(s"$info: [target $mname] Parameter $n in head operator is larger than input width $width.")
+ class TailWidthException(info: Info, mname: String, n: BigInt, width: BigInt)
+ extends PassException(s"$info: [target $mname] Parameter $n in tail operator is larger than input width $width.")
+ class AttachWidthsNotEqual(info: Info, mname: String, eName: String, source: String)
+ extends PassException(
+ s"$info: [target $mname] Attach source $source and expression $eName must have identical widths."
+ )
class DisjointSqueeze(info: Info, mname: String, squeeze: DoPrim)
- extends PassException({
- val toSqz = squeeze.args.head.serialize
- val toSqzTpe = squeeze.args.head.tpe.serialize
- val sqzTo = squeeze.args(1).serialize
- val sqzToTpe = squeeze.args(1).tpe.serialize
- s"$info: [module $mname] Disjoint squz currently unsupported: $toSqz:$toSqzTpe cannot be squeezed with $sqzTo's type $sqzToTpe"
- })
+ extends PassException({
+ val toSqz = squeeze.args.head.serialize
+ val toSqzTpe = squeeze.args.head.tpe.serialize
+ val sqzTo = squeeze.args(1).serialize
+ val sqzToTpe = squeeze.args(1).tpe.serialize
+ s"$info: [module $mname] Disjoint squz currently unsupported: $toSqz:$toSqzTpe cannot be squeezed with $sqzTo's type $sqzToTpe"
+ })
def run(c: Circuit): Circuit = {
val errors = new Errors()
@@ -77,35 +83,35 @@ object CheckWidths extends Pass {
def hasWidth(tpe: Type): Boolean = tpe match {
case GroundType(IntWidth(w)) => true
- case GroundType(_) => false
- case _ => throwInternalError(s"hasWidth - $tpe")
+ case GroundType(_) => false
+ case _ => throwInternalError(s"hasWidth - $tpe")
}
def check_width_t(info: Info, target: Target)(t: Type): Unit = {
t match {
case tt: BundleType => tt.fields.foreach(check_width_f(info, target))
//Supports when l = u (if closed)
- case i@IntervalType(Closed(l), Closed(u), IntWidth(_)) if l <= u => i
- case i:IntervalType if i.range == Some(Nil) =>
+ case i @ IntervalType(Closed(l), Closed(u), IntWidth(_)) if l <= u => i
+ case i: IntervalType if i.range == Some(Nil) =>
errors.append(new InvalidRange(info, target.prettyPrint(" "), i))
i
- case i@IntervalType(KnownBound(l), KnownBound(u), IntWidth(p)) if l >= u =>
+ case i @ IntervalType(KnownBound(l), KnownBound(u), IntWidth(p)) if l >= u =>
errors.append(new InvalidRange(info, target.prettyPrint(" "), i))
i
- case i@IntervalType(KnownBound(_), KnownBound(_), IntWidth(_)) => i
- case i@IntervalType(_: IsKnown, _, _) =>
+ case i @ IntervalType(KnownBound(_), KnownBound(_), IntWidth(_)) => i
+ case i @ IntervalType(_: IsKnown, _, _) =>
errors.append(new UninferredBound(info, target.prettyPrint(" "), "upper"))
i
- case i@IntervalType(_, _: IsKnown, _) =>
+ case i @ IntervalType(_, _: IsKnown, _) =>
errors.append(new UninferredBound(info, target.prettyPrint(" "), "lower"))
i
- case i@IntervalType(_, _, _) =>
+ case i @ IntervalType(_, _, _) =>
errors.append(new UninferredBound(info, target.prettyPrint(" "), "lower"))
errors.append(new UninferredBound(info, target.prettyPrint(" "), "upper"))
i
- case tt => tt foreach check_width_t(info, target)
+ case tt => tt.foreach(check_width_t(info, target))
}
- t foreach check_width_w(info, target, t)
+ t.foreach(check_width_w(info, target, t))
}
def check_width_f(info: Info, target: Target)(f: Field): Unit =
@@ -120,7 +126,8 @@ object CheckWidths extends Pass {
errors.append(new WidthTooSmall(info, target.serialize, v))
case e @ DoPrim(op, Seq(a, b), _, tpe) =>
(op, a.tpe, b.tpe) match {
- case (Squeeze, IntervalType(Closed(la), Closed(ua), _), IntervalType(Closed(lb), Closed(ub), _)) if (ua < lb) || (ub < la) =>
+ case (Squeeze, IntervalType(Closed(la), Closed(ua), _), IntervalType(Closed(lb), Closed(ub), _))
+ if (ua < lb) || (ub < la) =>
errors.append(new DisjointSqueeze(info, target.serialize, e))
case (Dshl, at, bt) if (hasWidth(at) && bitWidth(bt) >= DshlMaxWidth) =>
errors.append(new DshlTooBig(info, target.serialize))
@@ -159,7 +166,6 @@ object CheckWidths extends Pass {
}
}
-
def check_width_e_dfs(info: Info, target: Target, expr: Expression): Unit = {
val stack = collection.mutable.ArrayStack(expr)
def push(e: Expression): Unit = stack.push(e)
@@ -171,25 +177,31 @@ object CheckWidths extends Pass {
}
def check_width_s(minfo: Info, target: ModuleTarget)(s: Statement): Unit = {
- val info = get_info(s) match { case NoInfo => minfo case x => x }
- val subRef = s match { case sx: HasName => target.ref(sx.name) case _ => target }
- s foreach check_width_e(info, target, 4)
- s foreach check_width_s(info, target)
- s foreach check_width_t(info, subRef)
+ val info = get_info(s) match {
+ case NoInfo => minfo
+ case x => x
+ }
+ val subRef = s match {
+ case sx: HasName => target.ref(sx.name)
+ case _ => target
+ }
+ s.foreach(check_width_e(info, target, 4))
+ s.foreach(check_width_s(info, target))
+ s.foreach(check_width_t(info, subRef))
s match {
case Attach(infox, exprs) =>
- exprs.tail.foreach ( e =>
+ exprs.tail.foreach(e =>
if (bitWidth(e.tpe) != bitWidth(exprs.head.tpe))
errors.append(new AttachWidthsNotEqual(infox, target.serialize, e.serialize, exprs.head.serialize))
)
case sx: DefRegister =>
sx.reset.tpe match {
case UIntType(IntWidth(w)) if w == 1 =>
- case AsyncResetType =>
- case ResetType =>
- case _ => errors.append(new CheckTypes.IllegalResetType(info, target.serialize, sx.name))
+ case AsyncResetType =>
+ case ResetType =>
+ case _ => errors.append(new CheckTypes.IllegalResetType(info, target.serialize, sx.name))
}
- if(!CheckTypes.validConnect(sx.tpe, sx.init.tpe)) {
+ if (!CheckTypes.validConnect(sx.tpe, sx.init.tpe)) {
val conMsg = sx.copy(info = NoInfo).serialize
errors.append(new CheckTypes.InvalidConnect(info, target.module, conMsg, WRef(sx), sx.init))
}
@@ -197,14 +209,15 @@ object CheckWidths extends Pass {
}
}
- def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Unit = check_width_t(p.info, target.ref(p.name))(p.tpe)
+ def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Unit =
+ check_width_t(p.info, target.ref(p.name))(p.tpe)
def check_width_m(circuit: CircuitTarget)(m: DefModule): Unit = {
- m foreach check_width_p(m.info, circuit.module(m.name))
- m foreach check_width_s(m.info, circuit.module(m.name))
+ m.foreach(check_width_p(m.info, circuit.module(m.name)))
+ m.foreach(check_width_s(m.info, circuit.module(m.name)))
}
- c.modules foreach check_width_m(CircuitTarget(c.main))
+ c.modules.foreach(check_width_m(CircuitTarget(c.main)))
errors.trigger()
c
}
diff --git a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala
index 544f90a6..55a9c53a 100644
--- a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala
+++ b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala
@@ -10,15 +10,16 @@ import firrtl.options.Dependency
object CommonSubexpressionElimination extends Pass {
override def prerequisites = firrtl.stage.Forms.LowForm ++
- Seq( Dependency(firrtl.passes.RemoveValidIf),
- Dependency[firrtl.transforms.ConstantPropagation],
- Dependency(firrtl.passes.memlib.VerilogMemDelays),
- Dependency(firrtl.passes.SplitExpressions),
- Dependency[firrtl.transforms.CombineCats] )
+ Seq(
+ Dependency(firrtl.passes.RemoveValidIf),
+ Dependency[firrtl.transforms.ConstantPropagation],
+ Dependency(firrtl.passes.memlib.VerilogMemDelays),
+ Dependency(firrtl.passes.SplitExpressions),
+ Dependency[firrtl.transforms.CombineCats]
+ )
override def optionalPrerequisiteOf =
- Seq( Dependency[SystemVerilogEmitter],
- Dependency[VerilogEmitter] )
+ Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
override def invalidates(a: Transform) = false
@@ -27,24 +28,26 @@ object CommonSubexpressionElimination extends Pass {
val nodes = collection.mutable.HashMap[String, Expression]()
def eliminateNodeRef(e: Expression): Expression = e match {
- case WRef(name, tpe, kind, flow) => nodes get name match {
- case Some(expression) => expressions get expression match {
- case Some(cseName) if cseName != name =>
- WRef(cseName, tpe, kind, flow)
+ case WRef(name, tpe, kind, flow) =>
+ nodes.get(name) match {
+ case Some(expression) =>
+ expressions.get(expression) match {
+ case Some(cseName) if cseName != name =>
+ WRef(cseName, tpe, kind, flow)
+ case _ => e
+ }
case _ => e
}
- case _ => e
- }
- case _ => e map eliminateNodeRef
+ case _ => e.map(eliminateNodeRef)
}
def eliminateNodeRefs(s: Statement): Statement = {
- s map eliminateNodeRef match {
+ s.map(eliminateNodeRef) match {
case x: DefNode =>
nodes(x.name) = x.value
expressions.getOrElseUpdate(x.value, x.name)
x
- case other => other map eliminateNodeRefs
+ case other => other.map(eliminateNodeRefs)
}
}
@@ -54,7 +57,7 @@ object CommonSubexpressionElimination extends Pass {
def run(c: Circuit): Circuit = {
val modulesx = c.modules.map {
case m: ExtModule => m
- case m: Module => Module(m.info, m.name, m.ports, cse(m.body))
+ case m: Module => Module(m.info, m.name, m.ports, cse(m.body))
}
Circuit(c.info, modulesx, c.main)
}
diff --git a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala
index 4a426209..baf7d4d5 100644
--- a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala
+++ b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala
@@ -7,7 +7,7 @@ import firrtl.PrimOps._
import firrtl.ir._
import firrtl._
import firrtl.Mappers._
-import firrtl.Utils.{sub_type, module_type, field_type, max, throwInternalError}
+import firrtl.Utils.{field_type, max, module_type, sub_type, throwInternalError}
import firrtl.options.Dependency
/** Replaces FixedType with SIntType, and correctly aligns all binary points
@@ -15,71 +15,74 @@ import firrtl.options.Dependency
object ConvertFixedToSInt extends Pass {
override def prerequisites =
- Seq( Dependency(PullMuxes),
- Dependency(ReplaceAccesses),
- Dependency(ExpandConnects),
- Dependency(RemoveAccesses),
- Dependency[ExpandWhensAndCheck],
- Dependency[RemoveIntervals] ) ++ firrtl.stage.Forms.Deduped
+ Seq(
+ Dependency(PullMuxes),
+ Dependency(ReplaceAccesses),
+ Dependency(ExpandConnects),
+ Dependency(RemoveAccesses),
+ Dependency[ExpandWhensAndCheck],
+ Dependency[RemoveIntervals]
+ ) ++ firrtl.stage.Forms.Deduped
override def invalidates(a: Transform) = false
def alignArg(e: Expression, point: BigInt): Expression = e.tpe match {
case FixedType(IntWidth(w), IntWidth(p)) => // assert(point >= p)
- if((point - p) > 0) {
+ if ((point - p) > 0) {
DoPrim(Shl, Seq(e), Seq(point - p), UnknownType)
} else if (point - p < 0) {
DoPrim(Shr, Seq(e), Seq(p - point), UnknownType)
} else e
case FixedType(w, p) => throwInternalError(s"alignArg: shouldn't be here - $e")
- case _ => e
+ case _ => e
}
def calcPoint(es: Seq[Expression]): BigInt =
es.map(_.tpe match {
case FixedType(IntWidth(w), IntWidth(p)) => p
- case _ => BigInt(0)
+ case _ => BigInt(0)
}).reduce(max(_, _))
def toSIntType(t: Type): Type = t match {
case FixedType(IntWidth(w), IntWidth(p)) => SIntType(IntWidth(w))
- case FixedType(w, p) => throwInternalError(s"toSIntType: shouldn't be here - $t")
- case _ => t map toSIntType
+ case FixedType(w, p) => throwInternalError(s"toSIntType: shouldn't be here - $t")
+ case _ => t.map(toSIntType)
}
def run(c: Circuit): Circuit = {
- val moduleTypes = mutable.HashMap[String,Type]()
- def onModule(m:DefModule) : DefModule = {
- val types = mutable.HashMap[String,Type]()
- def updateExpType(e:Expression): Expression = e match {
- case DoPrim(Mul, args, consts, tpe) => e map updateExpType
- case DoPrim(AsFixedPoint, args, consts, tpe) => DoPrim(AsSInt, args, Seq.empty, tpe) map updateExpType
- case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe) map updateExpType
- case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe) map updateExpType
- case DoPrim(SetP, args, consts, FixedType(w, IntWidth(p))) => alignArg(args.head, p) map updateExpType
+ val moduleTypes = mutable.HashMap[String, Type]()
+ def onModule(m: DefModule): DefModule = {
+ val types = mutable.HashMap[String, Type]()
+ def updateExpType(e: Expression): Expression = e match {
+ case DoPrim(Mul, args, consts, tpe) => e.map(updateExpType)
+ case DoPrim(AsFixedPoint, args, consts, tpe) => DoPrim(AsSInt, args, Seq.empty, tpe).map(updateExpType)
+ case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe).map(updateExpType)
+ case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe).map(updateExpType)
+ case DoPrim(SetP, args, consts, FixedType(w, IntWidth(p))) => alignArg(args.head, p).map(updateExpType)
case DoPrim(op, args, consts, tpe) =>
val point = calcPoint(args)
val newExp = DoPrim(op, args.map(x => alignArg(x, point)), consts, UnknownType)
- newExp map updateExpType match {
+ newExp.map(updateExpType) match {
case DoPrim(AsFixedPoint, args, consts, tpe) => DoPrim(AsSInt, args, Seq.empty, tpe)
- case e => e
+ case e => e
}
case Mux(cond, tval, fval, tpe) =>
val point = calcPoint(Seq(tval, fval))
val newExp = Mux(cond, alignArg(tval, point), alignArg(fval, point), UnknownType)
- newExp map updateExpType
+ newExp.map(updateExpType)
case e: UIntLiteral => e
case e: SIntLiteral => e
- case _ => e map updateExpType match {
- case ValidIf(cond, value, tpe) => ValidIf(cond, value, value.tpe)
- case WRef(name, tpe, k, g) => WRef(name, types(name), k, g)
- case WSubField(exp, name, tpe, g) => WSubField(exp, name, field_type(exp.tpe, name), g)
- case WSubIndex(exp, value, tpe, g) => WSubIndex(exp, value, sub_type(exp.tpe), g)
- case WSubAccess(exp, index, tpe, g) => WSubAccess(exp, index, sub_type(exp.tpe), g)
- }
+ case _ =>
+ e.map(updateExpType) match {
+ case ValidIf(cond, value, tpe) => ValidIf(cond, value, value.tpe)
+ case WRef(name, tpe, k, g) => WRef(name, types(name), k, g)
+ case WSubField(exp, name, tpe, g) => WSubField(exp, name, field_type(exp.tpe, name), g)
+ case WSubIndex(exp, value, tpe, g) => WSubIndex(exp, value, sub_type(exp.tpe), g)
+ case WSubAccess(exp, index, tpe, g) => WSubAccess(exp, index, sub_type(exp.tpe), g)
+ }
}
def updateStmtType(s: Statement): Statement = s match {
case DefRegister(info, name, tpe, clock, reset, init) =>
val newType = toSIntType(tpe)
types(name) = newType
- DefRegister(info, name, newType, clock, reset, init) map updateExpType
+ DefRegister(info, name, newType, clock, reset, init).map(updateExpType)
case DefWire(info, name, tpe) =>
val newType = toSIntType(tpe)
types(name) = newType
@@ -101,37 +104,34 @@ object ConvertFixedToSInt extends Pass {
case Connect(info, loc, exp) =>
val point = calcPoint(Seq(loc))
val newExp = alignArg(exp, point)
- Connect(info, loc, newExp) map updateExpType
+ Connect(info, loc, newExp).map(updateExpType)
case PartialConnect(info, loc, exp) =>
val point = calcPoint(Seq(loc))
val newExp = alignArg(exp, point)
- PartialConnect(info, loc, newExp) map updateExpType
+ PartialConnect(info, loc, newExp).map(updateExpType)
// check Connect case, need to shl
- case s => (s map updateStmtType) map updateExpType
+ case s => (s.map(updateStmtType)).map(updateExpType)
}
m.ports.foreach(p => types(p.name) = p.tpe)
m match {
- case Module(info, name, ports, body) => Module(info,name,ports,updateStmtType(body))
- case m:ExtModule => m
+ case Module(info, name, ports, body) => Module(info, name, ports, updateStmtType(body))
+ case m: ExtModule => m
}
}
- val newModules = for(m <- c.modules) yield {
- val newPorts = m.ports.map(p => Port(p.info,p.name,p.direction,toSIntType(p.tpe)))
+ val newModules = for (m <- c.modules) yield {
+ val newPorts = m.ports.map(p => Port(p.info, p.name, p.direction, toSIntType(p.tpe)))
m match {
- case Module(info, name, ports, body) => Module(info,name,newPorts,body)
- case ext: ExtModule => ext.copy(ports = newPorts)
+ case Module(info, name, ports, body) => Module(info, name, newPorts, body)
+ case ext: ExtModule => ext.copy(ports = newPorts)
}
}
newModules.foreach(m => moduleTypes(m.name) = module_type(m))
/* @todo This should be moved outside */
- (firrtl.passes.InferTypes).run(Circuit(c.info, newModules.map(onModule(_)), c.main ))
+ (firrtl.passes.InferTypes).run(Circuit(c.info, newModules.map(onModule(_)), c.main))
}
}
-
-
-
// vim: set ts=4 sw=4 et:
diff --git a/src/main/scala/firrtl/passes/ExpandConnects.scala b/src/main/scala/firrtl/passes/ExpandConnects.scala
index d28e6399..4f849c5a 100644
--- a/src/main/scala/firrtl/passes/ExpandConnects.scala
+++ b/src/main/scala/firrtl/passes/ExpandConnects.scala
@@ -9,8 +9,7 @@ import firrtl.Mappers._
object ExpandConnects extends Pass {
override def prerequisites =
- Seq( Dependency(PullMuxes),
- Dependency(ReplaceAccesses) ) ++ firrtl.stage.Forms.Deduped
+ Seq(Dependency(PullMuxes), Dependency(ReplaceAccesses)) ++ firrtl.stage.Forms.Deduped
override def invalidates(a: Transform) = a match {
case ResolveFlows => true
@@ -19,62 +18,65 @@ object ExpandConnects extends Pass {
def run(c: Circuit): Circuit = {
def expand_connects(m: Module): Module = {
- val flows = collection.mutable.LinkedHashMap[String,Flow]()
+ val flows = collection.mutable.LinkedHashMap[String, Flow]()
def expand_s(s: Statement): Statement = {
- def set_flow(e: Expression): Expression = e map set_flow match {
+ def set_flow(e: Expression): Expression = e.map(set_flow) match {
case ex: WRef => WRef(ex.name, ex.tpe, ex.kind, flows(ex.name))
case ex: WSubField =>
val f = get_field(ex.expr.tpe, ex.name)
val flowx = times(flow(ex.expr), f.flip)
WSubField(ex.expr, ex.name, ex.tpe, flowx)
- case ex: WSubIndex => WSubIndex(ex.expr, ex.value, ex.tpe, flow(ex.expr))
+ case ex: WSubIndex => WSubIndex(ex.expr, ex.value, ex.tpe, flow(ex.expr))
case ex: WSubAccess => WSubAccess(ex.expr, ex.index, ex.tpe, flow(ex.expr))
case ex => ex
}
s match {
- case sx: DefWire => flows(sx.name) = DuplexFlow; sx
- case sx: DefRegister => flows(sx.name) = DuplexFlow; sx
+ case sx: DefWire => flows(sx.name) = DuplexFlow; sx
+ case sx: DefRegister => flows(sx.name) = DuplexFlow; sx
case sx: WDefInstance => flows(sx.name) = SourceFlow; sx
- case sx: DefMemory => flows(sx.name) = SourceFlow; sx
+ case sx: DefMemory => flows(sx.name) = SourceFlow; sx
case sx: DefNode => flows(sx.name) = SourceFlow; sx
case sx: IsInvalid =>
- val invalids = create_exps(sx.expr).flatMap { case expx =>
- flow(set_flow(expx)) match {
+ val invalids = create_exps(sx.expr).flatMap {
+ case expx =>
+ flow(set_flow(expx)) match {
case DuplexFlow => Some(IsInvalid(sx.info, expx))
- case SinkFlow => Some(IsInvalid(sx.info, expx))
- case _ => None
- }
+ case SinkFlow => Some(IsInvalid(sx.info, expx))
+ case _ => None
+ }
}
invalids.size match {
- case 0 => EmptyStmt
- case 1 => invalids.head
- case _ => Block(invalids)
+ case 0 => EmptyStmt
+ case 1 => invalids.head
+ case _ => Block(invalids)
}
case sx: Connect =>
val locs = create_exps(sx.loc)
val exps = create_exps(sx.expr)
- Block(locs.zip(exps).map { case (locx, expx) =>
- to_flip(flow(locx)) match {
+ Block(locs.zip(exps).map {
+ case (locx, expx) =>
+ to_flip(flow(locx)) match {
case Default => Connect(sx.info, locx, expx)
- case Flip => Connect(sx.info, expx, locx)
- }
+ case Flip => Connect(sx.info, expx, locx)
+ }
})
case sx: PartialConnect =>
val ls = get_valid_points(sx.loc.tpe, sx.expr.tpe, Default, Default)
val locs = create_exps(sx.loc)
val exps = create_exps(sx.expr)
- val stmts = ls map { case (x, y) =>
- locs(x).tpe match {
- case AnalogType(_) => Attach(sx.info, Seq(locs(x), exps(y)))
- case _ =>
- to_flip(flow(locs(x))) match {
- case Default => Connect(sx.info, locs(x), exps(y))
- case Flip => Connect(sx.info, exps(y), locs(x))
- }
- }
+ val stmts = ls.map {
+ case (x, y) =>
+ locs(x).tpe match {
+ case AnalogType(_) => Attach(sx.info, Seq(locs(x), exps(y)))
+ case _ =>
+ to_flip(flow(locs(x))) match {
+ case Default => Connect(sx.info, locs(x), exps(y))
+ case Flip => Connect(sx.info, exps(y), locs(x))
+ }
+ }
}
Block(stmts)
- case sx => sx map expand_s
+ case sx => sx.map(expand_s)
}
}
@@ -83,8 +85,8 @@ object ExpandConnects extends Pass {
}
val modulesx = c.modules.map {
- case (m: ExtModule) => m
- case (m: Module) => expand_connects(m)
+ case (m: ExtModule) => m
+ case (m: Module) => expand_connects(m)
}
Circuit(c.info, modulesx, c.main)
}
diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala
index ab7f02db..14d5d3ef 100644
--- a/src/main/scala/firrtl/passes/ExpandWhens.scala
+++ b/src/main/scala/firrtl/passes/ExpandWhens.scala
@@ -28,21 +28,23 @@ import collection.mutable
object ExpandWhens extends Pass {
override def prerequisites =
- Seq( Dependency(PullMuxes),
- Dependency(ReplaceAccesses),
- Dependency(ExpandConnects),
- Dependency(RemoveAccesses) ) ++ firrtl.stage.Forms.Resolved
+ Seq(
+ Dependency(PullMuxes),
+ Dependency(ReplaceAccesses),
+ Dependency(ExpandConnects),
+ Dependency(RemoveAccesses)
+ ) ++ firrtl.stage.Forms.Resolved
override def invalidates(a: Transform): Boolean = a match {
case CheckInitialization | ResolveKinds | InferTypes => true
- case _ => false
+ case _ => false
}
/** Returns circuit with when and last connection semantics resolved */
def run(c: Circuit): Circuit = {
- val modulesx = c.modules map {
+ val modulesx = c.modules.map {
case m: ExtModule => m
- case m: Module => onModule(m)
+ case m: Module => onModule(m)
}
Circuit(c.info, modulesx, c.main)
}
@@ -74,13 +76,12 @@ object ExpandWhens extends Pass {
// Does an expression contain WVoid inserted in this pass?
def containsVoid(e: Expression): Boolean = e match {
- case WVoid => true
+ case WVoid => true
case ValidIf(_, value, _) => memoizedVoid(value)
- case Mux(_, tv, fv, _) => memoizedVoid(tv) || memoizedVoid(fv)
- case _ => false
+ case Mux(_, tv, fv, _) => memoizedVoid(tv) || memoizedVoid(fv)
+ case _ => false
}
-
// Memoizes the node that holds a particular expression, if any
val nodes = new NodeLookup
@@ -95,18 +96,15 @@ object ExpandWhens extends Pass {
* @param p predicate so far, used to update simulation constructs
* @param s statement to expand
*/
- def expandWhens(netlist: Netlist,
- defaults: Defaults,
- p: Expression)
- (s: Statement): Statement = s match {
+ def expandWhens(netlist: Netlist, defaults: Defaults, p: Expression)(s: Statement): Statement = s match {
// For each non-register declaration, update netlist with value WVoid for each sink reference
// Return self, unchanged
case stmt @ (_: DefNode | EmptyStmt) => stmt
case w: DefWire =>
- netlist ++= (getSinkRefs(w.name, w.tpe, DuplexFlow) map (ref => we(ref) -> WVoid))
+ netlist ++= (getSinkRefs(w.name, w.tpe, DuplexFlow).map(ref => we(ref) -> WVoid))
w
case w: DefMemory =>
- netlist ++= (getSinkRefs(w.name, MemPortUtils.memType(w), SourceFlow) map (ref => we(ref) -> WVoid))
+ netlist ++= (getSinkRefs(w.name, MemPortUtils.memType(w), SourceFlow).map(ref => we(ref) -> WVoid))
w
case w: WDefInstance =>
netlist ++= (getSinkRefs(w.name, w.tpe, SourceFlow).map(ref => we(ref) -> WVoid))
@@ -151,82 +149,88 @@ object ExpandWhens extends Pass {
// Process combined maps because we only want to create 1 mux for each node
// present in the conseq and/or alt
- val memos = (conseqNetlist ++ altNetlist) map { case (lvalue, _) =>
- // Defaults in netlist get priority over those in defaults
- val default = netlist get lvalue match {
- case Some(v) => Some(v)
- case None => getDefault(lvalue, defaults)
- }
- // info0 and info1 correspond to Mux infos, use info0 only if ValidIf
- val (res, info0, info1) = default match {
- case Some(defaultValue) =>
- val (tinfo, trueValue) = unwrap(conseqNetlist.getOrElse(lvalue, defaultValue))
- val (finfo, falseValue) = unwrap(altNetlist.getOrElse(lvalue, defaultValue))
- (trueValue, falseValue) match {
- case (WInvalid, WInvalid) => (WInvalid, NoInfo, NoInfo)
- case (WInvalid, fv) => (ValidIf(NOT(sx.pred), fv, fv.tpe), finfo, NoInfo)
- case (tv, WInvalid) => (ValidIf(sx.pred, tv, tv.tpe), tinfo, NoInfo)
- case (tv, fv) => (Mux(sx.pred, tv, fv, mux_type_and_widths(tv, fv)), tinfo, finfo)
- }
- case None =>
- // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt
- (conseqNetlist.getOrElse(lvalue, altNetlist(lvalue)), NoInfo, NoInfo)
- }
+ val memos = (conseqNetlist ++ altNetlist).map {
+ case (lvalue, _) =>
+ // Defaults in netlist get priority over those in defaults
+ val default = netlist.get(lvalue) match {
+ case Some(v) => Some(v)
+ case None => getDefault(lvalue, defaults)
+ }
+ // info0 and info1 correspond to Mux infos, use info0 only if ValidIf
+ val (res, info0, info1) = default match {
+ case Some(defaultValue) =>
+ val (tinfo, trueValue) = unwrap(conseqNetlist.getOrElse(lvalue, defaultValue))
+ val (finfo, falseValue) = unwrap(altNetlist.getOrElse(lvalue, defaultValue))
+ (trueValue, falseValue) match {
+ case (WInvalid, WInvalid) => (WInvalid, NoInfo, NoInfo)
+ case (WInvalid, fv) => (ValidIf(NOT(sx.pred), fv, fv.tpe), finfo, NoInfo)
+ case (tv, WInvalid) => (ValidIf(sx.pred, tv, tv.tpe), tinfo, NoInfo)
+ case (tv, fv) => (Mux(sx.pred, tv, fv, mux_type_and_widths(tv, fv)), tinfo, finfo)
+ }
+ case None =>
+ // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt
+ (conseqNetlist.getOrElse(lvalue, altNetlist(lvalue)), NoInfo, NoInfo)
+ }
- res match {
- // Don't create a node to hold mux trees with void values
- // "Idiomatic" emission of these muxes isn't a concern because they represent bad code (latches)
- case e if containsVoid(e) =>
- netlist(lvalue) = e
- memoizedVoid += e // remember that this was void
- EmptyStmt
- case _: ValidIf | _: Mux | _: DoPrim => nodes get res match {
- case Some(name) =>
- netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow)
+ res match {
+ // Don't create a node to hold mux trees with void values
+ // "Idiomatic" emission of these muxes isn't a concern because they represent bad code (latches)
+ case e if containsVoid(e) =>
+ netlist(lvalue) = e
+ memoizedVoid += e // remember that this was void
+ EmptyStmt
+ case _: ValidIf | _: Mux | _: DoPrim =>
+ nodes.get(res) match {
+ case Some(name) =>
+ netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow)
+ EmptyStmt
+ case None =>
+ val name = namespace.newTemp
+ nodes(res) = name
+ netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow)
+ // Use MultiInfo constructor to preserve NoInfos
+ val info = new MultiInfo(List(sx.info, info0, info1))
+ DefNode(info, name, res)
+ }
+ case _ =>
+ netlist(lvalue) = res
EmptyStmt
- case None =>
- val name = namespace.newTemp
- nodes(res) = name
- netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow)
- // Use MultiInfo constructor to preserve NoInfos
- val info = new MultiInfo(List(sx.info, info0, info1))
- DefNode(info, name, res)
}
- case _ =>
- netlist(lvalue) = res
- EmptyStmt
- }
}
Block(Seq(conseqStmt, altStmt) ++ memos)
- case block: Block => block map expandWhens(netlist, defaults, p)
+ case block: Block => block.map(expandWhens(netlist, defaults, p))
case _ => throwInternalError()
}
val netlist = new Netlist
// Add ports to netlist
- netlist ++= (m.ports flatMap { case Port(_, name, dir, tpe) =>
- getSinkRefs(name, tpe, to_flow(dir)) map (ref => we(ref) -> WVoid)
+ netlist ++= (m.ports.flatMap {
+ case Port(_, name, dir, tpe) =>
+ getSinkRefs(name, tpe, to_flow(dir)).map(ref => we(ref) -> WVoid)
})
// Do traversal and construct mutable datastructures
val bodyx = expandWhens(netlist, Seq(netlist), one)(m.body)
val attachedAnalogs = attaches.flatMap(_.exprs.map(we)).toSet
- val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist, attachedAnalogs) ++
- combineAttaches(attaches.toSeq) ++ simlist)
+ val newBody = Block(
+ Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist, attachedAnalogs) ++
+ combineAttaches(attaches.toSeq) ++ simlist
+ )
Module(m.info, m.name, m.ports, newBody)
}
-
/** Returns all references to all sink leaf subcomponents of a reference */
private def getSinkRefs(n: String, t: Type, g: Flow): Seq[Expression] = {
val exps = create_exps(WRef(n, t, ExpKind, g))
- exps.flatMap { case exp =>
- exp.tpe match {
- case AnalogType(w) => None
- case _ => flow(exp) match {
- case (DuplexFlow | SinkFlow) => Some(exp)
- case _ => None
+ exps.flatMap {
+ case exp =>
+ exp.tpe match {
+ case AnalogType(w) => None
+ case _ =>
+ flow(exp) match {
+ case (DuplexFlow | SinkFlow) => Some(exp)
+ case _ => None
+ }
}
- }
}
}
@@ -238,7 +242,7 @@ object ExpandWhens extends Pass {
def handleInvalid(k: WrappedExpression, info: Info): Statement =
if (attached.contains(k)) EmptyStmt else IsInvalid(info, k.e1)
netlist.map {
- case (k, WInvalid) => handleInvalid(k, NoInfo)
+ case (k, WInvalid) => handleInvalid(k, NoInfo)
case (k, InfoExpr(info, WInvalid)) => handleInvalid(k, info)
case (k, v) =>
val (info, expr) = unwrap(v)
@@ -261,7 +265,7 @@ object ExpandWhens extends Pass {
case Seq() => // None of these expressions is present in the attachMap
AttachAcc(exprs, attachMap.size)
case accs => // At least one expression present in the attachMap
- val sorted = accs sortBy (_.idx)
+ val sorted = accs.sortBy(_.idx)
AttachAcc((sorted.map(_.exprs) :+ exprs).flatten.distinct, sorted.head.idx)
}
attachMap ++= acc.exprs.map(_ -> acc)
@@ -274,10 +278,11 @@ object ExpandWhens extends Pass {
private def getDefault(lvalue: WrappedExpression, defaults: Defaults): Option[Expression] = {
defaults match {
case Nil => None
- case head :: tail => head get lvalue match {
- case Some(p) => Some(p)
- case None => getDefault(lvalue, tail)
- }
+ case head :: tail =>
+ head.get(lvalue) match {
+ case Some(p) => Some(p)
+ case None => getDefault(lvalue, tail)
+ }
}
}
@@ -290,10 +295,12 @@ object ExpandWhens extends Pass {
class ExpandWhensAndCheck extends Transform with DependencyAPIMigration {
override def prerequisites =
- Seq( Dependency(PullMuxes),
- Dependency(ReplaceAccesses),
- Dependency(ExpandConnects),
- Dependency(RemoveAccesses) ) ++ firrtl.stage.Forms.Deduped
+ Seq(
+ Dependency(PullMuxes),
+ Dependency(ReplaceAccesses),
+ Dependency(ExpandConnects),
+ Dependency(RemoveAccesses)
+ ) ++ firrtl.stage.Forms.Deduped
override def invalidates(a: Transform): Boolean = a match {
case ResolveKinds | InferTypes | ResolveFlows | _: InferWidths => true
@@ -301,6 +308,6 @@ class ExpandWhensAndCheck extends Transform with DependencyAPIMigration {
}
override def execute(a: CircuitState): CircuitState =
- Seq(ExpandWhens, CheckInitialization).foldLeft(a){ case (acc, tx) => tx.transform(acc) }
+ Seq(ExpandWhens, CheckInitialization).foldLeft(a) { case (acc, tx) => tx.transform(acc) }
}
diff --git a/src/main/scala/firrtl/passes/InferBinaryPoints.scala b/src/main/scala/firrtl/passes/InferBinaryPoints.scala
index a16205a7..f393d8a5 100644
--- a/src/main/scala/firrtl/passes/InferBinaryPoints.scala
+++ b/src/main/scala/firrtl/passes/InferBinaryPoints.scala
@@ -13,9 +13,7 @@ import firrtl.options.Dependency
class InferBinaryPoints extends Pass {
override def prerequisites =
- Seq( Dependency(ResolveKinds),
- Dependency(InferTypes),
- Dependency(ResolveFlows) )
+ Seq(Dependency(ResolveKinds), Dependency(InferTypes), Dependency(ResolveFlows))
override def optionalPrerequisiteOf = Seq.empty
@@ -23,12 +21,12 @@ class InferBinaryPoints extends Pass {
private val constraintSolver = new ConstraintSolver()
- private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1,t2) match {
- case (UIntType(w1), UIntType(w2)) =>
- case (SIntType(w1), SIntType(w2)) =>
- case (ClockType, ClockType) =>
- case (ResetType, _) =>
- case (_, ResetType) =>
+ private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1, t2) match {
+ case (UIntType(w1), UIntType(w2)) =>
+ case (SIntType(w1), SIntType(w2)) =>
+ case (ClockType, ClockType) =>
+ case (ResetType, _) =>
+ case (_, ResetType) =>
case (AsyncResetType, AsyncResetType) =>
case (FixedType(w1, p1), FixedType(w2, p2)) =>
constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint(""))
@@ -36,78 +34,86 @@ class InferBinaryPoints extends Pass {
constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint(""))
case (AnalogType(w1), AnalogType(w2)) =>
case (t1: BundleType, t2: BundleType) =>
- (t1.fields zip t2.fields) foreach { case (f1, f2) =>
- (f1.flip, f2.flip) match {
- case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe)
- case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe)
- case _ => sys.error("Shouldn't be here")
- }
+ (t1.fields.zip(t2.fields)).foreach {
+ case (f1, f2) =>
+ (f1.flip, f2.flip) match {
+ case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe)
+ case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe)
+ case _ => sys.error("Shouldn't be here")
+ }
}
case (t1: VectorType, t2: VectorType) => addTypeConstraints(r1.index(0), r2.index(0))(t1.tpe, t2.tpe)
case other => throwInternalError(s"Illegal compiler state: cannot constraint different types - $other")
}
- private def addDecConstraints(t: Type): Type = t map addDecConstraints
- private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s map addDecConstraints match {
+ private def addDecConstraints(t: Type): Type = t.map(addDecConstraints)
+ private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s.map(addDecConstraints) match {
case c: Connect =>
val n = get_size(c.loc.tpe)
val locs = create_exps(c.loc)
val exps = create_exps(c.expr)
- (locs zip exps) foreach { case (loc, exp) =>
- to_flip(flow(loc)) match {
- case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
- case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
- }
+ (locs.zip(exps)).foreach {
+ case (loc, exp) =>
+ to_flip(flow(loc)) match {
+ case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
+ case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
+ }
}
c
case pc: PartialConnect =>
val ls = get_valid_points(pc.loc.tpe, pc.expr.tpe, Default, Default)
val locs = create_exps(pc.loc)
val exps = create_exps(pc.expr)
- ls foreach { case (x, y) =>
- val loc = locs(x)
- val exp = exps(y)
- to_flip(flow(loc)) match {
- case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
- case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
- }
+ ls.foreach {
+ case (x, y) =>
+ val loc = locs(x)
+ val exp = exps(y)
+ to_flip(flow(loc)) match {
+ case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
+ case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
+ }
}
pc
case r: DefRegister =>
- addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe)
+ addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe)
r
- case x => x map addStmtConstraints(mt)
+ case x => x.map(addStmtConstraints(mt))
}
private def fixWidth(w: Width): Width = constraintSolver.get(w) match {
case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt)
- case None => w
- case _ => sys.error("Shouldn't be here")
+ case None => w
+ case _ => sys.error("Shouldn't be here")
}
- private def fixType(t: Type): Type = t map fixType map fixWidth match {
+ private def fixType(t: Type): Type = t.map(fixType).map(fixWidth) match {
case IntervalType(l, u, p) =>
val px = constraintSolver.get(p) match {
case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt)
- case None => p
- case _ => sys.error("Shouldn't be here")
+ case None => p
+ case _ => sys.error("Shouldn't be here")
}
IntervalType(l, u, px)
case FixedType(w, p) =>
val px = constraintSolver.get(p) match {
case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt)
- case None => p
- case _ => sys.error("Shouldn't be here")
+ case None => p
+ case _ => sys.error("Shouldn't be here")
}
FixedType(w, px)
case x => x
}
- private def fixStmt(s: Statement): Statement = s map fixStmt map fixType
- private def fixPort(p: Port): Port = Port(p.info, p.name, p.direction, fixType(p.tpe))
- def run (c: Circuit): Circuit = {
+ private def fixStmt(s: Statement): Statement = s.map(fixStmt).map(fixType)
+ private def fixPort(p: Port): Port = Port(p.info, p.name, p.direction, fixType(p.tpe))
+ def run(c: Circuit): Circuit = {
val ct = CircuitTarget(c.main)
- c.modules foreach (m => m map addStmtConstraints(ct.module(m.name)))
- c.modules foreach (_.ports foreach {p => addDecConstraints(p.tpe)})
+ c.modules.foreach(m => m.map(addStmtConstraints(ct.module(m.name))))
+ c.modules.foreach(_.ports.foreach { p => addDecConstraints(p.tpe) })
constraintSolver.solve()
- InferTypes.run(c.copy(modules = c.modules map (_
- map fixPort
- map fixStmt)))
+ InferTypes.run(
+ c.copy(modules =
+ c.modules.map(
+ _.map(fixPort)
+ .map(fixStmt)
+ )
+ )
+ )
}
}
diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala
index 6cc9f2b9..4d14e7ff 100644
--- a/src/main/scala/firrtl/passes/InferTypes.scala
+++ b/src/main/scala/firrtl/passes/InferTypes.scala
@@ -23,16 +23,16 @@ object InferTypes extends Pass {
def remove_unknowns_b(b: Bound): Bound = b match {
case UnknownBound => VarBound(namespace.newName("b"))
- case k => k
+ case k => k
}
def remove_unknowns_w(w: Width): Width = w match {
case UnknownWidth => VarWidth(namespace.newName("w"))
- case wx => wx
+ case wx => wx
}
def remove_unknowns(t: Type): Type = {
- t map remove_unknowns map remove_unknowns_w match {
+ t.map(remove_unknowns).map(remove_unknowns_w) match {
case IntervalType(l, u, p) =>
IntervalType(remove_unknowns_b(l), remove_unknowns_b(u), p)
case x => x
@@ -41,18 +41,18 @@ object InferTypes extends Pass {
// we first need to remove the unknown widths and bounds from all ports,
// as their type will determine the module types
- val portsKnown = c.modules.map(_.map{ p: Port => p.copy(tpe = remove_unknowns(p.tpe)) })
+ val portsKnown = c.modules.map(_.map { p: Port => p.copy(tpe = remove_unknowns(p.tpe)) })
val mtypes = portsKnown.map(m => m.name -> module_type(m)).toMap
def infer_types_e(types: TypeLookup)(e: Expression): Expression =
- e map infer_types_e(types) match {
- case e: WRef => e copy (tpe = types(e.name))
- case e: WSubField => e copy (tpe = field_type(e.expr.tpe, e.name))
- case e: WSubIndex => e copy (tpe = sub_type(e.expr.tpe))
- case e: WSubAccess => e copy (tpe = sub_type(e.expr.tpe))
- case e: DoPrim => PrimOps.set_primop_type(e)
- case e: Mux => e copy (tpe = mux_type_and_widths(e.tval, e.fval))
- case e: ValidIf => e copy (tpe = e.value.tpe)
+ e.map(infer_types_e(types)) match {
+ case e: WRef => e.copy(tpe = types(e.name))
+ case e: WSubField => e.copy(tpe = field_type(e.expr.tpe, e.name))
+ case e: WSubIndex => e.copy(tpe = sub_type(e.expr.tpe))
+ case e: WSubAccess => e.copy(tpe = sub_type(e.expr.tpe))
+ case e: DoPrim => PrimOps.set_primop_type(e)
+ case e: Mux => e.copy(tpe = mux_type_and_widths(e.tval, e.fval))
+ case e: ValidIf => e.copy(tpe = e.value.tpe)
case e @ (_: UIntLiteral | _: SIntLiteral) => e
}
@@ -60,37 +60,37 @@ object InferTypes extends Pass {
case sx: WDefInstance =>
val t = mtypes(sx.module)
types(sx.name) = t
- sx copy (tpe = t)
+ sx.copy(tpe = t)
case sx: DefWire =>
val t = remove_unknowns(sx.tpe)
types(sx.name) = t
- sx copy (tpe = t)
+ sx.copy(tpe = t)
case sx: DefNode =>
- val sxx = (sx map infer_types_e(types)).asInstanceOf[DefNode]
+ val sxx = (sx.map(infer_types_e(types))).asInstanceOf[DefNode]
val t = remove_unknowns(sxx.value.tpe)
types(sx.name) = t
sxx
case sx: DefRegister =>
val t = remove_unknowns(sx.tpe)
types(sx.name) = t
- sx copy (tpe = t) map infer_types_e(types)
+ sx.copy(tpe = t).map(infer_types_e(types))
case sx: DefMemory =>
// we need to remove the unknowns from the data type so that all ports get the same VarWidth
val knownDataType = sx.copy(dataType = remove_unknowns(sx.dataType))
types(sx.name) = MemPortUtils.memType(knownDataType)
knownDataType
- case sx => sx map infer_types_s(types) map infer_types_e(types)
+ case sx => sx.map(infer_types_s(types)).map(infer_types_e(types))
}
def infer_types_p(types: TypeLookup)(p: Port): Port = {
val t = remove_unknowns(p.tpe)
types(p.name) = t
- p copy (tpe = t)
+ p.copy(tpe = t)
}
def infer_types(m: DefModule): DefModule = {
val types = new TypeLookup
- m map infer_types_p(types) map infer_types_s(types)
+ m.map(infer_types_p(types)).map(infer_types_s(types))
}
c.copy(modules = portsKnown.map(infer_types))
@@ -108,45 +108,45 @@ object CInferTypes extends Pass {
private type TypeLookup = collection.mutable.HashMap[String, Type]
def run(c: Circuit): Circuit = {
- val mtypes = (c.modules map (m => m.name -> module_type(m))).toMap
-
- def infer_types_e(types: TypeLookup)(e: Expression) : Expression =
- e map infer_types_e(types) match {
- case (e: Reference) => e copy (tpe = types.getOrElse(e.name, UnknownType))
- case (e: SubField) => e copy (tpe = field_type(e.expr.tpe, e.name))
- case (e: SubIndex) => e copy (tpe = sub_type(e.expr.tpe))
- case (e: SubAccess) => e copy (tpe = sub_type(e.expr.tpe))
- case (e: DoPrim) => PrimOps.set_primop_type(e)
- case (e: Mux) => e copy (tpe = mux_type(e.tval, e.fval))
- case (e: ValidIf) => e copy (tpe = e.value.tpe)
- case e @ (_: UIntLiteral | _: SIntLiteral) => e
+ val mtypes = (c.modules.map(m => m.name -> module_type(m))).toMap
+
+ def infer_types_e(types: TypeLookup)(e: Expression): Expression =
+ e.map(infer_types_e(types)) match {
+ case (e: Reference) => e.copy(tpe = types.getOrElse(e.name, UnknownType))
+ case (e: SubField) => e.copy(tpe = field_type(e.expr.tpe, e.name))
+ case (e: SubIndex) => e.copy(tpe = sub_type(e.expr.tpe))
+ case (e: SubAccess) => e.copy(tpe = sub_type(e.expr.tpe))
+ case (e: DoPrim) => PrimOps.set_primop_type(e)
+ case (e: Mux) => e.copy(tpe = mux_type(e.tval, e.fval))
+ case (e: ValidIf) => e.copy(tpe = e.value.tpe)
+ case e @ (_: UIntLiteral | _: SIntLiteral) => e
}
def infer_types_s(types: TypeLookup)(s: Statement): Statement = s match {
case sx: DefRegister =>
types(sx.name) = sx.tpe
- sx map infer_types_e(types)
+ sx.map(infer_types_e(types))
case sx: DefWire =>
types(sx.name) = sx.tpe
sx
case sx: DefNode =>
- val sxx = (sx map infer_types_e(types)).asInstanceOf[DefNode]
+ val sxx = (sx.map(infer_types_e(types))).asInstanceOf[DefNode]
types(sxx.name) = sxx.value.tpe
sxx
case sx: DefMemory =>
types(sx.name) = MemPortUtils.memType(sx)
sx
case sx: CDefMPort =>
- val t = types getOrElse(sx.mem, UnknownType)
+ val t = types.getOrElse(sx.mem, UnknownType)
types(sx.name) = t
- sx copy (tpe = t)
+ sx.copy(tpe = t)
case sx: CDefMemory =>
types(sx.name) = sx.tpe
sx
case sx: DefInstance =>
types(sx.name) = mtypes(sx.module)
sx
- case sx => sx map infer_types_s(types) map infer_types_e(types)
+ case sx => sx.map(infer_types_s(types)).map(infer_types_e(types))
}
def infer_types_p(types: TypeLookup)(p: Port): Port = {
@@ -156,9 +156,9 @@ object CInferTypes extends Pass {
def infer_types(m: DefModule): DefModule = {
val types = new TypeLookup
- m map infer_types_p(types) map infer_types_s(types)
+ m.map(infer_types_p(types)).map(infer_types_s(types))
}
- c copy (modules = c.modules map infer_types)
+ c.copy(modules = c.modules.map(infer_types))
}
}
diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala
index 3720523b..eae9690f 100644
--- a/src/main/scala/firrtl/passes/InferWidths.scala
+++ b/src/main/scala/firrtl/passes/InferWidths.scala
@@ -14,7 +14,7 @@ import firrtl.options.Dependency
object InferWidths {
def apply(): InferWidths = new InferWidths()
- def run(c: Circuit): Circuit = new InferWidths().run(c)(new ConstraintSolver)
+ def run(c: Circuit): Circuit = new InferWidths().run(c)(new ConstraintSolver)
def execute(state: CircuitState): CircuitState = new InferWidths().execute(state)
}
@@ -22,12 +22,14 @@ case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarg
def update(renameMap: RenameMap): Seq[WidthGeqConstraintAnnotation] = {
val newLoc :: newExp :: Nil = Seq(loc, exp).map { target =>
renameMap.get(target) match {
- case None => Some(target)
- case Some(Seq()) => None
+ case None => Some(target)
+ case Some(Seq()) => None
case Some(Seq(one)) => Some(one)
case Some(many) =>
- throw new Exception(s"Target below is an AggregateType, which " +
- "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint())
+ throw new Exception(
+ s"Target below is an AggregateType, which " +
+ "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint()
+ )
}
}
@@ -60,28 +62,31 @@ case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarg
*
* Uses firrtl.constraint package to infer widths
*/
-class InferWidths extends Transform
- with ResolvedAnnotationPaths
- with DependencyAPIMigration {
+class InferWidths extends Transform with ResolvedAnnotationPaths with DependencyAPIMigration {
override def prerequisites =
- Seq( Dependency(passes.ResolveKinds),
- Dependency(passes.InferTypes),
- Dependency(passes.ResolveFlows),
- Dependency[passes.InferBinaryPoints],
- Dependency[passes.TrimIntervals] ) ++ firrtl.stage.Forms.WorkingIR
+ Seq(
+ Dependency(passes.ResolveKinds),
+ Dependency(passes.InferTypes),
+ Dependency(passes.ResolveFlows),
+ Dependency[passes.InferBinaryPoints],
+ Dependency[passes.TrimIntervals]
+ ) ++ firrtl.stage.Forms.WorkingIR
override def invalidates(a: Transform) = false
val annotationClasses = Seq(classOf[WidthGeqConstraintAnnotation])
- private def addTypeConstraints
- (r1: ReferenceTarget, r2: ReferenceTarget)
- (t1: Type, t2: Type)
- (implicit constraintSolver: ConstraintSolver)
- : Unit = (t1,t2) match {
+ private def addTypeConstraints(
+ r1: ReferenceTarget,
+ r2: ReferenceTarget
+ )(t1: Type,
+ t2: Type
+ )(
+ implicit constraintSolver: ConstraintSolver
+ ): Unit = (t1, t2) match {
case (UIntType(w1), UIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
case (SIntType(w1), SIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
- case (ClockType, ClockType) =>
+ case (ClockType, ClockType) =>
case (FixedType(w1, p1), FixedType(w2, p2)) =>
constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint(""))
constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
@@ -93,101 +98,119 @@ class InferWidths extends Transform
constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
constraintSolver.addGeq(w2, w1, r1.prettyPrint(""), r2.prettyPrint(""))
case (t1: BundleType, t2: BundleType) =>
- (t1.fields zip t2.fields) foreach { case (f1, f2) =>
- (f1.flip, f2.flip) match {
- case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe)
- case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe)
- case _ => sys.error("Shouldn't be here")
- }
+ (t1.fields.zip(t2.fields)).foreach {
+ case (f1, f2) =>
+ (f1.flip, f2.flip) match {
+ case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe)
+ case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe)
+ case _ => sys.error("Shouldn't be here")
+ }
}
case (t1: VectorType, t2: VectorType) => addTypeConstraints(r1.index(0), r2.index(0))(t1.tpe, t2.tpe)
case (AsyncResetType, AsyncResetType) => Nil
- case (ResetType, _) => Nil
- case (_, ResetType) => Nil
+ case (ResetType, _) => Nil
+ case (_, ResetType) => Nil
}
- private def addExpConstraints(e: Expression)(implicit constraintSolver: ConstraintSolver)
- : Expression = e map addExpConstraints match {
- case m@Mux(p, tVal, fVal, t) =>
- constraintSolver.addGeq(getWidth(p), Closed(1), "mux predicate", "1.W")
- m
- case other => other
- }
+ private def addExpConstraints(e: Expression)(implicit constraintSolver: ConstraintSolver): Expression =
+ e.map(addExpConstraints) match {
+ case m @ Mux(p, tVal, fVal, t) =>
+ constraintSolver.addGeq(getWidth(p), Closed(1), "mux predicate", "1.W")
+ m
+ case other => other
+ }
- private def addStmtConstraints(mt: ModuleTarget)(s: Statement)(implicit constraintSolver: ConstraintSolver)
- : Statement = s map addExpConstraints match {
+ private def addStmtConstraints(
+ mt: ModuleTarget
+ )(s: Statement
+ )(
+ implicit constraintSolver: ConstraintSolver
+ ): Statement = s.map(addExpConstraints) match {
case c: Connect =>
val n = get_size(c.loc.tpe)
val locs = create_exps(c.loc)
val exps = create_exps(c.expr)
- (locs zip exps).foreach { case (loc, exp) =>
- to_flip(flow(loc)) match {
- case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
- case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
- }
- }
+ (locs.zip(exps)).foreach {
+ case (loc, exp) =>
+ to_flip(flow(loc)) match {
+ case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
+ case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
+ }
+ }
c
case pc: PartialConnect =>
val ls = get_valid_points(pc.loc.tpe, pc.expr.tpe, Default, Default)
val locs = create_exps(pc.loc)
val exps = create_exps(pc.expr)
- ls foreach { case (x, y) =>
- val loc = locs(x)
- val exp = exps(y)
- to_flip(flow(loc)) match {
- case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
- case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
- }
+ ls.foreach {
+ case (x, y) =>
+ val loc = locs(x)
+ val exp = exps(y)
+ to_flip(flow(loc)) match {
+ case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
+ case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
+ }
}
pc
case r: DefRegister =>
- if (r.reset.tpe != AsyncResetType ) {
+ if (r.reset.tpe != AsyncResetType) {
addTypeConstraints(Target.asTarget(mt)(r.reset), mt.ref("1"))(r.reset.tpe, UIntType(IntWidth(1)))
}
addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe)
r
- case a@Attach(_, exprs) =>
- val widths = exprs map (e => (e, getWidth(e.tpe)))
+ case a @ Attach(_, exprs) =>
+ val widths = exprs.map(e => (e, getWidth(e.tpe)))
val maxWidth = IsMax(widths.map(x => width2constraint(x._2)))
- widths.foreach { case (e, w) =>
- constraintSolver.addGeq(w, CalcWidth(maxWidth), Target.asTarget(mt)(e).prettyPrint(""), mt.ref(a.serialize).prettyPrint(""))
+ widths.foreach {
+ case (e, w) =>
+ constraintSolver.addGeq(
+ w,
+ CalcWidth(maxWidth),
+ Target.asTarget(mt)(e).prettyPrint(""),
+ mt.ref(a.serialize).prettyPrint("")
+ )
}
a
case c: Conditionally =>
addTypeConstraints(Target.asTarget(mt)(c.pred), mt.ref("1.W"))(c.pred.tpe, UIntType(IntWidth(1)))
- c map addStmtConstraints(mt)
- case x => x map addStmtConstraints(mt)
+ c.map(addStmtConstraints(mt))
+ case x => x.map(addStmtConstraints(mt))
}
private def fixWidth(w: Width)(implicit constraintSolver: ConstraintSolver): Width = constraintSolver.get(w) match {
case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt)
- case None => w
- case _ => sys.error("Shouldn't be here")
+ case None => w
+ case _ => sys.error("Shouldn't be here")
}
- private def fixType(t: Type)(implicit constraintSolver: ConstraintSolver): Type = t map fixType map fixWidth match {
+ private def fixType(t: Type)(implicit constraintSolver: ConstraintSolver): Type = t.map(fixType).map(fixWidth) match {
case IntervalType(l, u, p) =>
val (lx, ux) = (constraintSolver.get(l), constraintSolver.get(u)) match {
case (Some(x: Bound), Some(y: Bound)) => (x, y)
case (None, None) => (l, u)
- case x => sys.error(s"Shouldn't be here: $x")
-
+ case x => sys.error(s"Shouldn't be here: $x")
}
IntervalType(lx, ux, fixWidth(p))
case FixedType(w, p) => FixedType(w, fixWidth(p))
- case x => x
+ case x => x
}
- private def fixStmt(s: Statement)(implicit constraintSolver: ConstraintSolver): Statement = s map fixStmt map fixType
+ private def fixStmt(s: Statement)(implicit constraintSolver: ConstraintSolver): Statement =
+ s.map(fixStmt).map(fixType)
private def fixPort(p: Port)(implicit constraintSolver: ConstraintSolver): Port = {
Port(p.info, p.name, p.direction, fixType(p.tpe))
}
- def run (c: Circuit)(implicit constraintSolver: ConstraintSolver): Circuit = {
+ def run(c: Circuit)(implicit constraintSolver: ConstraintSolver): Circuit = {
val ct = CircuitTarget(c.main)
- c.modules foreach ( m => m map addStmtConstraints(ct.module(m.name)))
+ c.modules.foreach(m => m.map(addStmtConstraints(ct.module(m.name))))
constraintSolver.solve()
- val ret = InferTypes.run(c.copy(modules = c.modules map (_
- map fixPort
- map fixStmt)))
+ val ret = InferTypes.run(
+ c.copy(modules =
+ c.modules.map(
+ _.map(fixPort)
+ .map(fixStmt)
+ )
+ )
+ )
constraintSolver.clear()
ret
}
@@ -200,15 +223,16 @@ class InferWidths extends Transform
def getDeclTypes(modName: String)(stmt: Statement): Unit = {
val pairOpt = stmt match {
- case w: DefWire => Some(w.name -> w.tpe)
- case r: DefRegister => Some(r.name -> r.tpe)
- case n: DefNode => Some(n.name -> n.value.tpe)
+ case w: DefWire => Some(w.name -> w.tpe)
+ case r: DefRegister => Some(r.name -> r.tpe)
+ case n: DefNode => Some(n.name -> n.value.tpe)
case i: WDefInstance => Some(i.name -> i.tpe)
- case m: DefMemory => Some(m.name -> MemPortUtils.memType(m))
+ case m: DefMemory => Some(m.name -> MemPortUtils.memType(m))
case other => None
}
- pairOpt.foreach { case (ref, tpe) =>
- typeMap += (ReferenceTarget(circuitName, modName, Nil, ref, Nil) -> tpe)
+ pairOpt.foreach {
+ case (ref, tpe) =>
+ typeMap += (ReferenceTarget(circuitName, modName, Nil, ref, Nil) -> tpe)
}
stmt.foreachStmt(getDeclTypes(modName))
}
@@ -223,14 +247,20 @@ class InferWidths extends Transform
}
state.annotations.foreach {
- case anno: WidthGeqConstraintAnnotation if anno.loc.isLocal && anno.exp.isLocal =>
- val locType :: expType :: Nil = Seq(anno.loc, anno.exp) map { target =>
- val baseType = typeMap.getOrElse(target.copy(component = Seq.empty),
- throw new Exception(s"Target below from WidthGeqConstraintAnnotation was not found\n" + target.prettyPrint()))
+ case anno: WidthGeqConstraintAnnotation if anno.loc.isLocal && anno.exp.isLocal =>
+ val locType :: expType :: Nil = Seq(anno.loc, anno.exp).map { target =>
+ val baseType = typeMap.getOrElse(
+ target.copy(component = Seq.empty),
+ throw new Exception(
+ s"Target below from WidthGeqConstraintAnnotation was not found\n" + target.prettyPrint()
+ )
+ )
val leafType = target.componentType(baseType)
if (leafType.isInstanceOf[AggregateType]) {
- throw new Exception(s"Target below is an AggregateType, which " +
- "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint())
+ throw new Exception(
+ s"Target below is an AggregateType, which " +
+ "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint()
+ )
}
leafType
diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala
index ad963b19..316878fb 100644
--- a/src/main/scala/firrtl/passes/Inline.scala
+++ b/src/main/scala/firrtl/passes/Inline.scala
@@ -32,89 +32,100 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe
override def invalidates(a: Transform): Boolean = a == ResolveKinds
- private [firrtl] val inlineDelim: String = "_"
+ private[firrtl] val inlineDelim: String = "_"
val options = Seq(
new ShellOption[Seq[String]](
longOption = "inline",
- toAnnotationSeq = (a: Seq[String]) => a.map { value =>
- value.split('.') match {
- case Array(circuit) =>
- InlineAnnotation(CircuitName(circuit))
- case Array(circuit, module) =>
- InlineAnnotation(ModuleName(module, CircuitName(circuit)))
- case Array(circuit, module, inst) =>
- InlineAnnotation(ComponentName(inst, ModuleName(module, CircuitName(circuit))))
- }
- } :+ RunFirrtlTransformAnnotation(new InlineInstances),
+ toAnnotationSeq = (a: Seq[String]) =>
+ a.map { value =>
+ value.split('.') match {
+ case Array(circuit) =>
+ InlineAnnotation(CircuitName(circuit))
+ case Array(circuit, module) =>
+ InlineAnnotation(ModuleName(module, CircuitName(circuit)))
+ case Array(circuit, module, inst) =>
+ InlineAnnotation(ComponentName(inst, ModuleName(module, CircuitName(circuit))))
+ }
+ } :+ RunFirrtlTransformAnnotation(new InlineInstances),
helpText = "Inline selected modules",
shortOption = Some("fil"),
- helpValueName = Some("<circuit>[.<module>[.<instance>]][,...]") ) )
-
- private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) =
- anns.foldLeft( (Set.empty[ModuleName], Set.empty[ComponentName]) ) {
- case ((modNames, instNames), ann) => ann match {
- case InlineAnnotation(CircuitName(c)) =>
- (circuit.modules.collect {
- case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c))
- }.toSet, instNames)
- case InlineAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames)
- case InlineAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod))
- case _ => (modNames, instNames)
- }
- }
-
- def execute(state: CircuitState): CircuitState = {
- // TODO Add error check for more than one annotation for inlining
- val (modNames, instNames) = collectAnns(state.circuit, state.annotations)
- if (modNames.nonEmpty || instNames.nonEmpty) {
- run(state.circuit, modNames, instNames, state.annotations)
- } else {
- state
- }
- }
-
- // Checks the following properties:
- // 1) All annotated modules exist
- // 2) All annotated modules are InModules (can be inlined)
- // 3) All annotated instances exist, and their modules can be inline
- def check(c: Circuit, moduleNames: Set[ModuleName], instanceNames: Set[ComponentName]): Unit = {
- val errors = mutable.ArrayBuffer[PassException]()
- val moduleMap = InstanceKeyGraph(c).moduleMap
- def checkExists(name: String): Unit =
- if (!moduleMap.contains(name))
- errors += new PassException(s"Annotated module does not exist: $name")
- def checkExternal(name: String): Unit = moduleMap(name) match {
- case m: ExtModule => errors += new PassException(s"Annotated module cannot be an external module: $name")
- case _ =>
- }
- def checkInstance(cn: ComponentName): Unit = {
- var containsCN = false
- def onStmt(name: String)(s: Statement): Statement = {
- s match {
- case WDefInstance(_, inst_name, module_name, tpe) =>
- if (name == inst_name) {
- containsCN = true
- checkExternal(module_name)
- }
- case _ =>
+ helpValueName = Some("<circuit>[.<module>[.<instance>]][,...]")
+ )
+ )
+
+ private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) =
+ anns.foldLeft((Set.empty[ModuleName], Set.empty[ComponentName])) {
+ case ((modNames, instNames), ann) =>
+ ann match {
+ case InlineAnnotation(CircuitName(c)) =>
+ (
+ circuit.modules.collect {
+ case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c))
+ }.toSet,
+ instNames
+ )
+ case InlineAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames)
+ case InlineAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod))
+ case _ => (modNames, instNames)
+ }
+ }
+
+ def execute(state: CircuitState): CircuitState = {
+ // TODO Add error check for more than one annotation for inlining
+ val (modNames, instNames) = collectAnns(state.circuit, state.annotations)
+ if (modNames.nonEmpty || instNames.nonEmpty) {
+ run(state.circuit, modNames, instNames, state.annotations)
+ } else {
+ state
+ }
+ }
+
+ // Checks the following properties:
+ // 1) All annotated modules exist
+ // 2) All annotated modules are InModules (can be inlined)
+ // 3) All annotated instances exist, and their modules can be inline
+ def check(c: Circuit, moduleNames: Set[ModuleName], instanceNames: Set[ComponentName]): Unit = {
+ val errors = mutable.ArrayBuffer[PassException]()
+ val moduleMap = InstanceKeyGraph(c).moduleMap
+ def checkExists(name: String): Unit =
+ if (!moduleMap.contains(name))
+ errors += new PassException(s"Annotated module does not exist: $name")
+ def checkExternal(name: String): Unit = moduleMap(name) match {
+ case m: ExtModule => errors += new PassException(s"Annotated module cannot be an external module: $name")
+ case _ =>
+ }
+ def checkInstance(cn: ComponentName): Unit = {
+ var containsCN = false
+ def onStmt(name: String)(s: Statement): Statement = {
+ s match {
+ case WDefInstance(_, inst_name, module_name, tpe) =>
+ if (name == inst_name) {
+ containsCN = true
+ checkExternal(module_name)
}
- s map onStmt(name)
- }
- onStmt(cn.name)(moduleMap(cn.module.name).asInstanceOf[Module].body)
- if (!containsCN) errors += new PassException(s"Annotated instance does not exist: ${cn.module.name}.${cn.name}")
+ case _ =>
+ }
+ s.map(onStmt(name))
}
+ onStmt(cn.name)(moduleMap(cn.module.name).asInstanceOf[Module].body)
+ if (!containsCN) errors += new PassException(s"Annotated instance does not exist: ${cn.module.name}.${cn.name}")
+ }
- moduleNames.foreach{mn => checkExists(mn.name)}
- if (errors.nonEmpty) throw new PassExceptions(errors.toSeq)
- moduleNames.foreach{mn => checkExternal(mn.name)}
- if (errors.nonEmpty) throw new PassExceptions(errors.toSeq)
- instanceNames.foreach{cn => checkInstance(cn)}
- if (errors.nonEmpty) throw new PassExceptions(errors.toSeq)
- }
-
+ moduleNames.foreach { mn => checkExists(mn.name) }
+ if (errors.nonEmpty) throw new PassExceptions(errors.toSeq)
+ moduleNames.foreach { mn => checkExternal(mn.name) }
+ if (errors.nonEmpty) throw new PassExceptions(errors.toSeq)
+ instanceNames.foreach { cn => checkInstance(cn) }
+ if (errors.nonEmpty) throw new PassExceptions(errors.toSeq)
+ }
- def run(c: Circuit, modsToInline: Set[ModuleName], instsToInline: Set[ComponentName], annos: AnnotationSeq): CircuitState = {
+ def run(
+ c: Circuit,
+ modsToInline: Set[ModuleName],
+ instsToInline: Set[ComponentName],
+ annos: AnnotationSeq
+ ): CircuitState = {
def getInstancesOf(c: Circuit, modules: Set[String]): Set[(OfModule, Instance)] =
c.modules.foldLeft(Set[(OfModule, Instance)]()) { (set, d) =>
d match {
@@ -125,7 +136,7 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe
case WDefInstance(info, instName, moduleName, instTpe) if modules.contains(moduleName) =>
instances += (OfModule(m.name) -> Instance(instName))
s
- case sx => sx map findInstances
+ case sx => sx.map(findInstances)
}
findInstances(m.body)
instances.toSet ++ set
@@ -135,7 +146,8 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe
// Check annotations and circuit match up
check(c, modsToInline, instsToInline)
val flatModules = modsToInline.map(m => m.name)
- val flatInstances: Set[(OfModule, Instance)] = instsToInline.map(i => OfModule(i.module.name) -> Instance(i.name)) ++ getInstancesOf(c, flatModules)
+ val flatInstances: Set[(OfModule, Instance)] =
+ instsToInline.map(i => OfModule(i.module.name) -> Instance(i.name)) ++ getInstancesOf(c, flatModules)
val iGraph = InstanceKeyGraph(c)
val namespaceMap = collection.mutable.Map[String, Namespace]()
// Map of Module name to Map of instance name to Module name
@@ -144,11 +156,13 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe
/** Add a prefix to all declarations updating a [[Namespace]] and appending to a [[RenameMap]] */
def appendNamePrefix(
currentModule: IsModule,
- nextModule: IsModule,
- prefix: String,
- ns: Namespace,
- renames: mutable.HashMap[String, String],
- renameMap: RenameMap)(s: Statement): Statement = {
+ nextModule: IsModule,
+ prefix: String,
+ ns: Namespace,
+ renames: mutable.HashMap[String, String],
+ renameMap: RenameMap
+ )(s: Statement
+ ): Statement = {
def onName(ofModuleOpt: Option[String])(name: String) = {
if (prefix.nonEmpty && !ns.tryName(prefix + name)) {
throw new Exception(s"Inlining failed. Inlined name '${prefix + name}' already exists")
@@ -164,25 +178,29 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe
}
s match {
- case s: WDefInstance => s.map(onName(Some(s.module))).map(appendNamePrefix(currentModule, nextModule, prefix, ns, renames, renameMap))
- case other => s.map(onName(None)).map(appendNamePrefix(currentModule, nextModule, prefix, ns, renames, renameMap))
+ case s: WDefInstance =>
+ s.map(onName(Some(s.module))).map(appendNamePrefix(currentModule, nextModule, prefix, ns, renames, renameMap))
+ case other =>
+ s.map(onName(None)).map(appendNamePrefix(currentModule, nextModule, prefix, ns, renames, renameMap))
}
}
/** Modify all references */
def appendRefPrefix(
currentModule: IsModule,
- renames: mutable.HashMap[String, String])(s: Statement): Statement = {
- def onExpr(e: Expression): Expression = e match {
- case wr@ WRef(name, _, _, _) =>
- renames.get(name) match {
- case Some(prefixedName) => wr.copy(name = prefixedName)
- case None => wr
- }
- case ex => ex.map(onExpr)
- }
- s.map(onExpr).map(appendRefPrefix(currentModule, renames))
+ renames: mutable.HashMap[String, String]
+ )(s: Statement
+ ): Statement = {
+ def onExpr(e: Expression): Expression = e match {
+ case wr @ WRef(name, _, _, _) =>
+ renames.get(name) match {
+ case Some(prefixedName) => wr.copy(name = prefixedName)
+ case None => wr
+ }
+ case ex => ex.map(onExpr)
}
+ s.map(onExpr).map(appendRefPrefix(currentModule, renames))
+ }
val cache = mutable.HashMap.empty[ModuleTarget, Statement]
@@ -194,16 +212,19 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe
val (renamesMap, renamesSeq) = {
val mutableDiGraph = new MutableDiGraph[(OfModule, Instance)]
// compute instance graph
- instMaps.foreach { case (grandParentOfMod, parents) =>
- parents.foreach { case (parentInst, parentOfMod) =>
- val from = grandParentOfMod -> parentInst
- mutableDiGraph.addVertex(from)
- instMaps(parentOfMod).foreach { case (childInst, _) =>
- val to = parentOfMod -> childInst
- mutableDiGraph.addVertex(to)
- mutableDiGraph.addEdge(from, to)
+ instMaps.foreach {
+ case (grandParentOfMod, parents) =>
+ parents.foreach {
+ case (parentInst, parentOfMod) =>
+ val from = grandParentOfMod -> parentInst
+ mutableDiGraph.addVertex(from)
+ instMaps(parentOfMod).foreach {
+ case (childInst, _) =>
+ val to = parentOfMod -> childInst
+ mutableDiGraph.addVertex(to)
+ mutableDiGraph.addEdge(from, to)
+ }
}
- }
}
val diGraph = DiGraph(mutableDiGraph)
@@ -226,10 +247,12 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe
}
def fixupRefs(
- instMap: collection.Map[Instance, OfModule],
- currentModule: IsModule)(e: Expression): Expression = {
+ instMap: collection.Map[Instance, OfModule],
+ currentModule: IsModule
+ )(e: Expression
+ ): Expression = {
e match {
- case wsf@ WSubField(wr@ WRef(ref, _, InstanceKind, _), field, tpe, gen) =>
+ case wsf @ WSubField(wr @ WRef(ref, _, InstanceKind, _), field, tpe, gen) =>
val inst = currentModule.instOf(ref, instMap(Instance(ref)).value)
val renamesOpt = renamesMap.get(OfModule(currentModule.module) -> Instance(inst.instance))
val port = inst.ref(field)
@@ -242,12 +265,12 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe
}
case None => wsf
}
- case wr@ WRef(name, _, InstanceKind, _) =>
+ case wr @ WRef(name, _, InstanceKind, _) =>
val inst = currentModule.instOf(name, instMap(Instance(name)).value)
val renamesOpt = renamesMap.get(OfModule(currentModule.module) -> Instance(inst.instance))
val comp = currentModule.ref(name)
renamesOpt.flatMap(_.get(comp)).getOrElse(Seq(comp)) match {
- case Seq(car: ReferenceTarget) => wr.copy(name=car.ref)
+ case Seq(car: ReferenceTarget) => wr.copy(name = car.ref)
}
case ex => ex.map(fixupRefs(instMap, currentModule))
}
@@ -258,7 +281,8 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe
val ns = namespaceMap.getOrElseUpdate(currentModuleName, Namespace(iGraph.moduleMap(currentModuleName)))
val instMap = instMaps(OfModule(currentModuleName))
s match {
- case wDef@ WDefInstance(_, instName, modName, _) if flatInstances.contains(OfModule(currentModuleName) -> Instance(instName)) =>
+ case wDef @ WDefInstance(_, instName, modName, _)
+ if flatInstances.contains(OfModule(currentModuleName) -> Instance(instName)) =>
val renames = renamesMap(OfModule(currentModuleName) -> Instance(instName))
val toInline = iGraph.moduleMap(modName) match {
case m: ExtModule => throw new PassException(s"Cannot inline external module ${m.name}")
@@ -269,7 +293,7 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe
val bodyx = {
val module = currentModule.copy(module = modName)
- cache.getOrElseUpdate(module, Block(ports :+ toInline.body) map onStmt(module))
+ cache.getOrElseUpdate(module, Block(ports :+ toInline.body).map(onStmt(module)))
}
val names = "" +: Uniquify
@@ -294,14 +318,14 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe
renamedBody
case sx =>
sx
- .map(fixupRefs(instMap, currentModule))
- .map(onStmt(currentModule))
+ .map(fixupRefs(instMap, currentModule))
+ .map(onStmt(currentModule))
}
}
val flatCircuit = c.copy(modules = c.modules.flatMap {
case m if flatModules.contains(m.name) => None
- case m =>
+ case m =>
Some(m.map(onStmt(ModuleName(m.name, CircuitName(c.main)))))
})
diff --git a/src/main/scala/firrtl/passes/Legalize.scala b/src/main/scala/firrtl/passes/Legalize.scala
index 8b7b733a..5d59e075 100644
--- a/src/main/scala/firrtl/passes/Legalize.scala
+++ b/src/main/scala/firrtl/passes/Legalize.scala
@@ -1,11 +1,11 @@
package firrtl.passes
import firrtl.PrimOps._
-import firrtl.Utils.{BoolType, error, zero}
+import firrtl.Utils.{error, zero, BoolType}
import firrtl.ir._
import firrtl.options.Dependency
import firrtl.transforms.ConstantPropagation
-import firrtl.{Transform, bitWidth}
+import firrtl.{bitWidth, Transform}
import firrtl.Mappers._
// Replace shr by amount >= arg width with 0 for UInts and MSB for SInts
@@ -62,30 +62,31 @@ object Legalize extends Pass {
} else {
val bits = DoPrim(Bits, Seq(c.expr), Seq(w - 1, 0), UIntType(IntWidth(w)))
val expr = t match {
- case UIntType(_) => bits
- case SIntType(_) => DoPrim(AsSInt, Seq(bits), Seq(), SIntType(IntWidth(w)))
+ case UIntType(_) => bits
+ case SIntType(_) => DoPrim(AsSInt, Seq(bits), Seq(), SIntType(IntWidth(w)))
case FixedType(_, IntWidth(p)) => DoPrim(AsFixedPoint, Seq(bits), Seq(p), t)
}
Connect(c.info, c.loc, expr)
}
}
- def run (c: Circuit): Circuit = {
- def legalizeE(expr: Expression): Expression = expr map legalizeE match {
- case prim: DoPrim => prim.op match {
- case Shr => legalizeShiftRight(prim)
- case Pad => legalizePad(prim)
- case Bits | Head | Tail => legalizeBitExtract(prim)
- case _ => prim
- }
+ def run(c: Circuit): Circuit = {
+ def legalizeE(expr: Expression): Expression = expr.map(legalizeE) match {
+ case prim: DoPrim =>
+ prim.op match {
+ case Shr => legalizeShiftRight(prim)
+ case Pad => legalizePad(prim)
+ case Bits | Head | Tail => legalizeBitExtract(prim)
+ case _ => prim
+ }
case e => e // respect pre-order traversal
}
- def legalizeS (s: Statement): Statement = {
+ def legalizeS(s: Statement): Statement = {
val legalizedStmt = s match {
case c: Connect => legalizeConnect(c)
case _ => s
}
- legalizedStmt map legalizeS map legalizeE
+ legalizedStmt.map(legalizeS).map(legalizeE)
}
- c copy (modules = c.modules map (_ map legalizeS))
+ c.copy(modules = c.modules.map(_.map(legalizeS)))
}
}
diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala
index ace4f3e8..ad608cec 100644
--- a/src/main/scala/firrtl/passes/LowerTypes.scala
+++ b/src/main/scala/firrtl/passes/LowerTypes.scala
@@ -3,8 +3,26 @@
package firrtl.passes
import firrtl.analyses.{InstanceKeyGraph, SymbolTable}
-import firrtl.annotations.{CircuitTarget, MemoryInitAnnotation, MemoryRandomInitAnnotation, ModuleTarget, ReferenceTarget}
-import firrtl.{CircuitForm, CircuitState, DependencyAPIMigration, InstanceKind, Kind, MemKind, PortKind, RenameMap, Transform, UnknownForm, Utils}
+import firrtl.annotations.{
+ CircuitTarget,
+ MemoryInitAnnotation,
+ MemoryRandomInitAnnotation,
+ ModuleTarget,
+ ReferenceTarget
+}
+import firrtl.{
+ CircuitForm,
+ CircuitState,
+ DependencyAPIMigration,
+ InstanceKind,
+ Kind,
+ MemKind,
+ PortKind,
+ RenameMap,
+ Transform,
+ UnknownForm,
+ Utils
+}
import firrtl.ir._
import firrtl.options.Dependency
import firrtl.stage.TransformManager.TransformDependency
@@ -20,18 +38,19 @@ import scala.collection.mutable
object LowerTypes extends Transform with DependencyAPIMigration {
override def prerequisites: Seq[TransformDependency] = Seq(
Dependency(RemoveAccesses), // we require all SubAccess nodes to have been removed
- Dependency(CheckTypes), // we require all types to be correct
- Dependency(InferTypes), // we require instance types to be resolved (i.e., DefInstance.tpe != UnknownType)
- Dependency(ExpandConnects) // we require all PartialConnect nodes to have been expanded
+ Dependency(CheckTypes), // we require all types to be correct
+ Dependency(InferTypes), // we require instance types to be resolved (i.e., DefInstance.tpe != UnknownType)
+ Dependency(ExpandConnects) // we require all PartialConnect nodes to have been expanded
)
- override def optionalPrerequisiteOf: Seq[TransformDependency] = Seq.empty
+ override def optionalPrerequisiteOf: Seq[TransformDependency] = Seq.empty
override def invalidates(a: Transform): Boolean = a match {
case ResolveFlows => true // we generate UnknownFlow for now (could be fixed)
- case _ => false
+ case _ => false
}
/** Delimiter used in lowering names */
val delim = "_"
+
/** Expands a chain of referential [[firrtl.ir.Expression]]s into the equivalent lowered name
* @param e [[firrtl.ir.Expression]] made up of _only_ [[firrtl.WRef]], [[firrtl.WSubField]], and [[firrtl.WSubIndex]]
* @return Lowered name of e
@@ -39,8 +58,8 @@ object LowerTypes extends Transform with DependencyAPIMigration {
*/
def loweredName(e: Expression): String = e match {
case e: Reference => e.name
- case e: SubField => s"${loweredName(e.expr)}$delim${e.name}"
- case e: SubIndex => s"${loweredName(e.expr)}$delim${e.value}"
+ case e: SubField => s"${loweredName(e.expr)}$delim${e.name}"
+ case e: SubIndex => s"${loweredName(e.expr)}$delim${e.value}"
}
def loweredName(s: Seq[String]): String = s.mkString(delim)
@@ -48,7 +67,7 @@ object LowerTypes extends Transform with DependencyAPIMigration {
// When memories are lowered to ground type, we have to fix the init annotation or error on it.
val (memInitAnnos, otherAnnos) = state.annotations.partition {
case _: MemoryRandomInitAnnotation => false
- case _: MemoryInitAnnotation => true
+ case _: MemoryInitAnnotation => true
case _ => false
}
val memInitByModule = memInitAnnos.map(_.asInstanceOf[MemoryInitAnnotation]).groupBy(_.target.encapsulatingModule)
@@ -61,14 +80,18 @@ object LowerTypes extends Transform with DependencyAPIMigration {
val newAnnos = otherAnnos ++ resultAndRenames.flatMap(_._3)
// chain module renames in topological order
- val moduleRenames = resultAndRenames.map{ case(m,r, _) => m.name -> r }.toMap
+ val moduleRenames = resultAndRenames.map { case (m, r, _) => m.name -> r }.toMap
val moduleOrderBottomUp = InstanceKeyGraph(result).moduleOrder.reverseIterator
- val renames = moduleOrderBottomUp.map(m => moduleRenames(m.name)).reduce((a,b) => a.andThen(b))
+ val renames = moduleOrderBottomUp.map(m => moduleRenames(m.name)).reduce((a, b) => a.andThen(b))
state.copy(circuit = result, renames = Some(renames), annotations = newAnnos)
}
- private def onModule(c: CircuitTarget, m: DefModule, memoryInit: Seq[MemoryInitAnnotation]): (DefModule, RenameMap, Seq[MemoryInitAnnotation]) = {
+ private def onModule(
+ c: CircuitTarget,
+ m: DefModule,
+ memoryInit: Seq[MemoryInitAnnotation]
+ ): (DefModule, RenameMap, Seq[MemoryInitAnnotation]) = {
val renameMap = RenameMap()
val ref = c.module(m.name)
@@ -86,26 +109,36 @@ object LowerTypes extends Transform with DependencyAPIMigration {
}
// We lower ports in a separate pass in order to ensure that statements inside the module do not influence port names.
- private def lowerPorts(ref: ModuleTarget, m: DefModule, renameMap: RenameMap):
- (DefModule, Seq[(String, Seq[Reference])]) = {
+ private def lowerPorts(
+ ref: ModuleTarget,
+ m: DefModule,
+ renameMap: RenameMap
+ ): (DefModule, Seq[(String, Seq[Reference])]) = {
val namespace = mutable.HashSet[String]() ++ m.ports.map(_.name)
val loweredPortsAndRefs = m.ports.flatMap { p =>
- val fieldsAndRefs = DestructTypes.destruct(ref, Field(p.name, Utils.to_flip(p.direction), p.tpe), namespace, renameMap, Set())
- fieldsAndRefs.map { case (f, ref) =>
- (Port(p.info, f.name, Utils.to_dir(f.flip), f.tpe), ref -> Seq(Reference(f.name, f.tpe, PortKind)))
+ val fieldsAndRefs =
+ DestructTypes.destruct(ref, Field(p.name, Utils.to_flip(p.direction), p.tpe), namespace, renameMap, Set())
+ fieldsAndRefs.map {
+ case (f, ref) =>
+ (Port(p.info, f.name, Utils.to_dir(f.flip), f.tpe), ref -> Seq(Reference(f.name, f.tpe, PortKind)))
}
}
val newM = m match {
- case e : ExtModule => e.copy(ports = loweredPortsAndRefs.map(_._1))
- case mod: Module => mod.copy(ports = loweredPortsAndRefs.map(_._1))
+ case e: ExtModule => e.copy(ports = loweredPortsAndRefs.map(_._1))
+ case mod: Module => mod.copy(ports = loweredPortsAndRefs.map(_._1))
}
(newM, loweredPortsAndRefs.map(_._2))
}
- private def onStatement(s: Statement)(implicit symbols: LoweringTable, memInit: Seq[MemoryInitAnnotation]): Statement = s match {
+ private def onStatement(
+ s: Statement
+ )(
+ implicit symbols: LoweringTable,
+ memInit: Seq[MemoryInitAnnotation]
+ ): Statement = s match {
// declarations
- case d : DefWire =>
- Block(symbols.lower(d.name, d.tpe, firrtl.WireKind).map { case (name, tpe, _) => d.copy(name=name, tpe=tpe) })
+ case d: DefWire =>
+ Block(symbols.lower(d.name, d.tpe, firrtl.WireKind).map { case (name, tpe, _) => d.copy(name = name, tpe = tpe) })
case d @ DefRegister(info, _, _, clock, reset, _) =>
// clock and reset are always of ground type
val loweredClock = onExpression(clock)
@@ -113,41 +146,41 @@ object LowerTypes extends Transform with DependencyAPIMigration {
// It is important to first lower the declaration, because the reset can refer to the register itself!
val loweredRegs = symbols.lower(d.name, d.tpe, firrtl.RegKind)
val inits = Utils.create_exps(d.init).map(onExpression)
- Block(
- loweredRegs.zip(inits).map { case ((name, tpe, _), init) =>
+ Block(loweredRegs.zip(inits).map {
+ case ((name, tpe, _), init) =>
DefRegister(info, name, tpe, loweredClock, loweredReset, init)
})
- case d : DefNode =>
+ case d: DefNode =>
val values = Utils.create_exps(d.value).map(onExpression)
- Block(
- symbols.lower(d.name, d.value.tpe, firrtl.NodeKind).zip(values).map{ case((name, tpe, _), value) =>
+ Block(symbols.lower(d.name, d.value.tpe, firrtl.NodeKind).zip(values).map {
+ case ((name, tpe, _), value) =>
assert(tpe == value.tpe)
DefNode(d.info, name, value)
})
- case d : DefMemory =>
+ case d: DefMemory =>
// TODO: as an optimization, we could just skip ground type memories here.
// This would require that we don't error in getReferences() but instead return the old reference.
val mems = symbols.lower(d)
- if(mems.length > 1 && memInit.exists(_.target.ref == d.name)) {
+ if (mems.length > 1 && memInit.exists(_.target.ref == d.name)) {
val mod = memInit.find(_.target.ref == d.name).get.target.encapsulatingModule
val msg = s"[module $mod] Cannot initialize memory ${d.name} of non ground type ${d.dataType.serialize}"
throw new RuntimeException(msg)
}
Block(mems)
- case d : DefInstance => symbols.lower(d)
+ case d: DefInstance => symbols.lower(d)
// connections
case Connect(info, loc, expr) =>
- if(!expr.tpe.isInstanceOf[GroundType]) {
+ if (!expr.tpe.isInstanceOf[GroundType]) {
throw new RuntimeException(s"LowerTypes expects Connects to have been expanded! ${expr.tpe.serialize}")
}
val rhs = onExpression(expr)
// We can get multiple refs on the lhs because of ground-type memory ports like "clk" which can get duplicated.
val lhs = symbols.getReferences(loc.asInstanceOf[RefLikeExpression])
Block(lhs.map(loc => Connect(info, loc, rhs)))
- case p : PartialConnect =>
+ case p: PartialConnect =>
throw new RuntimeException(s"LowerTypes expects PartialConnects to be resolved! $p")
case IsInvalid(info, expr) =>
- if(!expr.tpe.isInstanceOf[GroundType]) {
+ if (!expr.tpe.isInstanceOf[GroundType]) {
throw new RuntimeException(s"LowerTypes expects IsInvalids to have been expanded! ${expr.tpe.serialize}")
}
// We can get multiple refs on the lhs because of ground-type memory ports like "clk" which can get duplicated.
@@ -172,15 +205,18 @@ object LowerTypes extends Transform with DependencyAPIMigration {
// Holds the first level of the module-level namespace.
// (i.e. everything that can be addressed directly by a Reference node)
private class LoweringSymbolTable extends SymbolTable {
- def declare(name: String, tpe: Type, kind: Kind): Unit = symbols.append(name)
+ def declare(name: String, tpe: Type, kind: Kind): Unit = symbols.append(name)
def declareInstance(name: String, module: String): Unit = symbols.append(name)
private val symbols = mutable.ArrayBuffer[String]()
def getSymbolNames: Iterable[String] = symbols
}
// Lowers types and keeps track of references to lowered types.
-private class LoweringTable(table: LoweringSymbolTable, renameMap: RenameMap, m: ModuleTarget,
- portNameToExprs: Seq[(String, Seq[Reference])]) {
+private class LoweringTable(
+ table: LoweringSymbolTable,
+ renameMap: RenameMap,
+ m: ModuleTarget,
+ portNameToExprs: Seq[(String, Seq[Reference])]) {
private val portNames: Set[String] = portNameToExprs.map(_._2.head.name).toSet
private val namespace = mutable.HashSet[String]() ++ table.getSymbolNames
// Serialized old access string to new ground type reference.
@@ -196,10 +232,11 @@ private class LoweringTable(table: LoweringSymbolTable, renameMap: RenameMap, m:
nameToExprs ++= refs.map { case (name, r) => name -> List(r) }
newInst
}
+
/** used to lower nodes, registers and wires */
def lower(name: String, tpe: Type, kind: Kind, flip: Orientation = Default): Seq[(String, Type, Orientation)] = {
val fieldsAndRefs = DestructTypes.destruct(m, Field(name, flip, tpe), namespace, renameMap, portNames)
- nameToExprs ++= fieldsAndRefs.map{ case (f, ref) => ref -> List(Reference(f.name, f.tpe, kind)) }
+ nameToExprs ++= fieldsAndRefs.map { case (f, ref) => ref -> List(Reference(f.name, f.tpe, kind)) }
fieldsAndRefs.map { case (f, _) => (f.name, f.tpe, f.flip) }
}
def lower(p: Port): Seq[Port] = {
@@ -211,10 +248,10 @@ private class LoweringTable(table: LoweringSymbolTable, renameMap: RenameMap, m:
// We could just use FirrtlNode.serialize here, but we want to make sure there are not SubAccess nodes left.
private def serialize(expr: RefLikeExpression): String = expr match {
- case Reference(name, _, _, _) => name
- case SubField(expr, name, _, _) => serialize(expr.asInstanceOf[RefLikeExpression]) + "." + name
+ case Reference(name, _, _, _) => name
+ case SubField(expr, name, _, _) => serialize(expr.asInstanceOf[RefLikeExpression]) + "." + name
case SubIndex(expr, index, _, _) => serialize(expr.asInstanceOf[RefLikeExpression]) + "[" + index.toString + "]"
- case a : SubAccess =>
+ case a: SubAccess =>
throw new RuntimeException(s"LowerTypes expects all SubAccesses to have been expanded! ${a.serialize}")
}
}
@@ -230,13 +267,18 @@ private object DestructTypes {
* - generates a list of all old reference name that now refer to the particular ground type field
* - updates namespace with all possibly conflicting names
*/
- def destruct(m: ModuleTarget, ref: Field, namespace: Namespace, renameMap: RenameMap, reserved: Set[String]):
- Seq[(Field, String)] = {
+ def destruct(
+ m: ModuleTarget,
+ ref: Field,
+ namespace: Namespace,
+ renameMap: RenameMap,
+ reserved: Set[String]
+ ): Seq[(Field, String)] = {
// field renames (uniquify) are computed bottom up
val (rename, _) = uniquify(ref, namespace, reserved)
// early exit for ground types that do not need renaming
- if(ref.tpe.isInstanceOf[GroundType] && rename.isEmpty) {
+ if (ref.tpe.isInstanceOf[GroundType] && rename.isEmpty) {
return List((ref, ref.name))
}
@@ -253,8 +295,13 @@ private object DestructTypes {
* Note that the list of fields is only of the child fields, and needs a SubField node
* instead of a flat Reference when turning them into access expressions.
*/
- def destructInstance(m: ModuleTarget, instance: DefInstance, namespace: Namespace, renameMap: RenameMap,
- reserved: Set[String]): (DefInstance, Seq[(String, SubField)]) = {
+ def destructInstance(
+ m: ModuleTarget,
+ instance: DefInstance,
+ namespace: Namespace,
+ renameMap: RenameMap,
+ reserved: Set[String]
+ ): (DefInstance, Seq[(String, SubField)]) = {
val (rename, _) = uniquify(Field(instance.name, Default, instance.tpe), namespace, reserved)
val newName = rename.map(_.name).getOrElse(instance.name)
@@ -266,14 +313,14 @@ private object DestructTypes {
}
// rename all references to the instance if necessary
- if(newName != instance.name) {
+ if (newName != instance.name) {
renameMap.record(m.instOf(instance.name, instance.module), m.instOf(newName, instance.module))
}
// The ports do not need to be explicitly renamed here. They are renamed when the module ports are lowered.
val newInstance = instance.copy(name = newName, tpe = BundleType(children.map(_._1)))
val instanceRef = Reference(newName, newInstance.tpe, InstanceKind)
- val refs = children.map{ case(c,r) => extractGroundTypeRefString(r) -> SubField(instanceRef, c.name, c.tpe) }
+ val refs = children.map { case (c, r) => extractGroundTypeRefString(r) -> SubField(instanceRef, c.name, c.tpe) }
(newInstance, refs)
}
@@ -285,8 +332,13 @@ private object DestructTypes {
* e.g. ("mem_a.r.clk", "mem.r.clk") and ("mem_b.r.clk", "mem.r.clk")
* Thus it is appropriate to groupBy old reference string instead of just inserting into a hash table.
*/
- def destructMemory(m: ModuleTarget, mem: DefMemory, namespace: Namespace, renameMap: RenameMap,
- reserved: Set[String]): (Seq[DefMemory], Seq[(String, SubField)]) = {
+ def destructMemory(
+ m: ModuleTarget,
+ mem: DefMemory,
+ namespace: Namespace,
+ renameMap: RenameMap,
+ reserved: Set[String]
+ ): (Seq[DefMemory], Seq[(String, SubField)]) = {
// Uniquify the lowered memory names: When memories get split up into ground types, the access order is changes.
// E.g. `mem.r.data.x` becomes `mem_x.r.data`.
// This is why we need to create the new bundle structure before we can resolve any name clashes.
@@ -301,48 +353,50 @@ private object DestructTypes {
// the "old dummy field" is used as a template for the new memory port types
val oldDummyField = Field("dummy", Default, MemPortUtils.memType(mem.copy(dataType = BoolType)))
- val newMemAndSubFields = res.map { case (field, refs) =>
- val newMem = mem.copy(name = field.name, dataType = field.tpe)
- val newMemRef = m.ref(field.name)
- val memWasRenamed = field.name != mem.name // false iff the dataType was a GroundType
- if(memWasRenamed) { renameMap.record(oldMemRef, newMemRef) }
-
- val newMemReference = Reference(field.name, MemPortUtils.memType(newMem), MemKind)
- val refSuffixes = refs.map(_.component).filterNot(_.isEmpty)
-
- val subFields = oldDummyField.tpe.asInstanceOf[BundleType].fields.flatMap { port =>
- val oldPortRef = oldMemRef.field(port.name)
- val newPortRef = newMemRef.field(port.name)
-
- val newPortType = newMemReference.tpe.asInstanceOf[BundleType].fields.find(_.name == port.name).get.tpe
- val newPortAccess = SubField(newMemReference, port.name, newPortType)
-
- port.tpe.asInstanceOf[BundleType].fields.map { portField =>
- val isDataField = portField.name == "data" || portField.name == "wdata" || portField.name == "rdata"
- val isMaskField = portField.name == "mask" || portField.name == "wmask"
- val isDataOrMaskField = isDataField || isMaskField
- val oldFieldRefs = if(memWasRenamed && isDataOrMaskField) {
- // there might have been multiple different fields which now alias to the same lowered field.
- val oldPortFieldBaseRef = oldPortRef.field(portField.name)
- refSuffixes.map(s => oldPortFieldBaseRef.copy(component = oldPortFieldBaseRef.component ++ s))
- } else {
- List(oldPortRef.field(portField.name))
+ val newMemAndSubFields = res.map {
+ case (field, refs) =>
+ val newMem = mem.copy(name = field.name, dataType = field.tpe)
+ val newMemRef = m.ref(field.name)
+ val memWasRenamed = field.name != mem.name // false iff the dataType was a GroundType
+ if (memWasRenamed) { renameMap.record(oldMemRef, newMemRef) }
+
+ val newMemReference = Reference(field.name, MemPortUtils.memType(newMem), MemKind)
+ val refSuffixes = refs.map(_.component).filterNot(_.isEmpty)
+
+ val subFields = oldDummyField.tpe.asInstanceOf[BundleType].fields.flatMap { port =>
+ val oldPortRef = oldMemRef.field(port.name)
+ val newPortRef = newMemRef.field(port.name)
+
+ val newPortType = newMemReference.tpe.asInstanceOf[BundleType].fields.find(_.name == port.name).get.tpe
+ val newPortAccess = SubField(newMemReference, port.name, newPortType)
+
+ port.tpe.asInstanceOf[BundleType].fields.map { portField =>
+ val isDataField = portField.name == "data" || portField.name == "wdata" || portField.name == "rdata"
+ val isMaskField = portField.name == "mask" || portField.name == "wmask"
+ val isDataOrMaskField = isDataField || isMaskField
+ val oldFieldRefs = if (memWasRenamed && isDataOrMaskField) {
+ // there might have been multiple different fields which now alias to the same lowered field.
+ val oldPortFieldBaseRef = oldPortRef.field(portField.name)
+ refSuffixes.map(s => oldPortFieldBaseRef.copy(component = oldPortFieldBaseRef.component ++ s))
+ } else {
+ List(oldPortRef.field(portField.name))
+ }
+
+ val newPortType = if (isDataField) { newMem.dataType }
+ else { portField.tpe }
+ val newPortFieldAccess = SubField(newPortAccess, portField.name, newPortType)
+
+ // record renames only for the data field which is the only port field of non-ground type
+ val newPortFieldRef = newPortRef.field(portField.name)
+ if (memWasRenamed && isDataOrMaskField) {
+ oldFieldRefs.foreach { o => renameMap.record(o, newPortFieldRef) }
+ }
+
+ val oldFieldStringRef = extractGroundTypeRefString(oldFieldRefs)
+ (oldFieldStringRef, newPortFieldAccess)
}
-
- val newPortType = if(isDataField) { newMem.dataType } else { portField.tpe }
- val newPortFieldAccess = SubField(newPortAccess, portField.name, newPortType)
-
- // record renames only for the data field which is the only port field of non-ground type
- val newPortFieldRef = newPortRef.field(portField.name)
- if(memWasRenamed && isDataOrMaskField) {
- oldFieldRefs.foreach { o => renameMap.record(o, newPortFieldRef) }
- }
-
- val oldFieldStringRef = extractGroundTypeRefString(oldFieldRefs)
- (oldFieldStringRef, newPortFieldAccess)
}
- }
- (newMem, subFields)
+ (newMem, subFields)
}
(newMemAndSubFields.map(_._1), newMemAndSubFields.flatMap(_._2))
@@ -356,22 +410,30 @@ private object DestructTypes {
Field(mem.name, Default, BundleType(fields))
}
- private def recordRenames(fieldToRefs: Seq[(Field, Seq[ReferenceTarget])], renameMap: RenameMap, parent: ParentRef):
- Unit = {
+ private def recordRenames(
+ fieldToRefs: Seq[(Field, Seq[ReferenceTarget])],
+ renameMap: RenameMap,
+ parent: ParentRef
+ ): Unit = {
// TODO: if we group by ReferenceTarget, we could reduce the number of calls to `record`. Is it worth it?
- fieldToRefs.foreach { case(field, refs) =>
- val fieldRef = parent.ref(field.name)
- refs.foreach{ r => renameMap.record(r, fieldRef) }
+ fieldToRefs.foreach {
+ case (field, refs) =>
+ val fieldRef = parent.ref(field.name)
+ refs.foreach { r => renameMap.record(r, fieldRef) }
}
}
private def extractGroundTypeRefString(refs: Seq[ReferenceTarget]): String = {
- if (refs.isEmpty) { "" } else {
+ if (refs.isEmpty) { "" }
+ else {
// Since we depend on ExpandConnects any reference we encounter will be of ground type
// and thus the one with the longest access path.
- refs.reduceLeft((x, y) => if (x.component.length > y.component.length) x else y)
+ refs
+ .reduceLeft((x, y) => if (x.component.length > y.component.length) x else y)
// convert references to strings relative to the module
- .serialize.dropWhile(_ != '>').tail
+ .serialize
+ .dropWhile(_ != '>')
+ .tail
}
}
@@ -385,14 +447,19 @@ private object DestructTypes {
* @return a sequence of ground type fields with new names and, for each field,
* a sequence of old references that should to be renamed to point to the particular field
*/
- private def destruct(prefix: String, oldParent: ParentRef, oldField: Field,
- isVecField: Boolean, rename: Option[RenameNode]): Seq[(Field, Seq[ReferenceTarget])] = {
+ private def destruct(
+ prefix: String,
+ oldParent: ParentRef,
+ oldField: Field,
+ isVecField: Boolean,
+ rename: Option[RenameNode]
+ ): Seq[(Field, Seq[ReferenceTarget])] = {
val newName = rename.map(_.name).getOrElse(oldField.name)
val oldRef = oldParent.ref(oldField.name, isVecField)
oldField.tpe match {
- case _ : GroundType => List((oldField.copy(name = prefix + newName), List(oldRef)))
- case _ : BundleType | _ : VectorType =>
+ case _: GroundType => List((oldField.copy(name = prefix + newName), List(oldRef)))
+ case _: BundleType | _: VectorType =>
val newPrefix = prefix + newName + LowerTypes.delim
val isVecField = oldField.tpe.isInstanceOf[VectorType]
val fields = getFields(oldField.tpe)
@@ -401,7 +468,7 @@ private object DestructTypes {
destruct(newPrefix, RefParentRef(oldRef), f, isVecField, rename.flatMap(_.children.get(f.name)))
}
// the bundle/vec reference refers to all children
- children.map{ case(c, r) => (c, r :+ oldRef) }
+ children.map { case (c, r) => (c, r :+ oldRef) }
}
}
@@ -409,7 +476,8 @@ private object DestructTypes {
/** Implements the core functionality of the old Uniquify pass: rename bundle fields and top-level references
* where necessary in order to avoid name clashes when lowering aggregate type with the `_` delimiter.
- * We don't actually do the rename here but just calculate a rename tree. */
+ * We don't actually do the rename here but just calculate a rename tree.
+ */
private def uniquify(ref: Field, namespace: Namespace, reserved: Set[String]): (Option[RenameNode], Seq[String]) = {
// ensure that there are no name clashes with the list of reserved (port) names
val newRefName = findValidPrefix(ref.name, reserved.contains)
@@ -426,23 +494,23 @@ private object DestructTypes {
// We added f.name in previous map, delete if we change it
val renamed = prefix != ref.name
if (renamed) {
- if(!reserved.contains(ref.name)) namespace -= ref.name
+ if (!reserved.contains(ref.name)) namespace -= ref.name
namespace += prefix
}
val suffixes = renamedFieldNames.map(f => prefix + LowerTypes.delim + f)
val anyChildRenamed = renamedFields.exists(_._1.isDefined)
- val rename = if(renamed || anyChildRenamed){
- val children = renamedFields.map(_._1).zip(fields).collect{ case (Some(r), f) => f.name -> r }.toMap
+ val rename = if (renamed || anyChildRenamed) {
+ val children = renamedFields.map(_._1).zip(fields).collect { case (Some(r), f) => f.name -> r }.toMap
Some(RenameNode(prefix, children))
} else { None }
(rename, suffixes :+ prefix)
- case v : VectorType=>
+ case v: VectorType =>
// if Vecs are to be lowered, we can just treat them like a bundle
uniquify(ref.copy(tpe = vecToBundle(v)), namespace, reserved)
- case _ : GroundType =>
- if(newRefName == ref.name) {
+ case _: GroundType =>
+ if (newRefName == ref.name) {
(None, List(ref.name))
} else {
(Some(RenameNode(newRefName, Map())), List(newRefName))
@@ -452,22 +520,23 @@ private object DestructTypes {
}
/** Appends delim to prefix until no collisions of prefix + elts in names We don't add an _ in the collision check
- * because elts could be Seq("") In this case, we're just really checking if prefix itself collides */
+ * because elts could be Seq("") In this case, we're just really checking if prefix itself collides
+ */
@tailrec
private def findValidPrefix(prefix: String, inNamespace: String => Boolean, elts: Seq[String] = List("")): String = {
elts.find(elt => inNamespace(prefix + elt)) match {
case Some(_) => findValidPrefix(prefix + "_", inNamespace, elts)
- case None => prefix
+ case None => prefix
}
}
private def getFields(tpe: Type): Seq[Field] = tpe match {
case BundleType(fields) => fields
- case v : VectorType => vecToBundle(v).fields
+ case v: VectorType => vecToBundle(v).fields
}
private def vecToBundle(v: VectorType): BundleType = {
- BundleType(( 0 until v.size).map(i => Field(i.toString, Default, v.tpe)))
+ BundleType((0 until v.size).map(i => Field(i.toString, Default, v.tpe)))
}
/** Used to abstract over module and reference parents.
@@ -480,6 +549,7 @@ private object DestructTypes {
}
private case class RefParentRef(r: ReferenceTarget) extends ParentRef {
override def ref(name: String, asVecField: Boolean): ReferenceTarget =
- if(asVecField) { r.index(name.toInt) } else { r.field(name) }
+ if (asVecField) { r.index(name.toInt) }
+ else { r.field(name) }
}
}
diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala
index ca5c2544..79560605 100644
--- a/src/main/scala/firrtl/passes/PadWidths.scala
+++ b/src/main/scala/firrtl/passes/PadWidths.scala
@@ -15,23 +15,21 @@ object PadWidths extends Pass {
override def prerequisites =
((new mutable.LinkedHashSet())
- ++ firrtl.stage.Forms.LowForm
- - Dependency(firrtl.passes.Legalize)
- + Dependency(firrtl.passes.RemoveValidIf)).toSeq
+ ++ firrtl.stage.Forms.LowForm
+ - Dependency(firrtl.passes.Legalize)
+ + Dependency(firrtl.passes.RemoveValidIf)).toSeq
override def optionalPrerequisites = Seq(Dependency[firrtl.transforms.ConstantPropagation])
override def optionalPrerequisiteOf =
- Seq( Dependency(firrtl.passes.memlib.VerilogMemDelays),
- Dependency[SystemVerilogEmitter],
- Dependency[VerilogEmitter] )
+ Seq(Dependency(firrtl.passes.memlib.VerilogMemDelays), Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
override def invalidates(a: Transform): Boolean = a match {
case _: firrtl.transforms.ConstantPropagation | Legalize => true
case _ => false
}
- private def width(t: Type): Int = bitWidth(t).toInt
+ private def width(t: Type): Int = bitWidth(t).toInt
private def width(e: Expression): Int = width(e.tpe)
// Returns an expression with the correct integer width
private def fixup(i: Int)(e: Expression) = {
@@ -54,31 +52,31 @@ object PadWidths extends Pass {
}
// Recursive, updates expression so children exp's have correct widths
- private def onExp(e: Expression): Expression = e map onExp match {
+ private def onExp(e: Expression): Expression = e.map(onExp) match {
case Mux(cond, tval, fval, tpe) =>
Mux(cond, fixup(width(tpe))(tval), fixup(width(tpe))(fval), tpe)
- case ex: ValidIf => ex copy (value = fixup(width(ex.tpe))(ex.value))
- case ex: DoPrim => ex.op match {
- case Lt | Leq | Gt | Geq | Eq | Neq | Not | And | Or | Xor |
- Add | Sub | Mul | Div | Rem | Shr =>
- // sensitive ops
- ex map fixup((ex.args map width foldLeft 0)(math.max))
- case Dshl =>
- // special case as args aren't all same width
- ex copy (op = Dshlw, args = Seq(fixup(width(ex.tpe))(ex.args.head), ex.args(1)))
- case _ => ex
- }
+ case ex: ValidIf => ex.copy(value = fixup(width(ex.tpe))(ex.value))
+ case ex: DoPrim =>
+ ex.op match {
+ case Lt | Leq | Gt | Geq | Eq | Neq | Not | And | Or | Xor | Add | Sub | Mul | Div | Rem | Shr =>
+ // sensitive ops
+ ex.map(fixup((ex.args.map(width).foldLeft(0))(math.max)))
+ case Dshl =>
+ // special case as args aren't all same width
+ ex.copy(op = Dshlw, args = Seq(fixup(width(ex.tpe))(ex.args.head), ex.args(1)))
+ case _ => ex
+ }
case ex => ex
}
// Recursive. Fixes assignments and register initialization widths
- private def onStmt(s: Statement): Statement = s map onExp match {
+ private def onStmt(s: Statement): Statement = s.map(onExp) match {
case sx: Connect =>
- sx copy (expr = fixup(width(sx.loc))(sx.expr))
+ sx.copy(expr = fixup(width(sx.loc))(sx.expr))
case sx: DefRegister =>
- sx copy (init = fixup(width(sx.tpe))(sx.init))
- case sx => sx map onStmt
+ sx.copy(init = fixup(width(sx.tpe))(sx.init))
+ case sx => sx.map(onStmt)
}
- def run(c: Circuit): Circuit = c copy (modules = c.modules map (_ map onStmt))
+ def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(_.map(onStmt)))
}
diff --git a/src/main/scala/firrtl/passes/Pass.scala b/src/main/scala/firrtl/passes/Pass.scala
index 036bd06a..b5eac4ed 100644
--- a/src/main/scala/firrtl/passes/Pass.scala
+++ b/src/main/scala/firrtl/passes/Pass.scala
@@ -8,7 +8,7 @@ import firrtl.{CircuitState, FirrtlUserException, Transform}
* Has an [[UnknownForm]], because larger [[Transform]] should specify form
*/
trait Pass extends Transform with DependencyAPIMigration {
- def run(c: Circuit): Circuit
+ def run(c: Circuit): Circuit
def execute(state: CircuitState): CircuitState = state.copy(circuit = run(state.circuit))
}
diff --git a/src/main/scala/firrtl/passes/PullMuxes.scala b/src/main/scala/firrtl/passes/PullMuxes.scala
index b805b5fc..27543d63 100644
--- a/src/main/scala/firrtl/passes/PullMuxes.scala
+++ b/src/main/scala/firrtl/passes/PullMuxes.scala
@@ -11,38 +11,50 @@ object PullMuxes extends Pass {
override def invalidates(a: Transform) = false
def run(c: Circuit): Circuit = {
- def pull_muxes_e(e: Expression): Expression = e map pull_muxes_e match {
- case ex: WSubField => ex.expr match {
- case exx: Mux => Mux(exx.cond,
- WSubField(exx.tval, ex.name, ex.tpe, ex.flow),
- WSubField(exx.fval, ex.name, ex.tpe, ex.flow), ex.tpe)
- case exx: ValidIf => ValidIf(exx.cond,
- WSubField(exx.value, ex.name, ex.tpe, ex.flow), ex.tpe)
- case _ => ex // case exx => exx causes failed tests
- }
- case ex: WSubIndex => ex.expr match {
- case exx: Mux => Mux(exx.cond,
- WSubIndex(exx.tval, ex.value, ex.tpe, ex.flow),
- WSubIndex(exx.fval, ex.value, ex.tpe, ex.flow), ex.tpe)
- case exx: ValidIf => ValidIf(exx.cond,
- WSubIndex(exx.value, ex.value, ex.tpe, ex.flow), ex.tpe)
- case _ => ex // case exx => exx causes failed tests
- }
- case ex: WSubAccess => ex.expr match {
- case exx: Mux => Mux(exx.cond,
- WSubAccess(exx.tval, ex.index, ex.tpe, ex.flow),
- WSubAccess(exx.fval, ex.index, ex.tpe, ex.flow), ex.tpe)
- case exx: ValidIf => ValidIf(exx.cond,
- WSubAccess(exx.value, ex.index, ex.tpe, ex.flow), ex.tpe)
- case _ => ex // case exx => exx causes failed tests
- }
- case ex => ex
- }
- def pull_muxes(s: Statement): Statement = s map pull_muxes map pull_muxes_e
- val modulesx = c.modules.map {
- case (m:Module) => Module(m.info, m.name, m.ports, pull_muxes(m.body))
- case (m:ExtModule) => m
- }
- Circuit(c.info, modulesx, c.main)
- }
+ def pull_muxes_e(e: Expression): Expression = e.map(pull_muxes_e) match {
+ case ex: WSubField =>
+ ex.expr match {
+ case exx: Mux =>
+ Mux(
+ exx.cond,
+ WSubField(exx.tval, ex.name, ex.tpe, ex.flow),
+ WSubField(exx.fval, ex.name, ex.tpe, ex.flow),
+ ex.tpe
+ )
+ case exx: ValidIf => ValidIf(exx.cond, WSubField(exx.value, ex.name, ex.tpe, ex.flow), ex.tpe)
+ case _ => ex // case exx => exx causes failed tests
+ }
+ case ex: WSubIndex =>
+ ex.expr match {
+ case exx: Mux =>
+ Mux(
+ exx.cond,
+ WSubIndex(exx.tval, ex.value, ex.tpe, ex.flow),
+ WSubIndex(exx.fval, ex.value, ex.tpe, ex.flow),
+ ex.tpe
+ )
+ case exx: ValidIf => ValidIf(exx.cond, WSubIndex(exx.value, ex.value, ex.tpe, ex.flow), ex.tpe)
+ case _ => ex // case exx => exx causes failed tests
+ }
+ case ex: WSubAccess =>
+ ex.expr match {
+ case exx: Mux =>
+ Mux(
+ exx.cond,
+ WSubAccess(exx.tval, ex.index, ex.tpe, ex.flow),
+ WSubAccess(exx.fval, ex.index, ex.tpe, ex.flow),
+ ex.tpe
+ )
+ case exx: ValidIf => ValidIf(exx.cond, WSubAccess(exx.value, ex.index, ex.tpe, ex.flow), ex.tpe)
+ case _ => ex // case exx => exx causes failed tests
+ }
+ case ex => ex
+ }
+ def pull_muxes(s: Statement): Statement = s.map(pull_muxes).map(pull_muxes_e)
+ val modulesx = c.modules.map {
+ case (m: Module) => Module(m.info, m.name, m.ports, pull_muxes(m.body))
+ case (m: ExtModule) => m
+ }
+ Circuit(c.info, modulesx, c.main)
+ }
}
diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala
index 18db5939..015346ff 100644
--- a/src/main/scala/firrtl/passes/RemoveAccesses.scala
+++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala
@@ -2,7 +2,7 @@
package firrtl.passes
-import firrtl.{Namespace, Transform, WRef, WSubAccess, WSubIndex, WSubField}
+import firrtl.{Namespace, Transform, WRef, WSubAccess, WSubField, WSubIndex}
import firrtl.PrimOps.{And, Eq}
import firrtl.ir._
import firrtl.Mappers._
@@ -17,10 +17,12 @@ import scala.collection.mutable
object RemoveAccesses extends Pass {
override def prerequisites =
- Seq( Dependency(PullMuxes),
- Dependency(ZeroLengthVecs),
- Dependency(ReplaceAccesses),
- Dependency(ExpandConnects) ) ++ firrtl.stage.Forms.Deduped
+ Seq(
+ Dependency(PullMuxes),
+ Dependency(ZeroLengthVecs),
+ Dependency(ReplaceAccesses),
+ Dependency(ExpandConnects)
+ ) ++ firrtl.stage.Forms.Deduped
override def invalidates(a: Transform): Boolean = a match {
case Uniquify | ResolveKinds | ResolveFlows => true
@@ -28,8 +30,8 @@ object RemoveAccesses extends Pass {
}
private def AND(e1: Expression, e2: Expression) =
- if(e1 == one) e2
- else if(e2 == one) e1
+ if (e1 == one) e2
+ else if (e2 == one) e1
else DoPrim(And, Seq(e1, e2), Nil, BoolType)
private def EQV(e1: Expression, e2: Expression): Expression =
@@ -45,30 +47,35 @@ object RemoveAccesses extends Pass {
* Seq(Location(a[0], UIntLiteral(0)), Location(a[1], UIntLiteral(1)))
*/
private def getLocations(e: Expression): Seq[Location] = e match {
- case e: WRef => create_exps(e).map(Location(_,one))
+ case e: WRef => create_exps(e).map(Location(_, one))
case e: WSubIndex =>
val ls = getLocations(e.expr)
val start = get_point(e)
val end = start + get_size(e.tpe)
val stride = get_size(e.expr.tpe)
- for ((l, i) <- ls.zipWithIndex
- if ((i % stride) >= start) & ((i % stride) < end)) yield l
+ for (
+ (l, i) <- ls.zipWithIndex
+ if ((i % stride) >= start) & ((i % stride) < end)
+ ) yield l
case e: WSubField =>
val ls = getLocations(e.expr)
val start = get_point(e)
val end = start + get_size(e.tpe)
val stride = get_size(e.expr.tpe)
- for ((l, i) <- ls.zipWithIndex
- if ((i % stride) >= start) & ((i % stride) < end)) yield l
+ for (
+ (l, i) <- ls.zipWithIndex
+ if ((i % stride) >= start) & ((i % stride) < end)
+ ) yield l
case e: WSubAccess =>
val ls = getLocations(e.expr)
val stride = get_size(e.tpe)
val wrap = e.expr.tpe.asInstanceOf[VectorType].size
- ls.zipWithIndex map {case (l, i) =>
- val c = (i / stride) % wrap
- val basex = l.base
- val guardx = AND(l.guard,EQV(UIntLiteral(c),e.index))
- Location(basex,guardx)
+ ls.zipWithIndex.map {
+ case (l, i) =>
+ val c = (i / stride) % wrap
+ val basex = l.base
+ val guardx = AND(l.guard, EQV(UIntLiteral(c), e.index))
+ Location(basex, guardx)
}
}
@@ -78,10 +85,10 @@ object RemoveAccesses extends Pass {
var ret: Boolean = false
def rec_has_access(e: Expression): Expression = {
e match {
- case _ : WSubAccess => ret = true
+ case _: WSubAccess => ret = true
case _ =>
}
- e map rec_has_access
+ e.map(rec_has_access)
}
rec_has_access(e)
ret
@@ -90,7 +97,7 @@ object RemoveAccesses extends Pass {
// This improves the performance of this pass
private val createExpsCache = mutable.HashMap[Expression, Seq[Expression]]()
private def create_exps(e: Expression) =
- createExpsCache getOrElseUpdate (e, firrtl.Utils.create_exps(e))
+ createExpsCache.getOrElseUpdate(e, firrtl.Utils.create_exps(e))
def run(c: Circuit): Circuit = {
def remove_m(m: Module): Module = {
@@ -105,21 +112,21 @@ object RemoveAccesses extends Pass {
*/
val stmts = mutable.ArrayBuffer[Statement]()
def removeSource(e: Expression): Expression = e match {
- case (_:WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if hasAccess(e) =>
+ case (_: WSubAccess | _: WSubField | _: WSubIndex | _: WRef) if hasAccess(e) =>
val rs = getLocations(e)
- rs find (x => x.guard != one) match {
+ rs.find(x => x.guard != one) match {
case None => throwInternalError(s"removeSource: shouldn't be here - $e")
case Some(_) =>
val (wire, temp) = create_temp(e)
val temps = create_exps(temp)
def getTemp(i: Int) = temps(i % temps.size)
stmts += wire
- rs.zipWithIndex foreach {
+ rs.zipWithIndex.foreach {
case (x, i) if i < temps.size =>
- stmts += IsInvalid(get_info(s),getTemp(i))
- stmts += Conditionally(get_info(s),x.guard,Connect(get_info(s),getTemp(i),x.base),EmptyStmt)
+ stmts += IsInvalid(get_info(s), getTemp(i))
+ stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt)
case (x, i) =>
- stmts += Conditionally(get_info(s),x.guard,Connect(get_info(s),getTemp(i),x.base),EmptyStmt)
+ stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt)
}
temp
}
@@ -129,14 +136,16 @@ object RemoveAccesses extends Pass {
/** Replaces a subaccess in a given sink expression
*/
def removeSink(info: Info, loc: Expression): Expression = loc match {
- case (_: WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if hasAccess(loc) =>
+ case (_: WSubAccess | _: WSubField | _: WSubIndex | _: WRef) if hasAccess(loc) =>
val ls = getLocations(loc)
- if (ls.size == 1 & weq(ls.head.guard,one)) loc
+ if (ls.size == 1 & weq(ls.head.guard, one)) loc
else {
val (wire, temp) = create_temp(loc)
stmts += wire
- ls foreach (x => stmts +=
- Conditionally(info,x.guard,Connect(info,x.base,temp),EmptyStmt))
+ ls.foreach(x =>
+ stmts +=
+ Conditionally(info, x.guard, Connect(info, x.base, temp), EmptyStmt)
+ )
temp
}
case _ => loc
@@ -150,7 +159,7 @@ object RemoveAccesses extends Pass {
case w: WSubAccess => removeSource(WSubAccess(w.expr, fixSource(w.index), w.tpe, w.flow))
//case w: WSubIndex => removeSource(w)
//case w: WSubField => removeSource(w)
- case x => x map fixSource
+ case x => x.map(fixSource)
}
/** Recursively walks a sink expression and fixes all subaccesses
@@ -159,13 +168,13 @@ object RemoveAccesses extends Pass {
*/
def fixSink(e: Expression): Expression = e match {
case w: WSubAccess => WSubAccess(fixSink(w.expr), fixSource(w.index), w.tpe, w.flow)
- case x => x map fixSink
+ case x => x.map(fixSink)
}
val sx = s match {
case Connect(info, loc, exp) =>
Connect(info, removeSink(info, fixSink(loc)), fixSource(exp))
- case sxx => sxx map fixSource map onStmt
+ case sxx => sxx.map(fixSource).map(onStmt)
}
stmts += sx
if (stmts.size != 1) Block(stmts.toSeq) else stmts(0)
@@ -173,9 +182,9 @@ object RemoveAccesses extends Pass {
Module(m.info, m.name, m.ports, squashEmpty(onStmt(m.body)))
}
- c copy (modules = c.modules map {
+ c.copy(modules = c.modules.map {
case m: ExtModule => m
- case m: Module => remove_m(m)
+ case m: Module => remove_m(m)
})
}
}
diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
index 61fd6258..624138ab 100644
--- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
+++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
@@ -17,8 +17,7 @@ case class DataRef(exp: Expression, source: String, sink: String, mask: String,
object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.ChirrtlForm ++
- Seq( Dependency(passes.CInferTypes),
- Dependency(passes.CInferMDir) )
+ Seq(Dependency(passes.CInferTypes), Dependency(passes.CInferMDir))
override def invalidates(a: Transform) = false
@@ -31,10 +30,14 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
def create_all_exps(ex: Expression): Seq[Expression] = ex.tpe match {
case _: GroundType => Seq(ex)
- case t: BundleType => (t.fields foldLeft Seq[Expression]())((exps, f) =>
- exps ++ create_all_exps(SubField(ex, f.name, f.tpe))) ++ Seq(ex)
- case t: VectorType => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) =>
- exps ++ create_all_exps(SubIndex(ex, i, t.tpe))) ++ Seq(ex)
+ case t: BundleType =>
+ (t.fields.foldLeft(Seq[Expression]()))((exps, f) => exps ++ create_all_exps(SubField(ex, f.name, f.tpe))) ++ Seq(
+ ex
+ )
+ case t: VectorType =>
+ ((0 until t.size).foldLeft(Seq[Expression]()))((exps, i) =>
+ exps ++ create_all_exps(SubIndex(ex, i, t.tpe))
+ ) ++ Seq(ex)
case UnknownType => Seq(ex)
}
@@ -42,17 +45,18 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
case ex: Mux =>
val e1s = create_exps(ex.tval)
val e2s = create_exps(ex.fval)
- (e1s zip e2s) map { case (e1, e2) => Mux(ex.cond, e1, e2, mux_type(e1, e2)) }
+ (e1s.zip(e2s)).map { case (e1, e2) => Mux(ex.cond, e1, e2, mux_type(e1, e2)) }
case ex: ValidIf =>
- create_exps(ex.value) map (e1 => ValidIf(ex.cond, e1, e1.tpe))
- case ex => ex.tpe match {
- case _: GroundType => Seq(ex)
- case t: BundleType => (t.fields foldLeft Seq[Expression]())((exps, f) =>
- exps ++ create_exps(SubField(ex, f.name, f.tpe)))
- case t: VectorType => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) =>
- exps ++ create_exps(SubIndex(ex, i, t.tpe)))
- case UnknownType => Seq(ex)
- }
+ create_exps(ex.value).map(e1 => ValidIf(ex.cond, e1, e1.tpe))
+ case ex =>
+ ex.tpe match {
+ case _: GroundType => Seq(ex)
+ case t: BundleType =>
+ (t.fields.foldLeft(Seq[Expression]()))((exps, f) => exps ++ create_exps(SubField(ex, f.name, f.tpe)))
+ case t: VectorType =>
+ ((0 until t.size).foldLeft(Seq[Expression]()))((exps, i) => exps ++ create_exps(SubIndex(ex, i, t.tpe)))
+ case UnknownType => Seq(ex)
+ }
}
private def EMPs: MPorts = MPorts(ArrayBuffer[MPort](), ArrayBuffer[MPort](), ArrayBuffer[MPort]())
@@ -61,40 +65,48 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
s match {
case sx: CDefMemory if sx.seq => smems += sx.name
case sx: CDefMPort =>
- val p = mports getOrElse (sx.mem, EMPs)
+ val p = mports.getOrElse(sx.mem, EMPs)
sx.direction match {
- case MRead => p.readers += MPort(sx.name, sx.exps(1))
- case MWrite => p.writers += MPort(sx.name, sx.exps(1))
+ case MRead => p.readers += MPort(sx.name, sx.exps(1))
+ case MWrite => p.writers += MPort(sx.name, sx.exps(1))
case MReadWrite => p.readwriters += MPort(sx.name, sx.exps(1))
- case MInfer => // direction may not be inferred if it's not being used
+ case MInfer => // direction may not be inferred if it's not being used
}
mports(sx.mem) = p
case _ =>
}
- s map collect_smems_and_mports(mports, smems)
+ s.map(collect_smems_and_mports(mports, smems))
}
- def collect_refs(mports: MPortMap, smems: SeqMemSet, types: MPortTypeMap,
- refs: DataRefMap, raddrs: AddrMap, renames: RenameMap)(s: Statement): Statement = s match {
+ def collect_refs(
+ mports: MPortMap,
+ smems: SeqMemSet,
+ types: MPortTypeMap,
+ refs: DataRefMap,
+ raddrs: AddrMap,
+ renames: RenameMap
+ )(s: Statement
+ ): Statement = s match {
case sx: CDefMemory =>
types(sx.name) = sx.tpe
- val taddr = UIntType(IntWidth(1 max getUIntWidth(sx.size - 1)))
+ val taddr = UIntType(IntWidth(1.max(getUIntWidth(sx.size - 1))))
val tdata = sx.tpe
- def set_poison(vec: scala.collection.Seq[MPort]) = vec.toSeq.flatMap (r => Seq(
- IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "addr", taddr)),
- IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "clk", ClockType))
- ))
- def set_enable(vec: scala.collection.Seq[MPort], en: String) = vec.toSeq.map (r =>
- Connect(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), en, BoolType), zero)
+ def set_poison(vec: scala.collection.Seq[MPort]) = vec.toSeq.flatMap(r =>
+ Seq(
+ IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "addr", taddr)),
+ IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "clk", ClockType))
+ )
)
+ def set_enable(vec: scala.collection.Seq[MPort], en: String) =
+ vec.toSeq.map(r => Connect(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), en, BoolType), zero))
def set_write(vec: scala.collection.Seq[MPort], data: String, mask: String) = vec.toSeq.flatMap { r =>
val tmask = createMask(sx.tpe)
val portRef = SubField(Reference(sx.name, ut), r.name, ut)
Seq(IsInvalid(sx.info, SubField(portRef, data, tdata)), IsInvalid(sx.info, SubField(portRef, mask, tmask)))
}
- val rds = (mports getOrElse (sx.name, EMPs)).readers
- val wrs = (mports getOrElse (sx.name, EMPs)).writers
- val rws = (mports getOrElse (sx.name, EMPs)).readwriters
+ val rds = (mports.getOrElse(sx.name, EMPs)).readers
+ val wrs = (mports.getOrElse(sx.name, EMPs)).writers
+ val rws = (mports.getOrElse(sx.name, EMPs)).readwriters
val stmts = set_poison(rds) ++
set_enable(rds, "en") ++
set_poison(wrs) ++
@@ -104,8 +116,18 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
set_enable(rws, "wmode") ++
set_enable(rws, "en") ++
set_write(rws, "wdata", "wmask")
- val mem = DefMemory(sx.info, sx.name, sx.tpe, sx.size, 1, if (sx.seq) 1 else 0,
- rds.map(_.name).toSeq, wrs.map(_.name).toSeq, rws.map(_.name).toSeq, sx.readUnderWrite)
+ val mem = DefMemory(
+ sx.info,
+ sx.name,
+ sx.tpe,
+ sx.size,
+ 1,
+ if (sx.seq) 1 else 0,
+ rds.map(_.name).toSeq,
+ wrs.map(_.name).toSeq,
+ rws.map(_.name).toSeq,
+ sx.readUnderWrite
+ )
Block(mem +: stmts)
case sx: CDefMPort =>
types.get(sx.mem) match {
@@ -130,8 +152,8 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
val es = create_all_exps(WRef(sx.name, sx.tpe))
val rs = create_all_exps(WRef(s"${sx.mem}.${sx.name}.rdata", sx.tpe))
val ws = create_all_exps(WRef(s"${sx.mem}.${sx.name}.wdata", sx.tpe))
- ((es zip rs) zip ws) map {
- case ((e, r), w) => renames.rename(e.serialize, Seq(r.serialize, w.serialize))
+ ((es.zip(rs)).zip(ws)).map {
+ case ((e, r), w) => renames.rename(e.serialize, Seq(r.serialize, w.serialize))
}
case MWrite =>
refs(sx.name) = DataRef(portRef, "data", "data", "mask", rdwrite = false)
@@ -142,7 +164,7 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
renames.rename(sx.name, s"${sx.mem}.${sx.name}.data")
val es = create_all_exps(WRef(sx.name, sx.tpe))
val ws = create_all_exps(WRef(s"${sx.mem}.${sx.name}.data", sx.tpe))
- (es zip ws) map {
+ (es.zip(ws)).map {
case (e, w) => renames.rename(e.serialize, w.serialize)
}
case MRead =>
@@ -157,63 +179,69 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
renames.rename(sx.name, s"${sx.mem}.${sx.name}.data")
val es = create_all_exps(WRef(sx.name, sx.tpe))
val rs = create_all_exps(WRef(s"${sx.mem}.${sx.name}.data", sx.tpe))
- (es zip rs) map {
+ (es.zip(rs)).map {
case (e, r) => renames.rename(e.serialize, r.serialize)
}
case MInfer => // do nothing if it's not being used
}
- Block(List() ++
- (addrs.map (x => Connect(sx.info, SubField(portRef, x, ut), sx.exps.head))) ++
- (clks map (x => Connect(sx.info, SubField(portRef, x, ut), sx.exps(1)))) ++
- (ens map (x => Connect(sx.info,SubField(portRef, x, ut), one))) ++
- masks.map(lhs => Connect(sx.info, lhs, zero))
+ Block(
+ List() ++
+ (addrs.map(x => Connect(sx.info, SubField(portRef, x, ut), sx.exps.head))) ++
+ (clks.map(x => Connect(sx.info, SubField(portRef, x, ut), sx.exps(1)))) ++
+ (ens.map(x => Connect(sx.info, SubField(portRef, x, ut), one))) ++
+ masks.map(lhs => Connect(sx.info, lhs, zero))
)
- case sx => sx map collect_refs(mports, smems, types, refs, raddrs, renames)
+ case sx => sx.map(collect_refs(mports, smems, types, refs, raddrs, renames))
}
def get_mask(refs: DataRefMap)(e: Expression): Expression =
- e map get_mask(refs) match {
- case ex: Reference => refs get ex.name match {
- case None => ex
- case Some(p) => SubField(p.exp, p.mask, createMask(ex.tpe))
- }
+ e.map(get_mask(refs)) match {
+ case ex: Reference =>
+ refs.get(ex.name) match {
+ case None => ex
+ case Some(p) => SubField(p.exp, p.mask, createMask(ex.tpe))
+ }
case ex => ex
}
def remove_chirrtl_s(refs: DataRefMap, raddrs: AddrMap)(s: Statement): Statement = {
var has_write_mport = false
var has_readwrite_mport: Option[Expression] = None
- var has_read_mport: Option[Expression] = None
+ var has_read_mport: Option[Expression] = None
def remove_chirrtl_e(g: Flow)(e: Expression): Expression = e match {
- case Reference(name, tpe, _, _) => refs get name match {
- case Some(p) => g match {
- case SinkFlow =>
- has_write_mport = true
- if (p.rdwrite) has_readwrite_mport = Some(SubField(p.exp, "wmode", BoolType))
- SubField(p.exp, p.sink, tpe)
- case SourceFlow =>
- SubField(p.exp, p.source, tpe)
- }
- case None => g match {
- case SinkFlow => raddrs get name match {
- case Some(en) => has_read_mport = Some(en) ; e
- case None => e
- }
- case SourceFlow => e
+ case Reference(name, tpe, _, _) =>
+ refs.get(name) match {
+ case Some(p) =>
+ g match {
+ case SinkFlow =>
+ has_write_mport = true
+ if (p.rdwrite) has_readwrite_mport = Some(SubField(p.exp, "wmode", BoolType))
+ SubField(p.exp, p.sink, tpe)
+ case SourceFlow =>
+ SubField(p.exp, p.source, tpe)
+ }
+ case None =>
+ g match {
+ case SinkFlow =>
+ raddrs.get(name) match {
+ case Some(en) => has_read_mport = Some(en); e
+ case None => e
+ }
+ case SourceFlow => e
+ }
}
- }
- case SubAccess(expr, index, tpe, _) => SubAccess(
- remove_chirrtl_e(g)(expr), remove_chirrtl_e(SourceFlow)(index), tpe)
- case ex => ex map remove_chirrtl_e(g)
- }
- s match {
+ case SubAccess(expr, index, tpe, _) =>
+ SubAccess(remove_chirrtl_e(g)(expr), remove_chirrtl_e(SourceFlow)(index), tpe)
+ case ex => ex.map(remove_chirrtl_e(g))
+ }
+ s match {
case DefNode(info, name, value) =>
val valuex = remove_chirrtl_e(SourceFlow)(value)
val sx = DefNode(info, name, valuex)
// Check node is used for read port address
remove_chirrtl_e(SinkFlow)(Reference(name, value.tpe))
has_read_mport match {
- case None => sx
+ case None => sx
case Some(en) => Block(sx, Connect(info, en, one))
}
case Connect(info, loc, expr) =>
@@ -222,14 +250,14 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
val sx = Connect(info, locx, rocx)
val stmts = ArrayBuffer[Statement]()
has_read_mport match {
- case None =>
+ case None =>
case Some(en) => stmts += Connect(info, en, one)
}
if (has_write_mport) {
val locs = create_exps(get_mask(refs)(loc))
- stmts ++= (locs map (x => Connect(info, x, one)))
+ stmts ++= (locs.map(x => Connect(info, x, one)))
has_readwrite_mport match {
- case None =>
+ case None =>
case Some(wmode) => stmts += Connect(info, wmode, one)
}
}
@@ -240,20 +268,20 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
val sx = PartialConnect(info, locx, rocx)
val stmts = ArrayBuffer[Statement]()
has_read_mport match {
- case None =>
+ case None =>
case Some(en) => stmts += Connect(info, en, one)
}
if (has_write_mport) {
val ls = get_valid_points(loc.tpe, expr.tpe, Default, Default)
val locs = create_exps(get_mask(refs)(loc))
- stmts ++= (ls map { case (x, _) => Connect(info, locs(x), one) })
+ stmts ++= (ls.map { case (x, _) => Connect(info, locs(x), one) })
has_readwrite_mport match {
- case None =>
+ case None =>
case Some(wmode) => stmts += Connect(info, wmode, one)
}
}
if (stmts.isEmpty) sx else Block(sx +: stmts.toSeq)
- case sx => sx map remove_chirrtl_s(refs, raddrs) map remove_chirrtl_e(SourceFlow)
+ case sx => sx.map(remove_chirrtl_s(refs, raddrs)).map(remove_chirrtl_e(SourceFlow))
}
}
@@ -264,16 +292,16 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
val refs = new DataRefMap
val raddrs = new AddrMap
renames.setModule(m.name)
- (m map collect_smems_and_mports(mports, smems)
- map collect_refs(mports, smems, types, refs, raddrs, renames)
- map remove_chirrtl_s(refs, raddrs))
+ (m.map(collect_smems_and_mports(mports, smems))
+ .map(collect_refs(mports, smems, types, refs, raddrs, renames))
+ .map(remove_chirrtl_s(refs, raddrs)))
}
def execute(state: CircuitState): CircuitState = {
val c = state.circuit
val renames = RenameMap()
renames.setCircuit(c.main)
- val result = c copy (modules = c.modules map remove_chirrtl_m(renames))
+ val result = c.copy(modules = c.modules.map(remove_chirrtl_m(renames)))
state.copy(circuit = result, renames = Some(renames))
}
}
diff --git a/src/main/scala/firrtl/passes/RemoveEmpty.scala b/src/main/scala/firrtl/passes/RemoveEmpty.scala
index eabf667c..eb25dcc4 100644
--- a/src/main/scala/firrtl/passes/RemoveEmpty.scala
+++ b/src/main/scala/firrtl/passes/RemoveEmpty.scala
@@ -15,7 +15,7 @@ object RemoveEmpty extends Pass with DependencyAPIMigration {
private def onModule(m: DefModule): DefModule = {
m match {
- case m: Module => Module(m.info, m.name, m.ports, Utils.squashEmpty(m.body))
+ case m: Module => Module(m.info, m.name, m.ports, Utils.squashEmpty(m.body))
case m: ExtModule => m
}
}
diff --git a/src/main/scala/firrtl/passes/RemoveIntervals.scala b/src/main/scala/firrtl/passes/RemoveIntervals.scala
index 7059526c..657b4356 100644
--- a/src/main/scala/firrtl/passes/RemoveIntervals.scala
+++ b/src/main/scala/firrtl/passes/RemoveIntervals.scala
@@ -13,14 +13,13 @@ import firrtl.options.Dependency
import scala.math.BigDecimal.RoundingMode._
class WrapWithRemainder(info: Info, mname: String, wrap: DoPrim)
- extends PassException({
- val toWrap = wrap.args.head.serialize
- val toWrapTpe = wrap.args.head.tpe.serialize
- val wrapTo = wrap.args(1).serialize
- val wrapToTpe = wrap.args(1).tpe.serialize
- s"$info: [module $mname] Wraps with remainder currently unsupported: $toWrap:$toWrapTpe cannot be wrapped to $wrapTo's type $wrapToTpe"
- })
-
+ extends PassException({
+ val toWrap = wrap.args.head.serialize
+ val toWrapTpe = wrap.args.head.tpe.serialize
+ val wrapTo = wrap.args(1).serialize
+ val wrapToTpe = wrap.args(1).tpe.serialize
+ s"$info: [module $mname] Wraps with remainder currently unsupported: $toWrap:$toWrapTpe cannot be wrapped to $wrapTo's type $wrapToTpe"
+ })
/** Replaces IntervalType with SIntType, three AST walks:
* 1) Align binary points
@@ -39,48 +38,50 @@ class WrapWithRemainder(info: Info, mname: String, wrap: DoPrim)
class RemoveIntervals extends Pass {
override def prerequisites: Seq[Dependency[Transform]] =
- Seq( Dependency(PullMuxes),
- Dependency(ReplaceAccesses),
- Dependency(ExpandConnects),
- Dependency(RemoveAccesses),
- Dependency[ExpandWhensAndCheck] ) ++ firrtl.stage.Forms.Deduped
+ Seq(
+ Dependency(PullMuxes),
+ Dependency(ReplaceAccesses),
+ Dependency(ExpandConnects),
+ Dependency(RemoveAccesses),
+ Dependency[ExpandWhensAndCheck]
+ ) ++ firrtl.stage.Forms.Deduped
override def invalidates(transform: Transform): Boolean = {
transform match {
case InferTypes | ResolveKinds => true
- case _ => false
+ case _ => false
}
}
def run(c: Circuit): Circuit = {
val alignedCircuit = c
val errors = new Errors()
- val wiredCircuit = alignedCircuit map makeWireModule
- val replacedCircuit = wiredCircuit map replaceModuleInterval(errors)
+ val wiredCircuit = alignedCircuit.map(makeWireModule)
+ val replacedCircuit = wiredCircuit.map(replaceModuleInterval(errors))
errors.trigger()
replacedCircuit
}
/* Replace interval types */
private def replaceModuleInterval(errors: Errors)(m: DefModule): DefModule =
- m map replaceStmtInterval(errors, m.name) map replacePortInterval
+ m.map(replaceStmtInterval(errors, m.name)).map(replacePortInterval)
private def replaceStmtInterval(errors: Errors, mname: String)(s: Statement): Statement = {
val info = s match {
case h: HasInfo => h.info
case _ => NoInfo
}
- s map replaceTypeInterval map replaceStmtInterval(errors, mname) map replaceExprInterval(errors, info, mname)
+ s.map(replaceTypeInterval).map(replaceStmtInterval(errors, mname)).map(replaceExprInterval(errors, info, mname))
}
private def replaceExprInterval(errors: Errors, info: Info, mname: String)(e: Expression): Expression = e match {
case _: WRef | _: WSubIndex | _: WSubField => e
case o =>
- o map replaceExprInterval(errors, info, mname) match {
+ o.map(replaceExprInterval(errors, info, mname)) match {
case DoPrim(AsInterval, Seq(a1), _, tpe) => DoPrim(AsSInt, Seq(a1), Seq.empty, tpe)
- case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe)
- case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe)
+ case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe)
+ case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe)
case DoPrim(Clip, Seq(a1, _), Nil, tpe: IntervalType) =>
// Output interval (pre-calculated)
val clipLo = tpe.minAdjusted.get
@@ -94,13 +95,13 @@ class RemoveIntervals extends Pass {
val ltOpt = clipLo <= inLow
(gtOpt, ltOpt) match {
// input range within output range -> no optimization
- case (true, true) => a1
+ case (true, true) => a1
case (true, false) => Mux(Lt(a1, clipLo.S), clipLo.S, a1)
case (false, true) => Mux(Gt(a1, clipHi.S), clipHi.S, a1)
- case _ => Mux(Gt(a1, clipHi.S), clipHi.S, Mux(Lt(a1, clipLo.S), clipLo.S, a1))
+ case _ => Mux(Gt(a1, clipHi.S), clipHi.S, Mux(Lt(a1, clipLo.S), clipLo.S, a1))
}
- case sqz@DoPrim(Squeeze, Seq(a1, a2), Nil, tpe: IntervalType) =>
+ case sqz @ DoPrim(Squeeze, Seq(a1, a2), Nil, tpe: IntervalType) =>
// Using (conditional) reassign interval w/o adding mux
val a1tpe = a1.tpe.asInstanceOf[IntervalType]
val a2tpe = a2.tpe.asInstanceOf[IntervalType]
@@ -117,54 +118,55 @@ class RemoveIntervals extends Pass {
val bits = DoPrim(Bits, Seq(a1), Seq(w2 - 1, 0), UIntType(IntWidth(w2)))
DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(w2)))
}
- case w@DoPrim(Wrap, Seq(a1, a2), Nil, tpe: IntervalType) => a2.tpe match {
- // If a2 type is Interval wrap around range. If UInt, wrap around width
- case t: IntervalType =>
- // Need to match binary points before getting *adjusted!
- val (wrapLo, wrapHi) = t.copy(point = tpe.point) match {
- case t: IntervalType => (t.minAdjusted.get, t.maxAdjusted.get)
- case _ => Utils.throwInternalError(s"Illegal AST state: cannot have $e not have an IntervalType")
- }
- val (inLo, inHi) = a1.tpe match {
- case t2: IntervalType => (t2.minAdjusted.get, t2.maxAdjusted.get)
- case _ => sys.error("Shouldn't be here")
- }
- // If (max input) - (max wrap) + (min wrap) is less then (maxwrap), we can optimize when (max input > max wrap)
- val range = wrapHi - wrapLo
- val ltOpt = Add(a1, (range + 1).S)
- val gtOpt = Sub(a1, (range + 1).S)
- // [Angie]: This is dangerous. Would rather throw compilation error right now than allow "Rem" without the user explicitly including it.
- // If x < wl
- // output: wh - (wl - x) + 1 AKA x + r + 1
- // worst case: wh - (wl - xl) + 1 = wl
- // -> xl + wr + 1 = wl
- // If x > wh
- // output: wl + (x - wh) - 1 AKA x - r - 1
- // worst case: wl + (xh - wh) - 1 = wh
- // -> xh - wr - 1 = wh
- val default = Add(Rem(Sub(a1, wrapLo.S), Sub(wrapHi.S, wrapLo.S)), wrapLo.S)
- (wrapHi >= inHi, wrapLo <= inLo, (inHi - range - 1) <= wrapHi, (inLo + range + 1) >= wrapLo) match {
- case (true, true, _, _) => a1
- case (true, _, _, true) => Mux(Lt(a1, wrapLo.S), ltOpt, a1)
- case (_, true, true, _) => Mux(Gt(a1, wrapHi.S), gtOpt, a1)
- // Note: inHi - range - 1 = wrapHi can't be true when inLo + range + 1 = wrapLo (i.e. simultaneous extreme cases don't work)
- case (_, _, true, true) => Mux(Gt(a1, wrapHi.S), gtOpt, Mux(Lt(a1, wrapLo.S), ltOpt, a1))
- case _ =>
- errors.append(new WrapWithRemainder(info, mname, w))
- default
- }
- case _ => sys.error("Shouldn't be here")
- }
+ case w @ DoPrim(Wrap, Seq(a1, a2), Nil, tpe: IntervalType) =>
+ a2.tpe match {
+ // If a2 type is Interval wrap around range. If UInt, wrap around width
+ case t: IntervalType =>
+ // Need to match binary points before getting *adjusted!
+ val (wrapLo, wrapHi) = t.copy(point = tpe.point) match {
+ case t: IntervalType => (t.minAdjusted.get, t.maxAdjusted.get)
+ case _ => Utils.throwInternalError(s"Illegal AST state: cannot have $e not have an IntervalType")
+ }
+ val (inLo, inHi) = a1.tpe match {
+ case t2: IntervalType => (t2.minAdjusted.get, t2.maxAdjusted.get)
+ case _ => sys.error("Shouldn't be here")
+ }
+ // If (max input) - (max wrap) + (min wrap) is less then (maxwrap), we can optimize when (max input > max wrap)
+ val range = wrapHi - wrapLo
+ val ltOpt = Add(a1, (range + 1).S)
+ val gtOpt = Sub(a1, (range + 1).S)
+ // [Angie]: This is dangerous. Would rather throw compilation error right now than allow "Rem" without the user explicitly including it.
+ // If x < wl
+ // output: wh - (wl - x) + 1 AKA x + r + 1
+ // worst case: wh - (wl - xl) + 1 = wl
+ // -> xl + wr + 1 = wl
+ // If x > wh
+ // output: wl + (x - wh) - 1 AKA x - r - 1
+ // worst case: wl + (xh - wh) - 1 = wh
+ // -> xh - wr - 1 = wh
+ val default = Add(Rem(Sub(a1, wrapLo.S), Sub(wrapHi.S, wrapLo.S)), wrapLo.S)
+ (wrapHi >= inHi, wrapLo <= inLo, (inHi - range - 1) <= wrapHi, (inLo + range + 1) >= wrapLo) match {
+ case (true, true, _, _) => a1
+ case (true, _, _, true) => Mux(Lt(a1, wrapLo.S), ltOpt, a1)
+ case (_, true, true, _) => Mux(Gt(a1, wrapHi.S), gtOpt, a1)
+ // Note: inHi - range - 1 = wrapHi can't be true when inLo + range + 1 = wrapLo (i.e. simultaneous extreme cases don't work)
+ case (_, _, true, true) => Mux(Gt(a1, wrapHi.S), gtOpt, Mux(Lt(a1, wrapLo.S), ltOpt, a1))
+ case _ =>
+ errors.append(new WrapWithRemainder(info, mname, w))
+ default
+ }
+ case _ => sys.error("Shouldn't be here")
+ }
case other => other
}
}
- private def replacePortInterval(p: Port): Port = p map replaceTypeInterval
+ private def replacePortInterval(p: Port): Port = p.map(replaceTypeInterval)
private def replaceTypeInterval(t: Type): Type = t match {
- case i@IntervalType(l: IsKnown, u: IsKnown, p: IntWidth) => SIntType(i.width)
+ case i @ IntervalType(l: IsKnown, u: IsKnown, p: IntWidth) => SIntType(i.width)
case i: IntervalType => sys.error(s"Shouldn't be here: $i")
- case v => v map replaceTypeInterval
+ case v => v.map(replaceTypeInterval)
}
/** Replace Interval Nodes with Interval Wires
@@ -174,15 +176,16 @@ class RemoveIntervals extends Pass {
* @param m module to replace nodes with wire + connection
* @return
*/
- private def makeWireModule(m: DefModule): DefModule = m map makeWireStmt
+ private def makeWireModule(m: DefModule): DefModule = m.map(makeWireStmt)
private def makeWireStmt(s: Statement): Statement = s match {
- case DefNode(info, name, value) => value.tpe match {
- case IntervalType(l, u, p) =>
- val newType = IntervalType(l, u, p)
- Block(Seq(DefWire(info, name, newType), Connect(info, WRef(name, newType, WireKind, SinkFlow), value)))
- case other => s
- }
- case other => other map makeWireStmt
+ case DefNode(info, name, value) =>
+ value.tpe match {
+ case IntervalType(l, u, p) =>
+ val newType = IntervalType(l, u, p)
+ Block(Seq(DefWire(info, name, newType), Connect(info, WRef(name, newType, WireKind, SinkFlow), value)))
+ case other => s
+ }
+ case other => other.map(makeWireStmt)
}
}
diff --git a/src/main/scala/firrtl/passes/RemoveValidIf.scala b/src/main/scala/firrtl/passes/RemoveValidIf.scala
index 895cb10f..7e82b37b 100644
--- a/src/main/scala/firrtl/passes/RemoveValidIf.scala
+++ b/src/main/scala/firrtl/passes/RemoveValidIf.scala
@@ -26,14 +26,13 @@ object RemoveValidIf extends Pass {
case ClockType => ClockZero
case _: FixedType => FixedZero
case AsyncResetType => AsyncZero
- case other => throwInternalError(s"Unexpected type $other")
+ case other => throwInternalError(s"Unexpected type $other")
}
override def prerequisites = firrtl.stage.Forms.LowForm
override def optionalPrerequisiteOf =
- Seq( Dependency[SystemVerilogEmitter],
- Dependency[VerilogEmitter] )
+ Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
override def invalidates(a: Transform): Boolean = a match {
case Legalize | _: firrtl.transforms.ConstantPropagation => true
@@ -42,24 +41,25 @@ object RemoveValidIf extends Pass {
// Recursive. Removes ValidIfs
private def onExp(e: Expression): Expression = {
- e map onExp match {
+ e.map(onExp) match {
case ValidIf(_, value, _) => value
- case x => x
+ case x => x
}
}
// Recursive. Replaces IsInvalid with connecting zero
- private def onStmt(s: Statement): Statement = s map onStmt map onExp match {
- case invalid @ IsInvalid(info, loc) => loc.tpe match {
- case _: AnalogType => EmptyStmt
- case tpe => Connect(info, loc, getGroundZero(tpe))
- }
+ private def onStmt(s: Statement): Statement = s.map(onStmt).map(onExp) match {
+ case invalid @ IsInvalid(info, loc) =>
+ loc.tpe match {
+ case _: AnalogType => EmptyStmt
+ case tpe => Connect(info, loc, getGroundZero(tpe))
+ }
case other => other
}
private def onModule(m: DefModule): DefModule = {
m match {
- case m: Module => Module(m.info, m.name, m.ports, onStmt(m.body))
+ case m: Module => Module(m.info, m.name, m.ports, onStmt(m.body))
case m: ExtModule => m
}
}
diff --git a/src/main/scala/firrtl/passes/ReplaceAccesses.scala b/src/main/scala/firrtl/passes/ReplaceAccesses.scala
index e31d9410..4a3cd697 100644
--- a/src/main/scala/firrtl/passes/ReplaceAccesses.scala
+++ b/src/main/scala/firrtl/passes/ReplaceAccesses.scala
@@ -18,15 +18,16 @@ object ReplaceAccesses extends Pass {
override def invalidates(a: Transform) = false
def run(c: Circuit): Circuit = {
- def onStmt(s: Statement): Statement = s map onStmt map onExp
- def onExp(e: Expression): Expression = e match {
- case WSubAccess(ex, UIntLiteral(value, _), t, g) => ex.tpe match {
- case VectorType(_, len) if (value < len) => WSubIndex(onExp(ex), value.toInt, t, g)
- case _ => e map onExp
- }
- case _ => e map onExp
+ def onStmt(s: Statement): Statement = s.map(onStmt).map(onExp)
+ def onExp(e: Expression): Expression = e match {
+ case WSubAccess(ex, UIntLiteral(value, _), t, g) =>
+ ex.tpe match {
+ case VectorType(_, len) if (value < len) => WSubIndex(onExp(ex), value.toInt, t, g)
+ case _ => e.map(onExp)
+ }
+ case _ => e.map(onExp)
}
- c copy (modules = c.modules map (_ map onStmt))
+ c.copy(modules = c.modules.map(_.map(onStmt)))
}
}
diff --git a/src/main/scala/firrtl/passes/ResolveFlows.scala b/src/main/scala/firrtl/passes/ResolveFlows.scala
index 85a0a26f..48b9479c 100644
--- a/src/main/scala/firrtl/passes/ResolveFlows.scala
+++ b/src/main/scala/firrtl/passes/ResolveFlows.scala
@@ -14,17 +14,22 @@ object ResolveFlows extends Pass {
override def invalidates(a: Transform) = false
def resolve_e(g: Flow)(e: Expression): Expression = e match {
- case ex: WRef => ex copy (flow = g)
- case WSubField(exp, name, tpe, _) => WSubField(
- Utils.field_flip(exp.tpe, name) match {
- case Default => resolve_e(g)(exp)
- case Flip => resolve_e(Utils.swap(g))(exp)
- }, name, tpe, g)
+ case ex: WRef => ex.copy(flow = g)
+ case WSubField(exp, name, tpe, _) =>
+ WSubField(
+ Utils.field_flip(exp.tpe, name) match {
+ case Default => resolve_e(g)(exp)
+ case Flip => resolve_e(Utils.swap(g))(exp)
+ },
+ name,
+ tpe,
+ g
+ )
case WSubIndex(exp, value, tpe, _) =>
WSubIndex(resolve_e(g)(exp), value, tpe, g)
case WSubAccess(exp, index, tpe, _) =>
WSubAccess(resolve_e(g)(exp), resolve_e(SourceFlow)(index), tpe, g)
- case _ => e map resolve_e(g)
+ case _ => e.map(resolve_e(g))
}
def resolve_s(s: Statement): Statement = s match {
@@ -35,11 +40,11 @@ object ResolveFlows extends Pass {
Connect(info, resolve_e(SinkFlow)(loc), resolve_e(SourceFlow)(expr))
case PartialConnect(info, loc, expr) =>
PartialConnect(info, resolve_e(SinkFlow)(loc), resolve_e(SourceFlow)(expr))
- case sx => sx map resolve_e(SourceFlow) map resolve_s
+ case sx => sx.map(resolve_e(SourceFlow)).map(resolve_s)
}
- def resolve_flow(m: DefModule): DefModule = m map resolve_s
+ def resolve_flow(m: DefModule): DefModule = m.map(resolve_s)
def run(c: Circuit): Circuit =
- c copy (modules = c.modules map resolve_flow)
+ c.copy(modules = c.modules.map(resolve_flow))
}
diff --git a/src/main/scala/firrtl/passes/ResolveKinds.scala b/src/main/scala/firrtl/passes/ResolveKinds.scala
index 67360b74..fcbac163 100644
--- a/src/main/scala/firrtl/passes/ResolveKinds.scala
+++ b/src/main/scala/firrtl/passes/ResolveKinds.scala
@@ -20,21 +20,21 @@ object ResolveKinds extends Pass {
}
def resolve_expr(kinds: KindMap)(e: Expression): Expression = e match {
- case ex: WRef => ex copy (kind = kinds(ex.name))
- case _ => e map resolve_expr(kinds)
+ case ex: WRef => ex.copy(kind = kinds(ex.name))
+ case _ => e.map(resolve_expr(kinds))
}
def resolve_stmt(kinds: KindMap)(s: Statement): Statement = {
s match {
- case sx: DefWire => kinds(sx.name) = WireKind
- case sx: DefNode => kinds(sx.name) = NodeKind
- case sx: DefRegister => kinds(sx.name) = RegKind
+ case sx: DefWire => kinds(sx.name) = WireKind
+ case sx: DefNode => kinds(sx.name) = NodeKind
+ case sx: DefRegister => kinds(sx.name) = RegKind
case sx: WDefInstance => kinds(sx.name) = InstanceKind
- case sx: DefMemory => kinds(sx.name) = MemKind
+ case sx: DefMemory => kinds(sx.name) = MemKind
case _ =>
}
s.map(resolve_stmt(kinds))
- .map(resolve_expr(kinds))
+ .map(resolve_expr(kinds))
}
def resolve_kinds(m: DefModule): DefModule = {
@@ -44,5 +44,5 @@ object ResolveKinds extends Pass {
}
def run(c: Circuit): Circuit =
- c copy (modules = c.modules map resolve_kinds)
+ c.copy(modules = c.modules.map(resolve_kinds))
}
diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala
index c536cd5d..a65f8921 100644
--- a/src/main/scala/firrtl/passes/SplitExpressions.scala
+++ b/src/main/scala/firrtl/passes/SplitExpressions.scala
@@ -7,7 +7,7 @@ import firrtl.{SystemVerilogEmitter, Transform, VerilogEmitter}
import firrtl.ir._
import firrtl.options.Dependency
import firrtl.Mappers._
-import firrtl.Utils.{kind, flow, get_info}
+import firrtl.Utils.{flow, get_info, kind}
// Datastructures
import scala.collection.mutable
@@ -17,65 +17,63 @@ import scala.collection.mutable
object SplitExpressions extends Pass {
override def prerequisites = firrtl.stage.Forms.LowForm ++
- Seq( Dependency(firrtl.passes.RemoveValidIf),
- Dependency(firrtl.passes.memlib.VerilogMemDelays) )
+ Seq(Dependency(firrtl.passes.RemoveValidIf), Dependency(firrtl.passes.memlib.VerilogMemDelays))
override def optionalPrerequisiteOf =
- Seq( Dependency[SystemVerilogEmitter],
- Dependency[VerilogEmitter] )
+ Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
override def invalidates(a: Transform) = a match {
case ResolveKinds => true
case _ => false
}
- private def onModule(m: Module): Module = {
- val namespace = Namespace(m)
- def onStmt(s: Statement): Statement = {
- val v = mutable.ArrayBuffer[Statement]()
- // Splits current expression if needed
- // Adds named temporaries to v
- def split(e: Expression): Expression = e match {
- case e: DoPrim =>
- val name = namespace.newTemp
- v += DefNode(get_info(s), name, e)
- WRef(name, e.tpe, kind(e), flow(e))
- case e: Mux =>
- val name = namespace.newTemp
- v += DefNode(get_info(s), name, e)
- WRef(name, e.tpe, kind(e), flow(e))
- case e: ValidIf =>
- val name = namespace.newTemp
- v += DefNode(get_info(s), name, e)
- WRef(name, e.tpe, kind(e), flow(e))
- case _ => e
- }
-
- // Recursive. Splits compound nodes
- def onExp(e: Expression): Expression =
- e map onExp match {
- case ex: DoPrim => ex map split
- case ex => ex
- }
+ private def onModule(m: Module): Module = {
+ val namespace = Namespace(m)
+ def onStmt(s: Statement): Statement = {
+ val v = mutable.ArrayBuffer[Statement]()
+ // Splits current expression if needed
+ // Adds named temporaries to v
+ def split(e: Expression): Expression = e match {
+ case e: DoPrim =>
+ val name = namespace.newTemp
+ v += DefNode(get_info(s), name, e)
+ WRef(name, e.tpe, kind(e), flow(e))
+ case e: Mux =>
+ val name = namespace.newTemp
+ v += DefNode(get_info(s), name, e)
+ WRef(name, e.tpe, kind(e), flow(e))
+ case e: ValidIf =>
+ val name = namespace.newTemp
+ v += DefNode(get_info(s), name, e)
+ WRef(name, e.tpe, kind(e), flow(e))
+ case _ => e
+ }
- s map onExp match {
- case x: Block => x map onStmt
- case EmptyStmt => EmptyStmt
- case x =>
- v += x
- v.size match {
- case 1 => v.head
- case _ => Block(v.toSeq)
- }
+ // Recursive. Splits compound nodes
+ def onExp(e: Expression): Expression =
+ e.map(onExp) match {
+ case ex: DoPrim => ex.map(split)
+ case ex => ex
}
+
+ s.map(onExp) match {
+ case x: Block => x.map(onStmt)
+ case EmptyStmt => EmptyStmt
+ case x =>
+ v += x
+ v.size match {
+ case 1 => v.head
+ case _ => Block(v.toSeq)
+ }
}
- Module(m.info, m.name, m.ports, onStmt(m.body))
- }
- def run(c: Circuit): Circuit = {
- val modulesx = c.modules map {
- case m: Module => onModule(m)
- case m: ExtModule => m
- }
- Circuit(c.info, modulesx, c.main)
- }
+ }
+ Module(m.info, m.name, m.ports, onStmt(m.body))
+ }
+ def run(c: Circuit): Circuit = {
+ val modulesx = c.modules.map {
+ case m: Module => onModule(m)
+ case m: ExtModule => m
+ }
+ Circuit(c.info, modulesx, c.main)
+ }
}
diff --git a/src/main/scala/firrtl/passes/ToWorkingIR.scala b/src/main/scala/firrtl/passes/ToWorkingIR.scala
index c271302a..03faaf3c 100644
--- a/src/main/scala/firrtl/passes/ToWorkingIR.scala
+++ b/src/main/scala/firrtl/passes/ToWorkingIR.scala
@@ -6,5 +6,5 @@ import firrtl.Transform
object ToWorkingIR extends Pass {
override def prerequisites = firrtl.stage.Forms.MinimalHighForm
override def invalidates(a: Transform) = false
- def run(c:Circuit): Circuit = c
+ def run(c: Circuit): Circuit = c
}
diff --git a/src/main/scala/firrtl/passes/TrimIntervals.scala b/src/main/scala/firrtl/passes/TrimIntervals.scala
index 822a8125..0a05bd4e 100644
--- a/src/main/scala/firrtl/passes/TrimIntervals.scala
+++ b/src/main/scala/firrtl/passes/TrimIntervals.scala
@@ -23,10 +23,7 @@ import firrtl.Transform
class TrimIntervals extends Pass {
override def prerequisites =
- Seq( Dependency(ResolveKinds),
- Dependency(InferTypes),
- Dependency(ResolveFlows),
- Dependency[InferBinaryPoints] )
+ Seq(Dependency(ResolveKinds), Dependency(InferTypes), Dependency(ResolveFlows), Dependency[InferBinaryPoints])
override def optionalPrerequisiteOf = Seq.empty
@@ -34,48 +31,51 @@ class TrimIntervals extends Pass {
def run(c: Circuit): Circuit = {
// Open -> closed
- val firstPass = InferTypes.run(c map replaceModuleInterval)
+ val firstPass = InferTypes.run(c.map(replaceModuleInterval))
// Align binary points and adjust range accordingly (loss of precision changes range)
- firstPass map alignModuleBP
+ firstPass.map(alignModuleBP)
}
/* Replace interval types */
- private def replaceModuleInterval(m: DefModule): DefModule = m map replaceStmtInterval map replacePortInterval
+ private def replaceModuleInterval(m: DefModule): DefModule = m.map(replaceStmtInterval).map(replacePortInterval)
- private def replaceStmtInterval(s: Statement): Statement = s map replaceTypeInterval map replaceStmtInterval
+ private def replaceStmtInterval(s: Statement): Statement = s.map(replaceTypeInterval).map(replaceStmtInterval)
- private def replacePortInterval(p: Port): Port = p map replaceTypeInterval
+ private def replacePortInterval(p: Port): Port = p.map(replaceTypeInterval)
private def replaceTypeInterval(t: Type): Type = t match {
- case i@IntervalType(l: IsKnown, u: IsKnown, IntWidth(p)) =>
+ case i @ IntervalType(l: IsKnown, u: IsKnown, IntWidth(p)) =>
IntervalType(Closed(i.min.get), Closed(i.max.get), IntWidth(p))
case i: IntervalType => i
- case v => v map replaceTypeInterval
+ case v => v.map(replaceTypeInterval)
}
/* Align interval binary points -- BINARY POINT ALIGNMENT AFFECTS RANGE INFERENCE! */
- private def alignModuleBP(m: DefModule): DefModule = m map alignStmtBP
-
- private def alignStmtBP(s: Statement): Statement = s map alignExpBP match {
- case c@Connect(info, loc, expr) => loc.tpe match {
- case IntervalType(_, _, p) => Connect(info, loc, fixBP(p)(expr))
- case _ => c
- }
- case c@PartialConnect(info, loc, expr) => loc.tpe match {
- case IntervalType(_, _, p) => PartialConnect(info, loc, fixBP(p)(expr))
- case _ => c
- }
- case other => other map alignStmtBP
+ private def alignModuleBP(m: DefModule): DefModule = m.map(alignStmtBP)
+
+ private def alignStmtBP(s: Statement): Statement = s.map(alignExpBP) match {
+ case c @ Connect(info, loc, expr) =>
+ loc.tpe match {
+ case IntervalType(_, _, p) => Connect(info, loc, fixBP(p)(expr))
+ case _ => c
+ }
+ case c @ PartialConnect(info, loc, expr) =>
+ loc.tpe match {
+ case IntervalType(_, _, p) => PartialConnect(info, loc, fixBP(p)(expr))
+ case _ => c
+ }
+ case other => other.map(alignStmtBP)
}
// Note - wrap/clip/squeeze ignore the binary point of the second argument, thus not needed to be aligned
// Note - Mul does not need its binary points aligned, because multiplication is cool like that
- private val opsToFix = Seq(Add, Sub, Lt, Leq, Gt, Geq, Eq, Neq/*, Wrap, Clip, Squeeze*/)
+ private val opsToFix = Seq(Add, Sub, Lt, Leq, Gt, Geq, Eq, Neq /*, Wrap, Clip, Squeeze*/ )
- private def alignExpBP(e: Expression): Expression = e map alignExpBP match {
+ private def alignExpBP(e: Expression): Expression = e.map(alignExpBP) match {
case DoPrim(SetP, Seq(arg), Seq(const), tpe: IntervalType) => fixBP(IntWidth(const))(arg)
- case DoPrim(o, args, consts, t) if opsToFix.contains(o) &&
- (args.map(_.tpe).collect { case x: IntervalType => x }).size == args.size =>
+ case DoPrim(o, args, consts, t)
+ if opsToFix.contains(o) &&
+ (args.map(_.tpe).collect { case x: IntervalType => x }).size == args.size =>
val maxBP = args.map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ max _)
DoPrim(o, args.map { a => fixBP(maxBP)(a) }, consts, t)
case Mux(cond, tval, fval, t: IntervalType) =>
@@ -85,9 +85,9 @@ class TrimIntervals extends Pass {
}
private def fixBP(p: Width)(e: Expression): Expression = (p, e.tpe) match {
case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired == current => e
- case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired > current =>
+ case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired > current =>
DoPrim(IncP, Seq(e), Seq(desired - current), IntervalType(l, u, IntWidth(desired)))
- case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired < current =>
+ case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired < current =>
val shiftAmt = current - desired
val shiftGain = BigDecimal(BigInt(1) << shiftAmt.toInt)
val shiftMul = Closed(BigDecimal(1) / shiftGain)
diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala
index b9cd32fa..10198b33 100644
--- a/src/main/scala/firrtl/passes/Uniquify.scala
+++ b/src/main/scala/firrtl/passes/Uniquify.scala
@@ -2,7 +2,6 @@
package firrtl.passes
-
import scala.annotation.tailrec
import firrtl._
import firrtl.ir._
@@ -35,12 +34,11 @@ import MemPortUtils.memType
object Uniquify extends Transform with DependencyAPIMigration {
override def prerequisites =
- Seq( Dependency(ResolveKinds),
- Dependency(InferTypes) ) ++ firrtl.stage.Forms.WorkingIR
+ Seq(Dependency(ResolveKinds), Dependency(InferTypes)) ++ firrtl.stage.Forms.WorkingIR
override def invalidates(a: Transform): Boolean = a match {
case ResolveKinds | InferTypes => true
- case _ => false
+ case _ => false
}
private case class UniquifyException(msg: String) extends FirrtlInternalException(msg)
@@ -55,12 +53,13 @@ object Uniquify extends Transform with DependencyAPIMigration {
*/
@tailrec
def findValidPrefix(
- prefix: String,
- elts: Seq[String],
- namespace: collection.mutable.HashSet[String]): String = {
- elts find (elt => namespace.contains(prefix + elt)) match {
+ prefix: String,
+ elts: Seq[String],
+ namespace: collection.mutable.HashSet[String]
+ ): String = {
+ elts.find(elt => namespace.contains(prefix + elt)) match {
case Some(_) => findValidPrefix(prefix + "_", elts, namespace)
- case None => prefix
+ case None => prefix
}
}
@@ -70,16 +69,16 @@ object Uniquify extends Transform with DependencyAPIMigration {
* => foo, foo bar, foo bar 0, foo bar 1, foo bar 0 a, foo bar 0 b, foo bar 1 a, foo bar 1 b, foo c
* }}}
*/
- private [firrtl] def enumerateNames(tpe: Type): Seq[Seq[String]] = tpe match {
+ private[firrtl] def enumerateNames(tpe: Type): Seq[Seq[String]] = tpe match {
case t: BundleType =>
- t.fields flatMap { f =>
- (enumerateNames(f.tpe) map (f.name +: _)) ++ Seq(Seq(f.name))
+ t.fields.flatMap { f =>
+ (enumerateNames(f.tpe).map(f.name +: _)) ++ Seq(Seq(f.name))
}
case t: VectorType =>
- ((0 until t.size) map (i => Seq(i.toString))) ++
- ((0 until t.size) flatMap { i =>
- enumerateNames(t.tpe) map (i.toString +: _)
- })
+ ((0 until t.size).map(i => Seq(i.toString))) ++
+ ((0 until t.size).flatMap { i =>
+ enumerateNames(t.tpe).map(i.toString +: _)
+ })
case _ => Seq()
}
@@ -87,27 +86,38 @@ object Uniquify extends Transform with DependencyAPIMigration {
def stmtToType(s: Statement)(implicit sinfo: Info, mname: String): BundleType = {
// Recursive helper
def recStmtToType(s: Statement): Seq[Field] = s match {
- case sx: DefWire => Seq(Field(sx.name, Default, sx.tpe))
+ case sx: DefWire => Seq(Field(sx.name, Default, sx.tpe))
case sx: DefRegister => Seq(Field(sx.name, Default, sx.tpe))
case sx: WDefInstance => Seq(Field(sx.name, Default, sx.tpe))
- case sx: DefMemory => sx.dataType match {
- case (_: UIntType | _: SIntType | _: FixedType) =>
- Seq(Field(sx.name, Default, memType(sx)))
- case tpe: BundleType =>
- val newFields = tpe.fields map ( f =>
- DefMemory(sx.info, f.name, f.tpe, sx.depth, sx.writeLatency,
- sx.readLatency, sx.readers, sx.writers, sx.readwriters)
- ) flatMap recStmtToType
- Seq(Field(sx.name, Default, BundleType(newFields)))
- case tpe: VectorType =>
- val newFields = (0 until tpe.size) map ( i =>
- sx.copy(name = i.toString, dataType = tpe.tpe)
- ) flatMap recStmtToType
- Seq(Field(sx.name, Default, BundleType(newFields)))
- }
- case sx: DefNode => Seq(Field(sx.name, Default, sx.value.tpe))
+ case sx: DefMemory =>
+ sx.dataType match {
+ case (_: UIntType | _: SIntType | _: FixedType) =>
+ Seq(Field(sx.name, Default, memType(sx)))
+ case tpe: BundleType =>
+ val newFields = tpe.fields
+ .map(f =>
+ DefMemory(
+ sx.info,
+ f.name,
+ f.tpe,
+ sx.depth,
+ sx.writeLatency,
+ sx.readLatency,
+ sx.readers,
+ sx.writers,
+ sx.readwriters
+ )
+ )
+ .flatMap(recStmtToType)
+ Seq(Field(sx.name, Default, BundleType(newFields)))
+ case tpe: VectorType =>
+ val newFields =
+ (0 until tpe.size).map(i => sx.copy(name = i.toString, dataType = tpe.tpe)).flatMap(recStmtToType)
+ Seq(Field(sx.name, Default, BundleType(newFields)))
+ }
+ case sx: DefNode => Seq(Field(sx.name, Default, sx.value.tpe))
case sx: Conditionally => recStmtToType(sx.conseq) ++ recStmtToType(sx.alt)
- case sx: Block => (sx.stmts map recStmtToType).flatten
+ case sx: Block => (sx.stmts.map(recStmtToType)).flatten
case sx => Seq()
}
BundleType(recStmtToType(s))
@@ -116,40 +126,44 @@ object Uniquify extends Transform with DependencyAPIMigration {
// Accepts a Type and an initial namespace
// Returns new Type with uniquified names
private def uniquifyNames(
- t: BundleType,
- namespace: collection.mutable.HashSet[String])
- (implicit sinfo: Info, mname: String): BundleType = {
+ t: BundleType,
+ namespace: collection.mutable.HashSet[String]
+ )(
+ implicit sinfo: Info,
+ mname: String
+ ): BundleType = {
def recUniquifyNames(t: Type, namespace: collection.mutable.HashSet[String]): (Type, Seq[String]) = t match {
case tx: BundleType =>
// First add everything
- val newFieldsAndElts = tx.fields map { f =>
+ val newFieldsAndElts = tx.fields.map { f =>
val newName = findValidPrefix(f.name, Seq(""), namespace)
namespace += newName
Field(newName, f.flip, f.tpe)
- } map { f => f.tpe match {
- case _: GroundType => (f, Seq[String](f.name))
- case _ =>
- val (tpe, eltsx) = recUniquifyNames(f.tpe, collection.mutable.HashSet())
- // Need leading _ for findValidPrefix, it doesn't add _ for checks
- val eltsNames: Seq[String] = eltsx map (e => "_" + e)
- val prefix = findValidPrefix(f.name, eltsNames, namespace)
- // We added f.name in previous map, delete if we change it
- if (prefix != f.name) {
- namespace -= f.name
- namespace += prefix
- }
- val newElts: Seq[String] = eltsx map (e => LowerTypes.loweredName(prefix +: Seq(e)))
- namespace ++= newElts
- (Field(prefix, f.flip, tpe), prefix +: newElts)
+ }.map { f =>
+ f.tpe match {
+ case _: GroundType => (f, Seq[String](f.name))
+ case _ =>
+ val (tpe, eltsx) = recUniquifyNames(f.tpe, collection.mutable.HashSet())
+ // Need leading _ for findValidPrefix, it doesn't add _ for checks
+ val eltsNames: Seq[String] = eltsx.map(e => "_" + e)
+ val prefix = findValidPrefix(f.name, eltsNames, namespace)
+ // We added f.name in previous map, delete if we change it
+ if (prefix != f.name) {
+ namespace -= f.name
+ namespace += prefix
+ }
+ val newElts: Seq[String] = eltsx.map(e => LowerTypes.loweredName(prefix +: Seq(e)))
+ namespace ++= newElts
+ (Field(prefix, f.flip, tpe), prefix +: newElts)
}
}
val (newFields, elts) = newFieldsAndElts.unzip
(BundleType(newFields), elts.flatten)
case tx: VectorType =>
val (tpe, elts) = recUniquifyNames(tx.tpe, namespace)
- val newElts = ((0 until tx.size) map (i => i.toString)) ++
- ((0 until tx.size) flatMap { i =>
- elts map (e => LowerTypes.loweredName(Seq(i.toString, e)))
+ val newElts = ((0 until tx.size).map(i => i.toString)) ++
+ ((0 until tx.size).flatMap { i =>
+ elts.map(e => LowerTypes.loweredName(Seq(i.toString, e)))
})
(VectorType(tpe, tx.size), newElts)
case tx => (tx, Nil)
@@ -164,19 +178,26 @@ object Uniquify extends Transform with DependencyAPIMigration {
// Creates a mapping from flattened references to members of $from ->
// flattened references to members of $to
private def createNameMapping(
- from: Type,
- to: Type)
- (implicit sinfo: Info, mname: String): Map[String, NameMapNode] = {
+ from: Type,
+ to: Type
+ )(
+ implicit sinfo: Info,
+ mname: String
+ ): Map[String, NameMapNode] = {
(from, to) match {
case (fromx: BundleType, tox: BundleType) =>
- (fromx.fields zip tox.fields flatMap { case (f, t) =>
- val eltsMap = createNameMapping(f.tpe, t.tpe)
- if ((f.name != t.name) || eltsMap.nonEmpty) {
- Map(f.name -> NameMapNode(t.name, eltsMap))
- } else {
- Map[String, NameMapNode]()
- }
- }).toMap
+ (fromx.fields
+ .zip(tox.fields)
+ .flatMap {
+ case (f, t) =>
+ val eltsMap = createNameMapping(f.tpe, t.tpe)
+ if ((f.name != t.name) || eltsMap.nonEmpty) {
+ Map(f.name -> NameMapNode(t.name, eltsMap))
+ } else {
+ Map[String, NameMapNode]()
+ }
+ })
+ .toMap
case (fromx: VectorType, tox: VectorType) =>
createNameMapping(fromx.tpe, tox.tpe)
case (fromx, tox) =>
@@ -187,18 +208,19 @@ object Uniquify extends Transform with DependencyAPIMigration {
// Maps names in expression to new uniquified names
private def uniquifyNamesExp(
- exp: Expression,
- map: Map[String, NameMapNode])
- (implicit sinfo: Info, mname: String): Expression = {
+ exp: Expression,
+ map: Map[String, NameMapNode]
+ )(
+ implicit sinfo: Info,
+ mname: String
+ ): Expression = {
// Recursive Helper
- def rec(exp: Expression, m: Map[String, NameMapNode]):
- (Expression, Map[String, NameMapNode]) = exp match {
+ def rec(exp: Expression, m: Map[String, NameMapNode]): (Expression, Map[String, NameMapNode]) = exp match {
case e: WRef =>
if (m.contains(e.name)) {
val node = m(e.name)
(WRef(node.name, e.tpe, e.kind, e.flow), node.elts)
- }
- else (e, Map())
+ } else (e, Map())
case e: WSubField =>
val (subExp, subMap) = rec(e.expr, m)
val (retName, retMap) =
@@ -218,18 +240,21 @@ object Uniquify extends Transform with DependencyAPIMigration {
(WSubAccess(subExp, index, e.tpe, e.flow), subMap)
case (_: UIntLiteral | _: SIntLiteral) => (exp, m)
case (_: Mux | _: ValidIf | _: DoPrim) =>
- (exp map ((e: Expression) => uniquifyNamesExp(e, map)), m)
+ (exp.map((e: Expression) => uniquifyNamesExp(e, map)), m)
}
rec(exp, map)._1
}
// Uses map to recursively rename fields of tpe
private def uniquifyNamesType(
- tpe: Type,
- map: Map[String, NameMapNode])
- (implicit sinfo: Info, mname: String): Type = tpe match {
+ tpe: Type,
+ map: Map[String, NameMapNode]
+ )(
+ implicit sinfo: Info,
+ mname: String
+ ): Type = tpe match {
case t: BundleType =>
- val newFields = t.fields map { f =>
+ val newFields = t.fields.map { f =>
if (map.contains(f.name)) {
val node = map(f.name)
Field(node.name, f.flip, uniquifyNamesType(f.tpe, node.elts))
@@ -244,8 +269,11 @@ object Uniquify extends Transform with DependencyAPIMigration {
}
// Everything wrapped in run so that it's thread safe
- @deprecated("The functionality of Uniquify is now part of LowerTypes." +
- "Please file an issue with firrtl if you use Uniquify outside of the context of LowerTypes.", "Firrtl 1.4")
+ @deprecated(
+ "The functionality of Uniquify is now part of LowerTypes." +
+ "Please file an issue with firrtl if you use Uniquify outside of the context of LowerTypes.",
+ "Firrtl 1.4"
+ )
def execute(state: CircuitState): CircuitState = {
val c = state.circuit
val renames = RenameMap()
@@ -263,22 +291,22 @@ object Uniquify extends Transform with DependencyAPIMigration {
val nameMap = collection.mutable.HashMap[String, NameMapNode]()
def uniquifyExp(e: Expression): Expression = e match {
- case (_: WRef | _: WSubField | _: WSubIndex | _: WSubAccess ) =>
+ case (_: WRef | _: WSubField | _: WSubIndex | _: WSubAccess) =>
uniquifyNamesExp(e, nameMap.toMap)
- case e: Mux => e map uniquifyExp
- case e: ValidIf => e map uniquifyExp
+ case e: Mux => e.map(uniquifyExp)
+ case e: ValidIf => e.map(uniquifyExp)
case (_: UIntLiteral | _: SIntLiteral) => e
- case e: DoPrim => e map uniquifyExp
+ case e: DoPrim => e.map(uniquifyExp)
}
def uniquifyStmt(s: Statement): Statement = {
- s map uniquifyStmt map uniquifyExp match {
+ s.map(uniquifyStmt).map(uniquifyExp) match {
case sx: DefWire =>
sinfo = sx.info
if (nameMap.contains(sx.name)) {
val node = nameMap(sx.name)
val newType = uniquifyNamesType(sx.tpe, node.elts)
- (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach {
+ (Utils.create_exps(sx.name, sx.tpe).zip(Utils.create_exps(node.name, newType))).foreach {
case (from, to) => renames.rename(from.serialize, to.serialize)
}
DefWire(sx.info, node.name, newType)
@@ -290,7 +318,7 @@ object Uniquify extends Transform with DependencyAPIMigration {
if (nameMap.contains(sx.name)) {
val node = nameMap(sx.name)
val newType = uniquifyNamesType(sx.tpe, node.elts)
- (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach {
+ (Utils.create_exps(sx.name, sx.tpe).zip(Utils.create_exps(node.name, newType))).foreach {
case (from, to) => renames.rename(from.serialize, to.serialize)
}
DefRegister(sx.info, node.name, newType, sx.clock, sx.reset, sx.init)
@@ -302,7 +330,7 @@ object Uniquify extends Transform with DependencyAPIMigration {
if (nameMap.contains(sx.name)) {
val node = nameMap(sx.name)
val newType = portTypeMap(m.name)
- (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach {
+ (Utils.create_exps(sx.name, sx.tpe).zip(Utils.create_exps(node.name, newType))).foreach {
case (from, to) => renames.rename(from.serialize, to.serialize)
}
WDefInstance(sx.info, node.name, sx.module, newType)
@@ -317,7 +345,7 @@ object Uniquify extends Transform with DependencyAPIMigration {
val mem = sx.copy(name = node.name, dataType = dataType)
// Create new mapping to handle references to memory data fields
val uniqueMemMap = createNameMapping(memType(sx), memType(mem))
- (Utils.create_exps(sx.name, memType(sx)) zip Utils.create_exps(node.name, memType(mem))) foreach {
+ (Utils.create_exps(sx.name, memType(sx)).zip(Utils.create_exps(node.name, memType(mem)))).foreach {
case (from, to) => renames.rename(from.serialize, to.serialize)
}
nameMap(sx.name) = NameMapNode(node.name, node.elts ++ uniqueMemMap)
@@ -329,9 +357,12 @@ object Uniquify extends Transform with DependencyAPIMigration {
sinfo = sx.info
if (nameMap.contains(sx.name)) {
val node = nameMap(sx.name)
- (Utils.create_exps(sx.name, s.asInstanceOf[DefNode].value.tpe) zip Utils.create_exps(node.name, sx.value.tpe)) foreach {
- case (from, to) => renames.rename(from.serialize, to.serialize)
- }
+ (Utils
+ .create_exps(sx.name, s.asInstanceOf[DefNode].value.tpe)
+ .zip(Utils.create_exps(node.name, sx.value.tpe)))
+ .foreach {
+ case (from, to) => renames.rename(from.serialize, to.serialize)
+ }
DefNode(sx.info, node.name, sx.value)
} else {
sx
@@ -354,19 +385,18 @@ object Uniquify extends Transform with DependencyAPIMigration {
mname = m.name
m match {
case m: ExtModule => m
- case m: Module =>
+ case m: Module =>
// Adds port names to namespace and namemap
nameMap ++= portNameMap(m.name)
- namespace ++= create_exps("", portTypeMap(m.name)) map
- LowerTypes.loweredName map (_.tail)
- m.copy(body = uniquifyBody(m.body) )
+ namespace ++= create_exps("", portTypeMap(m.name)).map(LowerTypes.loweredName).map(_.tail)
+ m.copy(body = uniquifyBody(m.body))
}
}
def uniquifyPorts(renames: RenameMap)(m: DefModule): DefModule = {
renames.setModule(m.name)
def uniquifyPorts(ports: Seq[Port]): Seq[Port] = {
- val portsType = BundleType(ports map {
+ val portsType = BundleType(ports.map {
case Port(_, name, dir, tpe) => Field(name, to_flip(dir), tpe)
})
val uniquePortsType = uniquifyNames(portsType, collection.mutable.HashSet())
@@ -374,11 +404,12 @@ object Uniquify extends Transform with DependencyAPIMigration {
portNameMap += (m.name -> localMap)
portTypeMap += (m.name -> uniquePortsType)
- ports zip uniquePortsType.fields map { case (p, f) =>
- (Utils.create_exps(p.name, p.tpe) zip Utils.create_exps(f.name, f.tpe)) foreach {
- case (from, to) => renames.rename(from.serialize, to.serialize)
- }
- Port(p.info, f.name, p.direction, f.tpe)
+ ports.zip(uniquePortsType.fields).map {
+ case (p, f) =>
+ (Utils.create_exps(p.name, p.tpe).zip(Utils.create_exps(f.name, f.tpe))).foreach {
+ case (from, to) => renames.rename(from.serialize, to.serialize)
+ }
+ Port(p.info, f.name, p.direction, f.tpe)
}
}
@@ -386,12 +417,12 @@ object Uniquify extends Transform with DependencyAPIMigration {
mname = m.name
m match {
case m: ExtModule => m.copy(ports = uniquifyPorts(m.ports))
- case m: Module => m.copy(ports = uniquifyPorts(m.ports))
+ case m: Module => m.copy(ports = uniquifyPorts(m.ports))
}
}
sinfo = c.info
- val result = Circuit(c.info, c.modules map uniquifyPorts(renames) map uniquifyModule(renames), c.main)
+ val result = Circuit(c.info, c.modules.map(uniquifyPorts(renames)).map(uniquifyModule(renames)), c.main)
state.copy(circuit = result, renames = Some(renames))
}
}
diff --git a/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala b/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala
index 36eff379..0b046a5f 100644
--- a/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala
+++ b/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala
@@ -12,28 +12,30 @@ import firrtl.options.Dependency
import scala.collection.mutable
/**
- * Verilog has the width of (a % b) = Max(W(a), W(b))
- * FIRRTL has the width of (a % b) = Min(W(a), W(b)), which makes more sense,
- * but nevertheless is a problem when emitting verilog
- *
- * This pass finds every instance of (a % b) and:
- * 1) adds a temporary node equal to (a % b) with width Max(W(a), W(b))
- * 2) replaces the reference to (a % b) with a bitslice of the temporary node
- * to get back down to width Min(W(a), W(b))
- *
- * This is technically incorrect firrtl, but allows the verilog emitter
- * to emit correct verilog without needing to add temporary nodes
- */
+ * Verilog has the width of (a % b) = Max(W(a), W(b))
+ * FIRRTL has the width of (a % b) = Min(W(a), W(b)), which makes more sense,
+ * but nevertheless is a problem when emitting verilog
+ *
+ * This pass finds every instance of (a % b) and:
+ * 1) adds a temporary node equal to (a % b) with width Max(W(a), W(b))
+ * 2) replaces the reference to (a % b) with a bitslice of the temporary node
+ * to get back down to width Min(W(a), W(b))
+ *
+ * This is technically incorrect firrtl, but allows the verilog emitter
+ * to emit correct verilog without needing to add temporary nodes
+ */
object VerilogModulusCleanup extends Pass {
override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
- Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper],
- Dependency[firrtl.transforms.FixAddingNegativeLiterals],
- Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
- Dependency[firrtl.transforms.InlineBitExtractionsTransform],
- Dependency[firrtl.transforms.InlineCastsTransform],
- Dependency[firrtl.transforms.LegalizeClocksTransform],
- Dependency[firrtl.transforms.FlattenRegUpdate] )
+ Seq(
+ Dependency[firrtl.transforms.BlackBoxSourceHelper],
+ Dependency[firrtl.transforms.FixAddingNegativeLiterals],
+ Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
+ Dependency[firrtl.transforms.InlineBitExtractionsTransform],
+ Dependency[firrtl.transforms.InlineCastsTransform],
+ Dependency[firrtl.transforms.LegalizeClocksTransform],
+ Dependency[firrtl.transforms.FlattenRegUpdate]
+ )
override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized
@@ -51,32 +53,35 @@ object VerilogModulusCleanup extends Pass {
case t => UnknownWidth
}
- def maxWidth(ws: Seq[Width]): Width = ws reduceLeft { (x,y) => (x,y) match {
- case (IntWidth(x), IntWidth(y)) => IntWidth(x max y)
- case (x, y) => UnknownWidth
- }}
+ def maxWidth(ws: Seq[Width]): Width = ws.reduceLeft { (x, y) =>
+ (x, y) match {
+ case (IntWidth(x), IntWidth(y)) => IntWidth(x.max(y))
+ case (x, y) => UnknownWidth
+ }
+ }
def verilogRemWidth(e: DoPrim)(tpe: Type): Type = {
val newWidth = maxWidth(e.args.map(exp => getWidth(exp)))
- tpe mapWidth (w => newWidth)
+ tpe.mapWidth(w => newWidth)
}
def removeRem(e: Expression): Expression = e match {
- case e: DoPrim => e.op match {
- case Rem =>
- val name = namespace.newTemp
- val newType = e mapType verilogRemWidth(e)
- v += DefNode(get_info(s), name, e mapType verilogRemWidth(e))
- val remRef = WRef(name, newType.tpe, kind(e), flow(e))
- val remWidth = bitWidth(e.tpe)
- DoPrim(Bits, Seq(remRef), Seq(remWidth - 1, BigInt(0)), e.tpe)
- case _ => e
- }
+ case e: DoPrim =>
+ e.op match {
+ case Rem =>
+ val name = namespace.newTemp
+ val newType = e.mapType(verilogRemWidth(e))
+ v += DefNode(get_info(s), name, e.mapType(verilogRemWidth(e)))
+ val remRef = WRef(name, newType.tpe, kind(e), flow(e))
+ val remWidth = bitWidth(e.tpe)
+ DoPrim(Bits, Seq(remRef), Seq(remWidth - 1, BigInt(0)), e.tpe)
+ case _ => e
+ }
case _ => e
}
- s map removeRem match {
- case x: Block => x map onStmt
+ s.map(removeRem) match {
+ case x: Block => x.map(onStmt)
case EmptyStmt => EmptyStmt
case x =>
v += x
@@ -90,8 +95,8 @@ object VerilogModulusCleanup extends Pass {
}
def run(c: Circuit): Circuit = {
- val modules = c.modules map {
- case m: Module => onModule(m)
+ val modules = c.modules.map {
+ case m: Module => onModule(m)
case m: ExtModule => m
}
Circuit(c.info, modules, c.main)
diff --git a/src/main/scala/firrtl/passes/VerilogPrep.scala b/src/main/scala/firrtl/passes/VerilogPrep.scala
index 03d47cfc..eeb34fa9 100644
--- a/src/main/scala/firrtl/passes/VerilogPrep.scala
+++ b/src/main/scala/firrtl/passes/VerilogPrep.scala
@@ -21,15 +21,17 @@ import scala.collection.mutable
object VerilogPrep extends Pass {
override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
- Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper],
- Dependency[firrtl.transforms.FixAddingNegativeLiterals],
- Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
- Dependency[firrtl.transforms.InlineBitExtractionsTransform],
- Dependency[firrtl.transforms.InlineCastsTransform],
- Dependency[firrtl.transforms.LegalizeClocksTransform],
- Dependency[firrtl.transforms.FlattenRegUpdate],
- Dependency(passes.VerilogModulusCleanup),
- Dependency[firrtl.transforms.VerilogRename] )
+ Seq(
+ Dependency[firrtl.transforms.BlackBoxSourceHelper],
+ Dependency[firrtl.transforms.FixAddingNegativeLiterals],
+ Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
+ Dependency[firrtl.transforms.InlineBitExtractionsTransform],
+ Dependency[firrtl.transforms.InlineCastsTransform],
+ Dependency[firrtl.transforms.LegalizeClocksTransform],
+ Dependency[firrtl.transforms.FlattenRegUpdate],
+ Dependency(passes.VerilogModulusCleanup),
+ Dependency[firrtl.transforms.VerilogRename]
+ )
override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized
@@ -46,9 +48,9 @@ object VerilogPrep extends Pass {
val sourceMap = mutable.HashMap.empty[WrappedExpression, Expression]
lazy val namespace = Namespace(m)
- def onStmt(stmt: Statement): Statement = stmt map onStmt match {
+ def onStmt(stmt: Statement): Statement = stmt.map(onStmt) match {
case attach: Attach =>
- val wires = attach.exprs groupBy kind
+ val wires = attach.exprs.groupBy(kind)
val sources = wires.getOrElse(PortKind, Seq.empty) ++ wires.getOrElse(WireKind, Seq.empty)
val instPorts = wires.getOrElse(InstanceKind, Seq.empty)
// Sanity check (Should be caught by CheckTypes)
@@ -71,14 +73,14 @@ object VerilogPrep extends Pass {
case s => s
}
- (m map onStmt, sourceMap.toMap)
+ (m.map(onStmt), sourceMap.toMap)
}
def run(c: Circuit): Circuit = {
def lowerE(e: Expression): Expression = e match {
case (_: WRef | _: WSubField) if kind(e) == InstanceKind =>
WRef(LowerTypes.loweredName(e), e.tpe, kind(e), flow(e))
- case _ => e map lowerE
+ case _ => e.map(lowerE)
}
def lowerS(attachMap: AttachSourceMap)(s: Statement): Statement = s match {
@@ -96,12 +98,12 @@ object VerilogPrep extends Pass {
}.unzip
val newInst = WDefInstanceConnector(info, name, module, tpe, portCons)
Block(wires.flatten :+ newInst)
- case other => other map lowerS(attachMap) map lowerE
+ case other => other.map(lowerS(attachMap)).map(lowerE)
}
- val modulesx = c.modules map { mod =>
+ val modulesx = c.modules.map { mod =>
val (modx, attachMap) = collectAndRemoveAttach(mod)
- modx map lowerS(attachMap)
+ modx.map(lowerS(attachMap))
}
c.copy(modules = modulesx)
}
diff --git a/src/main/scala/firrtl/passes/ZeroLengthVecs.scala b/src/main/scala/firrtl/passes/ZeroLengthVecs.scala
index 39c127de..e61780a4 100644
--- a/src/main/scala/firrtl/passes/ZeroLengthVecs.scala
+++ b/src/main/scala/firrtl/passes/ZeroLengthVecs.scala
@@ -17,10 +17,7 @@ import firrtl.options.Dependency
*/
object ZeroLengthVecs extends Pass {
override def prerequisites =
- Seq( Dependency(PullMuxes),
- Dependency(ResolveKinds),
- Dependency(InferTypes),
- Dependency(ExpandConnects) )
+ Seq(Dependency(PullMuxes), Dependency(ResolveKinds), Dependency(InferTypes), Dependency(ExpandConnects))
override def invalidates(a: Transform) = false
@@ -28,8 +25,8 @@ object ZeroLengthVecs extends Pass {
// interval type with the type alone unless you declare a component
private def replaceWithDontCare(toReplace: Expression): Expression = {
val default = toReplace.tpe match {
- case UIntType(w) => UIntLiteral(0, w)
- case SIntType(w) => SIntLiteral(0, w)
+ case UIntType(w) => UIntLiteral(0, w)
+ case SIntType(w) => SIntLiteral(0, w)
case FixedType(w, p) => FixedLiteral(0, w, p)
case it: IntervalType =>
val zeroType = IntervalType(Closed(0), Closed(0), IntWidth(0))
@@ -40,11 +37,11 @@ object ZeroLengthVecs extends Pass {
}
private def zeroLenDerivedRefLike(expr: Expression): Boolean = (expr, expr.tpe) match {
- case (_, VectorType(_, 0)) => true
- case (WSubIndex(e, _, _, _), _) => zeroLenDerivedRefLike(e)
+ case (_, VectorType(_, 0)) => true
+ case (WSubIndex(e, _, _, _), _) => zeroLenDerivedRefLike(e)
case (WSubAccess(e, _, _, _), _) => zeroLenDerivedRefLike(e)
- case (WSubField(e, _, _, _), _) => zeroLenDerivedRefLike(e)
- case _ => false
+ case (WSubField(e, _, _, _), _) => zeroLenDerivedRefLike(e)
+ case _ => false
}
// The connects have all been lowered, so all aggregate-typed expressions are "grounded" by WSubField/WSubAccess/WSubIndex
@@ -52,13 +49,13 @@ object ZeroLengthVecs extends Pass {
private def dropZeroLenSubAccesses(expr: Expression): Expression = expr match {
case _: WSubIndex | _: WSubAccess | _: WSubField =>
if (zeroLenDerivedRefLike(expr)) replaceWithDontCare(expr) else expr
- case e => e map dropZeroLenSubAccesses
+ case e => e.map(dropZeroLenSubAccesses)
}
// Attach semantics: drop all zero-length-derived members of attach group, drop stmt if trivial
private def onStmt(stmt: Statement): Statement = stmt match {
case Connect(_, sink, _) if zeroLenDerivedRefLike(sink) => EmptyStmt
- case IsInvalid(_, sink) if zeroLenDerivedRefLike(sink) => EmptyStmt
+ case IsInvalid(_, sink) if zeroLenDerivedRefLike(sink) => EmptyStmt
case Attach(info, sinks) =>
val filtered = Attach(info, sinks.filterNot(zeroLenDerivedRefLike))
if (filtered.exprs.length < 2) EmptyStmt else filtered
diff --git a/src/main/scala/firrtl/passes/ZeroWidth.scala b/src/main/scala/firrtl/passes/ZeroWidth.scala
index 56d66ef0..82321f95 100644
--- a/src/main/scala/firrtl/passes/ZeroWidth.scala
+++ b/src/main/scala/firrtl/passes/ZeroWidth.scala
@@ -11,12 +11,14 @@ import firrtl.options.Dependency
object ZeroWidth extends Transform with DependencyAPIMigration {
override def prerequisites =
- Seq( Dependency(PullMuxes),
- Dependency(ReplaceAccesses),
- Dependency(ExpandConnects),
- Dependency(RemoveAccesses),
- Dependency[ExpandWhensAndCheck],
- Dependency(ConvertFixedToSInt) ) ++ firrtl.stage.Forms.Deduped
+ Seq(
+ Dependency(PullMuxes),
+ Dependency(ReplaceAccesses),
+ Dependency(ExpandConnects),
+ Dependency(RemoveAccesses),
+ Dependency[ExpandWhensAndCheck],
+ Dependency(ConvertFixedToSInt)
+ ) ++ firrtl.stage.Forms.Deduped
override def invalidates(a: Transform): Boolean = a match {
case InferTypes => true
@@ -24,30 +26,41 @@ object ZeroWidth extends Transform with DependencyAPIMigration {
}
private def makeEmptyMemBundle(name: String): Field =
- Field(name, Flip, BundleType(Seq(
- Field("addr", Default, UIntType(IntWidth(0))),
- Field("en", Default, UIntType(IntWidth(0))),
- Field("clk", Default, UIntType(IntWidth(0))),
- Field("data", Flip, UIntType(IntWidth(0)))
- )))
+ Field(
+ name,
+ Flip,
+ BundleType(
+ Seq(
+ Field("addr", Default, UIntType(IntWidth(0))),
+ Field("en", Default, UIntType(IntWidth(0))),
+ Field("clk", Default, UIntType(IntWidth(0))),
+ Field("data", Flip, UIntType(IntWidth(0)))
+ )
+ )
+ )
private def onEmptyMemStmt(s: Statement): Statement = s match {
- case d @ DefMemory(info, name, tpe, _, _, _, rs, ws, rws, _) => removeZero(tpe) match {
- case None =>
- DefWire(info, name, BundleType(
- rs.map(r => makeEmptyMemBundle(r)) ++
- ws.map(w => makeEmptyMemBundle(w)) ++
- rws.map(rw => makeEmptyMemBundle(rw))
- ))
- case Some(_) => d
- }
- case sx => sx map onEmptyMemStmt
+ case d @ DefMemory(info, name, tpe, _, _, _, rs, ws, rws, _) =>
+ removeZero(tpe) match {
+ case None =>
+ DefWire(
+ info,
+ name,
+ BundleType(
+ rs.map(r => makeEmptyMemBundle(r)) ++
+ ws.map(w => makeEmptyMemBundle(w)) ++
+ rws.map(rw => makeEmptyMemBundle(rw))
+ )
+ )
+ case Some(_) => d
+ }
+ case sx => sx.map(onEmptyMemStmt)
}
private def onModuleEmptyMemStmt(m: DefModule): DefModule = {
m match {
case ext: ExtModule => ext
- case in: Module => in.copy(body = onEmptyMemStmt(in.body))
+ case in: Module => in.copy(body = onEmptyMemStmt(in.body))
}
}
@@ -59,20 +72,20 @@ object ZeroWidth extends Transform with DependencyAPIMigration {
* This replaces memories with a DefWire() bundle that contains the address, en,
* clk, and data fields implemented as zero width wires. Running the rest of the ZeroWidth
* transform will remove these dangling references properly.
- *
*/
def executeEmptyMemStmt(state: CircuitState): CircuitState = {
val c = state.circuit
- val result = c.copy(modules = c.modules map onModuleEmptyMemStmt)
+ val result = c.copy(modules = c.modules.map(onModuleEmptyMemStmt))
state.copy(circuit = result)
}
// This is slightly different and specialized version of create_exps, TODO unify?
private def findRemovable(expr: => Expression, tpe: Type): Seq[Expression] = tpe match {
- case GroundType(width) => width match {
- case IntWidth(ZERO) => List(expr)
- case _ => List.empty
- }
+ case GroundType(width) =>
+ width match {
+ case IntWidth(ZERO) => List(expr)
+ case _ => List.empty
+ }
case BundleType(fields) =>
if (fields.isEmpty) List(expr)
else fields.flatMap(f => findRemovable(WSubField(expr, f.name, f.tpe, SourceFlow), f.tpe))
@@ -95,7 +108,7 @@ object ZeroWidth extends Transform with DependencyAPIMigration {
t
}
x match {
- case s: Statement => s map onType(s.name)
+ case s: Statement => s.map(onType(s.name))
case Port(_, name, _, t) => onType(name)(t)
}
removedNames
@@ -103,14 +116,14 @@ object ZeroWidth extends Transform with DependencyAPIMigration {
private[passes] def removeZero(t: Type): Option[Type] = t match {
case GroundType(IntWidth(ZERO)) => None
case BundleType(fields) =>
- fields map (f => (f, removeZero(f.tpe))) collect {
+ fields.map(f => (f, removeZero(f.tpe))).collect {
case (Field(name, flip, _), Some(t)) => Field(name, flip, t)
} match {
case Nil => None
case seq => Some(BundleType(seq))
}
- case VectorType(t, size) => removeZero(t) map (VectorType(_, size))
- case x => Some(x)
+ case VectorType(t, size) => removeZero(t).map(VectorType(_, size))
+ case x => Some(x)
}
private def onExp(e: Expression): Expression = e match {
case DoPrim(Cat, args, consts, tpe) =>
@@ -118,26 +131,27 @@ object ZeroWidth extends Transform with DependencyAPIMigration {
x.tpe match {
case UIntType(IntWidth(ZERO)) => Seq.empty[Expression]
case SIntType(IntWidth(ZERO)) => Seq.empty[Expression]
- case other => Seq(x)
+ case other => Seq(x)
}
}
nonZeros match {
- case Nil => UIntLiteral(ZERO, IntWidth(BigInt(1)))
+ case Nil => UIntLiteral(ZERO, IntWidth(BigInt(1)))
case Seq(x) => x
- case seq => DoPrim(Cat, seq, consts, tpe) map onExp
+ case seq => DoPrim(Cat, seq, consts, tpe).map(onExp)
}
case DoPrim(Andr, Seq(x), _, _) if (bitWidth(x.tpe) == 0) => UIntLiteral(1) // nothing false
- case other => other.tpe match {
- case UIntType(IntWidth(ZERO)) => UIntLiteral(ZERO, IntWidth(BigInt(1)))
- case SIntType(IntWidth(ZERO)) => SIntLiteral(ZERO, IntWidth(BigInt(1)))
- case _ => e map onExp
- }
+ case other =>
+ other.tpe match {
+ case UIntType(IntWidth(ZERO)) => UIntLiteral(ZERO, IntWidth(BigInt(1)))
+ case SIntType(IntWidth(ZERO)) => SIntLiteral(ZERO, IntWidth(BigInt(1)))
+ case _ => e.map(onExp)
+ }
}
private def onStmt(renames: RenameMap)(s: Statement): Statement = s match {
case d @ DefWire(info, name, tpe) =>
renames.delete(getRemoved(d))
removeZero(tpe) match {
- case None => EmptyStmt
+ case None => EmptyStmt
case Some(t) => DefWire(info, name, t)
}
case d @ DefRegister(info, name, tpe, clock, reset, init) =>
@@ -145,7 +159,7 @@ object ZeroWidth extends Transform with DependencyAPIMigration {
removeZero(tpe) match {
case None => EmptyStmt
case Some(t) =>
- DefRegister(info, name, t, onExp(clock), onExp(reset), onExp(init))
+ DefRegister(info, name, t, onExp(clock), onExp(reset), onExp(init))
}
case d: DefMemory =>
renames.delete(getRemoved(d))
@@ -154,25 +168,28 @@ object ZeroWidth extends Transform with DependencyAPIMigration {
Utils.throwInternalError(s"private pass ZeroWidthMemRemove should have removed this memory: $d")
case Some(t) => d.copy(dataType = t)
}
- case Connect(info, loc, exp) => removeZero(loc.tpe) match {
- case None => EmptyStmt
- case Some(t) => Connect(info, loc, onExp(exp))
- }
- case IsInvalid(info, exp) => removeZero(exp.tpe) match {
- case None => EmptyStmt
- case Some(t) => IsInvalid(info, onExp(exp))
- }
- case DefNode(info, name, value) => removeZero(value.tpe) match {
- case None => EmptyStmt
- case Some(t) => DefNode(info, name, onExp(value))
- }
- case sx => sx map onStmt(renames) map onExp
+ case Connect(info, loc, exp) =>
+ removeZero(loc.tpe) match {
+ case None => EmptyStmt
+ case Some(t) => Connect(info, loc, onExp(exp))
+ }
+ case IsInvalid(info, exp) =>
+ removeZero(exp.tpe) match {
+ case None => EmptyStmt
+ case Some(t) => IsInvalid(info, onExp(exp))
+ }
+ case DefNode(info, name, value) =>
+ removeZero(value.tpe) match {
+ case None => EmptyStmt
+ case Some(t) => DefNode(info, name, onExp(value))
+ }
+ case sx => sx.map(onStmt(renames)).map(onExp)
}
private def onModule(renames: RenameMap)(m: DefModule): DefModule = {
renames.setModule(m.name)
// For each port, record deleted subcomponents
- m.ports.foreach{p => renames.delete(getRemoved(p))}
- val ports = m.ports map (p => (p, removeZero(p.tpe))) flatMap {
+ m.ports.foreach { p => renames.delete(getRemoved(p)) }
+ val ports = m.ports.map(p => (p, removeZero(p.tpe))).flatMap {
case (Port(info, name, dir, _), Some(t)) => Seq(Port(info, name, dir, t))
case (Port(_, name, _, _), None) =>
renames.delete(name)
@@ -180,7 +197,7 @@ object ZeroWidth extends Transform with DependencyAPIMigration {
}
m match {
case ext: ExtModule => ext.copy(ports = ports)
- case in: Module => in.copy(ports = ports, body = onStmt(renames)(in.body))
+ case in: Module => in.copy(ports = ports, body = onStmt(renames)(in.body))
}
}
def execute(state: CircuitState): CircuitState = {
@@ -189,7 +206,7 @@ object ZeroWidth extends Transform with DependencyAPIMigration {
val c = InferTypes.run(executeEmptyMemStmt(state).circuit)
val renames = RenameMap()
renames.setCircuit(c.main)
- val result = c.copy(modules = c.modules map onModule(renames))
+ val result = c.copy(modules = c.modules.map(onModule(renames)))
CircuitState(result, outputForm, state.annotations, Some(renames))
}
}
diff --git a/src/main/scala/firrtl/passes/clocklist/ClockList.scala b/src/main/scala/firrtl/passes/clocklist/ClockList.scala
index c2323d4c..bfc03b51 100644
--- a/src/main/scala/firrtl/passes/clocklist/ClockList.scala
+++ b/src/main/scala/firrtl/passes/clocklist/ClockList.scala
@@ -13,8 +13,8 @@ import Utils._
import memlib.AnalysisUtils._
/** Starting with a top module, determine the clock origins of each child instance.
- * Write the result to writer.
- */
+ * Write the result to writer.
+ */
class ClockList(top: String, writer: Writer) extends Pass {
def run(c: Circuit): Circuit = {
// Build useful datastructures
@@ -29,7 +29,7 @@ class ClockList(top: String, writer: Writer) extends Pass {
// Clock sources must be blackbox outputs and top's clock
val partialSourceList = getSourceList(moduleMap)(lineages)
- val sourceList = partialSourceList ++ moduleMap(top).ports.collect{ case Port(i, n, Input, ClockType) => n }
+ val sourceList = partialSourceList ++ moduleMap(top).ports.collect { case Port(i, n, Input, ClockType) => n }
writer.append(s"Sourcelist: $sourceList \n")
// Remove everything from the circuit, unless it has a clock type
@@ -37,8 +37,9 @@ class ClockList(top: String, writer: Writer) extends Pass {
val onlyClockCircuit = RemoveAllButClocks.run(c)
// Inline the clock-only circuit up to the specified top module
- val modulesToInline = (c.modules.collect { case Module(_, n, _, _) if n != top => ModuleName(n, CircuitName(c.main)) }).toSet
- val inlineTransform = new InlineInstances{ override val inlineDelim = "$" }
+ val modulesToInline =
+ (c.modules.collect { case Module(_, n, _, _) if n != top => ModuleName(n, CircuitName(c.main)) }).toSet
+ val inlineTransform = new InlineInstances { override val inlineDelim = "$" }
val inlinedCircuit = inlineTransform.run(onlyClockCircuit, modulesToInline, Set(), Seq()).circuit
val topModule = inlinedCircuit.modules.find(_.name == top).getOrElse(throwInternalError("no top module"))
@@ -49,13 +50,14 @@ class ClockList(top: String, writer: Writer) extends Pass {
val origins = getOrigins(connects, "", moduleMap)(lineages)
// If the clock origin is contained in the source list, label good (otherwise bad)
- origins.foreach { case (instance, origin) =>
- val sep = if(instance == "") "" else "."
- if(!sourceList.contains(origin.replace('.','$'))){
- outputBuffer.append(s"Bad Origin of $instance${sep}clock is $origin\n")
- } else {
- outputBuffer.append(s"Good Origin of $instance${sep}clock is $origin\n")
- }
+ origins.foreach {
+ case (instance, origin) =>
+ val sep = if (instance == "") "" else "."
+ if (!sourceList.contains(origin.replace('.', '$'))) {
+ outputBuffer.append(s"Bad Origin of $instance${sep}clock is $origin\n")
+ } else {
+ outputBuffer.append(s"Good Origin of $instance${sep}clock is $origin\n")
+ }
}
// Write to output file
diff --git a/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala b/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala
index e6617857..468ba905 100644
--- a/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala
+++ b/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala
@@ -12,8 +12,7 @@ import memlib._
import firrtl.options.{RegisteredTransform, ShellOption}
import firrtl.stage.{Forms, RunFirrtlTransformAnnotation}
-case class ClockListAnnotation(target: ModuleName, outputConfig: String) extends
- SingleTargetAnnotation[ModuleName] {
+case class ClockListAnnotation(target: ModuleName, outputConfig: String) extends SingleTargetAnnotation[ModuleName] {
def duplicate(n: ModuleName) = ClockListAnnotation(n, outputConfig)
}
@@ -44,7 +43,7 @@ Usage:
)
passOptions.get(InputConfigFileName) match {
case Some(x) => error("Unneeded input config file name!" + usage)
- case None =>
+ case None =>
}
val target = ModuleName(passModule, CircuitName(passCircuit))
ClockListAnnotation(target, outputConfig)
@@ -53,18 +52,20 @@ Usage:
class ClockListTransform extends Transform with DependencyAPIMigration with RegisteredTransform {
- override def prerequisites = Forms.LowForm
- override def optionalPrerequisites = Seq.empty
- override def optionalPrerequisiteOf = Forms.LowEmitters
+ override def prerequisites = Forms.LowForm
+ override def optionalPrerequisites = Seq.empty
+ override def optionalPrerequisiteOf = Forms.LowEmitters
val options = Seq(
new ShellOption[String](
longOption = "list-clocks",
- toAnnotationSeq = (a: String) => Seq( passes.clocklist.ClockListAnnotation.parse(a),
- RunFirrtlTransformAnnotation(new ClockListTransform) ),
+ toAnnotationSeq = (a: String) =>
+ Seq(passes.clocklist.ClockListAnnotation.parse(a), RunFirrtlTransformAnnotation(new ClockListTransform)),
helpText = "List which signal drives each clock of every descendent of specified modules",
shortOption = Some("clks"),
- helpValueName = Some("-c:<circuit>:-m:<module>:-o:<filename>") ) )
+ helpValueName = Some("-c:<circuit>:-m:<module>:-o:<filename>")
+ )
+ )
def passSeq(top: String, writer: Writer): Seq[Pass] =
Seq(new ClockList(top, writer))
diff --git a/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala b/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala
index b77629fc..00e07588 100644
--- a/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala
+++ b/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala
@@ -10,45 +10,56 @@ import Utils._
import memlib.AnalysisUtils._
object ClockListUtils {
+
/** Returns a list of clock outputs from instances of external modules
- */
+ */
def getSourceList(moduleMap: Map[String, DefModule])(lin: Lineage): Seq[String] = {
- val s = lin.foldLeft(Seq[String]()){case (sL, (i, l)) =>
- val sLx = getSourceList(moduleMap)(l)
- val sLxx = sLx map (i + "$" + _)
- sL ++ sLxx
+ val s = lin.foldLeft(Seq[String]()) {
+ case (sL, (i, l)) =>
+ val sLx = getSourceList(moduleMap)(l)
+ val sLxx = sLx.map(i + "$" + _)
+ sL ++ sLxx
}
val sourceList = moduleMap(lin.name) match {
case ExtModule(i, n, ports, dn, p) =>
- val portExps = ports.flatMap{p => create_exps(WRef(p.name, p.tpe, PortKind, to_flow(p.direction)))}
+ val portExps = ports.flatMap { p => create_exps(WRef(p.name, p.tpe, PortKind, to_flow(p.direction))) }
portExps.filter(e => (e.tpe == ClockType) && (flow(e) == SinkFlow)).map(_.serialize)
case _ => Nil
}
val sx = sourceList ++ s
sx
}
+
/** Returns a map from instance name to its clock origin.
- * Child instances are not included if they share the same clock as their parent
- */
- def getOrigins(connects: Connects, me: String, moduleMap: Map[String, DefModule])(lin: Lineage): Map[String, String] = {
- val sep = if(me == "") "" else "$"
+ * Child instances are not included if they share the same clock as their parent
+ */
+ def getOrigins(
+ connects: Connects,
+ me: String,
+ moduleMap: Map[String, DefModule]
+ )(lin: Lineage
+ ): Map[String, String] = {
+ val sep = if (me == "") "" else "$"
// Get origins from all children
- val childrenOrigins = lin.foldLeft(Map[String, String]()){case (o, (i, l)) =>
- o ++ getOrigins(connects, me + sep + i, moduleMap)(l)
+ val childrenOrigins = lin.foldLeft(Map[String, String]()) {
+ case (o, (i, l)) =>
+ o ++ getOrigins(connects, me + sep + i, moduleMap)(l)
}
// If I have a clock, get it
val clockOpt = moduleMap(lin.name) match {
- case Module(i, n, ports, b) => ports.collectFirst { case p if p.name == "clock" => me + sep + "clock" }
+ case Module(i, n, ports, b) => ports.collectFirst { case p if p.name == "clock" => me + sep + "clock" }
case ExtModule(i, n, ports, dn, p) => None
}
// Return new origins with direct children removed, if they match my clock
clockOpt match {
case Some(clock) =>
val myOrigin = getOrigin(connects, clock).serialize
- childrenOrigins.foldLeft(Map(me -> myOrigin)) { case (o, (childInstance, childOrigin)) =>
- val childrenInstances = lin.children.map { case (instance, _) => me + sep + instance }
- // If direct child shares my origin, omit it
- if(childOrigin == myOrigin && childrenInstances.contains(childInstance)) o else o + (childInstance -> childOrigin)
+ childrenOrigins.foldLeft(Map(me -> myOrigin)) {
+ case (o, (childInstance, childOrigin)) =>
+ val childrenInstances = lin.children.map { case (instance, _) => me + sep + instance }
+ // If direct child shares my origin, omit it
+ if (childOrigin == myOrigin && childrenInstances.contains(childInstance)) o
+ else o + (childInstance -> childOrigin)
}
case None => childrenOrigins
}
diff --git a/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala b/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala
index 6eb8c138..d72bc293 100644
--- a/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala
+++ b/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala
@@ -9,22 +9,22 @@ import Utils._
import Mappers._
/** Remove all statements and ports (except instances/whens/blocks) whose
- * expressions do not relate to ground types.
- */
+ * expressions do not relate to ground types.
+ */
object RemoveAllButClocks extends Pass {
- def onStmt(s: Statement): Statement = (s map onStmt) match {
- case DefWire(i, n, ClockType) => s
+ def onStmt(s: Statement): Statement = (s.map(onStmt)) match {
+ case DefWire(i, n, ClockType) => s
case DefNode(i, n, value) if value.tpe == ClockType => s
- case Connect(i, l, r) if l.tpe == ClockType => s
- case sx: WDefInstance => sx
- case sx: DefInstance => sx
- case sx: Block => sx
+ case Connect(i, l, r) if l.tpe == ClockType => s
+ case sx: WDefInstance => sx
+ case sx: DefInstance => sx
+ case sx: Block => sx
case sx: Conditionally => sx
case _ => EmptyStmt
}
def onModule(m: DefModule): DefModule = m match {
- case Module(i, n, ps, b) => Module(i, n, ps.filter(_.tpe == ClockType), squashEmpty(onStmt(b)))
+ case Module(i, n, ps, b) => Module(i, n, ps.filter(_.tpe == ClockType), squashEmpty(onStmt(b)))
case ExtModule(i, n, ps, dn, p) => ExtModule(i, n, ps.filter(_.tpe == ClockType), dn, p)
}
- def run(c: Circuit): Circuit = c.copy(modules = c.modules map onModule)
+ def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(onModule))
}
diff --git a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala
index 14bd9e44..d237c36a 100644
--- a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala
+++ b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala
@@ -19,8 +19,9 @@ class CreateMemoryAnnotations(reader: Option[YamlFileReader]) extends Transform
import CustomYAMLProtocol._
val configs = r.parse[Config]
val oldAnnos = state.annotations
- val (as, pins) = configs.foldLeft((oldAnnos, Seq.empty[String])) { case ((annos, pins), config) =>
- (annos, pins :+ config.pin.name)
+ val (as, pins) = configs.foldLeft((oldAnnos, Seq.empty[String])) {
+ case ((annos, pins), config) =>
+ (annos, pins :+ config.pin.name)
}
state.copy(annotations = PinAnnotation(pins.toSeq) +: as)
}
diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
index 4847a698..e290633e 100644
--- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
+++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
@@ -10,12 +10,11 @@ import firrtl.PrimOps._
import firrtl.Utils.{one, zero, BoolType}
import firrtl.options.{HasShellOptions, ShellOption}
import MemPortUtils.memPortField
-import firrtl.passes.memlib.AnalysisUtils.{Connects, getConnects, getOrigin}
+import firrtl.passes.memlib.AnalysisUtils.{getConnects, getOrigin, Connects}
import WrappedExpression.weq
import annotations._
import firrtl.stage.{Forms, RunFirrtlTransformAnnotation}
-
case object InferReadWriteAnnotation extends NoTargetAnnotation
// This pass examine the enable signals of the read & write ports of memories
@@ -40,12 +39,13 @@ object InferReadWritePass extends Pass {
getProductTerms(connects)(cond) ++ getProductTerms(connects)(tval)
// Visit each term of AND operation
case DoPrim(op, args, consts, tpe) if op == And =>
- e +: (args flatMap getProductTerms(connects))
+ e +: (args.flatMap(getProductTerms(connects)))
// Visit connected nodes to references
- case _: WRef | _: WSubField | _: WSubIndex => connects get e match {
- case None => Seq(e)
- case Some(ex) => e +: getProductTerms(connects)(ex)
- }
+ case _: WRef | _: WSubField | _: WSubIndex =>
+ connects.get(e) match {
+ case None => Seq(e)
+ case Some(ex) => e +: getProductTerms(connects)(ex)
+ }
// Otherwise just return itself
case _ => Seq(e)
}
@@ -58,96 +58,103 @@ object InferReadWritePass extends Pass {
// b ?= Eq(a, 0) or b ?= Eq(0, a)
case (_, DoPrim(Eq, args, _, _)) =>
weq(args.head, a) && weq(args(1), zero) ||
- weq(args(1), a) && weq(args.head, zero)
+ weq(args(1), a) && weq(args.head, zero)
// a ?= Eq(b, 0) or b ?= Eq(0, a)
case (DoPrim(Eq, args, _, _), _) =>
weq(args.head, b) && weq(args(1), zero) ||
- weq(args(1), b) && weq(args.head, zero)
+ weq(args(1), b) && weq(args.head, zero)
case _ => false
}
-
def replaceExp(repl: Netlist)(e: Expression): Expression =
- e map replaceExp(repl) match {
- case ex: WSubField => repl getOrElse (ex.serialize, ex)
+ e.map(replaceExp(repl)) match {
+ case ex: WSubField => repl.getOrElse(ex.serialize, ex)
case ex => ex
}
def replaceStmt(repl: Netlist)(s: Statement): Statement =
- s map replaceStmt(repl) map replaceExp(repl) match {
+ s.map(replaceStmt(repl)).map(replaceExp(repl)) match {
case Connect(_, EmptyExpression, _) => EmptyStmt
- case sx => sx
+ case sx => sx
}
- def inferReadWriteStmt(connects: Connects,
- repl: Netlist,
- stmts: Statements)
- (s: Statement): Statement = s match {
+ def inferReadWriteStmt(connects: Connects, repl: Netlist, stmts: Statements)(s: Statement): Statement = s match {
// infer readwrite ports only for non combinational memories
case mem: DefMemory if mem.readLatency > 0 =>
val readers = new PortSet
val writers = new PortSet
val readwriters = collection.mutable.ArrayBuffer[String]()
val namespace = Namespace(mem.readers ++ mem.writers ++ mem.readwriters)
- for (w <- mem.writers ; r <- mem.readers) {
+ for {
+ w <- mem.writers
+ r <- mem.readers
+ } {
val wenProductTerms = getProductTerms(connects)(memPortField(mem, w, "en"))
val renProductTerms = getProductTerms(connects)(memPortField(mem, r, "en"))
- val proofOfMutualExclusion = wenProductTerms.find(a => renProductTerms exists (b => checkComplement(a, b)))
+ val proofOfMutualExclusion = wenProductTerms.find(a => renProductTerms.exists(b => checkComplement(a, b)))
val wclk = getOrigin(connects)(memPortField(mem, w, "clk"))
val rclk = getOrigin(connects)(memPortField(mem, r, "clk"))
if (weq(wclk, rclk) && proofOfMutualExclusion.nonEmpty) {
- val rw = namespace newName "rw"
+ val rw = namespace.newName("rw")
val rwExp = WSubField(WRef(mem.name), rw)
readwriters += rw
readers += r
writers += w
- repl(memPortField(mem, r, "clk")) = EmptyExpression
- repl(memPortField(mem, r, "en")) = EmptyExpression
+ repl(memPortField(mem, r, "clk")) = EmptyExpression
+ repl(memPortField(mem, r, "en")) = EmptyExpression
repl(memPortField(mem, r, "addr")) = EmptyExpression
repl(memPortField(mem, r, "data")) = WSubField(rwExp, "rdata")
- repl(memPortField(mem, w, "clk")) = EmptyExpression
- repl(memPortField(mem, w, "en")) = EmptyExpression
+ repl(memPortField(mem, w, "clk")) = EmptyExpression
+ repl(memPortField(mem, w, "en")) = EmptyExpression
repl(memPortField(mem, w, "addr")) = EmptyExpression
repl(memPortField(mem, w, "data")) = WSubField(rwExp, "wdata")
repl(memPortField(mem, w, "mask")) = WSubField(rwExp, "wmask")
stmts += Connect(NoInfo, WSubField(rwExp, "wmode"), proofOfMutualExclusion.get)
stmts += Connect(NoInfo, WSubField(rwExp, "clk"), wclk)
- stmts += Connect(NoInfo, WSubField(rwExp, "en"),
- DoPrim(Or, Seq(connects(memPortField(mem, r, "en")),
- connects(memPortField(mem, w, "en"))), Nil, BoolType))
- stmts += Connect(NoInfo, WSubField(rwExp, "addr"),
- Mux(connects(memPortField(mem, w, "en")),
- connects(memPortField(mem, w, "addr")),
- connects(memPortField(mem, r, "addr")), UnknownType))
+ stmts += Connect(
+ NoInfo,
+ WSubField(rwExp, "en"),
+ DoPrim(Or, Seq(connects(memPortField(mem, r, "en")), connects(memPortField(mem, w, "en"))), Nil, BoolType)
+ )
+ stmts += Connect(
+ NoInfo,
+ WSubField(rwExp, "addr"),
+ Mux(
+ connects(memPortField(mem, w, "en")),
+ connects(memPortField(mem, w, "addr")),
+ connects(memPortField(mem, r, "addr")),
+ UnknownType
+ )
+ )
}
}
- if (readwriters.isEmpty) mem else mem copy (
- readers = mem.readers filterNot readers,
- writers = mem.writers filterNot writers,
- readwriters = mem.readwriters ++ readwriters)
- case sx => sx map inferReadWriteStmt(connects, repl, stmts)
+ if (readwriters.isEmpty) mem
+ else
+ mem.copy(
+ readers = mem.readers.filterNot(readers),
+ writers = mem.writers.filterNot(writers),
+ readwriters = mem.readwriters ++ readwriters
+ )
+ case sx => sx.map(inferReadWriteStmt(connects, repl, stmts))
}
def inferReadWrite(m: DefModule) = {
val connects = getConnects(m)
val repl = new Netlist
val stmts = new Statements
- (m map inferReadWriteStmt(connects, repl, stmts)
- map replaceStmt(repl)) match {
+ (m.map(inferReadWriteStmt(connects, repl, stmts))
+ .map(replaceStmt(repl))) match {
case m: ExtModule => m
- case m: Module => m copy (body = Block(m.body +: stmts.toSeq))
+ case m: Module => m.copy(body = Block(m.body +: stmts.toSeq))
}
}
- def run(c: Circuit) = c copy (modules = c.modules map inferReadWrite)
+ def run(c: Circuit) = c.copy(modules = c.modules.map(inferReadWrite))
}
// Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl"
// To use this transform, circuit name should be annotated with its TransId.
-class InferReadWrite extends Transform
- with DependencyAPIMigration
- with SeqTransformBased
- with HasShellOptions {
+class InferReadWrite extends Transform with DependencyAPIMigration with SeqTransformBased with HasShellOptions {
override def prerequisites = Forms.MidForm
override def optionalPrerequisites = Seq.empty
@@ -159,7 +166,9 @@ class InferReadWrite extends Transform
longOption = "infer-rw",
toAnnotationSeq = (_: Unit) => Seq(InferReadWriteAnnotation, RunFirrtlTransformAnnotation(new InferReadWrite)),
helpText = "Enable read/write port inference for memories",
- shortOption = Some("firw") ) )
+ shortOption = Some("firw")
+ )
+ )
def transforms = Seq(
InferReadWritePass,
diff --git a/src/main/scala/firrtl/passes/memlib/MemConf.scala b/src/main/scala/firrtl/passes/memlib/MemConf.scala
index 3809c47c..871a1093 100644
--- a/src/main/scala/firrtl/passes/memlib/MemConf.scala
+++ b/src/main/scala/firrtl/passes/memlib/MemConf.scala
@@ -3,7 +3,6 @@
package firrtl.passes
package memlib
-
sealed abstract class MemPort(val name: String) { override def toString = name }
case object ReadPort extends MemPort("read")
@@ -19,22 +18,27 @@ object MemPort {
def apply(s: String): Option[MemPort] = MemPort.all.find(_.name == s)
def fromString(s: String): Map[MemPort, Int] = {
- s.split(",").toSeq.map(MemPort.apply).map(_ match {
- case Some(x) => x
- case _ => throw new Exception(s"Error parsing MemPort string : ${s}")
- }).groupBy(identity).mapValues(_.size).toMap
+ s.split(",")
+ .toSeq
+ .map(MemPort.apply)
+ .map(_ match {
+ case Some(x) => x
+ case _ => throw new Exception(s"Error parsing MemPort string : ${s}")
+ })
+ .groupBy(identity)
+ .mapValues(_.size)
+ .toMap
}
}
case class MemConf(
- name: String,
- depth: BigInt,
- width: Int,
- ports: Map[MemPort, Int],
- maskGranularity: Option[Int]
-) {
+ name: String,
+ depth: BigInt,
+ width: Int,
+ ports: Map[MemPort, Int],
+ maskGranularity: Option[Int]) {
- private def portsStr = ports.map { case (port, num) => Seq.fill(num)(port.name).mkString(",") } mkString (",")
+ private def portsStr = ports.map { case (port, num) => Seq.fill(num)(port.name).mkString(",") }.mkString(",")
private def maskGranStr = maskGranularity.map((p) => s"mask_gran $p").getOrElse("")
// Assert that all of the entries in the port map are greater than zero to make it easier to compare two of these case classes
@@ -49,21 +53,34 @@ object MemConf {
val regex = raw"\s*name\s+(\w+)\s+depth\s+(\d+)\s+width\s+(\d+)\s+ports\s+([^\s]+)\s+(?:mask_gran\s+(\d+))?\s*".r
def fromString(s: String): Seq[MemConf] = {
- s.split("\n").toSeq.map(_ match {
- case MemConf.regex(name, depth, width, ports, maskGran) => Some(MemConf(name, BigInt(depth), width.toInt, MemPort.fromString(ports), Option(maskGran).map(_.toInt)))
- case "" => None
- case _ => throw new Exception(s"Error parsing MemConf string : ${s}")
- }).flatten
+ s.split("\n")
+ .toSeq
+ .map(_ match {
+ case MemConf.regex(name, depth, width, ports, maskGran) =>
+ Some(MemConf(name, BigInt(depth), width.toInt, MemPort.fromString(ports), Option(maskGran).map(_.toInt)))
+ case "" => None
+ case _ => throw new Exception(s"Error parsing MemConf string : ${s}")
+ })
+ .flatten
}
- def apply(name: String, depth: BigInt, width: Int, readPorts: Int, writePorts: Int, readWritePorts: Int, maskGranularity: Option[Int]): MemConf = {
+ def apply(
+ name: String,
+ depth: BigInt,
+ width: Int,
+ readPorts: Int,
+ writePorts: Int,
+ readWritePorts: Int,
+ maskGranularity: Option[Int]
+ ): MemConf = {
val ports: Seq[(MemPort, Int)] = (if (maskGranularity.isEmpty) {
- (if (writePorts == 0) Seq() else Seq(WritePort -> writePorts)) ++
- (if (readWritePorts == 0) Seq() else Seq(ReadWritePort -> readWritePorts))
- } else {
- (if (writePorts == 0) Seq() else Seq(MaskedWritePort -> writePorts)) ++
- (if (readWritePorts == 0) Seq() else Seq(MaskedReadWritePort -> readWritePorts))
- }) ++ (if (readPorts == 0) Seq() else Seq(ReadPort -> readPorts))
+ (if (writePorts == 0) Seq() else Seq(WritePort -> writePorts)) ++
+ (if (readWritePorts == 0) Seq() else Seq(ReadWritePort -> readWritePorts))
+ } else {
+ (if (writePorts == 0) Seq() else Seq(MaskedWritePort -> writePorts)) ++
+ (if (readWritePorts == 0) Seq()
+ else Seq(MaskedReadWritePort -> readWritePorts))
+ }) ++ (if (readPorts == 0) Seq() else Seq(ReadPort -> readPorts))
new MemConf(name, depth, width, ports.toMap, maskGranularity)
}
}
diff --git a/src/main/scala/firrtl/passes/memlib/MemIR.scala b/src/main/scala/firrtl/passes/memlib/MemIR.scala
index 3731ea86..c8cd3e8d 100644
--- a/src/main/scala/firrtl/passes/memlib/MemIR.scala
+++ b/src/main/scala/firrtl/passes/memlib/MemIR.scala
@@ -19,38 +19,38 @@ object DefAnnotatedMemory {
m.readwriters,
m.readUnderWrite,
None, // mask granularity annotation
- None // No reference yet to another memory
+ None // No reference yet to another memory
)
}
}
case class DefAnnotatedMemory(
- info: Info,
- name: String,
- dataType: Type,
- depth: BigInt,
- writeLatency: Int,
- readLatency: Int,
- readers: Seq[String],
- writers: Seq[String],
- readwriters: Seq[String],
- readUnderWrite: ReadUnderWrite.Value,
- maskGran: Option[BigInt],
- memRef: Option[(String, String)] /* (Module, Mem) */
- //pins: Seq[Pin],
- ) extends Statement with IsDeclaration {
+ info: Info,
+ name: String,
+ dataType: Type,
+ depth: BigInt,
+ writeLatency: Int,
+ readLatency: Int,
+ readers: Seq[String],
+ writers: Seq[String],
+ readwriters: Seq[String],
+ readUnderWrite: ReadUnderWrite.Value,
+ maskGran: Option[BigInt],
+ memRef: Option[(String, String)] /* (Module, Mem) */
+ //pins: Seq[Pin],
+) extends Statement
+ with IsDeclaration {
override def serialize: String = this.toMem.serialize
- def mapStmt(f: Statement => Statement): Statement = this
- def mapExpr(f: Expression => Expression): Statement = this
- def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType))
- def mapString(f: String => String): Statement = this.copy(name = f(name))
- def toMem = DefMemory(info, name, dataType, depth,
- writeLatency, readLatency, readers, writers,
- readwriters, readUnderWrite)
- def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
- def foreachStmt(f: Statement => Unit): Unit = ()
- def foreachExpr(f: Expression => Unit): Unit = ()
- def foreachType(f: Type => Unit): Unit = f(dataType)
- def foreachString(f: String => Unit): Unit = f(name)
- def foreachInfo(f: Info => Unit): Unit = f(info)
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = this
+ def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType))
+ def mapString(f: String => String): Statement = this.copy(name = f(name))
+ def toMem =
+ DefMemory(info, name, dataType, depth, writeLatency, readLatency, readers, writers, readwriters, readUnderWrite)
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = ()
+ def foreachExpr(f: Expression => Unit): Unit = ()
+ def foreachType(f: Type => Unit): Unit = f(dataType)
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
diff --git a/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala b/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala
index f0c9ebf4..1db132f7 100644
--- a/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala
+++ b/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala
@@ -7,8 +7,7 @@ import firrtl.options.{RegisteredLibrary, ShellOption}
class MemLibOptions extends RegisteredLibrary {
val name: String = "MemLib Options"
- val options: Seq[ShellOption[_]] = Seq( new InferReadWrite,
- new ReplSeqMem )
+ val options: Seq[ShellOption[_]] = Seq(new InferReadWrite, new ReplSeqMem)
.flatMap(_.options)
}
diff --git a/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala b/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala
index b6a9a23d..f153fa2b 100644
--- a/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala
+++ b/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala
@@ -11,12 +11,12 @@ import MemPortUtils.{MemPortMap}
object MemTransformUtils {
/** Replaces references to old memory port names with new memory port names
- */
+ */
def updateStmtRefs(repl: MemPortMap)(s: Statement): Statement = {
//TODO(izraelevitz): check speed
def updateRef(e: Expression): Expression = {
- val ex = e map updateRef
- repl getOrElse (ex.serialize, ex)
+ val ex = e.map(updateRef)
+ repl.getOrElse(ex.serialize, ex)
}
def hasEmptyExpr(stmt: Statement): Boolean = {
@@ -24,16 +24,16 @@ object MemTransformUtils {
def testEmptyExpr(e: Expression): Expression = {
e match {
case EmptyExpression => foundEmpty = true
- case _ =>
+ case _ =>
}
- e map testEmptyExpr // map must return; no foreach
+ e.map(testEmptyExpr) // map must return; no foreach
}
- stmt map testEmptyExpr
+ stmt.map(testEmptyExpr)
foundEmpty
}
def updateStmtRefs(s: Statement): Statement =
- s map updateStmtRefs map updateRef match {
+ s.map(updateStmtRefs).map(updateRef) match {
case c: Connect if hasEmptyExpr(c) => EmptyStmt
case s => s
}
@@ -42,6 +42,6 @@ object MemTransformUtils {
}
def defaultPortSeq(mem: DefAnnotatedMemory): Seq[Field] = MemPortUtils.defaultPortSeq(mem.toMem)
- def memPortField(s: DefAnnotatedMemory, p: String, f: String): WSubField =
+ def memPortField(s: DefAnnotatedMemory, p: String, f: String): WSubField =
MemPortUtils.memPortField(s.toMem, p, f)
}
diff --git a/src/main/scala/firrtl/passes/memlib/MemUtils.scala b/src/main/scala/firrtl/passes/memlib/MemUtils.scala
index 69c6b284..f325c0ba 100644
--- a/src/main/scala/firrtl/passes/memlib/MemUtils.scala
+++ b/src/main/scala/firrtl/passes/memlib/MemUtils.scala
@@ -7,19 +7,19 @@ import firrtl.ir._
import firrtl.Utils._
/** Given a mask, return a bitmask corresponding to the desired datatype.
- * Requirements:
- * - The mask type and datatype must be equivalent, except any ground type in
- * datatype must be matched by a 1-bit wide UIntType.
- * - The mask must be a reference, subfield, or subindex
- * The bitmask is a series of concatenations of the single mask bit over the
- * length of the corresponding ground type, e.g.:
- *{{{
- * wire mask: {x: UInt<1>, y: UInt<1>}
- * wire data: {x: UInt<2>, y: SInt<2>}
- * // this would return:
- * cat(cat(mask.x, mask.x), cat(mask.y, mask.y))
- * }}}
- */
+ * Requirements:
+ * - The mask type and datatype must be equivalent, except any ground type in
+ * datatype must be matched by a 1-bit wide UIntType.
+ * - The mask must be a reference, subfield, or subindex
+ * The bitmask is a series of concatenations of the single mask bit over the
+ * length of the corresponding ground type, e.g.:
+ * {{{
+ * wire mask: {x: UInt<1>, y: UInt<1>}
+ * wire data: {x: UInt<2>, y: SInt<2>}
+ * // this would return:
+ * cat(cat(mask.x, mask.x), cat(mask.y, mask.y))
+ * }}}
+ */
object toBitMask {
def apply(mask: Expression, dataType: Type): Expression = mask match {
case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiermask(ex, dataType)
@@ -28,12 +28,13 @@ object toBitMask {
private def hiermask(mask: Expression, dataType: Type): Expression =
(mask.tpe, dataType) match {
case (mt: VectorType, dt: VectorType) =>
- seqCat((0 until mt.size).reverse map { i =>
+ seqCat((0 until mt.size).reverse.map { i =>
hiermask(WSubIndex(mask, i, mt.tpe, UnknownFlow), dt.tpe)
})
case (mt: BundleType, dt: BundleType) =>
- seqCat((mt.fields zip dt.fields) map { case (mf, df) =>
- hiermask(WSubField(mask, mf.name, mf.tpe, UnknownFlow), df.tpe)
+ seqCat((mt.fields.zip(dt.fields)).map {
+ case (mf, df) =>
+ hiermask(WSubField(mask, mf.name, mf.tpe, UnknownFlow), df.tpe)
})
case (UIntType(width), dt: GroundType) if width == IntWidth(BigInt(1)) =>
seqCat(List.fill(bitWidth(dt).intValue)(mask))
@@ -44,7 +45,7 @@ object toBitMask {
object createMask {
def apply(dt: Type): Type = dt match {
case t: VectorType => VectorType(apply(t.tpe), t.size)
- case t: BundleType => BundleType(t.fields map (f => f copy (tpe=apply(f.tpe))))
+ case t: BundleType => BundleType(t.fields.map(f => f.copy(tpe = apply(f.tpe))))
case GroundType(w) if w == IntWidth(0) => UIntType(IntWidth(0))
case t: GroundType => BoolType
}
@@ -56,27 +57,33 @@ object MemPortUtils {
type Modules = collection.mutable.ArrayBuffer[DefModule]
def defaultPortSeq(mem: DefMemory): Seq[Field] = Seq(
- Field("addr", Default, UIntType(IntWidth(getUIntWidth(mem.depth - 1) max 1))),
+ Field("addr", Default, UIntType(IntWidth(getUIntWidth(mem.depth - 1).max(1)))),
Field("en", Default, BoolType),
Field("clk", Default, ClockType)
)
// Todo: merge it with memToBundle
def memType(mem: DefMemory): BundleType = {
- val rType = BundleType(defaultPortSeq(mem) :+
- Field("data", Flip, mem.dataType))
- val wType = BundleType(defaultPortSeq(mem) ++ Seq(
- Field("data", Default, mem.dataType),
- Field("mask", Default, createMask(mem.dataType))))
- val rwType = BundleType(defaultPortSeq(mem) ++ Seq(
- Field("rdata", Flip, mem.dataType),
- Field("wmode", Default, BoolType),
- Field("wdata", Default, mem.dataType),
- Field("wmask", Default, createMask(mem.dataType))))
+ val rType = BundleType(
+ defaultPortSeq(mem) :+
+ Field("data", Flip, mem.dataType)
+ )
+ val wType = BundleType(
+ defaultPortSeq(mem) ++ Seq(Field("data", Default, mem.dataType), Field("mask", Default, createMask(mem.dataType)))
+ )
+ val rwType = BundleType(
+ defaultPortSeq(mem) ++ Seq(
+ Field("rdata", Flip, mem.dataType),
+ Field("wmode", Default, BoolType),
+ Field("wdata", Default, mem.dataType),
+ Field("wmask", Default, createMask(mem.dataType))
+ )
+ )
BundleType(
- (mem.readers map (Field(_, Flip, rType))) ++
- (mem.writers map (Field(_, Flip, wType))) ++
- (mem.readwriters map (Field(_, Flip, rwType))))
+ (mem.readers.map(Field(_, Flip, rType))) ++
+ (mem.writers.map(Field(_, Flip, wType))) ++
+ (mem.readwriters.map(Field(_, Flip, rwType)))
+ )
}
def memPortField(s: DefMemory, p: String, f: String): WSubField = {
diff --git a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala
index c51a0adc..30529119 100644
--- a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala
+++ b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala
@@ -9,27 +9,27 @@ import firrtl.Mappers._
import MemPortUtils._
import MemTransformUtils._
-
/** Changes memory port names to standard port names (i.e. RW0 instead T_408)
- */
+ */
object RenameAnnotatedMemoryPorts extends Pass {
+
/** Renames memory ports to a standard naming scheme:
- * - R0, R1, ... for each read port
- * - W0, W1, ... for each write port
- * - RW0, RW1, ... for each readwrite port
- */
+ * - R0, R1, ... for each read port
+ * - W0, W1, ... for each write port
+ * - RW0, RW1, ... for each readwrite port
+ */
def createMemProto(m: DefAnnotatedMemory): DefAnnotatedMemory = {
- val rports = m.readers.indices map (i => s"R$i")
- val wports = m.writers.indices map (i => s"W$i")
- val rwports = m.readwriters.indices map (i => s"RW$i")
- m copy (readers = rports, writers = wports, readwriters = rwports)
+ val rports = m.readers.indices.map(i => s"R$i")
+ val wports = m.writers.indices.map(i => s"W$i")
+ val rwports = m.readwriters.indices.map(i => s"RW$i")
+ m.copy(readers = rports, writers = wports, readwriters = rwports)
}
/** Maps the serialized form of all memory port field names to the
- * corresponding new memory port field Expression.
- * E.g.:
- * - ("m.read.addr") becomes (m.R0.addr)
- */
+ * corresponding new memory port field Expression.
+ * E.g.:
+ * - ("m.read.addr") becomes (m.R0.addr)
+ */
def getMemPortMap(m: DefAnnotatedMemory, memPortMap: MemPortMap): Unit = {
val defaultFields = Seq("addr", "en", "clk")
val rFields = defaultFields :+ "data"
@@ -37,7 +37,10 @@ object RenameAnnotatedMemoryPorts extends Pass {
val rwFields = defaultFields ++ Seq("wmode", "wdata", "rdata", "wmask")
def updateMemPortMap(ports: Seq[String], fields: Seq[String], newPortKind: String): Unit =
- for ((p, i) <- ports.zipWithIndex; f <- fields) {
+ for {
+ (p, i) <- ports.zipWithIndex
+ f <- fields
+ } {
val newPort = WSubField(WRef(m.name), newPortKind + i)
val field = WSubField(newPort, f)
memPortMap(s"${m.name}.$p.$f") = field
@@ -55,16 +58,16 @@ object RenameAnnotatedMemoryPorts extends Pass {
val updatedMem = createMemProto(m)
getMemPortMap(m, memPortMap)
updatedMem
- case s => s map updateMemStmts(memPortMap)
+ case s => s.map(updateMemStmts(memPortMap))
}
/** Replaces candidate memories and their references with standard port names
- */
+ */
def updateMemMods(m: DefModule) = {
val memPortMap = new MemPortMap
- (m map updateMemStmts(memPortMap)
- map updateStmtRefs(memPortMap))
+ (m.map(updateMemStmts(memPortMap))
+ .map(updateStmtRefs(memPortMap)))
}
- def run(c: Circuit) = c copy (modules = c.modules map updateMemMods)
+ def run(c: Circuit) = c.copy(modules = c.modules.map(updateMemMods))
}
diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala
index bfbc163a..fc381e88 100644
--- a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala
+++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala
@@ -13,7 +13,6 @@ import firrtl.annotations._
import firrtl.stage.Forms
import wiring._
-
/** Annotates the name of the pins to add for WiringTransform */
case class PinAnnotation(pins: Seq[String]) extends NoTargetAnnotation
@@ -35,14 +34,16 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
/** Return true if mask granularity is per bit, false if per byte or unspecified
*/
private def getFillWMask(mem: DefAnnotatedMemory) = mem.maskGran match {
- case None => false
+ case None => false
case Some(v) => v == 1
}
private def rPortToBundle(mem: DefAnnotatedMemory) = BundleType(
- defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType))
+ defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType)
+ )
private def rPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType(
- defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType)))
+ defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType))
+ )
/** Catch incorrect memory instantiations when there are masked memories with unsupported aggregate types.
*
@@ -82,7 +83,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
)
private def wPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType(
(defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType))) ++ (mem.maskGran match {
- case None => Nil
+ case None => Nil
case Some(_) if getFillWMask(mem) => Seq(Field("mask", Default, flattenType(mem.dataType)))
case Some(_) => {
checkMaskDatatype(mem)
@@ -111,7 +112,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
Field("wdata", Default, flattenType(mem.dataType)),
Field("rdata", Flip, flattenType(mem.dataType))
) ++ (mem.maskGran match {
- case None => Nil
+ case None => Nil
case Some(_) if (getFillWMask(mem)) => Seq(Field("wmask", Default, flattenType(mem.dataType)))
case Some(_) => {
checkMaskDatatype(mem)
@@ -122,32 +123,34 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
def memToBundle(s: DefAnnotatedMemory) = BundleType(
s.readers.map(Field(_, Flip, rPortToBundle(s))) ++
- s.writers.map(Field(_, Flip, wPortToBundle(s))) ++
- s.readwriters.map(Field(_, Flip, rwPortToBundle(s))))
+ s.writers.map(Field(_, Flip, wPortToBundle(s))) ++
+ s.readwriters.map(Field(_, Flip, rwPortToBundle(s)))
+ )
def memToFlattenBundle(s: DefAnnotatedMemory) = BundleType(
s.readers.map(Field(_, Flip, rPortToFlattenBundle(s))) ++
- s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++
- s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s))))
+ s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++
+ s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s)))
+ )
/** Creates a wrapper module and external module to replace a candidate memory
- * The wrapper module has the same type as the memory it replaces
- * The external module
- */
+ * The wrapper module has the same type as the memory it replaces
+ * The external module
+ */
def createMemModule(m: DefAnnotatedMemory, wrapperName: String): Seq[DefModule] = {
assert(m.dataType != UnknownType)
val wrapperIoType = memToBundle(m)
- val wrapperIoPorts = wrapperIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe))
+ val wrapperIoPorts = wrapperIoType.fields.map(f => Port(NoInfo, f.name, Input, f.tpe))
// Creates a type with the write/readwrite masks omitted if necessary
val bbIoType = memToFlattenBundle(m)
- val bbIoPorts = bbIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe))
+ val bbIoPorts = bbIoType.fields.map(f => Port(NoInfo, f.name, Input, f.tpe))
val bbRef = WRef(m.name, bbIoType)
val hasMask = m.maskGran.isDefined
val fillMask = getFillWMask(m)
def portRef(p: String) = WRef(p, field_type(wrapperIoType, p))
val stmts = Seq(WDefInstance(NoInfo, m.name, m.name, UnknownType)) ++
- (m.readers flatMap (r => adaptReader(portRef(r), WSubField(bbRef, r)))) ++
- (m.writers flatMap (w => adaptWriter(portRef(w), WSubField(bbRef, w), hasMask, fillMask))) ++
- (m.readwriters flatMap (rw => adaptReadWriter(portRef(rw), WSubField(bbRef, rw), hasMask, fillMask)))
+ (m.readers.flatMap(r => adaptReader(portRef(r), WSubField(bbRef, r)))) ++
+ (m.writers.flatMap(w => adaptWriter(portRef(w), WSubField(bbRef, w), hasMask, fillMask))) ++
+ (m.readwriters.flatMap(rw => adaptReadWriter(portRef(rw), WSubField(bbRef, rw), hasMask, fillMask)))
val wrapper = Module(NoInfo, wrapperName, wrapperIoPorts, Block(stmts))
val bb = ExtModule(NoInfo, m.name, bbIoPorts, m.name, Seq.empty)
// TODO: Annotate? -- use actual annotation map
@@ -160,16 +163,16 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
// TODO(shunshou): get rid of copy pasta
// Connects the clk, en, and addr fields from the wrapperPort to the bbPort
def defaultConnects(wrapperPort: WRef, bbPort: WSubField): Seq[Connect] =
- Seq("clk", "en", "addr") map (f => connectFields(bbPort, f, wrapperPort, f))
+ Seq("clk", "en", "addr").map(f => connectFields(bbPort, f, wrapperPort, f))
// Generates mask bits (concatenates an aggregate to ground type)
// depending on mask granularity (# bits = data width / mask granularity)
def maskBits(mask: WSubField, dataType: Type, fillMask: Boolean): Expression =
if (fillMask) toBitMask(mask, dataType) else toBits(mask)
- def adaptReader(wrapperPort: WRef, bbPort: WSubField): Seq[Statement] =
+ def adaptReader(wrapperPort: WRef, bbPort: WSubField): Seq[Statement] =
defaultConnects(wrapperPort, bbPort) :+
- fromBits(WSubField(wrapperPort, "data"), WSubField(bbPort, "data"))
+ fromBits(WSubField(wrapperPort, "data"), WSubField(bbPort, "data"))
def adaptWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean): Seq[Statement] = {
val wrapperData = WSubField(wrapperPort, "data")
@@ -177,11 +180,12 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
Connect(NoInfo, WSubField(bbPort, "data"), toBits(wrapperData))
hasMask match {
case false => defaultSeq
- case true => defaultSeq :+ Connect(
- NoInfo,
- WSubField(bbPort, "mask"),
- maskBits(WSubField(wrapperPort, "mask"), wrapperData.tpe, fillMask)
- )
+ case true =>
+ defaultSeq :+ Connect(
+ NoInfo,
+ WSubField(bbPort, "mask"),
+ maskBits(WSubField(wrapperPort, "mask"), wrapperData.tpe, fillMask)
+ )
}
}
@@ -190,61 +194,67 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
val defaultSeq = defaultConnects(wrapperPort, bbPort) ++ Seq(
fromBits(WSubField(wrapperPort, "rdata"), WSubField(bbPort, "rdata")),
connectFields(bbPort, "wmode", wrapperPort, "wmode"),
- Connect(NoInfo, WSubField(bbPort, "wdata"), toBits(wrapperWData)))
+ Connect(NoInfo, WSubField(bbPort, "wdata"), toBits(wrapperWData))
+ )
hasMask match {
case false => defaultSeq
- case true => defaultSeq :+ Connect(
- NoInfo,
- WSubField(bbPort, "wmask"),
- maskBits(WSubField(wrapperPort, "wmask"), wrapperWData.tpe, fillMask)
- )
+ case true =>
+ defaultSeq :+ Connect(
+ NoInfo,
+ WSubField(bbPort, "wmask"),
+ maskBits(WSubField(wrapperPort, "wmask"), wrapperWData.tpe, fillMask)
+ )
}
}
/** Mapping from (module, memory name) pairs to blackbox names */
private type NameMap = collection.mutable.HashMap[(String, String), String]
+
/** Construct NameMap by assigning unique names for each memory blackbox */
def constructNameMap(namespace: Namespace, nameMap: NameMap, mname: String)(s: Statement): Statement = {
s match {
- case m: DefAnnotatedMemory => m.memRef match {
- case None => nameMap(mname -> m.name) = namespace newName m.name
- case Some(_) =>
- }
+ case m: DefAnnotatedMemory =>
+ m.memRef match {
+ case None => nameMap(mname -> m.name) = namespace.newName(m.name)
+ case Some(_) =>
+ }
case _ =>
}
- s map constructNameMap(namespace, nameMap, mname)
+ s.map(constructNameMap(namespace, nameMap, mname))
}
- def updateMemStmts(namespace: Namespace,
- nameMap: NameMap,
- mname: String,
- memPortMap: MemPortMap,
- memMods: Modules)
- (s: Statement): Statement = s match {
+ def updateMemStmts(
+ namespace: Namespace,
+ nameMap: NameMap,
+ mname: String,
+ memPortMap: MemPortMap,
+ memMods: Modules
+ )(s: Statement
+ ): Statement = s match {
case m: DefAnnotatedMemory =>
if (m.maskGran.isEmpty) {
- m.writers foreach { w => memPortMap(s"${m.name}.$w.mask") = EmptyExpression }
- m.readwriters foreach { w => memPortMap(s"${m.name}.$w.wmask") = EmptyExpression }
+ m.writers.foreach { w => memPortMap(s"${m.name}.$w.mask") = EmptyExpression }
+ m.readwriters.foreach { w => memPortMap(s"${m.name}.$w.wmask") = EmptyExpression }
}
m.memRef match {
case None =>
// prototype mem
val newWrapperName = nameMap(mname -> m.name)
- val newMemBBName = namespace newName s"${newWrapperName}_ext"
- val newMem = m copy (name = newMemBBName)
+ val newMemBBName = namespace.newName(s"${newWrapperName}_ext")
+ val newMem = m.copy(name = newMemBBName)
memMods ++= createMemModule(newMem, newWrapperName)
WDefInstance(m.info, m.name, newWrapperName, UnknownType)
case Some((module, mem)) =>
WDefInstance(m.info, m.name, nameMap(module -> mem), UnknownType)
}
- case sx => sx map updateMemStmts(namespace, nameMap, mname, memPortMap, memMods)
+ case sx => sx.map(updateMemStmts(namespace, nameMap, mname, memPortMap, memMods))
}
def updateMemMods(namespace: Namespace, nameMap: NameMap, memMods: Modules)(m: DefModule) = {
val memPortMap = new MemPortMap
- (m map updateMemStmts(namespace, nameMap, m.name, memPortMap, memMods)
- map updateStmtRefs(memPortMap))
+ (m.map(updateMemStmts(namespace, nameMap, m.name, memPortMap, memMods))
+ .map(updateStmtRefs(memPortMap)))
}
def execute(state: CircuitState): CircuitState = {
@@ -252,15 +262,15 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
val namespace = Namespace(c)
val memMods = new Modules
val nameMap = new NameMap
- c.modules map (m => m map constructNameMap(namespace, nameMap, m.name))
- val modules = c.modules map updateMemMods(namespace, nameMap, memMods)
+ c.modules.map(m => m.map(constructNameMap(namespace, nameMap, m.name)))
+ val modules = c.modules.map(updateMemMods(namespace, nameMap, memMods))
// print conf
writer.serialize()
val pannos = state.annotations.collect { case a: PinAnnotation => a }
val pins = pannos match {
- case Seq() => Nil
+ case Seq() => Nil
case Seq(PinAnnotation(pins)) => pins
- case _ => throwInternalError("Something went wrong")
+ case _ => throwInternalError("Something went wrong")
}
val annos = pins.foldLeft(Seq[Annotation]()) { (seq, pin) =>
seq ++ memMods.collect {
diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
index 87321ea0..79e07640 100644
--- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
+++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
@@ -7,7 +7,7 @@ import firrtl._
import firrtl.annotations._
import firrtl.options.{HasShellOptions, ShellOption}
import Utils.error
-import java.io.{File, CharArrayWriter, PrintWriter}
+import java.io.{CharArrayWriter, File, PrintWriter}
import wiring._
import firrtl.stage.{Forms, RunFirrtlTransformAnnotation}
@@ -50,7 +50,15 @@ class ConfWriter(filename: String) {
// assert that we don't overflow going from BigInt to Int conversion
require(bitWidth(m.dataType) <= Int.MaxValue)
m.maskGran.foreach { case x => require(x <= Int.MaxValue) }
- val conf = MemConf(m.name, m.depth, bitWidth(m.dataType).toInt, m.readers.length, m.writers.length, m.readwriters.length, m.maskGran.map(_.toInt))
+ val conf = MemConf(
+ m.name,
+ m.depth,
+ bitWidth(m.dataType).toInt,
+ m.readers.length,
+ m.writers.length,
+ m.readwriters.length,
+ m.maskGran.map(_.toInt)
+ )
outputBuffer.append(conf.toString)
}
def serialize() = {
@@ -113,27 +121,31 @@ class ReplSeqMem extends Transform with HasShellOptions with DependencyAPIMigrat
val options = Seq(
new ShellOption[String](
longOption = "repl-seq-mem",
- toAnnotationSeq = (a: String) => Seq( passes.memlib.ReplSeqMemAnnotation.parse(a),
- RunFirrtlTransformAnnotation(new ReplSeqMem) ),
+ toAnnotationSeq =
+ (a: String) => Seq(passes.memlib.ReplSeqMemAnnotation.parse(a), RunFirrtlTransformAnnotation(new ReplSeqMem)),
helpText = "Blackbox and emit a configuration file for each sequential memory",
shortOption = Some("frsq"),
- helpValueName = Some("-c:<circuit>:-i:<file>:-o:<file>") ) )
+ helpValueName = Some("-c:<circuit>:-i:<file>:-o:<file>")
+ )
+ )
def transforms(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter): Seq[Transform] =
- Seq(new SimpleMidTransform(Legalize),
- new SimpleMidTransform(ToMemIR),
- new SimpleMidTransform(ResolveMaskGranularity),
- new SimpleMidTransform(RenameAnnotatedMemoryPorts),
- new ResolveMemoryReference,
- new CreateMemoryAnnotations(inConfigFile),
- new ReplaceMemMacros(outConfigFile),
- new WiringTransform,
- new SimpleMidTransform(RemoveEmpty),
- new SimpleMidTransform(CheckInitialization),
- new SimpleMidTransform(InferTypes),
- Uniquify,
- new SimpleMidTransform(ResolveKinds),
- new SimpleMidTransform(ResolveFlows))
+ Seq(
+ new SimpleMidTransform(Legalize),
+ new SimpleMidTransform(ToMemIR),
+ new SimpleMidTransform(ResolveMaskGranularity),
+ new SimpleMidTransform(RenameAnnotatedMemoryPorts),
+ new ResolveMemoryReference,
+ new CreateMemoryAnnotations(inConfigFile),
+ new ReplaceMemMacros(outConfigFile),
+ new WiringTransform,
+ new SimpleMidTransform(RemoveEmpty),
+ new SimpleMidTransform(CheckInitialization),
+ new SimpleMidTransform(InferTypes),
+ Uniquify,
+ new SimpleMidTransform(ResolveKinds),
+ new SimpleMidTransform(ResolveFlows)
+ )
def execute(state: CircuitState): CircuitState = {
val annos = state.annotations.collect { case a: ReplSeqMemAnnotation => a }
diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
index 41c47dce..434c7602 100644
--- a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
+++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
@@ -28,10 +28,10 @@ object AnalysisUtils {
connects(value.serialize) = WInvalid
case _ => // do nothing
}
- s map getConnects(connects)
+ s.map(getConnects(connects))
}
val connects = new Connects
- m map getConnects(connects)
+ m.map(getConnects(connects))
connects
}
@@ -56,8 +56,8 @@ object AnalysisUtils {
else if (weq(tvOrigin, fvOrigin)) tvOrigin
else if (weq(fvOrigin, zero) && weq(condOrigin, tvOrigin)) condOrigin
else e
- case DoPrim(PrimOps.Or, args, consts, tpe) if args exists (weq(_, one)) => one
- case DoPrim(PrimOps.And, args, consts, tpe) if args exists (weq(_, zero)) => zero
+ case DoPrim(PrimOps.Or, args, consts, tpe) if args.exists(weq(_, one)) => one
+ case DoPrim(PrimOps.And, args, consts, tpe) if args.exists(weq(_, zero)) => zero
case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) =>
val extractionWidth = (msb - lsb) + 1
val nodeWidth = bitWidth(args.head.tpe)
@@ -69,10 +69,10 @@ object AnalysisUtils {
case ValidIf(cond, value, _) => getOrigin(connects)(value)
// note: this should stop on a reg, but will stack overflow for combinational loops (not allowed)
case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind =>
- connects get e.serialize match {
- case Some(ex) => getOrigin(connects)(ex)
- case None => e
- }
+ connects.get(e.serialize) match {
+ case Some(ex) => getOrigin(connects)(ex)
+ case None => e
+ }
case _ => e
}
}
@@ -90,10 +90,9 @@ object ResolveMaskGranularity extends Pass {
*/
def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = {
val wenOrigin = getOrigin(connects)(wen)
- val wmaskOrigin = connects.keys filter
- (_ startsWith wmask.serialize) map {s: String => getOrigin(connects, s)}
+ val wmaskOrigin = connects.keys.filter(_.startsWith(wmask.serialize)).map { s: String => getOrigin(connects, s) }
// all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking)
- val redundantMask = wmaskOrigin forall (x => weq(x, wenOrigin) || weq(x, one))
+ val redundantMask = wmaskOrigin.forall(x => weq(x, wenOrigin) || weq(x, one))
if (redundantMask) None else Some(wmaskOrigin.size)
}
@@ -103,18 +102,17 @@ object ResolveMaskGranularity extends Pass {
def updateStmts(connects: Connects)(s: Statement): Statement = s match {
case m: DefAnnotatedMemory =>
val dataBits = bitWidth(m.dataType)
- val rwMasks = m.readwriters map (rw =>
- getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask")))
- val wMasks = m.writers map (w =>
- getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask")))
+ val rwMasks =
+ m.readwriters.map(rw => getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask")))
+ val wMasks = m.writers.map(w => getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask")))
val maskGran = (rwMasks ++ wMasks).head match {
- case None => None
+ case None => None
case Some(maskBits) => Some(dataBits / maskBits)
}
m.copy(maskGran = maskGran)
- case sx => sx map updateStmts(connects)
+ case sx => sx.map(updateStmts(connects))
}
- def annotateModMems(m: DefModule): DefModule = m map updateStmts(getConnects(m))
- def run(c: Circuit): Circuit = c copy (modules = c.modules map annotateModMems)
+ def annotateModMems(m: DefModule): DefModule = m.map(updateStmts(getConnects(m)))
+ def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(annotateModMems))
}
diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala
index b5ff10c6..e80e0c4a 100644
--- a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala
+++ b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala
@@ -14,7 +14,7 @@ case class NoDedupMemAnnotation(target: ComponentName) extends SingleTargetAnnot
}
/** Resolves annotation ref to memories that exactly match (except name) another memory
- */
+ */
class ResolveMemoryReference extends Transform with DependencyAPIMigration {
override def prerequisites = Forms.MidForm
@@ -45,10 +45,12 @@ class ResolveMemoryReference extends Transform with DependencyAPIMigration {
/** If a candidate memory is identical except for name to another, add an
* annotation that references the name of the other memory.
*/
- def updateMemStmts(mname: String,
- existingMems: AnnotatedMemories,
- noDedupMap: Map[String, Set[String]])
- (s: Statement): Statement = s match {
+ def updateMemStmts(
+ mname: String,
+ existingMems: AnnotatedMemories,
+ noDedupMap: Map[String, Set[String]]
+ )(s: Statement
+ ): Statement = s match {
// If not dedupable, no need to add to existing (since nothing can dedup with it)
// We just return the DefAnnotatedMemory as is in the default case below
case m: DefAnnotatedMemory if dedupable(noDedupMap, mname, m.name) =>
diff --git a/src/main/scala/firrtl/passes/memlib/ToMemIR.scala b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala
index 554a3572..9fe7f852 100644
--- a/src/main/scala/firrtl/passes/memlib/ToMemIR.scala
+++ b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala
@@ -14,16 +14,17 @@ import firrtl.ir._
* - undefined read-under-write behavior
*/
object ToMemIR extends Pass {
+
/** Only annotate memories that are candidates for memory macro replacements
* i.e. rw, w + r (read, write 1 cycle delay) and read-under-write "undefined."
*/
import ReadUnderWrite._
def updateStmts(s: Statement): Statement = s match {
- case m @ DefMemory(_,_,_,_,1,1,r,w,rw,Undefined) if (w.length + rw.length) == 1 && r.length <= 1 =>
+ case m @ DefMemory(_, _, _, _, 1, 1, r, w, rw, Undefined) if (w.length + rw.length) == 1 && r.length <= 1 =>
DefAnnotatedMemory(m)
- case sx => sx map updateStmts
+ case sx => sx.map(updateStmts)
}
- def annotateModMems(m: DefModule) = m map updateStmts
- def run(c: Circuit) = c copy (modules = c.modules map annotateModMems)
+ def annotateModMems(m: DefModule) = m.map(updateStmts)
+ def run(c: Circuit) = c.copy(modules = c.modules.map(annotateModMems))
}
diff --git a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala
index dd644323..a2b14343 100644
--- a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala
+++ b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala
@@ -24,19 +24,19 @@ object MemDelayAndReadwriteTransformer {
case class SplitStatements(decls: Seq[Statement], conns: Seq[Connect])
// Utilities for generating hardware
- def NOT(e: Expression) = DoPrim(PrimOps.Not, Seq(e), Nil, BoolType)
- def AND(e1: Expression, e2: Expression) = DoPrim(PrimOps.And, Seq(e1, e2), Nil, BoolType)
- def connect(l: Expression, r: Expression): Connect = Connect(NoInfo, l, r)
- def condConnect(c: Expression)(l: Expression, r: Expression): Connect = connect(l, Mux(c, r, l, l.tpe))
+ def NOT(e: Expression) = DoPrim(PrimOps.Not, Seq(e), Nil, BoolType)
+ def AND(e1: Expression, e2: Expression) = DoPrim(PrimOps.And, Seq(e1, e2), Nil, BoolType)
+ def connect(l: Expression, r: Expression): Connect = Connect(NoInfo, l, r)
+ def condConnect(c: Expression)(l: Expression, r: Expression): Connect = connect(l, Mux(c, r, l, l.tpe))
// Utilities for working with WithValid groups
def connect(l: WithValid, r: WithValid): Seq[Connect] = {
- val paired = (l.valid +: l.payload) zip (r.valid +: r.payload)
+ val paired = (l.valid +: l.payload).zip(r.valid +: r.payload)
paired.map { case (le, re) => connect(le, re) }
}
def condConnect(l: WithValid, r: WithValid): Seq[Connect] = {
- connect(l.valid, r.valid) +: (l.payload zip r.payload).map { case (le, re) => condConnect(r.valid)(le, re) }
+ connect(l.valid, r.valid) +: (l.payload.zip(r.payload)).map { case (le, re) => condConnect(r.valid)(le, re) }
}
// Internal representation of a pipeline stage with an associated valid signal
@@ -47,20 +47,23 @@ object MemDelayAndReadwriteTransformer {
private def flatName(e: Expression) = metaChars.replaceAllIn(e.serialize, "_")
// Pipeline a group of signals with an associated valid signal. Gate registers when possible.
- def pipelineWithValid(ns: Namespace)(
- clock: Expression,
- depth: Int,
- src: WithValid,
- nameTemplate: Option[WithValid] = None): (WithValid, Seq[Statement], Seq[Connect]) = {
+ def pipelineWithValid(
+ ns: Namespace
+ )(clock: Expression,
+ depth: Int,
+ src: WithValid,
+ nameTemplate: Option[WithValid] = None
+ ): (WithValid, Seq[Statement], Seq[Connect]) = {
def asReg(e: Expression) = DefRegister(NoInfo, e.serialize, e.tpe, clock, zero, e)
val template = nameTemplate.getOrElse(src)
- val stages = Seq.iterate(PipeStageWithValid(0, src), depth + 1) { case prev =>
- def pipeRegRef(e: Expression) = WRef(ns.newName(s"${flatName(e)}_pipe_${prev.idx}"), e.tpe, RegKind)
- val ref = WithValid(pipeRegRef(template.valid), template.payload.map(pipeRegRef))
- val regs = (ref.valid +: ref.payload).map(asReg)
- PipeStageWithValid(prev.idx + 1, ref, SplitStatements(regs, condConnect(ref, prev.ref)))
+ val stages = Seq.iterate(PipeStageWithValid(0, src), depth + 1) {
+ case prev =>
+ def pipeRegRef(e: Expression) = WRef(ns.newName(s"${flatName(e)}_pipe_${prev.idx}"), e.tpe, RegKind)
+ val ref = WithValid(pipeRegRef(template.valid), template.payload.map(pipeRegRef))
+ val regs = (ref.valid +: ref.payload).map(asReg)
+ PipeStageWithValid(prev.idx + 1, ref, SplitStatements(regs, condConnect(ref, prev.ref)))
}
(stages.last.ref, stages.flatMap(_.stmts.decls), stages.flatMap(_.stmts.conns))
}
@@ -84,10 +87,10 @@ class MemDelayAndReadwriteTransformer(m: DefModule) {
private def findMemConns(s: Statement): Unit = s match {
case Connect(_, loc, expr) if (kind(loc) == MemKind) => netlist(we(loc)) = expr
- case _ => s.foreach(findMemConns)
+ case _ => s.foreach(findMemConns)
}
- private def swapMemRefs(e: Expression): Expression = e map swapMemRefs match {
+ private def swapMemRefs(e: Expression): Expression = e.map(swapMemRefs) match {
case sf: WSubField => exprReplacements.getOrElse(we(sf), sf)
case ex => ex
}
@@ -105,51 +108,57 @@ class MemDelayAndReadwriteTransformer(m: DefModule) {
val rRespDelay = if (mem.readUnderWrite == ReadUnderWrite.Old) mem.readLatency else 0
val wCmdDelay = mem.writeLatency - 1
- val readStmts = (mem.readers ++ mem.readwriters).map { case r =>
- def oldDriver(f: String) = netlist(we(memPortField(mem, r, f)))
- def newField(f: String) = memPortField(newMem, rMap.getOrElse(r, r), f)
- val clk = oldDriver("clk")
-
- // Pack sources of read command inputs into WithValid object -> different for readwriter
- val enSrc = if (rMap.contains(r)) AND(oldDriver("en"), NOT(oldDriver("wmode"))) else oldDriver("en")
- val cmdSrc = WithValid(enSrc, Seq(oldDriver("addr")))
- val cmdSink = WithValid(newField("en"), Seq(newField("addr")))
- val (cmdPiped, cmdDecls, cmdConns) = pipelineWithValid(ns)(clk, rCmdDelay, cmdSrc, nameTemplate = Some(cmdSink))
- val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk)
-
- // Pipeline read response using *last* command pipe stage enable as the valid signal
- val resp = WithValid(cmdPiped.valid, Seq(newField("data")))
- val respPipeNameTemplate = Some(resp.copy(valid = cmdSink.valid)) // base pipeline register names off field names
- val (respPiped, respDecls, respConns) = pipelineWithValid(ns)(clk, rRespDelay, resp, nameTemplate = respPipeNameTemplate)
-
- // Make sure references to the read data get appropriately substituted
- val oldRDataName = if (rMap.contains(r)) "rdata" else "data"
- exprReplacements(we(memPortField(mem, r, oldRDataName))) = respPiped.payload.head
-
- // Return all statements; they're separated so connects can go after all declarations
- SplitStatements(cmdDecls ++ respDecls, cmdConns ++ cmdPortConns ++ respConns)
+ val readStmts = (mem.readers ++ mem.readwriters).map {
+ case r =>
+ def oldDriver(f: String) = netlist(we(memPortField(mem, r, f)))
+ def newField(f: String) = memPortField(newMem, rMap.getOrElse(r, r), f)
+ val clk = oldDriver("clk")
+
+ // Pack sources of read command inputs into WithValid object -> different for readwriter
+ val enSrc = if (rMap.contains(r)) AND(oldDriver("en"), NOT(oldDriver("wmode"))) else oldDriver("en")
+ val cmdSrc = WithValid(enSrc, Seq(oldDriver("addr")))
+ val cmdSink = WithValid(newField("en"), Seq(newField("addr")))
+ val (cmdPiped, cmdDecls, cmdConns) =
+ pipelineWithValid(ns)(clk, rCmdDelay, cmdSrc, nameTemplate = Some(cmdSink))
+ val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk)
+
+ // Pipeline read response using *last* command pipe stage enable as the valid signal
+ val resp = WithValid(cmdPiped.valid, Seq(newField("data")))
+ val respPipeNameTemplate =
+ Some(resp.copy(valid = cmdSink.valid)) // base pipeline register names off field names
+ val (respPiped, respDecls, respConns) =
+ pipelineWithValid(ns)(clk, rRespDelay, resp, nameTemplate = respPipeNameTemplate)
+
+ // Make sure references to the read data get appropriately substituted
+ val oldRDataName = if (rMap.contains(r)) "rdata" else "data"
+ exprReplacements(we(memPortField(mem, r, oldRDataName))) = respPiped.payload.head
+
+ // Return all statements; they're separated so connects can go after all declarations
+ SplitStatements(cmdDecls ++ respDecls, cmdConns ++ cmdPortConns ++ respConns)
}
- val writeStmts = (mem.writers ++ mem.readwriters).map { case w =>
- def oldDriver(f: String) = netlist(we(memPortField(mem, w, f)))
- def newField(f: String) = memPortField(newMem, wMap.getOrElse(w, w), f)
- val clk = oldDriver("clk")
-
- // Pack sources of write command inputs into WithValid object -> different for readwriter
- val cmdSrc = if (wMap.contains(w)) {
- val en = AND(oldDriver("en"), oldDriver("wmode"))
- WithValid(en, Seq(oldDriver("addr"), oldDriver("wmask"), oldDriver("wdata")))
- } else {
- WithValid(oldDriver("en"), Seq(oldDriver("addr"), oldDriver("mask"), oldDriver("data")))
- }
-
- // Pipeline write command, connect to memory
- val cmdSink = WithValid(newField("en"), Seq(newField("addr"), newField("mask"), newField("data")))
- val (cmdPiped, cmdDecls, cmdConns) = pipelineWithValid(ns)(clk, wCmdDelay, cmdSrc, nameTemplate = Some(cmdSink))
- val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk)
-
- // Return all statements; they're separated so connects can go after all declarations
- SplitStatements(cmdDecls, cmdConns ++ cmdPortConns)
+ val writeStmts = (mem.writers ++ mem.readwriters).map {
+ case w =>
+ def oldDriver(f: String) = netlist(we(memPortField(mem, w, f)))
+ def newField(f: String) = memPortField(newMem, wMap.getOrElse(w, w), f)
+ val clk = oldDriver("clk")
+
+ // Pack sources of write command inputs into WithValid object -> different for readwriter
+ val cmdSrc = if (wMap.contains(w)) {
+ val en = AND(oldDriver("en"), oldDriver("wmode"))
+ WithValid(en, Seq(oldDriver("addr"), oldDriver("wmask"), oldDriver("wdata")))
+ } else {
+ WithValid(oldDriver("en"), Seq(oldDriver("addr"), oldDriver("mask"), oldDriver("data")))
+ }
+
+ // Pipeline write command, connect to memory
+ val cmdSink = WithValid(newField("en"), Seq(newField("addr"), newField("mask"), newField("data")))
+ val (cmdPiped, cmdDecls, cmdConns) =
+ pipelineWithValid(ns)(clk, wCmdDelay, cmdSrc, nameTemplate = Some(cmdSink))
+ val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk)
+
+ // Return all statements; they're separated so connects can go after all declarations
+ SplitStatements(cmdDecls, cmdConns ++ cmdPortConns)
}
newConns ++= (readStmts ++ writeStmts).flatMap(_.conns)
@@ -171,8 +180,7 @@ object VerilogMemDelays extends Pass {
override def prerequisites = firrtl.stage.Forms.LowForm :+ Dependency(firrtl.passes.RemoveValidIf)
override val optionalPrerequisiteOf =
- Seq( Dependency[VerilogEmitter],
- Dependency[SystemVerilogEmitter] )
+ Seq(Dependency[VerilogEmitter], Dependency[SystemVerilogEmitter])
override def invalidates(a: Transform): Boolean = a match {
case _: transforms.ConstantPropagation | ResolveFlows => true
@@ -180,5 +188,5 @@ object VerilogMemDelays extends Pass {
}
def transform(m: DefModule): DefModule = (new MemDelayAndReadwriteTransformer(m)).transformed
- def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(transform))
+ def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(transform))
}
diff --git a/src/main/scala/firrtl/passes/memlib/YamlUtils.scala b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala
index a43adfe2..b5f91e7b 100644
--- a/src/main/scala/firrtl/passes/memlib/YamlUtils.scala
+++ b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala
@@ -6,7 +6,6 @@ import net.jcazevedo.moultingyaml._
import java.io.{CharArrayWriter, File, PrintWriter}
import firrtl.FileUtils
-
object CustomYAMLProtocol extends DefaultYamlProtocol {
// bottom depends on top
implicit val _pin = yamlFormat1(Pin)
@@ -20,17 +19,15 @@ case class Source(name: String, module: String)
case class Top(name: String)
case class Config(pin: Pin, source: Source, top: Top)
-
class YamlFileReader(file: String) {
- def parse[A](implicit reader: YamlReader[A]) : Seq[A] = {
+ def parse[A](implicit reader: YamlReader[A]): Seq[A] = {
if (new File(file).exists) {
val yamlString = FileUtils.getText(file)
- yamlString.parseYamls flatMap (x =>
- try Some(reader read x)
+ yamlString.parseYamls.flatMap(x =>
+ try Some(reader.read(x))
catch { case e: Exception => None }
)
- }
- else sys.error("Yaml file doesn't exist!")
+ } else sys.error("Yaml file doesn't exist!")
}
}
@@ -38,11 +35,11 @@ class YamlFileWriter(file: String) {
val outputBuffer = new CharArrayWriter
val separator = "--- \n"
def append(in: YamlValue): Unit = {
- outputBuffer append s"$separator${in.prettyPrint}"
+ outputBuffer.append(s"$separator${in.prettyPrint}")
}
def dump(): Unit = {
val outputFile = new PrintWriter(file)
- outputFile write outputBuffer.toString
+ outputFile.write(outputBuffer.toString)
outputFile.close()
}
}
diff --git a/src/main/scala/firrtl/passes/wiring/Wiring.scala b/src/main/scala/firrtl/passes/wiring/Wiring.scala
index 3f74e5d2..a69b7797 100644
--- a/src/main/scala/firrtl/passes/wiring/Wiring.scala
+++ b/src/main/scala/firrtl/passes/wiring/Wiring.scala
@@ -18,8 +18,7 @@ import firrtl.graph.EulerTour
case class WiringInfo(source: ComponentName, sinks: Seq[Named], pin: String)
/** A data store of wiring names */
-case class WiringNames(compName: String, source: String, sinks: Seq[Named],
- pin: String)
+case class WiringNames(compName: String, source: String, sinks: Seq[Named], pin: String)
/** Pass that computes and applies a sequence of wiring modifications
*
@@ -28,31 +27,39 @@ case class WiringNames(compName: String, source: String, sinks: Seq[Named],
*/
class Wiring(wiSeq: Seq[WiringInfo]) extends Pass {
def run(c: Circuit): Circuit = analyze(c)
- .foldLeft(c){
- case (cx, (tpe, modsMap)) => cx.copy(
- modules = cx.modules map onModule(tpe, modsMap)) }
+ .foldLeft(c) {
+ case (cx, (tpe, modsMap)) => cx.copy(modules = cx.modules.map(onModule(tpe, modsMap)))
+ }
/** Converts multiple units of wiring information to module modifications */
private def analyze(c: Circuit): Seq[(Type, Map[String, Modifications])] = {
val names = wiSeq
- .map ( wi => (wi.source, wi.sinks, wi.pin) match {
- case (ComponentName(comp, ModuleName(source,_)), sinks, pin) =>
- WiringNames(comp, source, sinks, pin) })
+ .map(wi =>
+ (wi.source, wi.sinks, wi.pin) match {
+ case (ComponentName(comp, ModuleName(source, _)), sinks, pin) =>
+ WiringNames(comp, source, sinks, pin)
+ }
+ )
val portNames = mutable.Seq.fill(names.size)(Map[String, String]())
- c.modules.foreach{ m =>
+ c.modules.foreach { m =>
val ns = Namespace(m)
- names.zipWithIndex.foreach{ case (WiringNames(c, so, si, p), i) =>
- portNames(i) = portNames(i) +
- ( m.name -> {
- if (si.exists(getModuleName(_) == m.name)) ns.newName(p)
- else ns.newName(tokenize(c) filterNot ("[]." contains _) mkString "_")
- })}}
+ names.zipWithIndex.foreach {
+ case (WiringNames(c, so, si, p), i) =>
+ portNames(i) = portNames(i) +
+ (m.name -> {
+ if (si.exists(getModuleName(_) == m.name)) ns.newName(p)
+ else ns.newName(tokenize(c).filterNot("[]." contains _).mkString("_"))
+ })
+ }
+ }
val iGraph = InstanceKeyGraph(c)
- names.zip(portNames).map{ case(WiringNames(comp, so, si, _), pn) =>
- computeModifications(c, iGraph, comp, so, si, pn) }
+ names.zip(portNames).map {
+ case (WiringNames(comp, so, si, _), pn) =>
+ computeModifications(c, iGraph, comp, so, si, pn)
+ }
}
/** Converts a single unit of wiring information to module modifications
@@ -69,19 +76,20 @@ class Wiring(wiSeq: Seq[WiringInfo]) extends Pass {
* @return a tuple of the component type and a map of module names
* to pending modifications
*/
- private def computeModifications(c: Circuit,
- iGraph: InstanceKeyGraph,
- compName: String,
- source: String,
- sinks: Seq[Named],
- portNames: Map[String, String]):
- (Type, Map[String, Modifications]) = {
+ private def computeModifications(
+ c: Circuit,
+ iGraph: InstanceKeyGraph,
+ compName: String,
+ source: String,
+ sinks: Seq[Named],
+ portNames: Map[String, String]
+ ): (Type, Map[String, Modifications]) = {
val sourceComponentType = getType(c, source, compName)
- val sinkComponents: Map[String, Seq[String]] = sinks
- .collect{ case ComponentName(c, ModuleName(m, _)) => (c, m) }
- .foldLeft(new scala.collection.immutable.HashMap[String, Seq[String]]){
- case (a, (c, m)) => a ++ Map(m -> (Seq(c) ++ a.getOrElse(m, Nil)) ) }
+ val sinkComponents: Map[String, Seq[String]] = sinks.collect { case ComponentName(c, ModuleName(m, _)) => (c, m) }
+ .foldLeft(new scala.collection.immutable.HashMap[String, Seq[String]]) {
+ case (a, (c, m)) => a ++ Map(m -> (Seq(c) ++ a.getOrElse(m, Nil)))
+ }
// Determine "ownership" of sources to sinks via minimum distance
val owners = sinksToSourcesSeq(sinks, source, iGraph)
@@ -95,86 +103,88 @@ class Wiring(wiSeq: Seq[WiringInfo]) extends Pass {
def makeWire(m: Modifications, portName: String): Modifications =
m.copy(addPortOrWire = Some(m.addPortOrWire.getOrElse((portName, DecWire))))
def makeWireC(m: Modifications, portName: String, c: (String, String)): Modifications =
- m.copy(addPortOrWire = Some(m.addPortOrWire.getOrElse((portName, DecWire))), cons = (m.cons :+ c).distinct )
+ m.copy(addPortOrWire = Some(m.addPortOrWire.getOrElse((portName, DecWire))), cons = (m.cons :+ c).distinct)
val tour = EulerTour(iGraph.graph, iGraph.top)
// Finds the lowest common ancestor instances for two module names in a design
def lowestCommonAncestor(moduleA: Seq[InstanceKey], moduleB: Seq[InstanceKey]): Seq[InstanceKey] =
tour.rmq(moduleA, moduleB)
- owners.foreach { case (sink, source) =>
- val lca = lowestCommonAncestor(sink, source)
-
- // Compute metadata along Sink to LCA paths.
- sink.drop(lca.size - 1).sliding(2).toList.reverse.foreach {
- case Seq(InstanceKey(_,pm), InstanceKey(ci,cm)) =>
- val to = s"$ci.${portNames(cm)}"
- val from = s"${portNames(pm)}"
- meta(pm) = makeWireC(meta(pm), portNames(pm), (to, from))
- meta(cm) = meta(cm).copy(
- addPortOrWire = Some((portNames(cm), DecInput))
- )
- // Case where the sink is the LCA
- case Seq(InstanceKey(_,pm)) =>
- // Case where the source is also the LCA
- if (source.drop(lca.size).isEmpty) {
- meta(pm) = makeWire(meta(pm), portNames(pm))
- } else {
- val InstanceKey(ci,cm) = source.drop(lca.size).head
- val to = s"${portNames(pm)}"
- val from = s"$ci.${portNames(cm)}"
- meta(pm) = makeWireC(meta(pm), portNames(pm), (to, from))
- }
- }
+ owners.foreach {
+ case (sink, source) =>
+ val lca = lowestCommonAncestor(sink, source)
- // Compute metadata for the Sink
- sink.last match { case InstanceKey( _, m) =>
- if (sinkComponents.contains(m)) {
- val from = s"${portNames(m)}"
- sinkComponents(m).foreach( to =>
- meta(m) = meta(m).copy(
- cons = (meta(m).cons :+( (to, from) )).distinct
+ // Compute metadata along Sink to LCA paths.
+ sink.drop(lca.size - 1).sliding(2).toList.reverse.foreach {
+ case Seq(InstanceKey(_, pm), InstanceKey(ci, cm)) =>
+ val to = s"$ci.${portNames(cm)}"
+ val from = s"${portNames(pm)}"
+ meta(pm) = makeWireC(meta(pm), portNames(pm), (to, from))
+ meta(cm) = meta(cm).copy(
+ addPortOrWire = Some((portNames(cm), DecInput))
)
- )
+ // Case where the sink is the LCA
+ case Seq(InstanceKey(_, pm)) =>
+ // Case where the source is also the LCA
+ if (source.drop(lca.size).isEmpty) {
+ meta(pm) = makeWire(meta(pm), portNames(pm))
+ } else {
+ val InstanceKey(ci, cm) = source.drop(lca.size).head
+ val to = s"${portNames(pm)}"
+ val from = s"$ci.${portNames(cm)}"
+ meta(pm) = makeWireC(meta(pm), portNames(pm), (to, from))
+ }
}
- }
- // Compute metadata for the Source
- source.last match { case InstanceKey( _, m) =>
- val to = s"${portNames(m)}"
- val from = compName
- meta(m) = meta(m).copy(
- cons = (meta(m).cons :+( (to, from) )).distinct
- )
- }
+ // Compute metadata for the Sink
+ sink.last match {
+ case InstanceKey(_, m) =>
+ if (sinkComponents.contains(m)) {
+ val from = s"${portNames(m)}"
+ sinkComponents(m).foreach(to =>
+ meta(m) = meta(m).copy(
+ cons = (meta(m).cons :+ ((to, from))).distinct
+ )
+ )
+ }
+ }
- // Compute metadata along Source to LCA path
- source.drop(lca.size - 1).sliding(2).toList.reverse.map {
- case Seq(InstanceKey(_,pm), InstanceKey(ci,cm)) => {
- val to = s"${portNames(pm)}"
- val from = s"$ci.${portNames(cm)}"
- meta(pm) = meta(pm).copy(
- cons = (meta(pm).cons :+( (to, from) )).distinct
- )
- meta(cm) = meta(cm).copy(
- addPortOrWire = Some((portNames(cm), DecOutput))
- )
+ // Compute metadata for the Source
+ source.last match {
+ case InstanceKey(_, m) =>
+ val to = s"${portNames(m)}"
+ val from = compName
+ meta(m) = meta(m).copy(
+ cons = (meta(m).cons :+ ((to, from))).distinct
+ )
}
- // Case where the source is the LCA
- case Seq(InstanceKey(_,pm)) => {
- // Case where the sink is also the LCA. We do nothing here,
- // as we've created the connecting wire above
- if (sink.drop(lca.size).isEmpty) {
- } else {
- val InstanceKey(ci,cm) = sink.drop(lca.size).head
- val to = s"$ci.${portNames(cm)}"
- val from = s"${portNames(pm)}"
+
+ // Compute metadata along Source to LCA path
+ source.drop(lca.size - 1).sliding(2).toList.reverse.map {
+ case Seq(InstanceKey(_, pm), InstanceKey(ci, cm)) => {
+ val to = s"${portNames(pm)}"
+ val from = s"$ci.${portNames(cm)}"
meta(pm) = meta(pm).copy(
- cons = (meta(pm).cons :+( (to, from) )).distinct
+ cons = (meta(pm).cons :+ ((to, from))).distinct
)
+ meta(cm) = meta(cm).copy(
+ addPortOrWire = Some((portNames(cm), DecOutput))
+ )
+ }
+ // Case where the source is the LCA
+ case Seq(InstanceKey(_, pm)) => {
+ // Case where the sink is also the LCA. We do nothing here,
+ // as we've created the connecting wire above
+ if (sink.drop(lca.size).isEmpty) {} else {
+ val InstanceKey(ci, cm) = sink.drop(lca.size).head
+ val to = s"$ci.${portNames(cm)}"
+ val from = s"${portNames(pm)}"
+ meta(pm) = meta(pm).copy(
+ cons = (meta(pm).cons :+ ((to, from))).distinct
+ )
+ }
}
}
- }
}
(sourceComponentType, meta.toMap)
}
@@ -189,20 +199,22 @@ class Wiring(wiSeq: Seq[WiringInfo]) extends Pass {
val ports = mutable.ArrayBuffer[Port]()
l.addPortOrWire match {
case None =>
- case Some((s, dt)) => dt match {
- case DecInput => ports += Port(NoInfo, s, Input, t)
- case DecOutput => ports += Port(NoInfo, s, Output, t)
- case DecWire => defines += DefWire(NoInfo, s, t)
- }
+ case Some((s, dt)) =>
+ dt match {
+ case DecInput => ports += Port(NoInfo, s, Input, t)
+ case DecOutput => ports += Port(NoInfo, s, Output, t)
+ case DecWire => defines += DefWire(NoInfo, s, t)
+ }
}
- connects ++= (l.cons map { case ((l, r)) =>
- Connect(NoInfo, toExp(l), toExp(r))
+ connects ++= (l.cons.map {
+ case ((l, r)) =>
+ Connect(NoInfo, toExp(l), toExp(r))
})
m match {
case Module(i, n, ps, body) =>
val stmts = body match {
case Block(sx) => sx
- case s => Seq(s)
+ case s => Seq(s)
}
Module(i, n, ps ++ ports, Block(List() ++ defines ++ stmts ++ connects))
case ExtModule(i, n, ps, dn, p) => ExtModule(i, n, ps ++ ports, dn, p)
diff --git a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala
index 20fb1215..d6658f16 100644
--- a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala
+++ b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala
@@ -14,14 +14,12 @@ import firrtl.stage.Forms
case class WiringException(msg: String) extends PassException(msg)
/** A component, e.g. register etc. Must be declared only once under the TopAnnotation */
-case class SourceAnnotation(target: ComponentName, pin: String) extends
- SingleTargetAnnotation[ComponentName] {
+case class SourceAnnotation(target: ComponentName, pin: String) extends SingleTargetAnnotation[ComponentName] {
def duplicate(n: ComponentName) = this.copy(target = n)
}
/** A module, e.g. ExtModule etc., that should add the input pin */
-case class SinkAnnotation(target: Named, pin: String) extends
- SingleTargetAnnotation[Named] {
+case class SinkAnnotation(target: Named, pin: String) extends SingleTargetAnnotation[Named] {
def duplicate(n: Named) = this.copy(target = n)
}
@@ -76,8 +74,9 @@ class WiringTransform extends Transform with DependencyAPIMigration {
(sources.size, sinks.size) match {
case (0, p) => state
case (s, p) if (p > 0) =>
- val wis = sources.foldLeft(Seq[WiringInfo]()) { case (seq, (pin, source)) =>
- seq :+ WiringInfo(source, sinks(pin), pin)
+ val wis = sources.foldLeft(Seq[WiringInfo]()) {
+ case (seq, (pin, source)) =>
+ seq :+ WiringInfo(source, sinks(pin), pin)
}
val annosx = state.annotations.filterNot(annos.toSet.contains)
transforms(wis)
diff --git a/src/main/scala/firrtl/passes/wiring/WiringUtils.scala b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala
index c220692a..5e8f8616 100644
--- a/src/main/scala/firrtl/passes/wiring/WiringUtils.scala
+++ b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala
@@ -25,54 +25,54 @@ case object DecWire extends DecKind
/** Store of pending wiring information for a Module */
case class Modifications(
addPortOrWire: Option[(String, DecKind)] = None,
- cons: Seq[(String, String)] = Seq.empty) {
+ cons: Seq[(String, String)] = Seq.empty) {
override def toString: String = serialize("")
def serialize(tab: String): String = s"""
- |$tab addPortOrWire: $addPortOrWire
- |$tab cons: $cons
- |""".stripMargin
+ |$tab addPortOrWire: $addPortOrWire
+ |$tab cons: $cons
+ |""".stripMargin
}
/** A lineage tree representing the instance hierarchy in a design
*/
@deprecated("Use DiGraph/InstanceGraph", "1.1.1")
case class Lineage(
- name: String,
- children: Seq[(String, Lineage)] = Seq.empty,
- source: Boolean = false,
- sink: Boolean = false,
- sourceParent: Boolean = false,
- sinkParent: Boolean = false,
- sharedParent: Boolean = false,
- addPort: Option[(String, DecKind)] = None,
- cons: Seq[(String, String)] = Seq.empty) {
+ name: String,
+ children: Seq[(String, Lineage)] = Seq.empty,
+ source: Boolean = false,
+ sink: Boolean = false,
+ sourceParent: Boolean = false,
+ sinkParent: Boolean = false,
+ sharedParent: Boolean = false,
+ addPort: Option[(String, DecKind)] = None,
+ cons: Seq[(String, String)] = Seq.empty) {
def map(f: Lineage => Lineage): Lineage =
- this.copy(children = children.map{ case (i, m) => (i, f(m)) })
+ this.copy(children = children.map { case (i, m) => (i, f(m)) })
override def toString: String = shortSerialize("")
def shortSerialize(tab: String): String = s"""
- |$tab name: $name,
- |$tab children: ${children.map(c => tab + " " + c._2.shortSerialize(tab + " "))}
- |""".stripMargin
+ |$tab name: $name,
+ |$tab children: ${children.map(c => tab + " " + c._2.shortSerialize(tab + " "))}
+ |""".stripMargin
def foldLeft[B](z: B)(op: (B, (String, Lineage)) => B): B =
this.children.foldLeft(z)(op)
def serialize(tab: String): String = s"""
- |$tab name: $name,
- |$tab source: $source,
- |$tab sink: $sink,
- |$tab sourceParent: $sourceParent,
- |$tab sinkParent: $sinkParent,
- |$tab sharedParent: $sharedParent,
- |$tab addPort: $addPort
- |$tab cons: $cons
- |$tab children: ${children.map(c => tab + " " + c._2.serialize(tab + " "))}
- |""".stripMargin
+ |$tab name: $name,
+ |$tab source: $source,
+ |$tab sink: $sink,
+ |$tab sourceParent: $sourceParent,
+ |$tab sinkParent: $sinkParent,
+ |$tab sharedParent: $sharedParent,
+ |$tab addPort: $addPort
+ |$tab cons: $cons
+ |$tab children: ${children.map(c => tab + " " + c._2.serialize(tab + " "))}
+ |""".stripMargin
}
object WiringUtils {
@@ -87,12 +87,12 @@ object WiringUtils {
val childrenMap = new ChildrenMap()
def getChildren(mname: String)(s: Statement): Unit = s match {
case s: WDefInstance =>
- childrenMap(mname) = childrenMap(mname) :+( (s.name, s.module) )
+ childrenMap(mname) = childrenMap(mname) :+ ((s.name, s.module))
case s: DefInstance =>
- childrenMap(mname) = childrenMap(mname) :+( (s.name, s.module) )
+ childrenMap(mname) = childrenMap(mname) :+ ((s.name, s.module))
case s => s.foreach(getChildren(mname))
}
- c.modules.foreach{ m =>
+ c.modules.foreach { m =>
childrenMap(m.name) = Nil
m.foreach(getChildren(m.name))
}
@@ -103,7 +103,7 @@ object WiringUtils {
*/
@deprecated("Use DiGraph/InstanceGraph", "1.1.1")
def getLineage(childrenMap: ChildrenMap, module: String): Lineage =
- Lineage(module, childrenMap(module) map { case (i, m) => (i, getLineage(childrenMap, m)) } )
+ Lineage(module, childrenMap(module).map { case (i, m) => (i, getLineage(childrenMap, m)) })
/** Return a map of sink instances to source instances that minimizes
* distance
@@ -114,22 +114,25 @@ object WiringUtils {
* @return a map of sink instance names to source instance names
* @throws WiringException if a sink is equidistant to two sources
*/
- @deprecated("This method can lead to non-determinism in your compiler pass and exposes internal details." +
- " Please file an issue with firrtl if you have a use case!", "Firrtl 1.4")
+ @deprecated(
+ "This method can lead to non-determinism in your compiler pass and exposes internal details." +
+ " Please file an issue with firrtl if you have a use case!",
+ "Firrtl 1.4"
+ )
def sinksToSources(sinks: Seq[Named], source: String, i: InstanceGraph): Map[Seq[WDefInstance], Seq[WDefInstance]] = {
// The order of owners influences the order of the results, it thus needs to be deterministic with a LinkedHashMap.
val owners = new mutable.LinkedHashMap[Seq[WDefInstance], Vector[Seq[WDefInstance]]]
val queue = new mutable.Queue[Seq[WDefInstance]]
val visited = new mutable.HashMap[Seq[WDefInstance], Boolean].withDefaultValue(false)
- val sourcePaths = i.fullHierarchy.collect { case (k,v) if k.module == source => v }
+ val sourcePaths = i.fullHierarchy.collect { case (k, v) if k.module == source => v }
sourcePaths.flatten.foreach { l =>
queue.enqueue(l)
owners(l) = Vector(l)
}
val sinkModuleNames = sinks.map(getModuleName).toSet
- val sinkPaths = i.fullHierarchy.collect { case (k,v) if sinkModuleNames.contains(k.module) => v }
+ val sinkPaths = i.fullHierarchy.collect { case (k, v) if sinkModuleNames.contains(k.module) => v }
// sinkInsts needs to have unique entries but is also iterated over which is why we use a LinkedHashSet
val sinkInsts = mutable.LinkedHashSet() ++ sinkPaths.flatten
@@ -156,8 +159,8 @@ object WiringUtils {
// [todo] This is the critical section
edges
- .filter( e => !visited(e) && e.nonEmpty )
- .foreach{ v =>
+ .filter(e => !visited(e) && e.nonEmpty)
+ .foreach { v =>
owners(v) = owners.getOrElse(v, Vector()) ++ owners(u)
queue.enqueue(v)
}
@@ -167,8 +170,8 @@ object WiringUtils {
// this should fail is if a sink is equidistant to two sources.
sinkInsts.foreach { s =>
if (!owners.contains(s) || owners(s).size > 1) {
- throw new WiringException(
- s"Unable to determine source mapping for sink '${s.map(_.name)}'") }
+ throw new WiringException(s"Unable to determine source mapping for sink '${s.map(_.name)}'")
+ }
}
}
@@ -184,21 +187,24 @@ object WiringUtils {
* @return a map of sink instance names to source instance names
* @throws WiringException if a sink is equidistant to two sources
*/
- private[firrtl] def sinksToSourcesSeq(sinks: Seq[Named], source: String, i: InstanceKeyGraph):
- Seq[(Seq[InstanceKey], Seq[InstanceKey])] = {
+ private[firrtl] def sinksToSourcesSeq(
+ sinks: Seq[Named],
+ source: String,
+ i: InstanceKeyGraph
+ ): Seq[(Seq[InstanceKey], Seq[InstanceKey])] = {
// The order of owners influences the order of the results, it thus needs to be deterministic with a LinkedHashMap.
val owners = new mutable.LinkedHashMap[Seq[InstanceKey], Vector[Seq[InstanceKey]]]
val queue = new mutable.Queue[Seq[InstanceKey]]
val visited = new mutable.HashMap[Seq[InstanceKey], Boolean].withDefaultValue(false)
- val sourcePaths = i.fullHierarchy.collect { case (k,v) if k.module == source => v }
+ val sourcePaths = i.fullHierarchy.collect { case (k, v) if k.module == source => v }
sourcePaths.flatten.foreach { l =>
queue.enqueue(l)
owners(l) = Vector(l)
}
val sinkModuleNames = sinks.map(getModuleName).toSet
- val sinkPaths = i.fullHierarchy.collect { case (k,v) if sinkModuleNames.contains(k.module) => v }
+ val sinkPaths = i.fullHierarchy.collect { case (k, v) if sinkModuleNames.contains(k.module) => v }
// sinkInsts needs to have unique entries but is also iterated over which is why we use a LinkedHashSet
val sinkInsts = mutable.LinkedHashSet() ++ sinkPaths.flatten
@@ -225,8 +231,8 @@ object WiringUtils {
// [todo] This is the critical section
edges
- .filter( e => !visited(e) && e.nonEmpty )
- .foreach{ v =>
+ .filter(e => !visited(e) && e.nonEmpty)
+ .foreach { v =>
owners(v) = owners.getOrElse(v, Vector()) ++ owners(u)
queue.enqueue(v)
}
@@ -236,8 +242,8 @@ object WiringUtils {
// this should fail is if a sink is equidistant to two sources.
sinkInsts.foreach { s =>
if (!owners.contains(s) || owners(s).size > 1) {
- throw new WiringException(
- s"Unable to determine source mapping for sink '${s.map(_.name)}'") }
+ throw new WiringException(s"Unable to determine source mapping for sink '${s.map(_.name)}'")
+ }
}
}
@@ -249,8 +255,7 @@ object WiringUtils {
n match {
case ModuleName(m, _) => m
case ComponentName(_, ModuleName(m, _)) => m
- case _ => throw new WiringException(
- "Only Components or Modules have an associated Module name")
+ case _ => throw new WiringException("Only Components or Modules have an associated Module name")
}
}
@@ -266,9 +271,9 @@ object WiringUtils {
def getType(c: Circuit, module: String, comp: String): Type = {
def getRoot(e: Expression): String = e match {
case r: Reference => r.name
- case i: SubIndex => getRoot(i.expr)
+ case i: SubIndex => getRoot(i.expr)
case a: SubAccess => getRoot(a.expr)
- case f: SubField => getRoot(f.expr)
+ case f: SubField => getRoot(f.expr)
}
val eComp = toExp(comp)
val root = getRoot(eComp)
@@ -289,11 +294,12 @@ object WiringUtils {
case sx: DefMemory if sx.name == root =>
tpe = Some(MemPortUtils.memType(sx))
sx
- case sx => sx map getType
+ case sx => sx.map(getType)
+ }
+ val m = c.modules.find(_.name == module).getOrElse {
+ throw new WiringException(s"Must have a module named $module")
}
- val m = c.modules find (_.name == module) getOrElse {
- throw new WiringException(s"Must have a module named $module") }
- tpe = m.ports find (_.name == root) map (_.tpe)
+ tpe = m.ports.find(_.name == root).map(_.tpe)
m match {
case Module(i, n, ps, b) => getType(b)
case e: ExtModule =>
@@ -301,10 +307,10 @@ object WiringUtils {
tpe match {
case None => throw new WiringException(s"Didn't find $comp in $module!")
case Some(t) =>
- def setType(e: Expression): Expression = e map setType match {
+ def setType(e: Expression): Expression = e.map(setType) match {
case ex: Reference => ex.copy(tpe = t)
- case ex: SubField => ex.copy(tpe = field_type(ex.expr.tpe, ex.name))
- case ex: SubIndex => ex.copy(tpe = sub_type(ex.expr.tpe))
+ case ex: SubField => ex.copy(tpe = field_type(ex.expr.tpe, ex.name))
+ case ex: SubIndex => ex.copy(tpe = sub_type(ex.expr.tpe))
case ex: SubAccess => ex.copy(tpe = sub_type(ex.expr.tpe))
}
setType(eComp).tpe