aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/Utils.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/Utils.scala')
-rw-r--r--src/main/scala/firrtl/Utils.scala280
1 files changed, 205 insertions, 75 deletions
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index ee974f11..647fb9c2 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -19,15 +19,33 @@ import PrimOps._
object Utils {
// Is there a more elegant way to do this?
- private type FlagMap = Map[Symbol, Boolean]
- private val FlagMap = Map[Symbol, Boolean]().withDefaultValue(false)
+ private type FlagMap = Map[String, Boolean]
+ private val FlagMap = Map[String, Boolean]().withDefaultValue(false)
+ val lnOf2 = scala.math.log(2) // natural log of 2
+ def ceil_log2(x: BigInt): BigInt = (x-1).bitLength
+
+ def create_mask (dt:Type) : Type = {
+ dt match {
+ case t:VectorType => VectorType(create_mask(t.tpe),t.size)
+ case t:BundleType => {
+ val fieldss = t.fields.map { f => Field(f.name,f.flip,create_mask(f.tpe)) }
+ BundleType(fieldss)
+ }
+ case t:UIntType => BoolType()
+ case t:SIntType => BoolType()
+ }
+ }
+
+ def error(str:String) = throw new FIRRTLException(str)
def debug(node: AST)(implicit flags: FlagMap): String = {
if (!flags.isEmpty) {
var str = ""
- if (flags('types)) {
+ if (flags("types")) {
val tpe = node.getType
- if( tpe != UnknownType ) str += s"@<t:${tpe.wipeWidth.serialize}>"
+ tpe match {
+ case t:UnknownType => str += s"@<t:${tpe.wipeWidth.serialize}>"
+ }
}
str
}
@@ -49,7 +67,7 @@ object Utils {
//case f: Field => f.getType
case t: Type => t.getType
case p: Port => p.getType
- case _ => UnknownType
+ case _ => UnknownType()
}
}
@@ -57,6 +75,145 @@ object Utils {
def serialize(implicit flags: FlagMap = FlagMap): String = op.getString
}
+
+// ACCESSORS =========
+ def gender (e:Expression) : Gender = {
+ e match {
+ case e:WRef => gender(e)
+ case e:WSubField => gender(e)
+ case e:WSubIndex => gender(e)
+ case e:WSubAccess => gender(e)
+ case e:PrimOp => MALE
+ case e:UIntValue => MALE
+ case e:SIntValue => MALE
+ case e:Mux => MALE
+ case e:ValidIf => MALE
+ case _ => error("Shouldn't be here")
+ }}
+ def get_gender (s:Stmt) : Gender =
+ s match {
+ case s:DefWire => BIGENDER
+ case s:DefRegister => BIGENDER
+ case s:WDefInstance => MALE
+ case s:DefNode => MALE
+ case s:DefInstance => MALE
+ case s:DefPoison => UNKNOWNGENDER
+ case s:DefMemory => MALE
+ case s:Begin => UNKNOWNGENDER
+ case s:Connect => UNKNOWNGENDER
+ case s:BulkConnect => UNKNOWNGENDER
+ case s:Stop => UNKNOWNGENDER
+ case s:Print => UNKNOWNGENDER
+ case s:Empty => UNKNOWNGENDER
+ case s:IsInvalid => UNKNOWNGENDER
+ }
+ def get_gender (p:Port) : Gender =
+ if (p.dir == Input) MALE else FEMALE
+ def kind (e:Expression) : Kind =
+ e match {
+ case e:WRef => e.kind
+ case e:WSubField => kind(e.exp)
+ case e:WSubIndex => kind(e.exp)
+ case e => ExpKind()
+ }
+ def tpe (e:Expression) : Type =
+ e match {
+ case e:WRef => e.tpe
+ case e:WSubField => e.tpe
+ case e:WSubIndex => e.tpe
+ case e:UIntValue => UIntType(e.width)
+ case e:SIntValue => SIntType(e.width)
+ case e:WVoid => UnknownType()
+ case e:WInvalid => UnknownType()
+ }
+ def get_type (s:Stmt) : Type = {
+ s match {
+ case s:DefWire => s.tpe
+ case s:DefPoison => s.tpe
+ case s:DefRegister => s.tpe
+ case s:DefNode => tpe(s.value)
+ case s:DefMemory => {
+ val depth = s.depth
+ val addr = Field("addr",Default,UIntType(IntWidth(ceil_log2(depth))))
+ val en = Field("en",Default,BoolType())
+ val clk = Field("clk",Default,ClockType())
+ val def_data = Field("data",Default,s.dataType)
+ val rev_data = Field("data",Reverse,s.dataType)
+ val mask = Field("mask",Default,create_mask(s.dataType))
+ val wmode = Field("wmode",Default,UIntType(IntWidth(1)))
+ val rdata = Field("rdata",Reverse,s.dataType)
+ val read_type = BundleType(Seq(rev_data,addr,en,clk))
+ val write_type = BundleType(Seq(def_data,mask,addr,en,clk))
+ val readwrite_type = BundleType(Seq(wmode,rdata,def_data,mask,addr,en,clk))
+
+ val mem_fields = Vector()
+ s.readers.foreach {x => mem_fields :+ Field(x,Reverse,read_type)}
+ s.writers.foreach {x => mem_fields :+ Field(x,Reverse,write_type)}
+ s.readwriters.foreach {x => mem_fields :+ Field(x,Reverse,readwrite_type)}
+ BundleType(mem_fields)
+ }
+ case s:DefInstance => UnknownType()
+ case _ => UnknownType()
+ }}
+
+ def sMap(f:Stmt => Stmt, stmt: Stmt): Stmt =
+ stmt match {
+ case w: Conditionally => Conditionally(w.info, w.pred, f(w.conseq), f(w.alt))
+ case b: Begin => Begin(b.stmts.map(f))
+ case s: Stmt => s
+ }
+ def eMap(f:Expression => Expression, stmt:Stmt) : Stmt =
+ stmt match {
+ case r: DefRegister => DefRegister(r.info, r.name, r.tpe, f(r.clock), f(r.reset), f(r.init))
+ case n: DefNode => DefNode(n.info, n.name, f(n.value))
+ case c: Connect => Connect(c.info, f(c.loc), f(c.exp))
+ case b: BulkConnect => BulkConnect(b.info, f(b.loc), f(b.exp))
+ case w: Conditionally => Conditionally(w.info, f(w.pred), w.conseq, w.alt)
+ case i: IsInvalid => IsInvalid(i.info, f(i.exp))
+ case s: Stop => Stop(s.info, s.ret, f(s.clk), f(s.en))
+ case p: Print => Print(p.info, p.string, p.args.map(f), f(p.clk), f(p.en))
+ case s: Stmt => s
+ }
+ def eMap(f: Expression => Expression, exp:Expression): Expression =
+ exp match {
+ case s: SubField => SubField(f(s.exp), s.name, s.tpe)
+ case s: SubIndex => SubIndex(f(s.exp), s.value, s.tpe)
+ case s: SubAccess => SubAccess(f(s.exp), f(s.index), s.tpe)
+ case m: Mux => Mux(f(m.cond), f(m.tval), f(m.fval), m.tpe)
+ case v: ValidIf => ValidIf(f(v.cond), f(v.value), v.tpe)
+ case p: DoPrim => DoPrim(p.op, p.args.map(f), p.consts, p.tpe)
+ case s: WSubField => WSubField(f(s.exp), s.name, s.tpe, s.gender)
+ case s: WSubIndex => WSubIndex(f(s.exp), s.value, s.tpe, s.gender)
+ case s: WSubAccess => WSubAccess(f(s.exp), f(s.index), s.tpe, s.gender)
+ case e: Expression => e
+ }
+ //private trait StmtMagnet {
+ // def map(stmt: Stmt): Stmt
+ //}
+ //private object StmtMagnet {
+ // implicit def forStmt(f: Stmt => Stmt) = new StmtMagnet {
+ // override def map(stmt: Stmt): Stmt =
+ // stmt match {
+ // case w: Conditionally => Conditionally(w.info, w.pred, f(w.conseq), f(w.alt))
+ // case b: Begin => Begin(b.stmts.map(f))
+ // case s: Stmt => s
+ // }
+ // }
+ // implicit def forExp(f: Expression => Expression) = new StmtMagnet {
+ // override def map(stmt: Stmt): Stmt =
+ // stmt match {
+ // case r: DefRegister => DefRegister(r.info, r.name, r.tpe, f(r.clock), f(r.reset), f(r.init))
+ // case n: DefNode => DefNode(n.info, n.name, f(n.value))
+ // case c: Connect => Connect(c.info, f(c.loc), f(c.exp))
+ // case b: BulkConnect => BulkConnect(b.info, f(b.loc), f(b.exp))
+ // case w: Conditionally => Conditionally(w.info, f(w.pred), w.conseq, w.alt)
+ // case i: IsInvalid => IsInvalid(i.info, f(i.exp))
+ // case s: Stop => Stop(s.info, s.ret, f(s.clk), f(s.en))
+ // case p: Print => Print(p.info, p.string, p.args.map(f), f(p.clk), f(p.en))
+ // case s: Stmt => s
+ // }
+ // }
+ //}
implicit class ExpUtils(exp: Expression) {
def serialize(implicit flags: FlagMap = FlagMap): String = {
val ret = exp match {
@@ -70,64 +227,29 @@ object Utils {
case v: ValidIf => s"validif(${v.cond.serialize}, ${v.value.serialize})"
case p: DoPrim =>
s"${p.op.serialize}(" + (p.args.map(_.serialize) ++ p.consts.map(_.toString)).mkString(", ") + ")"
+ case r: WRef => r.name
+ case s: WSubField => s"${s.exp.serialize}.${s.name}"
+ case s: WSubIndex => s"${s.exp.serialize}[${s.value}]"
+ case s: WSubAccess => s"${s.exp.serialize}[${s.index.serialize}]"
}
ret + debug(exp)
}
-
- def map(f: Expression => Expression): Expression =
- exp match {
- case s: SubField => SubField(f(s.exp), s.name, s.tpe)
- case s: SubIndex => SubIndex(f(s.exp), s.value, s.tpe)
- case s: SubAccess => SubAccess(f(s.exp), f(s.index), s.tpe)
- case m: Mux => Mux(f(m.cond), f(m.tval), f(m.fval), m.tpe)
- case v: ValidIf => ValidIf(f(v.cond), f(v.value), v.tpe)
- case p: DoPrim => DoPrim(p.op, p.args.map(f), p.consts, p.tpe)
- case e: Expression => e
- }
-
- def getType(): Type = {
- exp match {
- case v: UIntValue => UIntType(UnknownWidth)
- case v: SIntValue => SIntType(UnknownWidth)
- case r: Ref => r.tpe
- case s: SubField => s.tpe
- case s: SubIndex => s.tpe
- case s: SubAccess => s.tpe
- case p: DoPrim => p.tpe
- case m: Mux => m.tpe
- case v: ValidIf => v.tpe
- }
- }
}
- // Some Scala implicit magic to solve type erasure on Stmt map function overloading
- private trait StmtMagnet {
- def map(stmt: Stmt): Stmt
- }
- private object StmtMagnet {
- implicit def forStmt(f: Stmt => Stmt) = new StmtMagnet {
- override def map(stmt: Stmt): Stmt =
- stmt match {
- case w: Conditionally => Conditionally(w.info, w.pred, f(w.conseq), f(w.alt))
- case b: Begin => Begin(b.stmts.map(f))
- case s: Stmt => s
- }
- }
- implicit def forExp(f: Expression => Expression) = new StmtMagnet {
- override def map(stmt: Stmt): Stmt =
- stmt match {
- case r: DefRegister => DefRegister(r.info, r.name, r.tpe, f(r.clock), f(r.reset), f(r.init))
- case n: DefNode => DefNode(n.info, n.name, f(n.value))
- case c: Connect => Connect(c.info, f(c.loc), f(c.exp))
- case b: BulkConnect => BulkConnect(b.info, f(b.loc), f(b.exp))
- case w: Conditionally => Conditionally(w.info, f(w.pred), w.conseq, w.alt)
- case i: IsInvalid => IsInvalid(i.info, f(i.exp))
- case s: Stop => Stop(s.info, s.ret, f(s.clk), f(s.en))
- case p: Print => Print(p.info, p.string, p.args.map(f), f(p.clk), f(p.en))
- case s: Stmt => s
- }
- }
- }
+ // def map(f: Expression => Expression): Expression =
+ // exp match {
+ // case s: SubField => SubField(f(s.exp), s.name, s.tpe)
+ // case s: SubIndex => SubIndex(f(s.exp), s.value, s.tpe)
+ // case s: SubAccess => SubAccess(f(s.exp), f(s.index), s.tpe)
+ // case m: Mux => Mux(f(m.cond), f(m.tval), f(m.fval), m.tpe)
+ // case v: ValidIf => ValidIf(f(v.cond), f(v.value), v.tpe)
+ // case p: DoPrim => DoPrim(p.op, p.args.map(f), p.consts, p.tpe)
+ // case s: WSubField => SubField(f(s.exp), s.name, s.tpe, s.gender)
+ // case s: WSubIndex => SubIndex(f(s.exp), s.value, s.tpe, s.gender)
+ // case s: WSubAccess => SubAccess(f(s.exp), f(s.index), s.tpe, s.gender)
+ // case e: Expression => e
+ // }
+ //}
implicit class StmtUtils(stmt: Stmt) {
def serialize(implicit flags: FlagMap = FlagMap): String =
@@ -141,6 +263,7 @@ object Utils {
}
str
case i: DefInstance => s"inst ${i.name} of ${i.module}"
+ case i: WDefInstance => s"inst ${i.name} of ${i.module}"
case m: DefMemory => {
val str = new StringBuilder(s"mem ${m.name} : " + newline)
withIndent {
@@ -164,11 +287,14 @@ object Utils {
case w: Conditionally => {
var str = new StringBuilder(s"when ${w.pred.serialize} : ")
withIndent { str ++= w.conseq.serialize }
- if( w.alt != Empty ) {
- str ++= newline + "else :"
- withIndent { str ++= w.alt.serialize }
+ w.alt match {
+ case s:Empty => str.result
+ case s => {
+ str ++= newline + "else :"
+ withIndent { str ++= w.alt.serialize }
+ str.result
+ }
}
- str.result
}
case b: Begin => {
val s = new StringBuilder
@@ -179,20 +305,20 @@ object Utils {
case s: Stop => s"stop(${s.clk.serialize}, ${s.en.serialize}, ${s.ret})"
case p: Print => s"printf(${p.clk.serialize}, ${p.en.serialize}, ${p.string}" +
(if (p.args.nonEmpty) p.args.map(_.serialize).mkString(", ", ", ", "") else "") + ")"
- case Empty => "skip"
+ case s:Empty => "skip"
}
ret + debug(stmt)
}
// Using implicit types to allow overloading of function type to map, see StmtMagnet above
- def map[T](f: T => T)(implicit magnet: (T => T) => StmtMagnet): Stmt = magnet(f).map(stmt)
+ //def map[T](f: T => T)(implicit magnet: (T => T) => StmtMagnet): Stmt = magnet(f).map(stmt)
def getType(): Type =
stmt match {
case s: DefWire => s.tpe
case s: DefRegister => s.tpe
case s: DefMemory => s.dataType
- case _ => UnknownType
+ case _ => UnknownType()
}
}
@@ -242,9 +368,9 @@ object Utils {
def serialize(implicit flags: FlagMap = FlagMap): String = {
val commas = ", " // for mkString in BundleType
val s = t match {
- case ClockType => "Clock"
+ case c:ClockType => "Clock"
//case UnknownType => "UnknownType"
- case UnknownType => "?"
+ case u:UnknownType => "?"
case t: UIntType => s"UInt${t.width.serialize}"
case t: SIntType => s"SInt${t.width.serialize}"
case t: BundleType => s"{ ${t.fields.map(_.serialize).mkString(commas)}}"
@@ -256,7 +382,7 @@ object Utils {
def getType(): Type =
t match {
case v: VectorType => v.tpe
- case tpe: Type => UnknownType
+ case tpe: Type => UnknownType()
}
def wipeWidth(): Type =
@@ -292,19 +418,23 @@ object Utils {
implicit class ModuleUtils(m: Module) {
def serialize(implicit flags: FlagMap = FlagMap): String = {
- var s = new StringBuilder(s"module ${m.name} : ")
- withIndent {
- s ++= m.ports.map(newline ++ _.serialize).mkString
- s ++= m.stmt.serialize
+ m match {
+ case m:InModule => {
+ var s = new StringBuilder(s"module ${m.name} : ")
+ withIndent {
+ s ++= m.ports.map(newline ++ _.serialize).mkString
+ s ++= m.body.serialize
+ }
+ s ++= debug(m)
+ s.toString
+ }
}
- s ++= debug(m)
- s.toString
}
}
implicit class CircuitUtils(c: Circuit) {
def serialize(implicit flags: FlagMap = FlagMap): String = {
- var s = new StringBuilder(s"circuit ${c.name} : ")
+ var s = new StringBuilder(s"circuit ${c.main} : ")
withIndent { s ++= newline ++ c.modules.map(_.serialize).mkString(newline + newline) }
s ++= newline ++ newline
s ++= debug(c)