diff options
| author | jackkoenig | 2016-03-01 12:15:28 -0800 |
|---|---|---|
| committer | jackkoenig | 2016-03-01 12:21:11 -0800 |
| commit | 079005f630590bdaf4671c9d8ab127b649cd61df (patch) | |
| tree | 94885d84691570e43a59684d9facf71e10bdab0f /src | |
| parent | aa2322eb09e9059ad1cdf066c3e7270e0b98679d (diff) | |
Move mapper functions to implicit methods on IR vertices.
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Emitter.scala | 5 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Mappers.scala | 197 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 164 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Checks.scala | 53 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 137 |
5 files changed, 301 insertions, 255 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index b099dd55..e9c90fd6 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -37,6 +37,7 @@ import scala.io.Source import Utils._ import firrtl.Serialize._ +import firrtl.Mappers._ import firrtl.passes._ import WrappedExpression._ // Datastructures @@ -279,7 +280,7 @@ object VerilogEmitter extends Emitter { val e = WRef(s.name,get_type(s),NodeKind(),MALE) netlist(e) = s.value } - case (s) => sMap(build_netlist,s) + case (s) => s map (build_netlist) } s } @@ -529,7 +530,7 @@ object VerilogEmitter extends Emitter { update(wmem_port,datax,clk,AND(AND(enx,maskx),wmode)) } } - case (s:Begin) => sMap(build_streams _,s) + case (s:Begin) => s map (build_streams) } s } diff --git a/src/main/scala/firrtl/Mappers.scala b/src/main/scala/firrtl/Mappers.scala new file mode 100644 index 00000000..e8d9e072 --- /dev/null +++ b/src/main/scala/firrtl/Mappers.scala @@ -0,0 +1,197 @@ +/* +Copyright (c) 2014 - 2016 The Regents of the University of +California (Regents). All Rights Reserved. Redistribution and use in +source and binary forms, with or without modification, are permitted +provided that the following conditions are met: + * Redistributions of source code must retain the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer in the documentation and/or other materials + provided with the distribution. + * Neither the name of the Regents nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. +IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, +SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, +ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF +REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF +ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION +TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR +MODIFICATIONS. +*/ + +package firrtl + +// TODO: Implement remaining mappers and recursive mappers +object Mappers { + + // ********** Stmt Mappers ********** + private trait StmtMagnet { + def map(stmt: Stmt): Stmt + } + private object StmtMagnet { + implicit def forStmt(f: Stmt => Stmt) = new StmtMagnet { + override def map(stmt: Stmt): Stmt = { + stmt match { + case s: Conditionally => Conditionally(s.info, s.pred, f(s.conseq), f(s.alt)) + case s: Begin => Begin(s.stmts.map(f)) + case s: Stmt => s + } + } + } + implicit def forExp(f: Expression => Expression) = new StmtMagnet { + override def map(stmt: Stmt): Stmt = { + stmt match { + case s: DefRegister => DefRegister(s.info, s.name, s.tpe, f(s.clock), f(s.reset), f(s.init)) + case s: DefNode => DefNode(s.info, s.name, f(s.value)) + case s: Connect => Connect(s.info, f(s.loc), f(s.exp)) + case s: BulkConnect => BulkConnect(s.info, f(s.loc), f(s.exp)) + case s: Conditionally => Conditionally(s.info, f(s.pred), s.conseq, s.alt) + case s: IsInvalid => IsInvalid(s.info, f(s.exp)) + case s: Stop => Stop(s.info, s.ret, f(s.clk), f(s.en)) + case s: Print => Print(s.info, s.string, s.args.map(f), f(s.clk), f(s.en)) + case s: CDefMPort => CDefMPort(s.info,s.name,s.tpe,s.mem,s.exps.map(f),s.direction) + case s: Stmt => s + } + } + } + implicit def forType(f: Type => Type) = new StmtMagnet { + override def map(stmt: Stmt) : Stmt = { + stmt match { + case s:DefPoison => DefPoison(s.info,s.name,f(s.tpe)) + case s:DefWire => DefWire(s.info,s.name,f(s.tpe)) + case s:DefRegister => DefRegister(s.info,s.name,f(s.tpe),s.clock,s.reset,s.init) + case s:DefMemory => DefMemory(s.info,s.name, f(s.data_type), s.depth, s.write_latency, s.read_latency, s.readers, s.writers, s.readwriters) + case s:CDefMemory => CDefMemory(s.info,s.name, f(s.tpe), s.size, s.seq) + case s:CDefMPort => CDefMPort(s.info,s.name, f(s.tpe), s.mem, s.exps,s.direction) + case s => s + } + } + } + implicit def forString(f: String => String) = new StmtMagnet { + override def map(stmt: Stmt): Stmt = { + stmt match { + case s: DefWire => DefWire(s.info,f(s.name),s.tpe) + case s: DefPoison => DefPoison(s.info,f(s.name),s.tpe) + case s: DefRegister => DefRegister(s.info,f(s.name), s.tpe, s.clock, s.reset, s.init) + case s: DefMemory => DefMemory(s.info,f(s.name), s.data_type, s.depth, s.write_latency, s.read_latency, s.readers, s.writers, s.readwriters) + case s: DefNode => DefNode(s.info,f(s.name),s.value) + case s: DefInstance => DefInstance(s.info,f(s.name), s.module) + case s: WDefInstance => WDefInstance(s.info,f(s.name), s.module,s.tpe) + case s: CDefMemory => CDefMemory(s.info,f(s.name),s.tpe,s.size,s.seq) + case s: CDefMPort => CDefMPort(s.info,f(s.name),s.tpe,s.mem,s.exps,s.direction) + case s => s + } + } + } + } + implicit class StmtMap(stmt: Stmt) { + // Using implicit types to allow overloading of function type to map, see StmtMagnet above + def map[T](f: T => T)(implicit magnet: (T => T) => StmtMagnet): Stmt = magnet(f).map(stmt) + } + + // ********** Expression Mappers ********** + private trait ExpMagnet { + def map(exp: Expression): Expression + } + private object ExpMagnet { + implicit def forExp(f: Expression => Expression) = new ExpMagnet { + override def map(exp: Expression): Expression = { + exp match { + case e: SubField => SubField(f(e.exp), e.name, e.tpe) + case e: SubIndex => SubIndex(f(e.exp), e.value, e.tpe) + case e: SubAccess => SubAccess(f(e.exp), f(e.index), e.tpe) + case e: Mux => Mux(f(e.cond), f(e.tval), f(e.fval), e.tpe) + case e: ValidIf => ValidIf(f(e.cond), f(e.value), e.tpe) + case e: DoPrim => DoPrim(e.op, e.args.map(f), e.consts, e.tpe) + case e: WSubField => WSubField(f(e.exp), e.name, e.tpe, e.gender) + case e: WSubIndex => WSubIndex(f(e.exp), e.value, e.tpe, e.gender) + case e: WSubAccess => WSubAccess(f(e.exp), f(e.index), e.tpe, e.gender) + case e: Expression => e + } + } + } + implicit def forType(f: Type => Type) = new ExpMagnet { + override def map(exp: Expression): Expression = { + exp match { + case e: DoPrim => DoPrim(e.op,e.args,e.consts,f(e.tpe)) + case e: Mux => Mux(e.cond,e.tval,e.fval,f(e.tpe)) + case e: ValidIf => ValidIf(e.cond,e.value,f(e.tpe)) + case e: WRef => WRef(e.name,f(e.tpe),e.kind,e.gender) + case e: WSubField => WSubField(e.exp,e.name,f(e.tpe),e.gender) + case e: WSubIndex => WSubIndex(e.exp,e.value,f(e.tpe),e.gender) + case e: WSubAccess => WSubAccess(e.exp,e.index,f(e.tpe),e.gender) + case e => e + } + } + } + implicit def forWidth(f: Width => Width) = new ExpMagnet { + override def map(exp: Expression): Expression = { + exp match { + case e: UIntValue => UIntValue(e.value,f(e.width)) + case e: SIntValue => SIntValue(e.value,f(e.width)) + case e => e + } + } + } + } + implicit class ExpMap(exp: Expression) { + def map[T](f: T => T)(implicit magnet: (T => T) => ExpMagnet): Expression = magnet(f).map(exp) + } + + // ********** Type Mappers ********** + private trait TypeMagnet { + def map(tpe: Type): Type + } + private object TypeMagnet { + implicit def forType(f: Type => Type) = new TypeMagnet { + override def map(tpe: Type): Type = { + tpe match { + case t: BundleType => BundleType(t.fields.map(p => Field(p.name, p.flip, f(p.tpe)))) + case t: VectorType => VectorType(f(t.tpe), t.size) + case t => t + } + } + } + implicit def forWidth(f: Width => Width) = new TypeMagnet { + override def map(tpe: Type): Type = { + tpe match { + case t: UIntType => UIntType(f(t.width)) + case t: SIntType => SIntType(f(t.width)) + case t => t + } + } + } + } + implicit class TypeMap(tpe: Type) { + def map[T](f: T => T)(implicit magnet: (T => T) => TypeMagnet): Type = magnet(f).map(tpe) + } + + // ********** Width Mappers ********** + private trait WidthMagnet { + def map(width: Width): Width + } + private object WidthMagnet { + implicit def forWidth(f: Width => Width) = new WidthMagnet { + override def map(width: Width): Width = { + width match { + case w: MaxWidth => MaxWidth(w.args.map(f)) + case w: MinWidth => MinWidth(w.args.map(f)) + case w: PlusWidth => PlusWidth(f(w.arg1),f(w.arg2)) + case w: MinusWidth => MinusWidth(f(w.arg1),f(w.arg2)) + case w: ExpWidth => ExpWidth(f(w.arg1)) + case w => w + } + } + } + } + implicit class WidthMap(width: Width) { + def map[T](f: T => T)(implicit magnet: (T => T) => WidthMagnet): Width = magnet(f).map(width) + } + +} diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 517791e3..390a99a7 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -28,7 +28,6 @@ MODIFICATIONS. /* TODO * - Adopt style more similar to Chisel3 Emitter? - * - Find way to have generic map function instead of mapE and mapS under Stmt implicits */ /* TODO Richard @@ -42,6 +41,7 @@ import java.io.PrintWriter import PrimOps._ import WrappedExpression._ import firrtl.WrappedType._ +import firrtl.Mappers._ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.LinkedHashMap //import scala.reflect.runtime.universe._ @@ -581,123 +581,15 @@ object Utils { }} -// =============== MAPPERS =================== - def sMap(f:Stmt => Stmt, stmt: Stmt): Stmt = - stmt match { - case w: Conditionally => Conditionally(w.info, w.pred, f(w.conseq), f(w.alt)) - case b: Begin => { - val stmtsx = ArrayBuffer[Stmt]() - for (i <- 0 until b.stmts.size) { - stmtsx += f(b.stmts(i)) - } - Begin(stmtsx) - } - case s: Stmt => s - } - def eMap(f:Expression => Expression, stmt:Stmt) : Stmt = - stmt match { - case r: DefRegister => DefRegister(r.info, r.name, r.tpe, f(r.clock), f(r.reset), f(r.init)) - case n: DefNode => DefNode(n.info, n.name, f(n.value)) - case c: Connect => Connect(c.info, f(c.loc), f(c.exp)) - case b: BulkConnect => BulkConnect(b.info, f(b.loc), f(b.exp)) - case w: Conditionally => Conditionally(w.info, f(w.pred), w.conseq, w.alt) - case i: IsInvalid => IsInvalid(i.info, f(i.exp)) - case s: Stop => Stop(s.info, s.ret, f(s.clk), f(s.en)) - case p: Print => Print(p.info, p.string, p.args.map(f), f(p.clk), f(p.en)) - case c: CDefMPort => CDefMPort(c.info,c.name,c.tpe,c.mem,c.exps.map(f),c.direction) - case s: Stmt => s - } - def eMap(f: Expression => Expression, exp:Expression): Expression = - exp match { - case s: SubField => SubField(f(s.exp), s.name, s.tpe) - case s: SubIndex => SubIndex(f(s.exp), s.value, s.tpe) - case s: SubAccess => SubAccess(f(s.exp), f(s.index), s.tpe) - case m: Mux => Mux(f(m.cond), f(m.tval), f(m.fval), m.tpe) - case v: ValidIf => ValidIf(f(v.cond), f(v.value), v.tpe) - case p: DoPrim => DoPrim(p.op, p.args.map(f), p.consts, p.tpe) - case s: WSubField => WSubField(f(s.exp), s.name, s.tpe, s.gender) - case s: WSubIndex => WSubIndex(f(s.exp), s.value, s.tpe, s.gender) - case s: WSubAccess => WSubAccess(f(s.exp), f(s.index), s.tpe, s.gender) - case e: Expression => e - } - def tMap (f: Type => Type, t:Type):Type = { - t match { - case t:BundleType => BundleType(t.fields.map(p => Field(p.name, p.flip, f(p.tpe)))) - case t:VectorType => VectorType(f(t.tpe), t.size) - case t => t - } - } - def tMap (f: Type => Type, c:Expression) : Expression = { - c match { - case c:DoPrim => DoPrim(c.op,c.args,c.consts,f(c.tpe)) - case c:Mux => Mux(c.cond,c.tval,c.fval,f(c.tpe)) - case c:ValidIf => ValidIf(c.cond,c.value,f(c.tpe)) - case c:WRef => WRef(c.name,f(c.tpe),c.kind,c.gender) - case c:WSubField => WSubField(c.exp,c.name,f(c.tpe),c.gender) - case c:WSubIndex => WSubIndex(c.exp,c.value,f(c.tpe),c.gender) - case c:WSubAccess => WSubAccess(c.exp,c.index,f(c.tpe),c.gender) - case c => c - } - } - def tMap (f: Type => Type, c:Stmt) : Stmt = { - c match { - case c:DefPoison => DefPoison(c.info,c.name,f(c.tpe)) - case c:DefWire => DefWire(c.info,c.name,f(c.tpe)) - case c:DefRegister => DefRegister(c.info,c.name,f(c.tpe),c.clock,c.reset,c.init) - case c:DefMemory => DefMemory(c.info,c.name, f(c.data_type), c.depth, c.write_latency, c.read_latency, c.readers, c.writers, c.readwriters) - case c:CDefMemory => CDefMemory(c.info,c.name, f(c.tpe), c.size, c.seq) - case c:CDefMPort => CDefMPort(c.info,c.name, f(c.tpe), c.mem, c.exps,c.direction) - case c => c - } - } - def wMap (f: Width => Width, c:Expression) : Expression = { - c match { - case c:UIntValue => UIntValue(c.value,f(c.width)) - case c:SIntValue => SIntValue(c.value,f(c.width)) - case c => c - } - } - def wMap (f: Width => Width, c:Type) : Type = { - c match { - case c:UIntType => UIntType(f(c.width)) - case c:SIntType => SIntType(f(c.width)) - case c => c - } - } - def wMap (f: Width => Width, w:Width) : Width = { - w match { - case w:MaxWidth => MaxWidth(w.args.map(f)) - case w:MinWidth => MinWidth(w.args.map(f)) - case w:PlusWidth => PlusWidth(f(w.arg1),f(w.arg2)) - case w:MinusWidth => MinusWidth(f(w.arg1),f(w.arg2)) - case w:ExpWidth => ExpWidth(f(w.arg1)) - case w => w - } - } - def stMap (f: String => String, c:Stmt) : Stmt = { - c match { - case (c:DefWire) => DefWire(c.info,f(c.name),c.tpe) - case (c:DefPoison) => DefPoison(c.info,f(c.name),c.tpe) - case (c:DefRegister) => DefRegister(c.info,f(c.name), c.tpe, c.clock, c.reset, c.init) - case (c:DefMemory) => DefMemory(c.info,f(c.name), c.data_type, c.depth, c.write_latency, c.read_latency, c.readers, c.writers, c.readwriters) - case (c:DefNode) => DefNode(c.info,f(c.name),c.value) - case (c:DefInstance) => DefInstance(c.info,f(c.name), c.module) - case (c:WDefInstance) => WDefInstance(c.info,f(c.name), c.module,c.tpe) - case (c:CDefMemory) => CDefMemory(c.info,f(c.name),c.tpe,c.size,c.seq) - case (c:CDefMPort) => CDefMPort(c.info,f(c.name),c.tpe,c.mem,c.exps,c.direction) - case (c) => c - } - } +// =============== RECURISVE MAPPERS =================== def mapr (f: Width => Width, t:Type) : Type = { - def apply_t (t:Type) : Type = wMap(f,tMap(apply_t _,t)) + def apply_t (t:Type) : Type = t map (apply_t) map (f) apply_t(t) } def mapr (f: Width => Width, s:Stmt) : Stmt = { def apply_t (t:Type) : Type = mapr(f,t) - def apply_e (e:Expression) : Expression = - wMap(f,tMap(apply_t _,eMap(apply_e _,e))) - def apply_s (s:Stmt) : Stmt = - tMap(apply_t _,eMap(apply_e _,sMap(apply_s _,s))) + def apply_e (e:Expression) : Expression = e map (apply_e) map (apply_t) map (f) + def apply_s (s:Stmt) : Stmt = s map (apply_s) map (apply_e) map (apply_t) apply_s(s) } val ONE = IntWidth(1) @@ -751,54 +643,8 @@ object Utils { // to-stmt(body(m)) // map(to-port,ports(m)) // sym-hash - //private trait StmtMagnet { - // def map(stmt: Stmt): Stmt - //} - //private object StmtMagnet { - // implicit def forStmt(f: Stmt => Stmt) = new StmtMagnet { - // override def map(stmt: Stmt): Stmt = - // stmt match { - // case w: Conditionally => Conditionally(w.info, w.pred, f(w.conseq), f(w.alt)) - // case b: Begin => Begin(b.stmts.map(f)) - // case s: Stmt => s - // } - // } - // implicit def forExp(f: Expression => Expression) = new StmtMagnet { - // override def map(stmt: Stmt): Stmt = - // stmt match { - // case r: DefRegister => DefRegister(r.info, r.name, r.tpe, f(r.clock), f(r.reset), f(r.init)) - // case n: DefNode => DefNode(n.info, n.name, f(n.value)) - // case c: Connect => Connect(c.info, f(c.loc), f(c.exp)) - // case b: BulkConnect => BulkConnect(b.info, f(b.loc), f(b.exp)) - // case w: Conditionally => Conditionally(w.info, f(w.pred), w.conseq, w.alt) - // case i: IsInvalid => IsInvalid(i.info, f(i.exp)) - // case s: Stop => Stop(s.info, s.ret, f(s.clk), f(s.en)) - // case p: Print => Print(p.info, p.string, p.args.map(f), f(p.clk), f(p.en)) - // case s: Stmt => s - // } - // } - //} - - // def map(f: Expression => Expression): Expression = - // exp match { - // case s: SubField => SubField(f(s.exp), s.name, s.tpe) - // case s: SubIndex => SubIndex(f(s.exp), s.value, s.tpe) - // case s: SubAccess => SubAccess(f(s.exp), f(s.index), s.tpe) - // case m: Mux => Mux(f(m.cond), f(m.tval), f(m.fval), m.tpe) - // case v: ValidIf => ValidIf(f(v.cond), f(v.value), v.tpe) - // case p: DoPrim => DoPrim(p.op, p.args.map(f), p.consts, p.tpe) - // case s: WSubField => SubField(f(s.exp), s.name, s.tpe, s.gender) - // case s: WSubIndex => SubIndex(f(s.exp), s.value, s.tpe, s.gender) - // case s: WSubAccess => SubAccess(f(s.exp), f(s.index), s.tpe, s.gender) - // case e: Expression => e - // } - //} - implicit class StmtUtils(stmt: Stmt) { - // Using implicit types to allow overloading of function type to map, see StmtMagnet above - //def map[T](f: T => T)(implicit magnet: (T => T) => StmtMagnet): Stmt = magnet(f).map(stmt) - def getType(): Type = stmt match { case s: DefWire => s.tpe diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index 7e9c199d..9f163b3e 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -35,6 +35,7 @@ import scala.collection.mutable.ArrayBuffer import firrtl._ import firrtl.Utils._ +import firrtl.Mappers._ import firrtl.Serialize._ import firrtl.PrimOps._ import firrtl.WrappedType._ @@ -107,7 +108,7 @@ object CheckHighForm extends Pass with LazyLogging { } } findFlip(t) - tMap(findFlip _, t) + t map (findFlip) has } @@ -192,12 +193,12 @@ object CheckHighForm extends Pass with LazyLogging { w } def checkHighFormT(t: Type): Type = { - tMap(checkHighFormT _, t) match { + t map (checkHighFormT) match { case t: VectorType => if (t.size < 0) errors.append(new NegVecSizeException) case _ => // Do nothing } - wMap(checkHighFormW _, t) + t map (checkHighFormW) } def checkHighFormM(m: Module): Module = { @@ -212,7 +213,7 @@ object CheckHighForm extends Pass with LazyLogging { } e } - eMap(checkHighFormE _, e) match { + e map (checkHighFormE) match { case e: WRef => if (!names.contains(e.name)) errors.append(new UndeclaredReferenceException(e.name)) case e: DoPrim => checkHighFormPrimop(e) @@ -223,10 +224,10 @@ object CheckHighForm extends Pass with LazyLogging { } case e: UIntValue => if (e.value < 0) errors.append(new NegUIntException) - case e => eMap(validSubexp _, e) + case e => e map (validSubexp) } - wMap(checkHighFormW _, e) - tMap(checkHighFormT _, e) + e map (checkHighFormW) + e map (checkHighFormT) e } def checkHighFormS(s: Stmt): Stmt = { @@ -239,9 +240,9 @@ object CheckHighForm extends Pass with LazyLogging { } sinfo = s.getInfo - stMap(checkName _, s) - tMap(checkHighFormT _, s) - eMap(checkHighFormE _, s) + s map (checkName) + s map (checkHighFormT) + s map (checkHighFormE) s match { case s: DefPoison => { if (hasFlip(s.tpe)) errors.append(new PoisonWithFlipException(s.name)) @@ -261,7 +262,7 @@ object CheckHighForm extends Pass with LazyLogging { case _ => // Do Nothing } - sMap(checkHighFormS _, s) + s map (checkHighFormS) } mname = m.name @@ -272,8 +273,8 @@ object CheckHighForm extends Pass with LazyLogging { // FIXME should we set sinfo here? names(p.name) = true val tpe = p.getType - tMap(checkHighFormT _, tpe) - wMap(checkHighFormW _, tpe) + tpe map (checkHighFormT) + tpe map (checkHighFormW) } m match { @@ -408,7 +409,7 @@ object CheckTypes extends Pass with LazyLogging { } } def check_types_e (info:Info)(e:Expression) : Expression = { - (eMap(check_types_e(info) _,e)) match { + (e map (check_types_e(info))) match { case (e:WRef) => e case (e:WSubField) => { (tpe(e.exp)) match { @@ -477,7 +478,7 @@ object CheckTypes extends Pass with LazyLogging { } def check_types_s (s:Stmt) : Stmt = { - eMap(check_types_e(get_info(s)) _,s) match { + s map (check_types_e(get_info(s))) match { case (s:Connect) => if (wt(tpe(s.loc)) != wt(tpe(s.exp))) errors += new InvalidConnect(s.info) case (s:BulkConnect) => if (!bulk_equals(tpe(s.loc),tpe(s.exp)) ) errors += new InvalidConnect(s.info) case (s:Stop) => { @@ -495,7 +496,7 @@ object CheckTypes extends Pass with LazyLogging { case (s:DefNode) => if (!passive(tpe(s.value)) ) errors += new NodePassiveType(s.info) case (s) => false } - sMap(check_types_s,s) + s map (check_types_s) } for (m <- c.modules ) { @@ -603,15 +604,15 @@ object CheckGenders extends Pass { } def check_genders_e (info:Info,genders:HashMap[String,Gender])(e:Expression) : Expression = { - eMap(check_genders_e(info,genders) _,e) + e map (check_genders_e(info,genders)) (e) match { case (e:WRef) => false case (e:WSubField) => false case (e:WSubIndex) => false case (e:WSubAccess) => false case (e:DoPrim) => for (e <- e.args ) { check_gender(info,genders,MALE)(e) } - case (e:Mux) => eMap(check_gender(info,genders,MALE) _,e) - case (e:ValidIf) => eMap(check_gender(info,genders,MALE) _,e) + case (e:Mux) => e map (check_gender(info,genders,MALE)) + case (e:ValidIf) => e map (check_gender(info,genders,MALE)) case (e:UIntValue) => false case (e:SIntValue) => false } @@ -619,8 +620,8 @@ object CheckGenders extends Pass { } def check_genders_s (genders:HashMap[String,Gender])(s:Stmt) : Stmt = { - eMap(check_genders_e(get_info(s),genders) _,s) - sMap(check_genders_s(genders) _,s) + s map (check_genders_e(get_info(s),genders)) + s map (check_genders_s(genders)) (s) match { case (s:DefWire) => genders(s.name) = BIGENDER case (s:DefPoison) => genders(s.name) = MALE @@ -692,7 +693,7 @@ object CheckWidths extends Pass with StanzaPass { w } def check_width_e (info:Info)(e:Expression) : Expression = { - (eMap(check_width_e(info) _,e)) match { + (e map (check_width_e(info))) match { case (e:UIntValue) => { (e.width) match { case (w:IntWidth) => @@ -717,9 +718,9 @@ object CheckWidths extends Pass with StanzaPass { e } def check_width_s (s:Stmt) : Stmt = { - eMap(check_width_e(get_info(s)) _,sMap(check_width_s _,s)) + s map (check_width_s) map (check_width_e(get_info(s))) def tm (t:Type) : Type = mapr(check_width_w(info(s)) _,t) - tMap(tm _,s) + s map (tm) } for (p <- m.ports) { @@ -761,7 +762,7 @@ object CheckInitialization extends Pass with StanzaPass { def has_void (e:Expression) : Expression = { (e) match { case (e:WVoid) => void = true; e - case (e) => eMap(has_void,e) + case (e) => e map (has_void) } } has_void(e) @@ -773,7 +774,7 @@ object CheckInitialization extends Pass with StanzaPass { if (has_voidQ(s.exp)) errors += new RefNotInitialized(s.info,get_name(s.loc)) s } - case (s) => sMap(check_init_s,s) + case (s) => s map (check_init_s) } } check_init_s(m.body) diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 8a2fb5c8..7490c479 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -40,6 +40,7 @@ import scala.collection.mutable.ArrayBuffer import firrtl._ import firrtl.Utils._ +import firrtl.Mappers._ import firrtl.Serialize._ import firrtl.PrimOps._ import firrtl.WrappedExpression._ @@ -99,7 +100,7 @@ object ToWorkingIR extends Pass { def name = "Working IR" def run (c:Circuit): Circuit = { def toExp (e:Expression) : Expression = { - eMap(toExp _,e) match { + e map (toExp) match { case e:Ref => WRef(e.name, e.tpe, NodeKind(), UNKNOWNGENDER) case e:SubField => WSubField(e.exp, e.name, e.tpe, UNKNOWNGENDER) case e:SubIndex => WSubIndex(e.exp, e.value, e.tpe, UNKNOWNGENDER) @@ -108,9 +109,9 @@ object ToWorkingIR extends Pass { } } def toStmt (s:Stmt) : Stmt = { - eMap(toExp _,s) match { + s map (toExp) match { case s:DefInstance => WDefInstance(s.info,s.name,s.module,UnknownType()) - case s => sMap(toStmt _,s) + case s => s map (toStmt) } } val modulesx = c.modules.map { m => @@ -139,10 +140,10 @@ object ResolveKinds extends Pass { def resolve_expr (e:Expression):Expression = { e match { case e:WRef => WRef(e.name,tpe(e),kinds(e.name),e.gender) - case e => eMap(resolve_expr,e) + case e => e map (resolve_expr) } } - def resolve_stmt (s:Stmt):Stmt = eMap(resolve_expr,sMap(resolve_stmt,s)) + def resolve_stmt (s:Stmt):Stmt = s map (resolve_stmt) map (resolve_expr) resolve_stmt(body) } @@ -157,7 +158,7 @@ object ResolveKinds extends Pass { case s:DefMemory => kinds(s.name) = MemKind(s.readers ++ s.writers ++ s.readwriters) case s => false } - sMap(find_stmt,s) + s map (find_stmt) } m.ports.foreach { p => kinds(p.name) = PortKind() } m match { @@ -206,7 +207,7 @@ object InferTypes extends Pass { def infer_types (m:Module) : Module = { val types = LinkedHashMap[String,Type]() def infer_types_e (e:Expression) : Expression = { - eMap(infer_types_e _,e) match { + e map (infer_types_e) match { case e:ValidIf => ValidIf(e.cond,e.value,tpe(e.value)) case e:WRef => WRef(e.name, types(e.name),e.kind,e.gender) case e:WSubField => WSubField(e.exp,e.name,field_type(tpe(e.exp),e.name),e.gender) @@ -223,22 +224,22 @@ object InferTypes extends Pass { case s:DefRegister => { val t = remove_unknowns(get_type(s)) types(s.name) = t - eMap(infer_types_e _,set_type(s,t)) + set_type(s,t) map (infer_types_e) } case s:DefWire => { - val sx = eMap(infer_types_e _,s) + val sx = s map(infer_types_e) val t = remove_unknowns(get_type(sx)) types(s.name) = t set_type(sx,t) } case s:DefPoison => { - val sx = eMap(infer_types_e _,s) + val sx = s map (infer_types_e) val t = remove_unknowns(get_type(sx)) types(s.name) = t set_type(sx,t) } case s:DefNode => { - val sx = eMap(infer_types_e _,s) + val sx = s map (infer_types_e) val t = remove_unknowns(get_type(sx)) types(s.name) = t set_type(sx,t) @@ -253,7 +254,7 @@ object InferTypes extends Pass { types(s.name) = module_types(s.module) WDefInstance(s.info,s.name,s.module,module_types(s.module)) } - case s => eMap(infer_types_e _,sMap(infer_types_s,s)) + case s => s map (infer_types_s) map (infer_types_e) } } @@ -304,7 +305,7 @@ object ResolveGenders extends Pass { val indexx = resolve_e(MALE)(e.index) WSubAccess(expx,indexx,e.tpe,g) } - case e => eMap(resolve_e(g) _,e) + case e => e map (resolve_e(g)) } } @@ -324,7 +325,7 @@ object ResolveGenders extends Pass { val expx = resolve_e(MALE)(s.exp) BulkConnect(s.info,locx,expx) } - case s => sMap(resolve_s,eMap(resolve_e(MALE) _,s)) + case s => s map (resolve_e(MALE)) map (resolve_s) } } val modulesx = c.modules.map { @@ -362,7 +363,7 @@ object InferWidths extends Pass { h } def simplify (w:Width) : Width = { - (wMap(simplify _,w)) match { + (w map (simplify)) match { case (w:MinWidth) => { val v = ArrayBuffer[Width]() for (wx <- w.args) { @@ -394,7 +395,7 @@ object InferWidths extends Pass { //;println-all-debug(["Substituting for [" w "]"]) val wx = simplify(w) //;println-all-debug(["After Simplify: [" wx "]"]) - (wMap(substitute(h) _,simplify(w))) match { + (simplify(w) map (substitute(h))) match { case (w:VarWidth) => { //;("matched println-debugvarwidth!") if (h.contains(w.name)) { @@ -413,14 +414,14 @@ object InferWidths extends Pass { } } def b_sub (h:LinkedHashMap[String,Width])(w:Width) : Width = { - (wMap(b_sub(h) _,w)) match { + (w map (b_sub(h))) match { case (w:VarWidth) => if (h.contains(w.name)) h(w.name) else w case (w) => w } } def remove_cycle (n:String)(w:Width) : Width = { //;println-all-debug(["Removing cycle for " n " inside " w]) - val wx = (wMap(remove_cycle(n) _,w)) match { + val wx = (w map (remove_cycle(n))) match { case (w:MaxWidth) => MaxWidth(w.args.filter{ w => { w match { case (w:VarWidth) => !(n equals w.name) @@ -438,7 +439,7 @@ object InferWidths extends Pass { def self_rec (n:String,w:Width) : Boolean = { var has = false def look (w:Width) : Width = { - (wMap(look _,w)) match { + (w map (look)) match { case (w:VarWidth) => if (w.name == n) has = true case (w) => w } w } @@ -587,14 +588,14 @@ object InferWidths extends Pass { get_constraints_t(f1.tpe,f2.tpe,times(f1.flip,f)) }}} case (t1:VectorType,t2:VectorType) => get_constraints_t(t1.tpe,t2.tpe,f) }} def get_constraints_e (e:Expression) : Expression = { - (eMap(get_constraints_e _,e)) match { + (e map (get_constraints_e)) match { case (e:Mux) => { constrain(width_BANG(e.cond),ONE) constrain(ONE,width_BANG(e.cond)) e } case (e) => e }} def get_constraints (s:Stmt) : Stmt = { - (eMap(get_constraints_e _,s)) match { + (s map (get_constraints_e)) match { case (s:Connect) => { val n = get_size(tpe(s.loc)) val ce_loc = create_exps(s.loc) @@ -623,8 +624,8 @@ object InferWidths extends Pass { case (s:Conditionally) => { v += WGeq(width_BANG(s.pred),ONE) v += WGeq(ONE,width_BANG(s.pred)) - sMap(get_constraints _,s) } - case (s) => sMap(get_constraints _,s) }} + s map (get_constraints) } + case (s) => s map (get_constraints) }} for (m <- c.modules) { (m) match { @@ -646,7 +647,7 @@ object PullMuxes extends Pass { def name = "Pull Muxes" def run (c:Circuit): Circuit = { def pull_muxes_e (e:Expression) : Expression = { - val ex = eMap(pull_muxes_e _,e) match { + val ex = e map (pull_muxes_e) match { case (e:WRef) => e case (e:WSubField) => { e.exp match { @@ -673,9 +674,9 @@ object PullMuxes extends Pass { case (e:ValidIf) => e case (e) => e } - eMap(pull_muxes_e _,ex) + ex map (pull_muxes_e) } - def pull_muxes (s:Stmt) : Stmt = eMap(pull_muxes_e _,sMap(pull_muxes _,s)) + def pull_muxes (s:Stmt) : Stmt = s map (pull_muxes) map (pull_muxes_e) val modulesx = c.modules.map { m => { mname = m.name @@ -698,7 +699,7 @@ object ExpandConnects extends Pass { val genders = LinkedHashMap[String,Gender]() def expand_s (s:Stmt) : Stmt = { def set_gender (e:Expression) : Expression = { - eMap(set_gender _,e) match { + e map (set_gender) match { case (e:WRef) => WRef(e.name,e.tpe,e.kind,genders(e.name)) case (e:WSubField) => { val f = get_field(tpe(e.exp),e.name) @@ -768,7 +769,7 @@ object ExpandConnects extends Pass { }} Begin(connects) } - case (s) => sMap(expand_s _,s) + case (s) => s map (expand_s) } } @@ -845,7 +846,7 @@ object RemoveAccesses extends Pass { def rec_has_access (e:Expression) : Expression = { e match { case (e:WSubAccess) => { ret = true; e } - case (e) => eMap(rec_has_access _,e) + case (e) => e map (rec_has_access) } } rec_has_access(e) @@ -864,9 +865,9 @@ object RemoveAccesses extends Pass { } def remove_e (e:Expression) : Expression = { //NOT RECURSIVE (except primops) INTENTIONALLY! e match { - case (e:DoPrim) => eMap(remove_e,e) - case (e:Mux) => eMap(remove_e,e) - case (e:ValidIf) => eMap(remove_e,e) + case (e:DoPrim) => e map (remove_e) + case (e:Mux) => e map (remove_e) + case (e:ValidIf) => e map (remove_e) case (e:SIntValue) => e case (e:UIntValue) => e case e => { @@ -910,7 +911,7 @@ object RemoveAccesses extends Pass { Connect(s.info,locx,remove_e(s.exp)) } else { Connect(s.info,s.loc,remove_e(s.exp)) } } - case (s) => sMap(remove_s,eMap(remove_e,s)) + case (s) => s map (remove_e) map (remove_s) } stmts += sx if (stmts.size != 1) Begin(stmts) else stmts(0) @@ -979,7 +980,7 @@ object ExpandWhens extends Pass { } Begin(Seq(s,Begin(voids))) } - case (s) => sMap(void_all_s _,s) + case (s) => s map (void_all_s) } } val voids = ArrayBuffer[Stmt]() @@ -1003,7 +1004,7 @@ object ExpandWhens extends Pass { def prefetch (s:Stmt) : Stmt = { (s) match { case (s:Connect) => exps += s.loc; s - case (s) => sMap(prefetch _,s) + case (s) => s map(prefetch) } } prefetch(s.conseq) @@ -1042,7 +1043,7 @@ object ExpandWhens extends Pass { simlist += Stop(s.info,s.ret,s.clk,AND(p,s.en)) } } - case (s) => sMap(expand_whens(netlist,p) _, s) + case (s) => s map(expand_whens(netlist,p)) } s } @@ -1063,7 +1064,7 @@ object ExpandWhens extends Pass { def replace_void (e:Expression)(rvalue:Expression) : Expression = { (rvalue) match { case (rv:WVoid) => e - case (rv) => eMap(replace_void(e) _,rv) + case (rv) => rv map (replace_void(e)) } } def create (s:Stmt) : Stmt = { @@ -1091,7 +1092,7 @@ object ExpandWhens extends Pass { } } case (_:DefPoison|_:DefNode) => stmts += s - case (s) => sMap(create _,s) + case (s) => s map(create) } s } @@ -1131,7 +1132,7 @@ object ConstProp extends Pass { def name = "Constant Propogation" var mname = "" def const_prop_e (e:Expression) : Expression = { - eMap(const_prop_e _,e) match { + e map (const_prop_e) match { case (e:DoPrim) => { e.op match { case SHIFT_RIGHT_OP => { @@ -1173,7 +1174,7 @@ object ConstProp extends Pass { case (e) => e } } - def const_prop_s (s:Stmt) : Stmt = eMap(const_prop_e _, sMap(const_prop_s _,s)) + def const_prop_s (s:Stmt) : Stmt = s map (const_prop_s) map (const_prop_e) def run (c:Circuit): Circuit = { val modulesx = c.modules.map{ m => { m match { @@ -1202,7 +1203,7 @@ object VerilogWrap extends Pass { def name = "Verilog Wrap" var mname = "" def v_wrap_e (e:Expression) : Expression = { - eMap(v_wrap_e _,e) match { + e map (v_wrap_e) match { case (e:DoPrim) => { def a0 () = e.args(0) if (e.op == TAIL_OP) { @@ -1220,7 +1221,7 @@ object VerilogWrap extends Pass { case (e) => e } } - def v_wrap_s (s:Stmt) : Stmt = eMap(v_wrap_e _,sMap(v_wrap_s _,s)) + def v_wrap_s (s:Stmt) : Stmt = s map (v_wrap_s) map (v_wrap_e) def run (c:Circuit): Circuit = { val modulesx = c.modules.map{ m => { (m) match { @@ -1248,19 +1249,19 @@ object SplitExp extends Pass { WRef(n,tpe(e),kind(e),gender(e)) } def split_exp_e (i:Int)(e:Expression) : Expression = { - eMap(split_exp_e(i + 1) _,e) match { + e map (split_exp_e(i + 1)) match { case (e:DoPrim) => if (i > 0) split(e) else e case (e) => e } } s match { - case (s:Begin) => sMap(split_exp_s _,s) + case (s:Begin) => s map (split_exp_s) case (s:Print) => { - val sx = eMap(split_exp_e(1) _,s) + val sx = s map (split_exp_e(1)) v += sx; sx } case (s) => { - val sx = eMap(split_exp_e(0) _,s) + val sx = s map (split_exp_e(0)) v += sx; sx } } @@ -1289,11 +1290,11 @@ object VerilogRename extends Pass { def verilog_rename_e (e:Expression) : Expression = { (e) match { case (e:WRef) => WRef(verilog_rename_n(e.name),e.tpe,kind(e),gender(e)) - case (e) => eMap(verilog_rename_e,e) + case (e) => e map (verilog_rename_e) } } def verilog_rename_s (s:Stmt) : Stmt = { - stMap(verilog_rename_n _,eMap(verilog_rename_e _,sMap(verilog_rename_s _,s))) + s map (verilog_rename_s) map (verilog_rename_e) map (verilog_rename_n) } val modulesx = c.modules.map{ m => { val portsx = m.ports.map{ p => { @@ -1341,7 +1342,7 @@ object LowerTypes extends Pass { def expand_name (e:Expression) : Seq[String] = { val names = ArrayBuffer[String]() def expand_name_e (e:Expression) : Expression = { - (eMap(expand_name_e _,e)) match { + (e map (expand_name_e)) match { case (e:WRef) => names += e.name case (e:WSubField) => names += e.name case (e:WSubIndex) => names += e.value.toString @@ -1418,9 +1419,9 @@ object LowerTypes extends Pass { case (k) => WRef(lowered_name(e),tpe(e),kind(e),gender(e)) } } - case (e:DoPrim) => eMap(lower_types_e _,e) - case (e:Mux) => eMap(lower_types_e _,e) - case (e:ValidIf) => eMap(lower_types_e _,e) + case (e:DoPrim) => e map (lower_types_e) + case (e:Mux) => e map (lower_types_e) + case (e:ValidIf) => e map (lower_types_e) } } (s) match { @@ -1476,7 +1477,7 @@ object LowerTypes extends Pass { } } case (s:IsInvalid) => { - val sx = eMap(lower_types_e _,s).as[IsInvalid].get + val sx = (s map (lower_types_e)).as[IsInvalid].get kind(sx.exp) match { case (k:MemKind) => { val es = lower_mem(sx.exp) @@ -1486,7 +1487,7 @@ object LowerTypes extends Pass { } } case (s:Connect) => { - val sx = eMap(lower_types_e _,s).as[Connect].get + val sx = (s map (lower_types_e)).as[Connect].get kind(sx.loc) match { case (k:MemKind) => { val es = lower_mem(sx.loc) @@ -1507,7 +1508,7 @@ object LowerTypes extends Pass { } if (n == 1) nodes(0) else Begin(nodes) } - case (s) => eMap(lower_types_e _,sMap(lower_types _,s)) + case (s) => s map (lower_types) map (lower_types_e) } } @@ -1567,7 +1568,7 @@ object CInferTypes extends Pass { def infer_types (m:Module) : Module = { val types = LinkedHashMap[String,Type]() def infer_types_e (e:Expression) : Expression = { - (eMap(infer_types_e _,e)) match { + (e map (infer_types_e)) match { case (e:Ref) => Ref(e.name, types.getOrElse(e.name,UnknownType())) case (e:SubField) => SubField(e.exp,e.name,field_type(tpe(e.exp),e.name)) case (e:SubIndex) => SubIndex(e.exp,e.value,sub_type(tpe(e.exp))) @@ -1582,7 +1583,7 @@ object CInferTypes extends Pass { (s) match { case (s:DefRegister) => { types(s.name) = s.tpe - eMap(infer_types_e _,s) + s map (infer_types_e) s } case (s:DefWire) => { @@ -1594,7 +1595,7 @@ object CInferTypes extends Pass { s } case (s:DefNode) => { - val sx = eMap(infer_types_e _,s) + val sx = s map (infer_types_e) val t = get_type(sx) types(s.name) = t sx @@ -1616,7 +1617,7 @@ object CInferTypes extends Pass { types(s.name) = module_types.getOrElse(s.module,UnknownType()) s } - case (s) => eMap(infer_types_e _,sMap(infer_types_s _,s)) + case (s) => s map(infer_types_s) map (infer_types_e) } } for (p <- m.ports) { @@ -1644,7 +1645,7 @@ object CInferMDir extends Pass { def infer_mdir (m:Module) : Module = { val mports = LinkedHashMap[String,MPortDir]() def infer_mdir_e (dir:MPortDir)(e:Expression) : Expression = { - (eMap(infer_mdir_e(dir) _,e)) match { + (e map (infer_mdir_e(dir))) match { case (e:Ref) => { if (mports.contains(e.name)) { val new_mport_dir = { @@ -1678,7 +1679,7 @@ object CInferMDir extends Pass { (s) match { case (s:CDefMPort) => { mports(s.name) = s.direction - eMap(infer_mdir_e(MRead) _,s) + s map (infer_mdir_e(MRead)) } case (s:Connect) => { infer_mdir_e(MRead)(s.exp) @@ -1690,14 +1691,14 @@ object CInferMDir extends Pass { infer_mdir_e(MWrite)(s.loc) s } - case (s) => eMap(infer_mdir_e(MRead) _, sMap(infer_mdir_s,s)) + case (s) => s map (infer_mdir_s) map (infer_mdir_e(MRead)) } } def set_mdir_s (s:Stmt) : Stmt = { (s) match { case (s:CDefMPort) => CDefMPort(s.info,s.name,s.tpe,s.mem,s.exps,mports(s.name)) - case (s) => sMap(set_mdir_s _,s) + case (s) => s map (set_mdir_s) } } (m) match { @@ -1760,7 +1761,7 @@ object RemoveCHIRRTL extends Pass { hash(s.mem) = mports s } - case (s) => sMap(collect_mports _,s) + case (s) => s map (collect_mports) } } def collect_refs (s:Stmt) : Stmt = { @@ -1840,7 +1841,7 @@ object RemoveCHIRRTL extends Pass { } Begin(stmts) } - case (s) => sMap(collect_refs _,s) + case (s) => s map (collect_refs) } } def remove_chirrtl_s (s:Stmt) : Stmt = { @@ -1863,11 +1864,11 @@ object RemoveCHIRRTL extends Pass { } else e } case (e:SubAccess) => SubAccess(remove_chirrtl_e(g)(e.exp),remove_chirrtl_e(MALE)(e.index),e.tpe) - case (e) => eMap(remove_chirrtl_e(g) _,e) + case (e) => e map (remove_chirrtl_e(g)) } } def get_mask (e:Expression) : Expression = { - (eMap(get_mask _,e)) match { + (e map (get_mask)) match { case (e:Ref) => { if (repl.contains(e.name)) { val vt = repl(e.name) @@ -1917,7 +1918,7 @@ object RemoveCHIRRTL extends Pass { if (stmts.size > 1) Begin(stmts) else stmts(0) } - case (s) => eMap(remove_chirrtl_e(MALE) _, sMap(remove_chirrtl_s,s)) + case (s) => s map (remove_chirrtl_s) map (remove_chirrtl_e(MALE)) } } collect_mports(m.body) |
