diff options
| author | jackkoenig | 2016-04-21 16:18:25 -0700 |
|---|---|---|
| committer | jackkoenig | 2016-05-03 16:56:52 -0700 |
| commit | a5526c177563b2c4de2a9c2b39a5b51a05697292 (patch) | |
| tree | 93cd641cad513e5e4a670b4661563dc849ee4e3b /src | |
| parent | 75cbdf7682381c511345edc2a51c398251a8db8c (diff) | |
Change style and spacing of Expand Whens to be more idiomatic Scala
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/ExpandWhens.scala | 345 |
1 files changed, 159 insertions, 186 deletions
diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index ef0d3414..1b6030e2 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -34,209 +34,182 @@ import firrtl.PrimOps._ import firrtl.WrappedExpression._ // Datastructures -import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.HashMap import scala.collection.mutable.ArrayBuffer /** Expand Whens - * - * @note This pass does three things: remove last connect semantics, - * remove conditional blocks, and eliminate concept of scoping. - */ +* +* @note This pass does three things: remove last connect semantics, +* remove conditional blocks, and eliminate concept of scoping. +*/ object ExpandWhens extends Pass { - def name = "Expand Whens" - var mname = "" -// ; ========== Expand When Utilz ========== - def add (hash:LinkedHashMap[WrappedExpression,Expression],key:WrappedExpression,value:Expression) = { - hash += (key -> value) - } + def name = "Expand Whens" + var mname = "" + // ========== Expand When Utilz ========== + def getEntries( + hash: HashMap[WrappedExpression, Expression], + exps: Seq[Expression]): HashMap[WrappedExpression, Expression] = { + val hashx = HashMap[WrappedExpression, Expression]() + exps foreach (e => if (hash.contains(e)) hashx(e) = hash(e)) + hashx + } + def getFemaleRefs(n: String, t: Type, g: Gender): Seq[Expression] = { + def getGender(t: Type, i: Int, g: Gender): Gender = times(g, get_flip(t, i, DEFAULT)) + val exps = create_exps(WRef(n, t, ExpKind(), g)) + val expsx = ArrayBuffer[Expression]() + for (i <- 0 until exps.size) { + getGender(t, i, g) match { + case (BIGENDER | FEMALE) => expsx += exps(i) + case _ => + } + } + expsx + } - def get_entries (hash:LinkedHashMap[WrappedExpression,Expression],exps:Seq[Expression]) : LinkedHashMap[WrappedExpression,Expression] = { - val hashx = LinkedHashMap[WrappedExpression,Expression]() - exps.foreach { e => { - val value = hash.get(e) - value match { - case (value:Some[Expression]) => add(hashx,e,value.get) - case (None) => {} - } - }} - hashx - } - def get_female_refs (n:String,t:Type,g:Gender) : Seq[Expression] = { - val exps = create_exps(WRef(n,t,ExpKind(),g)) - val expsx = ArrayBuffer[Expression]() - def get_gender (t:Type, i:Int, g:Gender) : Gender = { - val f = get_flip(t,i,DEFAULT) - times(g, f) + // ------------ Pass ------------------- + def run(c: Circuit): Circuit = { + def voidAll(m: InModule): InModule = { + mname = m.name + def voidAllStmt(s: Stmt): Stmt = s match { + case (_: DefWire | _: DefRegister | _: WDefInstance |_: DefMemory) => + val voids = ArrayBuffer[Stmt]() + for (e <- getFemaleRefs(get_name(s),get_type(s),get_gender(s))) { + voids += Connect(get_info(s),e,WVoid()) + } + Begin(Seq(s,Begin(voids))) + case s => s map voidAllStmt } - for (i <- 0 until exps.size) { - get_gender(t,i,g) match { - case BIGENDER => expsx += exps(i) - case FEMALE => expsx += exps(i) - case _ => false - } + val voids = ArrayBuffer[Stmt]() + for (p <- m.ports) { + for (e <- getFemaleRefs(p.name,p.tpe,get_gender(p))) { + voids += Connect(p.info,e,WVoid()) + } } - expsx - } - - // ------------ Pass ------------------- - def run (c:Circuit): Circuit = { - def void_all (m:InModule) : InModule = { - mname = m.name - def void_all_s (s:Stmt) : Stmt = { - (s) match { - case (_:DefWire|_:DefRegister|_:WDefInstance|_:DefMemory) => { - val voids = ArrayBuffer[Stmt]() - for (e <- get_female_refs(get_name(s),get_type(s),get_gender(s))) { - voids += Connect(get_info(s),e,WVoid()) + val bodyx = voidAllStmt(m.body) + InModule(m.info, m.name, m.ports, Begin(Seq(Begin(voids),bodyx))) + } + def expandWhens(m: InModule): (HashMap[WrappedExpression, Expression], ArrayBuffer[Stmt]) = { + val simlist = ArrayBuffer[Stmt]() + mname = m.name + def expandWhens(netlist: HashMap[WrappedExpression, Expression], p: Expression)(s: Stmt): Stmt = { + s match { + case s: Connect => netlist(s.loc) = s.exp + case s: IsInvalid => netlist(s.exp) = WInvalid() + case s: Conditionally => + val exps = ArrayBuffer[Expression]() + def prefetch(s: Stmt): Stmt = s match { + case s: Connect => exps += s.loc; s + case s => s map prefetch + } + prefetch(s.conseq) + val c_netlist = getEntries(netlist,exps) + expandWhens(c_netlist, AND(p, s.pred))(s.conseq) + expandWhens(netlist, AND(p, NOT(s.pred)))(s.alt) + for (lvalue <- c_netlist.keys) { + val value = netlist.get(lvalue) + value match { + case value: Some[Expression] => + val tv = c_netlist(lvalue) + val fv = value.get + val res = (tv, fv) match { + case (tv:WInvalid, fv:WInvalid) => WInvalid() + case (tv:WInvalid, fv) => ValidIf(NOT(s.pred), fv,tpe(fv)) + case (tv, fv:WInvalid) => ValidIf(s.pred, tv, tpe(tv)) + case (tv, fv) => Mux(s.pred, tv, fv, mux_type_and_widths(tv, fv)) } - Begin(Seq(s,Begin(voids))) - } - case (s) => s map (void_all_s) + netlist(lvalue) = res + case None => netlist(lvalue) = c_netlist(lvalue) + } } - } - val voids = ArrayBuffer[Stmt]() - for (p <- m.ports) { - for (e <- get_female_refs(p.name,p.tpe,get_gender(p))) { - voids += Connect(p.info,e,WVoid()) + case s: Print => + if(weq(p, one)) { + simlist += s + } else { + simlist += Print(s.info, s.string, s.args, s.clk, AND(p, s.en)) } - } - val bodyx = void_all_s(m.body) - InModule(m.info,m.name,m.ports,Begin(Seq(Begin(voids),bodyx))) - } - def expand_whens (m:InModule) : Tuple2[LinkedHashMap[WrappedExpression,Expression],ArrayBuffer[Stmt]] = { - val simlist = ArrayBuffer[Stmt]() - mname = m.name - def expand_whens (netlist:LinkedHashMap[WrappedExpression,Expression],p:Expression)(s:Stmt) : Stmt = { - (s) match { - case (s:Connect) => netlist(s.loc) = s.exp - case (s:IsInvalid) => netlist(s.exp) = WInvalid() - case (s:Conditionally) => { - val exps = ArrayBuffer[Expression]() - def prefetch (s:Stmt) : Stmt = { - (s) match { - case (s:Connect) => exps += s.loc; s - case (s) => s map(prefetch) - } - } - prefetch(s.conseq) - val c_netlist = get_entries(netlist,exps) - expand_whens(c_netlist,AND(p,s.pred))(s.conseq) - expand_whens(netlist,AND(p,NOT(s.pred)))(s.alt) - for (lvalue <- c_netlist.keys) { - val value = netlist.get(lvalue) - (value) match { - case (value:Some[Expression]) => { - val tv = c_netlist(lvalue) - val fv = value.get - val res = (tv,fv) match { - case (tv:WInvalid,fv:WInvalid) => WInvalid() - case (tv:WInvalid,fv) => ValidIf(NOT(s.pred),fv,tpe(fv)) - case (tv,fv:WInvalid) => ValidIf(s.pred,tv,tpe(tv)) - case (tv,fv) => Mux(s.pred,tv,fv,mux_type_and_widths(tv,fv)) - } - netlist(lvalue) = res - } - case (None) => add(netlist,lvalue,c_netlist(lvalue)) - } - } - } - case (s:Print) => { - if (weq(p,one)) { - simlist += s - } else { - simlist += Print(s.info,s.string,s.args,s.clk,AND(p,s.en)) - } - } - case (s:Stop) => { - if (weq(p,one)) { - simlist += s - } else { - simlist += Stop(s.info,s.ret,s.clk,AND(p,s.en)) - } - } - case (s) => s map(expand_whens(netlist,p)) + case s: Stop => + if (weq(p, one)) { + simlist += s + } else { + simlist += Stop(s.info, s.ret, s.clk, AND(p, s.en)) } - s - } - val netlist = LinkedHashMap[WrappedExpression,Expression]() - expand_whens(netlist,one)(m.body) - - //println("Netlist:") - //println(netlist) - //println("Simlist:") - //println(simlist) - ( netlist, simlist ) + case s => s map expandWhens(netlist, p) + } + s } + val netlist = HashMap[WrappedExpression, Expression]() + expandWhens(netlist, one)(m.body) - def create_module (netlist:LinkedHashMap[WrappedExpression,Expression],simlist:ArrayBuffer[Stmt],m:InModule) : InModule = { - mname = m.name - val stmts = ArrayBuffer[Stmt]() - val connections = ArrayBuffer[Stmt]() - def replace_void (e:Expression)(rvalue:Expression) : Expression = { - (rvalue) match { - case (rv:WVoid) => e - case (rv) => rv map (replace_void(e)) - } - } - def create (s:Stmt) : Stmt = { - (s) match { - case (_:DefWire|_:WDefInstance|_:DefMemory) => { - stmts += s - for (e <- get_female_refs(get_name(s),get_type(s),get_gender(s))) { - val rvalue = netlist(e) - val con = (rvalue) match { - case (rvalue:WInvalid) => IsInvalid(get_info(s),e) - case (rvalue) => Connect(get_info(s),e,rvalue) - } - connections += con - } - } - case (s:DefRegister) => { - stmts += s - for (e <- get_female_refs(get_name(s),get_type(s),get_gender(s))) { - val rvalue = replace_void(e)(netlist(e)) - val con = (rvalue) match { - case (rvalue:WInvalid) => IsInvalid(get_info(s),e) - case (rvalue) => Connect(get_info(s),e,rvalue) - } - connections += con - } - } - case (_:DefPoison|_:DefNode) => stmts += s - case (s) => s map(create) + (netlist, simlist) + } + + def createModule(netlist: HashMap[WrappedExpression,Expression], simlist: ArrayBuffer[Stmt], m: InModule): InModule = { + mname = m.name + val stmts = ArrayBuffer[Stmt]() + val connections = ArrayBuffer[Stmt]() + def replace_void(e: Expression)(rvalue: Expression): Expression = rvalue match { + case rv: WVoid => e + case rv => rv map replace_void(e) + } + def create(s: Stmt): Stmt = { + s match { + case (_: DefWire | _: WDefInstance | _: DefMemory) => + stmts += s + for (e <- getFemaleRefs(get_name(s), get_type(s), get_gender(s))) { + val rvalue = netlist(e) + val con = rvalue match { + case rvalue: WInvalid => IsInvalid(get_info(s), e) + case rvalue => Connect(get_info(s), e, rvalue) + } + connections += con } - s - } - create(m.body) - for (p <- m.ports) { - for (e <- get_female_refs(p.name,p.tpe,get_gender(p))) { - val rvalue = netlist(e) - val con = (rvalue) match { - case (rvalue:WInvalid) => IsInvalid(p.info,e) - case (rvalue) => Connect(p.info,e,rvalue) - } - connections += con + case s: DefRegister => + stmts += s + for (e <- getFemaleRefs(get_name(s), get_type(s), get_gender(s))) { + val rvalue = replace_void(e)(netlist(e)) + val con = rvalue match { + case rvalue: WInvalid => IsInvalid(get_info(s), e) + case rvalue => Connect(get_info(s), e, rvalue) + } + connections += con } - } - for (x <- simlist) { stmts += x } - InModule(m.info,m.name,m.ports,Begin(Seq(Begin(stmts),Begin(connections)))) + case (_: DefPoison | _: DefNode) => stmts += s + case s => s map create + } + s } + create(m.body) + for (p <- m.ports) { + for (e <- getFemaleRefs(p.name, p.tpe, get_gender(p))) { + val rvalue = netlist(e) + val con = rvalue match { + case rvalue: WInvalid => IsInvalid(p.info, e) + case rvalue => Connect(p.info, e, rvalue) + } + connections += con + } + } + for (x <- simlist) { stmts += x } + InModule(m.info, m.name, m.ports, Begin(Seq(Begin(stmts), Begin(connections)))) + } + + val voided_modules = c.modules map { m => + m match { + case m: ExModule => m + case m: InModule => voidAll(m) + } + } - val voided_modules = c.modules.map{ m => { - (m) match { - case (m:ExModule) => m - case (m:InModule) => void_all(m) - } } } - val modulesx = voided_modules.map{ m => { - (m) match { - case (m:ExModule) => m - case (m:InModule) => { - val (netlist, simlist) = expand_whens(m) - create_module(netlist,simlist,m) - } - }}} - Circuit(c.info,modulesx,c.main) - } + val modulesx = voided_modules map { m => + m match { + case m: ExModule => m + case m: InModule => + val (netlist, simlist) = expandWhens(m) + createModule(netlist, simlist, m) + + } + } + Circuit(c.info, modulesx, c.main) + } } |
