aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/Passes.scala
diff options
context:
space:
mode:
authorjackkoenig2016-03-01 12:15:28 -0800
committerjackkoenig2016-03-01 12:21:11 -0800
commit079005f630590bdaf4671c9d8ab127b649cd61df (patch)
tree94885d84691570e43a59684d9facf71e10bdab0f /src/main/scala/firrtl/passes/Passes.scala
parentaa2322eb09e9059ad1cdf066c3e7270e0b98679d (diff)
Move mapper functions to implicit methods on IR vertices.
Diffstat (limited to 'src/main/scala/firrtl/passes/Passes.scala')
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala137
1 files changed, 69 insertions, 68 deletions
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
index 8a2fb5c8..7490c479 100644
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ b/src/main/scala/firrtl/passes/Passes.scala
@@ -40,6 +40,7 @@ import scala.collection.mutable.ArrayBuffer
import firrtl._
import firrtl.Utils._
+import firrtl.Mappers._
import firrtl.Serialize._
import firrtl.PrimOps._
import firrtl.WrappedExpression._
@@ -99,7 +100,7 @@ object ToWorkingIR extends Pass {
def name = "Working IR"
def run (c:Circuit): Circuit = {
def toExp (e:Expression) : Expression = {
- eMap(toExp _,e) match {
+ e map (toExp) match {
case e:Ref => WRef(e.name, e.tpe, NodeKind(), UNKNOWNGENDER)
case e:SubField => WSubField(e.exp, e.name, e.tpe, UNKNOWNGENDER)
case e:SubIndex => WSubIndex(e.exp, e.value, e.tpe, UNKNOWNGENDER)
@@ -108,9 +109,9 @@ object ToWorkingIR extends Pass {
}
}
def toStmt (s:Stmt) : Stmt = {
- eMap(toExp _,s) match {
+ s map (toExp) match {
case s:DefInstance => WDefInstance(s.info,s.name,s.module,UnknownType())
- case s => sMap(toStmt _,s)
+ case s => s map (toStmt)
}
}
val modulesx = c.modules.map { m =>
@@ -139,10 +140,10 @@ object ResolveKinds extends Pass {
def resolve_expr (e:Expression):Expression = {
e match {
case e:WRef => WRef(e.name,tpe(e),kinds(e.name),e.gender)
- case e => eMap(resolve_expr,e)
+ case e => e map (resolve_expr)
}
}
- def resolve_stmt (s:Stmt):Stmt = eMap(resolve_expr,sMap(resolve_stmt,s))
+ def resolve_stmt (s:Stmt):Stmt = s map (resolve_stmt) map (resolve_expr)
resolve_stmt(body)
}
@@ -157,7 +158,7 @@ object ResolveKinds extends Pass {
case s:DefMemory => kinds(s.name) = MemKind(s.readers ++ s.writers ++ s.readwriters)
case s => false
}
- sMap(find_stmt,s)
+ s map (find_stmt)
}
m.ports.foreach { p => kinds(p.name) = PortKind() }
m match {
@@ -206,7 +207,7 @@ object InferTypes extends Pass {
def infer_types (m:Module) : Module = {
val types = LinkedHashMap[String,Type]()
def infer_types_e (e:Expression) : Expression = {
- eMap(infer_types_e _,e) match {
+ e map (infer_types_e) match {
case e:ValidIf => ValidIf(e.cond,e.value,tpe(e.value))
case e:WRef => WRef(e.name, types(e.name),e.kind,e.gender)
case e:WSubField => WSubField(e.exp,e.name,field_type(tpe(e.exp),e.name),e.gender)
@@ -223,22 +224,22 @@ object InferTypes extends Pass {
case s:DefRegister => {
val t = remove_unknowns(get_type(s))
types(s.name) = t
- eMap(infer_types_e _,set_type(s,t))
+ set_type(s,t) map (infer_types_e)
}
case s:DefWire => {
- val sx = eMap(infer_types_e _,s)
+ val sx = s map(infer_types_e)
val t = remove_unknowns(get_type(sx))
types(s.name) = t
set_type(sx,t)
}
case s:DefPoison => {
- val sx = eMap(infer_types_e _,s)
+ val sx = s map (infer_types_e)
val t = remove_unknowns(get_type(sx))
types(s.name) = t
set_type(sx,t)
}
case s:DefNode => {
- val sx = eMap(infer_types_e _,s)
+ val sx = s map (infer_types_e)
val t = remove_unknowns(get_type(sx))
types(s.name) = t
set_type(sx,t)
@@ -253,7 +254,7 @@ object InferTypes extends Pass {
types(s.name) = module_types(s.module)
WDefInstance(s.info,s.name,s.module,module_types(s.module))
}
- case s => eMap(infer_types_e _,sMap(infer_types_s,s))
+ case s => s map (infer_types_s) map (infer_types_e)
}
}
@@ -304,7 +305,7 @@ object ResolveGenders extends Pass {
val indexx = resolve_e(MALE)(e.index)
WSubAccess(expx,indexx,e.tpe,g)
}
- case e => eMap(resolve_e(g) _,e)
+ case e => e map (resolve_e(g))
}
}
@@ -324,7 +325,7 @@ object ResolveGenders extends Pass {
val expx = resolve_e(MALE)(s.exp)
BulkConnect(s.info,locx,expx)
}
- case s => sMap(resolve_s,eMap(resolve_e(MALE) _,s))
+ case s => s map (resolve_e(MALE)) map (resolve_s)
}
}
val modulesx = c.modules.map {
@@ -362,7 +363,7 @@ object InferWidths extends Pass {
h
}
def simplify (w:Width) : Width = {
- (wMap(simplify _,w)) match {
+ (w map (simplify)) match {
case (w:MinWidth) => {
val v = ArrayBuffer[Width]()
for (wx <- w.args) {
@@ -394,7 +395,7 @@ object InferWidths extends Pass {
//;println-all-debug(["Substituting for [" w "]"])
val wx = simplify(w)
//;println-all-debug(["After Simplify: [" wx "]"])
- (wMap(substitute(h) _,simplify(w))) match {
+ (simplify(w) map (substitute(h))) match {
case (w:VarWidth) => {
//;("matched println-debugvarwidth!")
if (h.contains(w.name)) {
@@ -413,14 +414,14 @@ object InferWidths extends Pass {
}
}
def b_sub (h:LinkedHashMap[String,Width])(w:Width) : Width = {
- (wMap(b_sub(h) _,w)) match {
+ (w map (b_sub(h))) match {
case (w:VarWidth) => if (h.contains(w.name)) h(w.name) else w
case (w) => w
}
}
def remove_cycle (n:String)(w:Width) : Width = {
//;println-all-debug(["Removing cycle for " n " inside " w])
- val wx = (wMap(remove_cycle(n) _,w)) match {
+ val wx = (w map (remove_cycle(n))) match {
case (w:MaxWidth) => MaxWidth(w.args.filter{ w => {
w match {
case (w:VarWidth) => !(n equals w.name)
@@ -438,7 +439,7 @@ object InferWidths extends Pass {
def self_rec (n:String,w:Width) : Boolean = {
var has = false
def look (w:Width) : Width = {
- (wMap(look _,w)) match {
+ (w map (look)) match {
case (w:VarWidth) => if (w.name == n) has = true
case (w) => w }
w }
@@ -587,14 +588,14 @@ object InferWidths extends Pass {
get_constraints_t(f1.tpe,f2.tpe,times(f1.flip,f)) }}}
case (t1:VectorType,t2:VectorType) => get_constraints_t(t1.tpe,t2.tpe,f) }}
def get_constraints_e (e:Expression) : Expression = {
- (eMap(get_constraints_e _,e)) match {
+ (e map (get_constraints_e)) match {
case (e:Mux) => {
constrain(width_BANG(e.cond),ONE)
constrain(ONE,width_BANG(e.cond))
e }
case (e) => e }}
def get_constraints (s:Stmt) : Stmt = {
- (eMap(get_constraints_e _,s)) match {
+ (s map (get_constraints_e)) match {
case (s:Connect) => {
val n = get_size(tpe(s.loc))
val ce_loc = create_exps(s.loc)
@@ -623,8 +624,8 @@ object InferWidths extends Pass {
case (s:Conditionally) => {
v += WGeq(width_BANG(s.pred),ONE)
v += WGeq(ONE,width_BANG(s.pred))
- sMap(get_constraints _,s) }
- case (s) => sMap(get_constraints _,s) }}
+ s map (get_constraints) }
+ case (s) => s map (get_constraints) }}
for (m <- c.modules) {
(m) match {
@@ -646,7 +647,7 @@ object PullMuxes extends Pass {
def name = "Pull Muxes"
def run (c:Circuit): Circuit = {
def pull_muxes_e (e:Expression) : Expression = {
- val ex = eMap(pull_muxes_e _,e) match {
+ val ex = e map (pull_muxes_e) match {
case (e:WRef) => e
case (e:WSubField) => {
e.exp match {
@@ -673,9 +674,9 @@ object PullMuxes extends Pass {
case (e:ValidIf) => e
case (e) => e
}
- eMap(pull_muxes_e _,ex)
+ ex map (pull_muxes_e)
}
- def pull_muxes (s:Stmt) : Stmt = eMap(pull_muxes_e _,sMap(pull_muxes _,s))
+ def pull_muxes (s:Stmt) : Stmt = s map (pull_muxes) map (pull_muxes_e)
val modulesx = c.modules.map {
m => {
mname = m.name
@@ -698,7 +699,7 @@ object ExpandConnects extends Pass {
val genders = LinkedHashMap[String,Gender]()
def expand_s (s:Stmt) : Stmt = {
def set_gender (e:Expression) : Expression = {
- eMap(set_gender _,e) match {
+ e map (set_gender) match {
case (e:WRef) => WRef(e.name,e.tpe,e.kind,genders(e.name))
case (e:WSubField) => {
val f = get_field(tpe(e.exp),e.name)
@@ -768,7 +769,7 @@ object ExpandConnects extends Pass {
}}
Begin(connects)
}
- case (s) => sMap(expand_s _,s)
+ case (s) => s map (expand_s)
}
}
@@ -845,7 +846,7 @@ object RemoveAccesses extends Pass {
def rec_has_access (e:Expression) : Expression = {
e match {
case (e:WSubAccess) => { ret = true; e }
- case (e) => eMap(rec_has_access _,e)
+ case (e) => e map (rec_has_access)
}
}
rec_has_access(e)
@@ -864,9 +865,9 @@ object RemoveAccesses extends Pass {
}
def remove_e (e:Expression) : Expression = { //NOT RECURSIVE (except primops) INTENTIONALLY!
e match {
- case (e:DoPrim) => eMap(remove_e,e)
- case (e:Mux) => eMap(remove_e,e)
- case (e:ValidIf) => eMap(remove_e,e)
+ case (e:DoPrim) => e map (remove_e)
+ case (e:Mux) => e map (remove_e)
+ case (e:ValidIf) => e map (remove_e)
case (e:SIntValue) => e
case (e:UIntValue) => e
case e => {
@@ -910,7 +911,7 @@ object RemoveAccesses extends Pass {
Connect(s.info,locx,remove_e(s.exp))
} else { Connect(s.info,s.loc,remove_e(s.exp)) }
}
- case (s) => sMap(remove_s,eMap(remove_e,s))
+ case (s) => s map (remove_e) map (remove_s)
}
stmts += sx
if (stmts.size != 1) Begin(stmts) else stmts(0)
@@ -979,7 +980,7 @@ object ExpandWhens extends Pass {
}
Begin(Seq(s,Begin(voids)))
}
- case (s) => sMap(void_all_s _,s)
+ case (s) => s map (void_all_s)
}
}
val voids = ArrayBuffer[Stmt]()
@@ -1003,7 +1004,7 @@ object ExpandWhens extends Pass {
def prefetch (s:Stmt) : Stmt = {
(s) match {
case (s:Connect) => exps += s.loc; s
- case (s) => sMap(prefetch _,s)
+ case (s) => s map(prefetch)
}
}
prefetch(s.conseq)
@@ -1042,7 +1043,7 @@ object ExpandWhens extends Pass {
simlist += Stop(s.info,s.ret,s.clk,AND(p,s.en))
}
}
- case (s) => sMap(expand_whens(netlist,p) _, s)
+ case (s) => s map(expand_whens(netlist,p))
}
s
}
@@ -1063,7 +1064,7 @@ object ExpandWhens extends Pass {
def replace_void (e:Expression)(rvalue:Expression) : Expression = {
(rvalue) match {
case (rv:WVoid) => e
- case (rv) => eMap(replace_void(e) _,rv)
+ case (rv) => rv map (replace_void(e))
}
}
def create (s:Stmt) : Stmt = {
@@ -1091,7 +1092,7 @@ object ExpandWhens extends Pass {
}
}
case (_:DefPoison|_:DefNode) => stmts += s
- case (s) => sMap(create _,s)
+ case (s) => s map(create)
}
s
}
@@ -1131,7 +1132,7 @@ object ConstProp extends Pass {
def name = "Constant Propogation"
var mname = ""
def const_prop_e (e:Expression) : Expression = {
- eMap(const_prop_e _,e) match {
+ e map (const_prop_e) match {
case (e:DoPrim) => {
e.op match {
case SHIFT_RIGHT_OP => {
@@ -1173,7 +1174,7 @@ object ConstProp extends Pass {
case (e) => e
}
}
- def const_prop_s (s:Stmt) : Stmt = eMap(const_prop_e _, sMap(const_prop_s _,s))
+ def const_prop_s (s:Stmt) : Stmt = s map (const_prop_s) map (const_prop_e)
def run (c:Circuit): Circuit = {
val modulesx = c.modules.map{ m => {
m match {
@@ -1202,7 +1203,7 @@ object VerilogWrap extends Pass {
def name = "Verilog Wrap"
var mname = ""
def v_wrap_e (e:Expression) : Expression = {
- eMap(v_wrap_e _,e) match {
+ e map (v_wrap_e) match {
case (e:DoPrim) => {
def a0 () = e.args(0)
if (e.op == TAIL_OP) {
@@ -1220,7 +1221,7 @@ object VerilogWrap extends Pass {
case (e) => e
}
}
- def v_wrap_s (s:Stmt) : Stmt = eMap(v_wrap_e _,sMap(v_wrap_s _,s))
+ def v_wrap_s (s:Stmt) : Stmt = s map (v_wrap_s) map (v_wrap_e)
def run (c:Circuit): Circuit = {
val modulesx = c.modules.map{ m => {
(m) match {
@@ -1248,19 +1249,19 @@ object SplitExp extends Pass {
WRef(n,tpe(e),kind(e),gender(e))
}
def split_exp_e (i:Int)(e:Expression) : Expression = {
- eMap(split_exp_e(i + 1) _,e) match {
+ e map (split_exp_e(i + 1)) match {
case (e:DoPrim) => if (i > 0) split(e) else e
case (e) => e
}
}
s match {
- case (s:Begin) => sMap(split_exp_s _,s)
+ case (s:Begin) => s map (split_exp_s)
case (s:Print) => {
- val sx = eMap(split_exp_e(1) _,s)
+ val sx = s map (split_exp_e(1))
v += sx; sx
}
case (s) => {
- val sx = eMap(split_exp_e(0) _,s)
+ val sx = s map (split_exp_e(0))
v += sx; sx
}
}
@@ -1289,11 +1290,11 @@ object VerilogRename extends Pass {
def verilog_rename_e (e:Expression) : Expression = {
(e) match {
case (e:WRef) => WRef(verilog_rename_n(e.name),e.tpe,kind(e),gender(e))
- case (e) => eMap(verilog_rename_e,e)
+ case (e) => e map (verilog_rename_e)
}
}
def verilog_rename_s (s:Stmt) : Stmt = {
- stMap(verilog_rename_n _,eMap(verilog_rename_e _,sMap(verilog_rename_s _,s)))
+ s map (verilog_rename_s) map (verilog_rename_e) map (verilog_rename_n)
}
val modulesx = c.modules.map{ m => {
val portsx = m.ports.map{ p => {
@@ -1341,7 +1342,7 @@ object LowerTypes extends Pass {
def expand_name (e:Expression) : Seq[String] = {
val names = ArrayBuffer[String]()
def expand_name_e (e:Expression) : Expression = {
- (eMap(expand_name_e _,e)) match {
+ (e map (expand_name_e)) match {
case (e:WRef) => names += e.name
case (e:WSubField) => names += e.name
case (e:WSubIndex) => names += e.value.toString
@@ -1418,9 +1419,9 @@ object LowerTypes extends Pass {
case (k) => WRef(lowered_name(e),tpe(e),kind(e),gender(e))
}
}
- case (e:DoPrim) => eMap(lower_types_e _,e)
- case (e:Mux) => eMap(lower_types_e _,e)
- case (e:ValidIf) => eMap(lower_types_e _,e)
+ case (e:DoPrim) => e map (lower_types_e)
+ case (e:Mux) => e map (lower_types_e)
+ case (e:ValidIf) => e map (lower_types_e)
}
}
(s) match {
@@ -1476,7 +1477,7 @@ object LowerTypes extends Pass {
}
}
case (s:IsInvalid) => {
- val sx = eMap(lower_types_e _,s).as[IsInvalid].get
+ val sx = (s map (lower_types_e)).as[IsInvalid].get
kind(sx.exp) match {
case (k:MemKind) => {
val es = lower_mem(sx.exp)
@@ -1486,7 +1487,7 @@ object LowerTypes extends Pass {
}
}
case (s:Connect) => {
- val sx = eMap(lower_types_e _,s).as[Connect].get
+ val sx = (s map (lower_types_e)).as[Connect].get
kind(sx.loc) match {
case (k:MemKind) => {
val es = lower_mem(sx.loc)
@@ -1507,7 +1508,7 @@ object LowerTypes extends Pass {
}
if (n == 1) nodes(0) else Begin(nodes)
}
- case (s) => eMap(lower_types_e _,sMap(lower_types _,s))
+ case (s) => s map (lower_types) map (lower_types_e)
}
}
@@ -1567,7 +1568,7 @@ object CInferTypes extends Pass {
def infer_types (m:Module) : Module = {
val types = LinkedHashMap[String,Type]()
def infer_types_e (e:Expression) : Expression = {
- (eMap(infer_types_e _,e)) match {
+ (e map (infer_types_e)) match {
case (e:Ref) => Ref(e.name, types.getOrElse(e.name,UnknownType()))
case (e:SubField) => SubField(e.exp,e.name,field_type(tpe(e.exp),e.name))
case (e:SubIndex) => SubIndex(e.exp,e.value,sub_type(tpe(e.exp)))
@@ -1582,7 +1583,7 @@ object CInferTypes extends Pass {
(s) match {
case (s:DefRegister) => {
types(s.name) = s.tpe
- eMap(infer_types_e _,s)
+ s map (infer_types_e)
s
}
case (s:DefWire) => {
@@ -1594,7 +1595,7 @@ object CInferTypes extends Pass {
s
}
case (s:DefNode) => {
- val sx = eMap(infer_types_e _,s)
+ val sx = s map (infer_types_e)
val t = get_type(sx)
types(s.name) = t
sx
@@ -1616,7 +1617,7 @@ object CInferTypes extends Pass {
types(s.name) = module_types.getOrElse(s.module,UnknownType())
s
}
- case (s) => eMap(infer_types_e _,sMap(infer_types_s _,s))
+ case (s) => s map(infer_types_s) map (infer_types_e)
}
}
for (p <- m.ports) {
@@ -1644,7 +1645,7 @@ object CInferMDir extends Pass {
def infer_mdir (m:Module) : Module = {
val mports = LinkedHashMap[String,MPortDir]()
def infer_mdir_e (dir:MPortDir)(e:Expression) : Expression = {
- (eMap(infer_mdir_e(dir) _,e)) match {
+ (e map (infer_mdir_e(dir))) match {
case (e:Ref) => {
if (mports.contains(e.name)) {
val new_mport_dir = {
@@ -1678,7 +1679,7 @@ object CInferMDir extends Pass {
(s) match {
case (s:CDefMPort) => {
mports(s.name) = s.direction
- eMap(infer_mdir_e(MRead) _,s)
+ s map (infer_mdir_e(MRead))
}
case (s:Connect) => {
infer_mdir_e(MRead)(s.exp)
@@ -1690,14 +1691,14 @@ object CInferMDir extends Pass {
infer_mdir_e(MWrite)(s.loc)
s
}
- case (s) => eMap(infer_mdir_e(MRead) _, sMap(infer_mdir_s,s))
+ case (s) => s map (infer_mdir_s) map (infer_mdir_e(MRead))
}
}
def set_mdir_s (s:Stmt) : Stmt = {
(s) match {
case (s:CDefMPort) =>
CDefMPort(s.info,s.name,s.tpe,s.mem,s.exps,mports(s.name))
- case (s) => sMap(set_mdir_s _,s)
+ case (s) => s map (set_mdir_s)
}
}
(m) match {
@@ -1760,7 +1761,7 @@ object RemoveCHIRRTL extends Pass {
hash(s.mem) = mports
s
}
- case (s) => sMap(collect_mports _,s)
+ case (s) => s map (collect_mports)
}
}
def collect_refs (s:Stmt) : Stmt = {
@@ -1840,7 +1841,7 @@ object RemoveCHIRRTL extends Pass {
}
Begin(stmts)
}
- case (s) => sMap(collect_refs _,s)
+ case (s) => s map (collect_refs)
}
}
def remove_chirrtl_s (s:Stmt) : Stmt = {
@@ -1863,11 +1864,11 @@ object RemoveCHIRRTL extends Pass {
} else e
}
case (e:SubAccess) => SubAccess(remove_chirrtl_e(g)(e.exp),remove_chirrtl_e(MALE)(e.index),e.tpe)
- case (e) => eMap(remove_chirrtl_e(g) _,e)
+ case (e) => e map (remove_chirrtl_e(g))
}
}
def get_mask (e:Expression) : Expression = {
- (eMap(get_mask _,e)) match {
+ (e map (get_mask)) match {
case (e:Ref) => {
if (repl.contains(e.name)) {
val vt = repl(e.name)
@@ -1917,7 +1918,7 @@ object RemoveCHIRRTL extends Pass {
if (stmts.size > 1) Begin(stmts)
else stmts(0)
}
- case (s) => eMap(remove_chirrtl_e(MALE) _, sMap(remove_chirrtl_s,s))
+ case (s) => s map (remove_chirrtl_s) map (remove_chirrtl_e(MALE))
}
}
collect_mports(m.body)