aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/ir
diff options
context:
space:
mode:
authorjackkoenig2016-09-06 13:51:37 -0700
committerDonggyu Kim2016-09-12 12:48:46 -0700
commit4702bf9f257f954e19d1441b21e737f951ccfbcc (patch)
treec9bc87f605d0352d5cca58b56e7eda501e200759 /src/main/scala/firrtl/ir
parent00bef01b6df158939406f3e744cbdda544823ae5 (diff)
Rework map functions as class methods
Changed code from match statements in Mappers.scala to methods on the various IR classes. This allows custom IR nodes to implement the mapper functions and thus work (ie. not match error) when map is called on them. This also should have a marginal performance increase because of use of virtual function calls rather than match statements.
Diffstat (limited to 'src/main/scala/firrtl/ir')
-rw-r--r--src/main/scala/firrtl/ir/IR.scala118
1 files changed, 115 insertions, 3 deletions
diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala
index afe28634..8f1cabf6 100644
--- a/src/main/scala/firrtl/ir/IR.scala
+++ b/src/main/scala/firrtl/ir/IR.scala
@@ -68,24 +68,46 @@ abstract class PrimOp extends FirrtlNode {
abstract class Expression extends FirrtlNode {
def tpe: Type
+ def mapExpr(f: Expression => Expression): Expression
+ def mapType(f: Type => Type): Expression
+ def mapWidth(f: Width => Width): Expression
}
case class Reference(name: String, tpe: Type) extends Expression with HasName {
def serialize: String = name
+ def mapExpr(f: Expression => Expression): Expression = this
+ def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
+ def mapWidth(f: Width => Width): Expression = this
}
case class SubField(expr: Expression, name: String, tpe: Type) extends Expression with HasName {
def serialize: String = s"${expr.serialize}.$name"
+ def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr))
+ def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
+ def mapWidth(f: Width => Width): Expression = this
}
case class SubIndex(expr: Expression, value: Int, tpe: Type) extends Expression {
def serialize: String = s"${expr.serialize}[$value]"
+ def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr))
+ def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
+ def mapWidth(f: Width => Width): Expression = this
}
case class SubAccess(expr: Expression, index: Expression, tpe: Type) extends Expression {
def serialize: String = s"${expr.serialize}[${index.serialize}]"
+ def mapExpr(f: Expression => Expression): Expression =
+ this.copy(expr = f(expr), index = f(index))
+ def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
+ def mapWidth(f: Width => Width): Expression = this
}
case class Mux(cond: Expression, tval: Expression, fval: Expression, tpe: Type) extends Expression {
def serialize: String = s"mux(${cond.serialize}, ${tval.serialize}, ${fval.serialize})"
+ def mapExpr(f: Expression => Expression): Expression = Mux(f(cond), f(tval), f(fval), tpe)
+ def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
+ def mapWidth(f: Width => Width): Expression = this
}
case class ValidIf(cond: Expression, value: Expression, tpe: Type) extends Expression {
def serialize: String = s"validif(${cond.serialize}, ${value.serialize})"
+ def mapExpr(f: Expression => Expression): Expression = ValidIf(f(cond), f(value), tpe)
+ def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
+ def mapWidth(f: Width => Width): Expression = this
}
abstract class Literal extends Expression {
val value: BigInt
@@ -94,19 +116,37 @@ abstract class Literal extends Expression {
case class UIntLiteral(value: BigInt, width: Width) extends Literal {
def tpe = UIntType(width)
def serialize = s"UInt${width.serialize}(" + Utils.serialize(value) + ")"
+ def mapExpr(f: Expression => Expression): Expression = this
+ def mapType(f: Type => Type): Expression = this
+ def mapWidth(f: Width => Width): Expression = UIntLiteral(value, f(width))
}
case class SIntLiteral(value: BigInt, width: Width) extends Literal {
def tpe = SIntType(width)
def serialize = s"SInt${width.serialize}(" + Utils.serialize(value) + ")"
+ def mapExpr(f: Expression => Expression): Expression = this
+ def mapType(f: Type => Type): Expression = this
+ def mapWidth(f: Width => Width): Expression = SIntLiteral(value, f(width))
}
case class DoPrim(op: PrimOp, args: Seq[Expression], consts: Seq[BigInt], tpe: Type) extends Expression {
def serialize: String = op.serialize + "(" +
(args.map(_.serialize) ++ consts.map(_.toString)).mkString(", ") + ")"
+ def mapExpr(f: Expression => Expression): Expression = this.copy(args = args map f)
+ def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
+ def mapWidth(f: Width => Width): Expression = this
}
-abstract class Statement extends FirrtlNode
+abstract class Statement extends FirrtlNode {
+ def mapStmt(f: Statement => Statement): Statement
+ def mapExpr(f: Expression => Expression): Statement
+ def mapType(f: Type => Type): Statement
+ def mapString(f: String => String): Statement
+}
case class DefWire(info: Info, name: String, tpe: Type) extends Statement with IsDeclaration {
def serialize: String = s"wire $name : ${tpe.serialize}" + info.serialize
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = this
+ def mapType(f: Type => Type): Statement = DefWire(info, name, f(tpe))
+ def mapString(f: String => String): Statement = DefWire(info, f(name), tpe)
}
case class DefRegister(
info: Info,
@@ -118,10 +158,19 @@ case class DefRegister(
def serialize: String =
s"reg $name : ${tpe.serialize}, ${clock.serialize} with :" +
indent("\n" + s"reset => (${reset.serialize}, ${init.serialize})" + info.serialize)
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement =
+ DefRegister(info, name, tpe, f(clock), f(reset), f(init))
+ def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
+ def mapString(f: String => String): Statement = this.copy(name = f(name))
}
case class DefInstance(info: Info, name: String, module: String) extends Statement with IsDeclaration {
def serialize: String = s"inst $name of $module" + info.serialize
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = this
+ def mapType(f: Type => Type): Statement = this
+ def mapString(f: String => String): Statement = DefInstance(info, f(name), module)
}
case class DefMemory(
info: Info,
@@ -146,9 +195,17 @@ case class DefMemory(
(writers map ("writer => " + _)) ++
(readwriters map ("readwriter => " + _)) ++
Seq("read-under-write => undefined")) mkString "\n")
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = this
+ def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType))
+ def mapString(f: String => String): Statement = this.copy(name = f(name))
}
case class DefNode(info: Info, name: String, value: Expression) extends Statement with IsDeclaration {
def serialize: String = s"node $name = ${value.serialize}" + info.serialize
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = DefNode(info, name, f(value))
+ def mapType(f: Type => Type): Statement = this
+ def mapString(f: String => String): Statement = DefNode(info, f(name), value)
}
case class Conditionally(
info: Info,
@@ -160,21 +217,45 @@ case class Conditionally(
indent("\n" + conseq.serialize) +
(if (alt == EmptyStmt) ""
else "\nelse :" + indent("\n" + alt.serialize))
+ def mapStmt(f: Statement => Statement): Statement = Conditionally(info, pred, f(conseq), f(alt))
+ def mapExpr(f: Expression => Expression): Statement = Conditionally(info, f(pred), conseq, alt)
+ def mapType(f: Type => Type): Statement = this
+ def mapString(f: String => String): Statement = this
}
case class Block(stmts: Seq[Statement]) extends Statement {
def serialize: String = stmts map (_.serialize) mkString "\n"
+ def mapStmt(f: Statement => Statement): Statement = Block(stmts map f)
+ def mapExpr(f: Expression => Expression): Statement = this
+ def mapType(f: Type => Type): Statement = this
+ def mapString(f: String => String): Statement = this
}
case class PartialConnect(info: Info, loc: Expression, expr: Expression) extends Statement with HasInfo {
def serialize: String = s"${loc.serialize} <- ${expr.serialize}" + info.serialize
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = PartialConnect(info, f(loc), f(expr))
+ def mapType(f: Type => Type): Statement = this
+ def mapString(f: String => String): Statement = this
}
case class Connect(info: Info, loc: Expression, expr: Expression) extends Statement with HasInfo {
def serialize: String = s"${loc.serialize} <= ${expr.serialize}" + info.serialize
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = Connect(info, f(loc), f(expr))
+ def mapType(f: Type => Type): Statement = this
+ def mapString(f: String => String): Statement = this
}
case class IsInvalid(info: Info, expr: Expression) extends Statement with HasInfo {
def serialize: String = s"${expr.serialize} is invalid" + info.serialize
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = IsInvalid(info, f(expr))
+ def mapType(f: Type => Type): Statement = this
+ def mapString(f: String => String): Statement = this
}
case class Stop(info: Info, ret: Int, clk: Expression, en: Expression) extends Statement with HasInfo {
def serialize: String = s"stop(${clk.serialize}, ${en.serialize}, $ret)" + info.serialize
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = Stop(info, ret, f(clk), f(en))
+ def mapType(f: Type => Type): Statement = this
+ def mapString(f: String => String): Statement = this
}
case class Print(
info: Info,
@@ -187,9 +268,17 @@ case class Print(
(args map (_.serialize))
"printf(" + (strs mkString ", ") + ")" + info.serialize
}
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = Print(info, string, args map f, f(clk), f(en))
+ def mapType(f: Type => Type): Statement = this
+ def mapString(f: String => String): Statement = this
}
case object EmptyStmt extends Statement {
def serialize: String = "skip"
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = this
+ def mapType(f: Type => Type): Statement = this
+ def mapString(f: String => String): Statement = this
}
abstract class Width extends FirrtlNode {
@@ -232,29 +321,43 @@ case class Field(name: String, flip: Orientation, tpe: Type) extends FirrtlNode
def serialize: String = flip.serialize + name + " : " + tpe.serialize
}
-abstract class Type extends FirrtlNode
+abstract class Type extends FirrtlNode {
+ def mapType(f: Type => Type): Type
+ def mapWidth(f: Width => Width): Type
+}
abstract class GroundType extends Type {
val width: Width
+ def mapType(f: Type => Type): Type = this
+}
+abstract class AggregateType extends Type {
+ def mapWidth(f: Width => Width): Type = this
}
-abstract class AggregateType extends Type
case class UIntType(width: Width) extends GroundType {
def serialize: String = "UInt" + width.serialize
+ def mapWidth(f: Width => Width): Type = UIntType(f(width))
}
case class SIntType(width: Width) extends GroundType {
def serialize: String = "SInt" + width.serialize
+ def mapWidth(f: Width => Width): Type = SIntType(f(width))
}
case class BundleType(fields: Seq[Field]) extends AggregateType {
def serialize: String = "{ " + (fields map (_.serialize) mkString ", ") + "}"
+ def mapType(f: Type => Type): Type =
+ BundleType(fields map (x => x.copy(tpe = f(x.tpe))))
}
case class VectorType(tpe: Type, size: Int) extends AggregateType {
def serialize: String = tpe.serialize + s"[$size]"
+ def mapType(f: Type => Type): Type = this.copy(tpe = f(tpe))
}
case object ClockType extends GroundType {
val width = IntWidth(1)
def serialize: String = "Clock"
+ def mapWidth(f: Width => Width): Type = this
}
case object UnknownType extends Type {
def serialize: String = "?"
+ def mapType(f: Type => Type): Type = this
+ def mapWidth(f: Width => Width): Type = this
}
/** [[Port]] Direction */
@@ -283,6 +386,9 @@ abstract class DefModule extends FirrtlNode with IsDeclaration {
protected def serializeHeader(tpe: String): String =
s"$tpe $name :" + info.serialize +
indent(ports map ("\n" + _.serialize) mkString) + "\n"
+ def mapStmt(f: Statement => Statement): DefModule
+ def mapPort(f: Port => Port): DefModule
+ def mapString(f: String => String): DefModule
}
/** Internal Module
*
@@ -290,6 +396,9 @@ abstract class DefModule extends FirrtlNode with IsDeclaration {
*/
case class Module(info: Info, name: String, ports: Seq[Port], body: Statement) extends DefModule {
def serialize: String = serializeHeader("module") + indent("\n" + body.serialize)
+ def mapStmt(f: Statement => Statement): DefModule = this.copy(body = f(body))
+ def mapPort(f: Port => Port): DefModule = this.copy(ports = ports map f)
+ def mapString(f: String => String): DefModule = this.copy(name = f(name))
}
/** External Module
*
@@ -297,6 +406,9 @@ case class Module(info: Info, name: String, ports: Seq[Port], body: Statement) e
*/
case class ExtModule(info: Info, name: String, ports: Seq[Port]) extends DefModule {
def serialize: String = serializeHeader("extmodule")
+ def mapStmt(f: Statement => Statement): DefModule = this
+ def mapPort(f: Port => Port): DefModule = this.copy(ports = ports map f)
+ def mapString(f: String => String): DefModule = this.copy(name = f(name))
}
case class Circuit(info: Info, modules: Seq[DefModule], main: String) extends FirrtlNode with HasInfo {