diff options
Diffstat (limited to 'src/main/scala/firrtl/Utils.scala')
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 280 |
1 files changed, 205 insertions, 75 deletions
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index ee974f11..647fb9c2 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -19,15 +19,33 @@ import PrimOps._ object Utils { // Is there a more elegant way to do this? - private type FlagMap = Map[Symbol, Boolean] - private val FlagMap = Map[Symbol, Boolean]().withDefaultValue(false) + private type FlagMap = Map[String, Boolean] + private val FlagMap = Map[String, Boolean]().withDefaultValue(false) + val lnOf2 = scala.math.log(2) // natural log of 2 + def ceil_log2(x: BigInt): BigInt = (x-1).bitLength + + def create_mask (dt:Type) : Type = { + dt match { + case t:VectorType => VectorType(create_mask(t.tpe),t.size) + case t:BundleType => { + val fieldss = t.fields.map { f => Field(f.name,f.flip,create_mask(f.tpe)) } + BundleType(fieldss) + } + case t:UIntType => BoolType() + case t:SIntType => BoolType() + } + } + + def error(str:String) = throw new FIRRTLException(str) def debug(node: AST)(implicit flags: FlagMap): String = { if (!flags.isEmpty) { var str = "" - if (flags('types)) { + if (flags("types")) { val tpe = node.getType - if( tpe != UnknownType ) str += s"@<t:${tpe.wipeWidth.serialize}>" + tpe match { + case t:UnknownType => str += s"@<t:${tpe.wipeWidth.serialize}>" + } } str } @@ -49,7 +67,7 @@ object Utils { //case f: Field => f.getType case t: Type => t.getType case p: Port => p.getType - case _ => UnknownType + case _ => UnknownType() } } @@ -57,6 +75,145 @@ object Utils { def serialize(implicit flags: FlagMap = FlagMap): String = op.getString } + +// ACCESSORS ========= + def gender (e:Expression) : Gender = { + e match { + case e:WRef => gender(e) + case e:WSubField => gender(e) + case e:WSubIndex => gender(e) + case e:WSubAccess => gender(e) + case e:PrimOp => MALE + case e:UIntValue => MALE + case e:SIntValue => MALE + case e:Mux => MALE + case e:ValidIf => MALE + case _ => error("Shouldn't be here") + }} + def get_gender (s:Stmt) : Gender = + s match { + case s:DefWire => BIGENDER + case s:DefRegister => BIGENDER + case s:WDefInstance => MALE + case s:DefNode => MALE + case s:DefInstance => MALE + case s:DefPoison => UNKNOWNGENDER + case s:DefMemory => MALE + case s:Begin => UNKNOWNGENDER + case s:Connect => UNKNOWNGENDER + case s:BulkConnect => UNKNOWNGENDER + case s:Stop => UNKNOWNGENDER + case s:Print => UNKNOWNGENDER + case s:Empty => UNKNOWNGENDER + case s:IsInvalid => UNKNOWNGENDER + } + def get_gender (p:Port) : Gender = + if (p.dir == Input) MALE else FEMALE + def kind (e:Expression) : Kind = + e match { + case e:WRef => e.kind + case e:WSubField => kind(e.exp) + case e:WSubIndex => kind(e.exp) + case e => ExpKind() + } + def tpe (e:Expression) : Type = + e match { + case e:WRef => e.tpe + case e:WSubField => e.tpe + case e:WSubIndex => e.tpe + case e:UIntValue => UIntType(e.width) + case e:SIntValue => SIntType(e.width) + case e:WVoid => UnknownType() + case e:WInvalid => UnknownType() + } + def get_type (s:Stmt) : Type = { + s match { + case s:DefWire => s.tpe + case s:DefPoison => s.tpe + case s:DefRegister => s.tpe + case s:DefNode => tpe(s.value) + case s:DefMemory => { + val depth = s.depth + val addr = Field("addr",Default,UIntType(IntWidth(ceil_log2(depth)))) + val en = Field("en",Default,BoolType()) + val clk = Field("clk",Default,ClockType()) + val def_data = Field("data",Default,s.dataType) + val rev_data = Field("data",Reverse,s.dataType) + val mask = Field("mask",Default,create_mask(s.dataType)) + val wmode = Field("wmode",Default,UIntType(IntWidth(1))) + val rdata = Field("rdata",Reverse,s.dataType) + val read_type = BundleType(Seq(rev_data,addr,en,clk)) + val write_type = BundleType(Seq(def_data,mask,addr,en,clk)) + val readwrite_type = BundleType(Seq(wmode,rdata,def_data,mask,addr,en,clk)) + + val mem_fields = Vector() + s.readers.foreach {x => mem_fields :+ Field(x,Reverse,read_type)} + s.writers.foreach {x => mem_fields :+ Field(x,Reverse,write_type)} + s.readwriters.foreach {x => mem_fields :+ Field(x,Reverse,readwrite_type)} + BundleType(mem_fields) + } + case s:DefInstance => UnknownType() + case _ => UnknownType() + }} + + def sMap(f:Stmt => Stmt, stmt: Stmt): Stmt = + stmt match { + case w: Conditionally => Conditionally(w.info, w.pred, f(w.conseq), f(w.alt)) + case b: Begin => Begin(b.stmts.map(f)) + case s: Stmt => s + } + def eMap(f:Expression => Expression, stmt:Stmt) : Stmt = + stmt match { + case r: DefRegister => DefRegister(r.info, r.name, r.tpe, f(r.clock), f(r.reset), f(r.init)) + case n: DefNode => DefNode(n.info, n.name, f(n.value)) + case c: Connect => Connect(c.info, f(c.loc), f(c.exp)) + case b: BulkConnect => BulkConnect(b.info, f(b.loc), f(b.exp)) + case w: Conditionally => Conditionally(w.info, f(w.pred), w.conseq, w.alt) + case i: IsInvalid => IsInvalid(i.info, f(i.exp)) + case s: Stop => Stop(s.info, s.ret, f(s.clk), f(s.en)) + case p: Print => Print(p.info, p.string, p.args.map(f), f(p.clk), f(p.en)) + case s: Stmt => s + } + def eMap(f: Expression => Expression, exp:Expression): Expression = + exp match { + case s: SubField => SubField(f(s.exp), s.name, s.tpe) + case s: SubIndex => SubIndex(f(s.exp), s.value, s.tpe) + case s: SubAccess => SubAccess(f(s.exp), f(s.index), s.tpe) + case m: Mux => Mux(f(m.cond), f(m.tval), f(m.fval), m.tpe) + case v: ValidIf => ValidIf(f(v.cond), f(v.value), v.tpe) + case p: DoPrim => DoPrim(p.op, p.args.map(f), p.consts, p.tpe) + case s: WSubField => WSubField(f(s.exp), s.name, s.tpe, s.gender) + case s: WSubIndex => WSubIndex(f(s.exp), s.value, s.tpe, s.gender) + case s: WSubAccess => WSubAccess(f(s.exp), f(s.index), s.tpe, s.gender) + case e: Expression => e + } + //private trait StmtMagnet { + // def map(stmt: Stmt): Stmt + //} + //private object StmtMagnet { + // implicit def forStmt(f: Stmt => Stmt) = new StmtMagnet { + // override def map(stmt: Stmt): Stmt = + // stmt match { + // case w: Conditionally => Conditionally(w.info, w.pred, f(w.conseq), f(w.alt)) + // case b: Begin => Begin(b.stmts.map(f)) + // case s: Stmt => s + // } + // } + // implicit def forExp(f: Expression => Expression) = new StmtMagnet { + // override def map(stmt: Stmt): Stmt = + // stmt match { + // case r: DefRegister => DefRegister(r.info, r.name, r.tpe, f(r.clock), f(r.reset), f(r.init)) + // case n: DefNode => DefNode(n.info, n.name, f(n.value)) + // case c: Connect => Connect(c.info, f(c.loc), f(c.exp)) + // case b: BulkConnect => BulkConnect(b.info, f(b.loc), f(b.exp)) + // case w: Conditionally => Conditionally(w.info, f(w.pred), w.conseq, w.alt) + // case i: IsInvalid => IsInvalid(i.info, f(i.exp)) + // case s: Stop => Stop(s.info, s.ret, f(s.clk), f(s.en)) + // case p: Print => Print(p.info, p.string, p.args.map(f), f(p.clk), f(p.en)) + // case s: Stmt => s + // } + // } + //} implicit class ExpUtils(exp: Expression) { def serialize(implicit flags: FlagMap = FlagMap): String = { val ret = exp match { @@ -70,64 +227,29 @@ object Utils { case v: ValidIf => s"validif(${v.cond.serialize}, ${v.value.serialize})" case p: DoPrim => s"${p.op.serialize}(" + (p.args.map(_.serialize) ++ p.consts.map(_.toString)).mkString(", ") + ")" + case r: WRef => r.name + case s: WSubField => s"${s.exp.serialize}.${s.name}" + case s: WSubIndex => s"${s.exp.serialize}[${s.value}]" + case s: WSubAccess => s"${s.exp.serialize}[${s.index.serialize}]" } ret + debug(exp) } - - def map(f: Expression => Expression): Expression = - exp match { - case s: SubField => SubField(f(s.exp), s.name, s.tpe) - case s: SubIndex => SubIndex(f(s.exp), s.value, s.tpe) - case s: SubAccess => SubAccess(f(s.exp), f(s.index), s.tpe) - case m: Mux => Mux(f(m.cond), f(m.tval), f(m.fval), m.tpe) - case v: ValidIf => ValidIf(f(v.cond), f(v.value), v.tpe) - case p: DoPrim => DoPrim(p.op, p.args.map(f), p.consts, p.tpe) - case e: Expression => e - } - - def getType(): Type = { - exp match { - case v: UIntValue => UIntType(UnknownWidth) - case v: SIntValue => SIntType(UnknownWidth) - case r: Ref => r.tpe - case s: SubField => s.tpe - case s: SubIndex => s.tpe - case s: SubAccess => s.tpe - case p: DoPrim => p.tpe - case m: Mux => m.tpe - case v: ValidIf => v.tpe - } - } } - // Some Scala implicit magic to solve type erasure on Stmt map function overloading - private trait StmtMagnet { - def map(stmt: Stmt): Stmt - } - private object StmtMagnet { - implicit def forStmt(f: Stmt => Stmt) = new StmtMagnet { - override def map(stmt: Stmt): Stmt = - stmt match { - case w: Conditionally => Conditionally(w.info, w.pred, f(w.conseq), f(w.alt)) - case b: Begin => Begin(b.stmts.map(f)) - case s: Stmt => s - } - } - implicit def forExp(f: Expression => Expression) = new StmtMagnet { - override def map(stmt: Stmt): Stmt = - stmt match { - case r: DefRegister => DefRegister(r.info, r.name, r.tpe, f(r.clock), f(r.reset), f(r.init)) - case n: DefNode => DefNode(n.info, n.name, f(n.value)) - case c: Connect => Connect(c.info, f(c.loc), f(c.exp)) - case b: BulkConnect => BulkConnect(b.info, f(b.loc), f(b.exp)) - case w: Conditionally => Conditionally(w.info, f(w.pred), w.conseq, w.alt) - case i: IsInvalid => IsInvalid(i.info, f(i.exp)) - case s: Stop => Stop(s.info, s.ret, f(s.clk), f(s.en)) - case p: Print => Print(p.info, p.string, p.args.map(f), f(p.clk), f(p.en)) - case s: Stmt => s - } - } - } + // def map(f: Expression => Expression): Expression = + // exp match { + // case s: SubField => SubField(f(s.exp), s.name, s.tpe) + // case s: SubIndex => SubIndex(f(s.exp), s.value, s.tpe) + // case s: SubAccess => SubAccess(f(s.exp), f(s.index), s.tpe) + // case m: Mux => Mux(f(m.cond), f(m.tval), f(m.fval), m.tpe) + // case v: ValidIf => ValidIf(f(v.cond), f(v.value), v.tpe) + // case p: DoPrim => DoPrim(p.op, p.args.map(f), p.consts, p.tpe) + // case s: WSubField => SubField(f(s.exp), s.name, s.tpe, s.gender) + // case s: WSubIndex => SubIndex(f(s.exp), s.value, s.tpe, s.gender) + // case s: WSubAccess => SubAccess(f(s.exp), f(s.index), s.tpe, s.gender) + // case e: Expression => e + // } + //} implicit class StmtUtils(stmt: Stmt) { def serialize(implicit flags: FlagMap = FlagMap): String = @@ -141,6 +263,7 @@ object Utils { } str case i: DefInstance => s"inst ${i.name} of ${i.module}" + case i: WDefInstance => s"inst ${i.name} of ${i.module}" case m: DefMemory => { val str = new StringBuilder(s"mem ${m.name} : " + newline) withIndent { @@ -164,11 +287,14 @@ object Utils { case w: Conditionally => { var str = new StringBuilder(s"when ${w.pred.serialize} : ") withIndent { str ++= w.conseq.serialize } - if( w.alt != Empty ) { - str ++= newline + "else :" - withIndent { str ++= w.alt.serialize } + w.alt match { + case s:Empty => str.result + case s => { + str ++= newline + "else :" + withIndent { str ++= w.alt.serialize } + str.result + } } - str.result } case b: Begin => { val s = new StringBuilder @@ -179,20 +305,20 @@ object Utils { case s: Stop => s"stop(${s.clk.serialize}, ${s.en.serialize}, ${s.ret})" case p: Print => s"printf(${p.clk.serialize}, ${p.en.serialize}, ${p.string}" + (if (p.args.nonEmpty) p.args.map(_.serialize).mkString(", ", ", ", "") else "") + ")" - case Empty => "skip" + case s:Empty => "skip" } ret + debug(stmt) } // Using implicit types to allow overloading of function type to map, see StmtMagnet above - def map[T](f: T => T)(implicit magnet: (T => T) => StmtMagnet): Stmt = magnet(f).map(stmt) + //def map[T](f: T => T)(implicit magnet: (T => T) => StmtMagnet): Stmt = magnet(f).map(stmt) def getType(): Type = stmt match { case s: DefWire => s.tpe case s: DefRegister => s.tpe case s: DefMemory => s.dataType - case _ => UnknownType + case _ => UnknownType() } } @@ -242,9 +368,9 @@ object Utils { def serialize(implicit flags: FlagMap = FlagMap): String = { val commas = ", " // for mkString in BundleType val s = t match { - case ClockType => "Clock" + case c:ClockType => "Clock" //case UnknownType => "UnknownType" - case UnknownType => "?" + case u:UnknownType => "?" case t: UIntType => s"UInt${t.width.serialize}" case t: SIntType => s"SInt${t.width.serialize}" case t: BundleType => s"{ ${t.fields.map(_.serialize).mkString(commas)}}" @@ -256,7 +382,7 @@ object Utils { def getType(): Type = t match { case v: VectorType => v.tpe - case tpe: Type => UnknownType + case tpe: Type => UnknownType() } def wipeWidth(): Type = @@ -292,19 +418,23 @@ object Utils { implicit class ModuleUtils(m: Module) { def serialize(implicit flags: FlagMap = FlagMap): String = { - var s = new StringBuilder(s"module ${m.name} : ") - withIndent { - s ++= m.ports.map(newline ++ _.serialize).mkString - s ++= m.stmt.serialize + m match { + case m:InModule => { + var s = new StringBuilder(s"module ${m.name} : ") + withIndent { + s ++= m.ports.map(newline ++ _.serialize).mkString + s ++= m.body.serialize + } + s ++= debug(m) + s.toString + } } - s ++= debug(m) - s.toString } } implicit class CircuitUtils(c: Circuit) { def serialize(implicit flags: FlagMap = FlagMap): String = { - var s = new StringBuilder(s"circuit ${c.name} : ") + var s = new StringBuilder(s"circuit ${c.main} : ") withIndent { s ++= newline ++ c.modules.map(_.serialize).mkString(newline + newline) } s ++= newline ++ newline s ++= debug(c) |
