diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Emitter.scala | 3 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Mappers.scala | 149 | ||||
| -rw-r--r-- | src/main/scala/firrtl/WIR.scala | 55 | ||||
| -rw-r--r-- | src/main/scala/firrtl/ir/IR.scala | 118 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Inline.scala | 8 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/PadWidths.scala | 6 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveEmpty.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveValidIf.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/SplitExpressions.scala | 7 |
9 files changed, 202 insertions, 148 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index b5d212e4..1610655b 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -59,6 +59,9 @@ case class VRandom(width: BigInt) extends Expression { def nWords = (width + 31) / 32 def realWidth = nWords * 32 def serialize: String = "RANDOM" + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = this } class VerilogEmitter extends Emitter { val tab = " " diff --git a/src/main/scala/firrtl/Mappers.scala b/src/main/scala/firrtl/Mappers.scala index 5f073e0d..4eac68a5 100644 --- a/src/main/scala/firrtl/Mappers.scala +++ b/src/main/scala/firrtl/Mappers.scala @@ -38,56 +38,16 @@ object Mappers { } private object StmtMagnet { implicit def forStmt(f: Statement => Statement) = new StmtMagnet { - override def map(stmt: Statement): Statement = { - stmt match { - case s: Conditionally => Conditionally(s.info, s.pred, f(s.conseq), f(s.alt)) - case s: Block => Block(s.stmts.map(f)) - case s: Statement => s - } - } + override def map(stmt: Statement): Statement = stmt mapStmt f } implicit def forExp(f: Expression => Expression) = new StmtMagnet { - override def map(stmt: Statement): Statement = { - stmt match { - case s: DefRegister => DefRegister(s.info, s.name, s.tpe, f(s.clock), f(s.reset), f(s.init)) - case s: DefNode => DefNode(s.info, s.name, f(s.value)) - case s: Connect => Connect(s.info, f(s.loc), f(s.expr)) - case s: PartialConnect => PartialConnect(s.info, f(s.loc), f(s.expr)) - case s: Conditionally => Conditionally(s.info, f(s.pred), s.conseq, s.alt) - case s: IsInvalid => IsInvalid(s.info, f(s.expr)) - case s: Stop => Stop(s.info, s.ret, f(s.clk), f(s.en)) - case s: Print => Print(s.info, s.string, s.args.map(f), f(s.clk), f(s.en)) - case s: CDefMPort => CDefMPort(s.info,s.name,s.tpe,s.mem,s.exps.map(f),s.direction) - case s: Statement => s - } - } + override def map(stmt: Statement): Statement = stmt mapExpr f } implicit def forType(f: Type => Type) = new StmtMagnet { - override def map(stmt: Statement) : Statement = { - stmt match { - case s:DefWire => DefWire(s.info,s.name,f(s.tpe)) - case s:DefRegister => DefRegister(s.info,s.name,f(s.tpe),s.clock,s.reset,s.init) - case s:DefMemory => DefMemory(s.info,s.name, f(s.dataType), s.depth, s.writeLatency, s.readLatency, s.readers, s.writers, s.readwriters) - case s:CDefMemory => CDefMemory(s.info,s.name, f(s.tpe), s.size, s.seq) - case s:CDefMPort => CDefMPort(s.info,s.name, f(s.tpe), s.mem, s.exps,s.direction) - case s => s - } - } + override def map(stmt: Statement) : Statement = stmt mapType f } implicit def forString(f: String => String) = new StmtMagnet { - override def map(stmt: Statement): Statement = { - stmt match { - case s: DefWire => DefWire(s.info,f(s.name),s.tpe) - case s: DefRegister => DefRegister(s.info,f(s.name), s.tpe, s.clock, s.reset, s.init) - case s: DefMemory => DefMemory(s.info,f(s.name), s.dataType, s.depth, s.writeLatency, s.readLatency, s.readers, s.writers, s.readwriters) - case s: DefNode => DefNode(s.info,f(s.name),s.value) - case s: DefInstance => DefInstance(s.info,f(s.name), s.module) - case s: WDefInstance => WDefInstance(s.info,f(s.name), s.module,s.tpe) - case s: CDefMemory => CDefMemory(s.info,f(s.name),s.tpe,s.size,s.seq) - case s: CDefMPort => CDefMPort(s.info,f(s.name),s.tpe,s.mem,s.exps,s.direction) - case s => s - } - } + override def map(stmt: Statement): Statement = stmt mapString f } } implicit class StmtMap(stmt: Statement) { @@ -96,52 +56,22 @@ object Mappers { } // ********** Expression Mappers ********** - private trait ExpMagnet { - def map(exp: Expression): Expression + private trait ExprMagnet { + def map(expr: Expression): Expression } - private object ExpMagnet { - implicit def forExp(f: Expression => Expression) = new ExpMagnet { - override def map(exp: Expression): Expression = { - exp match { - case e: SubField => SubField(f(e.expr), e.name, e.tpe) - case e: SubIndex => SubIndex(f(e.expr), e.value, e.tpe) - case e: SubAccess => SubAccess(f(e.expr), f(e.index), e.tpe) - case e: Mux => Mux(f(e.cond), f(e.tval), f(e.fval), e.tpe) - case e: ValidIf => ValidIf(f(e.cond), f(e.value), e.tpe) - case e: DoPrim => DoPrim(e.op, e.args.map(f), e.consts, e.tpe) - case e: WSubField => WSubField(f(e.exp), e.name, e.tpe, e.gender) - case e: WSubIndex => WSubIndex(f(e.exp), e.value, e.tpe, e.gender) - case e: WSubAccess => WSubAccess(f(e.exp), f(e.index), e.tpe, e.gender) - case e: Expression => e - } - } + private object ExprMagnet { + implicit def forExpr(f: Expression => Expression) = new ExprMagnet { + override def map(expr: Expression): Expression = expr mapExpr f } - implicit def forType(f: Type => Type) = new ExpMagnet { - override def map(exp: Expression): Expression = { - exp match { - case e: DoPrim => DoPrim(e.op,e.args,e.consts,f(e.tpe)) - case e: Mux => Mux(e.cond,e.tval,e.fval,f(e.tpe)) - case e: ValidIf => ValidIf(e.cond,e.value,f(e.tpe)) - case e: WRef => WRef(e.name,f(e.tpe),e.kind,e.gender) - case e: WSubField => WSubField(e.exp,e.name,f(e.tpe),e.gender) - case e: WSubIndex => WSubIndex(e.exp,e.value,f(e.tpe),e.gender) - case e: WSubAccess => WSubAccess(e.exp,e.index,f(e.tpe),e.gender) - case e => e - } - } + implicit def forType(f: Type => Type) = new ExprMagnet { + override def map(expr: Expression): Expression = expr mapType f } - implicit def forWidth(f: Width => Width) = new ExpMagnet { - override def map(exp: Expression): Expression = { - exp match { - case e: UIntLiteral => UIntLiteral(e.value,f(e.width)) - case e: SIntLiteral => SIntLiteral(e.value,f(e.width)) - case e => e - } - } + implicit def forWidth(f: Width => Width) = new ExprMagnet { + override def map(expr: Expression): Expression = expr mapWidth f } } - implicit class ExpMap(exp: Expression) { - def map[T](f: T => T)(implicit magnet: (T => T) => ExpMagnet): Expression = magnet(f).map(exp) + implicit class ExprMap(expr: Expression) { + def map[T](f: T => T)(implicit magnet: (T => T) => ExprMagnet): Expression = magnet(f).map(expr) } // ********** Type Mappers ********** @@ -150,22 +80,10 @@ object Mappers { } private object TypeMagnet { implicit def forType(f: Type => Type) = new TypeMagnet { - override def map(tpe: Type): Type = { - tpe match { - case t: BundleType => BundleType(t.fields.map(p => Field(p.name, p.flip, f(p.tpe)))) - case t: VectorType => VectorType(f(t.tpe), t.size) - case t => t - } - } + override def map(tpe: Type): Type = tpe mapType f } implicit def forWidth(f: Width => Width) = new TypeMagnet { - override def map(tpe: Type): Type = { - tpe match { - case t: UIntType => UIntType(f(t.width)) - case t: SIntType => SIntType(f(t.width)) - case t => t - } - } + override def map(tpe: Type): Type = tpe mapWidth f } } implicit class TypeMap(tpe: Type) { @@ -178,15 +96,9 @@ object Mappers { } private object WidthMagnet { implicit def forWidth(f: Width => Width) = new WidthMagnet { - override def map(width: Width): Width = { - width match { - case w: MaxWidth => MaxWidth(w.args.map(f)) - case w: MinWidth => MinWidth(w.args.map(f)) - case w: PlusWidth => PlusWidth(f(w.arg1),f(w.arg2)) - case w: MinusWidth => MinusWidth(f(w.arg1),f(w.arg2)) - case w: ExpWidth => ExpWidth(f(w.arg1)) - case w => w - } + override def map(width: Width): Width = width match { + case mapable: HasMapWidth => mapable mapWidth f // WIR + case other => other // Standard IR nodes } } } @@ -200,28 +112,13 @@ object Mappers { } private object ModuleMagnet { implicit def forStmt(f: Statement => Statement) = new ModuleMagnet { - override def map(module: DefModule): DefModule = { - module match { - case m: Module => Module(m.info, m.name, m.ports, f(m.body)) - case m: ExtModule => m - } - } + override def map(module: DefModule): DefModule = module mapStmt f } implicit def forPorts(f: Port => Port) = new ModuleMagnet { - override def map(module: DefModule): DefModule = { - module match { - case m: Module => Module(m.info, m.name, m.ports.map(f), m.body) - case m: ExtModule => ExtModule(m.info, m.name, m.ports.map(f)) - } - } + override def map(module: DefModule): DefModule = module mapPort f } implicit def forString(f: String => String) = new ModuleMagnet { - override def map(module: DefModule): DefModule = { - module match { - case m: Module => Module(m.info, f(m.name), m.ports, m.body) - case m: ExtModule => ExtModule(m.info, f(m.name), m.ports) - } - } + override def map(module: DefModule): DefModule = module mapString f } } implicit class ModuleMap(module: DefModule) { diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index 9583175e..9ff48446 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -51,31 +51,56 @@ case object UNKNOWNGENDER extends Gender case class WRef(name: String, tpe: Type, kind: Kind, gender: Gender) extends Expression { def serialize: String = name + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this } case class WSubField(exp: Expression, name: String, tpe: Type, gender: Gender) extends Expression { def serialize: String = s"${exp.serialize}.$name" + def mapExpr(f: Expression => Expression): Expression = this.copy(exp = f(exp)) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this } case class WSubIndex(exp: Expression, value: Int, tpe: Type, gender: Gender) extends Expression { def serialize: String = s"${exp.serialize}[$value]" + def mapExpr(f: Expression => Expression): Expression = this.copy(exp = f(exp)) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this } case class WSubAccess(exp: Expression, index: Expression, tpe: Type, gender: Gender) extends Expression { def serialize: String = s"${exp.serialize}[${index.serialize}]" + def mapExpr(f: Expression => Expression): Expression = this.copy(exp = f(exp), index = f(index)) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this } case class WVoid() extends Expression { def tpe = UnknownType def serialize: String = "VOID" + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = this } case class WInvalid() extends Expression { def tpe = UnknownType def serialize: String = "INVALID" + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = this } // Useful for splitting then remerging references case object EmptyExpression extends Expression { def tpe = UnknownType def serialize: String = "EMPTY" + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = this } case class WDefInstance(info: Info, name: String, module: String, tpe: Type) extends Statement with IsDeclaration { def serialize: String = s"inst $name of $module" + info.serialize + def mapExpr(f: Expression => Expression): Statement = this + def mapStmt(f: Statement => Statement): Statement = this + def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) + def mapString(f: String => String): Statement = this.copy(name = f(name)) } // Resultant width is the same as the maximum input width @@ -115,25 +140,33 @@ class WrappedExpression (val e1: Expression) { override def hashCode = e1.serialize.hashCode override def toString = e1.serialize } - -case class VarWidth(name: String) extends Width { +private[firrtl] sealed trait HasMapWidth { + def mapWidth(f: Width => Width): Width +} +case class VarWidth(name: String) extends Width with HasMapWidth { def serialize: String = name + def mapWidth(f: Width => Width): Width = this } -case class PlusWidth(arg1: Width, arg2: Width) extends Width { +case class PlusWidth(arg1: Width, arg2: Width) extends Width with HasMapWidth { def serialize: String = "(" + arg1.serialize + " + " + arg2.serialize + ")" + def mapWidth(f: Width => Width): Width = PlusWidth(f(arg1), f(arg2)) } -case class MinusWidth(arg1: Width, arg2: Width) extends Width { +case class MinusWidth(arg1: Width, arg2: Width) extends Width with HasMapWidth { def serialize: String = "(" + arg1.serialize + " - " + arg2.serialize + ")" + def mapWidth(f: Width => Width): Width = MinusWidth(f(arg1), f(arg2)) } -case class MaxWidth(args: Seq[Width]) extends Width { +case class MaxWidth(args: Seq[Width]) extends Width with HasMapWidth { def serialize: String = args map (_.serialize) mkString ("max(", ", ", ")") + def mapWidth(f: Width => Width): Width = MaxWidth(args map f) } -case class MinWidth(args: Seq[Width]) extends Width { +case class MinWidth(args: Seq[Width]) extends Width with HasMapWidth { def serialize: String = args map (_.serialize) mkString ("min(", ", ", ")") + def mapWidth(f: Width => Width): Width = MinWidth(args map f) } -case class ExpWidth(arg1: Width) extends Width { +case class ExpWidth(arg1: Width) extends Width with HasMapWidth { def serialize: String = "exp(" + arg1.serialize + " )" + def mapWidth(f: Width => Width): Width = ExpWidth(f(arg1)) } object WrappedType { @@ -234,6 +267,10 @@ case class CDefMemory( seq: Boolean) extends Statement { def serialize: String = (if (seq) "smem" else "cmem") + s" $name : ${tpe.serialize} [$size]" + info.serialize + def mapExpr(f: Expression => Expression): Statement = this + def mapStmt(f: Statement => Statement): Statement = this + def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) + def mapString(f: String => String): Statement = this.copy(name = f(name)) } case class CDefMPort(info: Info, name: String, @@ -245,5 +282,9 @@ case class CDefMPort(info: Info, val dir = direction.serialize s"$dir mport $name = $mem[${exps(0).serialize}], ${exps(1).serialize}" + info.serialize } + def mapExpr(f: Expression => Expression): Statement = this.copy(exps = exps map f) + def mapStmt(f: Statement => Statement): Statement = this + def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) + def mapString(f: String => String): Statement = this.copy(name = f(name)) } diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index afe28634..8f1cabf6 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -68,24 +68,46 @@ abstract class PrimOp extends FirrtlNode { abstract class Expression extends FirrtlNode { def tpe: Type + def mapExpr(f: Expression => Expression): Expression + def mapType(f: Type => Type): Expression + def mapWidth(f: Width => Width): Expression } case class Reference(name: String, tpe: Type) extends Expression with HasName { def serialize: String = name + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this } case class SubField(expr: Expression, name: String, tpe: Type) extends Expression with HasName { def serialize: String = s"${expr.serialize}.$name" + def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr)) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this } case class SubIndex(expr: Expression, value: Int, tpe: Type) extends Expression { def serialize: String = s"${expr.serialize}[$value]" + def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr)) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this } case class SubAccess(expr: Expression, index: Expression, tpe: Type) extends Expression { def serialize: String = s"${expr.serialize}[${index.serialize}]" + def mapExpr(f: Expression => Expression): Expression = + this.copy(expr = f(expr), index = f(index)) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this } case class Mux(cond: Expression, tval: Expression, fval: Expression, tpe: Type) extends Expression { def serialize: String = s"mux(${cond.serialize}, ${tval.serialize}, ${fval.serialize})" + def mapExpr(f: Expression => Expression): Expression = Mux(f(cond), f(tval), f(fval), tpe) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this } case class ValidIf(cond: Expression, value: Expression, tpe: Type) extends Expression { def serialize: String = s"validif(${cond.serialize}, ${value.serialize})" + def mapExpr(f: Expression => Expression): Expression = ValidIf(f(cond), f(value), tpe) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this } abstract class Literal extends Expression { val value: BigInt @@ -94,19 +116,37 @@ abstract class Literal extends Expression { case class UIntLiteral(value: BigInt, width: Width) extends Literal { def tpe = UIntType(width) def serialize = s"UInt${width.serialize}(" + Utils.serialize(value) + ")" + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = UIntLiteral(value, f(width)) } case class SIntLiteral(value: BigInt, width: Width) extends Literal { def tpe = SIntType(width) def serialize = s"SInt${width.serialize}(" + Utils.serialize(value) + ")" + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = SIntLiteral(value, f(width)) } case class DoPrim(op: PrimOp, args: Seq[Expression], consts: Seq[BigInt], tpe: Type) extends Expression { def serialize: String = op.serialize + "(" + (args.map(_.serialize) ++ consts.map(_.toString)).mkString(", ") + ")" + def mapExpr(f: Expression => Expression): Expression = this.copy(args = args map f) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this } -abstract class Statement extends FirrtlNode +abstract class Statement extends FirrtlNode { + def mapStmt(f: Statement => Statement): Statement + def mapExpr(f: Expression => Expression): Statement + def mapType(f: Type => Type): Statement + def mapString(f: String => String): Statement +} case class DefWire(info: Info, name: String, tpe: Type) extends Statement with IsDeclaration { def serialize: String = s"wire $name : ${tpe.serialize}" + info.serialize + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = this + def mapType(f: Type => Type): Statement = DefWire(info, name, f(tpe)) + def mapString(f: String => String): Statement = DefWire(info, f(name), tpe) } case class DefRegister( info: Info, @@ -118,10 +158,19 @@ case class DefRegister( def serialize: String = s"reg $name : ${tpe.serialize}, ${clock.serialize} with :" + indent("\n" + s"reset => (${reset.serialize}, ${init.serialize})" + info.serialize) + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = + DefRegister(info, name, tpe, f(clock), f(reset), f(init)) + def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) + def mapString(f: String => String): Statement = this.copy(name = f(name)) } case class DefInstance(info: Info, name: String, module: String) extends Statement with IsDeclaration { def serialize: String = s"inst $name of $module" + info.serialize + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = this + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = DefInstance(info, f(name), module) } case class DefMemory( info: Info, @@ -146,9 +195,17 @@ case class DefMemory( (writers map ("writer => " + _)) ++ (readwriters map ("readwriter => " + _)) ++ Seq("read-under-write => undefined")) mkString "\n") + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = this + def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType)) + def mapString(f: String => String): Statement = this.copy(name = f(name)) } case class DefNode(info: Info, name: String, value: Expression) extends Statement with IsDeclaration { def serialize: String = s"node $name = ${value.serialize}" + info.serialize + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = DefNode(info, name, f(value)) + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = DefNode(info, f(name), value) } case class Conditionally( info: Info, @@ -160,21 +217,45 @@ case class Conditionally( indent("\n" + conseq.serialize) + (if (alt == EmptyStmt) "" else "\nelse :" + indent("\n" + alt.serialize)) + def mapStmt(f: Statement => Statement): Statement = Conditionally(info, pred, f(conseq), f(alt)) + def mapExpr(f: Expression => Expression): Statement = Conditionally(info, f(pred), conseq, alt) + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this } case class Block(stmts: Seq[Statement]) extends Statement { def serialize: String = stmts map (_.serialize) mkString "\n" + def mapStmt(f: Statement => Statement): Statement = Block(stmts map f) + def mapExpr(f: Expression => Expression): Statement = this + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this } case class PartialConnect(info: Info, loc: Expression, expr: Expression) extends Statement with HasInfo { def serialize: String = s"${loc.serialize} <- ${expr.serialize}" + info.serialize + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = PartialConnect(info, f(loc), f(expr)) + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this } case class Connect(info: Info, loc: Expression, expr: Expression) extends Statement with HasInfo { def serialize: String = s"${loc.serialize} <= ${expr.serialize}" + info.serialize + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = Connect(info, f(loc), f(expr)) + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this } case class IsInvalid(info: Info, expr: Expression) extends Statement with HasInfo { def serialize: String = s"${expr.serialize} is invalid" + info.serialize + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = IsInvalid(info, f(expr)) + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this } case class Stop(info: Info, ret: Int, clk: Expression, en: Expression) extends Statement with HasInfo { def serialize: String = s"stop(${clk.serialize}, ${en.serialize}, $ret)" + info.serialize + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = Stop(info, ret, f(clk), f(en)) + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this } case class Print( info: Info, @@ -187,9 +268,17 @@ case class Print( (args map (_.serialize)) "printf(" + (strs mkString ", ") + ")" + info.serialize } + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = Print(info, string, args map f, f(clk), f(en)) + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this } case object EmptyStmt extends Statement { def serialize: String = "skip" + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = this + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this } abstract class Width extends FirrtlNode { @@ -232,29 +321,43 @@ case class Field(name: String, flip: Orientation, tpe: Type) extends FirrtlNode def serialize: String = flip.serialize + name + " : " + tpe.serialize } -abstract class Type extends FirrtlNode +abstract class Type extends FirrtlNode { + def mapType(f: Type => Type): Type + def mapWidth(f: Width => Width): Type +} abstract class GroundType extends Type { val width: Width + def mapType(f: Type => Type): Type = this +} +abstract class AggregateType extends Type { + def mapWidth(f: Width => Width): Type = this } -abstract class AggregateType extends Type case class UIntType(width: Width) extends GroundType { def serialize: String = "UInt" + width.serialize + def mapWidth(f: Width => Width): Type = UIntType(f(width)) } case class SIntType(width: Width) extends GroundType { def serialize: String = "SInt" + width.serialize + def mapWidth(f: Width => Width): Type = SIntType(f(width)) } case class BundleType(fields: Seq[Field]) extends AggregateType { def serialize: String = "{ " + (fields map (_.serialize) mkString ", ") + "}" + def mapType(f: Type => Type): Type = + BundleType(fields map (x => x.copy(tpe = f(x.tpe)))) } case class VectorType(tpe: Type, size: Int) extends AggregateType { def serialize: String = tpe.serialize + s"[$size]" + def mapType(f: Type => Type): Type = this.copy(tpe = f(tpe)) } case object ClockType extends GroundType { val width = IntWidth(1) def serialize: String = "Clock" + def mapWidth(f: Width => Width): Type = this } case object UnknownType extends Type { def serialize: String = "?" + def mapType(f: Type => Type): Type = this + def mapWidth(f: Width => Width): Type = this } /** [[Port]] Direction */ @@ -283,6 +386,9 @@ abstract class DefModule extends FirrtlNode with IsDeclaration { protected def serializeHeader(tpe: String): String = s"$tpe $name :" + info.serialize + indent(ports map ("\n" + _.serialize) mkString) + "\n" + def mapStmt(f: Statement => Statement): DefModule + def mapPort(f: Port => Port): DefModule + def mapString(f: String => String): DefModule } /** Internal Module * @@ -290,6 +396,9 @@ abstract class DefModule extends FirrtlNode with IsDeclaration { */ case class Module(info: Info, name: String, ports: Seq[Port], body: Statement) extends DefModule { def serialize: String = serializeHeader("module") + indent("\n" + body.serialize) + def mapStmt(f: Statement => Statement): DefModule = this.copy(body = f(body)) + def mapPort(f: Port => Port): DefModule = this.copy(ports = ports map f) + def mapString(f: String => String): DefModule = this.copy(name = f(name)) } /** External Module * @@ -297,6 +406,9 @@ case class Module(info: Info, name: String, ports: Seq[Port], body: Statement) e */ case class ExtModule(info: Info, name: String, ports: Seq[Port]) extends DefModule { def serialize: String = serializeHeader("extmodule") + def mapStmt(f: Statement => Statement): DefModule = this + def mapPort(f: Port => Port): DefModule = this.copy(ports = ports map f) + def mapString(f: String => String): DefModule = this.copy(name = f(name)) } case class Circuit(info: Info, modules: Seq[DefModule], main: String) extends FirrtlNode with HasInfo { diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala index c4529bd9..43c0ef1e 100644 --- a/src/main/scala/firrtl/passes/Inline.scala +++ b/src/main/scala/firrtl/passes/Inline.scala @@ -1,12 +1,12 @@ package firrtl package passes -// Datastructures -import scala.collection.mutable - import firrtl.ir._ +import firrtl.Mappers._ import firrtl.Annotations._ -import firrtl.Mappers.{ExpMap, StmtMap} + +// Datastructures +import scala.collection.mutable // Tags an annotation to be consumed by this pass case class InlineAnnotation(target: Named, tID: TransID) extends Annotation with Loose with Unstable { diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala index f2117761..4cdcae59 100644 --- a/src/main/scala/firrtl/passes/PadWidths.scala +++ b/src/main/scala/firrtl/passes/PadWidths.scala @@ -1,10 +1,10 @@ package firrtl package passes -import firrtl.Mappers.{ExpMap, StmtMap} -import firrtl.Utils.long_BANG -import firrtl.PrimOps._ import firrtl.ir._ +import firrtl.PrimOps._ +import firrtl.Mappers._ +import firrtl.Utils.long_BANG // Makes all implicit width extensions and truncations explicit object PadWidths extends Pass { diff --git a/src/main/scala/firrtl/passes/RemoveEmpty.scala b/src/main/scala/firrtl/passes/RemoveEmpty.scala index 7ba2ef09..225e2222 100644 --- a/src/main/scala/firrtl/passes/RemoveEmpty.scala +++ b/src/main/scala/firrtl/passes/RemoveEmpty.scala @@ -2,7 +2,7 @@ package firrtl package passes import scala.collection.mutable -import firrtl.Mappers.{ExpMap, StmtMap} +import firrtl.Mappers._ import firrtl.ir._ object RemoveEmpty extends Pass { diff --git a/src/main/scala/firrtl/passes/RemoveValidIf.scala b/src/main/scala/firrtl/passes/RemoveValidIf.scala index a534cc50..e0a4b621 100644 --- a/src/main/scala/firrtl/passes/RemoveValidIf.scala +++ b/src/main/scala/firrtl/passes/RemoveValidIf.scala @@ -1,6 +1,6 @@ package firrtl package passes -import firrtl.Mappers.{ExpMap, StmtMap} +import firrtl.Mappers._ import firrtl.ir._ // Removes ValidIf as an optimization diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala index 90b92a35..31306046 100644 --- a/src/main/scala/firrtl/passes/SplitExpressions.scala +++ b/src/main/scala/firrtl/passes/SplitExpressions.scala @@ -1,11 +1,12 @@ package firrtl package passes -import firrtl.Mappers.{ExpMap, StmtMap} -import firrtl.Utils.{kind, gender, get_info} import firrtl.ir._ -import scala.collection.mutable +import firrtl.Mappers._ +import firrtl.Utils.{kind, gender, get_info} +// Datastructures +import scala.collection.mutable // Splits compound expressions into simple expressions // and named intermediate nodes |
