diff options
| author | Donggyu Kim | 2016-08-26 00:07:36 -0700 |
|---|---|---|
| committer | Donggyu Kim | 2016-09-08 13:15:56 -0700 |
| commit | 2a513ff47eebe38a81a1312c51972fcecaeb114f (patch) | |
| tree | 1f02ee22d028ff50a656a1b475160aa650d4113c /src | |
| parent | a6c0ee1c556d8e2ccd3aaf05f2c132734152a706 (diff) | |
refactor RemoveCHIRRTL
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveCHIRRTL.scala | 432 |
1 files changed, 196 insertions, 236 deletions
diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala index 47e6bbbc..2bae92a7 100644 --- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala +++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala @@ -27,270 +27,230 @@ MODIFICATIONS. package firrtl.passes -import com.typesafe.scalalogging.LazyLogging -import java.nio.file.{Paths, Files} - // Datastructures -import scala.collection.mutable.LinkedHashMap -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet import scala.collection.mutable.ArrayBuffer import firrtl._ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ -import firrtl.PrimOps._ -import firrtl.WrappedExpression._ -case class MPort( val name : String, val clk : Expression) -case class MPorts( val readers : ArrayBuffer[MPort], val writers : ArrayBuffer[MPort], val readwriters : ArrayBuffer[MPort]) -case class DataRef( val exp : Expression, val male : String, val female : String, val mask : String, val rdwrite : Boolean) +case class MPort(name: String, clk: Expression) +case class MPorts(readers: ArrayBuffer[MPort], writers: ArrayBuffer[MPort], readwriters: ArrayBuffer[MPort]) +case class DataRef(exp: Expression, male: String, female: String, mask: String, rdwrite: Boolean) object RemoveCHIRRTL extends Pass { def name = "Remove CHIRRTL" - var mname = "" - def create_exps (e:Expression) : Seq[Expression] = e match { - case (e:Mux) => + + val ut = UnknownType + type MPortMap = collection.mutable.LinkedHashMap[String, MPorts] + type SeqMemSet = collection.mutable.HashSet[String] + type MPortTypeMap = collection.mutable.LinkedHashMap[String, Type] + type DataRefMap = collection.mutable.LinkedHashMap[String, DataRef] + type AddrMap = collection.mutable.HashMap[String, Expression] + + def create_exps(e: Expression): Seq[Expression] = e match { + case (e: Mux) => val e1s = create_exps(e.tval) val e2s = create_exps(e.fval) - (e1s,e2s).zipped map ((e1,e2) => Mux(e.cond,e1,e2,mux_type(e1,e2))) - case (e:ValidIf) => - create_exps(e.value) map (e1 => ValidIf(e.cond,e1,e1.tpe)) + (e1s zip e2s) map { case (e1, e2) => Mux(e.cond, e1, e2, mux_type(e1, e2)) } + case (e: ValidIf) => + create_exps(e.value) map (e1 => ValidIf(e.cond, e1, e1.tpe)) case (e) => (e.tpe) match { - case (_:GroundType) => Seq(e) - case (t:BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) => - exps ++ create_exps(SubField(e,f.name,f.tpe))) - case (t:VectorType) => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) => - exps ++ create_exps(SubIndex(e,i,t.tpe))) + case (_: GroundType) => Seq(e) + case (t: BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) => + exps ++ create_exps(SubField(e, f.name, f.tpe))) + case (t: VectorType) => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) => + exps ++ create_exps(SubIndex(e, i, t.tpe))) case UnknownType => Seq(e) } } - def run (c:Circuit) : Circuit = { - def remove_chirrtl_m (m:Module) : Module = { - val hash = LinkedHashMap[String,MPorts]() - val repl = LinkedHashMap[String,DataRef]() - val raddrs = HashMap[String, Expression]() - val ut = UnknownType - val mport_types = LinkedHashMap[String,Type]() - val smems = HashSet[String]() - def EMPs () : MPorts = MPorts(ArrayBuffer[MPort](),ArrayBuffer[MPort](),ArrayBuffer[MPort]()) - def collect_smems_and_mports (s:Statement) : Statement = { - (s) match { - case (s:CDefMemory) if s.seq => - smems += s.name - s - case (s:CDefMPort) => { - val mports = hash.getOrElse(s.mem,EMPs()) - s.direction match { - case MRead => mports.readers += MPort(s.name,s.exps(1)) - case MWrite => mports.writers += MPort(s.name,s.exps(1)) - case MReadWrite => mports.readwriters += MPort(s.name,s.exps(1)) - } - hash(s.mem) = mports - s - } - case (s) => s map (collect_smems_and_mports) + + private def EMPs: MPorts = MPorts(ArrayBuffer[MPort](), ArrayBuffer[MPort](), ArrayBuffer[MPort]()) + + def collect_smems_and_mports(mports: MPortMap, smems: SeqMemSet)(s: Statement): Statement = { + s match { + case (s:CDefMemory) if s.seq => smems += s.name + case (s:CDefMPort) => + val p = mports getOrElse (s.mem, EMPs) + s.direction match { + case MRead => p.readers += MPort(s.name,s.exps(1)) + case MWrite => p.writers += MPort(s.name,s.exps(1)) + case MReadWrite => p.readwriters += MPort(s.name,s.exps(1)) } + mports(s.mem) = p + case s => + } + s map collect_smems_and_mports(mports, smems) + } + + def collect_refs(mports: MPortMap, smems: SeqMemSet, types: MPortTypeMap, + refs: DataRefMap, raddrs: AddrMap)(s: Statement): Statement = s match { + case (s: CDefMemory) => + types(s.name) = s.tpe + val taddr = UIntType(IntWidth(math.max(1, ceil_log2(s.size)))) + val tdata = s.tpe + def set_poison(vec: Seq[MPort], addr: String) = vec flatMap (r => Seq( + IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), addr, taddr)), + IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), "clk", taddr)) + )) + def set_enable(vec: Seq[MPort], en: String) = vec map (r => + Connect(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), en, taddr), zero) + ) + def set_wmode (vec: Seq[MPort], wmode: String) = vec map (r => + Connect(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), wmode, taddr), zero) + ) + def set_write (vec: Seq[MPort], data: String, mask: String) = vec flatMap {r => + val tmask = create_mask(s.tpe) + IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), data, tdata)) +: + (create_exps(SubField(SubField(Reference(s.name, ut), r.name, ut), mask, tmask)) + map (Connect(s.info, _, zero)) + ) } - def collect_refs (s:Statement) : Statement = { - (s) match { - case (s:CDefMemory) => { - mport_types(s.name) = s.tpe - val stmts = ArrayBuffer[Statement]() - val taddr = UIntType(IntWidth(scala.math.max(1,ceil_log2(s.size)))) - val tdata = s.tpe - def set_poison (vec:Seq[MPort],addr:String) : Unit = { - for (r <- vec ) { - stmts += IsInvalid(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),addr,taddr)) - stmts += IsInvalid(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),"clk",taddr)) - } - } - def set_enable (vec:Seq[MPort],en:String) : Unit = { - for (r <- vec ) { - stmts += Connect(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),en,taddr),zero) - }} - def set_wmode (vec:Seq[MPort],wmode:String) : Unit = { - for (r <- vec) { - stmts += Connect(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),wmode,taddr),zero) - }} - def set_write (vec:Seq[MPort],data:String,mask:String) : Unit = { - val tmask = create_mask(s.tpe) - for (r <- vec ) { - stmts += IsInvalid(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),data,tdata)) - for (x <- create_exps(SubField(SubField(Reference(s.name,ut),r.name,ut),mask,tmask)) ) { - stmts += Connect(s.info,x,zero) - }}} - val rds = (hash.getOrElse(s.name,EMPs())).readers - set_poison(rds,"addr") - set_enable(rds,"en") - val wrs = (hash.getOrElse(s.name,EMPs())).writers - set_poison(wrs,"addr") - set_enable(wrs,"en") - set_write(wrs,"data","mask") - val rws = (hash.getOrElse(s.name,EMPs())).readwriters - set_poison(rws,"addr") - set_wmode(rws,"wmode") - set_enable(rws,"en") - set_write(rws,"wdata","wmask") - val read_l = if (s.seq) 1 else 0 - val mem = DefMemory(s.info,s.name,s.tpe,s.size,1,read_l,rds.map(_.name),wrs.map(_.name),rws.map(_.name)) - Block(Seq(mem,Block(stmts))) + val rds = (mports getOrElse (s.name, EMPs)).readers + val wrs = (mports getOrElse (s.name, EMPs)).writers + val rws = (mports getOrElse (s.name, EMPs)).readwriters + val stmts = set_poison(rds, "addr") ++ + set_enable(rds, "en") ++ + set_poison(wrs, "addr") ++ + set_enable(wrs, "en") ++ + set_write(wrs, "data", "mask") ++ + set_poison(rws, "addr") ++ + set_wmode(rws, "wmode") ++ + set_enable(rws, "en") ++ + set_write(rws, "wdata", "wmask") + val mem = DefMemory(s.info, s.name, s.tpe, s.size, 1, if (s.seq) 1 else 0, + rds map (_.name), wrs map (_.name), rws map (_.name)) + Block(mem +: stmts) + case (s: CDefMPort) => { + types(s.name) = types(s.mem) + val addrs = ArrayBuffer[String]() + val clks = ArrayBuffer[String]() + val ens = ArrayBuffer[String]() + s.direction match { + case MReadWrite => + refs(s.name) = DataRef(SubField(Reference(s.mem, ut), s.name, ut), "rdata", "wdata", "wmask", true) + addrs += "addr" + clks += "clk" + ens += "en" + case MWrite => + refs(s.name) = DataRef(SubField(Reference(s.mem, ut), s.name, ut), "data", "data", "mask", false) + addrs += "addr" + clks += "clk" + ens += "en" + case MRead => + refs(s.name) = DataRef(SubField(Reference(s.mem, ut), s.name, ut), "data", "data", "blah", false) + addrs += "addr" + clks += "clk" + s.exps.head match { + case e: Reference if smems(s.mem) => + raddrs(e.name) = SubField(SubField(Reference(s.mem, ut), s.name, ut), "en", ut) + case _ => ens += "en" } - case (s:CDefMPort) => { - mport_types(s.name) = mport_types(s.mem) - val addrs = ArrayBuffer[String]() - val clks = ArrayBuffer[String]() - val ens = ArrayBuffer[String]() - val masks = ArrayBuffer[String]() - s.direction match { - case MReadWrite => { - repl(s.name) = DataRef(SubField(Reference(s.mem,ut),s.name,ut),"rdata","wdata","wmask",true) - addrs += "addr" - clks += "clk" - ens += "en" - masks += "wmask" - } - case MWrite => { - repl(s.name) = DataRef(SubField(Reference(s.mem,ut),s.name,ut),"data","data","mask",false) - addrs += "addr" - clks += "clk" - ens += "en" - masks += "mask" - } - case MRead => { - repl(s.name) = DataRef(SubField(Reference(s.mem,ut),s.name,ut),"data","data","blah",false) - addrs += "addr" - clks += "clk" - s.exps(0) match { - case e: Reference if smems(s.mem) => - raddrs(e.name) = SubField(SubField(Reference(s.mem,ut),s.name,ut),"en",ut) - case _ => ens += "en" - } - } - } - val stmts = ArrayBuffer[Statement]() - for (x <- addrs ) { - stmts += Connect(s.info,SubField(SubField(Reference(s.mem,ut),s.name,ut),x,ut),s.exps(0)) - } - for (x <- clks ) { - stmts += Connect(s.info,SubField(SubField(Reference(s.mem,ut),s.name,ut),x,ut),s.exps(1)) - } - for (x <- ens ) { - stmts += Connect(s.info,SubField(SubField(Reference(s.mem,ut),s.name,ut),x,ut),one) - } - Block(stmts) + } + Block( + (addrs map (x => Connect(s.info, SubField(SubField(Reference(s.mem, ut), s.name, ut), x, ut), s.exps(0)))) ++ + (clks map (x => Connect(s.info, SubField(SubField(Reference(s.mem, ut), s.name, ut), x, ut), s.exps(1)))) ++ + (ens map (x => Connect(s.info,SubField(SubField(Reference(s.mem,ut), s.name, ut), x, ut), one)))) + } + case (s) => s map collect_refs(mports, smems, types, refs, raddrs) + } + + def get_mask(refs: DataRefMap)(e: Expression): Expression = + e map get_mask(refs) match { + case e: Reference => refs get e.name match { + case None => e + case Some(p) => SubField(p.exp, p.mask, create_mask(e.tpe)) + } + case e => e + } + + def remove_chirrtl_s(refs: DataRefMap, raddrs: AddrMap)(s: Statement): Statement = { + var has_write_mport = false + var has_readwrite_mport: Option[Expression] = None + var has_read_mport: Option[Expression] = None + def remove_chirrtl_e(g: Gender)(e: Expression): Expression = e match { + case Reference(name, tpe) => refs get name match { + case Some(p) => g match { + case FEMALE => + has_write_mport = true + if (p.rdwrite) has_readwrite_mport = Some(SubField(p.exp, "wmode", UIntType(IntWidth(1)))) + SubField(p.exp, p.female, tpe) + case MALE => + SubField(p.exp, p.male, tpe) + } + case None => g match { + case FEMALE => raddrs get name match { + case Some(en) => has_read_mport = Some(en) ; e + case None => e } - case (s) => s map (collect_refs) + case MALE => e } } - def remove_chirrtl_s (s:Statement) : Statement = { - var has_write_mport = false - var has_read_mport: Option[Expression] = None - var has_readwrite_mport: Option[Expression] = None - def remove_chirrtl_e (g:Gender)(e:Expression) : Expression = { - (e) match { - case (e:Reference) if repl contains e.name => - val vt = repl(e.name) - g match { - case MALE => SubField(vt.exp,vt.male,e.tpe) - case FEMALE => { - has_write_mport = true - if (vt.rdwrite) - has_readwrite_mport = Some(SubField(vt.exp,"wmode",UIntType(IntWidth(1)))) - SubField(vt.exp,vt.female,e.tpe) - } - } - case (e:Reference) if g == FEMALE && (raddrs contains e.name) => - has_read_mport = Some(raddrs(e.name)) - e - case (e:Reference) => e - case (e:SubAccess) => SubAccess(remove_chirrtl_e(g)(e.expr),remove_chirrtl_e(MALE)(e.index),e.tpe) - case (e) => e map (remove_chirrtl_e(g)) - } + case SubAccess(expr, index, tpe) => SubAccess( + remove_chirrtl_e(g)(expr), remove_chirrtl_e(MALE)(index), tpe) + case e => e map remove_chirrtl_e(g) + } + (s) match { + case DefNode(info, name, value) => + val valuex = remove_chirrtl_e(MALE)(value) + val sx = DefNode(info, name, valuex) + has_read_mport match { + case None => sx + case Some(en) => Block(Seq(sx, Connect(info, en, one))) } - def get_mask (e:Expression) : Expression = { - (e map (get_mask)) match { - case (e:Reference) => { - if (repl.contains(e.name)) { - val vt = repl(e.name) - val t = create_mask(e.tpe) - SubField(vt.exp,vt.mask,t) - } else e - } - case (e) => e - } + case Connect(info, loc, expr) => + val rocx = remove_chirrtl_e(MALE)(expr) + val locx = remove_chirrtl_e(FEMALE)(loc) + val sx = Connect(info, locx, rocx) + val stmts = ArrayBuffer[Statement]() + has_read_mport match { + case None => + case Some(en) => stmts += Connect(info, en, one) } - (s) match { - case (s:DefNode) => { - val stmts = ArrayBuffer[Statement]() - val valuex = remove_chirrtl_e(MALE)(s.value) - stmts += DefNode(s.info,s.name,valuex) - has_read_mport match { - case None => - case Some(en) => stmts += Connect(s.info,en,one) - } - if (stmts.size > 1) Block(stmts) - else stmts(0) - } - case (s:Connect) => { - val stmts = ArrayBuffer[Statement]() - val rocx = remove_chirrtl_e(MALE)(s.expr) - val locx = remove_chirrtl_e(FEMALE)(s.loc) - stmts += Connect(s.info,locx,rocx) - has_read_mport match { - case None => - case Some(en) => stmts += Connect(s.info,en,one) - } - if (has_write_mport) { - val e = get_mask(s.loc) - for (x <- create_exps(e) ) { - stmts += Connect(s.info,x,one) - } - has_readwrite_mport match { - case None => - case Some(wmode) => stmts += Connect(s.info,wmode,one) - } - } - if (stmts.size > 1) Block(stmts) - else stmts(0) + if (has_write_mport) { + val locs = create_exps(get_mask(refs)(loc)) + stmts ++= (locs map (x => Connect(info, x, one))) + has_readwrite_mport match { + case None => + case Some(wmode) => stmts += Connect(info, wmode, one) } - case (s:PartialConnect) => { - val stmts = ArrayBuffer[Statement]() - val locx = remove_chirrtl_e(FEMALE)(s.loc) - val rocx = remove_chirrtl_e(MALE)(s.expr) - stmts += PartialConnect(s.info,locx,rocx) - has_read_mport match { - case None => - case Some(en) => stmts += Connect(s.info,en,one) - } - if (has_write_mport) { - val ls = get_valid_points(s.loc.tpe,s.expr.tpe,Default,Default) - val locs = create_exps(get_mask(s.loc)) - for (x <- ls ) { - val locx = locs(x._1) - stmts += Connect(s.info,locx,one) - } - has_readwrite_mport match { - case None => - case Some(wmode) => stmts += Connect(s.info,wmode,one) - } - } - if (stmts.size > 1) Block(stmts) - else stmts(0) + } + if (stmts.isEmpty) sx else Block(sx +: stmts) + case PartialConnect(info, loc, expr) => + val locx = remove_chirrtl_e(FEMALE)(loc) + val rocx = remove_chirrtl_e(MALE)(expr) + val sx = PartialConnect(info, locx, rocx) + val stmts = ArrayBuffer[Statement]() + has_read_mport match { + case None => + case Some(en) => stmts += Connect(info, en, one) + } + if (has_write_mport) { + val ls = get_valid_points(loc.tpe, expr.tpe, Default, Default) + val locs = create_exps(get_mask(refs)(loc)) + stmts ++= (ls map { case (x, _) => Connect(info, locs(x), one) }) + has_readwrite_mport match { + case None => + case Some(wmode) => stmts += Connect(info, wmode, one) } - case (s) => s map (remove_chirrtl_s) map (remove_chirrtl_e(MALE)) } - } - collect_smems_and_mports(m.body) - val sx = collect_refs(m.body) - Module(m.info,m.name, m.ports, remove_chirrtl_s(sx)) + if (stmts.isEmpty) sx else Block(sx +: stmts) + case s => s map remove_chirrtl_s(refs, raddrs) map remove_chirrtl_e(MALE) } - val modulesx = c.modules.map{ m => { - (m) match { - case (m:Module) => remove_chirrtl_m(m) - case (m:ExtModule) => m - }}} - Circuit(c.info,modulesx, c.main) } + + def remove_chirrtl_m(m: DefModule): DefModule = { + val mports = new MPortMap + val smems = new SeqMemSet + val types = new MPortTypeMap + val refs = new DataRefMap + val raddrs = new AddrMap + (m map collect_smems_and_mports(mports, smems) + map collect_refs(mports, smems, types, refs, raddrs) + map remove_chirrtl_s(refs, raddrs)) + } + + def run(c: Circuit): Circuit = + c copy (modules = (c.modules map remove_chirrtl_m)) } |
