diff options
| author | Adam Izraelevitz | 2018-11-27 13:28:12 -0800 |
|---|---|---|
| committer | GitHub | 2018-11-27 13:28:12 -0800 |
| commit | 17d1d2db772f90b039210874aadb11a8a807baba (patch) | |
| tree | f303cee0e5eeafffa73f93ee16a91be7aca1d34b /src | |
| parent | 82f62e04ed71d4507b72f784b3c230dda1262340 (diff) | |
Add foreach as alternative to map (#952)
* Added Foreachers
* Changed CheckTypes to use foreach
* Check widths now uses foreach
* Finished merge, added foreachers to added stmts
* Address reviewer feedback
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/AddDescriptionNodes.scala | 9 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Emitter.scala | 3 | ||||
| -rw-r--r-- | src/main/scala/firrtl/WIR.scala | 41 | ||||
| -rw-r--r-- | src/main/scala/firrtl/ir/IR.scala | 137 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/CheckWidths.scala | 41 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Checks.scala | 86 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/memlib/MemIR.scala | 5 | ||||
| -rw-r--r-- | src/main/scala/firrtl/traversals/Foreachers.scala | 113 |
8 files changed, 370 insertions, 65 deletions
diff --git a/src/main/scala/firrtl/AddDescriptionNodes.scala b/src/main/scala/firrtl/AddDescriptionNodes.scala index 6bae6857..1ed3259f 100644 --- a/src/main/scala/firrtl/AddDescriptionNodes.scala +++ b/src/main/scala/firrtl/AddDescriptionNodes.scala @@ -36,6 +36,11 @@ private case class DescribedStmt(description: Description, stmt: Statement) exte def mapType(f: Type => Type): Statement = this.copy(stmt = stmt.mapType(f)) def mapString(f: String => String): Statement = this.copy(stmt = stmt.mapString(f)) def mapInfo(f: Info => Info): Statement = this.copy(stmt = stmt.mapInfo(f)) + def foreachStmt(f: Statement => Unit): Unit = f(stmt) + def foreachExpr(f: Expression => Unit): Unit = stmt.foreachExpr(f) + def foreachType(f: Type => Unit): Unit = stmt.foreachType(f) + def foreachString(f: String => Unit): Unit = stmt.foreachString(f) + def foreachInfo(f: Info => Unit): Unit = stmt.foreachInfo(f) } private case class DescribedMod(description: Description, @@ -49,6 +54,10 @@ private case class DescribedMod(description: Description, def mapPort(f: Port => Port): DefModule = this.copy(mod = mod.mapPort(f)) def mapString(f: String => String): DefModule = this.copy(mod = mod.mapString(f)) def mapInfo(f: Info => Info): DefModule = this.copy(mod = mod.mapInfo(f)) + def foreachStmt(f: Statement => Unit): Unit = mod.foreachStmt(f) + def foreachPort(f: Port => Unit): Unit = mod.foreachPort(f) + def foreachString(f: String => Unit): Unit = mod.foreachString(f) + def foreachInfo(f: Info => Unit): Unit = mod.foreachInfo(f) } /** Wraps modules or statements with their respective described nodes. diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index 578782ce..7be049ed 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -125,6 +125,9 @@ case class VRandom(width: BigInt) extends Expression { def mapExpr(f: Expression => Expression): Expression = this def mapType(f: Type => Type): Expression = this def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = Unit + def foreachWidth(f: Width => Unit): Unit = Unit } class VerilogEmitter extends SeqTransform with Emitter { diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index f61fa41e..b96cd253 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -29,6 +29,9 @@ case class WRef(name: String, tpe: Type, kind: Kind, gender: Gender) extends Exp def mapExpr(f: Expression => Expression): Expression = this def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = Unit } object WRef { /** Creates a WRef from a Wire */ @@ -44,6 +47,9 @@ case class WSubField(expr: Expression, name: String, tpe: Type, gender: Gender) def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr)) def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = f(expr) + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = Unit } object WSubField { def apply(expr: Expression, n: String): WSubField = new WSubField(expr, n, field_type(expr.tpe, n), UNKNOWNGENDER) @@ -54,12 +60,18 @@ case class WSubIndex(expr: Expression, value: Int, tpe: Type, gender: Gender) ex def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr)) def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = f(expr) + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = Unit } case class WSubAccess(expr: Expression, index: Expression, tpe: Type, gender: Gender) extends Expression { def serialize: String = s"${expr.serialize}[${index.serialize}]" def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr), index = f(index)) def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = { f(expr); f(index) } + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = Unit } case object WVoid extends Expression { def tpe = UnknownType @@ -67,6 +79,9 @@ case object WVoid extends Expression { def mapExpr(f: Expression => Expression): Expression = this def mapType(f: Type => Type): Expression = this def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = Unit + def foreachWidth(f: Width => Unit): Unit = Unit } case object WInvalid extends Expression { def tpe = UnknownType @@ -74,6 +89,9 @@ case object WInvalid extends Expression { def mapExpr(f: Expression => Expression): Expression = this def mapType(f: Type => Type): Expression = this def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = Unit + def foreachWidth(f: Width => Unit): Unit = Unit } // Useful for splitting then remerging references case object EmptyExpression extends Expression { @@ -82,6 +100,9 @@ case object EmptyExpression extends Expression { def mapExpr(f: Expression => Expression): Expression = this def mapType(f: Type => Type): Expression = this def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = Unit + def foreachWidth(f: Width => Unit): Unit = Unit } case class WDefInstance(info: Info, name: String, module: String, tpe: Type) extends Statement with IsDeclaration { def serialize: String = s"inst $name of $module" + info.serialize @@ -90,6 +111,11 @@ case class WDefInstance(info: Info, name: String, module: String, tpe: Type) ext def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) def mapString(f: String => String): Statement = this.copy(name = f(name)) def mapInfo(f: Info => Info): Statement = this.copy(f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } object WDefInstance { def apply(name: String, module: String): WDefInstance = new WDefInstance(NoInfo, name, module, UnknownType) @@ -108,6 +134,11 @@ case class WDefInstanceConnector( def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) def mapString(f: String => String): Statement = this.copy(name = f(name)) def mapInfo(f: Info => Info): Statement = this.copy(f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = portCons foreach { case (e1, e2) => (f(e1), f(e2)) } + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } // Resultant width is the same as the maximum input width @@ -285,6 +316,11 @@ case class CDefMemory( def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) def mapString(f: String => String): Statement = this.copy(name = f(name)) def mapInfo(f: Info => Info): Statement = this.copy(f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } case class CDefMPort(info: Info, name: String, @@ -301,5 +337,10 @@ case class CDefMPort(info: Info, def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) def mapString(f: String => String): Statement = this.copy(name = f(name)) def mapInfo(f: Info => Info): Statement = this.copy(f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = exps.foreach(f) + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index faebc7b8..cdf8e194 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -112,24 +112,36 @@ abstract class Expression extends FirrtlNode { def mapExpr(f: Expression => Expression): Expression def mapType(f: Type => Type): Expression def mapWidth(f: Width => Width): Expression + def foreachExpr(f: Expression => Unit): Unit + def foreachType(f: Type => Unit): Unit + def foreachWidth(f: Width => Unit): Unit } case class Reference(name: String, tpe: Type) extends Expression with HasName { def serialize: String = name def mapExpr(f: Expression => Expression): Expression = this def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = Unit } case class SubField(expr: Expression, name: String, tpe: Type) extends Expression with HasName { def serialize: String = s"${expr.serialize}.$name" def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr)) def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = f(expr) + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = Unit } case class SubIndex(expr: Expression, value: Int, tpe: Type) extends Expression { def serialize: String = s"${expr.serialize}[$value]" def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr)) def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = f(expr) + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = Unit } case class SubAccess(expr: Expression, index: Expression, tpe: Type) extends Expression { def serialize: String = s"${expr.serialize}[${index.serialize}]" @@ -137,18 +149,27 @@ case class SubAccess(expr: Expression, index: Expression, tpe: Type) extends Exp this.copy(expr = f(expr), index = f(index)) def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = { f(expr); f(index) } + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = Unit } case class Mux(cond: Expression, tval: Expression, fval: Expression, tpe: Type) extends Expression { def serialize: String = s"mux(${cond.serialize}, ${tval.serialize}, ${fval.serialize})" def mapExpr(f: Expression => Expression): Expression = Mux(f(cond), f(tval), f(fval), tpe) def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = { f(cond); f(tval); f(fval) } + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = Unit } case class ValidIf(cond: Expression, value: Expression, tpe: Type) extends Expression { def serialize: String = s"validif(${cond.serialize}, ${value.serialize})" def mapExpr(f: Expression => Expression): Expression = ValidIf(f(cond), f(value), tpe) def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = { f(cond); f(value) } + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = Unit } abstract class Literal extends Expression { val value: BigInt @@ -160,6 +181,9 @@ case class UIntLiteral(value: BigInt, width: Width) extends Literal { def mapExpr(f: Expression => Expression): Expression = this def mapType(f: Type => Type): Expression = this def mapWidth(f: Width => Width): Expression = UIntLiteral(value, f(width)) + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = Unit + def foreachWidth(f: Width => Unit): Unit = f(width) } object UIntLiteral { def minWidth(value: BigInt): Width = IntWidth(math.max(value.bitLength, 1)) @@ -171,6 +195,9 @@ case class SIntLiteral(value: BigInt, width: Width) extends Literal { def mapExpr(f: Expression => Expression): Expression = this def mapType(f: Type => Type): Expression = this def mapWidth(f: Width => Width): Expression = SIntLiteral(value, f(width)) + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = Unit + def foreachWidth(f: Width => Unit): Unit = f(width) } object SIntLiteral { def minWidth(value: BigInt): Width = IntWidth(value.bitLength + 1) @@ -185,6 +212,9 @@ case class FixedLiteral(value: BigInt, width: Width, point: Width) extends Liter def mapExpr(f: Expression => Expression): Expression = this def mapType(f: Type => Type): Expression = this def mapWidth(f: Width => Width): Expression = FixedLiteral(value, f(width), f(point)) + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = Unit + def foreachWidth(f: Width => Unit): Unit = { f(width); f(point) } } case class DoPrim(op: PrimOp, args: Seq[Expression], consts: Seq[BigInt], tpe: Type) extends Expression { def serialize: String = op.serialize + "(" + @@ -192,6 +222,9 @@ case class DoPrim(op: PrimOp, args: Seq[Expression], consts: Seq[BigInt], tpe: T def mapExpr(f: Expression => Expression): Expression = this.copy(args = args map f) def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = args.foreach(f) + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = Unit } abstract class Statement extends FirrtlNode { @@ -200,6 +233,11 @@ abstract class Statement extends FirrtlNode { def mapType(f: Type => Type): Statement def mapString(f: String => String): Statement def mapInfo(f: Info => Info): Statement + def foreachStmt(f: Statement => Unit): Unit + def foreachExpr(f: Expression => Unit): Unit + def foreachType(f: Type => Unit): Unit + def foreachString(f: String => Unit): Unit + def foreachInfo(f: Info => Unit): Unit } case class DefWire(info: Info, name: String, tpe: Type) extends Statement with IsDeclaration { def serialize: String = s"wire $name : ${tpe.serialize}" + info.serialize @@ -208,6 +246,11 @@ case class DefWire(info: Info, name: String, tpe: Type) extends Statement with I def mapType(f: Type => Type): Statement = DefWire(info, name, f(tpe)) def mapString(f: String => String): Statement = DefWire(info, f(name), tpe) def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } case class DefRegister( info: Info, @@ -225,7 +268,11 @@ case class DefRegister( def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) def mapString(f: String => String): Statement = this.copy(name = f(name)) def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) - + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = { f(clock); f(reset); f(init) } + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } case class DefInstance(info: Info, name: String, module: String) extends Statement with IsDeclaration { def serialize: String = s"inst $name of $module" + info.serialize @@ -234,6 +281,11 @@ case class DefInstance(info: Info, name: String, module: String) extends Stateme def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = DefInstance(info, f(name), module) def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = Unit + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } case class DefMemory( info: Info, @@ -263,6 +315,11 @@ case class DefMemory( def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType)) def mapString(f: String => String): Statement = this.copy(name = f(name)) def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = f(dataType) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } case class DefNode(info: Info, name: String, value: Expression) extends Statement with IsDeclaration { def serialize: String = s"node $name = ${value.serialize}" + info.serialize @@ -271,6 +328,11 @@ case class DefNode(info: Info, name: String, value: Expression) extends Statemen def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = DefNode(info, f(name), value) def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = f(value) + def foreachType(f: Type => Unit): Unit = Unit + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } case class Conditionally( info: Info, @@ -287,6 +349,11 @@ case class Conditionally( def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = { f(conseq); f(alt) } + def foreachExpr(f: Expression => Unit): Unit = f(pred) + def foreachType(f: Type => Unit): Unit = Unit + def foreachString(f: String => Unit): Unit = Unit + def foreachInfo(f: Info => Unit): Unit = f(info) } case class Block(stmts: Seq[Statement]) extends Statement { def serialize: String = stmts map (_.serialize) mkString "\n" @@ -295,6 +362,11 @@ case class Block(stmts: Seq[Statement]) extends Statement { def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this def mapInfo(f: Info => Info): Statement = this + def foreachStmt(f: Statement => Unit): Unit = stmts.foreach(f) + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = Unit + def foreachString(f: String => Unit): Unit = Unit + def foreachInfo(f: Info => Unit): Unit = Unit } case class PartialConnect(info: Info, loc: Expression, expr: Expression) extends Statement with HasInfo { def serialize: String = s"${loc.serialize} <- ${expr.serialize}" + info.serialize @@ -303,6 +375,11 @@ case class PartialConnect(info: Info, loc: Expression, expr: Expression) extends def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = { f(loc); f(expr) } + def foreachType(f: Type => Unit): Unit = Unit + def foreachString(f: String => Unit): Unit = Unit + def foreachInfo(f: Info => Unit): Unit = f(info) } case class Connect(info: Info, loc: Expression, expr: Expression) extends Statement with HasInfo { def serialize: String = s"${loc.serialize} <= ${expr.serialize}" + info.serialize @@ -311,6 +388,11 @@ case class Connect(info: Info, loc: Expression, expr: Expression) extends Statem def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = { f(loc); f(expr) } + def foreachType(f: Type => Unit): Unit = Unit + def foreachString(f: String => Unit): Unit = Unit + def foreachInfo(f: Info => Unit): Unit = f(info) } case class IsInvalid(info: Info, expr: Expression) extends Statement with HasInfo { def serialize: String = s"${expr.serialize} is invalid" + info.serialize @@ -319,6 +401,11 @@ case class IsInvalid(info: Info, expr: Expression) extends Statement with HasInf def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = f(expr) + def foreachType(f: Type => Unit): Unit = Unit + def foreachString(f: String => Unit): Unit = Unit + def foreachInfo(f: Info => Unit): Unit = f(info) } case class Attach(info: Info, exprs: Seq[Expression]) extends Statement with HasInfo { def serialize: String = "attach " + exprs.map(_.serialize).mkString("(", ", ", ")") @@ -327,6 +414,11 @@ case class Attach(info: Info, exprs: Seq[Expression]) extends Statement with Has def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = exprs.foreach(f) + def foreachType(f: Type => Unit): Unit = Unit + def foreachString(f: String => Unit): Unit = Unit + def foreachInfo(f: Info => Unit): Unit = f(info) } case class Stop(info: Info, ret: Int, clk: Expression, en: Expression) extends Statement with HasInfo { def serialize: String = s"stop(${clk.serialize}, ${en.serialize}, $ret)" + info.serialize @@ -335,6 +427,11 @@ case class Stop(info: Info, ret: Int, clk: Expression, en: Expression) extends S def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = { f(clk); f(en) } + def foreachType(f: Type => Unit): Unit = Unit + def foreachString(f: String => Unit): Unit = Unit + def foreachInfo(f: Info => Unit): Unit = f(info) } case class Print( info: Info, @@ -352,6 +449,11 @@ case class Print( def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = { args.foreach(f); f(clk); f(en) } + def foreachType(f: Type => Unit): Unit = Unit + def foreachString(f: String => Unit): Unit = Unit + def foreachInfo(f: Info => Unit): Unit = f(info) } case object EmptyStmt extends Statement { def serialize: String = "skip" @@ -360,6 +462,11 @@ case object EmptyStmt extends Statement { def mapType(f: Type => Type): Statement = this def mapString(f: String => String): Statement = this def mapInfo(f: Info => Info): Statement = this + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = Unit + def foreachString(f: String => Unit): Unit = Unit + def foreachInfo(f: Info => Unit): Unit = Unit } abstract class Width extends FirrtlNode { @@ -436,24 +543,30 @@ case class Field(name: String, flip: Orientation, tpe: Type) extends FirrtlNode abstract class Type extends FirrtlNode { def mapType(f: Type => Type): Type def mapWidth(f: Width => Width): Type + def foreachType(f: Type => Unit): Unit + def foreachWidth(f: Width => Unit): Unit } abstract class GroundType extends Type { val width: Width def mapType(f: Type => Type): Type = this + def foreachType(f: Type => Unit): Unit = Unit } object GroundType { def unapply(ground: GroundType): Option[Width] = Some(ground.width) } abstract class AggregateType extends Type { def mapWidth(f: Width => Width): Type = this + def foreachWidth(f: Width => Unit): Unit = Unit } case class UIntType(width: Width) extends GroundType { def serialize: String = "UInt" + width.serialize def mapWidth(f: Width => Width): Type = UIntType(f(width)) + def foreachWidth(f: Width => Unit): Unit = f(width) } case class SIntType(width: Width) extends GroundType { def serialize: String = "SInt" + width.serialize def mapWidth(f: Width => Width): Type = SIntType(f(width)) + def foreachWidth(f: Width => Unit): Unit = f(width) } case class FixedType(width: Width, point: Width) extends GroundType { override def serialize: String = { @@ -461,29 +574,36 @@ case class FixedType(width: Width, point: Width) extends GroundType { s"Fixed${width.serialize}$pstring" } def mapWidth(f: Width => Width): Type = FixedType(f(width), f(point)) + def foreachWidth(f: Width => Unit): Unit = { f(width); f(point) } } case class BundleType(fields: Seq[Field]) extends AggregateType { def serialize: String = "{ " + (fields map (_.serialize) mkString ", ") + "}" def mapType(f: Type => Type): Type = BundleType(fields map (x => x.copy(tpe = f(x.tpe)))) + def foreachType(f: Type => Unit): Unit = fields.foreach{ x => f(x.tpe) } } case class VectorType(tpe: Type, size: Int) extends AggregateType { def serialize: String = tpe.serialize + s"[$size]" def mapType(f: Type => Type): Type = this.copy(tpe = f(tpe)) + def foreachType(f: Type => Unit): Unit = f(tpe) } case object ClockType extends GroundType { val width = IntWidth(1) def serialize: String = "Clock" def mapWidth(f: Width => Width): Type = this + def foreachWidth(f: Width => Unit): Unit = Unit } case class AnalogType(width: Width) extends GroundType { def serialize: String = "Analog" + width.serialize def mapWidth(f: Width => Width): Type = AnalogType(f(width)) + def foreachWidth(f: Width => Unit): Unit = f(width) } case object UnknownType extends Type { def serialize: String = "?" def mapType(f: Type => Type): Type = this def mapWidth(f: Width => Width): Type = this + def foreachType(f: Type => Unit): Unit = Unit + def foreachWidth(f: Width => Unit): Unit = Unit } /** [[Port]] Direction */ @@ -540,6 +660,10 @@ abstract class DefModule extends FirrtlNode with IsDeclaration { def mapPort(f: Port => Port): DefModule def mapString(f: String => String): DefModule def mapInfo(f: Info => Info): DefModule + def foreachStmt(f: Statement => Unit): Unit + def foreachPort(f: Port => Unit): Unit + def foreachString(f: String => Unit): Unit + def foreachInfo(f: Info => Unit): Unit } /** Internal Module * @@ -551,6 +675,10 @@ case class Module(info: Info, name: String, ports: Seq[Port], body: Statement) e def mapPort(f: Port => Port): DefModule = this.copy(ports = ports map f) def mapString(f: String => String): DefModule = this.copy(name = f(name)) def mapInfo(f: Info => Info): DefModule = this.copy(f(info)) + def foreachStmt(f: Statement => Unit): Unit = f(body) + def foreachPort(f: Port => Unit): Unit = ports.foreach(f) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } /** External Module * @@ -569,6 +697,10 @@ case class ExtModule( def mapPort(f: Port => Port): DefModule = this.copy(ports = ports map f) def mapString(f: String => String): DefModule = this.copy(name = f(name)) def mapInfo(f: Info => Info): DefModule = this.copy(f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachPort(f: Port => Unit): Unit = ports.foreach(f) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } case class Circuit(info: Info, modules: Seq[DefModule], main: String) extends FirrtlNode with HasInfo { @@ -578,4 +710,7 @@ case class Circuit(info: Info, modules: Seq[DefModule], main: String) extends Fi def mapModule(f: DefModule => DefModule): Circuit = this.copy(modules = modules map f) def mapString(f: String => String): Circuit = this.copy(main = f(main)) def mapInfo(f: Info => Info): Circuit = this.copy(f(info)) + def foreachModule(f: DefModule => Unit): Unit = modules foreach f + def foreachString(f: String => Unit): Unit = f(main) + def foreachInfo(f: Info => Unit): Unit = f(info) } diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala index 4a72b18c..061c6b16 100644 --- a/src/main/scala/firrtl/passes/CheckWidths.scala +++ b/src/main/scala/firrtl/passes/CheckWidths.scala @@ -5,9 +5,9 @@ package firrtl.passes import firrtl._ import firrtl.ir._ import firrtl.PrimOps._ -import firrtl.Mappers._ +import firrtl.traversals.Foreachers._ import firrtl.Utils._ -import firrtl.annotations.{Target, TargetToken, CircuitTarget, ModuleTarget} +import firrtl.annotations.{CircuitTarget, ModuleTarget, Target, TargetToken} object CheckWidths extends Pass { /** The maximum allowed width for any circuit element */ @@ -36,7 +36,7 @@ object CheckWidths extends Pass { def run(c: Circuit): Circuit = { val errors = new Errors() - def check_width_w(info: Info, target: Target)(w: Width): Width = { + def check_width_w(info: Info, target: Target)(w: Width): Unit = { w match { case IntWidth(width) if width >= MaxWidth => errors.append(new WidthTooBig(info, target.serialize, width)) @@ -46,7 +46,6 @@ object CheckWidths extends Pass { case _ => errors append new UninferredWidth(info, target.prettyPrint(" ")) } - w } def hasWidth(tpe: Type): Boolean = tpe match { @@ -55,18 +54,18 @@ object CheckWidths extends Pass { case _ => throwInternalError(s"hasWidth - $tpe") } - def check_width_t(info: Info, target: Target)(t: Type): Type = { - val tx = t match { - case tt: BundleType => BundleType(tt.fields.map(check_width_f(info, target))) - case tt => tt map check_width_t(info, target) + def check_width_t(info: Info, target: Target)(t: Type): Unit = { + t match { + case tt: BundleType => tt.fields.foreach(check_width_f(info, target)) + case tt => tt foreach check_width_t(info, target) } - tx map check_width_w(info, target) + t foreach check_width_w(info, target) } - def check_width_f(info: Info, target: Target)(f: Field): Field = f - .copy(tpe = check_width_t(info, target.modify(tokens = target.tokens :+ TargetToken.Field(f.name)))(f.tpe)) + def check_width_f(info: Info, target: Target)(f: Field): Unit = + check_width_t(info, target.modify(tokens = target.tokens :+ TargetToken.Field(f.name)))(f.tpe) - def check_width_e(info: Info, target: Target)(e: Expression): Expression = { + def check_width_e(info: Info, target: Target)(e: Expression): Unit = { e match { case e: UIntLiteral => e.width match { case w: IntWidth if math.max(1, e.value.bitLength) > w.width => @@ -89,34 +88,36 @@ object CheckWidths extends Pass { case _ => } //e map check_width_t(info, mname) map check_width_e(info, mname) - e map check_width_e(info, target) + e foreach check_width_e(info, target) } - def check_width_s(minfo: Info, target: ModuleTarget)(s: Statement): Statement = { + def check_width_s(minfo: Info, target: ModuleTarget)(s: Statement): Unit = { val info = get_info(s) match { case NoInfo => minfo case x => x } val subRef = s match { case sx: HasName => target.ref(sx.name) case _ => target } - s map check_width_e(info, target) map check_width_s(info, target) map check_width_t(info, subRef) match { + s foreach check_width_e(info, target) + s foreach check_width_s(info, target) + s foreach check_width_t(info, subRef) + s match { case Attach(infox, exprs) => exprs.tail.foreach ( e => if (bitWidth(e.tpe) != bitWidth(exprs.head.tpe)) errors.append(new AttachWidthsNotEqual(infox, target.serialize, e.serialize, exprs.head.serialize)) ) - s case sx: DefRegister => sx.reset.tpe match { case UIntType(IntWidth(w)) if w == 1 => case _ => errors.append(new CheckTypes.IllegalResetType(info, target.serialize, sx.name)) } - s - case _ => s + case _ => } } - def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Port = p.copy(tpe = check_width_t(p.info, target)(p.tpe)) + def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Unit = check_width_t(p.info, target)(p.tpe) def check_width_m(circuit: CircuitTarget)(m: DefModule) { - m map check_width_p(m.info, circuit.module(m.name)) map check_width_s(m.info, circuit.module(m.name)) + m foreach check_width_p(m.info, circuit.module(m.name)) + m foreach check_width_s(m.info, circuit.module(m.name)) } c.modules foreach check_width_m(CircuitTarget(c.main)) diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index 4c7458bf..bc9d3a1c 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -7,6 +7,7 @@ import firrtl.ir._ import firrtl.PrimOps._ import firrtl.Utils._ import firrtl.Mappers._ +import firrtl.traversals.Foreachers._ import firrtl.WrappedType._ object CheckHighForm extends Pass { @@ -116,32 +117,29 @@ object CheckHighForm extends Pass { case _ => // Do Nothing } - def checkHighFormW(info: Info, mname: String)(w: Width): Width = { + def checkHighFormW(info: Info, mname: String)(w: Width): Unit = { w match { - case wx: IntWidth if wx.width < 0 => - errors.append(new NegWidthException(info, mname)) + case wx: IntWidth if wx.width < 0 => errors.append(new NegWidthException(info, mname)) case wx => // Do nothing } - w } - def checkHighFormT(info: Info, mname: String)(t: Type): Type = - t map checkHighFormT(info, mname) match { - case tx: VectorType if tx.size < 0 => - errors.append(new NegVecSizeException(info, mname)) - t - case _ => t map checkHighFormW(info, mname) + def checkHighFormT(info: Info, mname: String)(t: Type): Unit = { + t foreach checkHighFormT(info, mname) + t match { + case tx: VectorType if tx.size < 0 => errors.append(new NegVecSizeException(info, mname)) + case _ => t foreach checkHighFormW(info, mname) } + } - def validSubexp(info: Info, mname: String)(e: Expression): Expression = { + def validSubexp(info: Info, mname: String)(e: Expression): Unit = { e match { case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess | _: Mux | _: ValidIf => // No error case _ => errors.append(new InvalidAccessException(info, mname)) } - e } - def checkHighFormE(info: Info, mname: String, names: NameSet)(e: Expression): Expression = { + def checkHighFormE(info: Info, mname: String, names: NameSet)(e: Expression): Unit = { e match { case ex: WRef if !names(ex.name) => errors.append(new UndeclaredReferenceException(info, mname, ex.name)) @@ -150,23 +148,23 @@ object CheckHighForm extends Pass { case ex: DoPrim => checkHighFormPrimop(info, mname, ex) case _: WRef | _: UIntLiteral | _: Mux | _: ValidIf => case ex: WSubAccess => validSubexp(info, mname)(ex.expr) - case ex => ex map validSubexp(info, mname) + case ex => ex foreach validSubexp(info, mname) } - (e map checkHighFormW(info, mname) - map checkHighFormT(info, mname) - map checkHighFormE(info, mname, names)) + e foreach checkHighFormW(info, mname) + e foreach checkHighFormT(info, mname) + e foreach checkHighFormE(info, mname, names) } - def checkName(info: Info, mname: String, names: NameSet)(name: String): String = { + def checkName(info: Info, mname: String, names: NameSet)(name: String): Unit = { if (names(name)) errors.append(new NotUniqueException(info, mname, name)) names += name - name } - def checkHighFormS(minfo: Info, mname: String, names: NameSet)(s: Statement): Statement = { + def checkHighFormS(minfo: Info, mname: String, names: NameSet)(s: Statement): Unit = { val info = get_info(s) match {case NoInfo => minfo case x => x} - s map checkName(info, mname, names) match { + s foreach checkName(info, mname, names) + s match { case sx: DefMemory => if (hasFlip(sx.dataType)) errors.append(new MemWithFlipException(info, mname, sx.name)) @@ -184,24 +182,23 @@ object CheckHighForm extends Pass { case sx: Print => checkFstring(info, mname, sx.string, sx.args.length) case sx => // Do Nothing } - (s map checkHighFormT(info, mname) - map checkHighFormE(info, mname, names) - map checkHighFormS(minfo, mname, names)) + s foreach checkHighFormT(info, mname) + s foreach checkHighFormE(info, mname, names) + s foreach checkHighFormS(minfo, mname, names) } - def checkHighFormP(mname: String, names: NameSet)(p: Port): Port = { + def checkHighFormP(mname: String, names: NameSet)(p: Port): Unit = { if (names(p.name)) errors.append(new NotUniqueException(NoInfo, mname, p.name)) names += p.name - (p.tpe map checkHighFormT(p.info, mname) - map checkHighFormW(p.info, mname)) - p + p.tpe foreach checkHighFormT(p.info, mname) + p.tpe foreach checkHighFormW(p.info, mname) } def checkHighFormM(m: DefModule) { val names = new NameSet - (m map checkHighFormP(m.name, names) - map checkHighFormS(m.info, m.name, names)) + m foreach checkHighFormP(m.name, names) + m foreach checkHighFormS(m.info, m.name, names) } c.modules foreach checkHighFormM @@ -333,7 +330,7 @@ object CheckTypes extends Pass { } } - def check_types_e(info:Info, mname: String)(e: Expression): Expression = { + def check_types_e(info:Info, mname: String)(e: Expression): Unit = { e match { case (e: WSubField) => e.expr.tpe match { case (t: BundleType) => t.fields find (_.name == e.name) match { @@ -377,7 +374,7 @@ object CheckTypes extends Pass { } case _ => } - e map check_types_e(info, mname) + e foreach check_types_e(info, mname) } def bulk_equals(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Boolean = { @@ -404,7 +401,7 @@ object CheckTypes extends Pass { } } - def check_types_s(minfo: Info, mname: String)(s: Statement): Statement = { + def check_types_s(minfo: Info, mname: String)(s: Statement): Unit = { val info = get_info(s) match { case NoInfo => minfo case x => x } s match { case sx: Connect if wt(sx.loc.tpe) != wt(sx.expr.tpe) => @@ -457,10 +454,11 @@ object CheckTypes extends Pass { } case _ => } - s map check_types_e(info, mname) map check_types_s(info, mname) + s foreach check_types_e(info, mname) + s foreach check_types_s(info, mname) } - c.modules foreach (m => m map check_types_s(m.info, m.name)) + c.modules foreach (m => m foreach check_types_s(m.info, m.name)) errors.trigger() c } @@ -504,7 +502,7 @@ object CheckGenders extends Pass { flip_rec(t, Default) } - def check_gender(info:Info, mname: String, genders: GenderMap, desired: Gender)(e:Expression): Expression = { + def check_gender(info:Info, mname: String, genders: GenderMap, desired: Gender)(e:Expression): Unit = { val gender = get_gender(e,genders) (gender, desired) match { case (MALE, FEMALE) => @@ -516,19 +514,18 @@ object CheckGenders extends Pass { } case _ => } - e } - def check_genders_e (info:Info, mname: String, genders: GenderMap)(e:Expression): Expression = { + def check_genders_e (info:Info, mname: String, genders: GenderMap)(e:Expression): Unit = { e match { - case e: Mux => e map check_gender(info, mname, genders, MALE) - case e: DoPrim => e.args map check_gender(info, mname, genders, MALE) + case e: Mux => e foreach check_gender(info, mname, genders, MALE) + case e: DoPrim => e.args foreach check_gender(info, mname, genders, MALE) case _ => } - e map check_genders_e(info, mname, genders) + e foreach check_genders_e(info, mname, genders) } - def check_genders_s(minfo: Info, mname: String, genders: GenderMap)(s: Statement): Statement = { + def check_genders_s(minfo: Info, mname: String, genders: GenderMap)(s: Statement): Unit = { val info = get_info(s) match { case NoInfo => minfo case x => x } s match { case (s: DefWire) => genders(s.name) = BIGENDER @@ -555,13 +552,14 @@ object CheckGenders extends Pass { check_gender(info, mname, genders, MALE)(s.clk) case _ => } - s map check_genders_e(info, mname, genders) map check_genders_s(minfo, mname, genders) + s foreach check_genders_e(info, mname, genders) + s foreach check_genders_s(minfo, mname, genders) } for (m <- c.modules) { val genders = new GenderMap genders ++= (m.ports map (p => p.name -> to_gender(p.direction))) - m map check_genders_s(m.info, m.name, genders) + m foreach check_genders_s(m.info, m.name, genders) } errors.trigger() c diff --git a/src/main/scala/firrtl/passes/memlib/MemIR.scala b/src/main/scala/firrtl/passes/memlib/MemIR.scala index 5fb837c1..a7ef9d43 100644 --- a/src/main/scala/firrtl/passes/memlib/MemIR.scala +++ b/src/main/scala/firrtl/passes/memlib/MemIR.scala @@ -31,4 +31,9 @@ case class DefAnnotatedMemory( writeLatency, readLatency, readers, writers, readwriters, readUnderWrite) def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = Unit + def foreachExpr(f: Expression => Unit): Unit = Unit + def foreachType(f: Type => Unit): Unit = f(dataType) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } diff --git a/src/main/scala/firrtl/traversals/Foreachers.scala b/src/main/scala/firrtl/traversals/Foreachers.scala new file mode 100644 index 00000000..fdb02399 --- /dev/null +++ b/src/main/scala/firrtl/traversals/Foreachers.scala @@ -0,0 +1,113 @@ +// See LICENSE for license details. + +package firrtl.traversals + +import firrtl.ir._ +import language.implicitConversions + +/** Enables FIRRTL IR nodes to use foreach to traverse children IR nodes + */ +object Foreachers { + + /** Statement Foreachers */ + private trait StmtForMagnet { + def foreach(stmt: Statement): Unit + } + private object StmtForMagnet { + implicit def forStmt(f: Statement => Unit): StmtForMagnet = new StmtForMagnet { + def foreach(stmt: Statement): Unit = stmt foreachStmt f + } + implicit def forExp(f: Expression => Unit): StmtForMagnet = new StmtForMagnet { + def foreach(stmt: Statement): Unit = stmt foreachExpr f + } + implicit def forType(f: Type => Unit): StmtForMagnet = new StmtForMagnet { + def foreach(stmt: Statement) : Unit = stmt foreachType f + } + implicit def forString(f: String => Unit): StmtForMagnet = new StmtForMagnet { + def foreach(stmt: Statement): Unit = stmt foreachString f + } + implicit def forInfo(f: Info => Unit): StmtForMagnet = new StmtForMagnet { + def foreach(stmt: Statement): Unit = stmt foreachInfo f + } + } + implicit class StmtForeach(val _stmt: Statement) extends AnyVal { + // Using implicit types to allow overloading of function type to foreach, see StmtForMagnet above + def foreach[T](f: T => Unit)(implicit magnet: (T => Unit) => StmtForMagnet): Unit = magnet(f).foreach(_stmt) + } + + /** Expression Foreachers */ + private trait ExprForMagnet { + def foreach(expr: Expression): Unit + } + private object ExprForMagnet { + implicit def forExpr(f: Expression => Unit): ExprForMagnet = new ExprForMagnet { + def foreach(expr: Expression): Unit = expr foreachExpr f + } + implicit def forType(f: Type => Unit): ExprForMagnet = new ExprForMagnet { + def foreach(expr: Expression): Unit = expr foreachType f + } + implicit def forWidth(f: Width => Unit): ExprForMagnet = new ExprForMagnet { + def foreach(expr: Expression): Unit = expr foreachWidth f + } + } + implicit class ExprForeach(val _expr: Expression) extends AnyVal { + def foreach[T](f: T => Unit)(implicit magnet: (T => Unit) => ExprForMagnet): Unit = magnet(f).foreach(_expr) + } + + /** Type Foreachers */ + private trait TypeForMagnet { + def foreach(tpe: Type): Unit + } + private object TypeForMagnet { + implicit def forType(f: Type => Unit): TypeForMagnet = new TypeForMagnet { + def foreach(tpe: Type): Unit = tpe foreachType f + } + implicit def forWidth(f: Width => Unit): TypeForMagnet = new TypeForMagnet { + def foreach(tpe: Type): Unit = tpe foreachWidth f + } + } + implicit class TypeForeach(val _tpe: Type) extends AnyVal { + def foreach[T](f: T => Unit)(implicit magnet: (T => Unit) => TypeForMagnet): Unit = magnet(f).foreach(_tpe) + } + + /** Module Foreachers */ + private trait ModuleForMagnet { + def foreach(module: DefModule): Unit + } + private object ModuleForMagnet { + implicit def forStmt(f: Statement => Unit): ModuleForMagnet = new ModuleForMagnet { + def foreach(module: DefModule): Unit = module foreachStmt f + } + implicit def forPorts(f: Port => Unit): ModuleForMagnet = new ModuleForMagnet { + def foreach(module: DefModule): Unit = module foreachPort f + } + implicit def forString(f: String => Unit): ModuleForMagnet = new ModuleForMagnet { + def foreach(module: DefModule): Unit = module foreachString f + } + implicit def forInfo(f: Info => Unit): ModuleForMagnet = new ModuleForMagnet { + def foreach(module: DefModule): Unit = module foreachInfo f + } + } + implicit class ModuleForeach(val _module: DefModule) extends AnyVal { + def foreach[T](f: T => Unit)(implicit magnet: (T => Unit) => ModuleForMagnet): Unit = magnet(f).foreach(_module) + } + + /** Circuit Foreachers */ + private trait CircuitForMagnet { + def foreach(module: Circuit): Unit + } + private object CircuitForMagnet { + implicit def forModules(f: DefModule => Unit): CircuitForMagnet = new CircuitForMagnet { + def foreach(circuit: Circuit): Unit = circuit foreachModule f + } + implicit def forString(f: String => Unit): CircuitForMagnet = new CircuitForMagnet { + def foreach(circuit: Circuit): Unit = circuit foreachString f + } + implicit def forInfo(f: Info => Unit): CircuitForMagnet = new CircuitForMagnet { + def foreach(circuit: Circuit): Unit = circuit foreachInfo f + } + } + implicit class CircuitForeach(val _circuit: Circuit) extends AnyVal { + def foreach[T](f: T => Unit)(implicit magnet: (T => Unit) => CircuitForMagnet): Unit = magnet(f).foreach(_circuit) + } +} |
