aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Izraelevitz2018-11-27 13:28:12 -0800
committerGitHub2018-11-27 13:28:12 -0800
commit17d1d2db772f90b039210874aadb11a8a807baba (patch)
treef303cee0e5eeafffa73f93ee16a91be7aca1d34b
parent82f62e04ed71d4507b72f784b3c230dda1262340 (diff)
Add foreach as alternative to map (#952)
* Added Foreachers * Changed CheckTypes to use foreach * Check widths now uses foreach * Finished merge, added foreachers to added stmts * Address reviewer feedback
-rw-r--r--src/main/scala/firrtl/AddDescriptionNodes.scala9
-rw-r--r--src/main/scala/firrtl/Emitter.scala3
-rw-r--r--src/main/scala/firrtl/WIR.scala41
-rw-r--r--src/main/scala/firrtl/ir/IR.scala137
-rw-r--r--src/main/scala/firrtl/passes/CheckWidths.scala41
-rw-r--r--src/main/scala/firrtl/passes/Checks.scala86
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemIR.scala5
-rw-r--r--src/main/scala/firrtl/traversals/Foreachers.scala113
8 files changed, 370 insertions, 65 deletions
diff --git a/src/main/scala/firrtl/AddDescriptionNodes.scala b/src/main/scala/firrtl/AddDescriptionNodes.scala
index 6bae6857..1ed3259f 100644
--- a/src/main/scala/firrtl/AddDescriptionNodes.scala
+++ b/src/main/scala/firrtl/AddDescriptionNodes.scala
@@ -36,6 +36,11 @@ private case class DescribedStmt(description: Description, stmt: Statement) exte
def mapType(f: Type => Type): Statement = this.copy(stmt = stmt.mapType(f))
def mapString(f: String => String): Statement = this.copy(stmt = stmt.mapString(f))
def mapInfo(f: Info => Info): Statement = this.copy(stmt = stmt.mapInfo(f))
+ def foreachStmt(f: Statement => Unit): Unit = f(stmt)
+ def foreachExpr(f: Expression => Unit): Unit = stmt.foreachExpr(f)
+ def foreachType(f: Type => Unit): Unit = stmt.foreachType(f)
+ def foreachString(f: String => Unit): Unit = stmt.foreachString(f)
+ def foreachInfo(f: Info => Unit): Unit = stmt.foreachInfo(f)
}
private case class DescribedMod(description: Description,
@@ -49,6 +54,10 @@ private case class DescribedMod(description: Description,
def mapPort(f: Port => Port): DefModule = this.copy(mod = mod.mapPort(f))
def mapString(f: String => String): DefModule = this.copy(mod = mod.mapString(f))
def mapInfo(f: Info => Info): DefModule = this.copy(mod = mod.mapInfo(f))
+ def foreachStmt(f: Statement => Unit): Unit = mod.foreachStmt(f)
+ def foreachPort(f: Port => Unit): Unit = mod.foreachPort(f)
+ def foreachString(f: String => Unit): Unit = mod.foreachString(f)
+ def foreachInfo(f: Info => Unit): Unit = mod.foreachInfo(f)
}
/** Wraps modules or statements with their respective described nodes.
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala
index 578782ce..7be049ed 100644
--- a/src/main/scala/firrtl/Emitter.scala
+++ b/src/main/scala/firrtl/Emitter.scala
@@ -125,6 +125,9 @@ case class VRandom(width: BigInt) extends Expression {
def mapExpr(f: Expression => Expression): Expression = this
def mapType(f: Type => Type): Expression = this
def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = Unit
+ def foreachType(f: Type => Unit): Unit = Unit
+ def foreachWidth(f: Width => Unit): Unit = Unit
}
class VerilogEmitter extends SeqTransform with Emitter {
diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala
index f61fa41e..b96cd253 100644
--- a/src/main/scala/firrtl/WIR.scala
+++ b/src/main/scala/firrtl/WIR.scala
@@ -29,6 +29,9 @@ case class WRef(name: String, tpe: Type, kind: Kind, gender: Gender) extends Exp
def mapExpr(f: Expression => Expression): Expression = this
def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = Unit
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachWidth(f: Width => Unit): Unit = Unit
}
object WRef {
/** Creates a WRef from a Wire */
@@ -44,6 +47,9 @@ case class WSubField(expr: Expression, name: String, tpe: Type, gender: Gender)
def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr))
def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = f(expr)
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachWidth(f: Width => Unit): Unit = Unit
}
object WSubField {
def apply(expr: Expression, n: String): WSubField = new WSubField(expr, n, field_type(expr.tpe, n), UNKNOWNGENDER)
@@ -54,12 +60,18 @@ case class WSubIndex(expr: Expression, value: Int, tpe: Type, gender: Gender) ex
def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr))
def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = f(expr)
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachWidth(f: Width => Unit): Unit = Unit
}
case class WSubAccess(expr: Expression, index: Expression, tpe: Type, gender: Gender) extends Expression {
def serialize: String = s"${expr.serialize}[${index.serialize}]"
def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr), index = f(index))
def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = { f(expr); f(index) }
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachWidth(f: Width => Unit): Unit = Unit
}
case object WVoid extends Expression {
def tpe = UnknownType
@@ -67,6 +79,9 @@ case object WVoid extends Expression {
def mapExpr(f: Expression => Expression): Expression = this
def mapType(f: Type => Type): Expression = this
def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = Unit
+ def foreachType(f: Type => Unit): Unit = Unit
+ def foreachWidth(f: Width => Unit): Unit = Unit
}
case object WInvalid extends Expression {
def tpe = UnknownType
@@ -74,6 +89,9 @@ case object WInvalid extends Expression {
def mapExpr(f: Expression => Expression): Expression = this
def mapType(f: Type => Type): Expression = this
def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = Unit
+ def foreachType(f: Type => Unit): Unit = Unit
+ def foreachWidth(f: Width => Unit): Unit = Unit
}
// Useful for splitting then remerging references
case object EmptyExpression extends Expression {
@@ -82,6 +100,9 @@ case object EmptyExpression extends Expression {
def mapExpr(f: Expression => Expression): Expression = this
def mapType(f: Type => Type): Expression = this
def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = Unit
+ def foreachType(f: Type => Unit): Unit = Unit
+ def foreachWidth(f: Width => Unit): Unit = Unit
}
case class WDefInstance(info: Info, name: String, module: String, tpe: Type) extends Statement with IsDeclaration {
def serialize: String = s"inst $name of $module" + info.serialize
@@ -90,6 +111,11 @@ case class WDefInstance(info: Info, name: String, module: String, tpe: Type) ext
def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
def mapString(f: String => String): Statement = this.copy(name = f(name))
def mapInfo(f: Info => Info): Statement = this.copy(f(info))
+ def foreachStmt(f: Statement => Unit): Unit = Unit
+ def foreachExpr(f: Expression => Unit): Unit = Unit
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
object WDefInstance {
def apply(name: String, module: String): WDefInstance = new WDefInstance(NoInfo, name, module, UnknownType)
@@ -108,6 +134,11 @@ case class WDefInstanceConnector(
def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
def mapString(f: String => String): Statement = this.copy(name = f(name))
def mapInfo(f: Info => Info): Statement = this.copy(f(info))
+ def foreachStmt(f: Statement => Unit): Unit = Unit
+ def foreachExpr(f: Expression => Unit): Unit = portCons foreach { case (e1, e2) => (f(e1), f(e2)) }
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
// Resultant width is the same as the maximum input width
@@ -285,6 +316,11 @@ case class CDefMemory(
def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
def mapString(f: String => String): Statement = this.copy(name = f(name))
def mapInfo(f: Info => Info): Statement = this.copy(f(info))
+ def foreachStmt(f: Statement => Unit): Unit = Unit
+ def foreachExpr(f: Expression => Unit): Unit = Unit
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
case class CDefMPort(info: Info,
name: String,
@@ -301,5 +337,10 @@ case class CDefMPort(info: Info,
def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
def mapString(f: String => String): Statement = this.copy(name = f(name))
def mapInfo(f: Info => Info): Statement = this.copy(f(info))
+ def foreachStmt(f: Statement => Unit): Unit = Unit
+ def foreachExpr(f: Expression => Unit): Unit = exps.foreach(f)
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala
index faebc7b8..cdf8e194 100644
--- a/src/main/scala/firrtl/ir/IR.scala
+++ b/src/main/scala/firrtl/ir/IR.scala
@@ -112,24 +112,36 @@ abstract class Expression extends FirrtlNode {
def mapExpr(f: Expression => Expression): Expression
def mapType(f: Type => Type): Expression
def mapWidth(f: Width => Width): Expression
+ def foreachExpr(f: Expression => Unit): Unit
+ def foreachType(f: Type => Unit): Unit
+ def foreachWidth(f: Width => Unit): Unit
}
case class Reference(name: String, tpe: Type) extends Expression with HasName {
def serialize: String = name
def mapExpr(f: Expression => Expression): Expression = this
def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = Unit
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachWidth(f: Width => Unit): Unit = Unit
}
case class SubField(expr: Expression, name: String, tpe: Type) extends Expression with HasName {
def serialize: String = s"${expr.serialize}.$name"
def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr))
def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = f(expr)
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachWidth(f: Width => Unit): Unit = Unit
}
case class SubIndex(expr: Expression, value: Int, tpe: Type) extends Expression {
def serialize: String = s"${expr.serialize}[$value]"
def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr))
def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = f(expr)
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachWidth(f: Width => Unit): Unit = Unit
}
case class SubAccess(expr: Expression, index: Expression, tpe: Type) extends Expression {
def serialize: String = s"${expr.serialize}[${index.serialize}]"
@@ -137,18 +149,27 @@ case class SubAccess(expr: Expression, index: Expression, tpe: Type) extends Exp
this.copy(expr = f(expr), index = f(index))
def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = { f(expr); f(index) }
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachWidth(f: Width => Unit): Unit = Unit
}
case class Mux(cond: Expression, tval: Expression, fval: Expression, tpe: Type) extends Expression {
def serialize: String = s"mux(${cond.serialize}, ${tval.serialize}, ${fval.serialize})"
def mapExpr(f: Expression => Expression): Expression = Mux(f(cond), f(tval), f(fval), tpe)
def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = { f(cond); f(tval); f(fval) }
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachWidth(f: Width => Unit): Unit = Unit
}
case class ValidIf(cond: Expression, value: Expression, tpe: Type) extends Expression {
def serialize: String = s"validif(${cond.serialize}, ${value.serialize})"
def mapExpr(f: Expression => Expression): Expression = ValidIf(f(cond), f(value), tpe)
def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = { f(cond); f(value) }
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachWidth(f: Width => Unit): Unit = Unit
}
abstract class Literal extends Expression {
val value: BigInt
@@ -160,6 +181,9 @@ case class UIntLiteral(value: BigInt, width: Width) extends Literal {
def mapExpr(f: Expression => Expression): Expression = this
def mapType(f: Type => Type): Expression = this
def mapWidth(f: Width => Width): Expression = UIntLiteral(value, f(width))
+ def foreachExpr(f: Expression => Unit): Unit = Unit
+ def foreachType(f: Type => Unit): Unit = Unit
+ def foreachWidth(f: Width => Unit): Unit = f(width)
}
object UIntLiteral {
def minWidth(value: BigInt): Width = IntWidth(math.max(value.bitLength, 1))
@@ -171,6 +195,9 @@ case class SIntLiteral(value: BigInt, width: Width) extends Literal {
def mapExpr(f: Expression => Expression): Expression = this
def mapType(f: Type => Type): Expression = this
def mapWidth(f: Width => Width): Expression = SIntLiteral(value, f(width))
+ def foreachExpr(f: Expression => Unit): Unit = Unit
+ def foreachType(f: Type => Unit): Unit = Unit
+ def foreachWidth(f: Width => Unit): Unit = f(width)
}
object SIntLiteral {
def minWidth(value: BigInt): Width = IntWidth(value.bitLength + 1)
@@ -185,6 +212,9 @@ case class FixedLiteral(value: BigInt, width: Width, point: Width) extends Liter
def mapExpr(f: Expression => Expression): Expression = this
def mapType(f: Type => Type): Expression = this
def mapWidth(f: Width => Width): Expression = FixedLiteral(value, f(width), f(point))
+ def foreachExpr(f: Expression => Unit): Unit = Unit
+ def foreachType(f: Type => Unit): Unit = Unit
+ def foreachWidth(f: Width => Unit): Unit = { f(width); f(point) }
}
case class DoPrim(op: PrimOp, args: Seq[Expression], consts: Seq[BigInt], tpe: Type) extends Expression {
def serialize: String = op.serialize + "(" +
@@ -192,6 +222,9 @@ case class DoPrim(op: PrimOp, args: Seq[Expression], consts: Seq[BigInt], tpe: T
def mapExpr(f: Expression => Expression): Expression = this.copy(args = args map f)
def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = args.foreach(f)
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachWidth(f: Width => Unit): Unit = Unit
}
abstract class Statement extends FirrtlNode {
@@ -200,6 +233,11 @@ abstract class Statement extends FirrtlNode {
def mapType(f: Type => Type): Statement
def mapString(f: String => String): Statement
def mapInfo(f: Info => Info): Statement
+ def foreachStmt(f: Statement => Unit): Unit
+ def foreachExpr(f: Expression => Unit): Unit
+ def foreachType(f: Type => Unit): Unit
+ def foreachString(f: String => Unit): Unit
+ def foreachInfo(f: Info => Unit): Unit
}
case class DefWire(info: Info, name: String, tpe: Type) extends Statement with IsDeclaration {
def serialize: String = s"wire $name : ${tpe.serialize}" + info.serialize
@@ -208,6 +246,11 @@ case class DefWire(info: Info, name: String, tpe: Type) extends Statement with I
def mapType(f: Type => Type): Statement = DefWire(info, name, f(tpe))
def mapString(f: String => String): Statement = DefWire(info, f(name), tpe)
def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = Unit
+ def foreachExpr(f: Expression => Unit): Unit = Unit
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
case class DefRegister(
info: Info,
@@ -225,7 +268,11 @@ case class DefRegister(
def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
def mapString(f: String => String): Statement = this.copy(name = f(name))
def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
-
+ def foreachStmt(f: Statement => Unit): Unit = Unit
+ def foreachExpr(f: Expression => Unit): Unit = { f(clock); f(reset); f(init) }
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
case class DefInstance(info: Info, name: String, module: String) extends Statement with IsDeclaration {
def serialize: String = s"inst $name of $module" + info.serialize
@@ -234,6 +281,11 @@ case class DefInstance(info: Info, name: String, module: String) extends Stateme
def mapType(f: Type => Type): Statement = this
def mapString(f: String => String): Statement = DefInstance(info, f(name), module)
def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = Unit
+ def foreachExpr(f: Expression => Unit): Unit = Unit
+ def foreachType(f: Type => Unit): Unit = Unit
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
case class DefMemory(
info: Info,
@@ -263,6 +315,11 @@ case class DefMemory(
def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType))
def mapString(f: String => String): Statement = this.copy(name = f(name))
def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = Unit
+ def foreachExpr(f: Expression => Unit): Unit = Unit
+ def foreachType(f: Type => Unit): Unit = f(dataType)
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
case class DefNode(info: Info, name: String, value: Expression) extends Statement with IsDeclaration {
def serialize: String = s"node $name = ${value.serialize}" + info.serialize
@@ -271,6 +328,11 @@ case class DefNode(info: Info, name: String, value: Expression) extends Statemen
def mapType(f: Type => Type): Statement = this
def mapString(f: String => String): Statement = DefNode(info, f(name), value)
def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = Unit
+ def foreachExpr(f: Expression => Unit): Unit = f(value)
+ def foreachType(f: Type => Unit): Unit = Unit
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
case class Conditionally(
info: Info,
@@ -287,6 +349,11 @@ case class Conditionally(
def mapType(f: Type => Type): Statement = this
def mapString(f: String => String): Statement = this
def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = { f(conseq); f(alt) }
+ def foreachExpr(f: Expression => Unit): Unit = f(pred)
+ def foreachType(f: Type => Unit): Unit = Unit
+ def foreachString(f: String => Unit): Unit = Unit
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
case class Block(stmts: Seq[Statement]) extends Statement {
def serialize: String = stmts map (_.serialize) mkString "\n"
@@ -295,6 +362,11 @@ case class Block(stmts: Seq[Statement]) extends Statement {
def mapType(f: Type => Type): Statement = this
def mapString(f: String => String): Statement = this
def mapInfo(f: Info => Info): Statement = this
+ def foreachStmt(f: Statement => Unit): Unit = stmts.foreach(f)
+ def foreachExpr(f: Expression => Unit): Unit = Unit
+ def foreachType(f: Type => Unit): Unit = Unit
+ def foreachString(f: String => Unit): Unit = Unit
+ def foreachInfo(f: Info => Unit): Unit = Unit
}
case class PartialConnect(info: Info, loc: Expression, expr: Expression) extends Statement with HasInfo {
def serialize: String = s"${loc.serialize} <- ${expr.serialize}" + info.serialize
@@ -303,6 +375,11 @@ case class PartialConnect(info: Info, loc: Expression, expr: Expression) extends
def mapType(f: Type => Type): Statement = this
def mapString(f: String => String): Statement = this
def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = Unit
+ def foreachExpr(f: Expression => Unit): Unit = { f(loc); f(expr) }
+ def foreachType(f: Type => Unit): Unit = Unit
+ def foreachString(f: String => Unit): Unit = Unit
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
case class Connect(info: Info, loc: Expression, expr: Expression) extends Statement with HasInfo {
def serialize: String = s"${loc.serialize} <= ${expr.serialize}" + info.serialize
@@ -311,6 +388,11 @@ case class Connect(info: Info, loc: Expression, expr: Expression) extends Statem
def mapType(f: Type => Type): Statement = this
def mapString(f: String => String): Statement = this
def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = Unit
+ def foreachExpr(f: Expression => Unit): Unit = { f(loc); f(expr) }
+ def foreachType(f: Type => Unit): Unit = Unit
+ def foreachString(f: String => Unit): Unit = Unit
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
case class IsInvalid(info: Info, expr: Expression) extends Statement with HasInfo {
def serialize: String = s"${expr.serialize} is invalid" + info.serialize
@@ -319,6 +401,11 @@ case class IsInvalid(info: Info, expr: Expression) extends Statement with HasInf
def mapType(f: Type => Type): Statement = this
def mapString(f: String => String): Statement = this
def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = Unit
+ def foreachExpr(f: Expression => Unit): Unit = f(expr)
+ def foreachType(f: Type => Unit): Unit = Unit
+ def foreachString(f: String => Unit): Unit = Unit
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
case class Attach(info: Info, exprs: Seq[Expression]) extends Statement with HasInfo {
def serialize: String = "attach " + exprs.map(_.serialize).mkString("(", ", ", ")")
@@ -327,6 +414,11 @@ case class Attach(info: Info, exprs: Seq[Expression]) extends Statement with Has
def mapType(f: Type => Type): Statement = this
def mapString(f: String => String): Statement = this
def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = Unit
+ def foreachExpr(f: Expression => Unit): Unit = exprs.foreach(f)
+ def foreachType(f: Type => Unit): Unit = Unit
+ def foreachString(f: String => Unit): Unit = Unit
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
case class Stop(info: Info, ret: Int, clk: Expression, en: Expression) extends Statement with HasInfo {
def serialize: String = s"stop(${clk.serialize}, ${en.serialize}, $ret)" + info.serialize
@@ -335,6 +427,11 @@ case class Stop(info: Info, ret: Int, clk: Expression, en: Expression) extends S
def mapType(f: Type => Type): Statement = this
def mapString(f: String => String): Statement = this
def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = Unit
+ def foreachExpr(f: Expression => Unit): Unit = { f(clk); f(en) }
+ def foreachType(f: Type => Unit): Unit = Unit
+ def foreachString(f: String => Unit): Unit = Unit
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
case class Print(
info: Info,
@@ -352,6 +449,11 @@ case class Print(
def mapType(f: Type => Type): Statement = this
def mapString(f: String => String): Statement = this
def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = Unit
+ def foreachExpr(f: Expression => Unit): Unit = { args.foreach(f); f(clk); f(en) }
+ def foreachType(f: Type => Unit): Unit = Unit
+ def foreachString(f: String => Unit): Unit = Unit
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
case object EmptyStmt extends Statement {
def serialize: String = "skip"
@@ -360,6 +462,11 @@ case object EmptyStmt extends Statement {
def mapType(f: Type => Type): Statement = this
def mapString(f: String => String): Statement = this
def mapInfo(f: Info => Info): Statement = this
+ def foreachStmt(f: Statement => Unit): Unit = Unit
+ def foreachExpr(f: Expression => Unit): Unit = Unit
+ def foreachType(f: Type => Unit): Unit = Unit
+ def foreachString(f: String => Unit): Unit = Unit
+ def foreachInfo(f: Info => Unit): Unit = Unit
}
abstract class Width extends FirrtlNode {
@@ -436,24 +543,30 @@ case class Field(name: String, flip: Orientation, tpe: Type) extends FirrtlNode
abstract class Type extends FirrtlNode {
def mapType(f: Type => Type): Type
def mapWidth(f: Width => Width): Type
+ def foreachType(f: Type => Unit): Unit
+ def foreachWidth(f: Width => Unit): Unit
}
abstract class GroundType extends Type {
val width: Width
def mapType(f: Type => Type): Type = this
+ def foreachType(f: Type => Unit): Unit = Unit
}
object GroundType {
def unapply(ground: GroundType): Option[Width] = Some(ground.width)
}
abstract class AggregateType extends Type {
def mapWidth(f: Width => Width): Type = this
+ def foreachWidth(f: Width => Unit): Unit = Unit
}
case class UIntType(width: Width) extends GroundType {
def serialize: String = "UInt" + width.serialize
def mapWidth(f: Width => Width): Type = UIntType(f(width))
+ def foreachWidth(f: Width => Unit): Unit = f(width)
}
case class SIntType(width: Width) extends GroundType {
def serialize: String = "SInt" + width.serialize
def mapWidth(f: Width => Width): Type = SIntType(f(width))
+ def foreachWidth(f: Width => Unit): Unit = f(width)
}
case class FixedType(width: Width, point: Width) extends GroundType {
override def serialize: String = {
@@ -461,29 +574,36 @@ case class FixedType(width: Width, point: Width) extends GroundType {
s"Fixed${width.serialize}$pstring"
}
def mapWidth(f: Width => Width): Type = FixedType(f(width), f(point))
+ def foreachWidth(f: Width => Unit): Unit = { f(width); f(point) }
}
case class BundleType(fields: Seq[Field]) extends AggregateType {
def serialize: String = "{ " + (fields map (_.serialize) mkString ", ") + "}"
def mapType(f: Type => Type): Type =
BundleType(fields map (x => x.copy(tpe = f(x.tpe))))
+ def foreachType(f: Type => Unit): Unit = fields.foreach{ x => f(x.tpe) }
}
case class VectorType(tpe: Type, size: Int) extends AggregateType {
def serialize: String = tpe.serialize + s"[$size]"
def mapType(f: Type => Type): Type = this.copy(tpe = f(tpe))
+ def foreachType(f: Type => Unit): Unit = f(tpe)
}
case object ClockType extends GroundType {
val width = IntWidth(1)
def serialize: String = "Clock"
def mapWidth(f: Width => Width): Type = this
+ def foreachWidth(f: Width => Unit): Unit = Unit
}
case class AnalogType(width: Width) extends GroundType {
def serialize: String = "Analog" + width.serialize
def mapWidth(f: Width => Width): Type = AnalogType(f(width))
+ def foreachWidth(f: Width => Unit): Unit = f(width)
}
case object UnknownType extends Type {
def serialize: String = "?"
def mapType(f: Type => Type): Type = this
def mapWidth(f: Width => Width): Type = this
+ def foreachType(f: Type => Unit): Unit = Unit
+ def foreachWidth(f: Width => Unit): Unit = Unit
}
/** [[Port]] Direction */
@@ -540,6 +660,10 @@ abstract class DefModule extends FirrtlNode with IsDeclaration {
def mapPort(f: Port => Port): DefModule
def mapString(f: String => String): DefModule
def mapInfo(f: Info => Info): DefModule
+ def foreachStmt(f: Statement => Unit): Unit
+ def foreachPort(f: Port => Unit): Unit
+ def foreachString(f: String => Unit): Unit
+ def foreachInfo(f: Info => Unit): Unit
}
/** Internal Module
*
@@ -551,6 +675,10 @@ case class Module(info: Info, name: String, ports: Seq[Port], body: Statement) e
def mapPort(f: Port => Port): DefModule = this.copy(ports = ports map f)
def mapString(f: String => String): DefModule = this.copy(name = f(name))
def mapInfo(f: Info => Info): DefModule = this.copy(f(info))
+ def foreachStmt(f: Statement => Unit): Unit = f(body)
+ def foreachPort(f: Port => Unit): Unit = ports.foreach(f)
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
/** External Module
*
@@ -569,6 +697,10 @@ case class ExtModule(
def mapPort(f: Port => Port): DefModule = this.copy(ports = ports map f)
def mapString(f: String => String): DefModule = this.copy(name = f(name))
def mapInfo(f: Info => Info): DefModule = this.copy(f(info))
+ def foreachStmt(f: Statement => Unit): Unit = Unit
+ def foreachPort(f: Port => Unit): Unit = ports.foreach(f)
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
case class Circuit(info: Info, modules: Seq[DefModule], main: String) extends FirrtlNode with HasInfo {
@@ -578,4 +710,7 @@ case class Circuit(info: Info, modules: Seq[DefModule], main: String) extends Fi
def mapModule(f: DefModule => DefModule): Circuit = this.copy(modules = modules map f)
def mapString(f: String => String): Circuit = this.copy(main = f(main))
def mapInfo(f: Info => Info): Circuit = this.copy(f(info))
+ def foreachModule(f: DefModule => Unit): Unit = modules foreach f
+ def foreachString(f: String => Unit): Unit = f(main)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala
index 4a72b18c..061c6b16 100644
--- a/src/main/scala/firrtl/passes/CheckWidths.scala
+++ b/src/main/scala/firrtl/passes/CheckWidths.scala
@@ -5,9 +5,9 @@ package firrtl.passes
import firrtl._
import firrtl.ir._
import firrtl.PrimOps._
-import firrtl.Mappers._
+import firrtl.traversals.Foreachers._
import firrtl.Utils._
-import firrtl.annotations.{Target, TargetToken, CircuitTarget, ModuleTarget}
+import firrtl.annotations.{CircuitTarget, ModuleTarget, Target, TargetToken}
object CheckWidths extends Pass {
/** The maximum allowed width for any circuit element */
@@ -36,7 +36,7 @@ object CheckWidths extends Pass {
def run(c: Circuit): Circuit = {
val errors = new Errors()
- def check_width_w(info: Info, target: Target)(w: Width): Width = {
+ def check_width_w(info: Info, target: Target)(w: Width): Unit = {
w match {
case IntWidth(width) if width >= MaxWidth =>
errors.append(new WidthTooBig(info, target.serialize, width))
@@ -46,7 +46,6 @@ object CheckWidths extends Pass {
case _ =>
errors append new UninferredWidth(info, target.prettyPrint(" "))
}
- w
}
def hasWidth(tpe: Type): Boolean = tpe match {
@@ -55,18 +54,18 @@ object CheckWidths extends Pass {
case _ => throwInternalError(s"hasWidth - $tpe")
}
- def check_width_t(info: Info, target: Target)(t: Type): Type = {
- val tx = t match {
- case tt: BundleType => BundleType(tt.fields.map(check_width_f(info, target)))
- case tt => tt map check_width_t(info, target)
+ def check_width_t(info: Info, target: Target)(t: Type): Unit = {
+ t match {
+ case tt: BundleType => tt.fields.foreach(check_width_f(info, target))
+ case tt => tt foreach check_width_t(info, target)
}
- tx map check_width_w(info, target)
+ t foreach check_width_w(info, target)
}
- def check_width_f(info: Info, target: Target)(f: Field): Field = f
- .copy(tpe = check_width_t(info, target.modify(tokens = target.tokens :+ TargetToken.Field(f.name)))(f.tpe))
+ def check_width_f(info: Info, target: Target)(f: Field): Unit =
+ check_width_t(info, target.modify(tokens = target.tokens :+ TargetToken.Field(f.name)))(f.tpe)
- def check_width_e(info: Info, target: Target)(e: Expression): Expression = {
+ def check_width_e(info: Info, target: Target)(e: Expression): Unit = {
e match {
case e: UIntLiteral => e.width match {
case w: IntWidth if math.max(1, e.value.bitLength) > w.width =>
@@ -89,34 +88,36 @@ object CheckWidths extends Pass {
case _ =>
}
//e map check_width_t(info, mname) map check_width_e(info, mname)
- e map check_width_e(info, target)
+ e foreach check_width_e(info, target)
}
- def check_width_s(minfo: Info, target: ModuleTarget)(s: Statement): Statement = {
+ def check_width_s(minfo: Info, target: ModuleTarget)(s: Statement): Unit = {
val info = get_info(s) match { case NoInfo => minfo case x => x }
val subRef = s match { case sx: HasName => target.ref(sx.name) case _ => target }
- s map check_width_e(info, target) map check_width_s(info, target) map check_width_t(info, subRef) match {
+ s foreach check_width_e(info, target)
+ s foreach check_width_s(info, target)
+ s foreach check_width_t(info, subRef)
+ s match {
case Attach(infox, exprs) =>
exprs.tail.foreach ( e =>
if (bitWidth(e.tpe) != bitWidth(exprs.head.tpe))
errors.append(new AttachWidthsNotEqual(infox, target.serialize, e.serialize, exprs.head.serialize))
)
- s
case sx: DefRegister =>
sx.reset.tpe match {
case UIntType(IntWidth(w)) if w == 1 =>
case _ => errors.append(new CheckTypes.IllegalResetType(info, target.serialize, sx.name))
}
- s
- case _ => s
+ case _ =>
}
}
- def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Port = p.copy(tpe = check_width_t(p.info, target)(p.tpe))
+ def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Unit = check_width_t(p.info, target)(p.tpe)
def check_width_m(circuit: CircuitTarget)(m: DefModule) {
- m map check_width_p(m.info, circuit.module(m.name)) map check_width_s(m.info, circuit.module(m.name))
+ m foreach check_width_p(m.info, circuit.module(m.name))
+ m foreach check_width_s(m.info, circuit.module(m.name))
}
c.modules foreach check_width_m(CircuitTarget(c.main))
diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala
index 4c7458bf..bc9d3a1c 100644
--- a/src/main/scala/firrtl/passes/Checks.scala
+++ b/src/main/scala/firrtl/passes/Checks.scala
@@ -7,6 +7,7 @@ import firrtl.ir._
import firrtl.PrimOps._
import firrtl.Utils._
import firrtl.Mappers._
+import firrtl.traversals.Foreachers._
import firrtl.WrappedType._
object CheckHighForm extends Pass {
@@ -116,32 +117,29 @@ object CheckHighForm extends Pass {
case _ => // Do Nothing
}
- def checkHighFormW(info: Info, mname: String)(w: Width): Width = {
+ def checkHighFormW(info: Info, mname: String)(w: Width): Unit = {
w match {
- case wx: IntWidth if wx.width < 0 =>
- errors.append(new NegWidthException(info, mname))
+ case wx: IntWidth if wx.width < 0 => errors.append(new NegWidthException(info, mname))
case wx => // Do nothing
}
- w
}
- def checkHighFormT(info: Info, mname: String)(t: Type): Type =
- t map checkHighFormT(info, mname) match {
- case tx: VectorType if tx.size < 0 =>
- errors.append(new NegVecSizeException(info, mname))
- t
- case _ => t map checkHighFormW(info, mname)
+ def checkHighFormT(info: Info, mname: String)(t: Type): Unit = {
+ t foreach checkHighFormT(info, mname)
+ t match {
+ case tx: VectorType if tx.size < 0 => errors.append(new NegVecSizeException(info, mname))
+ case _ => t foreach checkHighFormW(info, mname)
}
+ }
- def validSubexp(info: Info, mname: String)(e: Expression): Expression = {
+ def validSubexp(info: Info, mname: String)(e: Expression): Unit = {
e match {
case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess | _: Mux | _: ValidIf => // No error
case _ => errors.append(new InvalidAccessException(info, mname))
}
- e
}
- def checkHighFormE(info: Info, mname: String, names: NameSet)(e: Expression): Expression = {
+ def checkHighFormE(info: Info, mname: String, names: NameSet)(e: Expression): Unit = {
e match {
case ex: WRef if !names(ex.name) =>
errors.append(new UndeclaredReferenceException(info, mname, ex.name))
@@ -150,23 +148,23 @@ object CheckHighForm extends Pass {
case ex: DoPrim => checkHighFormPrimop(info, mname, ex)
case _: WRef | _: UIntLiteral | _: Mux | _: ValidIf =>
case ex: WSubAccess => validSubexp(info, mname)(ex.expr)
- case ex => ex map validSubexp(info, mname)
+ case ex => ex foreach validSubexp(info, mname)
}
- (e map checkHighFormW(info, mname)
- map checkHighFormT(info, mname)
- map checkHighFormE(info, mname, names))
+ e foreach checkHighFormW(info, mname)
+ e foreach checkHighFormT(info, mname)
+ e foreach checkHighFormE(info, mname, names)
}
- def checkName(info: Info, mname: String, names: NameSet)(name: String): String = {
+ def checkName(info: Info, mname: String, names: NameSet)(name: String): Unit = {
if (names(name))
errors.append(new NotUniqueException(info, mname, name))
names += name
- name
}
- def checkHighFormS(minfo: Info, mname: String, names: NameSet)(s: Statement): Statement = {
+ def checkHighFormS(minfo: Info, mname: String, names: NameSet)(s: Statement): Unit = {
val info = get_info(s) match {case NoInfo => minfo case x => x}
- s map checkName(info, mname, names) match {
+ s foreach checkName(info, mname, names)
+ s match {
case sx: DefMemory =>
if (hasFlip(sx.dataType))
errors.append(new MemWithFlipException(info, mname, sx.name))
@@ -184,24 +182,23 @@ object CheckHighForm extends Pass {
case sx: Print => checkFstring(info, mname, sx.string, sx.args.length)
case sx => // Do Nothing
}
- (s map checkHighFormT(info, mname)
- map checkHighFormE(info, mname, names)
- map checkHighFormS(minfo, mname, names))
+ s foreach checkHighFormT(info, mname)
+ s foreach checkHighFormE(info, mname, names)
+ s foreach checkHighFormS(minfo, mname, names)
}
- def checkHighFormP(mname: String, names: NameSet)(p: Port): Port = {
+ def checkHighFormP(mname: String, names: NameSet)(p: Port): Unit = {
if (names(p.name))
errors.append(new NotUniqueException(NoInfo, mname, p.name))
names += p.name
- (p.tpe map checkHighFormT(p.info, mname)
- map checkHighFormW(p.info, mname))
- p
+ p.tpe foreach checkHighFormT(p.info, mname)
+ p.tpe foreach checkHighFormW(p.info, mname)
}
def checkHighFormM(m: DefModule) {
val names = new NameSet
- (m map checkHighFormP(m.name, names)
- map checkHighFormS(m.info, m.name, names))
+ m foreach checkHighFormP(m.name, names)
+ m foreach checkHighFormS(m.info, m.name, names)
}
c.modules foreach checkHighFormM
@@ -333,7 +330,7 @@ object CheckTypes extends Pass {
}
}
- def check_types_e(info:Info, mname: String)(e: Expression): Expression = {
+ def check_types_e(info:Info, mname: String)(e: Expression): Unit = {
e match {
case (e: WSubField) => e.expr.tpe match {
case (t: BundleType) => t.fields find (_.name == e.name) match {
@@ -377,7 +374,7 @@ object CheckTypes extends Pass {
}
case _ =>
}
- e map check_types_e(info, mname)
+ e foreach check_types_e(info, mname)
}
def bulk_equals(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Boolean = {
@@ -404,7 +401,7 @@ object CheckTypes extends Pass {
}
}
- def check_types_s(minfo: Info, mname: String)(s: Statement): Statement = {
+ def check_types_s(minfo: Info, mname: String)(s: Statement): Unit = {
val info = get_info(s) match { case NoInfo => minfo case x => x }
s match {
case sx: Connect if wt(sx.loc.tpe) != wt(sx.expr.tpe) =>
@@ -457,10 +454,11 @@ object CheckTypes extends Pass {
}
case _ =>
}
- s map check_types_e(info, mname) map check_types_s(info, mname)
+ s foreach check_types_e(info, mname)
+ s foreach check_types_s(info, mname)
}
- c.modules foreach (m => m map check_types_s(m.info, m.name))
+ c.modules foreach (m => m foreach check_types_s(m.info, m.name))
errors.trigger()
c
}
@@ -504,7 +502,7 @@ object CheckGenders extends Pass {
flip_rec(t, Default)
}
- def check_gender(info:Info, mname: String, genders: GenderMap, desired: Gender)(e:Expression): Expression = {
+ def check_gender(info:Info, mname: String, genders: GenderMap, desired: Gender)(e:Expression): Unit = {
val gender = get_gender(e,genders)
(gender, desired) match {
case (MALE, FEMALE) =>
@@ -516,19 +514,18 @@ object CheckGenders extends Pass {
}
case _ =>
}
- e
}
- def check_genders_e (info:Info, mname: String, genders: GenderMap)(e:Expression): Expression = {
+ def check_genders_e (info:Info, mname: String, genders: GenderMap)(e:Expression): Unit = {
e match {
- case e: Mux => e map check_gender(info, mname, genders, MALE)
- case e: DoPrim => e.args map check_gender(info, mname, genders, MALE)
+ case e: Mux => e foreach check_gender(info, mname, genders, MALE)
+ case e: DoPrim => e.args foreach check_gender(info, mname, genders, MALE)
case _ =>
}
- e map check_genders_e(info, mname, genders)
+ e foreach check_genders_e(info, mname, genders)
}
- def check_genders_s(minfo: Info, mname: String, genders: GenderMap)(s: Statement): Statement = {
+ def check_genders_s(minfo: Info, mname: String, genders: GenderMap)(s: Statement): Unit = {
val info = get_info(s) match { case NoInfo => minfo case x => x }
s match {
case (s: DefWire) => genders(s.name) = BIGENDER
@@ -555,13 +552,14 @@ object CheckGenders extends Pass {
check_gender(info, mname, genders, MALE)(s.clk)
case _ =>
}
- s map check_genders_e(info, mname, genders) map check_genders_s(minfo, mname, genders)
+ s foreach check_genders_e(info, mname, genders)
+ s foreach check_genders_s(minfo, mname, genders)
}
for (m <- c.modules) {
val genders = new GenderMap
genders ++= (m.ports map (p => p.name -> to_gender(p.direction)))
- m map check_genders_s(m.info, m.name, genders)
+ m foreach check_genders_s(m.info, m.name, genders)
}
errors.trigger()
c
diff --git a/src/main/scala/firrtl/passes/memlib/MemIR.scala b/src/main/scala/firrtl/passes/memlib/MemIR.scala
index 5fb837c1..a7ef9d43 100644
--- a/src/main/scala/firrtl/passes/memlib/MemIR.scala
+++ b/src/main/scala/firrtl/passes/memlib/MemIR.scala
@@ -31,4 +31,9 @@ case class DefAnnotatedMemory(
writeLatency, readLatency, readers, writers,
readwriters, readUnderWrite)
def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = Unit
+ def foreachExpr(f: Expression => Unit): Unit = Unit
+ def foreachType(f: Type => Unit): Unit = f(dataType)
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
diff --git a/src/main/scala/firrtl/traversals/Foreachers.scala b/src/main/scala/firrtl/traversals/Foreachers.scala
new file mode 100644
index 00000000..fdb02399
--- /dev/null
+++ b/src/main/scala/firrtl/traversals/Foreachers.scala
@@ -0,0 +1,113 @@
+// See LICENSE for license details.
+
+package firrtl.traversals
+
+import firrtl.ir._
+import language.implicitConversions
+
+/** Enables FIRRTL IR nodes to use foreach to traverse children IR nodes
+ */
+object Foreachers {
+
+ /** Statement Foreachers */
+ private trait StmtForMagnet {
+ def foreach(stmt: Statement): Unit
+ }
+ private object StmtForMagnet {
+ implicit def forStmt(f: Statement => Unit): StmtForMagnet = new StmtForMagnet {
+ def foreach(stmt: Statement): Unit = stmt foreachStmt f
+ }
+ implicit def forExp(f: Expression => Unit): StmtForMagnet = new StmtForMagnet {
+ def foreach(stmt: Statement): Unit = stmt foreachExpr f
+ }
+ implicit def forType(f: Type => Unit): StmtForMagnet = new StmtForMagnet {
+ def foreach(stmt: Statement) : Unit = stmt foreachType f
+ }
+ implicit def forString(f: String => Unit): StmtForMagnet = new StmtForMagnet {
+ def foreach(stmt: Statement): Unit = stmt foreachString f
+ }
+ implicit def forInfo(f: Info => Unit): StmtForMagnet = new StmtForMagnet {
+ def foreach(stmt: Statement): Unit = stmt foreachInfo f
+ }
+ }
+ implicit class StmtForeach(val _stmt: Statement) extends AnyVal {
+ // Using implicit types to allow overloading of function type to foreach, see StmtForMagnet above
+ def foreach[T](f: T => Unit)(implicit magnet: (T => Unit) => StmtForMagnet): Unit = magnet(f).foreach(_stmt)
+ }
+
+ /** Expression Foreachers */
+ private trait ExprForMagnet {
+ def foreach(expr: Expression): Unit
+ }
+ private object ExprForMagnet {
+ implicit def forExpr(f: Expression => Unit): ExprForMagnet = new ExprForMagnet {
+ def foreach(expr: Expression): Unit = expr foreachExpr f
+ }
+ implicit def forType(f: Type => Unit): ExprForMagnet = new ExprForMagnet {
+ def foreach(expr: Expression): Unit = expr foreachType f
+ }
+ implicit def forWidth(f: Width => Unit): ExprForMagnet = new ExprForMagnet {
+ def foreach(expr: Expression): Unit = expr foreachWidth f
+ }
+ }
+ implicit class ExprForeach(val _expr: Expression) extends AnyVal {
+ def foreach[T](f: T => Unit)(implicit magnet: (T => Unit) => ExprForMagnet): Unit = magnet(f).foreach(_expr)
+ }
+
+ /** Type Foreachers */
+ private trait TypeForMagnet {
+ def foreach(tpe: Type): Unit
+ }
+ private object TypeForMagnet {
+ implicit def forType(f: Type => Unit): TypeForMagnet = new TypeForMagnet {
+ def foreach(tpe: Type): Unit = tpe foreachType f
+ }
+ implicit def forWidth(f: Width => Unit): TypeForMagnet = new TypeForMagnet {
+ def foreach(tpe: Type): Unit = tpe foreachWidth f
+ }
+ }
+ implicit class TypeForeach(val _tpe: Type) extends AnyVal {
+ def foreach[T](f: T => Unit)(implicit magnet: (T => Unit) => TypeForMagnet): Unit = magnet(f).foreach(_tpe)
+ }
+
+ /** Module Foreachers */
+ private trait ModuleForMagnet {
+ def foreach(module: DefModule): Unit
+ }
+ private object ModuleForMagnet {
+ implicit def forStmt(f: Statement => Unit): ModuleForMagnet = new ModuleForMagnet {
+ def foreach(module: DefModule): Unit = module foreachStmt f
+ }
+ implicit def forPorts(f: Port => Unit): ModuleForMagnet = new ModuleForMagnet {
+ def foreach(module: DefModule): Unit = module foreachPort f
+ }
+ implicit def forString(f: String => Unit): ModuleForMagnet = new ModuleForMagnet {
+ def foreach(module: DefModule): Unit = module foreachString f
+ }
+ implicit def forInfo(f: Info => Unit): ModuleForMagnet = new ModuleForMagnet {
+ def foreach(module: DefModule): Unit = module foreachInfo f
+ }
+ }
+ implicit class ModuleForeach(val _module: DefModule) extends AnyVal {
+ def foreach[T](f: T => Unit)(implicit magnet: (T => Unit) => ModuleForMagnet): Unit = magnet(f).foreach(_module)
+ }
+
+ /** Circuit Foreachers */
+ private trait CircuitForMagnet {
+ def foreach(module: Circuit): Unit
+ }
+ private object CircuitForMagnet {
+ implicit def forModules(f: DefModule => Unit): CircuitForMagnet = new CircuitForMagnet {
+ def foreach(circuit: Circuit): Unit = circuit foreachModule f
+ }
+ implicit def forString(f: String => Unit): CircuitForMagnet = new CircuitForMagnet {
+ def foreach(circuit: Circuit): Unit = circuit foreachString f
+ }
+ implicit def forInfo(f: Info => Unit): CircuitForMagnet = new CircuitForMagnet {
+ def foreach(circuit: Circuit): Unit = circuit foreachInfo f
+ }
+ }
+ implicit class CircuitForeach(val _circuit: Circuit) extends AnyVal {
+ def foreach[T](f: T => Unit)(implicit magnet: (T => Unit) => CircuitForMagnet): Unit = magnet(f).foreach(_circuit)
+ }
+}