diff options
| author | Adam Izraelevitz | 2018-11-27 13:28:12 -0800 |
|---|---|---|
| committer | GitHub | 2018-11-27 13:28:12 -0800 |
| commit | 17d1d2db772f90b039210874aadb11a8a807baba (patch) | |
| tree | f303cee0e5eeafffa73f93ee16a91be7aca1d34b /src/main/scala/firrtl/ir | |
| parent | 82f62e04ed71d4507b72f784b3c230dda1262340 (diff) | |
Add foreach as alternative to map (#952)
* Added Foreachers
* Changed CheckTypes to use foreach
* Check widths now uses foreach
* Finished merge, added foreachers to added stmts
* Address reviewer feedback
Diffstat (limited to 'src/main/scala/firrtl/ir')
| -rw-r--r-- | src/main/scala/firrtl/ir/IR.scala | 137 |
1 files changed, 136 insertions, 1 deletions
diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index faebc7b8..cdf8e194 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -112,24 +112,36 @@ abstract class Expression extends FirrtlNode { def mapExpr(f: Expression => Expression): Expression def mapType(f: Type => Type): Expression def mapWidth(f: Width => Width): Expression + def foreachExpr(f: Expression => Unit): Unit + def foreachType(f: Type => Unit): Unit + def foreachWidth(f: Width => Unit): Unit } case class Reference(name: String, tpe: Type) extends Expression with HasName { def serialize: String = name def mapExpr(f: Expression => Expression): Expression = this def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = Unit } case class SubField(expr: Expression, name: String, tpe: Type) extends Expression with HasName { def serialize: String = s"${expr.serialize}.$name" def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr)) def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = f(expr) + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = Unit } case class SubIndex(expr: Expression, value: Int, tpe: Type) extends Expression { def serialize: String = s"${expr.serialize}[$value]" def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr)) def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = f(expr) + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = Unit } case class SubAccess(expr: Expression, index: Expression, tpe: Type) extends Expression { def serialize: String = s"${expr.serialize}[${index.serialize}]" @@ -137,18 +149,27 @@ case class SubAccess(expr: Expression, index: Expression, tpe: Type) extends Exp this.copy(expr = f(expr), index = f(index)) def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = { f(expr); f(index) } + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = Unit } case class Mux(cond: Expression, tval: Expression, fval: Expression, tpe: Type) extends Expression { def serialize: String = s"mux(${cond.serialize}, ${tval.serialize}, ${fval.serialize})" def mapExpr(f: Expression => Expression): Expression = Mux(f(cond), f(tval), f(fval), tpe) def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = { f(cond); f(tval); f(fval) } + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = Unit } case class ValidIf(cond: Expression, value: Expression, tpe: Type) extends Expression { def serialize: String = s"validif(${cond.serialize}, ${value.serialize})" def mapExpr(f: Expression => Expression): Expression = ValidIf(f(cond), f(value), tpe) def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = { f(cond); f(value) } + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = Unit } abstract class Literal extends Expression { val value: BigInt @@ -160,6 +181,9 @@ case class UIntLiteral(value: BigInt, width: Width) extends Literal { def mapExpr(f: Expression => Expression): Expression = this def mapType(f: Type => Type): Expression = this def mapWidth(f: Width => Width): Expression = UIntLiteral(value, f(width)) + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = Unit + def foreachWidth(f: Width => Unit): Unit = f(width) } object UIntLiteral { def minWidth(value: BigInt): Width = IntWidth(math.max(value.bitLength, 1)) @@ -171,6 +195,9 @@ case class SIntLiteral(value: BigInt, width: Width) extends Literal { def mapExpr(f: Expression => Expression): Expression = this def mapType(f: Type => Type): Expression = this def mapWidth(f: Width => Width): Expression = SIntLiteral(value, f(width)) + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = Unit + def foreachWidth(f: Width => Unit): Unit = f(width) } object SIntLiteral { def minWidth(value: BigInt): Width = IntWidth(value.bitLength + 1) @@ -185,6 +212,9 @@ case class FixedLiteral(value: BigInt, width: Width, point: Width) extends Liter def mapExpr(f: Expression => Expression): Expression = this def mapType(f: Type => Type): Expression = this def mapWidth(f: Width => Width): Expression = FixedLiteral(value, f(width), f(point)) + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = Unit + def foreachWidth(f: Width => Unit): Unit = { f(width); f(point) } } case class DoPrim(op: PrimOp, args: Seq[Expression], consts: Seq[BigInt], tpe: Type) extends Expression { def serialize: String = op.serialize + "(" + @@ -192,6 +222,9 @@ case class DoPrim(op: PrimOp, args: Seq[Expression], consts: Seq[BigInt], tpe: T def mapExpr(f: Expression => Expression): Expression = this.copy(args = args map f) def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = args.foreach(f) + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = Unit } abstract class Statement extends FirrtlNode { @@ -200,6 +233,11 @@ abstract class Statement extends FirrtlNode { def mapType(f: Type => Type): Statement def mapString(f: String => String): Statement def mapInfo(f: Info => Info): Statement + def foreachStmt(f: Statement => Unit): Unit + def foreachExpr(f: Expression => Unit): Unit + def foreachType(f: Type => Unit): Unit + def foreachString(f: String => Unit): Unit + def foreachInfo(f: Info => Unit): Unit } case class DefWire(info: Info, name: String, tpe: Type) extends Statement with IsDeclaration { def serialize: String = s"wire $name : ${tpe.serialize}" + info.serialize @@ -208,6 +246,11 @@ case class DefWire(info: Info, name: String, tpe: Type) extends Statement with I def mapType(f: Type => Type): Statement = DefWire(info, name, f(tpe)) def mapString(f: String => String): Statement = DefWire(info, f(name), tpe) def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } case class DefRegister( info: Info, @@ -225,7 +268,11 @@ case class DefRegister( def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) def mapString(f: String => String): Statement = this.copy(name = f(name)) def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) - + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = { f(clock); f(reset); f(init) } + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } case class DefInstance(info: Info, name: String, module: String) extends Statement with IsDeclaration { def serialize: String = s"inst $name of $module" + info.serialize @@ -234,6 +281,11 @@ case class DefInstance(info: Info, name: String, module: String) extends Stateme def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = DefInstance(info, f(name), module) def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = Unit + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } case class DefMemory( info: Info, @@ -263,6 +315,11 @@ case class DefMemory( def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType)) def mapString(f: String => String): Statement = this.copy(name = f(name)) def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = f(dataType) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } case class DefNode(info: Info, name: String, value: Expression) extends Statement with IsDeclaration { def serialize: String = s"node $name = ${value.serialize}" + info.serialize @@ -271,6 +328,11 @@ case class DefNode(info: Info, name: String, value: Expression) extends Statemen def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = DefNode(info, f(name), value) def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = f(value) + def foreachType(f: Type => Unit): Unit = Unit + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } case class Conditionally( info: Info, @@ -287,6 +349,11 @@ case class Conditionally( def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = { f(conseq); f(alt) } + def foreachExpr(f: Expression => Unit): Unit = f(pred) + def foreachType(f: Type => Unit): Unit = Unit + def foreachString(f: String => Unit): Unit = Unit + def foreachInfo(f: Info => Unit): Unit = f(info) } case class Block(stmts: Seq[Statement]) extends Statement { def serialize: String = stmts map (_.serialize) mkString "\n" @@ -295,6 +362,11 @@ case class Block(stmts: Seq[Statement]) extends Statement { def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this def mapInfo(f: Info => Info): Statement = this + def foreachStmt(f: Statement => Unit): Unit = stmts.foreach(f) + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = Unit + def foreachString(f: String => Unit): Unit = Unit + def foreachInfo(f: Info => Unit): Unit = Unit } case class PartialConnect(info: Info, loc: Expression, expr: Expression) extends Statement with HasInfo { def serialize: String = s"${loc.serialize} <- ${expr.serialize}" + info.serialize @@ -303,6 +375,11 @@ case class PartialConnect(info: Info, loc: Expression, expr: Expression) extends def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = { f(loc); f(expr) } + def foreachType(f: Type => Unit): Unit = Unit + def foreachString(f: String => Unit): Unit = Unit + def foreachInfo(f: Info => Unit): Unit = f(info) } case class Connect(info: Info, loc: Expression, expr: Expression) extends Statement with HasInfo { def serialize: String = s"${loc.serialize} <= ${expr.serialize}" + info.serialize @@ -311,6 +388,11 @@ case class Connect(info: Info, loc: Expression, expr: Expression) extends Statem def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = { f(loc); f(expr) } + def foreachType(f: Type => Unit): Unit = Unit + def foreachString(f: String => Unit): Unit = Unit + def foreachInfo(f: Info => Unit): Unit = f(info) } case class IsInvalid(info: Info, expr: Expression) extends Statement with HasInfo { def serialize: String = s"${expr.serialize} is invalid" + info.serialize @@ -319,6 +401,11 @@ case class IsInvalid(info: Info, expr: Expression) extends Statement with HasInf def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = f(expr) + def foreachType(f: Type => Unit): Unit = Unit + def foreachString(f: String => Unit): Unit = Unit + def foreachInfo(f: Info => Unit): Unit = f(info) } case class Attach(info: Info, exprs: Seq[Expression]) extends Statement with HasInfo { def serialize: String = "attach " + exprs.map(_.serialize).mkString("(", ", ", ")") @@ -327,6 +414,11 @@ case class Attach(info: Info, exprs: Seq[Expression]) extends Statement with Has def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = exprs.foreach(f) + def foreachType(f: Type => Unit): Unit = Unit + def foreachString(f: String => Unit): Unit = Unit + def foreachInfo(f: Info => Unit): Unit = f(info) } case class Stop(info: Info, ret: Int, clk: Expression, en: Expression) extends Statement with HasInfo { def serialize: String = s"stop(${clk.serialize}, ${en.serialize}, $ret)" + info.serialize @@ -335,6 +427,11 @@ case class Stop(info: Info, ret: Int, clk: Expression, en: Expression) extends S def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = { f(clk); f(en) } + def foreachType(f: Type => Unit): Unit = Unit + def foreachString(f: String => Unit): Unit = Unit + def foreachInfo(f: Info => Unit): Unit = f(info) } case class Print( info: Info, @@ -352,6 +449,11 @@ case class Print( def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = { args.foreach(f); f(clk); f(en) } + def foreachType(f: Type => Unit): Unit = Unit + def foreachString(f: String => Unit): Unit = Unit + def foreachInfo(f: Info => Unit): Unit = f(info) } case object EmptyStmt extends Statement { def serialize: String = "skip" @@ -360,6 +462,11 @@ case object EmptyStmt extends Statement { def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this def mapInfo(f: Info => Info): Statement = this + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = Unit + def foreachString(f: String => Unit): Unit = Unit + def foreachInfo(f: Info => Unit): Unit = Unit } abstract class Width extends FirrtlNode { @@ -436,24 +543,30 @@ case class Field(name: String, flip: Orientation, tpe: Type) extends FirrtlNode abstract class Type extends FirrtlNode { def mapType(f: Type => Type): Type def mapWidth(f: Width => Width): Type + def foreachType(f: Type => Unit): Unit + def foreachWidth(f: Width => Unit): Unit } abstract class GroundType extends Type { val width: Width def mapType(f: Type => Type): Type = this + def foreachType(f: Type => Unit): Unit = Unit } object GroundType { def unapply(ground: GroundType): Option[Width] = Some(ground.width) } abstract class AggregateType extends Type { def mapWidth(f: Width => Width): Type = this + def foreachWidth(f: Width => Unit): Unit = Unit } case class UIntType(width: Width) extends GroundType { def serialize: String = "UInt" + width.serialize def mapWidth(f: Width => Width): Type = UIntType(f(width)) + def foreachWidth(f: Width => Unit): Unit = f(width) } case class SIntType(width: Width) extends GroundType { def serialize: String = "SInt" + width.serialize def mapWidth(f: Width => Width): Type = SIntType(f(width)) + def foreachWidth(f: Width => Unit): Unit = f(width) } case class FixedType(width: Width, point: Width) extends GroundType { override def serialize: String = { @@ -461,29 +574,36 @@ case class FixedType(width: Width, point: Width) extends GroundType { s"Fixed${width.serialize}$pstring" } def mapWidth(f: Width => Width): Type = FixedType(f(width), f(point)) + def foreachWidth(f: Width => Unit): Unit = { f(width); f(point) } } case class BundleType(fields: Seq[Field]) extends AggregateType { def serialize: String = "{ " + (fields map (_.serialize) mkString ", ") + "}" def mapType(f: Type => Type): Type = BundleType(fields map (x => x.copy(tpe = f(x.tpe)))) + def foreachType(f: Type => Unit): Unit = fields.foreach{ x => f(x.tpe) } } case class VectorType(tpe: Type, size: Int) extends AggregateType { def serialize: String = tpe.serialize + s"[$size]" def mapType(f: Type => Type): Type = this.copy(tpe = f(tpe)) + def foreachType(f: Type => Unit): Unit = f(tpe) } case object ClockType extends GroundType { val width = IntWidth(1) def serialize: String = "Clock" def mapWidth(f: Width => Width): Type = this + def foreachWidth(f: Width => Unit): Unit = Unit } case class AnalogType(width: Width) extends GroundType { def serialize: String = "Analog" + width.serialize def mapWidth(f: Width => Width): Type = AnalogType(f(width)) + def foreachWidth(f: Width => Unit): Unit = f(width) } case object UnknownType extends Type { def serialize: String = "?" def mapType(f: Type => Type): Type = this def mapWidth(f: Width => Width): Type = this + def foreachType(f: Type => Unit): Unit = Unit + def foreachWidth(f: Width => Unit): Unit = Unit } /** [[Port]] Direction */ @@ -540,6 +660,10 @@ abstract class DefModule extends FirrtlNode with IsDeclaration { def mapPort(f: Port => Port): DefModule def mapString(f: String => String): DefModule def mapInfo(f: Info => Info): DefModule + def foreachStmt(f: Statement => Unit): Unit + def foreachPort(f: Port => Unit): Unit + def foreachString(f: String => Unit): Unit + def foreachInfo(f: Info => Unit): Unit } /** Internal Module * @@ -551,6 +675,10 @@ case class Module(info: Info, name: String, ports: Seq[Port], body: Statement) e def mapPort(f: Port => Port): DefModule = this.copy(ports = ports map f) def mapString(f: String => String): DefModule = this.copy(name = f(name)) def mapInfo(f: Info => Info): DefModule = this.copy(f(info)) + def foreachStmt(f: Statement => Unit): Unit = f(body) + def foreachPort(f: Port => Unit): Unit = ports.foreach(f) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } /** External Module * @@ -569,6 +697,10 @@ case class ExtModule( def mapPort(f: Port => Port): DefModule = this.copy(ports = ports map f) def mapString(f: String => String): DefModule = this.copy(name = f(name)) def mapInfo(f: Info => Info): DefModule = this.copy(f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachPort(f: Port => Unit): Unit = ports.foreach(f) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } case class Circuit(info: Info, modules: Seq[DefModule], main: String) extends FirrtlNode with HasInfo { @@ -578,4 +710,7 @@ case class Circuit(info: Info, modules: Seq[DefModule], main: String) extends Fi def mapModule(f: DefModule => DefModule): Circuit = this.copy(modules = modules map f) def mapString(f: String => String): Circuit = this.copy(main = f(main)) def mapInfo(f: Info => Info): Circuit = this.copy(f(info)) + def foreachModule(f: DefModule => Unit): Unit = modules foreach f + def foreachString(f: String => Unit): Unit = f(main) + def foreachInfo(f: Info => Unit): Unit = f(info) } |
