diff options
| author | jackkoenig | 2016-09-06 13:51:37 -0700 |
|---|---|---|
| committer | Donggyu Kim | 2016-09-12 12:48:46 -0700 |
| commit | 4702bf9f257f954e19d1441b21e737f951ccfbcc (patch) | |
| tree | c9bc87f605d0352d5cca58b56e7eda501e200759 | |
| parent | 00bef01b6df158939406f3e744cbdda544823ae5 (diff) | |
Rework map functions as class methods
Changed code from match statements in Mappers.scala to methods on the various
IR classes. This allows custom IR nodes to implement the mapper functions and
thus work (ie. not match error) when map is called on them.
This also should have a marginal performance increase because of use of virtual
function calls rather than match statements.
| -rw-r--r-- | src/main/scala/firrtl/Emitter.scala | 3 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Mappers.scala | 149 | ||||
| -rw-r--r-- | src/main/scala/firrtl/WIR.scala | 55 | ||||
| -rw-r--r-- | src/main/scala/firrtl/ir/IR.scala | 118 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Inline.scala | 8 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/PadWidths.scala | 6 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveEmpty.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveValidIf.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/SplitExpressions.scala | 7 |
9 files changed, 202 insertions, 148 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index b5d212e4..1610655b 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -59,6 +59,9 @@ case class VRandom(width: BigInt) extends Expression { def nWords = (width + 31) / 32 def realWidth = nWords * 32 def serialize: String = "RANDOM" + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = this } class VerilogEmitter extends Emitter { val tab = " " diff --git a/src/main/scala/firrtl/Mappers.scala b/src/main/scala/firrtl/Mappers.scala index 5f073e0d..4eac68a5 100644 --- a/src/main/scala/firrtl/Mappers.scala +++ b/src/main/scala/firrtl/Mappers.scala @@ -38,56 +38,16 @@ object Mappers { } private object StmtMagnet { implicit def forStmt(f: Statement => Statement) = new StmtMagnet { - override def map(stmt: Statement): Statement = { - stmt match { - case s: Conditionally => Conditionally(s.info, s.pred, f(s.conseq), f(s.alt)) - case s: Block => Block(s.stmts.map(f)) - case s: Statement => s - } - } + override def map(stmt: Statement): Statement = stmt mapStmt f } implicit def forExp(f: Expression => Expression) = new StmtMagnet { - override def map(stmt: Statement): Statement = { - stmt match { - case s: DefRegister => DefRegister(s.info, s.name, s.tpe, f(s.clock), f(s.reset), f(s.init)) - case s: DefNode => DefNode(s.info, s.name, f(s.value)) - case s: Connect => Connect(s.info, f(s.loc), f(s.expr)) - case s: PartialConnect => PartialConnect(s.info, f(s.loc), f(s.expr)) - case s: Conditionally => Conditionally(s.info, f(s.pred), s.conseq, s.alt) - case s: IsInvalid => IsInvalid(s.info, f(s.expr)) - case s: Stop => Stop(s.info, s.ret, f(s.clk), f(s.en)) - case s: Print => Print(s.info, s.string, s.args.map(f), f(s.clk), f(s.en)) - case s: CDefMPort => CDefMPort(s.info,s.name,s.tpe,s.mem,s.exps.map(f),s.direction) - case s: Statement => s - } - } + override def map(stmt: Statement): Statement = stmt mapExpr f } implicit def forType(f: Type => Type) = new StmtMagnet { - override def map(stmt: Statement) : Statement = { - stmt match { - case s:DefWire => DefWire(s.info,s.name,f(s.tpe)) - case s:DefRegister => DefRegister(s.info,s.name,f(s.tpe),s.clock,s.reset,s.init) - case s:DefMemory => DefMemory(s.info,s.name, f(s.dataType), s.depth, s.writeLatency, s.readLatency, s.readers, s.writers, s.readwriters) - case s:CDefMemory => CDefMemory(s.info,s.name, f(s.tpe), s.size, s.seq) - case s:CDefMPort => CDefMPort(s.info,s.name, f(s.tpe), s.mem, s.exps,s.direction) - case s => s - } - } + override def map(stmt: Statement) : Statement = stmt mapType f } implicit def forString(f: String => String) = new StmtMagnet { - override def map(stmt: Statement): Statement = { - stmt match { - case s: DefWire => DefWire(s.info,f(s.name),s.tpe) - case s: DefRegister => DefRegister(s.info,f(s.name), s.tpe, s.clock, s.reset, s.init) - case s: DefMemory => DefMemory(s.info,f(s.name), s.dataType, s.depth, s.writeLatency, s.readLatency, s.readers, s.writers, s.readwriters) - case s: DefNode => DefNode(s.info,f(s.name),s.value) - case s: DefInstance => DefInstance(s.info,f(s.name), s.module) - case s: WDefInstance => WDefInstance(s.info,f(s.name), s.module,s.tpe) - case s: CDefMemory => CDefMemory(s.info,f(s.name),s.tpe,s.size,s.seq) - case s: CDefMPort => CDefMPort(s.info,f(s.name),s.tpe,s.mem,s.exps,s.direction) - case s => s - } - } + override def map(stmt: Statement): Statement = stmt mapString f } } implicit class StmtMap(stmt: Statement) { @@ -96,52 +56,22 @@ object Mappers { } // ********** Expression Mappers ********** - private trait ExpMagnet { - def map(exp: Expression): Expression + private trait ExprMagnet { + def map(expr: Expression): Expression } - private object ExpMagnet { - implicit def forExp(f: Expression => Expression) = new ExpMagnet { - override def map(exp: Expression): Expression = { - exp match { - case e: SubField => SubField(f(e.expr), e.name, e.tpe) - case e: SubIndex => SubIndex(f(e.expr), e.value, e.tpe) - case e: SubAccess => SubAccess(f(e.expr), f(e.index), e.tpe) - case e: Mux => Mux(f(e.cond), f(e.tval), f(e.fval), e.tpe) - case e: ValidIf => ValidIf(f(e.cond), f(e.value), e.tpe) - case e: DoPrim => DoPrim(e.op, e.args.map(f), e.consts, e.tpe) - case e: WSubField => WSubField(f(e.exp), e.name, e.tpe, e.gender) - case e: WSubIndex => WSubIndex(f(e.exp), e.value, e.tpe, e.gender) - case e: WSubAccess => WSubAccess(f(e.exp), f(e.index), e.tpe, e.gender) - case e: Expression => e - } - } + private object ExprMagnet { + implicit def forExpr(f: Expression => Expression) = new ExprMagnet { + override def map(expr: Expression): Expression = expr mapExpr f } - implicit def forType(f: Type => Type) = new ExpMagnet { - override def map(exp: Expression): Expression = { - exp match { - case e: DoPrim => DoPrim(e.op,e.args,e.consts,f(e.tpe)) - case e: Mux => Mux(e.cond,e.tval,e.fval,f(e.tpe)) - case e: ValidIf => ValidIf(e.cond,e.value,f(e.tpe)) - case e: WRef => WRef(e.name,f(e.tpe),e.kind,e.gender) - case e: WSubField => WSubField(e.exp,e.name,f(e.tpe),e.gender) - case e: WSubIndex => WSubIndex(e.exp,e.value,f(e.tpe),e.gender) - case e: WSubAccess => WSubAccess(e.exp,e.index,f(e.tpe),e.gender) - case e => e - } - } + implicit def forType(f: Type => Type) = new ExprMagnet { + override def map(expr: Expression): Expression = expr mapType f } - implicit def forWidth(f: Width => Width) = new ExpMagnet { - override def map(exp: Expression): Expression = { - exp match { - case e: UIntLiteral => UIntLiteral(e.value,f(e.width)) - case e: SIntLiteral => SIntLiteral(e.value,f(e.width)) - case e => e - } - } + implicit def forWidth(f: Width => Width) = new ExprMagnet { + override def map(expr: Expression): Expression = expr mapWidth f } } - implicit class ExpMap(exp: Expression) { - def map[T](f: T => T)(implicit magnet: (T => T) => ExpMagnet): Expression = magnet(f).map(exp) + implicit class ExprMap(expr: Expression) { + def map[T](f: T => T)(implicit magnet: (T => T) => ExprMagnet): Expression = magnet(f).map(expr) } // ********** Type Mappers ********** @@ -150,22 +80,10 @@ object Mappers { } private object TypeMagnet { implicit def forType(f: Type => Type) = new TypeMagnet { - override def map(tpe: Type): Type = { - tpe match { - case t: BundleType => BundleType(t.fields.map(p => Field(p.name, p.flip, f(p.tpe)))) - case t: VectorType => VectorType(f(t.tpe), t.size) - case t => t - } - } + override def map(tpe: Type): Type = tpe mapType f } implicit def forWidth(f: Width => Width) = new TypeMagnet { - override def map(tpe: Type): Type = { - tpe match { - case t: UIntType => UIntType(f(t.width)) - case t: SIntType => SIntType(f(t.width)) - case t => t - } - } + override def map(tpe: Type): Type = tpe mapWidth f } } implicit class TypeMap(tpe: Type) { @@ -178,15 +96,9 @@ object Mappers { } private object WidthMagnet { implicit def forWidth(f: Width => Width) = new WidthMagnet { - override def map(width: Width): Width = { - width match { - case w: MaxWidth => MaxWidth(w.args.map(f)) - case w: MinWidth => MinWidth(w.args.map(f)) - case w: PlusWidth => PlusWidth(f(w.arg1),f(w.arg2)) - case w: MinusWidth => MinusWidth(f(w.arg1),f(w.arg2)) - case w: ExpWidth => ExpWidth(f(w.arg1)) - case w => w - } + override def map(width: Width): Width = width match { + case mapable: HasMapWidth => mapable mapWidth f // WIR + case other => other // Standard IR nodes } } } @@ -200,28 +112,13 @@ object Mappers { } private object ModuleMagnet { implicit def forStmt(f: Statement => Statement) = new ModuleMagnet { - override def map(module: DefModule): DefModule = { - module match { - case m: Module => Module(m.info, m.name, m.ports, f(m.body)) - case m: ExtModule => m - } - } + override def map(module: DefModule): DefModule = module mapStmt f } implicit def forPorts(f: Port => Port) = new ModuleMagnet { - override def map(module: DefModule): DefModule = { - module match { - case m: Module => Module(m.info, m.name, m.ports.map(f), m.body) - case m: ExtModule => ExtModule(m.info, m.name, m.ports.map(f)) - } - } + override def map(module: DefModule): DefModule = module mapPort f } implicit def forString(f: String => String) = new ModuleMagnet { - override def map(module: DefModule): DefModule = { - module match { - case m: Module => Module(m.info, f(m.name), m.ports, m.body) - case m: ExtModule => ExtModule(m.info, f(m.name), m.ports) - } - } + override def map(module: DefModule): DefModule = module mapString f } } implicit class ModuleMap(module: DefModule) { diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index 9583175e..9ff48446 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -51,31 +51,56 @@ case object UNKNOWNGENDER extends Gender case class WRef(name: String, tpe: Type, kind: Kind, gender: Gender) extends Expression { def serialize: String = name + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this } case class WSubField(exp: Expression, name: String, tpe: Type, gender: Gender) extends Expression { def serialize: String = s"${exp.serialize}.$name" + def mapExpr(f: Expression => Expression): Expression = this.copy(exp = f(exp)) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this } case class WSubIndex(exp: Expression, value: Int, tpe: Type, gender: Gender) extends Expression { def serialize: String = s"${exp.serialize}[$value]" + def mapExpr(f: Expression => Expression): Expression = this.copy(exp = f(exp)) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this } case class WSubAccess(exp: Expression, index: Expression, tpe: Type, gender: Gender) extends Expression { def serialize: String = s"${exp.serialize}[${index.serialize}]" + def mapExpr(f: Expression => Expression): Expression = this.copy(exp = f(exp), index = f(index)) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this } case class WVoid() extends Expression { def tpe = UnknownType def serialize: String = "VOID" + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = this } case class WInvalid() extends Expression { def tpe = UnknownType def serialize: String = "INVALID" + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = this } // Useful for splitting then remerging references case object EmptyExpression extends Expression { def tpe = UnknownType def serialize: String = "EMPTY" + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = this } case class WDefInstance(info: Info, name: String, module: String, tpe: Type) extends Statement with IsDeclaration { def serialize: String = s"inst $name of $module" + info.serialize + def mapExpr(f: Expression => Expression): Statement = this + def mapStmt(f: Statement => Statement): Statement = this + def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) + def mapString(f: String => String): Statement = this.copy(name = f(name)) } // Resultant width is the same as the maximum input width @@ -115,25 +140,33 @@ class WrappedExpression (val e1: Expression) { override def hashCode = e1.serialize.hashCode override def toString = e1.serialize } - -case class VarWidth(name: String) extends Width { +private[firrtl] sealed trait HasMapWidth { + def mapWidth(f: Width => Width): Width +} +case class VarWidth(name: String) extends Width with HasMapWidth { def serialize: String = name + def mapWidth(f: Width => Width): Width = this } -case class PlusWidth(arg1: Width, arg2: Width) extends Width { +case class PlusWidth(arg1: Width, arg2: Width) extends Width with HasMapWidth { def serialize: String = "(" + arg1.serialize + " + " + arg2.serialize + ")" + def mapWidth(f: Width => Width): Width = PlusWidth(f(arg1), f(arg2)) } -case class MinusWidth(arg1: Width, arg2: Width) extends Width { +case class MinusWidth(arg1: Width, arg2: Width) extends Width with HasMapWidth { def serialize: String = "(" + arg1.serialize + " - " + arg2.serialize + ")" + def mapWidth(f: Width => Width): Width = MinusWidth(f(arg1), f(arg2)) } -case class MaxWidth(args: Seq[Width]) extends Width { +case class MaxWidth(args: Seq[Width]) extends Width with HasMapWidth { def serialize: String = args map (_.serialize) mkString ("max(", ", ", ")") + def mapWidth(f: Width => Width): Width = MaxWidth(args map f) } -case class MinWidth(args: Seq[Width]) extends Width { +case class MinWidth(args: Seq[Width]) extends Width with HasMapWidth { def serialize: String = args map (_.serialize) mkString ("min(", ", ", ")") + def mapWidth(f: Width => Width): Width = MinWidth(args map f) } -case class ExpWidth(arg1: Width) extends Width { +case class ExpWidth(arg1: Width) extends Width with HasMapWidth { def serialize: String = "exp(" + arg1.serialize + " )" + def mapWidth(f: Width => Width): Width = ExpWidth(f(arg1)) } object WrappedType { @@ -234,6 +267,10 @@ case class CDefMemory( seq: Boolean) extends Statement { def serialize: String = (if (seq) "smem" else "cmem") + s" $name : ${tpe.serialize} [$size]" + info.serialize + def mapExpr(f: Expression => Expression): Statement = this + def mapStmt(f: Statement => Statement): Statement = this + def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) + def mapString(f: String => String): Statement = this.copy(name = f(name)) } case class CDefMPort(info: Info, name: String, @@ -245,5 +282,9 @@ case class CDefMPort(info: Info, val dir = direction.serialize s"$dir mport $name = $mem[${exps(0).serialize}], ${exps(1).serialize}" + info.serialize } + def mapExpr(f: Expression => Expression): Statement = this.copy(exps = exps map f) + def mapStmt(f: Statement => Statement): Statement = this + def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) + def mapString(f: String => String): Statement = this.copy(name = f(name)) } diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index afe28634..8f1cabf6 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -68,24 +68,46 @@ abstract class PrimOp extends FirrtlNode { abstract class Expression extends FirrtlNode { def tpe: Type + def mapExpr(f: Expression => Expression): Expression + def mapType(f: Type => Type): Expression + def mapWidth(f: Width => Width): Expression } case class Reference(name: String, tpe: Type) extends Expression with HasName { def serialize: String = name + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this } case class SubField(expr: Expression, name: String, tpe: Type) extends Expression with HasName { def serialize: String = s"${expr.serialize}.$name" + def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr)) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this } case class SubIndex(expr: Expression, value: Int, tpe: Type) extends Expression { def serialize: String = s"${expr.serialize}[$value]" + def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr)) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this } case class SubAccess(expr: Expression, index: Expression, tpe: Type) extends Expression { def serialize: String = s"${expr.serialize}[${index.serialize}]" + def mapExpr(f: Expression => Expression): Expression = + this.copy(expr = f(expr), index = f(index)) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this } case class Mux(cond: Expression, tval: Expression, fval: Expression, tpe: Type) extends Expression { def serialize: String = s"mux(${cond.serialize}, ${tval.serialize}, ${fval.serialize})" + def mapExpr(f: Expression => Expression): Expression = Mux(f(cond), f(tval), f(fval), tpe) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this } case class ValidIf(cond: Expression, value: Expression, tpe: Type) extends Expression { def serialize: String = s"validif(${cond.serialize}, ${value.serialize})" + def mapExpr(f: Expression => Expression): Expression = ValidIf(f(cond), f(value), tpe) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this } abstract class Literal extends Expression { val value: BigInt @@ -94,19 +116,37 @@ abstract class Literal extends Expression { case class UIntLiteral(value: BigInt, width: Width) extends Literal { def tpe = UIntType(width) def serialize = s"UInt${width.serialize}(" + Utils.serialize(value) + ")" + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = UIntLiteral(value, f(width)) } case class SIntLiteral(value: BigInt, width: Width) extends Literal { def tpe = SIntType(width) def serialize = s"SInt${width.serialize}(" + Utils.serialize(value) + ")" + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = SIntLiteral(value, f(width)) } case class DoPrim(op: PrimOp, args: Seq[Expression], consts: Seq[BigInt], tpe: Type) extends Expression { def serialize: String = op.serialize + "(" + (args.map(_.serialize) ++ consts.map(_.toString)).mkString(", ") + ")" + def mapExpr(f: Expression => Expression): Expression = this.copy(args = args map f) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this } -abstract class Statement extends FirrtlNode +abstract class Statement extends FirrtlNode { + def mapStmt(f: Statement => Statement): Statement + def mapExpr(f: Expression => Expression): Statement + def mapType(f: Type => Type): Statement + def mapString(f: String => String): Statement +} case class DefWire(info: Info, name: String, tpe: Type) extends Statement with IsDeclaration { def serialize: String = s"wire $name : ${tpe.serialize}" + info.serialize + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = this + def mapType(f: Type => Type): Statement = DefWire(info, name, f(tpe)) + def mapString(f: String => String): Statement = DefWire(info, f(name), tpe) } case class DefRegister( info: Info, @@ -118,10 +158,19 @@ case class DefRegister( def serialize: String = s"reg $name : ${tpe.serialize}, ${clock.serialize} with :" + indent("\n" + s"reset => (${reset.serialize}, ${init.serialize})" + info.serialize) + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = + DefRegister(info, name, tpe, f(clock), f(reset), f(init)) + def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) + def mapString(f: String => String): Statement = this.copy(name = f(name)) } case class DefInstance(info: Info, name: String, module: String) extends Statement with IsDeclaration { def serialize: String = s"inst $name of $module" + info.serialize + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = this + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = DefInstance(info, f(name), module) } case class DefMemory( info: Info, @@ -146,9 +195,17 @@ case class DefMemory( (writers map ("writer => " + _)) ++ (readwriters map ("readwriter => " + _)) ++ Seq("read-under-write => undefined")) mkString "\n") + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = this + def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType)) + def mapString(f: String => String): Statement = this.copy(name = f(name)) } case class DefNode(info: Info, name: String, value: Expression) extends Statement with IsDeclaration { def serialize: String = s"node $name = ${value.serialize}" + info.serialize + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = DefNode(info, name, f(value)) + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = DefNode(info, f(name), value) } case class Conditionally( info: Info, @@ -160,21 +217,45 @@ case class Conditionally( indent("\n" + conseq.serialize) + (if (alt == EmptyStmt) "" else "\nelse :" + indent("\n" + alt.serialize)) + def mapStmt(f: Statement => Statement): Statement = Conditionally(info, pred, f(conseq), f(alt)) + def mapExpr(f: Expression => Expression): Statement = Conditionally(info, f(pred), conseq, alt) + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this } case class Block(stmts: Seq[Statement]) extends Statement { def serialize: String = stmts map (_.serialize) mkString "\n" + def mapStmt(f: Statement => Statement): Statement = Block(stmts map f) + def mapExpr(f: Expression => Expression): Statement = this + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this } case class PartialConnect(info: Info, loc: Expression, expr: Expression) extends Statement with HasInfo { def serialize: String = s"${loc.serialize} <- ${expr.serialize}" + info.serialize + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = PartialConnect(info, f(loc), f(expr)) + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this } case class Connect(info: Info, loc: Expression, expr: Expression) extends Statement with HasInfo { def serialize: String = s"${loc.serialize} <= ${expr.serialize}" + info.serialize + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = Connect(info, f(loc), f(expr)) + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this } case class IsInvalid(info: Info, expr: Expression) extends Statement with HasInfo { def serialize: String = s"${expr.serialize} is invalid" + info.serialize + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = IsInvalid(info, f(expr)) + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this } case class Stop(info: Info, ret: Int, clk: Expression, en: Expression) extends Statement with HasInfo { def serialize: String = s"stop(${clk.serialize}, ${en.serialize}, $ret)" + info.serialize + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = Stop(info, ret, f(clk), f(en)) + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this } case class Print( info: Info, @@ -187,9 +268,17 @@ case class Print( (args map (_.serialize)) "printf(" + (strs mkString ", ") + ")" + info.serialize } + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = Print(info, string, args map f, f(clk), f(en)) + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this } case object EmptyStmt extends Statement { def serialize: String = "skip" + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = this + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this } abstract class Width extends FirrtlNode { @@ -232,29 +321,43 @@ case class Field(name: String, flip: Orientation, tpe: Type) extends FirrtlNode def serialize: String = flip.serialize + name + " : " + tpe.serialize } -abstract class Type extends FirrtlNode +abstract class Type extends FirrtlNode { + def mapType(f: Type => Type): Type + def mapWidth(f: Width => Width): Type +} abstract class GroundType extends Type { val width: Width + def mapType(f: Type => Type): Type = this +} +abstract class AggregateType extends Type { + def mapWidth(f: Width => Width): Type = this } -abstract class AggregateType extends Type case class UIntType(width: Width) extends GroundType { def serialize: String = "UInt" + width.serialize + def mapWidth(f: Width => Width): Type = UIntType(f(width)) } case class SIntType(width: Width) extends GroundType { def serialize: String = "SInt" + width.serialize + def mapWidth(f: Width => Width): Type = SIntType(f(width)) } case class BundleType(fields: Seq[Field]) extends AggregateType { def serialize: String = "{ " + (fields map (_.serialize) mkString ", ") + "}" + def mapType(f: Type => Type): Type = + BundleType(fields map (x => x.copy(tpe = f(x.tpe)))) } case class VectorType(tpe: Type, size: Int) extends AggregateType { def serialize: String = tpe.serialize + s"[$size]" + def mapType(f: Type => Type): Type = this.copy(tpe = f(tpe)) } case object ClockType extends GroundType { val width = IntWidth(1) def serialize: String = "Clock" + def mapWidth(f: Width => Width): Type = this } case object UnknownType extends Type { def serialize: String = "?" + def mapType(f: Type => Type): Type = this + def mapWidth(f: Width => Width): Type = this } /** [[Port]] Direction */ @@ -283,6 +386,9 @@ abstract class DefModule extends FirrtlNode with IsDeclaration { protected def serializeHeader(tpe: String): String = s"$tpe $name :" + info.serialize + indent(ports map ("\n" + _.serialize) mkString) + "\n" + def mapStmt(f: Statement => Statement): DefModule + def mapPort(f: Port => Port): DefModule + def mapString(f: String => String): DefModule } /** Internal Module * @@ -290,6 +396,9 @@ abstract class DefModule extends FirrtlNode with IsDeclaration { */ case class Module(info: Info, name: String, ports: Seq[Port], body: Statement) extends DefModule { def serialize: String = serializeHeader("module") + indent("\n" + body.serialize) + def mapStmt(f: Statement => Statement): DefModule = this.copy(body = f(body)) + def mapPort(f: Port => Port): DefModule = this.copy(ports = ports map f) + def mapString(f: String => String): DefModule = this.copy(name = f(name)) } /** External Module * @@ -297,6 +406,9 @@ case class Module(info: Info, name: String, ports: Seq[Port], body: Statement) e */ case class ExtModule(info: Info, name: String, ports: Seq[Port]) extends DefModule { def serialize: String = serializeHeader("extmodule") + def mapStmt(f: Statement => Statement): DefModule = this + def mapPort(f: Port => Port): DefModule = this.copy(ports = ports map f) + def mapString(f: String => String): DefModule = this.copy(name = f(name)) } case class Circuit(info: Info, modules: Seq[DefModule], main: String) extends FirrtlNode with HasInfo { diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala index c4529bd9..43c0ef1e 100644 --- a/src/main/scala/firrtl/passes/Inline.scala +++ b/src/main/scala/firrtl/passes/Inline.scala @@ -1,12 +1,12 @@ package firrtl package passes -// Datastructures -import scala.collection.mutable - import firrtl.ir._ +import firrtl.Mappers._ import firrtl.Annotations._ -import firrtl.Mappers.{ExpMap, StmtMap} + +// Datastructures +import scala.collection.mutable // Tags an annotation to be consumed by this pass case class InlineAnnotation(target: Named, tID: TransID) extends Annotation with Loose with Unstable { diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala index f2117761..4cdcae59 100644 --- a/src/main/scala/firrtl/passes/PadWidths.scala +++ b/src/main/scala/firrtl/passes/PadWidths.scala @@ -1,10 +1,10 @@ package firrtl package passes -import firrtl.Mappers.{ExpMap, StmtMap} -import firrtl.Utils.long_BANG -import firrtl.PrimOps._ import firrtl.ir._ +import firrtl.PrimOps._ +import firrtl.Mappers._ +import firrtl.Utils.long_BANG // Makes all implicit width extensions and truncations explicit object PadWidths extends Pass { diff --git a/src/main/scala/firrtl/passes/RemoveEmpty.scala b/src/main/scala/firrtl/passes/RemoveEmpty.scala index 7ba2ef09..225e2222 100644 --- a/src/main/scala/firrtl/passes/RemoveEmpty.scala +++ b/src/main/scala/firrtl/passes/RemoveEmpty.scala @@ -2,7 +2,7 @@ package firrtl package passes import scala.collection.mutable -import firrtl.Mappers.{ExpMap, StmtMap} +import firrtl.Mappers._ import firrtl.ir._ object RemoveEmpty extends Pass { diff --git a/src/main/scala/firrtl/passes/RemoveValidIf.scala b/src/main/scala/firrtl/passes/RemoveValidIf.scala index a534cc50..e0a4b621 100644 --- a/src/main/scala/firrtl/passes/RemoveValidIf.scala +++ b/src/main/scala/firrtl/passes/RemoveValidIf.scala @@ -1,6 +1,6 @@ package firrtl package passes -import firrtl.Mappers.{ExpMap, StmtMap} +import firrtl.Mappers._ import firrtl.ir._ // Removes ValidIf as an optimization diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala index 90b92a35..31306046 100644 --- a/src/main/scala/firrtl/passes/SplitExpressions.scala +++ b/src/main/scala/firrtl/passes/SplitExpressions.scala @@ -1,11 +1,12 @@ package firrtl package passes -import firrtl.Mappers.{ExpMap, StmtMap} -import firrtl.Utils.{kind, gender, get_info} import firrtl.ir._ -import scala.collection.mutable +import firrtl.Mappers._ +import firrtl.Utils.{kind, gender, get_info} +// Datastructures +import scala.collection.mutable // Splits compound expressions into simple expressions // and named intermediate nodes |
