aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorjackkoenig2016-04-21 16:18:25 -0700
committerjackkoenig2016-05-03 16:56:52 -0700
commita5526c177563b2c4de2a9c2b39a5b51a05697292 (patch)
tree93cd641cad513e5e4a670b4661563dc849ee4e3b /src
parent75cbdf7682381c511345edc2a51c398251a8db8c (diff)
Change style and spacing of Expand Whens to be more idiomatic Scala
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/ExpandWhens.scala345
1 files changed, 159 insertions, 186 deletions
diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala
index ef0d3414..1b6030e2 100644
--- a/src/main/scala/firrtl/passes/ExpandWhens.scala
+++ b/src/main/scala/firrtl/passes/ExpandWhens.scala
@@ -34,209 +34,182 @@ import firrtl.PrimOps._
import firrtl.WrappedExpression._
// Datastructures
-import scala.collection.mutable.LinkedHashMap
import scala.collection.mutable.HashMap
import scala.collection.mutable.ArrayBuffer
/** Expand Whens
- *
- * @note This pass does three things: remove last connect semantics,
- * remove conditional blocks, and eliminate concept of scoping.
- */
+*
+* @note This pass does three things: remove last connect semantics,
+* remove conditional blocks, and eliminate concept of scoping.
+*/
object ExpandWhens extends Pass {
- def name = "Expand Whens"
- var mname = ""
-// ; ========== Expand When Utilz ==========
- def add (hash:LinkedHashMap[WrappedExpression,Expression],key:WrappedExpression,value:Expression) = {
- hash += (key -> value)
- }
+ def name = "Expand Whens"
+ var mname = ""
+ // ========== Expand When Utilz ==========
+ def getEntries(
+ hash: HashMap[WrappedExpression, Expression],
+ exps: Seq[Expression]): HashMap[WrappedExpression, Expression] = {
+ val hashx = HashMap[WrappedExpression, Expression]()
+ exps foreach (e => if (hash.contains(e)) hashx(e) = hash(e))
+ hashx
+ }
+ def getFemaleRefs(n: String, t: Type, g: Gender): Seq[Expression] = {
+ def getGender(t: Type, i: Int, g: Gender): Gender = times(g, get_flip(t, i, DEFAULT))
+ val exps = create_exps(WRef(n, t, ExpKind(), g))
+ val expsx = ArrayBuffer[Expression]()
+ for (i <- 0 until exps.size) {
+ getGender(t, i, g) match {
+ case (BIGENDER | FEMALE) => expsx += exps(i)
+ case _ =>
+ }
+ }
+ expsx
+ }
- def get_entries (hash:LinkedHashMap[WrappedExpression,Expression],exps:Seq[Expression]) : LinkedHashMap[WrappedExpression,Expression] = {
- val hashx = LinkedHashMap[WrappedExpression,Expression]()
- exps.foreach { e => {
- val value = hash.get(e)
- value match {
- case (value:Some[Expression]) => add(hashx,e,value.get)
- case (None) => {}
- }
- }}
- hashx
- }
- def get_female_refs (n:String,t:Type,g:Gender) : Seq[Expression] = {
- val exps = create_exps(WRef(n,t,ExpKind(),g))
- val expsx = ArrayBuffer[Expression]()
- def get_gender (t:Type, i:Int, g:Gender) : Gender = {
- val f = get_flip(t,i,DEFAULT)
- times(g, f)
+ // ------------ Pass -------------------
+ def run(c: Circuit): Circuit = {
+ def voidAll(m: InModule): InModule = {
+ mname = m.name
+ def voidAllStmt(s: Stmt): Stmt = s match {
+ case (_: DefWire | _: DefRegister | _: WDefInstance |_: DefMemory) =>
+ val voids = ArrayBuffer[Stmt]()
+ for (e <- getFemaleRefs(get_name(s),get_type(s),get_gender(s))) {
+ voids += Connect(get_info(s),e,WVoid())
+ }
+ Begin(Seq(s,Begin(voids)))
+ case s => s map voidAllStmt
}
- for (i <- 0 until exps.size) {
- get_gender(t,i,g) match {
- case BIGENDER => expsx += exps(i)
- case FEMALE => expsx += exps(i)
- case _ => false
- }
+ val voids = ArrayBuffer[Stmt]()
+ for (p <- m.ports) {
+ for (e <- getFemaleRefs(p.name,p.tpe,get_gender(p))) {
+ voids += Connect(p.info,e,WVoid())
+ }
}
- expsx
- }
-
- // ------------ Pass -------------------
- def run (c:Circuit): Circuit = {
- def void_all (m:InModule) : InModule = {
- mname = m.name
- def void_all_s (s:Stmt) : Stmt = {
- (s) match {
- case (_:DefWire|_:DefRegister|_:WDefInstance|_:DefMemory) => {
- val voids = ArrayBuffer[Stmt]()
- for (e <- get_female_refs(get_name(s),get_type(s),get_gender(s))) {
- voids += Connect(get_info(s),e,WVoid())
+ val bodyx = voidAllStmt(m.body)
+ InModule(m.info, m.name, m.ports, Begin(Seq(Begin(voids),bodyx)))
+ }
+ def expandWhens(m: InModule): (HashMap[WrappedExpression, Expression], ArrayBuffer[Stmt]) = {
+ val simlist = ArrayBuffer[Stmt]()
+ mname = m.name
+ def expandWhens(netlist: HashMap[WrappedExpression, Expression], p: Expression)(s: Stmt): Stmt = {
+ s match {
+ case s: Connect => netlist(s.loc) = s.exp
+ case s: IsInvalid => netlist(s.exp) = WInvalid()
+ case s: Conditionally =>
+ val exps = ArrayBuffer[Expression]()
+ def prefetch(s: Stmt): Stmt = s match {
+ case s: Connect => exps += s.loc; s
+ case s => s map prefetch
+ }
+ prefetch(s.conseq)
+ val c_netlist = getEntries(netlist,exps)
+ expandWhens(c_netlist, AND(p, s.pred))(s.conseq)
+ expandWhens(netlist, AND(p, NOT(s.pred)))(s.alt)
+ for (lvalue <- c_netlist.keys) {
+ val value = netlist.get(lvalue)
+ value match {
+ case value: Some[Expression] =>
+ val tv = c_netlist(lvalue)
+ val fv = value.get
+ val res = (tv, fv) match {
+ case (tv:WInvalid, fv:WInvalid) => WInvalid()
+ case (tv:WInvalid, fv) => ValidIf(NOT(s.pred), fv,tpe(fv))
+ case (tv, fv:WInvalid) => ValidIf(s.pred, tv, tpe(tv))
+ case (tv, fv) => Mux(s.pred, tv, fv, mux_type_and_widths(tv, fv))
}
- Begin(Seq(s,Begin(voids)))
- }
- case (s) => s map (void_all_s)
+ netlist(lvalue) = res
+ case None => netlist(lvalue) = c_netlist(lvalue)
+ }
}
- }
- val voids = ArrayBuffer[Stmt]()
- for (p <- m.ports) {
- for (e <- get_female_refs(p.name,p.tpe,get_gender(p))) {
- voids += Connect(p.info,e,WVoid())
+ case s: Print =>
+ if(weq(p, one)) {
+ simlist += s
+ } else {
+ simlist += Print(s.info, s.string, s.args, s.clk, AND(p, s.en))
}
- }
- val bodyx = void_all_s(m.body)
- InModule(m.info,m.name,m.ports,Begin(Seq(Begin(voids),bodyx)))
- }
- def expand_whens (m:InModule) : Tuple2[LinkedHashMap[WrappedExpression,Expression],ArrayBuffer[Stmt]] = {
- val simlist = ArrayBuffer[Stmt]()
- mname = m.name
- def expand_whens (netlist:LinkedHashMap[WrappedExpression,Expression],p:Expression)(s:Stmt) : Stmt = {
- (s) match {
- case (s:Connect) => netlist(s.loc) = s.exp
- case (s:IsInvalid) => netlist(s.exp) = WInvalid()
- case (s:Conditionally) => {
- val exps = ArrayBuffer[Expression]()
- def prefetch (s:Stmt) : Stmt = {
- (s) match {
- case (s:Connect) => exps += s.loc; s
- case (s) => s map(prefetch)
- }
- }
- prefetch(s.conseq)
- val c_netlist = get_entries(netlist,exps)
- expand_whens(c_netlist,AND(p,s.pred))(s.conseq)
- expand_whens(netlist,AND(p,NOT(s.pred)))(s.alt)
- for (lvalue <- c_netlist.keys) {
- val value = netlist.get(lvalue)
- (value) match {
- case (value:Some[Expression]) => {
- val tv = c_netlist(lvalue)
- val fv = value.get
- val res = (tv,fv) match {
- case (tv:WInvalid,fv:WInvalid) => WInvalid()
- case (tv:WInvalid,fv) => ValidIf(NOT(s.pred),fv,tpe(fv))
- case (tv,fv:WInvalid) => ValidIf(s.pred,tv,tpe(tv))
- case (tv,fv) => Mux(s.pred,tv,fv,mux_type_and_widths(tv,fv))
- }
- netlist(lvalue) = res
- }
- case (None) => add(netlist,lvalue,c_netlist(lvalue))
- }
- }
- }
- case (s:Print) => {
- if (weq(p,one)) {
- simlist += s
- } else {
- simlist += Print(s.info,s.string,s.args,s.clk,AND(p,s.en))
- }
- }
- case (s:Stop) => {
- if (weq(p,one)) {
- simlist += s
- } else {
- simlist += Stop(s.info,s.ret,s.clk,AND(p,s.en))
- }
- }
- case (s) => s map(expand_whens(netlist,p))
+ case s: Stop =>
+ if (weq(p, one)) {
+ simlist += s
+ } else {
+ simlist += Stop(s.info, s.ret, s.clk, AND(p, s.en))
}
- s
- }
- val netlist = LinkedHashMap[WrappedExpression,Expression]()
- expand_whens(netlist,one)(m.body)
-
- //println("Netlist:")
- //println(netlist)
- //println("Simlist:")
- //println(simlist)
- ( netlist, simlist )
+ case s => s map expandWhens(netlist, p)
+ }
+ s
}
+ val netlist = HashMap[WrappedExpression, Expression]()
+ expandWhens(netlist, one)(m.body)
- def create_module (netlist:LinkedHashMap[WrappedExpression,Expression],simlist:ArrayBuffer[Stmt],m:InModule) : InModule = {
- mname = m.name
- val stmts = ArrayBuffer[Stmt]()
- val connections = ArrayBuffer[Stmt]()
- def replace_void (e:Expression)(rvalue:Expression) : Expression = {
- (rvalue) match {
- case (rv:WVoid) => e
- case (rv) => rv map (replace_void(e))
- }
- }
- def create (s:Stmt) : Stmt = {
- (s) match {
- case (_:DefWire|_:WDefInstance|_:DefMemory) => {
- stmts += s
- for (e <- get_female_refs(get_name(s),get_type(s),get_gender(s))) {
- val rvalue = netlist(e)
- val con = (rvalue) match {
- case (rvalue:WInvalid) => IsInvalid(get_info(s),e)
- case (rvalue) => Connect(get_info(s),e,rvalue)
- }
- connections += con
- }
- }
- case (s:DefRegister) => {
- stmts += s
- for (e <- get_female_refs(get_name(s),get_type(s),get_gender(s))) {
- val rvalue = replace_void(e)(netlist(e))
- val con = (rvalue) match {
- case (rvalue:WInvalid) => IsInvalid(get_info(s),e)
- case (rvalue) => Connect(get_info(s),e,rvalue)
- }
- connections += con
- }
- }
- case (_:DefPoison|_:DefNode) => stmts += s
- case (s) => s map(create)
+ (netlist, simlist)
+ }
+
+ def createModule(netlist: HashMap[WrappedExpression,Expression], simlist: ArrayBuffer[Stmt], m: InModule): InModule = {
+ mname = m.name
+ val stmts = ArrayBuffer[Stmt]()
+ val connections = ArrayBuffer[Stmt]()
+ def replace_void(e: Expression)(rvalue: Expression): Expression = rvalue match {
+ case rv: WVoid => e
+ case rv => rv map replace_void(e)
+ }
+ def create(s: Stmt): Stmt = {
+ s match {
+ case (_: DefWire | _: WDefInstance | _: DefMemory) =>
+ stmts += s
+ for (e <- getFemaleRefs(get_name(s), get_type(s), get_gender(s))) {
+ val rvalue = netlist(e)
+ val con = rvalue match {
+ case rvalue: WInvalid => IsInvalid(get_info(s), e)
+ case rvalue => Connect(get_info(s), e, rvalue)
+ }
+ connections += con
}
- s
- }
- create(m.body)
- for (p <- m.ports) {
- for (e <- get_female_refs(p.name,p.tpe,get_gender(p))) {
- val rvalue = netlist(e)
- val con = (rvalue) match {
- case (rvalue:WInvalid) => IsInvalid(p.info,e)
- case (rvalue) => Connect(p.info,e,rvalue)
- }
- connections += con
+ case s: DefRegister =>
+ stmts += s
+ for (e <- getFemaleRefs(get_name(s), get_type(s), get_gender(s))) {
+ val rvalue = replace_void(e)(netlist(e))
+ val con = rvalue match {
+ case rvalue: WInvalid => IsInvalid(get_info(s), e)
+ case rvalue => Connect(get_info(s), e, rvalue)
+ }
+ connections += con
}
- }
- for (x <- simlist) { stmts += x }
- InModule(m.info,m.name,m.ports,Begin(Seq(Begin(stmts),Begin(connections))))
+ case (_: DefPoison | _: DefNode) => stmts += s
+ case s => s map create
+ }
+ s
}
+ create(m.body)
+ for (p <- m.ports) {
+ for (e <- getFemaleRefs(p.name, p.tpe, get_gender(p))) {
+ val rvalue = netlist(e)
+ val con = rvalue match {
+ case rvalue: WInvalid => IsInvalid(p.info, e)
+ case rvalue => Connect(p.info, e, rvalue)
+ }
+ connections += con
+ }
+ }
+ for (x <- simlist) { stmts += x }
+ InModule(m.info, m.name, m.ports, Begin(Seq(Begin(stmts), Begin(connections))))
+ }
+
+ val voided_modules = c.modules map { m =>
+ m match {
+ case m: ExModule => m
+ case m: InModule => voidAll(m)
+ }
+ }
- val voided_modules = c.modules.map{ m => {
- (m) match {
- case (m:ExModule) => m
- case (m:InModule) => void_all(m)
- } } }
- val modulesx = voided_modules.map{ m => {
- (m) match {
- case (m:ExModule) => m
- case (m:InModule) => {
- val (netlist, simlist) = expand_whens(m)
- create_module(netlist,simlist,m)
- }
- }}}
- Circuit(c.info,modulesx,c.main)
- }
+ val modulesx = voided_modules map { m =>
+ m match {
+ case m: ExModule => m
+ case m: InModule =>
+ val (netlist, simlist) = expandWhens(m)
+ createModule(netlist, simlist, m)
+
+ }
+ }
+ Circuit(c.info, modulesx, c.main)
+ }
}