diff options
Diffstat (limited to 'src/main/scala/firrtl/passes/Passes.scala')
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 512 |
1 files changed, 454 insertions, 58 deletions
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 6c77d35d..a6b53e86 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -9,12 +9,13 @@ import scala.sys.process._ import scala.io.Source // Datastructures -import scala.collection.mutable.HashMap +import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ArrayBuffer import firrtl._ import firrtl.Utils._ import firrtl.PrimOps._ +import firrtl.WrappedExpression._ trait Pass extends LazyLogging { def name: String @@ -107,7 +108,7 @@ object ResolveKinds extends Pass { def name = "Resolve Kinds" def run (c:Circuit): Circuit = { def resolve_kinds (m:Module, c:Circuit):Module = { - val kinds = HashMap[String,Kind]() + val kinds = LinkedHashMap[String,Kind]() def resolve (body:Stmt) = { def resolve_expr (e:Expression):Expression = { e match { @@ -157,7 +158,7 @@ object ResolveKinds extends Pass { object InferTypes extends Pass { private var mname = "" def name = "Infer Types" - val width_name_hash = HashMap[String,Int]() + val width_name_hash = LinkedHashMap[String,Int]() def set_type (s:Stmt,t:Type) : Stmt = { s match { case s:DefWire => DefWire(s.info,s.name,t) @@ -175,9 +176,9 @@ object InferTypes extends Pass { } def remove_unknowns (t:Type): Type = mapr(remove_unknowns_w _,t) def run (c:Circuit): Circuit = { - val module_types = HashMap[String,Type]() + val module_types = LinkedHashMap[String,Type]() def infer_types (m:Module) : Module = { - val types = HashMap[String,Type]() + val types = LinkedHashMap[String,Type]() def infer_types_e (e:Expression) : Expression = { eMap(infer_types_e _,e) match { case e:ValidIf => ValidIf(e.cond,e.value,tpe(e.value)) @@ -328,10 +329,10 @@ object CheckGenders extends Pass with StanzaPass { object InferWidths extends Pass { def name = "Infer Widths" var mname = "" - def solve_constraints (l:Seq[WGeq]) : HashMap[String,Width] = { + def solve_constraints (l:Seq[WGeq]) : LinkedHashMap[String,Width] = { def unique (ls:Seq[Width]) : Seq[Width] = ls.map(w => new WrappedWidth(w)).distinct.map(_.w) - def make_unique (ls:Seq[WGeq]) : HashMap[String,Width] = { - val h = HashMap[String,Width]() + def make_unique (ls:Seq[WGeq]) : LinkedHashMap[String,Width] = { + val h = LinkedHashMap[String,Width]() for (g <- ls) { (g.loc) match { case (w:VarWidth) => { @@ -369,10 +370,10 @@ object InferWidths extends Pass { case (w1,w2) => w }} case (w:ExpWidth) => { (w.arg1) match { - case (w1:IntWidth) => IntWidth((2 ^ w1.width) - 1) + case (w1:IntWidth) => IntWidth(BigInt((scala.math.pow(2,w1.width.toDouble) - 1).toLong)) case (w1) => w }} case (w) => w } } - def substitute (h:HashMap[String,Width])(w:Width) : Width = { + def substitute (h:LinkedHashMap[String,Width])(w:Width) : Width = { //;println-all-debug(["Substituting for [" w "]"]) val wx = simplify(w) //;println-all-debug(["After Simplify: [" wx "]"]) @@ -394,7 +395,7 @@ object InferWidths extends Pass { //;println-all-debug(["not varwidth!" w]) } } - def b_sub (h:HashMap[String,Width])(w:Width) : Width = { + def b_sub (h:LinkedHashMap[String,Width])(w:Width) : Width = { (wMap(b_sub(h) _,w)) match { case (w:VarWidth) => if (h.contains(w.name)) h(w.name) else w case (w) => w @@ -433,54 +434,44 @@ object InferWidths extends Pass { //; 2) Remove Cycles //; 3) Move to solved if not self-recursive val u = make_unique(l) - /* - println("======== UNIQUE CONSTRAINTS ========") - for (x <- u) { println(x) } - println("====================================") - */ + + //println("======== UNIQUE CONSTRAINTS ========") + //for (x <- u) { println(x) } + //println("====================================") + - val f = HashMap[String,Width]() + val f = LinkedHashMap[String,Width]() val o = ArrayBuffer[String]() for (x <- u) { - /* - println("==== SOLUTIONS TABLE ====") - for (x <- f) println(x) - println("=========================") - */ + //println("==== SOLUTIONS TABLE ====") + //for (x <- f) println(x) + //println("=========================") val (n, e) = (x._1, x._2) val e_sub = substitute(f)(e) - /* - println("Solving " + n + " => " + e) - println("After Substitute: " + n + " => " + e_sub) - println("==== SOLUTIONS TABLE (Post Substitute) ====") - for (x <- f) println(x) - println("=========================") - */ + //println("Solving " + n + " => " + e) + //println("After Substitute: " + n + " => " + e_sub) + //println("==== SOLUTIONS TABLE (Post Substitute) ====") + //for (x <- f) println(x) + //println("=========================") val ex = remove_cycle(n)(e_sub) - /* - println("After Remove Cycle: " + n + " => " + ex) - */ + //println("After Remove Cycle: " + n + " => " + ex) if (!self_rec(n,ex)) { - /* - println("Not rec!: " + n + " => " + ex) - println("Adding [" + n + "=>" + ex + "] to Solutions Table") - */ + //println("Not rec!: " + n + " => " + ex) + //println("Adding [" + n + "=>" + ex + "] to Solutions Table") o += n f(n) = ex } } - /* - println("Forward Solved Constraints") - for (x <- f) println(x) - */ + //println("Forward Solved Constraints") + //for (x <- f) println(x) //; Backwards Solve - val b = HashMap[String,Width]() + val b = LinkedHashMap[String,Width]() for (i <- 0 until o.size) { val n = o(o.size - 1 - i) /* @@ -510,7 +501,7 @@ object InferWidths extends Pass { case (t:ClockType) => IntWidth(1) case (t) => error("No width!"); IntWidth(-1) } } def width_BANG (e:Expression) : Width = width_BANG(tpe(e)) - def reduce_var_widths (c:Circuit,h:HashMap[String,Width]) : Circuit = { + def reduce_var_widths (c:Circuit,h:LinkedHashMap[String,Width]) : Circuit = { def evaluate (w:Width) : Width = { def apply_2 (a:Option[BigInt],b:Option[BigInt], f: (BigInt,BigInt) => BigInt) : Option[BigInt] = { (a,b) match { @@ -525,6 +516,7 @@ object InferWidths extends Pass { } def max (a:BigInt,b:BigInt) : BigInt = if (a >= b) a else b def min (a:BigInt,b:BigInt) : BigInt = if (a >= b) b else a + def pow (a:BigInt,b:BigInt) : BigInt = BigInt((scala.math.pow(a.toDouble,b.toDouble) - 1).toLong) def solve (w:Width) : Option[BigInt] = { (w) match { case (w:VarWidth) => { @@ -539,7 +531,7 @@ object InferWidths extends Pass { case (w:MinWidth) => apply_l(w.args.map(solve _),min) case (w:PlusWidth) => apply_2(solve(w.arg1),solve(w.arg2),{_ + _}) case (w:MinusWidth) => apply_2(solve(w.arg1),solve(w.arg2),{_ - _}) - case (w:ExpWidth) => apply_2(Some(BigInt(2)),solve(w.arg1),{(x,y) => (x ^ y) - BigInt(1)}) + case (w:ExpWidth) => apply_2(Some(BigInt(2)),solve(w.arg1),pow) case (w:IntWidth) => Some(w.width) case (w) => println(w); error("Shouldn't be here"); None; } @@ -691,7 +683,7 @@ object ExpandConnects extends Pass { def run (c:Circuit): Circuit = { def expand_connects (m:InModule) : InModule = { mname = m.name - val genders = HashMap[String,Gender]() + val genders = LinkedHashMap[String,Gender]() def expand_s (s:Stmt) : Stmt = { def set_gender (e:Expression) : Expression = { eMap(set_gender _,e) match { @@ -854,7 +846,7 @@ object RemoveAccesses extends Pass { def remove_s (s:Stmt) : Stmt = { val stmts = ArrayBuffer[Stmt]() def create_temp (e:Expression) : Expression = { - val n = firrtl_gensym("GEN",sh) + val n = firrtl_gensym_module(mname) stmts += DefWire(info(s),n,tpe(e)) WRef(n,tpe(e),kind(e),gender(e)) } @@ -897,7 +889,7 @@ object RemoveAccesses extends Pass { if (has_access(s.loc)) { val ls = get_locations(s.loc) val locx = - if (ls.size == 1 & ls(0).guard == one) s.loc + if (ls.size == 1 & weq(ls(0).guard,one)) s.loc else { val temp = create_temp(s.loc) for (x <- ls) { stmts += Conditionally(s.info,x.guard,Connect(s.info,x.base,temp),Empty()) } @@ -930,12 +922,12 @@ object ExpandWhens extends Pass { def name = "Expand Whens" var mname = "" // ; ========== Expand When Utilz ========== - def add (hash:HashMap[WrappedExpression,Expression],key:WrappedExpression,value:Expression) = { + def add (hash:LinkedHashMap[WrappedExpression,Expression],key:WrappedExpression,value:Expression) = { hash += (key -> value) } - def get_entries (hash:HashMap[WrappedExpression,Expression],exps:Seq[Expression]) : HashMap[WrappedExpression,Expression] = { - val hashx = HashMap[WrappedExpression,Expression]() + def get_entries (hash:LinkedHashMap[WrappedExpression,Expression],exps:Seq[Expression]) : LinkedHashMap[WrappedExpression,Expression] = { + val hashx = LinkedHashMap[WrappedExpression,Expression]() exps.foreach { e => { val value = hash.get(e) value match { @@ -987,10 +979,10 @@ object ExpandWhens extends Pass { val bodyx = void_all_s(m.body) InModule(m.info,m.name,m.ports,Begin(Seq(Begin(voids),bodyx))) } - def expand_whens (m:InModule) : Tuple2[HashMap[WrappedExpression,Expression],ArrayBuffer[Stmt]] = { + def expand_whens (m:InModule) : Tuple2[LinkedHashMap[WrappedExpression,Expression],ArrayBuffer[Stmt]] = { val simlist = ArrayBuffer[Stmt]() mname = m.name - def expand_whens (netlist:HashMap[WrappedExpression,Expression],p:Expression)(s:Stmt) : Stmt = { + def expand_whens (netlist:LinkedHashMap[WrappedExpression,Expression],p:Expression)(s:Stmt) : Stmt = { (s) match { case (s:Connect) => netlist(s.loc) = s.exp case (s:IsInvalid) => netlist(s.exp) = WInvalid() @@ -1025,14 +1017,14 @@ object ExpandWhens extends Pass { } } case (s:Print) => { - if (p == one) { + if (weq(p,one)) { simlist += s } else { simlist += Print(s.info,s.string,s.args,s.clk,AND(p,s.en)) } } case (s:Stop) => { - if (p == one) { + if (weq(p,one)) { simlist += s } else { simlist += Stop(s.info,s.ret,s.clk,AND(p,s.en)) @@ -1042,7 +1034,7 @@ object ExpandWhens extends Pass { } s } - val netlist = HashMap[WrappedExpression,Expression]() + val netlist = LinkedHashMap[WrappedExpression,Expression]() expand_whens(netlist,one)(m.body) //println("Netlist:") @@ -1052,7 +1044,7 @@ object ExpandWhens extends Pass { ( netlist, simlist ) } - def create_module (netlist:HashMap[WrappedExpression,Expression],simlist:ArrayBuffer[Stmt],m:InModule) : InModule = { + def create_module (netlist:LinkedHashMap[WrappedExpression,Expression],simlist:ArrayBuffer[Stmt],m:InModule) : InModule = { mname = m.name val stmts = ArrayBuffer[Stmt]() val connections = ArrayBuffer[Stmt]() @@ -1242,10 +1234,9 @@ object SplitExp extends Pass { def split_exp (m:InModule) : InModule = { mname = m.name val v = ArrayBuffer[Stmt]() - val sh = sym_hash def split_exp_s (s:Stmt) : Stmt = { def split (e:Expression) : Expression = { - val n = firrtl_gensym("GEN",sh) + val n = firrtl_gensym_module(mname) v += DefNode(info(s),n,e) WRef(n,tpe(e),kind(e),gender(e)) } @@ -1385,7 +1376,7 @@ object LowerTypes extends Pass { //;------------- Pass ------------------ def lower_types (m:Module) : Module = { - val mdt = HashMap[String,Type]() + val mdt = LinkedHashMap[String,Type]() mname = m.name def lower_types (s:Stmt) : Stmt = { def lower_mem (e:Expression) : Seq[Expression] = { @@ -1522,3 +1513,408 @@ object LowerTypes extends Pass { } } +object CInferTypes extends Pass { + def name = "CInfer Types" + var mname = "" + def set_type (s:Stmt,t:Type) : Stmt = { + (s) match { + case (s:DefWire) => DefWire(s.info,s.name,t) + case (s:DefRegister) => DefRegister(s.info,s.name,t,s.clock,s.reset,s.init) + case (s:CDefMemory) => CDefMemory(s.info,s.name,t,s.size,s.seq) + case (s:CDefMPort) => CDefMPort(s.info,s.name,t,s.mem,s.exps,s.direction) + case (s:DefNode) => s + case (s:DefPoison) => DefPoison(s.info,s.name,t) + } + } + + def to_field (p:Port) : Field = { + if (p.direction == OUTPUT) Field(p.name,DEFAULT,p.tpe) + else if (p.direction == INPUT) Field(p.name,REVERSE,p.tpe) + else error("Shouldn't be here"); Field(p.name,REVERSE,p.tpe) + } + def module_type (m:Module) : Type = BundleType(m.ports.map(p => to_field(p))) + def field_type (v:Type,s:String) : Type = { + (v) match { + case (v:BundleType) => { + val ft = v.fields.find(p => p.name == s) + if (ft != None) ft.get.tpe + else UnknownType() + } + case (v) => UnknownType() + } + } + def sub_type (v:Type) : Type = + (v) match { + case (v:VectorType) => v.tpe + case (v) => UnknownType() + } + def run (c:Circuit) : Circuit = { + val module_types = LinkedHashMap[String,Type]() + def infer_types (m:Module) : Module = { + val types = LinkedHashMap[String,Type]() + def infer_types_e (e:Expression) : Expression = { + (eMap(infer_types_e _,e)) match { + case (e:Ref) => Ref(e.name, types.getOrElse(e.name,UnknownType())) + case (e:SubField) => SubField(e.exp,e.name,field_type(tpe(e.exp),e.name)) + case (e:SubIndex) => SubIndex(e.exp,e.value,sub_type(tpe(e.exp))) + case (e:SubAccess) => SubAccess(e.exp,e.index,sub_type(tpe(e.exp))) + case (e:DoPrim) => set_primop_type(e) + case (e:Mux) => Mux(e.cond,e.tval,e.fval,mux_type(e.tval,e.tval)) + case (e:ValidIf) => ValidIf(e.cond,e.value,tpe(e.value)) + case (_:UIntValue|_:SIntValue) => e + } + } + def infer_types_s (s:Stmt) : Stmt = { + (s) match { + case (s:DefRegister) => { + types(s.name) = s.tpe + eMap(infer_types_e _,s) + s + } + case (s:DefWire) => { + types(s.name) = s.tpe + s + } + case (s:DefPoison) => { + types(s.name) = s.tpe + s + } + case (s:DefNode) => { + val sx = eMap(infer_types_e _,s) + val t = get_type(sx) + types(s.name) = t + sx + } + case (s:DefMemory) => { + types(s.name) = get_type(s) + s + } + case (s:CDefMPort) => { + val t = types.getOrElse(s.mem,UnknownType()) + types(s.name) = t + CDefMPort(s.info,s.name,t,s.mem,s.exps,s.direction) + } + case (s:CDefMemory) => { + types(s.name) = s.tpe + s + } + case (s:DefInstance) => { + types(s.name) = module_types.getOrElse(s.module,UnknownType()) + s + } + case (s) => eMap(infer_types_e _,sMap(infer_types_s _,s)) + } + } + for (p <- m.ports) { + types(p.name) = p.tpe + } + (m) match { + case (m:InModule) => InModule(m.info,m.name,m.ports,infer_types_s(m.body)) + case (m:ExModule) => m + } + } + + //; MAIN + for (m <- c.modules) { + module_types(m.name) = module_type(m) + } + val modulesx = c.modules.map(m => infer_types(m)) + Circuit(c.info, modulesx, c.main) + } +} + +object CInferMDir extends Pass { + def name = "CInfer MDir" + var mname = "" + def run (c:Circuit) : Circuit = { + def infer_mdir (m:Module) : Module = { + val mports = LinkedHashMap[String,MPortDir]() + def infer_mdir_e (dir:MPortDir)(e:Expression) : Expression = { + (eMap(infer_mdir_e(dir) _,e)) match { + case (e:Ref) => { + if (mports.contains(e.name)) { + val new_mport_dir = { + (mports(e.name),dir) match { + case (MInfer,MInfer) => error("Shouldn't be here") + case (MInfer,MWrite) => MWrite + case (MInfer,MRead) => MRead + case (MInfer,MReadWrite) => MReadWrite + case (MWrite,MInfer) => error("Shouldn't be here") + case (MWrite,MWrite) => MWrite + case (MWrite,MRead) => MReadWrite + case (MWrite,MReadWrite) => MReadWrite + case (MRead,MInfer) => error("Shouldn't be here") + case (MRead,MWrite) => MReadWrite + case (MRead,MRead) => MRead + case (MRead,MReadWrite) => MReadWrite + case (MReadWrite,MInfer) => error("Shouldn't be here") + case (MReadWrite,MWrite) => MReadWrite + case (MReadWrite,MRead) => MReadWrite + case (MReadWrite,MReadWrite) => MReadWrite + } + } + mports(e.name) = new_mport_dir + } + e + } + case (e) => e + } + } + def infer_mdir_s (s:Stmt) : Stmt = { + (s) match { + case (s:CDefMPort) => { + mports(s.name) = s.direction + eMap(infer_mdir_e(MRead) _,s) + } + case (s:Connect) => { + infer_mdir_e(MRead)(s.exp) + infer_mdir_e(MWrite)(s.loc) + s + } + case (s:BulkConnect) => { + infer_mdir_e(MRead)(s.exp) + infer_mdir_e(MWrite)(s.loc) + s + } + case (s) => eMap(infer_mdir_e(MRead) _, sMap(infer_mdir_s,s)) + } + } + def set_mdir_s (s:Stmt) : Stmt = { + (s) match { + case (s:CDefMPort) => + CDefMPort(s.info,s.name,s.tpe,s.mem,s.exps,mports(s.name)) + case (s) => sMap(set_mdir_s _,s) + } + } + (m) match { + case (m:InModule) => { + infer_mdir_s(m.body) + InModule(m.info,m.name,m.ports,set_mdir_s(m.body)) + } + case (m:ExModule) => m + } + } + + //; MAIN + Circuit(c.info, c.modules.map(m => infer_mdir(m)), c.main) + } +} + +case class MPort( val name : String, val clk : Expression) +case class MPorts( val readers : ArrayBuffer[MPort], val writers : ArrayBuffer[MPort], val readwriters : ArrayBuffer[MPort]) +case class DataRef( val exp : Expression, val male : String, val female : String, val mask : String, val rdwrite : Boolean) + +object RemoveCHIRRTL extends Pass { + def name = "Remove CHIRRTL" + var mname = "" + def create_exps (e:Expression) : Seq[Expression] = { + (e) match { + case (e:Mux)=> + (create_exps(e.tval),create_exps(e.fval)).zipped.map((e1,e2) => { + Mux(e.cond,e1,e2,mux_type(e1,e2)) + }) + case (e:ValidIf) => + create_exps(e.value).map(e1 => { + ValidIf(e.cond,e1,tpe(e1)) + }) + case (e) => (tpe(e)) match { + case (_:UIntType|_:SIntType|_:ClockType) => Seq(e) + case (t:BundleType) => + t.fields.flatMap(f => create_exps(SubField(e,f.name,f.tpe))) + case (t:VectorType)=> + (0 until t.size).flatMap(i => create_exps(SubIndex(e,i,t.tpe))) + case (t:UnknownType) => Seq(e) + } + } + } + def run (c:Circuit) : Circuit = { + def remove_chirrtl_m (m:InModule) : InModule = { + val hash = LinkedHashMap[String,MPorts]() + val repl = LinkedHashMap[String,DataRef]() + val ut = UnknownType() + val mport_types = LinkedHashMap[String,Type]() + def EMPs () : MPorts = MPorts(ArrayBuffer[MPort](),ArrayBuffer[MPort](),ArrayBuffer[MPort]()) + def collect_mports (s:Stmt) : Stmt = { + (s) match { + case (s:CDefMPort) => { + val mports = hash.getOrElse(s.mem,EMPs()) + s.direction match { + case MRead => mports.readers += MPort(s.name,s.exps(1)) + case MWrite => mports.writers += MPort(s.name,s.exps(1)) + case MReadWrite => mports.readwriters += MPort(s.name,s.exps(1)) + } + hash(s.mem) = mports + s + } + case (s) => sMap(collect_mports _,s) + } + } + def collect_refs (s:Stmt) : Stmt = { + (s) match { + case (s:CDefMemory) => { + mport_types(s.name) = s.tpe + val stmts = ArrayBuffer[Stmt]() + val taddr = UIntType(IntWidth(scala.math.max(1,ceil_log2(s.size)))) + val tdata = s.tpe + def set_poison (vec:Seq[MPort],addr:String) : Unit = { + for (r <- vec ) { + stmts += IsInvalid(s.info,SubField(SubField(Ref(s.name,ut),r.name,ut),addr,taddr)) + stmts += Connect(s.info,SubField(SubField(Ref(s.name,ut),r.name,ut),"clk",taddr),r.clk) + } + } + def set_enable (vec:Seq[MPort],en:String) : Unit = { + for (r <- vec ) { + stmts += Connect(s.info,SubField(SubField(Ref(s.name,ut),r.name,ut),en,taddr),zero) + }} + def set_wmode (vec:Seq[MPort],wmode:String) : Unit = { + for (r <- vec) { + stmts += Connect(s.info,SubField(SubField(Ref(s.name,ut),r.name,ut),wmode,taddr),zero) + }} + def set_write (vec:Seq[MPort],data:String,mask:String) : Unit = { + val tmask = create_mask(s.tpe) + for (r <- vec ) { + stmts += IsInvalid(s.info,SubField(SubField(Ref(s.name,ut),r.name,ut),data,tdata)) + for (x <- create_exps(SubField(SubField(Ref(s.name,ut),r.name,ut),mask,tmask)) ) { + stmts += Connect(s.info,x,zero) + }}} + val rds = (hash.getOrElse(s.name,EMPs())).readers + set_poison(rds,"addr") + set_enable(rds,"en") + val wrs = (hash.getOrElse(s.name,EMPs())).writers + set_poison(wrs,"addr") + set_enable(wrs,"en") + set_write(wrs,"data","mask") + val rws = (hash.getOrElse(s.name,EMPs())).readwriters + set_poison(rws,"addr") + set_wmode(rws,"wmode") + set_enable(rws,"en") + set_write(rws,"data","mask") + val read_l = if (s.seq) 1 else 0 + val mem = DefMemory(s.info,s.name,s.tpe,s.size,1,read_l,rds.map(_.name),wrs.map(_.name),rws.map(_.name)) + Begin(Seq(mem,Begin(stmts))) + } + case (s:CDefMPort) => { + mport_types(s.name) = mport_types(s.mem) + val addrs = ArrayBuffer[String]() + val ens = ArrayBuffer[String]() + val masks = ArrayBuffer[String]() + s.direction match { + case MReadWrite => { + repl(s.name) = DataRef(SubField(Ref(s.mem,ut),s.name,ut),"rdata","data","mask",true) + addrs += "addr" + ens += "en" + masks += "mask" + } + case MWrite => { + repl(s.name) = DataRef(SubField(Ref(s.mem,ut),s.name,ut),"data","data","mask",false) + addrs += "addr" + ens += "en" + masks += "mask" + } + case _ => { + repl(s.name) = DataRef(SubField(Ref(s.mem,ut),s.name,ut),"data","data","blah",false) + addrs += "addr" + ens += "en" + } + } + val stmts = ArrayBuffer[Stmt]() + for (x <- addrs ) { + stmts += Connect(s.info,SubField(SubField(Ref(s.mem,ut),s.name,ut),x,ut),s.exps(0)) + } + for (x <- ens ) { + stmts += Connect(s.info,SubField(SubField(Ref(s.mem,ut),s.name,ut),x,ut),one) + } + Begin(stmts) + } + case (s) => sMap(collect_refs _,s) + } + } + def remove_chirrtl_s (s:Stmt) : Stmt = { + var has_write_mport = false + var has_readwrite_mport:Option[Expression] = None + def remove_chirrtl_e (g:Gender)(e:Expression) : Expression = { + (e) match { + case (e:Ref) => { + if (repl.contains(e.name)) { + val vt = repl(e.name) + g match { + case MALE => SubField(vt.exp,vt.male,e.tpe) + case FEMALE => { + has_write_mport = true + if (vt.rdwrite == true) + has_readwrite_mport = Some(SubField(vt.exp,"wmode",UIntType(IntWidth(1)))) + SubField(vt.exp,vt.female,e.tpe) + } + } + } else e + } + case (e:SubAccess) => SubAccess(remove_chirrtl_e(g)(e.exp),remove_chirrtl_e(MALE)(e.index),e.tpe) + case (e) => eMap(remove_chirrtl_e(g) _,e) + } + } + def get_mask (e:Expression) : Expression = { + (eMap(get_mask _,e)) match { + case (e:Ref) => { + if (repl.contains(e.name)) { + val vt = repl(e.name) + val t = create_mask(e.tpe) + SubField(vt.exp,vt.mask,t) + } else e + } + case (e) => e + } + } + (s) match { + case (s:Connect) => { + val stmts = ArrayBuffer[Stmt]() + val rocx = remove_chirrtl_e(MALE)(s.exp) + val locx = remove_chirrtl_e(FEMALE)(s.loc) + stmts += Connect(s.info,locx,rocx) + if (has_write_mport) { + val e = get_mask(s.loc) + for (x <- create_exps(e) ) { + stmts += Connect(s.info,x,one) + } + if (has_readwrite_mport != None) { + val wmode = has_readwrite_mport.get + stmts += Connect(s.info,wmode,one) + } + } + if (stmts.size > 1) Begin(stmts) + else stmts(0) + } + case (s:BulkConnect) => { + val stmts = ArrayBuffer[Stmt]() + val locx = remove_chirrtl_e(FEMALE)(s.loc) + val rocx = remove_chirrtl_e(MALE)(s.exp) + stmts += BulkConnect(s.info,locx,rocx) + if (has_write_mport != false) { + val ls = get_valid_points(tpe(s.loc),tpe(s.exp),DEFAULT,DEFAULT) + val locs = create_exps(get_mask(s.loc)) + for (x <- ls ) { + val locx = locs(x._1) + stmts += Connect(s.info,locx,one) + } + if (has_readwrite_mport != None) { + val wmode = has_readwrite_mport.get + stmts += Connect(s.info,wmode,one) + } + } + if (stmts.size > 1) Begin(stmts) + else stmts(0) + } + case (s) => eMap(remove_chirrtl_e(MALE) _, sMap(remove_chirrtl_s,s)) + } + } + collect_mports(m.body) + val sx = collect_refs(m.body) + InModule(m.info,m.name, m.ports, remove_chirrtl_s(sx)) + } + val modulesx = c.modules.map{ m => { + (m) match { + case (m:InModule) => remove_chirrtl_m(m) + case (m:ExModule) => m + }}} + Circuit(c.info,modulesx, c.main) + } +} |
