aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
diff options
context:
space:
mode:
authorchick2020-08-14 19:47:53 -0700
committerJack Koenig2020-08-14 19:47:53 -0700
commit6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch)
tree2ed103ee80b0fba613c88a66af854ae9952610ce /src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
parentb516293f703c4de86397862fee1897aded2ae140 (diff)
All of src/ formatted with scalafmt
Diffstat (limited to 'src/main/scala/firrtl/passes/RemoveCHIRRTL.scala')
-rw-r--r--src/main/scala/firrtl/passes/RemoveCHIRRTL.scala196
1 files changed, 112 insertions, 84 deletions
diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
index 61fd6258..624138ab 100644
--- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
+++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
@@ -17,8 +17,7 @@ case class DataRef(exp: Expression, source: String, sink: String, mask: String,
object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.ChirrtlForm ++
- Seq( Dependency(passes.CInferTypes),
- Dependency(passes.CInferMDir) )
+ Seq(Dependency(passes.CInferTypes), Dependency(passes.CInferMDir))
override def invalidates(a: Transform) = false
@@ -31,10 +30,14 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
def create_all_exps(ex: Expression): Seq[Expression] = ex.tpe match {
case _: GroundType => Seq(ex)
- case t: BundleType => (t.fields foldLeft Seq[Expression]())((exps, f) =>
- exps ++ create_all_exps(SubField(ex, f.name, f.tpe))) ++ Seq(ex)
- case t: VectorType => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) =>
- exps ++ create_all_exps(SubIndex(ex, i, t.tpe))) ++ Seq(ex)
+ case t: BundleType =>
+ (t.fields.foldLeft(Seq[Expression]()))((exps, f) => exps ++ create_all_exps(SubField(ex, f.name, f.tpe))) ++ Seq(
+ ex
+ )
+ case t: VectorType =>
+ ((0 until t.size).foldLeft(Seq[Expression]()))((exps, i) =>
+ exps ++ create_all_exps(SubIndex(ex, i, t.tpe))
+ ) ++ Seq(ex)
case UnknownType => Seq(ex)
}
@@ -42,17 +45,18 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
case ex: Mux =>
val e1s = create_exps(ex.tval)
val e2s = create_exps(ex.fval)
- (e1s zip e2s) map { case (e1, e2) => Mux(ex.cond, e1, e2, mux_type(e1, e2)) }
+ (e1s.zip(e2s)).map { case (e1, e2) => Mux(ex.cond, e1, e2, mux_type(e1, e2)) }
case ex: ValidIf =>
- create_exps(ex.value) map (e1 => ValidIf(ex.cond, e1, e1.tpe))
- case ex => ex.tpe match {
- case _: GroundType => Seq(ex)
- case t: BundleType => (t.fields foldLeft Seq[Expression]())((exps, f) =>
- exps ++ create_exps(SubField(ex, f.name, f.tpe)))
- case t: VectorType => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) =>
- exps ++ create_exps(SubIndex(ex, i, t.tpe)))
- case UnknownType => Seq(ex)
- }
+ create_exps(ex.value).map(e1 => ValidIf(ex.cond, e1, e1.tpe))
+ case ex =>
+ ex.tpe match {
+ case _: GroundType => Seq(ex)
+ case t: BundleType =>
+ (t.fields.foldLeft(Seq[Expression]()))((exps, f) => exps ++ create_exps(SubField(ex, f.name, f.tpe)))
+ case t: VectorType =>
+ ((0 until t.size).foldLeft(Seq[Expression]()))((exps, i) => exps ++ create_exps(SubIndex(ex, i, t.tpe)))
+ case UnknownType => Seq(ex)
+ }
}
private def EMPs: MPorts = MPorts(ArrayBuffer[MPort](), ArrayBuffer[MPort](), ArrayBuffer[MPort]())
@@ -61,40 +65,48 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
s match {
case sx: CDefMemory if sx.seq => smems += sx.name
case sx: CDefMPort =>
- val p = mports getOrElse (sx.mem, EMPs)
+ val p = mports.getOrElse(sx.mem, EMPs)
sx.direction match {
- case MRead => p.readers += MPort(sx.name, sx.exps(1))
- case MWrite => p.writers += MPort(sx.name, sx.exps(1))
+ case MRead => p.readers += MPort(sx.name, sx.exps(1))
+ case MWrite => p.writers += MPort(sx.name, sx.exps(1))
case MReadWrite => p.readwriters += MPort(sx.name, sx.exps(1))
- case MInfer => // direction may not be inferred if it's not being used
+ case MInfer => // direction may not be inferred if it's not being used
}
mports(sx.mem) = p
case _ =>
}
- s map collect_smems_and_mports(mports, smems)
+ s.map(collect_smems_and_mports(mports, smems))
}
- def collect_refs(mports: MPortMap, smems: SeqMemSet, types: MPortTypeMap,
- refs: DataRefMap, raddrs: AddrMap, renames: RenameMap)(s: Statement): Statement = s match {
+ def collect_refs(
+ mports: MPortMap,
+ smems: SeqMemSet,
+ types: MPortTypeMap,
+ refs: DataRefMap,
+ raddrs: AddrMap,
+ renames: RenameMap
+ )(s: Statement
+ ): Statement = s match {
case sx: CDefMemory =>
types(sx.name) = sx.tpe
- val taddr = UIntType(IntWidth(1 max getUIntWidth(sx.size - 1)))
+ val taddr = UIntType(IntWidth(1.max(getUIntWidth(sx.size - 1))))
val tdata = sx.tpe
- def set_poison(vec: scala.collection.Seq[MPort]) = vec.toSeq.flatMap (r => Seq(
- IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "addr", taddr)),
- IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "clk", ClockType))
- ))
- def set_enable(vec: scala.collection.Seq[MPort], en: String) = vec.toSeq.map (r =>
- Connect(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), en, BoolType), zero)
+ def set_poison(vec: scala.collection.Seq[MPort]) = vec.toSeq.flatMap(r =>
+ Seq(
+ IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "addr", taddr)),
+ IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "clk", ClockType))
+ )
)
+ def set_enable(vec: scala.collection.Seq[MPort], en: String) =
+ vec.toSeq.map(r => Connect(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), en, BoolType), zero))
def set_write(vec: scala.collection.Seq[MPort], data: String, mask: String) = vec.toSeq.flatMap { r =>
val tmask = createMask(sx.tpe)
val portRef = SubField(Reference(sx.name, ut), r.name, ut)
Seq(IsInvalid(sx.info, SubField(portRef, data, tdata)), IsInvalid(sx.info, SubField(portRef, mask, tmask)))
}
- val rds = (mports getOrElse (sx.name, EMPs)).readers
- val wrs = (mports getOrElse (sx.name, EMPs)).writers
- val rws = (mports getOrElse (sx.name, EMPs)).readwriters
+ val rds = (mports.getOrElse(sx.name, EMPs)).readers
+ val wrs = (mports.getOrElse(sx.name, EMPs)).writers
+ val rws = (mports.getOrElse(sx.name, EMPs)).readwriters
val stmts = set_poison(rds) ++
set_enable(rds, "en") ++
set_poison(wrs) ++
@@ -104,8 +116,18 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
set_enable(rws, "wmode") ++
set_enable(rws, "en") ++
set_write(rws, "wdata", "wmask")
- val mem = DefMemory(sx.info, sx.name, sx.tpe, sx.size, 1, if (sx.seq) 1 else 0,
- rds.map(_.name).toSeq, wrs.map(_.name).toSeq, rws.map(_.name).toSeq, sx.readUnderWrite)
+ val mem = DefMemory(
+ sx.info,
+ sx.name,
+ sx.tpe,
+ sx.size,
+ 1,
+ if (sx.seq) 1 else 0,
+ rds.map(_.name).toSeq,
+ wrs.map(_.name).toSeq,
+ rws.map(_.name).toSeq,
+ sx.readUnderWrite
+ )
Block(mem +: stmts)
case sx: CDefMPort =>
types.get(sx.mem) match {
@@ -130,8 +152,8 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
val es = create_all_exps(WRef(sx.name, sx.tpe))
val rs = create_all_exps(WRef(s"${sx.mem}.${sx.name}.rdata", sx.tpe))
val ws = create_all_exps(WRef(s"${sx.mem}.${sx.name}.wdata", sx.tpe))
- ((es zip rs) zip ws) map {
- case ((e, r), w) => renames.rename(e.serialize, Seq(r.serialize, w.serialize))
+ ((es.zip(rs)).zip(ws)).map {
+ case ((e, r), w) => renames.rename(e.serialize, Seq(r.serialize, w.serialize))
}
case MWrite =>
refs(sx.name) = DataRef(portRef, "data", "data", "mask", rdwrite = false)
@@ -142,7 +164,7 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
renames.rename(sx.name, s"${sx.mem}.${sx.name}.data")
val es = create_all_exps(WRef(sx.name, sx.tpe))
val ws = create_all_exps(WRef(s"${sx.mem}.${sx.name}.data", sx.tpe))
- (es zip ws) map {
+ (es.zip(ws)).map {
case (e, w) => renames.rename(e.serialize, w.serialize)
}
case MRead =>
@@ -157,63 +179,69 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
renames.rename(sx.name, s"${sx.mem}.${sx.name}.data")
val es = create_all_exps(WRef(sx.name, sx.tpe))
val rs = create_all_exps(WRef(s"${sx.mem}.${sx.name}.data", sx.tpe))
- (es zip rs) map {
+ (es.zip(rs)).map {
case (e, r) => renames.rename(e.serialize, r.serialize)
}
case MInfer => // do nothing if it's not being used
}
- Block(List() ++
- (addrs.map (x => Connect(sx.info, SubField(portRef, x, ut), sx.exps.head))) ++
- (clks map (x => Connect(sx.info, SubField(portRef, x, ut), sx.exps(1)))) ++
- (ens map (x => Connect(sx.info,SubField(portRef, x, ut), one))) ++
- masks.map(lhs => Connect(sx.info, lhs, zero))
+ Block(
+ List() ++
+ (addrs.map(x => Connect(sx.info, SubField(portRef, x, ut), sx.exps.head))) ++
+ (clks.map(x => Connect(sx.info, SubField(portRef, x, ut), sx.exps(1)))) ++
+ (ens.map(x => Connect(sx.info, SubField(portRef, x, ut), one))) ++
+ masks.map(lhs => Connect(sx.info, lhs, zero))
)
- case sx => sx map collect_refs(mports, smems, types, refs, raddrs, renames)
+ case sx => sx.map(collect_refs(mports, smems, types, refs, raddrs, renames))
}
def get_mask(refs: DataRefMap)(e: Expression): Expression =
- e map get_mask(refs) match {
- case ex: Reference => refs get ex.name match {
- case None => ex
- case Some(p) => SubField(p.exp, p.mask, createMask(ex.tpe))
- }
+ e.map(get_mask(refs)) match {
+ case ex: Reference =>
+ refs.get(ex.name) match {
+ case None => ex
+ case Some(p) => SubField(p.exp, p.mask, createMask(ex.tpe))
+ }
case ex => ex
}
def remove_chirrtl_s(refs: DataRefMap, raddrs: AddrMap)(s: Statement): Statement = {
var has_write_mport = false
var has_readwrite_mport: Option[Expression] = None
- var has_read_mport: Option[Expression] = None
+ var has_read_mport: Option[Expression] = None
def remove_chirrtl_e(g: Flow)(e: Expression): Expression = e match {
- case Reference(name, tpe, _, _) => refs get name match {
- case Some(p) => g match {
- case SinkFlow =>
- has_write_mport = true
- if (p.rdwrite) has_readwrite_mport = Some(SubField(p.exp, "wmode", BoolType))
- SubField(p.exp, p.sink, tpe)
- case SourceFlow =>
- SubField(p.exp, p.source, tpe)
- }
- case None => g match {
- case SinkFlow => raddrs get name match {
- case Some(en) => has_read_mport = Some(en) ; e
- case None => e
- }
- case SourceFlow => e
+ case Reference(name, tpe, _, _) =>
+ refs.get(name) match {
+ case Some(p) =>
+ g match {
+ case SinkFlow =>
+ has_write_mport = true
+ if (p.rdwrite) has_readwrite_mport = Some(SubField(p.exp, "wmode", BoolType))
+ SubField(p.exp, p.sink, tpe)
+ case SourceFlow =>
+ SubField(p.exp, p.source, tpe)
+ }
+ case None =>
+ g match {
+ case SinkFlow =>
+ raddrs.get(name) match {
+ case Some(en) => has_read_mport = Some(en); e
+ case None => e
+ }
+ case SourceFlow => e
+ }
}
- }
- case SubAccess(expr, index, tpe, _) => SubAccess(
- remove_chirrtl_e(g)(expr), remove_chirrtl_e(SourceFlow)(index), tpe)
- case ex => ex map remove_chirrtl_e(g)
- }
- s match {
+ case SubAccess(expr, index, tpe, _) =>
+ SubAccess(remove_chirrtl_e(g)(expr), remove_chirrtl_e(SourceFlow)(index), tpe)
+ case ex => ex.map(remove_chirrtl_e(g))
+ }
+ s match {
case DefNode(info, name, value) =>
val valuex = remove_chirrtl_e(SourceFlow)(value)
val sx = DefNode(info, name, valuex)
// Check node is used for read port address
remove_chirrtl_e(SinkFlow)(Reference(name, value.tpe))
has_read_mport match {
- case None => sx
+ case None => sx
case Some(en) => Block(sx, Connect(info, en, one))
}
case Connect(info, loc, expr) =>
@@ -222,14 +250,14 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
val sx = Connect(info, locx, rocx)
val stmts = ArrayBuffer[Statement]()
has_read_mport match {
- case None =>
+ case None =>
case Some(en) => stmts += Connect(info, en, one)
}
if (has_write_mport) {
val locs = create_exps(get_mask(refs)(loc))
- stmts ++= (locs map (x => Connect(info, x, one)))
+ stmts ++= (locs.map(x => Connect(info, x, one)))
has_readwrite_mport match {
- case None =>
+ case None =>
case Some(wmode) => stmts += Connect(info, wmode, one)
}
}
@@ -240,20 +268,20 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
val sx = PartialConnect(info, locx, rocx)
val stmts = ArrayBuffer[Statement]()
has_read_mport match {
- case None =>
+ case None =>
case Some(en) => stmts += Connect(info, en, one)
}
if (has_write_mport) {
val ls = get_valid_points(loc.tpe, expr.tpe, Default, Default)
val locs = create_exps(get_mask(refs)(loc))
- stmts ++= (ls map { case (x, _) => Connect(info, locs(x), one) })
+ stmts ++= (ls.map { case (x, _) => Connect(info, locs(x), one) })
has_readwrite_mport match {
- case None =>
+ case None =>
case Some(wmode) => stmts += Connect(info, wmode, one)
}
}
if (stmts.isEmpty) sx else Block(sx +: stmts.toSeq)
- case sx => sx map remove_chirrtl_s(refs, raddrs) map remove_chirrtl_e(SourceFlow)
+ case sx => sx.map(remove_chirrtl_s(refs, raddrs)).map(remove_chirrtl_e(SourceFlow))
}
}
@@ -264,16 +292,16 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
val refs = new DataRefMap
val raddrs = new AddrMap
renames.setModule(m.name)
- (m map collect_smems_and_mports(mports, smems)
- map collect_refs(mports, smems, types, refs, raddrs, renames)
- map remove_chirrtl_s(refs, raddrs))
+ (m.map(collect_smems_and_mports(mports, smems))
+ .map(collect_refs(mports, smems, types, refs, raddrs, renames))
+ .map(remove_chirrtl_s(refs, raddrs)))
}
def execute(state: CircuitState): CircuitState = {
val c = state.circuit
val renames = RenameMap()
renames.setCircuit(c.main)
- val result = c copy (modules = c.modules map remove_chirrtl_m(renames))
+ val result = c.copy(modules = c.modules.map(remove_chirrtl_m(renames)))
state.copy(circuit = result, renames = Some(renames))
}
}