diff options
| author | azidar | 2016-01-29 18:14:41 -0800 |
|---|---|---|
| committer | azidar | 2016-02-09 18:55:25 -0800 |
| commit | 0181686fe4bdf24f9e22f406c43dbeb98789cb8b (patch) | |
| tree | 1b99f826c9f58a9119e030a0ec53de3b9a002c2f /src | |
| parent | e2177899c82e464f853e4daf8d23c11d27ca5157 (diff) | |
WIP. Got to-working-ir working
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Driver.scala | 6 | ||||
| -rw-r--r-- | src/main/scala/firrtl/IR.scala | 19 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Passes.scala | 314 | ||||
| -rw-r--r-- | src/main/scala/firrtl/PrimOps.scala | 4 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 280 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Visitor.scala | 42 | ||||
| -rw-r--r-- | src/main/scala/firrtl/WIR.scala | 36 | ||||
| -rw-r--r-- | src/main/stanza/passes.stanza | 23 |
8 files changed, 457 insertions, 267 deletions
diff --git a/src/main/scala/firrtl/Driver.scala b/src/main/scala/firrtl/Driver.scala index 3eee19bd..2e525f7e 100644 --- a/src/main/scala/firrtl/Driver.scala +++ b/src/main/scala/firrtl/Driver.scala @@ -84,11 +84,11 @@ object Driver extends LazyLogging { // ===================================== StanzaPass("high-form-check"), // ===================================== + ScalaPass(toWorkingIr), +// ===================================== StanzaPass("to-working-ir"), // ===================================== - StanzaPass("resolve-kinds"), -// StanzaPass("infer-types"), - ScalaPass(inferTypes), + StanzaPass("infer-types"), StanzaPass("check-types"), StanzaPass("resolve-genders"), StanzaPass("check-genders"), diff --git a/src/main/scala/firrtl/IR.scala b/src/main/scala/firrtl/IR.scala index 858f48cf..6b93c763 100644 --- a/src/main/scala/firrtl/IR.scala +++ b/src/main/scala/firrtl/IR.scala @@ -14,6 +14,8 @@ case class FileInfo(file: String, line: Int, column: Int) extends Info { override def toString(): String = s"$file@$line.$column" } +case class FIRRTLException(str:String) extends Exception + trait AST trait PrimOp extends AST @@ -63,6 +65,7 @@ case class DoPrim(op: PrimOp, args: Seq[Expression], consts: Seq[BigInt], tpe: T trait Stmt extends AST case class DefWire(info: Info, name: String, tpe: Type) extends Stmt +case class DefPoison(info: Info, name: String, tpe: Type) extends Stmt case class DefRegister(info: Info, name: String, tpe: Type, clock: Expression, reset: Expression, init: Expression) extends Stmt case class DefInstance(info: Info, name: String, module: String) extends Stmt case class DefMemory(info: Info, name: String, dataType: Type, depth: Int, writeLatency: Int, @@ -75,7 +78,7 @@ case class Connect(info: Info, loc: Expression, exp: Expression) extends Stmt case class IsInvalid(info: Info, exp: Expression) extends Stmt case class Stop(info: Info, ret: Int, clk: Expression, en: Expression) extends Stmt case class Print(info: Info, string: String, args: Seq[Expression], clk: Expression, en: Expression) extends Stmt -case object Empty extends Stmt +case class Empty() extends Stmt trait Width extends AST case class IntWidth(width: BigInt) extends Width @@ -92,8 +95,8 @@ case class UIntType(width: Width) extends Type case class SIntType(width: Width) extends Type case class BundleType(fields: Seq[Field]) extends Type case class VectorType(tpe: Type, size: BigInt) extends Type -case object ClockType extends Type -case object UnknownType extends Type +case class ClockType() extends Type +case class UnknownType() extends Type trait Direction extends AST case object Input extends Direction @@ -101,8 +104,14 @@ case object Output extends Direction case class Port(info: Info, name: String, dir: Direction, tpe: Type) extends AST -case class Module(info: Info, name: String, ports: Seq[Port], stmt: Stmt) extends AST +trait Module extends AST { + val info : Info + val name : String + val ports : Seq[Port] +} +case class InModule(info: Info, name: String, ports: Seq[Port], body: Stmt) extends Module +case class ExModule(info: Info, name: String, ports: Seq[Port]) extends Module -case class Circuit(info: Info, name: String, modules: Seq[Module]) extends AST +case class Circuit(info: Info, modules: Seq[Module], main: String) extends AST diff --git a/src/main/scala/firrtl/Passes.scala b/src/main/scala/firrtl/Passes.scala index 02016494..e28205f9 100644 --- a/src/main/scala/firrtl/Passes.scala +++ b/src/main/scala/firrtl/Passes.scala @@ -7,9 +7,6 @@ import Utils._ import DebugUtils._ import PrimOps._ - - - object Passes extends LazyLogging { // TODO Perhaps we should get rid of Logger since this map would be nice @@ -20,9 +17,10 @@ object Passes extends LazyLogging { def nameToPass(name: String): Circuit => Circuit = { //mapNameToPass.getOrElse(name, throw new Exception("No Standard FIRRTL Pass of name " + name)) name match { - case "infer-types" => inferTypes + case "to-working-ir" => toWorkingIr + //case "infer-types" => inferTypes // errrrrrrrrrr... - case "renameall" => renameall(Map()) + //case "renameall" => renameall(Map()) } } @@ -34,6 +32,34 @@ object Passes extends LazyLogging { } } + // ============== TO WORKING IR ================== + def toWorkingIr (c:Circuit) = { + def toExp (e:Expression) : Expression = { + eMap(toExp _,e) match { + case e:Ref => WRef(e.name, e.tpe, NodeKind(), UNKNOWNGENDER) + case e:SubField => WSubField(e.exp, e.name, e.tpe, UNKNOWNGENDER) + case e:SubIndex => WSubIndex(e.exp, e.value, e.tpe, UNKNOWNGENDER) + case e:SubAccess => WSubAccess(e.exp, e.index, e.tpe, UNKNOWNGENDER) + case e => e + } + } + def toStmt (s:Stmt) : Stmt = { + eMap(toExp _,s) match { + case s:DefInstance => WDefInstance(s.info,s.name,s.module,UnknownType()) + case s => sMap(toStmt _,s) + } + } + val modulesx = c.modules.map { m => + m match { + case m:InModule => InModule(m.info,m.name, m.ports, toStmt(m.body)) + case m:ExModule => m + } + } + Circuit(c.info,modulesx,c.main) + } + // =============================================== + + /** INFER TYPES * * This pass infers the type field in all IR nodes by updating @@ -45,137 +71,149 @@ object Passes extends LazyLogging { * postponed for a later/earlier pass. */ // input -> flip - private def getBundleSubtype(t: Type, name: String): Type = { - t match { - case b: BundleType => { - val tpe = b.fields.find( _.name == name ) - if (tpe.isEmpty) UnknownType - else tpe.get.tpe - } - case _ => UnknownType - } - } - private def getVectorSubtype(t: Type): Type = t.getType // Added for clarity - private type TypeMap = Map[String, Type] - /*def inferTypes(c: Circuit): Circuit = { - val moduleTypeMap = Map[String, Type]().withDefaultValue(UnknownType) - def inferTypes(m: Module): Module = { - val typeMap = Map[String, Type]().withDefaultValue(UnknownType) - def inferExpTypes(exp: Expression): Expression = { - //logger.debug(s"inferTypes called on ${exp.getClass.getSimpleName}") - exp.map(inferExpTypes) match { - case e: UIntValue => e - case e: SIntValue => e - case e: Ref => Ref(e.name, typeMap(e.name)) - case e: SubField => SubField(e.exp, e.name, getBundleSubtype(e.exp.getType, e.name)) - case e: SubIndex => SubIndex(e.exp, e.value, getVectorSubtype(e.exp.getType)) - case e: SubAccess => SubAccess(e.exp, e.index, getVectorSubtype(e.exp.getType)) - case e: DoPrim => lowerAndTypePrimOp(e) - case e: Expression => e - } - } - def inferStmtTypes(stmt: Stmt): (Stmt) = { - //logger.debug(s"inferStmtTypes called on ${stmt.getClass.getSimpleName} ") - stmt match { - case s: DefWire => - typeMap(s.name) = s.tpe - s - case s: DefRegister => - typeMap(s.name) = get_tpe(s) - s - case s: DefMemory => - typeMap(s.name) = get_tpe(s) - s - case s: DefInstance => (s, typeMap ++ Map(s.name -> typeMap(s.module))) - case s: DefNode => (s, typeMap ++ Map(s.name -> s.value.getType)) - case s: s.map(inferStmtTypes) - }.map(inferExpTypes) - } - //logger.debug(s"inferTypes called on module ${m.name}") - - m.ports.for( p => typeMap(p.name) = p.tpe) - Module(m.info, m.name, m.ports, inferStmtTypes(m.stmt)) - } - //logger.debug(s"inferTypes called on circuit ${c.name}") - - // initialize typeMap with each module of circuit mapped to their bundled IO (ports converted to fields) - val typeMap = c.modules.map(m => m.name -> BundleType(m.ports.map(toField(_)))).toMap - Circuit(c.info, c.name, c.modules.map(inferTypes(typeMap, _))) - }*/ - - def renameall(s : String)(implicit map : Map[String,String]) : String = - map getOrElse (s, s) - - def renameall(e : Expression)(implicit map : Map[String,String]) : Expression = { - logger.debug(s"renameall called on expression ${e.toString}") - e match { - case p : Ref => - Ref(renameall(p.name), p.tpe) - case p : SubField => - SubField(renameall(p.exp), renameall(p.name), p.tpe) - case p : SubIndex => - SubIndex(renameall(p.exp), p.value, p.tpe) - case p : SubAccess => - SubAccess(renameall(p.exp), renameall(p.index), p.tpe) - case p : Mux => - Mux(renameall(p.cond), renameall(p.tval), renameall(p.fval), p.tpe) - case p : ValidIf => - ValidIf(renameall(p.cond), renameall(p.value), p.tpe) - case p : DoPrim => - println( p.args.map(x => renameall(x)) ) - DoPrim(p.op, p.args.map(renameall), p.consts, p.tpe) - case p : Expression => p - } - } - - def renameall(s : Stmt)(implicit map : Map[String,String]) : Stmt = { - logger.debug(s"renameall called on statement ${s.toString}") - - s match { - case p : DefWire => - DefWire(p.info, renameall(p.name), p.tpe) - case p: DefRegister => - DefRegister(p.info, renameall(p.name), p.tpe, p.clock, p.reset, p.init) - case p : DefMemory => - DefMemory(p.info, renameall(p.name), p.dataType, p.depth, p.writeLatency, p.readLatency, - p.readers, p.writers, p.readwriters) - case p : DefInstance => - DefInstance(p.info, renameall(p.name), renameall(p.module)) - case p : DefNode => - DefNode(p.info, renameall(p.name), renameall(p.value)) - case p : Connect => - Connect(p.info, renameall(p.loc), renameall(p.exp)) - case p : BulkConnect => - BulkConnect(p.info, renameall(p.loc), renameall(p.exp)) - case p : IsInvalid => - IsInvalid(p.info, renameall(p.exp)) - case p : Stop => - Stop(p.info, p.ret, renameall(p.clk), renameall(p.en)) - case p : Print => - Print(p.info, p.string, p.args.map(renameall), renameall(p.clk), renameall(p.en)) - case p : Conditionally => - Conditionally(p.info, renameall(p.pred), renameall(p.conseq), renameall(p.alt)) - case p : Begin => - Begin(p.stmts.map(renameall)) - case p : Stmt => p - } - } - - def renameall(p : Port)(implicit map : Map[String,String]) : Port = { - logger.debug(s"renameall called on port ${p.name}") - Port(p.info, renameall(p.name), p.dir, p.tpe) - } - - def renameall(m : Module)(implicit map : Map[String,String]) : Module = { - logger.debug(s"renameall called on module ${m.name}") - Module(m.info, renameall(m.name), m.ports.map(renameall(_)), renameall(m.stmt)) - } - - def renameall(map : Map[String,String]) : Circuit => Circuit = { - c => { - implicit val imap = map - logger.debug(s"renameall called on circuit ${c.name} with %{renameto}") - Circuit(c.info, renameall(c.name), c.modules.map(renameall(_))) - } - } + //private type TypeMap = Map[String, Type] + //private val TypeMap = Map[String, Type]().withDefaultValue(UnknownType) + //private def getBundleSubtype(t: Type, name: String): Type = { + // t match { + // case b: BundleType => { + // val tpe = b.fields.find( _.name == name ) + // if (tpe.isEmpty) UnknownType + // else tpe.get.tpe + // } + // case _ => UnknownType + // } + //} + //private def getVectorSubtype(t: Type): Type = t.getType // Added for clarity + //// TODO Add genders + //private def inferExpTypes(typeMap: TypeMap)(exp: Expression): Expression = { + // logger.debug(s"inferTypes called on ${exp.getClass.getSimpleName}") + // exp.map(inferExpTypes(typeMap)) match { + // case e: UIntValue => e + // case e: SIntValue => e + // case e: Ref => Ref(e.name, typeMap(e.name)) + // case e: SubField => SubField(e.exp, e.name, getBundleSubtype(e.exp.getType, e.name)) + // case e: SubIndex => SubIndex(e.exp, e.value, getVectorSubtype(e.exp.getType)) + // case e: SubAccess => SubAccess(e.exp, e.index, getVectorSubtype(e.exp.getType)) + // case e: DoPrim => lowerAndTypePrimOp(e) + // case e: Expression => e + // } + //} + //private def inferTypes(typeMap: TypeMap, stmt: Stmt): (Stmt, TypeMap) = { + // logger.debug(s"inferTypes called on ${stmt.getClass.getSimpleName} ") + // stmt.map(inferExpTypes(typeMap)) match { + // case b: Begin => { + // var tMap = typeMap + // // TODO FIXME is map correctly called in sequential order + // val body = b.stmts.map { s => + // val ret = inferTypes(tMap, s) + // tMap = ret._2 + // ret._1 + // } + // (Begin(body), tMap) + // } + // case s: DefWire => (s, typeMap ++ Map(s.name -> s.tpe)) + // case s: DefRegister => (s, typeMap ++ Map(s.name -> s.tpe)) + // case s: DefMemory => (s, typeMap ++ Map(s.name -> s.dataType)) + // case s: DefInstance => (s, typeMap ++ Map(s.name -> typeMap(s.module))) + // case s: DefNode => (s, typeMap ++ Map(s.name -> s.value.getType)) + // case s: Conditionally => { // TODO Check: Assuming else block won't see when scope + // val (conseq, cMap) = inferTypes(typeMap, s.conseq) + // val (alt, aMap) = inferTypes(typeMap, s.alt) + // (Conditionally(s.info, s.pred, conseq, alt), cMap ++ aMap) + // } + // case s: Stmt => (s, typeMap) + // } + //} + //private def inferTypes(typeMap: TypeMap, m: Module): Module = { + // logger.debug(s"inferTypes called on module ${m.name}") + + // val pTypeMap = m.ports.map( p => p.name -> p.tpe ).toMap + + // Module(m.info, m.name, m.ports, inferTypes(typeMap ++ pTypeMap, m.stmt)._1) + //} + //def inferTypes(c: Circuit): Circuit = { + // logger.debug(s"inferTypes called on circuit ${c.name}") + + // // initialize typeMap with each module of circuit mapped to their bundled IO (ports converted to fields) + // val typeMap = c.modules.map(m => m.name -> BundleType(m.ports.map(toField(_)))).toMap + + // //val typeMap = c.modules.flatMap(buildTypeMap).toMap + // Circuit(c.info, c.name, c.modules.map(inferTypes(typeMap, _))) + //} + + //def renameall(s : String)(implicit map : Map[String,String]) : String = + // map getOrElse (s, s) + + //def renameall(e : Expression)(implicit map : Map[String,String]) : Expression = { + // logger.debug(s"renameall called on expression ${e.toString}") + // e match { + // case p : Ref => + // Ref(renameall(p.name), p.tpe) + // case p : SubField => + // SubField(renameall(p.exp), renameall(p.name), p.tpe) + // case p : SubIndex => + // SubIndex(renameall(p.exp), p.value, p.tpe) + // case p : SubAccess => + // SubAccess(renameall(p.exp), renameall(p.index), p.tpe) + // case p : Mux => + // Mux(renameall(p.cond), renameall(p.tval), renameall(p.fval), p.tpe) + // case p : ValidIf => + // ValidIf(renameall(p.cond), renameall(p.value), p.tpe) + // case p : DoPrim => + // println( p.args.map(x => renameall(x)) ) + // DoPrim(p.op, p.args.map(renameall), p.consts, p.tpe) + // case p : Expression => p + // } + //} + + //def renameall(s : Stmt)(implicit map : Map[String,String]) : Stmt = { + // logger.debug(s"renameall called on statement ${s.toString}") + + // s match { + // case p : DefWire => + // DefWire(p.info, renameall(p.name), p.tpe) + // case p: DefRegister => + // DefRegister(p.info, renameall(p.name), p.tpe, p.clock, p.reset, p.init) + // case p : DefMemory => + // DefMemory(p.info, renameall(p.name), p.dataType, p.depth, p.writeLatency, p.readLatency, + // p.readers, p.writers, p.readwriters) + // case p : DefInstance => + // DefInstance(p.info, renameall(p.name), renameall(p.module)) + // case p : DefNode => + // DefNode(p.info, renameall(p.name), renameall(p.value)) + // case p : Connect => + // Connect(p.info, renameall(p.loc), renameall(p.exp)) + // case p : BulkConnect => + // BulkConnect(p.info, renameall(p.loc), renameall(p.exp)) + // case p : IsInvalid => + // IsInvalid(p.info, renameall(p.exp)) + // case p : Stop => + // Stop(p.info, p.ret, renameall(p.clk), renameall(p.en)) + // case p : Print => + // Print(p.info, p.string, p.args.map(renameall), renameall(p.clk), renameall(p.en)) + // case p : Conditionally => + // Conditionally(p.info, renameall(p.pred), renameall(p.conseq), renameall(p.alt)) + // case p : Begin => + // Begin(p.stmts.map(renameall)) + // case p : Stmt => p + // } + //} + + //def renameall(p : Port)(implicit map : Map[String,String]) : Port = { + // logger.debug(s"renameall called on port ${p.name}") + // Port(p.info, renameall(p.name), p.dir, p.tpe) + //} + + //def renameall(m : Module)(implicit map : Map[String,String]) : Module = { + // logger.debug(s"renameall called on module ${m.name}") + // Module(m.info, renameall(m.name), m.ports.map(renameall(_)), renameall(m.stmt)) + //} + + //def renameall(map : Map[String,String]) : Circuit => Circuit = { + // c => { + // implicit val imap = map + // logger.debug(s"renameall called on circuit ${c.name} with %{renameto}") + // Circuit(c.info, renameall(c.name), c.modules.map(renameall(_))) + // } + //} } diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala index 0da0e01e..4732e756 100644 --- a/src/main/scala/firrtl/PrimOps.scala +++ b/src/main/scala/firrtl/PrimOps.scala @@ -56,14 +56,14 @@ object PrimOps extends LazyLogging { case (t1: UIntType, t2: UIntType) => UIntType(UnknownWidth) case (t1: SIntType, t2) => SIntType(UnknownWidth) case (t1, t2: SIntType) => SIntType(UnknownWidth) - case _ => UnknownType + case _ => UnknownType() } } def ofType(op: Expression): Type = { op.getType match { case t: UIntType => UIntType(UnknownWidth) case t: SIntType => SIntType(UnknownWidth) - case _ => UnknownType + case _ => UnknownType() } } diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index ee974f11..647fb9c2 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -19,15 +19,33 @@ import PrimOps._ object Utils { // Is there a more elegant way to do this? - private type FlagMap = Map[Symbol, Boolean] - private val FlagMap = Map[Symbol, Boolean]().withDefaultValue(false) + private type FlagMap = Map[String, Boolean] + private val FlagMap = Map[String, Boolean]().withDefaultValue(false) + val lnOf2 = scala.math.log(2) // natural log of 2 + def ceil_log2(x: BigInt): BigInt = (x-1).bitLength + + def create_mask (dt:Type) : Type = { + dt match { + case t:VectorType => VectorType(create_mask(t.tpe),t.size) + case t:BundleType => { + val fieldss = t.fields.map { f => Field(f.name,f.flip,create_mask(f.tpe)) } + BundleType(fieldss) + } + case t:UIntType => BoolType() + case t:SIntType => BoolType() + } + } + + def error(str:String) = throw new FIRRTLException(str) def debug(node: AST)(implicit flags: FlagMap): String = { if (!flags.isEmpty) { var str = "" - if (flags('types)) { + if (flags("types")) { val tpe = node.getType - if( tpe != UnknownType ) str += s"@<t:${tpe.wipeWidth.serialize}>" + tpe match { + case t:UnknownType => str += s"@<t:${tpe.wipeWidth.serialize}>" + } } str } @@ -49,7 +67,7 @@ object Utils { //case f: Field => f.getType case t: Type => t.getType case p: Port => p.getType - case _ => UnknownType + case _ => UnknownType() } } @@ -57,6 +75,145 @@ object Utils { def serialize(implicit flags: FlagMap = FlagMap): String = op.getString } + +// ACCESSORS ========= + def gender (e:Expression) : Gender = { + e match { + case e:WRef => gender(e) + case e:WSubField => gender(e) + case e:WSubIndex => gender(e) + case e:WSubAccess => gender(e) + case e:PrimOp => MALE + case e:UIntValue => MALE + case e:SIntValue => MALE + case e:Mux => MALE + case e:ValidIf => MALE + case _ => error("Shouldn't be here") + }} + def get_gender (s:Stmt) : Gender = + s match { + case s:DefWire => BIGENDER + case s:DefRegister => BIGENDER + case s:WDefInstance => MALE + case s:DefNode => MALE + case s:DefInstance => MALE + case s:DefPoison => UNKNOWNGENDER + case s:DefMemory => MALE + case s:Begin => UNKNOWNGENDER + case s:Connect => UNKNOWNGENDER + case s:BulkConnect => UNKNOWNGENDER + case s:Stop => UNKNOWNGENDER + case s:Print => UNKNOWNGENDER + case s:Empty => UNKNOWNGENDER + case s:IsInvalid => UNKNOWNGENDER + } + def get_gender (p:Port) : Gender = + if (p.dir == Input) MALE else FEMALE + def kind (e:Expression) : Kind = + e match { + case e:WRef => e.kind + case e:WSubField => kind(e.exp) + case e:WSubIndex => kind(e.exp) + case e => ExpKind() + } + def tpe (e:Expression) : Type = + e match { + case e:WRef => e.tpe + case e:WSubField => e.tpe + case e:WSubIndex => e.tpe + case e:UIntValue => UIntType(e.width) + case e:SIntValue => SIntType(e.width) + case e:WVoid => UnknownType() + case e:WInvalid => UnknownType() + } + def get_type (s:Stmt) : Type = { + s match { + case s:DefWire => s.tpe + case s:DefPoison => s.tpe + case s:DefRegister => s.tpe + case s:DefNode => tpe(s.value) + case s:DefMemory => { + val depth = s.depth + val addr = Field("addr",Default,UIntType(IntWidth(ceil_log2(depth)))) + val en = Field("en",Default,BoolType()) + val clk = Field("clk",Default,ClockType()) + val def_data = Field("data",Default,s.dataType) + val rev_data = Field("data",Reverse,s.dataType) + val mask = Field("mask",Default,create_mask(s.dataType)) + val wmode = Field("wmode",Default,UIntType(IntWidth(1))) + val rdata = Field("rdata",Reverse,s.dataType) + val read_type = BundleType(Seq(rev_data,addr,en,clk)) + val write_type = BundleType(Seq(def_data,mask,addr,en,clk)) + val readwrite_type = BundleType(Seq(wmode,rdata,def_data,mask,addr,en,clk)) + + val mem_fields = Vector() + s.readers.foreach {x => mem_fields :+ Field(x,Reverse,read_type)} + s.writers.foreach {x => mem_fields :+ Field(x,Reverse,write_type)} + s.readwriters.foreach {x => mem_fields :+ Field(x,Reverse,readwrite_type)} + BundleType(mem_fields) + } + case s:DefInstance => UnknownType() + case _ => UnknownType() + }} + + def sMap(f:Stmt => Stmt, stmt: Stmt): Stmt = + stmt match { + case w: Conditionally => Conditionally(w.info, w.pred, f(w.conseq), f(w.alt)) + case b: Begin => Begin(b.stmts.map(f)) + case s: Stmt => s + } + def eMap(f:Expression => Expression, stmt:Stmt) : Stmt = + stmt match { + case r: DefRegister => DefRegister(r.info, r.name, r.tpe, f(r.clock), f(r.reset), f(r.init)) + case n: DefNode => DefNode(n.info, n.name, f(n.value)) + case c: Connect => Connect(c.info, f(c.loc), f(c.exp)) + case b: BulkConnect => BulkConnect(b.info, f(b.loc), f(b.exp)) + case w: Conditionally => Conditionally(w.info, f(w.pred), w.conseq, w.alt) + case i: IsInvalid => IsInvalid(i.info, f(i.exp)) + case s: Stop => Stop(s.info, s.ret, f(s.clk), f(s.en)) + case p: Print => Print(p.info, p.string, p.args.map(f), f(p.clk), f(p.en)) + case s: Stmt => s + } + def eMap(f: Expression => Expression, exp:Expression): Expression = + exp match { + case s: SubField => SubField(f(s.exp), s.name, s.tpe) + case s: SubIndex => SubIndex(f(s.exp), s.value, s.tpe) + case s: SubAccess => SubAccess(f(s.exp), f(s.index), s.tpe) + case m: Mux => Mux(f(m.cond), f(m.tval), f(m.fval), m.tpe) + case v: ValidIf => ValidIf(f(v.cond), f(v.value), v.tpe) + case p: DoPrim => DoPrim(p.op, p.args.map(f), p.consts, p.tpe) + case s: WSubField => WSubField(f(s.exp), s.name, s.tpe, s.gender) + case s: WSubIndex => WSubIndex(f(s.exp), s.value, s.tpe, s.gender) + case s: WSubAccess => WSubAccess(f(s.exp), f(s.index), s.tpe, s.gender) + case e: Expression => e + } + //private trait StmtMagnet { + // def map(stmt: Stmt): Stmt + //} + //private object StmtMagnet { + // implicit def forStmt(f: Stmt => Stmt) = new StmtMagnet { + // override def map(stmt: Stmt): Stmt = + // stmt match { + // case w: Conditionally => Conditionally(w.info, w.pred, f(w.conseq), f(w.alt)) + // case b: Begin => Begin(b.stmts.map(f)) + // case s: Stmt => s + // } + // } + // implicit def forExp(f: Expression => Expression) = new StmtMagnet { + // override def map(stmt: Stmt): Stmt = + // stmt match { + // case r: DefRegister => DefRegister(r.info, r.name, r.tpe, f(r.clock), f(r.reset), f(r.init)) + // case n: DefNode => DefNode(n.info, n.name, f(n.value)) + // case c: Connect => Connect(c.info, f(c.loc), f(c.exp)) + // case b: BulkConnect => BulkConnect(b.info, f(b.loc), f(b.exp)) + // case w: Conditionally => Conditionally(w.info, f(w.pred), w.conseq, w.alt) + // case i: IsInvalid => IsInvalid(i.info, f(i.exp)) + // case s: Stop => Stop(s.info, s.ret, f(s.clk), f(s.en)) + // case p: Print => Print(p.info, p.string, p.args.map(f), f(p.clk), f(p.en)) + // case s: Stmt => s + // } + // } + //} implicit class ExpUtils(exp: Expression) { def serialize(implicit flags: FlagMap = FlagMap): String = { val ret = exp match { @@ -70,64 +227,29 @@ object Utils { case v: ValidIf => s"validif(${v.cond.serialize}, ${v.value.serialize})" case p: DoPrim => s"${p.op.serialize}(" + (p.args.map(_.serialize) ++ p.consts.map(_.toString)).mkString(", ") + ")" + case r: WRef => r.name + case s: WSubField => s"${s.exp.serialize}.${s.name}" + case s: WSubIndex => s"${s.exp.serialize}[${s.value}]" + case s: WSubAccess => s"${s.exp.serialize}[${s.index.serialize}]" } ret + debug(exp) } - - def map(f: Expression => Expression): Expression = - exp match { - case s: SubField => SubField(f(s.exp), s.name, s.tpe) - case s: SubIndex => SubIndex(f(s.exp), s.value, s.tpe) - case s: SubAccess => SubAccess(f(s.exp), f(s.index), s.tpe) - case m: Mux => Mux(f(m.cond), f(m.tval), f(m.fval), m.tpe) - case v: ValidIf => ValidIf(f(v.cond), f(v.value), v.tpe) - case p: DoPrim => DoPrim(p.op, p.args.map(f), p.consts, p.tpe) - case e: Expression => e - } - - def getType(): Type = { - exp match { - case v: UIntValue => UIntType(UnknownWidth) - case v: SIntValue => SIntType(UnknownWidth) - case r: Ref => r.tpe - case s: SubField => s.tpe - case s: SubIndex => s.tpe - case s: SubAccess => s.tpe - case p: DoPrim => p.tpe - case m: Mux => m.tpe - case v: ValidIf => v.tpe - } - } } - // Some Scala implicit magic to solve type erasure on Stmt map function overloading - private trait StmtMagnet { - def map(stmt: Stmt): Stmt - } - private object StmtMagnet { - implicit def forStmt(f: Stmt => Stmt) = new StmtMagnet { - override def map(stmt: Stmt): Stmt = - stmt match { - case w: Conditionally => Conditionally(w.info, w.pred, f(w.conseq), f(w.alt)) - case b: Begin => Begin(b.stmts.map(f)) - case s: Stmt => s - } - } - implicit def forExp(f: Expression => Expression) = new StmtMagnet { - override def map(stmt: Stmt): Stmt = - stmt match { - case r: DefRegister => DefRegister(r.info, r.name, r.tpe, f(r.clock), f(r.reset), f(r.init)) - case n: DefNode => DefNode(n.info, n.name, f(n.value)) - case c: Connect => Connect(c.info, f(c.loc), f(c.exp)) - case b: BulkConnect => BulkConnect(b.info, f(b.loc), f(b.exp)) - case w: Conditionally => Conditionally(w.info, f(w.pred), w.conseq, w.alt) - case i: IsInvalid => IsInvalid(i.info, f(i.exp)) - case s: Stop => Stop(s.info, s.ret, f(s.clk), f(s.en)) - case p: Print => Print(p.info, p.string, p.args.map(f), f(p.clk), f(p.en)) - case s: Stmt => s - } - } - } + // def map(f: Expression => Expression): Expression = + // exp match { + // case s: SubField => SubField(f(s.exp), s.name, s.tpe) + // case s: SubIndex => SubIndex(f(s.exp), s.value, s.tpe) + // case s: SubAccess => SubAccess(f(s.exp), f(s.index), s.tpe) + // case m: Mux => Mux(f(m.cond), f(m.tval), f(m.fval), m.tpe) + // case v: ValidIf => ValidIf(f(v.cond), f(v.value), v.tpe) + // case p: DoPrim => DoPrim(p.op, p.args.map(f), p.consts, p.tpe) + // case s: WSubField => SubField(f(s.exp), s.name, s.tpe, s.gender) + // case s: WSubIndex => SubIndex(f(s.exp), s.value, s.tpe, s.gender) + // case s: WSubAccess => SubAccess(f(s.exp), f(s.index), s.tpe, s.gender) + // case e: Expression => e + // } + //} implicit class StmtUtils(stmt: Stmt) { def serialize(implicit flags: FlagMap = FlagMap): String = @@ -141,6 +263,7 @@ object Utils { } str case i: DefInstance => s"inst ${i.name} of ${i.module}" + case i: WDefInstance => s"inst ${i.name} of ${i.module}" case m: DefMemory => { val str = new StringBuilder(s"mem ${m.name} : " + newline) withIndent { @@ -164,11 +287,14 @@ object Utils { case w: Conditionally => { var str = new StringBuilder(s"when ${w.pred.serialize} : ") withIndent { str ++= w.conseq.serialize } - if( w.alt != Empty ) { - str ++= newline + "else :" - withIndent { str ++= w.alt.serialize } + w.alt match { + case s:Empty => str.result + case s => { + str ++= newline + "else :" + withIndent { str ++= w.alt.serialize } + str.result + } } - str.result } case b: Begin => { val s = new StringBuilder @@ -179,20 +305,20 @@ object Utils { case s: Stop => s"stop(${s.clk.serialize}, ${s.en.serialize}, ${s.ret})" case p: Print => s"printf(${p.clk.serialize}, ${p.en.serialize}, ${p.string}" + (if (p.args.nonEmpty) p.args.map(_.serialize).mkString(", ", ", ", "") else "") + ")" - case Empty => "skip" + case s:Empty => "skip" } ret + debug(stmt) } // Using implicit types to allow overloading of function type to map, see StmtMagnet above - def map[T](f: T => T)(implicit magnet: (T => T) => StmtMagnet): Stmt = magnet(f).map(stmt) + //def map[T](f: T => T)(implicit magnet: (T => T) => StmtMagnet): Stmt = magnet(f).map(stmt) def getType(): Type = stmt match { case s: DefWire => s.tpe case s: DefRegister => s.tpe case s: DefMemory => s.dataType - case _ => UnknownType + case _ => UnknownType() } } @@ -242,9 +368,9 @@ object Utils { def serialize(implicit flags: FlagMap = FlagMap): String = { val commas = ", " // for mkString in BundleType val s = t match { - case ClockType => "Clock" + case c:ClockType => "Clock" //case UnknownType => "UnknownType" - case UnknownType => "?" + case u:UnknownType => "?" case t: UIntType => s"UInt${t.width.serialize}" case t: SIntType => s"SInt${t.width.serialize}" case t: BundleType => s"{ ${t.fields.map(_.serialize).mkString(commas)}}" @@ -256,7 +382,7 @@ object Utils { def getType(): Type = t match { case v: VectorType => v.tpe - case tpe: Type => UnknownType + case tpe: Type => UnknownType() } def wipeWidth(): Type = @@ -292,19 +418,23 @@ object Utils { implicit class ModuleUtils(m: Module) { def serialize(implicit flags: FlagMap = FlagMap): String = { - var s = new StringBuilder(s"module ${m.name} : ") - withIndent { - s ++= m.ports.map(newline ++ _.serialize).mkString - s ++= m.stmt.serialize + m match { + case m:InModule => { + var s = new StringBuilder(s"module ${m.name} : ") + withIndent { + s ++= m.ports.map(newline ++ _.serialize).mkString + s ++= m.body.serialize + } + s ++= debug(m) + s.toString + } } - s ++= debug(m) - s.toString } } implicit class CircuitUtils(c: Circuit) { def serialize(implicit flags: FlagMap = FlagMap): String = { - var s = new StringBuilder(s"circuit ${c.name} : ") + var s = new StringBuilder(s"circuit ${c.main} : ") withIndent { s ++= newline ++ c.modules.map(_.serialize).mkString(newline + newline) } s ++= newline ++ newline s ++= debug(c) diff --git a/src/main/scala/firrtl/Visitor.scala b/src/main/scala/firrtl/Visitor.scala index ad8d24f2..96d77485 100644 --- a/src/main/scala/firrtl/Visitor.scala +++ b/src/main/scala/firrtl/Visitor.scala @@ -43,13 +43,13 @@ class Visitor(val fullFilename: String) extends FIRRTLBaseVisitor[AST] FileInfo(filename, ctx.getStart().getLine(), ctx.getStart().getCharPositionInLine()) private def visitCircuit[AST](ctx: FIRRTLParser.CircuitContext): Circuit = - Circuit(getInfo(ctx), ctx.id.getText, ctx.module.map(visitModule)) + Circuit(getInfo(ctx), ctx.module.map(visitModule), (ctx.id.getText)) private def visitModule[AST](ctx: FIRRTLParser.ModuleContext): Module = - Module(getInfo(ctx), ctx.id.getText, ctx.port.map(visitPort), visitBlock(ctx.block)) + InModule(getInfo(ctx), (ctx.id.getText), ctx.port.map(visitPort), visitBlock(ctx.block)) private def visitPort[AST](ctx: FIRRTLParser.PortContext): Port = - Port(getInfo(ctx), ctx.id.getText, visitDir(ctx.dir), visitType(ctx.`type`)) + Port(getInfo(ctx), (ctx.id.getText), visitDir(ctx.dir), visitType(ctx.`type`)) private def visitDir[AST](ctx: FIRRTLParser.DirContext): Direction = ctx.getText match { @@ -64,7 +64,7 @@ class Visitor(val fullFilename: String) extends FIRRTLBaseVisitor[AST] else UIntType( UnknownWidth ) case "SInt" => if (ctx.getChildCount > 1) SIntType(IntWidth(string2BigInt(ctx.IntLit.getText))) else SIntType( UnknownWidth ) - case "Clock" => ClockType + case "Clock" => ClockType() case "{" => BundleType(ctx.field.map(visitField)) case _ => new VectorType( visitType(ctx.`type`), string2BigInt(ctx.IntLit.getText) ) } @@ -72,7 +72,7 @@ class Visitor(val fullFilename: String) extends FIRRTLBaseVisitor[AST] private def visitField[AST](ctx: FIRRTLParser.FieldContext): Field = { val flip = if(ctx.getChild(0).getText == "flip") Reverse else Default - Field(ctx.id.getText, flip, visitType(ctx.`type`)) + Field((ctx.id.getText), flip, visitType(ctx.`type`)) } @@ -107,10 +107,10 @@ class Visitor(val fullFilename: String) extends FIRRTLBaseVisitor[AST] } // Each memory field value has been left as ParseTree type, need to convert // TODO Improve? Remove dynamic typecast of data-type - DefMemory(info, ctx.id(0).getText, visitType(map("data-type").head.asInstanceOf[FIRRTLParser.TypeContext]), + DefMemory(info, (ctx.id(0).getText), visitType(map("data-type").head.asInstanceOf[FIRRTLParser.TypeContext]), string2Int(map("depth").head.getText), string2Int(map("write-latency").head.getText), - string2Int(map("read-latency").head.getText), map.getOrElse("reader", Seq()).map(_.getText), - map.getOrElse("writer", Seq()).map(_.getText), map.getOrElse("readwriter", Seq()).map(_.getText)) + string2Int(map("read-latency").head.getText), map.getOrElse("reader", Seq()).map(x => (x.getText)), + map.getOrElse("writer", Seq()).map(x => (x.getText)), map.getOrElse("readwriter", Seq()).map(x => (x.getText))) } // visitStmt @@ -118,25 +118,25 @@ class Visitor(val fullFilename: String) extends FIRRTLBaseVisitor[AST] val info = getInfo(ctx) ctx.getChild(0).getText match { - case "wire" => DefWire(info, ctx.id(0).getText, visitType(ctx.`type`(0))) + case "wire" => DefWire(info, (ctx.id(0).getText), visitType(ctx.`type`(0))) case "reg" => { - val name = ctx.id(0).getText + val name = (ctx.id(0).getText) val tpe = visitType(ctx.`type`(0)) val reset = if (ctx.exp(1) != null) visitExp(ctx.exp(1)) else UIntValue(0, IntWidth(1)) val init = if (ctx.exp(2) != null) visitExp(ctx.exp(2)) else Ref(name, tpe) DefRegister(info, name, tpe, visitExp(ctx.exp(0)), reset, init) } case "mem" => visitMem(ctx) - case "inst" => DefInstance(info, ctx.id(0).getText, ctx.id(1).getText) - case "node" => DefNode(info, ctx.id(0).getText, visitExp(ctx.exp(0))) + case "inst" => DefInstance(info, (ctx.id(0).getText), (ctx.id(1).getText)) + case "node" => DefNode(info, (ctx.id(0).getText), visitExp(ctx.exp(0))) case "when" => { - val alt = if (ctx.block.length > 1) visitBlock(ctx.block(1)) else Empty + val alt = if (ctx.block.length > 1) visitBlock(ctx.block(1)) else Empty() Conditionally(info, visitExp(ctx.exp(0)), visitBlock(ctx.block(0)), alt) } case "stop(" => Stop(info, string2Int(ctx.IntLit(0).getText), visitExp(ctx.exp(0)), visitExp(ctx.exp(1))) case "printf(" => Print(info, ctx.StringLit.getText, ctx.exp.drop(2).map(visitExp), visitExp(ctx.exp(0)), visitExp(ctx.exp(1))) - case "skip" => Empty + case "skip" => Empty() // If we don't match on the first child, try the next one case _ => { ctx.getChild(1).getText match { @@ -157,7 +157,7 @@ class Visitor(val fullFilename: String) extends FIRRTLBaseVisitor[AST] // - Add validif private def visitExp[AST](ctx: FIRRTLParser.ExpContext): Expression = if( ctx.getChildCount == 1 ) - Ref(ctx.getText, UnknownType) + Ref((ctx.getText), UnknownType()) else ctx.getChild(0).getText match { case "UInt" => { // This could be better @@ -174,17 +174,17 @@ class Visitor(val fullFilename: String) extends FIRRTLBaseVisitor[AST] else (UnknownWidth, string2BigInt(ctx.IntLit(0).getText)) SIntValue(value, width) } - case "validif(" => ValidIf(visitExp(ctx.exp(0)), visitExp(ctx.exp(1)), UnknownType) - case "mux(" => Mux(visitExp(ctx.exp(0)), visitExp(ctx.exp(1)), visitExp(ctx.exp(2)), UnknownType) + case "validif(" => ValidIf(visitExp(ctx.exp(0)), visitExp(ctx.exp(1)), UnknownType()) + case "mux(" => Mux(visitExp(ctx.exp(0)), visitExp(ctx.exp(1)), visitExp(ctx.exp(2)), UnknownType()) case _ => ctx.getChild(1).getText match { - case "." => new SubField(visitExp(ctx.exp(0)), ctx.id.getText, UnknownType) + case "." => new SubField(visitExp(ctx.exp(0)), (ctx.id.getText), UnknownType()) case "[" => if (ctx.exp(1) == null) - new SubIndex(visitExp(ctx.exp(0)), string2BigInt(ctx.IntLit(0).getText), UnknownType) - else new SubAccess(visitExp(ctx.exp(0)), visitExp(ctx.exp(1)), UnknownType) + new SubIndex(visitExp(ctx.exp(0)), string2BigInt(ctx.IntLit(0).getText), UnknownType()) + else new SubAccess(visitExp(ctx.exp(0)), visitExp(ctx.exp(1)), UnknownType()) // Assume primop case _ => DoPrim(visitPrimop(ctx.primop), ctx.exp.map(visitExp), - ctx.IntLit.map(x => string2BigInt(x.getText)), UnknownType) + ctx.IntLit.map(x => string2BigInt(x.getText)), UnknownType()) } } diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala new file mode 100644 index 00000000..27ec5516 --- /dev/null +++ b/src/main/scala/firrtl/WIR.scala @@ -0,0 +1,36 @@ + +package firrtl + +import scala.collection.Seq +import Utils._ + + +trait Kind +case class WireKind() extends Kind +case class PoisonKind() extends Kind +case class RegKind() extends Kind +case class InstanceKind() extends Kind +case class PortKind() extends Kind +case class NodeKind() extends Kind +case class MemKind(ports:Seq[String]) extends Kind +case class ExpKind() extends Kind + +trait Gender +case object MALE extends Gender +case object FEMALE extends Gender +case object BIGENDER extends Gender +case object UNKNOWNGENDER extends Gender + +case class BoolType() extends Type { UIntType(IntWidth(1)) } +case class WRef(name:String,tpe:Type,kind:Kind,gender:Gender) extends Expression +case class WSubField(exp:Expression,name:String,tpe:Type,gender:Gender) extends Expression +case class WSubIndex(exp:Expression,value:BigInt,tpe:Type,gender:Gender) extends Expression +case class WSubAccess(exp:Expression,index:Expression,tpe:Type,gender:Gender) extends Expression +case class WVoid() extends Expression +case class WInvalid() extends Expression + +case class WDefInstance(info:Info,name:String,module:String,tpe:Type) extends Stmt + +//case class IntWidth(width: BigInt) extends Width +//case object UnknownWidth extends Width + diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza index 7a94d59d..6e8d89ad 100644 --- a/src/main/stanza/passes.stanza +++ b/src/main/stanza/passes.stanza @@ -68,11 +68,6 @@ public defstruct WSubAccess <: Expression : index: Expression type: Type with: (as-method => true) gender: Gender with: (as-method => true) -defstruct WIndexer <: Expression : - exps: List<Expression> - index: Expression - type: Type with: (as-method => true) - gender : Gender with: (as-method => true) public defstruct WVoid <: Expression public defstruct WInvalid <: Expression public defstruct WDefInstance <: Stmt : @@ -102,11 +97,6 @@ defmethod kind (e:Expression) : (e:WRef) : kind(e) (e:WSubField) : kind(exp(e)) (e:WSubIndex) : kind(exp(e)) - (e:WIndexer) : - val k = kind(exps(e)[0]) - for x in exps(e) do : - if k != kind(x) : error("All kinds of exps of WIndexer must be the same") - k (e) : ExpKind() defmethod info (stmt:Begin) -> FileInfo : FileInfo() @@ -136,11 +126,6 @@ defmethod equal? (e1:Expression,e2:Expression) -> True|False : (index(e1) == index(e2)) and (exp(e1) == exp(e2)) (e1:WVoid,e2:WVoid) : true (e1:WInvalid,e2:WInvalid) : true - (e1:WIndexer,e2:WIndexer) : - var bool = (length(exps(e1)) == length(exps(e2))) - for (e1* in exps(e1),e2* in exps(e2)) do : - bool = bool and (e1* == e2*) - bool and (index(e1) == index(e2)) (e1:DoPrim,e2:DoPrim) : var are-equal? = op(e1) == op(e2) for (x in args(e1),y in args(e2)) do : @@ -294,9 +279,6 @@ defmethod print (o:OutputStream, e:WVoid) : defmethod print (o:OutputStream, e:WInvalid) : print(o,"INVALID") print-debug(o,e as ?) -defmethod print (o:OutputStream, c:WIndexer) : - print-all(o, [exps(c) "[" index(c) "]"]) - print-debug(o,c as ?) defmethod print (o:OutputStream, c:WDefInstance) : print-all(o, ["inst " name(c) " of " module(c) " : " type(c)]) print-debug(o,c as ?) @@ -795,10 +777,6 @@ defn resolve-genders (c:Circuit) : val exp* = resolve-e(exp(e),g) val index* = resolve-e(index(e),MALE) WSubAccess(exp*,index*,type(e),g) - (e:WIndexer) : - val exps* = map(resolve-e{_,g},exps(e)) - val index* = resolve-e(index(e),MALE) - WIndexer(exps*,index*,type(e),g) (e) : map(resolve-e{_,g},e) defn resolve-s (s:Stmt) -> Stmt : @@ -2280,7 +2258,6 @@ defn root-ref (e:Expression) -> WRef : match(e) : (e:WRef) : e (e:WSubField|WSubIndex|WSubAccess) : root-ref(exp(e)) - (e:WIndexer) : root-ref(exps(e)[0]) ;------------- Pass ------------------ |
