diff options
| author | jackkoenig | 2016-03-01 12:15:28 -0800 |
|---|---|---|
| committer | jackkoenig | 2016-03-01 12:21:11 -0800 |
| commit | 079005f630590bdaf4671c9d8ab127b649cd61df (patch) | |
| tree | 94885d84691570e43a59684d9facf71e10bdab0f /src/main/scala/firrtl/passes/Passes.scala | |
| parent | aa2322eb09e9059ad1cdf066c3e7270e0b98679d (diff) | |
Move mapper functions to implicit methods on IR vertices.
Diffstat (limited to 'src/main/scala/firrtl/passes/Passes.scala')
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 137 |
1 files changed, 69 insertions, 68 deletions
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 8a2fb5c8..7490c479 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -40,6 +40,7 @@ import scala.collection.mutable.ArrayBuffer import firrtl._ import firrtl.Utils._ +import firrtl.Mappers._ import firrtl.Serialize._ import firrtl.PrimOps._ import firrtl.WrappedExpression._ @@ -99,7 +100,7 @@ object ToWorkingIR extends Pass { def name = "Working IR" def run (c:Circuit): Circuit = { def toExp (e:Expression) : Expression = { - eMap(toExp _,e) match { + e map (toExp) match { case e:Ref => WRef(e.name, e.tpe, NodeKind(), UNKNOWNGENDER) case e:SubField => WSubField(e.exp, e.name, e.tpe, UNKNOWNGENDER) case e:SubIndex => WSubIndex(e.exp, e.value, e.tpe, UNKNOWNGENDER) @@ -108,9 +109,9 @@ object ToWorkingIR extends Pass { } } def toStmt (s:Stmt) : Stmt = { - eMap(toExp _,s) match { + s map (toExp) match { case s:DefInstance => WDefInstance(s.info,s.name,s.module,UnknownType()) - case s => sMap(toStmt _,s) + case s => s map (toStmt) } } val modulesx = c.modules.map { m => @@ -139,10 +140,10 @@ object ResolveKinds extends Pass { def resolve_expr (e:Expression):Expression = { e match { case e:WRef => WRef(e.name,tpe(e),kinds(e.name),e.gender) - case e => eMap(resolve_expr,e) + case e => e map (resolve_expr) } } - def resolve_stmt (s:Stmt):Stmt = eMap(resolve_expr,sMap(resolve_stmt,s)) + def resolve_stmt (s:Stmt):Stmt = s map (resolve_stmt) map (resolve_expr) resolve_stmt(body) } @@ -157,7 +158,7 @@ object ResolveKinds extends Pass { case s:DefMemory => kinds(s.name) = MemKind(s.readers ++ s.writers ++ s.readwriters) case s => false } - sMap(find_stmt,s) + s map (find_stmt) } m.ports.foreach { p => kinds(p.name) = PortKind() } m match { @@ -206,7 +207,7 @@ object InferTypes extends Pass { def infer_types (m:Module) : Module = { val types = LinkedHashMap[String,Type]() def infer_types_e (e:Expression) : Expression = { - eMap(infer_types_e _,e) match { + e map (infer_types_e) match { case e:ValidIf => ValidIf(e.cond,e.value,tpe(e.value)) case e:WRef => WRef(e.name, types(e.name),e.kind,e.gender) case e:WSubField => WSubField(e.exp,e.name,field_type(tpe(e.exp),e.name),e.gender) @@ -223,22 +224,22 @@ object InferTypes extends Pass { case s:DefRegister => { val t = remove_unknowns(get_type(s)) types(s.name) = t - eMap(infer_types_e _,set_type(s,t)) + set_type(s,t) map (infer_types_e) } case s:DefWire => { - val sx = eMap(infer_types_e _,s) + val sx = s map(infer_types_e) val t = remove_unknowns(get_type(sx)) types(s.name) = t set_type(sx,t) } case s:DefPoison => { - val sx = eMap(infer_types_e _,s) + val sx = s map (infer_types_e) val t = remove_unknowns(get_type(sx)) types(s.name) = t set_type(sx,t) } case s:DefNode => { - val sx = eMap(infer_types_e _,s) + val sx = s map (infer_types_e) val t = remove_unknowns(get_type(sx)) types(s.name) = t set_type(sx,t) @@ -253,7 +254,7 @@ object InferTypes extends Pass { types(s.name) = module_types(s.module) WDefInstance(s.info,s.name,s.module,module_types(s.module)) } - case s => eMap(infer_types_e _,sMap(infer_types_s,s)) + case s => s map (infer_types_s) map (infer_types_e) } } @@ -304,7 +305,7 @@ object ResolveGenders extends Pass { val indexx = resolve_e(MALE)(e.index) WSubAccess(expx,indexx,e.tpe,g) } - case e => eMap(resolve_e(g) _,e) + case e => e map (resolve_e(g)) } } @@ -324,7 +325,7 @@ object ResolveGenders extends Pass { val expx = resolve_e(MALE)(s.exp) BulkConnect(s.info,locx,expx) } - case s => sMap(resolve_s,eMap(resolve_e(MALE) _,s)) + case s => s map (resolve_e(MALE)) map (resolve_s) } } val modulesx = c.modules.map { @@ -362,7 +363,7 @@ object InferWidths extends Pass { h } def simplify (w:Width) : Width = { - (wMap(simplify _,w)) match { + (w map (simplify)) match { case (w:MinWidth) => { val v = ArrayBuffer[Width]() for (wx <- w.args) { @@ -394,7 +395,7 @@ object InferWidths extends Pass { //;println-all-debug(["Substituting for [" w "]"]) val wx = simplify(w) //;println-all-debug(["After Simplify: [" wx "]"]) - (wMap(substitute(h) _,simplify(w))) match { + (simplify(w) map (substitute(h))) match { case (w:VarWidth) => { //;("matched println-debugvarwidth!") if (h.contains(w.name)) { @@ -413,14 +414,14 @@ object InferWidths extends Pass { } } def b_sub (h:LinkedHashMap[String,Width])(w:Width) : Width = { - (wMap(b_sub(h) _,w)) match { + (w map (b_sub(h))) match { case (w:VarWidth) => if (h.contains(w.name)) h(w.name) else w case (w) => w } } def remove_cycle (n:String)(w:Width) : Width = { //;println-all-debug(["Removing cycle for " n " inside " w]) - val wx = (wMap(remove_cycle(n) _,w)) match { + val wx = (w map (remove_cycle(n))) match { case (w:MaxWidth) => MaxWidth(w.args.filter{ w => { w match { case (w:VarWidth) => !(n equals w.name) @@ -438,7 +439,7 @@ object InferWidths extends Pass { def self_rec (n:String,w:Width) : Boolean = { var has = false def look (w:Width) : Width = { - (wMap(look _,w)) match { + (w map (look)) match { case (w:VarWidth) => if (w.name == n) has = true case (w) => w } w } @@ -587,14 +588,14 @@ object InferWidths extends Pass { get_constraints_t(f1.tpe,f2.tpe,times(f1.flip,f)) }}} case (t1:VectorType,t2:VectorType) => get_constraints_t(t1.tpe,t2.tpe,f) }} def get_constraints_e (e:Expression) : Expression = { - (eMap(get_constraints_e _,e)) match { + (e map (get_constraints_e)) match { case (e:Mux) => { constrain(width_BANG(e.cond),ONE) constrain(ONE,width_BANG(e.cond)) e } case (e) => e }} def get_constraints (s:Stmt) : Stmt = { - (eMap(get_constraints_e _,s)) match { + (s map (get_constraints_e)) match { case (s:Connect) => { val n = get_size(tpe(s.loc)) val ce_loc = create_exps(s.loc) @@ -623,8 +624,8 @@ object InferWidths extends Pass { case (s:Conditionally) => { v += WGeq(width_BANG(s.pred),ONE) v += WGeq(ONE,width_BANG(s.pred)) - sMap(get_constraints _,s) } - case (s) => sMap(get_constraints _,s) }} + s map (get_constraints) } + case (s) => s map (get_constraints) }} for (m <- c.modules) { (m) match { @@ -646,7 +647,7 @@ object PullMuxes extends Pass { def name = "Pull Muxes" def run (c:Circuit): Circuit = { def pull_muxes_e (e:Expression) : Expression = { - val ex = eMap(pull_muxes_e _,e) match { + val ex = e map (pull_muxes_e) match { case (e:WRef) => e case (e:WSubField) => { e.exp match { @@ -673,9 +674,9 @@ object PullMuxes extends Pass { case (e:ValidIf) => e case (e) => e } - eMap(pull_muxes_e _,ex) + ex map (pull_muxes_e) } - def pull_muxes (s:Stmt) : Stmt = eMap(pull_muxes_e _,sMap(pull_muxes _,s)) + def pull_muxes (s:Stmt) : Stmt = s map (pull_muxes) map (pull_muxes_e) val modulesx = c.modules.map { m => { mname = m.name @@ -698,7 +699,7 @@ object ExpandConnects extends Pass { val genders = LinkedHashMap[String,Gender]() def expand_s (s:Stmt) : Stmt = { def set_gender (e:Expression) : Expression = { - eMap(set_gender _,e) match { + e map (set_gender) match { case (e:WRef) => WRef(e.name,e.tpe,e.kind,genders(e.name)) case (e:WSubField) => { val f = get_field(tpe(e.exp),e.name) @@ -768,7 +769,7 @@ object ExpandConnects extends Pass { }} Begin(connects) } - case (s) => sMap(expand_s _,s) + case (s) => s map (expand_s) } } @@ -845,7 +846,7 @@ object RemoveAccesses extends Pass { def rec_has_access (e:Expression) : Expression = { e match { case (e:WSubAccess) => { ret = true; e } - case (e) => eMap(rec_has_access _,e) + case (e) => e map (rec_has_access) } } rec_has_access(e) @@ -864,9 +865,9 @@ object RemoveAccesses extends Pass { } def remove_e (e:Expression) : Expression = { //NOT RECURSIVE (except primops) INTENTIONALLY! e match { - case (e:DoPrim) => eMap(remove_e,e) - case (e:Mux) => eMap(remove_e,e) - case (e:ValidIf) => eMap(remove_e,e) + case (e:DoPrim) => e map (remove_e) + case (e:Mux) => e map (remove_e) + case (e:ValidIf) => e map (remove_e) case (e:SIntValue) => e case (e:UIntValue) => e case e => { @@ -910,7 +911,7 @@ object RemoveAccesses extends Pass { Connect(s.info,locx,remove_e(s.exp)) } else { Connect(s.info,s.loc,remove_e(s.exp)) } } - case (s) => sMap(remove_s,eMap(remove_e,s)) + case (s) => s map (remove_e) map (remove_s) } stmts += sx if (stmts.size != 1) Begin(stmts) else stmts(0) @@ -979,7 +980,7 @@ object ExpandWhens extends Pass { } Begin(Seq(s,Begin(voids))) } - case (s) => sMap(void_all_s _,s) + case (s) => s map (void_all_s) } } val voids = ArrayBuffer[Stmt]() @@ -1003,7 +1004,7 @@ object ExpandWhens extends Pass { def prefetch (s:Stmt) : Stmt = { (s) match { case (s:Connect) => exps += s.loc; s - case (s) => sMap(prefetch _,s) + case (s) => s map(prefetch) } } prefetch(s.conseq) @@ -1042,7 +1043,7 @@ object ExpandWhens extends Pass { simlist += Stop(s.info,s.ret,s.clk,AND(p,s.en)) } } - case (s) => sMap(expand_whens(netlist,p) _, s) + case (s) => s map(expand_whens(netlist,p)) } s } @@ -1063,7 +1064,7 @@ object ExpandWhens extends Pass { def replace_void (e:Expression)(rvalue:Expression) : Expression = { (rvalue) match { case (rv:WVoid) => e - case (rv) => eMap(replace_void(e) _,rv) + case (rv) => rv map (replace_void(e)) } } def create (s:Stmt) : Stmt = { @@ -1091,7 +1092,7 @@ object ExpandWhens extends Pass { } } case (_:DefPoison|_:DefNode) => stmts += s - case (s) => sMap(create _,s) + case (s) => s map(create) } s } @@ -1131,7 +1132,7 @@ object ConstProp extends Pass { def name = "Constant Propogation" var mname = "" def const_prop_e (e:Expression) : Expression = { - eMap(const_prop_e _,e) match { + e map (const_prop_e) match { case (e:DoPrim) => { e.op match { case SHIFT_RIGHT_OP => { @@ -1173,7 +1174,7 @@ object ConstProp extends Pass { case (e) => e } } - def const_prop_s (s:Stmt) : Stmt = eMap(const_prop_e _, sMap(const_prop_s _,s)) + def const_prop_s (s:Stmt) : Stmt = s map (const_prop_s) map (const_prop_e) def run (c:Circuit): Circuit = { val modulesx = c.modules.map{ m => { m match { @@ -1202,7 +1203,7 @@ object VerilogWrap extends Pass { def name = "Verilog Wrap" var mname = "" def v_wrap_e (e:Expression) : Expression = { - eMap(v_wrap_e _,e) match { + e map (v_wrap_e) match { case (e:DoPrim) => { def a0 () = e.args(0) if (e.op == TAIL_OP) { @@ -1220,7 +1221,7 @@ object VerilogWrap extends Pass { case (e) => e } } - def v_wrap_s (s:Stmt) : Stmt = eMap(v_wrap_e _,sMap(v_wrap_s _,s)) + def v_wrap_s (s:Stmt) : Stmt = s map (v_wrap_s) map (v_wrap_e) def run (c:Circuit): Circuit = { val modulesx = c.modules.map{ m => { (m) match { @@ -1248,19 +1249,19 @@ object SplitExp extends Pass { WRef(n,tpe(e),kind(e),gender(e)) } def split_exp_e (i:Int)(e:Expression) : Expression = { - eMap(split_exp_e(i + 1) _,e) match { + e map (split_exp_e(i + 1)) match { case (e:DoPrim) => if (i > 0) split(e) else e case (e) => e } } s match { - case (s:Begin) => sMap(split_exp_s _,s) + case (s:Begin) => s map (split_exp_s) case (s:Print) => { - val sx = eMap(split_exp_e(1) _,s) + val sx = s map (split_exp_e(1)) v += sx; sx } case (s) => { - val sx = eMap(split_exp_e(0) _,s) + val sx = s map (split_exp_e(0)) v += sx; sx } } @@ -1289,11 +1290,11 @@ object VerilogRename extends Pass { def verilog_rename_e (e:Expression) : Expression = { (e) match { case (e:WRef) => WRef(verilog_rename_n(e.name),e.tpe,kind(e),gender(e)) - case (e) => eMap(verilog_rename_e,e) + case (e) => e map (verilog_rename_e) } } def verilog_rename_s (s:Stmt) : Stmt = { - stMap(verilog_rename_n _,eMap(verilog_rename_e _,sMap(verilog_rename_s _,s))) + s map (verilog_rename_s) map (verilog_rename_e) map (verilog_rename_n) } val modulesx = c.modules.map{ m => { val portsx = m.ports.map{ p => { @@ -1341,7 +1342,7 @@ object LowerTypes extends Pass { def expand_name (e:Expression) : Seq[String] = { val names = ArrayBuffer[String]() def expand_name_e (e:Expression) : Expression = { - (eMap(expand_name_e _,e)) match { + (e map (expand_name_e)) match { case (e:WRef) => names += e.name case (e:WSubField) => names += e.name case (e:WSubIndex) => names += e.value.toString @@ -1418,9 +1419,9 @@ object LowerTypes extends Pass { case (k) => WRef(lowered_name(e),tpe(e),kind(e),gender(e)) } } - case (e:DoPrim) => eMap(lower_types_e _,e) - case (e:Mux) => eMap(lower_types_e _,e) - case (e:ValidIf) => eMap(lower_types_e _,e) + case (e:DoPrim) => e map (lower_types_e) + case (e:Mux) => e map (lower_types_e) + case (e:ValidIf) => e map (lower_types_e) } } (s) match { @@ -1476,7 +1477,7 @@ object LowerTypes extends Pass { } } case (s:IsInvalid) => { - val sx = eMap(lower_types_e _,s).as[IsInvalid].get + val sx = (s map (lower_types_e)).as[IsInvalid].get kind(sx.exp) match { case (k:MemKind) => { val es = lower_mem(sx.exp) @@ -1486,7 +1487,7 @@ object LowerTypes extends Pass { } } case (s:Connect) => { - val sx = eMap(lower_types_e _,s).as[Connect].get + val sx = (s map (lower_types_e)).as[Connect].get kind(sx.loc) match { case (k:MemKind) => { val es = lower_mem(sx.loc) @@ -1507,7 +1508,7 @@ object LowerTypes extends Pass { } if (n == 1) nodes(0) else Begin(nodes) } - case (s) => eMap(lower_types_e _,sMap(lower_types _,s)) + case (s) => s map (lower_types) map (lower_types_e) } } @@ -1567,7 +1568,7 @@ object CInferTypes extends Pass { def infer_types (m:Module) : Module = { val types = LinkedHashMap[String,Type]() def infer_types_e (e:Expression) : Expression = { - (eMap(infer_types_e _,e)) match { + (e map (infer_types_e)) match { case (e:Ref) => Ref(e.name, types.getOrElse(e.name,UnknownType())) case (e:SubField) => SubField(e.exp,e.name,field_type(tpe(e.exp),e.name)) case (e:SubIndex) => SubIndex(e.exp,e.value,sub_type(tpe(e.exp))) @@ -1582,7 +1583,7 @@ object CInferTypes extends Pass { (s) match { case (s:DefRegister) => { types(s.name) = s.tpe - eMap(infer_types_e _,s) + s map (infer_types_e) s } case (s:DefWire) => { @@ -1594,7 +1595,7 @@ object CInferTypes extends Pass { s } case (s:DefNode) => { - val sx = eMap(infer_types_e _,s) + val sx = s map (infer_types_e) val t = get_type(sx) types(s.name) = t sx @@ -1616,7 +1617,7 @@ object CInferTypes extends Pass { types(s.name) = module_types.getOrElse(s.module,UnknownType()) s } - case (s) => eMap(infer_types_e _,sMap(infer_types_s _,s)) + case (s) => s map(infer_types_s) map (infer_types_e) } } for (p <- m.ports) { @@ -1644,7 +1645,7 @@ object CInferMDir extends Pass { def infer_mdir (m:Module) : Module = { val mports = LinkedHashMap[String,MPortDir]() def infer_mdir_e (dir:MPortDir)(e:Expression) : Expression = { - (eMap(infer_mdir_e(dir) _,e)) match { + (e map (infer_mdir_e(dir))) match { case (e:Ref) => { if (mports.contains(e.name)) { val new_mport_dir = { @@ -1678,7 +1679,7 @@ object CInferMDir extends Pass { (s) match { case (s:CDefMPort) => { mports(s.name) = s.direction - eMap(infer_mdir_e(MRead) _,s) + s map (infer_mdir_e(MRead)) } case (s:Connect) => { infer_mdir_e(MRead)(s.exp) @@ -1690,14 +1691,14 @@ object CInferMDir extends Pass { infer_mdir_e(MWrite)(s.loc) s } - case (s) => eMap(infer_mdir_e(MRead) _, sMap(infer_mdir_s,s)) + case (s) => s map (infer_mdir_s) map (infer_mdir_e(MRead)) } } def set_mdir_s (s:Stmt) : Stmt = { (s) match { case (s:CDefMPort) => CDefMPort(s.info,s.name,s.tpe,s.mem,s.exps,mports(s.name)) - case (s) => sMap(set_mdir_s _,s) + case (s) => s map (set_mdir_s) } } (m) match { @@ -1760,7 +1761,7 @@ object RemoveCHIRRTL extends Pass { hash(s.mem) = mports s } - case (s) => sMap(collect_mports _,s) + case (s) => s map (collect_mports) } } def collect_refs (s:Stmt) : Stmt = { @@ -1840,7 +1841,7 @@ object RemoveCHIRRTL extends Pass { } Begin(stmts) } - case (s) => sMap(collect_refs _,s) + case (s) => s map (collect_refs) } } def remove_chirrtl_s (s:Stmt) : Stmt = { @@ -1863,11 +1864,11 @@ object RemoveCHIRRTL extends Pass { } else e } case (e:SubAccess) => SubAccess(remove_chirrtl_e(g)(e.exp),remove_chirrtl_e(MALE)(e.index),e.tpe) - case (e) => eMap(remove_chirrtl_e(g) _,e) + case (e) => e map (remove_chirrtl_e(g)) } } def get_mask (e:Expression) : Expression = { - (eMap(get_mask _,e)) match { + (e map (get_mask)) match { case (e:Ref) => { if (repl.contains(e.name)) { val vt = repl(e.name) @@ -1917,7 +1918,7 @@ object RemoveCHIRRTL extends Pass { if (stmts.size > 1) Begin(stmts) else stmts(0) } - case (s) => eMap(remove_chirrtl_e(MALE) _, sMap(remove_chirrtl_s,s)) + case (s) => s map (remove_chirrtl_s) map (remove_chirrtl_e(MALE)) } } collect_mports(m.body) |
