aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/Mappers.scala
diff options
context:
space:
mode:
authorjackkoenig2016-09-06 13:51:37 -0700
committerDonggyu Kim2016-09-12 12:48:46 -0700
commit4702bf9f257f954e19d1441b21e737f951ccfbcc (patch)
treec9bc87f605d0352d5cca58b56e7eda501e200759 /src/main/scala/firrtl/Mappers.scala
parent00bef01b6df158939406f3e744cbdda544823ae5 (diff)
Rework map functions as class methods
Changed code from match statements in Mappers.scala to methods on the various IR classes. This allows custom IR nodes to implement the mapper functions and thus work (ie. not match error) when map is called on them. This also should have a marginal performance increase because of use of virtual function calls rather than match statements.
Diffstat (limited to 'src/main/scala/firrtl/Mappers.scala')
-rw-r--r--src/main/scala/firrtl/Mappers.scala149
1 files changed, 23 insertions, 126 deletions
diff --git a/src/main/scala/firrtl/Mappers.scala b/src/main/scala/firrtl/Mappers.scala
index 5f073e0d..4eac68a5 100644
--- a/src/main/scala/firrtl/Mappers.scala
+++ b/src/main/scala/firrtl/Mappers.scala
@@ -38,56 +38,16 @@ object Mappers {
}
private object StmtMagnet {
implicit def forStmt(f: Statement => Statement) = new StmtMagnet {
- override def map(stmt: Statement): Statement = {
- stmt match {
- case s: Conditionally => Conditionally(s.info, s.pred, f(s.conseq), f(s.alt))
- case s: Block => Block(s.stmts.map(f))
- case s: Statement => s
- }
- }
+ override def map(stmt: Statement): Statement = stmt mapStmt f
}
implicit def forExp(f: Expression => Expression) = new StmtMagnet {
- override def map(stmt: Statement): Statement = {
- stmt match {
- case s: DefRegister => DefRegister(s.info, s.name, s.tpe, f(s.clock), f(s.reset), f(s.init))
- case s: DefNode => DefNode(s.info, s.name, f(s.value))
- case s: Connect => Connect(s.info, f(s.loc), f(s.expr))
- case s: PartialConnect => PartialConnect(s.info, f(s.loc), f(s.expr))
- case s: Conditionally => Conditionally(s.info, f(s.pred), s.conseq, s.alt)
- case s: IsInvalid => IsInvalid(s.info, f(s.expr))
- case s: Stop => Stop(s.info, s.ret, f(s.clk), f(s.en))
- case s: Print => Print(s.info, s.string, s.args.map(f), f(s.clk), f(s.en))
- case s: CDefMPort => CDefMPort(s.info,s.name,s.tpe,s.mem,s.exps.map(f),s.direction)
- case s: Statement => s
- }
- }
+ override def map(stmt: Statement): Statement = stmt mapExpr f
}
implicit def forType(f: Type => Type) = new StmtMagnet {
- override def map(stmt: Statement) : Statement = {
- stmt match {
- case s:DefWire => DefWire(s.info,s.name,f(s.tpe))
- case s:DefRegister => DefRegister(s.info,s.name,f(s.tpe),s.clock,s.reset,s.init)
- case s:DefMemory => DefMemory(s.info,s.name, f(s.dataType), s.depth, s.writeLatency, s.readLatency, s.readers, s.writers, s.readwriters)
- case s:CDefMemory => CDefMemory(s.info,s.name, f(s.tpe), s.size, s.seq)
- case s:CDefMPort => CDefMPort(s.info,s.name, f(s.tpe), s.mem, s.exps,s.direction)
- case s => s
- }
- }
+ override def map(stmt: Statement) : Statement = stmt mapType f
}
implicit def forString(f: String => String) = new StmtMagnet {
- override def map(stmt: Statement): Statement = {
- stmt match {
- case s: DefWire => DefWire(s.info,f(s.name),s.tpe)
- case s: DefRegister => DefRegister(s.info,f(s.name), s.tpe, s.clock, s.reset, s.init)
- case s: DefMemory => DefMemory(s.info,f(s.name), s.dataType, s.depth, s.writeLatency, s.readLatency, s.readers, s.writers, s.readwriters)
- case s: DefNode => DefNode(s.info,f(s.name),s.value)
- case s: DefInstance => DefInstance(s.info,f(s.name), s.module)
- case s: WDefInstance => WDefInstance(s.info,f(s.name), s.module,s.tpe)
- case s: CDefMemory => CDefMemory(s.info,f(s.name),s.tpe,s.size,s.seq)
- case s: CDefMPort => CDefMPort(s.info,f(s.name),s.tpe,s.mem,s.exps,s.direction)
- case s => s
- }
- }
+ override def map(stmt: Statement): Statement = stmt mapString f
}
}
implicit class StmtMap(stmt: Statement) {
@@ -96,52 +56,22 @@ object Mappers {
}
// ********** Expression Mappers **********
- private trait ExpMagnet {
- def map(exp: Expression): Expression
+ private trait ExprMagnet {
+ def map(expr: Expression): Expression
}
- private object ExpMagnet {
- implicit def forExp(f: Expression => Expression) = new ExpMagnet {
- override def map(exp: Expression): Expression = {
- exp match {
- case e: SubField => SubField(f(e.expr), e.name, e.tpe)
- case e: SubIndex => SubIndex(f(e.expr), e.value, e.tpe)
- case e: SubAccess => SubAccess(f(e.expr), f(e.index), e.tpe)
- case e: Mux => Mux(f(e.cond), f(e.tval), f(e.fval), e.tpe)
- case e: ValidIf => ValidIf(f(e.cond), f(e.value), e.tpe)
- case e: DoPrim => DoPrim(e.op, e.args.map(f), e.consts, e.tpe)
- case e: WSubField => WSubField(f(e.exp), e.name, e.tpe, e.gender)
- case e: WSubIndex => WSubIndex(f(e.exp), e.value, e.tpe, e.gender)
- case e: WSubAccess => WSubAccess(f(e.exp), f(e.index), e.tpe, e.gender)
- case e: Expression => e
- }
- }
+ private object ExprMagnet {
+ implicit def forExpr(f: Expression => Expression) = new ExprMagnet {
+ override def map(expr: Expression): Expression = expr mapExpr f
}
- implicit def forType(f: Type => Type) = new ExpMagnet {
- override def map(exp: Expression): Expression = {
- exp match {
- case e: DoPrim => DoPrim(e.op,e.args,e.consts,f(e.tpe))
- case e: Mux => Mux(e.cond,e.tval,e.fval,f(e.tpe))
- case e: ValidIf => ValidIf(e.cond,e.value,f(e.tpe))
- case e: WRef => WRef(e.name,f(e.tpe),e.kind,e.gender)
- case e: WSubField => WSubField(e.exp,e.name,f(e.tpe),e.gender)
- case e: WSubIndex => WSubIndex(e.exp,e.value,f(e.tpe),e.gender)
- case e: WSubAccess => WSubAccess(e.exp,e.index,f(e.tpe),e.gender)
- case e => e
- }
- }
+ implicit def forType(f: Type => Type) = new ExprMagnet {
+ override def map(expr: Expression): Expression = expr mapType f
}
- implicit def forWidth(f: Width => Width) = new ExpMagnet {
- override def map(exp: Expression): Expression = {
- exp match {
- case e: UIntLiteral => UIntLiteral(e.value,f(e.width))
- case e: SIntLiteral => SIntLiteral(e.value,f(e.width))
- case e => e
- }
- }
+ implicit def forWidth(f: Width => Width) = new ExprMagnet {
+ override def map(expr: Expression): Expression = expr mapWidth f
}
}
- implicit class ExpMap(exp: Expression) {
- def map[T](f: T => T)(implicit magnet: (T => T) => ExpMagnet): Expression = magnet(f).map(exp)
+ implicit class ExprMap(expr: Expression) {
+ def map[T](f: T => T)(implicit magnet: (T => T) => ExprMagnet): Expression = magnet(f).map(expr)
}
// ********** Type Mappers **********
@@ -150,22 +80,10 @@ object Mappers {
}
private object TypeMagnet {
implicit def forType(f: Type => Type) = new TypeMagnet {
- override def map(tpe: Type): Type = {
- tpe match {
- case t: BundleType => BundleType(t.fields.map(p => Field(p.name, p.flip, f(p.tpe))))
- case t: VectorType => VectorType(f(t.tpe), t.size)
- case t => t
- }
- }
+ override def map(tpe: Type): Type = tpe mapType f
}
implicit def forWidth(f: Width => Width) = new TypeMagnet {
- override def map(tpe: Type): Type = {
- tpe match {
- case t: UIntType => UIntType(f(t.width))
- case t: SIntType => SIntType(f(t.width))
- case t => t
- }
- }
+ override def map(tpe: Type): Type = tpe mapWidth f
}
}
implicit class TypeMap(tpe: Type) {
@@ -178,15 +96,9 @@ object Mappers {
}
private object WidthMagnet {
implicit def forWidth(f: Width => Width) = new WidthMagnet {
- override def map(width: Width): Width = {
- width match {
- case w: MaxWidth => MaxWidth(w.args.map(f))
- case w: MinWidth => MinWidth(w.args.map(f))
- case w: PlusWidth => PlusWidth(f(w.arg1),f(w.arg2))
- case w: MinusWidth => MinusWidth(f(w.arg1),f(w.arg2))
- case w: ExpWidth => ExpWidth(f(w.arg1))
- case w => w
- }
+ override def map(width: Width): Width = width match {
+ case mapable: HasMapWidth => mapable mapWidth f // WIR
+ case other => other // Standard IR nodes
}
}
}
@@ -200,28 +112,13 @@ object Mappers {
}
private object ModuleMagnet {
implicit def forStmt(f: Statement => Statement) = new ModuleMagnet {
- override def map(module: DefModule): DefModule = {
- module match {
- case m: Module => Module(m.info, m.name, m.ports, f(m.body))
- case m: ExtModule => m
- }
- }
+ override def map(module: DefModule): DefModule = module mapStmt f
}
implicit def forPorts(f: Port => Port) = new ModuleMagnet {
- override def map(module: DefModule): DefModule = {
- module match {
- case m: Module => Module(m.info, m.name, m.ports.map(f), m.body)
- case m: ExtModule => ExtModule(m.info, m.name, m.ports.map(f))
- }
- }
+ override def map(module: DefModule): DefModule = module mapPort f
}
implicit def forString(f: String => String) = new ModuleMagnet {
- override def map(module: DefModule): DefModule = {
- module match {
- case m: Module => Module(m.info, f(m.name), m.ports, m.body)
- case m: ExtModule => ExtModule(m.info, f(m.name), m.ports)
- }
- }
+ override def map(module: DefModule): DefModule = module mapString f
}
}
implicit class ModuleMap(module: DefModule) {