aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDonggyu Kim2016-08-26 00:07:36 -0700
committerDonggyu Kim2016-09-08 13:15:56 -0700
commit2a513ff47eebe38a81a1312c51972fcecaeb114f (patch)
tree1f02ee22d028ff50a656a1b475160aa650d4113c /src
parenta6c0ee1c556d8e2ccd3aaf05f2c132734152a706 (diff)
refactor RemoveCHIRRTL
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/RemoveCHIRRTL.scala432
1 files changed, 196 insertions, 236 deletions
diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
index 47e6bbbc..2bae92a7 100644
--- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
+++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
@@ -27,270 +27,230 @@ MODIFICATIONS.
package firrtl.passes
-import com.typesafe.scalalogging.LazyLogging
-import java.nio.file.{Paths, Files}
-
// Datastructures
-import scala.collection.mutable.LinkedHashMap
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
import scala.collection.mutable.ArrayBuffer
import firrtl._
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
-import firrtl.PrimOps._
-import firrtl.WrappedExpression._
-case class MPort( val name : String, val clk : Expression)
-case class MPorts( val readers : ArrayBuffer[MPort], val writers : ArrayBuffer[MPort], val readwriters : ArrayBuffer[MPort])
-case class DataRef( val exp : Expression, val male : String, val female : String, val mask : String, val rdwrite : Boolean)
+case class MPort(name: String, clk: Expression)
+case class MPorts(readers: ArrayBuffer[MPort], writers: ArrayBuffer[MPort], readwriters: ArrayBuffer[MPort])
+case class DataRef(exp: Expression, male: String, female: String, mask: String, rdwrite: Boolean)
object RemoveCHIRRTL extends Pass {
def name = "Remove CHIRRTL"
- var mname = ""
- def create_exps (e:Expression) : Seq[Expression] = e match {
- case (e:Mux) =>
+
+ val ut = UnknownType
+ type MPortMap = collection.mutable.LinkedHashMap[String, MPorts]
+ type SeqMemSet = collection.mutable.HashSet[String]
+ type MPortTypeMap = collection.mutable.LinkedHashMap[String, Type]
+ type DataRefMap = collection.mutable.LinkedHashMap[String, DataRef]
+ type AddrMap = collection.mutable.HashMap[String, Expression]
+
+ def create_exps(e: Expression): Seq[Expression] = e match {
+ case (e: Mux) =>
val e1s = create_exps(e.tval)
val e2s = create_exps(e.fval)
- (e1s,e2s).zipped map ((e1,e2) => Mux(e.cond,e1,e2,mux_type(e1,e2)))
- case (e:ValidIf) =>
- create_exps(e.value) map (e1 => ValidIf(e.cond,e1,e1.tpe))
+ (e1s zip e2s) map { case (e1, e2) => Mux(e.cond, e1, e2, mux_type(e1, e2)) }
+ case (e: ValidIf) =>
+ create_exps(e.value) map (e1 => ValidIf(e.cond, e1, e1.tpe))
case (e) => (e.tpe) match {
- case (_:GroundType) => Seq(e)
- case (t:BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) =>
- exps ++ create_exps(SubField(e,f.name,f.tpe)))
- case (t:VectorType) => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) =>
- exps ++ create_exps(SubIndex(e,i,t.tpe)))
+ case (_: GroundType) => Seq(e)
+ case (t: BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) =>
+ exps ++ create_exps(SubField(e, f.name, f.tpe)))
+ case (t: VectorType) => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) =>
+ exps ++ create_exps(SubIndex(e, i, t.tpe)))
case UnknownType => Seq(e)
}
}
- def run (c:Circuit) : Circuit = {
- def remove_chirrtl_m (m:Module) : Module = {
- val hash = LinkedHashMap[String,MPorts]()
- val repl = LinkedHashMap[String,DataRef]()
- val raddrs = HashMap[String, Expression]()
- val ut = UnknownType
- val mport_types = LinkedHashMap[String,Type]()
- val smems = HashSet[String]()
- def EMPs () : MPorts = MPorts(ArrayBuffer[MPort](),ArrayBuffer[MPort](),ArrayBuffer[MPort]())
- def collect_smems_and_mports (s:Statement) : Statement = {
- (s) match {
- case (s:CDefMemory) if s.seq =>
- smems += s.name
- s
- case (s:CDefMPort) => {
- val mports = hash.getOrElse(s.mem,EMPs())
- s.direction match {
- case MRead => mports.readers += MPort(s.name,s.exps(1))
- case MWrite => mports.writers += MPort(s.name,s.exps(1))
- case MReadWrite => mports.readwriters += MPort(s.name,s.exps(1))
- }
- hash(s.mem) = mports
- s
- }
- case (s) => s map (collect_smems_and_mports)
+
+ private def EMPs: MPorts = MPorts(ArrayBuffer[MPort](), ArrayBuffer[MPort](), ArrayBuffer[MPort]())
+
+ def collect_smems_and_mports(mports: MPortMap, smems: SeqMemSet)(s: Statement): Statement = {
+ s match {
+ case (s:CDefMemory) if s.seq => smems += s.name
+ case (s:CDefMPort) =>
+ val p = mports getOrElse (s.mem, EMPs)
+ s.direction match {
+ case MRead => p.readers += MPort(s.name,s.exps(1))
+ case MWrite => p.writers += MPort(s.name,s.exps(1))
+ case MReadWrite => p.readwriters += MPort(s.name,s.exps(1))
}
+ mports(s.mem) = p
+ case s =>
+ }
+ s map collect_smems_and_mports(mports, smems)
+ }
+
+ def collect_refs(mports: MPortMap, smems: SeqMemSet, types: MPortTypeMap,
+ refs: DataRefMap, raddrs: AddrMap)(s: Statement): Statement = s match {
+ case (s: CDefMemory) =>
+ types(s.name) = s.tpe
+ val taddr = UIntType(IntWidth(math.max(1, ceil_log2(s.size))))
+ val tdata = s.tpe
+ def set_poison(vec: Seq[MPort], addr: String) = vec flatMap (r => Seq(
+ IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), addr, taddr)),
+ IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), "clk", taddr))
+ ))
+ def set_enable(vec: Seq[MPort], en: String) = vec map (r =>
+ Connect(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), en, taddr), zero)
+ )
+ def set_wmode (vec: Seq[MPort], wmode: String) = vec map (r =>
+ Connect(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), wmode, taddr), zero)
+ )
+ def set_write (vec: Seq[MPort], data: String, mask: String) = vec flatMap {r =>
+ val tmask = create_mask(s.tpe)
+ IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), data, tdata)) +:
+ (create_exps(SubField(SubField(Reference(s.name, ut), r.name, ut), mask, tmask))
+ map (Connect(s.info, _, zero))
+ )
}
- def collect_refs (s:Statement) : Statement = {
- (s) match {
- case (s:CDefMemory) => {
- mport_types(s.name) = s.tpe
- val stmts = ArrayBuffer[Statement]()
- val taddr = UIntType(IntWidth(scala.math.max(1,ceil_log2(s.size))))
- val tdata = s.tpe
- def set_poison (vec:Seq[MPort],addr:String) : Unit = {
- for (r <- vec ) {
- stmts += IsInvalid(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),addr,taddr))
- stmts += IsInvalid(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),"clk",taddr))
- }
- }
- def set_enable (vec:Seq[MPort],en:String) : Unit = {
- for (r <- vec ) {
- stmts += Connect(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),en,taddr),zero)
- }}
- def set_wmode (vec:Seq[MPort],wmode:String) : Unit = {
- for (r <- vec) {
- stmts += Connect(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),wmode,taddr),zero)
- }}
- def set_write (vec:Seq[MPort],data:String,mask:String) : Unit = {
- val tmask = create_mask(s.tpe)
- for (r <- vec ) {
- stmts += IsInvalid(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),data,tdata))
- for (x <- create_exps(SubField(SubField(Reference(s.name,ut),r.name,ut),mask,tmask)) ) {
- stmts += Connect(s.info,x,zero)
- }}}
- val rds = (hash.getOrElse(s.name,EMPs())).readers
- set_poison(rds,"addr")
- set_enable(rds,"en")
- val wrs = (hash.getOrElse(s.name,EMPs())).writers
- set_poison(wrs,"addr")
- set_enable(wrs,"en")
- set_write(wrs,"data","mask")
- val rws = (hash.getOrElse(s.name,EMPs())).readwriters
- set_poison(rws,"addr")
- set_wmode(rws,"wmode")
- set_enable(rws,"en")
- set_write(rws,"wdata","wmask")
- val read_l = if (s.seq) 1 else 0
- val mem = DefMemory(s.info,s.name,s.tpe,s.size,1,read_l,rds.map(_.name),wrs.map(_.name),rws.map(_.name))
- Block(Seq(mem,Block(stmts)))
+ val rds = (mports getOrElse (s.name, EMPs)).readers
+ val wrs = (mports getOrElse (s.name, EMPs)).writers
+ val rws = (mports getOrElse (s.name, EMPs)).readwriters
+ val stmts = set_poison(rds, "addr") ++
+ set_enable(rds, "en") ++
+ set_poison(wrs, "addr") ++
+ set_enable(wrs, "en") ++
+ set_write(wrs, "data", "mask") ++
+ set_poison(rws, "addr") ++
+ set_wmode(rws, "wmode") ++
+ set_enable(rws, "en") ++
+ set_write(rws, "wdata", "wmask")
+ val mem = DefMemory(s.info, s.name, s.tpe, s.size, 1, if (s.seq) 1 else 0,
+ rds map (_.name), wrs map (_.name), rws map (_.name))
+ Block(mem +: stmts)
+ case (s: CDefMPort) => {
+ types(s.name) = types(s.mem)
+ val addrs = ArrayBuffer[String]()
+ val clks = ArrayBuffer[String]()
+ val ens = ArrayBuffer[String]()
+ s.direction match {
+ case MReadWrite =>
+ refs(s.name) = DataRef(SubField(Reference(s.mem, ut), s.name, ut), "rdata", "wdata", "wmask", true)
+ addrs += "addr"
+ clks += "clk"
+ ens += "en"
+ case MWrite =>
+ refs(s.name) = DataRef(SubField(Reference(s.mem, ut), s.name, ut), "data", "data", "mask", false)
+ addrs += "addr"
+ clks += "clk"
+ ens += "en"
+ case MRead =>
+ refs(s.name) = DataRef(SubField(Reference(s.mem, ut), s.name, ut), "data", "data", "blah", false)
+ addrs += "addr"
+ clks += "clk"
+ s.exps.head match {
+ case e: Reference if smems(s.mem) =>
+ raddrs(e.name) = SubField(SubField(Reference(s.mem, ut), s.name, ut), "en", ut)
+ case _ => ens += "en"
}
- case (s:CDefMPort) => {
- mport_types(s.name) = mport_types(s.mem)
- val addrs = ArrayBuffer[String]()
- val clks = ArrayBuffer[String]()
- val ens = ArrayBuffer[String]()
- val masks = ArrayBuffer[String]()
- s.direction match {
- case MReadWrite => {
- repl(s.name) = DataRef(SubField(Reference(s.mem,ut),s.name,ut),"rdata","wdata","wmask",true)
- addrs += "addr"
- clks += "clk"
- ens += "en"
- masks += "wmask"
- }
- case MWrite => {
- repl(s.name) = DataRef(SubField(Reference(s.mem,ut),s.name,ut),"data","data","mask",false)
- addrs += "addr"
- clks += "clk"
- ens += "en"
- masks += "mask"
- }
- case MRead => {
- repl(s.name) = DataRef(SubField(Reference(s.mem,ut),s.name,ut),"data","data","blah",false)
- addrs += "addr"
- clks += "clk"
- s.exps(0) match {
- case e: Reference if smems(s.mem) =>
- raddrs(e.name) = SubField(SubField(Reference(s.mem,ut),s.name,ut),"en",ut)
- case _ => ens += "en"
- }
- }
- }
- val stmts = ArrayBuffer[Statement]()
- for (x <- addrs ) {
- stmts += Connect(s.info,SubField(SubField(Reference(s.mem,ut),s.name,ut),x,ut),s.exps(0))
- }
- for (x <- clks ) {
- stmts += Connect(s.info,SubField(SubField(Reference(s.mem,ut),s.name,ut),x,ut),s.exps(1))
- }
- for (x <- ens ) {
- stmts += Connect(s.info,SubField(SubField(Reference(s.mem,ut),s.name,ut),x,ut),one)
- }
- Block(stmts)
+ }
+ Block(
+ (addrs map (x => Connect(s.info, SubField(SubField(Reference(s.mem, ut), s.name, ut), x, ut), s.exps(0)))) ++
+ (clks map (x => Connect(s.info, SubField(SubField(Reference(s.mem, ut), s.name, ut), x, ut), s.exps(1)))) ++
+ (ens map (x => Connect(s.info,SubField(SubField(Reference(s.mem,ut), s.name, ut), x, ut), one))))
+ }
+ case (s) => s map collect_refs(mports, smems, types, refs, raddrs)
+ }
+
+ def get_mask(refs: DataRefMap)(e: Expression): Expression =
+ e map get_mask(refs) match {
+ case e: Reference => refs get e.name match {
+ case None => e
+ case Some(p) => SubField(p.exp, p.mask, create_mask(e.tpe))
+ }
+ case e => e
+ }
+
+ def remove_chirrtl_s(refs: DataRefMap, raddrs: AddrMap)(s: Statement): Statement = {
+ var has_write_mport = false
+ var has_readwrite_mport: Option[Expression] = None
+ var has_read_mport: Option[Expression] = None
+ def remove_chirrtl_e(g: Gender)(e: Expression): Expression = e match {
+ case Reference(name, tpe) => refs get name match {
+ case Some(p) => g match {
+ case FEMALE =>
+ has_write_mport = true
+ if (p.rdwrite) has_readwrite_mport = Some(SubField(p.exp, "wmode", UIntType(IntWidth(1))))
+ SubField(p.exp, p.female, tpe)
+ case MALE =>
+ SubField(p.exp, p.male, tpe)
+ }
+ case None => g match {
+ case FEMALE => raddrs get name match {
+ case Some(en) => has_read_mport = Some(en) ; e
+ case None => e
}
- case (s) => s map (collect_refs)
+ case MALE => e
}
}
- def remove_chirrtl_s (s:Statement) : Statement = {
- var has_write_mport = false
- var has_read_mport: Option[Expression] = None
- var has_readwrite_mport: Option[Expression] = None
- def remove_chirrtl_e (g:Gender)(e:Expression) : Expression = {
- (e) match {
- case (e:Reference) if repl contains e.name =>
- val vt = repl(e.name)
- g match {
- case MALE => SubField(vt.exp,vt.male,e.tpe)
- case FEMALE => {
- has_write_mport = true
- if (vt.rdwrite)
- has_readwrite_mport = Some(SubField(vt.exp,"wmode",UIntType(IntWidth(1))))
- SubField(vt.exp,vt.female,e.tpe)
- }
- }
- case (e:Reference) if g == FEMALE && (raddrs contains e.name) =>
- has_read_mport = Some(raddrs(e.name))
- e
- case (e:Reference) => e
- case (e:SubAccess) => SubAccess(remove_chirrtl_e(g)(e.expr),remove_chirrtl_e(MALE)(e.index),e.tpe)
- case (e) => e map (remove_chirrtl_e(g))
- }
+ case SubAccess(expr, index, tpe) => SubAccess(
+ remove_chirrtl_e(g)(expr), remove_chirrtl_e(MALE)(index), tpe)
+ case e => e map remove_chirrtl_e(g)
+ }
+ (s) match {
+ case DefNode(info, name, value) =>
+ val valuex = remove_chirrtl_e(MALE)(value)
+ val sx = DefNode(info, name, valuex)
+ has_read_mport match {
+ case None => sx
+ case Some(en) => Block(Seq(sx, Connect(info, en, one)))
}
- def get_mask (e:Expression) : Expression = {
- (e map (get_mask)) match {
- case (e:Reference) => {
- if (repl.contains(e.name)) {
- val vt = repl(e.name)
- val t = create_mask(e.tpe)
- SubField(vt.exp,vt.mask,t)
- } else e
- }
- case (e) => e
- }
+ case Connect(info, loc, expr) =>
+ val rocx = remove_chirrtl_e(MALE)(expr)
+ val locx = remove_chirrtl_e(FEMALE)(loc)
+ val sx = Connect(info, locx, rocx)
+ val stmts = ArrayBuffer[Statement]()
+ has_read_mport match {
+ case None =>
+ case Some(en) => stmts += Connect(info, en, one)
}
- (s) match {
- case (s:DefNode) => {
- val stmts = ArrayBuffer[Statement]()
- val valuex = remove_chirrtl_e(MALE)(s.value)
- stmts += DefNode(s.info,s.name,valuex)
- has_read_mport match {
- case None =>
- case Some(en) => stmts += Connect(s.info,en,one)
- }
- if (stmts.size > 1) Block(stmts)
- else stmts(0)
- }
- case (s:Connect) => {
- val stmts = ArrayBuffer[Statement]()
- val rocx = remove_chirrtl_e(MALE)(s.expr)
- val locx = remove_chirrtl_e(FEMALE)(s.loc)
- stmts += Connect(s.info,locx,rocx)
- has_read_mport match {
- case None =>
- case Some(en) => stmts += Connect(s.info,en,one)
- }
- if (has_write_mport) {
- val e = get_mask(s.loc)
- for (x <- create_exps(e) ) {
- stmts += Connect(s.info,x,one)
- }
- has_readwrite_mport match {
- case None =>
- case Some(wmode) => stmts += Connect(s.info,wmode,one)
- }
- }
- if (stmts.size > 1) Block(stmts)
- else stmts(0)
+ if (has_write_mport) {
+ val locs = create_exps(get_mask(refs)(loc))
+ stmts ++= (locs map (x => Connect(info, x, one)))
+ has_readwrite_mport match {
+ case None =>
+ case Some(wmode) => stmts += Connect(info, wmode, one)
}
- case (s:PartialConnect) => {
- val stmts = ArrayBuffer[Statement]()
- val locx = remove_chirrtl_e(FEMALE)(s.loc)
- val rocx = remove_chirrtl_e(MALE)(s.expr)
- stmts += PartialConnect(s.info,locx,rocx)
- has_read_mport match {
- case None =>
- case Some(en) => stmts += Connect(s.info,en,one)
- }
- if (has_write_mport) {
- val ls = get_valid_points(s.loc.tpe,s.expr.tpe,Default,Default)
- val locs = create_exps(get_mask(s.loc))
- for (x <- ls ) {
- val locx = locs(x._1)
- stmts += Connect(s.info,locx,one)
- }
- has_readwrite_mport match {
- case None =>
- case Some(wmode) => stmts += Connect(s.info,wmode,one)
- }
- }
- if (stmts.size > 1) Block(stmts)
- else stmts(0)
+ }
+ if (stmts.isEmpty) sx else Block(sx +: stmts)
+ case PartialConnect(info, loc, expr) =>
+ val locx = remove_chirrtl_e(FEMALE)(loc)
+ val rocx = remove_chirrtl_e(MALE)(expr)
+ val sx = PartialConnect(info, locx, rocx)
+ val stmts = ArrayBuffer[Statement]()
+ has_read_mport match {
+ case None =>
+ case Some(en) => stmts += Connect(info, en, one)
+ }
+ if (has_write_mport) {
+ val ls = get_valid_points(loc.tpe, expr.tpe, Default, Default)
+ val locs = create_exps(get_mask(refs)(loc))
+ stmts ++= (ls map { case (x, _) => Connect(info, locs(x), one) })
+ has_readwrite_mport match {
+ case None =>
+ case Some(wmode) => stmts += Connect(info, wmode, one)
}
- case (s) => s map (remove_chirrtl_s) map (remove_chirrtl_e(MALE))
}
- }
- collect_smems_and_mports(m.body)
- val sx = collect_refs(m.body)
- Module(m.info,m.name, m.ports, remove_chirrtl_s(sx))
+ if (stmts.isEmpty) sx else Block(sx +: stmts)
+ case s => s map remove_chirrtl_s(refs, raddrs) map remove_chirrtl_e(MALE)
}
- val modulesx = c.modules.map{ m => {
- (m) match {
- case (m:Module) => remove_chirrtl_m(m)
- case (m:ExtModule) => m
- }}}
- Circuit(c.info,modulesx, c.main)
}
+
+ def remove_chirrtl_m(m: DefModule): DefModule = {
+ val mports = new MPortMap
+ val smems = new SeqMemSet
+ val types = new MPortTypeMap
+ val refs = new DataRefMap
+ val raddrs = new AddrMap
+ (m map collect_smems_and_mports(mports, smems)
+ map collect_refs(mports, smems, types, refs, raddrs)
+ map remove_chirrtl_s(refs, raddrs))
+ }
+
+ def run(c: Circuit): Circuit =
+ c copy (modules = (c.modules map remove_chirrtl_m))
}