diff options
| author | azidar | 2016-02-08 22:47:43 -0800 |
|---|---|---|
| committer | azidar | 2016-02-09 18:57:07 -0800 |
| commit | a9afec2145fe27a26c51fca7e169495114c5108d (patch) | |
| tree | 39b232e7bd67cec9c8a65807d92c51b5a44ad764 /src | |
| parent | 32f26d3939980644ddd573c1fcf1dd985a150947 (diff) | |
Added chirrtl passes, need to update parser
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Compiler.scala | 3 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Driver.scala | 1 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Emitter.scala | 24 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Passes.scala | 490 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 100 | ||||
| -rw-r--r-- | src/main/scala/firrtl/WIR.scala | 33 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 512 | ||||
| -rw-r--r-- | src/main/stanza/passes.stanza | 4 |
8 files changed, 809 insertions, 358 deletions
diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala index 2998232f..78ea644d 100644 --- a/src/main/scala/firrtl/Compiler.scala +++ b/src/main/scala/firrtl/Compiler.scala @@ -23,6 +23,9 @@ object VerilogCompiler extends Compiler { val passes = Seq( //CheckHighForm, //FromCHIRRTL, + CInferTypes, + CInferMDir, + RemoveCHIRRTL, ToWorkingIR, ResolveKinds, InferTypes, diff --git a/src/main/scala/firrtl/Driver.scala b/src/main/scala/firrtl/Driver.scala index 14861f5e..69eb163c 100644 --- a/src/main/scala/firrtl/Driver.scala +++ b/src/main/scala/firrtl/Driver.scala @@ -10,7 +10,6 @@ import com.typesafe.scalalogging.LazyLogging import Utils._ import DebugUtils._ -import Passes._ object Driver extends LazyLogging { private val usage = """ diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index 85a2d759..499ab6db 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -13,7 +13,7 @@ import Utils._ import firrtl.passes._ import WrappedExpression._ // Datastructures -import scala.collection.mutable.HashMap +import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ArrayBuffer trait Emitter extends LazyLogging { @@ -34,7 +34,7 @@ object VerilogEmitter extends Emitter { def wref (n:String,t:Type) = WRef(n,t,ExpKind(),UNKNOWNGENDER) def escape (s:String) : String = { val sx = ArrayBuffer[String]() - sx += "\"" + //sx += '"'.toString var percent:Boolean = false for (c <- s) { if (c == '\n') sx += "\\n" @@ -43,7 +43,7 @@ object VerilogEmitter extends Emitter { } percent = (c == '%') } - sx += "\"" + //sx += '"'.toString sx.reduce(_ + _) } def remove_root (ex:Expression) : Expression = { @@ -190,7 +190,7 @@ object VerilogEmitter extends Emitter { case (t:SIntType) => Seq(cast(a0())) } } - case NOT_OP => Seq("~",a0()) + case NOT_OP => Seq("~ ",a0()) case AND_OP => Seq(cast_as(a0())," & ", cast_as(a1())) case OR_OP => Seq(cast_as(a0())," | ", cast_as(a1())) case XOR_OP => Seq(cast_as(a0())," ^ ", cast_as(a1())) @@ -236,14 +236,13 @@ object VerilogEmitter extends Emitter { def emit_verilog (m:InModule) : Module = { mname = m.name - val netlist = HashMap[WrappedExpression,Expression]() + val netlist = LinkedHashMap[WrappedExpression,Expression]() val simlist = ArrayBuffer[Stmt]() - val namehash = sym_hash def build_netlist (s:Stmt) : Stmt = { s match { case (s:Connect) => netlist(s.loc) = s.exp case (s:IsInvalid) => { - val n = firrtl_gensym("GEN",namehash) + val n = firrtl_gensym_module(mname) val e = wref(n,tpe(s.exp)) netlist(s.exp) = e } @@ -261,7 +260,7 @@ object VerilogEmitter extends Emitter { val declares = ArrayBuffer[Seq[Any]]() val instdeclares = ArrayBuffer[Seq[Any]]() val assigns = ArrayBuffer[Seq[Any]]() - val at_clock = HashMap[Expression,ArrayBuffer[Seq[Any]]]() + val at_clock = LinkedHashMap[Expression,ArrayBuffer[Seq[Any]]]() val initials = ArrayBuffer[Seq[Any]]() val simulates = ArrayBuffer[Seq[Any]]() def declare (b:String,n:String,t:Type) = { @@ -320,11 +319,6 @@ object VerilogEmitter extends Emitter { declare("wire",lowered_name(e),tpe(e)) val ex = WRef(lowered_name(e),tpe(e),kind(e),gender(e)) if (gender(e) == FEMALE) { - if (lowered_name(e) == "interconnect_clk") { - for (x <- netlist) { - print("(" + x._1.e1.serialize() + " -> " + x._2.e1.serialize() + ")") - } - } assign(ex,netlist(e)) } } @@ -341,13 +335,13 @@ object VerilogEmitter extends Emitter { Seq("$fdisplay(32'h80000002,\"",ret,"\");$finish;") } def printf (str:String,args:Seq[Expression]) : Seq[Any] = { - val strx = (Seq(escape(str)) ++ args).reduce(_ + "," + _) + val strx = (Seq(escape(str)) ++ args).reduce(Seq(_, ",", _)) Seq("$fwrite(32'h80000002,",strx,");") } def delay (e:Expression, n:Int, clk:Expression) : Expression = { var ex = e for (i <- 0 until n) { - val name = firrtl_gensym("GEN",namehash) + val name = firrtl_gensym_module(mname) declare("reg",name,tpe(e)) val exx = WRef(name,tpe(e),ExpKind(),UNKNOWNGENDER) update(exx,ex,clk,one) diff --git a/src/main/scala/firrtl/Passes.scala b/src/main/scala/firrtl/Passes.scala index 28ad876c..babe18d0 100644 --- a/src/main/scala/firrtl/Passes.scala +++ b/src/main/scala/firrtl/Passes.scala @@ -10,248 +10,248 @@ import DebugUtils._ import PrimOps._ -@deprecated("This object will be replaced with package firrtl.passes") -object Passes extends LazyLogging { - - - // TODO Perhaps we should get rid of Logger since this map would be nice - ////private val defaultLogger = Logger() - //private def mapNameToPass = Map[String, Circuit => Circuit] ( - // "infer-types" -> inferTypes - //) - var mname = "" - //def nameToPass(name: String): Circuit => Circuit = { - //mapNameToPass.getOrElse(name, throw new Exception("No Standard FIRRTL Pass of name " + name)) - //name match { - //case "to-working-ir" => toWorkingIr - //case "infer-types" => inferTypes - // errrrrrrrrrr... - //case "renameall" => renameall(Map()) - //} - //} - - private def toField(p: Port): Field = { - logger.debug(s"toField called on port ${p.serialize}") - p.direction match { - case INPUT => Field(p.name, REVERSE, p.tpe) - case OUTPUT => Field(p.name, DEFAULT, p.tpe) - } - } - // ============== RESOLVE ALL =================== - def resolve (c:Circuit) = {c - //val passes = Seq( - // toWorkingIr _, - // resolveKinds _, - // inferTypes _, - // resolveGenders _, - // pullMuxes _, - // expandConnects _, - // removeAccesses _) - //val names = Seq( - // "To Working IR", - // "Resolve Kinds", - // "Infer Types", - // "Resolve Genders", - // "Pull Muxes", - // "Expand Connects", - // "Remove Accesses") - //var c_BANG = c - //(names, passes).zipped.foreach { - // (n,p) => { - // println("Starting " + n) - // c_BANG = p(c_BANG) - // println(c_BANG.serialize()) - // println("Finished " + n) - // } - //} - //c_BANG - } - - - // ============== RESOLVE KINDS ================== - // =============================================== - - // ============== INFER TYPES ================== - - // ------------------ Utils ------------------------- - - -// =================== RESOLVE GENDERS ======================= - // =============================================== - - // =============== PULL MUXES ==================== - // =============================================== - - - - // ============ EXPAND CONNECTS ================== - // ---------------- UTILS ------------------ - - - //---------------- Pass --------------------- - - // =============================================== - - - - // ============ REMOVE ACCESSES ================== - // ---------------- UTILS ------------------ - - - /** INFER TYPES - * - * This pass infers the type field in all IR nodes by updating - * and passing an environment to all statements in pre-order - * traversal, and resolving types in expressions in post- - * order traversal. - * Type propagation for primary ops are defined in Primops.scala. - * Type errors are not checked in this pass, as this is - * postponed for a later/earlier pass. - */ - // input -> flip - //private type TypeMap = Map[String, Type] - //private val TypeMap = Map[String, Type]().withDefaultValue(UnknownType) - //private def getBundleSubtype(t: Type, name: String): Type = { - // t match { - // case b: BundleType => { - // val tpe = b.fields.find( _.name == name ) - // if (tpe.isEmpty) UnknownType - // else tpe.get.tpe - // } - // case _ => UnknownType - // } - //} - //private def getVectorSubtype(t: Type): Type = t.getType // Added for clarity - //// TODO Add genders - //private def inferExpTypes(typeMap: TypeMap)(exp: Expression): Expression = { - // logger.debug(s"inferTypes called on ${exp.getClass.getSimpleName}") - // exp.map(inferExpTypes(typeMap)) match { - // case e: UIntValue => e - // case e: SIntValue => e - // case e: Ref => Ref(e.name, typeMap(e.name)) - // case e: SubField => SubField(e.exp, e.name, getBundleSubtype(e.exp.getType, e.name)) - // case e: SubIndex => SubIndex(e.exp, e.value, getVectorSubtype(e.exp.getType)) - // case e: SubAccess => SubAccess(e.exp, e.index, getVectorSubtype(e.exp.getType)) - // case e: DoPrim => lowerAndTypePrimOp(e) - // case e: Expression => e - // } - //} - //private def inferTypes(typeMap: TypeMap, stmt: Stmt): (Stmt, TypeMap) = { - // logger.debug(s"inferTypes called on ${stmt.getClass.getSimpleName} ") - // stmt.map(inferExpTypes(typeMap)) match { - // case b: Begin => { - // var tMap = typeMap - // // TODO FIXME is map correctly called in sequential order - // val body = b.stmts.map { s => - // val ret = inferTypes(tMap, s) - // tMap = ret._2 - // ret._1 - // } - // (Begin(body), tMap) - // } - // case s: DefWire => (s, typeMap ++ Map(s.name -> s.tpe)) - // case s: DefRegister => (s, typeMap ++ Map(s.name -> s.tpe)) - // case s: DefMemory => (s, typeMap ++ Map(s.name -> s.dataType)) - // case s: DefInstance => (s, typeMap ++ Map(s.name -> typeMap(s.module))) - // case s: DefNode => (s, typeMap ++ Map(s.name -> s.value.getType)) - // case s: Conditionally => { // TODO Check: Assuming else block won't see when scope - // val (conseq, cMap) = inferTypes(typeMap, s.conseq) - // val (alt, aMap) = inferTypes(typeMap, s.alt) - // (Conditionally(s.info, s.pred, conseq, alt), cMap ++ aMap) - // } - // case s: Stmt => (s, typeMap) - // } - //} - //private def inferTypes(typeMap: TypeMap, m: Module): Module = { - // logger.debug(s"inferTypes called on module ${m.name}") - - // val pTypeMap = m.ports.map( p => p.name -> p.tpe ).toMap - - // Module(m.info, m.name, m.ports, inferTypes(typeMap ++ pTypeMap, m.stmt)._1) - //} - //def inferTypes(c: Circuit): Circuit = { - // logger.debug(s"inferTypes called on circuit ${c.name}") - - // // initialize typeMap with each module of circuit mapped to their bundled IO (ports converted to fields) - // val typeMap = c.modules.map(m => m.name -> BundleType(m.ports.map(toField(_)))).toMap - - // //val typeMap = c.modules.flatMap(buildTypeMap).toMap - // Circuit(c.info, c.name, c.modules.map(inferTypes(typeMap, _))) - //} - - //def renameall(s : String)(implicit map : Map[String,String]) : String = - // map getOrElse (s, s) - - //def renameall(e : Expression)(implicit map : Map[String,String]) : Expression = { - // logger.debug(s"renameall called on expression ${e.toString}") - // e match { - // case p : Ref => - // Ref(renameall(p.name), p.tpe) - // case p : SubField => - // SubField(renameall(p.exp), renameall(p.name), p.tpe) - // case p : SubIndex => - // SubIndex(renameall(p.exp), p.value, p.tpe) - // case p : SubAccess => - // SubAccess(renameall(p.exp), renameall(p.index), p.tpe) - // case p : Mux => - // Mux(renameall(p.cond), renameall(p.tval), renameall(p.fval), p.tpe) - // case p : ValidIf => - // ValidIf(renameall(p.cond), renameall(p.value), p.tpe) - // case p : DoPrim => - // println( p.args.map(x => renameall(x)) ) - // DoPrim(p.op, p.args.map(renameall), p.consts, p.tpe) - // case p : Expression => p - // } - //} - - //def renameall(s : Stmt)(implicit map : Map[String,String]) : Stmt = { - // logger.debug(s"renameall called on statement ${s.toString}") - - // s match { - // case p : DefWire => - // DefWire(p.info, renameall(p.name), p.tpe) - // case p: DefRegister => - // DefRegister(p.info, renameall(p.name), p.tpe, p.clock, p.reset, p.init) - // case p : DefMemory => - // DefMemory(p.info, renameall(p.name), p.dataType, p.depth, p.writeLatency, p.readLatency, - // p.readers, p.writers, p.readwriters) - // case p : DefInstance => - // DefInstance(p.info, renameall(p.name), renameall(p.module)) - // case p : DefNode => - // DefNode(p.info, renameall(p.name), renameall(p.value)) - // case p : Connect => - // Connect(p.info, renameall(p.loc), renameall(p.exp)) - // case p : BulkConnect => - // BulkConnect(p.info, renameall(p.loc), renameall(p.exp)) - // case p : IsInvalid => - // IsInvalid(p.info, renameall(p.exp)) - // case p : Stop => - // Stop(p.info, p.ret, renameall(p.clk), renameall(p.en)) - // case p : Print => - // Print(p.info, p.string, p.args.map(renameall), renameall(p.clk), renameall(p.en)) - // case p : Conditionally => - // Conditionally(p.info, renameall(p.pred), renameall(p.conseq), renameall(p.alt)) - // case p : Begin => - // Begin(p.stmts.map(renameall)) - // case p : Stmt => p - // } - //} - - //def renameall(p : Port)(implicit map : Map[String,String]) : Port = { - // logger.debug(s"renameall called on port ${p.name}") - // Port(p.info, renameall(p.name), p.dir, p.tpe) - //} - - //def renameall(m : Module)(implicit map : Map[String,String]) : Module = { - // logger.debug(s"renameall called on module ${m.name}") - // Module(m.info, renameall(m.name), m.ports.map(renameall(_)), renameall(m.stmt)) - //} - - //def renameall(map : Map[String,String]) : Circuit => Circuit = { - // c => { - // implicit val imap = map - // logger.debug(s"renameall called on circuit ${c.name} with %{renameto}") - // Circuit(c.info, renameall(c.name), c.modules.map(renameall(_))) - // } - //} -} +//@deprecated("This object will be replaced with package firrtl.passes") +//object Passes extends LazyLogging { +// +// +// // TODO Perhaps we should get rid of Logger since this map would be nice +// ////private val defaultLogger = Logger() +// //private def mapNameToPass = Map[String, Circuit => Circuit] ( +// // "infer-types" -> inferTypes +// //) +// var mname = "" +// //def nameToPass(name: String): Circuit => Circuit = { +// //mapNameToPass.getOrElse(name, throw new Exception("No Standard FIRRTL Pass of name " + name)) +// //name match { +// //case "to-working-ir" => toWorkingIr +// //case "infer-types" => inferTypes +// // errrrrrrrrrr... +// //case "renameall" => renameall(Map()) +// //} +// //} +// +// private def toField(p: Port): Field = { +// logger.debug(s"toField called on port ${p.serialize}") +// p.direction match { +// case INPUT => Field(p.name, REVERSE, p.tpe) +// case OUTPUT => Field(p.name, DEFAULT, p.tpe) +// } +// } +// // ============== RESOLVE ALL =================== +// def resolve (c:Circuit) = {c +// //val passes = Seq( +// // toWorkingIr _, +// // resolveKinds _, +// // inferTypes _, +// // resolveGenders _, +// // pullMuxes _, +// // expandConnects _, +// // removeAccesses _) +// //val names = Seq( +// // "To Working IR", +// // "Resolve Kinds", +// // "Infer Types", +// // "Resolve Genders", +// // "Pull Muxes", +// // "Expand Connects", +// // "Remove Accesses") +// //var c_BANG = c +// //(names, passes).zipped.foreach { +// // (n,p) => { +// // println("Starting " + n) +// // c_BANG = p(c_BANG) +// // println(c_BANG.serialize()) +// // println("Finished " + n) +// // } +// //} +// //c_BANG +// } +// +// +// // ============== RESOLVE KINDS ================== +// // =============================================== +// +// // ============== INFER TYPES ================== +// +// // ------------------ Utils ------------------------- +// +// +//// =================== RESOLVE GENDERS ======================= +// // =============================================== +// +// // =============== PULL MUXES ==================== +// // =============================================== +// +// +// +// // ============ EXPAND CONNECTS ================== +// // ---------------- UTILS ------------------ +// +// +// //---------------- Pass --------------------- +// +// // =============================================== +// +// +// +// // ============ REMOVE ACCESSES ================== +// // ---------------- UTILS ------------------ +// +// +// /** INFER TYPES +// * +// * This pass infers the type field in all IR nodes by updating +// * and passing an environment to all statements in pre-order +// * traversal, and resolving types in expressions in post- +// * order traversal. +// * Type propagation for primary ops are defined in Primops.scala. +// * Type errors are not checked in this pass, as this is +// * postponed for a later/earlier pass. +// */ +// // input -> flip +// //private type TypeMap = Map[String, Type] +// //private val TypeMap = Map[String, Type]().withDefaultValue(UnknownType) +// //private def getBundleSubtype(t: Type, name: String): Type = { +// // t match { +// // case b: BundleType => { +// // val tpe = b.fields.find( _.name == name ) +// // if (tpe.isEmpty) UnknownType +// // else tpe.get.tpe +// // } +// // case _ => UnknownType +// // } +// //} +// //private def getVectorSubtype(t: Type): Type = t.getType // Added for clarity +// //// TODO Add genders +// //private def inferExpTypes(typeMap: TypeMap)(exp: Expression): Expression = { +// // logger.debug(s"inferTypes called on ${exp.getClass.getSimpleName}") +// // exp.map(inferExpTypes(typeMap)) match { +// // case e: UIntValue => e +// // case e: SIntValue => e +// // case e: Ref => Ref(e.name, typeMap(e.name)) +// // case e: SubField => SubField(e.exp, e.name, getBundleSubtype(e.exp.getType, e.name)) +// // case e: SubIndex => SubIndex(e.exp, e.value, getVectorSubtype(e.exp.getType)) +// // case e: SubAccess => SubAccess(e.exp, e.index, getVectorSubtype(e.exp.getType)) +// // case e: DoPrim => lowerAndTypePrimOp(e) +// // case e: Expression => e +// // } +// //} +// //private def inferTypes(typeMap: TypeMap, stmt: Stmt): (Stmt, TypeMap) = { +// // logger.debug(s"inferTypes called on ${stmt.getClass.getSimpleName} ") +// // stmt.map(inferExpTypes(typeMap)) match { +// // case b: Begin => { +// // var tMap = typeMap +// // // TODO FIXME is map correctly called in sequential order +// // val body = b.stmts.map { s => +// // val ret = inferTypes(tMap, s) +// // tMap = ret._2 +// // ret._1 +// // } +// // (Begin(body), tMap) +// // } +// // case s: DefWire => (s, typeMap ++ Map(s.name -> s.tpe)) +// // case s: DefRegister => (s, typeMap ++ Map(s.name -> s.tpe)) +// // case s: DefMemory => (s, typeMap ++ Map(s.name -> s.dataType)) +// // case s: DefInstance => (s, typeMap ++ Map(s.name -> typeMap(s.module))) +// // case s: DefNode => (s, typeMap ++ Map(s.name -> s.value.getType)) +// // case s: Conditionally => { // TODO Check: Assuming else block won't see when scope +// // val (conseq, cMap) = inferTypes(typeMap, s.conseq) +// // val (alt, aMap) = inferTypes(typeMap, s.alt) +// // (Conditionally(s.info, s.pred, conseq, alt), cMap ++ aMap) +// // } +// // case s: Stmt => (s, typeMap) +// // } +// //} +// //private def inferTypes(typeMap: TypeMap, m: Module): Module = { +// // logger.debug(s"inferTypes called on module ${m.name}") +// +// // val pTypeMap = m.ports.map( p => p.name -> p.tpe ).toMap +// +// // Module(m.info, m.name, m.ports, inferTypes(typeMap ++ pTypeMap, m.stmt)._1) +// //} +// //def inferTypes(c: Circuit): Circuit = { +// // logger.debug(s"inferTypes called on circuit ${c.name}") +// +// // // initialize typeMap with each module of circuit mapped to their bundled IO (ports converted to fields) +// // val typeMap = c.modules.map(m => m.name -> BundleType(m.ports.map(toField(_)))).toMap +// +// // //val typeMap = c.modules.flatMap(buildTypeMap).toMap +// // Circuit(c.info, c.name, c.modules.map(inferTypes(typeMap, _))) +// //} +// +// //def renameall(s : String)(implicit map : Map[String,String]) : String = +// // map getOrElse (s, s) +// +// //def renameall(e : Expression)(implicit map : Map[String,String]) : Expression = { +// // logger.debug(s"renameall called on expression ${e.toString}") +// // e match { +// // case p : Ref => +// // Ref(renameall(p.name), p.tpe) +// // case p : SubField => +// // SubField(renameall(p.exp), renameall(p.name), p.tpe) +// // case p : SubIndex => +// // SubIndex(renameall(p.exp), p.value, p.tpe) +// // case p : SubAccess => +// // SubAccess(renameall(p.exp), renameall(p.index), p.tpe) +// // case p : Mux => +// // Mux(renameall(p.cond), renameall(p.tval), renameall(p.fval), p.tpe) +// // case p : ValidIf => +// // ValidIf(renameall(p.cond), renameall(p.value), p.tpe) +// // case p : DoPrim => +// // println( p.args.map(x => renameall(x)) ) +// // DoPrim(p.op, p.args.map(renameall), p.consts, p.tpe) +// // case p : Expression => p +// // } +// //} +// +// //def renameall(s : Stmt)(implicit map : Map[String,String]) : Stmt = { +// // logger.debug(s"renameall called on statement ${s.toString}") +// +// // s match { +// // case p : DefWire => +// // DefWire(p.info, renameall(p.name), p.tpe) +// // case p: DefRegister => +// // DefRegister(p.info, renameall(p.name), p.tpe, p.clock, p.reset, p.init) +// // case p : DefMemory => +// // DefMemory(p.info, renameall(p.name), p.dataType, p.depth, p.writeLatency, p.readLatency, +// // p.readers, p.writers, p.readwriters) +// // case p : DefInstance => +// // DefInstance(p.info, renameall(p.name), renameall(p.module)) +// // case p : DefNode => +// // DefNode(p.info, renameall(p.name), renameall(p.value)) +// // case p : Connect => +// // Connect(p.info, renameall(p.loc), renameall(p.exp)) +// // case p : BulkConnect => +// // BulkConnect(p.info, renameall(p.loc), renameall(p.exp)) +// // case p : IsInvalid => +// // IsInvalid(p.info, renameall(p.exp)) +// // case p : Stop => +// // Stop(p.info, p.ret, renameall(p.clk), renameall(p.en)) +// // case p : Print => +// // Print(p.info, p.string, p.args.map(renameall), renameall(p.clk), renameall(p.en)) +// // case p : Conditionally => +// // Conditionally(p.info, renameall(p.pred), renameall(p.conseq), renameall(p.alt)) +// // case p : Begin => +// // Begin(p.stmts.map(renameall)) +// // case p : Stmt => p +// // } +// //} +// +// //def renameall(p : Port)(implicit map : Map[String,String]) : Port = { +// // logger.debug(s"renameall called on port ${p.name}") +// // Port(p.info, renameall(p.name), p.dir, p.tpe) +// //} +// +// //def renameall(m : Module)(implicit map : Map[String,String]) : Module = { +// // logger.debug(s"renameall called on module ${m.name}") +// // Module(m.info, renameall(m.name), m.ports.map(renameall(_)), renameall(m.stmt)) +// //} +// +// //def renameall(map : Map[String,String]) : Circuit => Circuit = { +// // c => { +// // implicit val imap = map +// // logger.debug(s"renameall called on circuit ${c.name} with %{renameto}") +// // Circuit(c.info, renameall(c.name), c.modules.map(renameall(_))) +// // } +// //} +//} diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index a0170b00..339d112c 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -14,13 +14,15 @@ package firrtl import scala.collection.mutable.StringBuilder import java.io.PrintWriter import PrimOps._ +import WrappedExpression._ +import firrtl.WrappedType._ import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap +import scala.collection.mutable.LinkedHashMap //import scala.reflect.runtime.universe._ object Utils { - - // Is there a more elegant way to do this? +// +// // Is there a more elegant way to do this? private type FlagMap = Map[String, Boolean] private val FlagMap = Map[String, Boolean]().withDefaultValue(false) implicit class WithAs[T](x: T) { @@ -37,7 +39,7 @@ object Utils { def ceil_log2(x: Int): Int = scala.math.ceil(scala.math.log(x) / scala.math.log(2)).toInt val gen_names = Map[String,Int]() val delin = "_" - val sym_hash = HashMap[String,Int]() + val sym_hash = LinkedHashMap[String,LinkedHashMap[String,Int]]() def BoolType () = { UIntType(IntWidth(1)) } val one = UIntValue(BigInt(1),IntWidth(1)) val zero = UIntValue(BigInt(0),IntWidth(1)) @@ -50,9 +52,15 @@ object Utils { val ix = if (i < 0) ((-1 * i) - 1) else i ceil_log2(ix + 1) + 1 } - def firrtl_gensym (s:String):String = { firrtl_gensym(s,HashMap[String,Int]()) } - def firrtl_gensym (sym_hash:HashMap[String,Int]):String = { firrtl_gensym("gen",sym_hash) } - def firrtl_gensym (s:String,sym_hash:HashMap[String,Int]):String = { + def firrtl_gensym (s:String):String = { firrtl_gensym(s,LinkedHashMap[String,Int]()) } + def firrtl_gensym (sym_hash:LinkedHashMap[String,Int]):String = { firrtl_gensym("GEN",sym_hash) } + def firrtl_gensym_module (s:String):String = { + val sh = sym_hash.getOrElse(s,LinkedHashMap[String,Int]()) + val name = firrtl_gensym("GEN",sh) + sym_hash(s) = sh + name + } + def firrtl_gensym (s:String,sym_hash:LinkedHashMap[String,Int]):String = { if (sym_hash contains s) { val num = sym_hash(s) + 1 sym_hash += (s -> num) @@ -62,26 +70,26 @@ object Utils { (s + delin + 0) } } - def AND (e1:Expression,e2:Expression) : Expression = { - if (e1 == e2) e1 - else if ((e1 == zero) | (e2 == zero)) zero - else if (e1 == one) e2 - else if (e2 == one) e1 - else DoPrim(AND_OP,Seq(e1,e2),Seq(),UIntType(IntWidth(1))) + def AND (e1:WrappedExpression,e2:WrappedExpression) : Expression = { + if (e1 == e2) e1.e1 + else if ((e1 == we(zero)) | (e2 == we(zero))) zero + else if (e1 == we(one)) e2.e1 + else if (e2 == we(one)) e1.e1 + else DoPrim(AND_OP,Seq(e1.e1,e2.e1),Seq(),UIntType(IntWidth(1))) } - def OR (e1:Expression,e2:Expression) : Expression = { - if (e1 == e2) e1 - else if ((e1 == one) | (e2 == one)) one - else if (e1 == zero) e2 - else if (e2 == zero) e1 - else DoPrim(OR_OP,Seq(e1,e2),Seq(),UIntType(IntWidth(1))) + def OR (e1:WrappedExpression,e2:WrappedExpression) : Expression = { + if (e1 == e2) e1.e1 + else if ((e1 == we(one)) | (e2 == we(one))) one + else if (e1 == we(zero)) e2.e1 + else if (e2 == we(zero)) e1.e1 + else DoPrim(OR_OP,Seq(e1.e1,e2.e1),Seq(),UIntType(IntWidth(1))) } def EQV (e1:Expression,e2:Expression) : Expression = { DoPrim(EQUAL_OP,Seq(e1,e2),Seq(),tpe(e1)) } - def NOT (e1:Expression) : Expression = { - if (e1 == one) zero - else if (e1 == zero) one - else DoPrim(EQUAL_OP,Seq(e1,zero),Seq(),UIntType(IntWidth(1))) + def NOT (e1:WrappedExpression) : Expression = { + if (e1 == we(one)) zero + else if (e1 == we(zero)) one + else DoPrim(EQUAL_OP,Seq(e1.e1,zero),Seq(),UIntType(IntWidth(1))) } @@ -185,6 +193,20 @@ object Utils { } //============== TYPES ================ + def mux_type (e1:Expression,e2:Expression) : Type = mux_type(tpe(e1),tpe(e2)) + def mux_type (t1:Type,t2:Type) : Type = { + if (wt(t1) == wt(t2)) { + (t1,t2) match { + case (t1:UIntType,t2:UIntType) => UIntType(UnknownWidth()) + case (t1:SIntType,t2:SIntType) => SIntType(UnknownWidth()) + case (t1:VectorType,t2:VectorType) => VectorType(mux_type(t1.tpe,t2.tpe),t1.size) + case (t1:BundleType,t2:BundleType) => + BundleType((t1.fields,t2.fields).zipped.map((f1,f2) => { + Field(f1.name,f1.flip,mux_type(f1.tpe,f2.tpe)) + })) + } + } else UnknownType() + } def mux_type_and_widths (e1:Expression,e2:Expression) : Type = mux_type_and_widths(tpe(e1),tpe(e2)) def mux_type_and_widths (t1:Type,t2:Type) : Type = { def wmax (w1:Width,w2:Width) : Width = { @@ -226,7 +248,7 @@ object Utils { } } -//===================================== +////===================================== def widthBANG (t:Type) : Width = { t match { case t:UIntType => t.width @@ -288,7 +310,7 @@ object Utils { def serialize(implicit flags: FlagMap = FlagMap): String = op.getString } -// =============== EXPANSION FUNCTIONS ================ +//// =============== EXPANSION FUNCTIONS ================ def get_size (t:Type) : Int = { t match { case (t:BundleType) => { @@ -486,6 +508,10 @@ object Utils { } def tpe (e:Expression) : Type = e match { + case e:Ref => e.tpe + case e:SubField => e.tpe + case e:SubIndex => e.tpe + case e:SubAccess => e.tpe case e:WRef => e.tpe case e:WSubField => e.tpe case e:WSubIndex => e.tpe @@ -581,6 +607,7 @@ object Utils { case i: IsInvalid => IsInvalid(i.info, f(i.exp)) case s: Stop => Stop(s.info, s.ret, f(s.clk), f(s.en)) case p: Print => Print(p.info, p.string, p.args.map(f), f(p.clk), f(p.en)) + case c: CDefMPort => CDefMPort(c.info,c.name,c.tpe,c.mem,c.exps.map(f),c.direction) case s: Stmt => s } def eMap(f: Expression => Expression, exp:Expression): Expression = @@ -621,6 +648,8 @@ object Utils { case c:DefWire => DefWire(c.info,c.name,f(c.tpe)) case c:DefRegister => DefRegister(c.info,c.name,f(c.tpe),c.clock,c.reset,c.init) case c:DefMemory => DefMemory(c.info,c.name, f(c.data_type), c.depth, c.write_latency, c.read_latency, c.readers, c.writers, c.readwriters) + case c:CDefMemory => CDefMemory(c.info,c.name, f(c.tpe), c.size, c.seq) + case c:CDefMPort => CDefMPort(c.info,c.name, f(c.tpe), c.mem, c.exps,c.direction) case c => c } } @@ -657,6 +686,8 @@ object Utils { case (c:DefNode) => DefNode(c.info,f(c.name),c.value) case (c:DefInstance) => DefInstance(c.info,f(c.name), c.module) case (c:WDefInstance) => WDefInstance(c.info,f(c.name), c.module,c.tpe) + case (c:CDefMemory) => CDefMemory(c.info,f(c.name),c.tpe,c.size,c.seq) + case (c:CDefMPort) => CDefMPort(c.info,f(c.name),c.tpe,c.mem,c.exps,c.direction) case (c) => c } } @@ -689,9 +720,9 @@ object Utils { // } // } //} - //def get-sym-hash (m:InModule) : HashMap[String,Int] = { get-sym-hash(m,Seq()) } - //def get-sym-hash (m:InModule,keywords:Seq[String]) : HashMap[String,Int] = { - // val sym-hash = HashMap[String,Int]() + //def get-sym-hash (m:InModule) : LinkedHashMap[String,Int] = { get-sym-hash(m,Seq()) } + //def get-sym-hash (m:InModule,keywords:Seq[String]) : LinkedHashMap[String,Int] = { + // val sym-hash = LinkedHashMap[String,Int]() // for (k <- keywords) { sym-hash += (k -> 0) } // def add-name (s:String) : String = { // val sx = to-string(s) @@ -847,6 +878,19 @@ object Utils { case p: Print => s"printf(${p.clk.serialize}, ${p.en.serialize}, ${p.string}" + (if (p.args.nonEmpty) p.args.map(_.serialize).mkString(", ", ", ", "") else "") + ")" case s:Empty => "skip" + case s:CDefMemory => { + if (s.seq) s"smem ${s.name} : ${s.tpe} [${s.size}]" + else s"cmem ${s.name} : ${s.tpe} [${s.size}]" + } + case s:CDefMPort => { + val dir = s.direction match { + case MInfer => "infer" + case MRead => "read" + case MWrite => "write" + case MReadWrite => "rdwr" + } + s"${dir} mport ${s.name} = ${s.mem}[${s.exps(0)}], s.exps(1)" + } } ret + debug(stmt) } diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index bbe6a235..eaa4166b 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -4,6 +4,7 @@ package firrtl import scala.collection.Seq import Utils._ import WrappedExpression._ +import WrappedWidth._ trait Kind case class WireKind() extends Kind @@ -42,10 +43,8 @@ class WrappedExpression (val e1:Expression) { we match { case (we:WrappedExpression) => { (e1,we.e1) match { - case (e1:UIntValue,e2:UIntValue) => if (e1.value == e2.value) true else false - // TODO is this necessary? width(e1) == width(e2) - case (e1:SIntValue,e2:SIntValue) => if (e1.value == e2.value) true else false - // TODO is this necessary? width(e1) == width(e2) + case (e1:UIntValue,e2:UIntValue) => if (e1.value == e2.value) eqw(e1.width,e2.width) else false + case (e1:SIntValue,e2:SIntValue) => if (e1.value == e2.value) eqw(e1.width,e2.width) else false case (e1:WRef,e2:WRef) => e1.name equals e2.name case (e1:WSubField,e2:WSubField) => (e1.name equals e2.name) && weq(e1.exp,e2.exp) case (e1:WSubIndex,e2:WSubIndex) => (e1.value == e2.value) && weq(e1.exp,e2.exp) @@ -78,6 +77,10 @@ case class MaxWidth(args:Seq[Width]) extends Width case class MinWidth(args:Seq[Width]) extends Width case class ExpWidth(arg1:Width) extends Width +object WrappedType { + def apply (t:Type) = new WrappedType(t) + def wt (t:Type) = apply(t) +} class WrappedType (val t:Type) { def wt (tx:Type) = new WrappedType(tx) override def equals (o:Any) : Boolean = { @@ -104,6 +107,13 @@ class WrappedType (val t:Type) { } } } + +object WrappedWidth { + def eqw (w1:Width,w2:Width) : Boolean = { + (new WrappedWidth(w1)) == (new WrappedWidth(w2)) + } +} + class WrappedWidth (val w:Width) { override def toString = { w match { @@ -117,9 +127,6 @@ class WrappedWidth (val w:Width) { case (w:UnknownWidth) => "?" } } - def eq (w1:Width,w2:Width) : Boolean = { - (new WrappedWidth(w1)) == (new WrappedWidth(w2)) - } def ww (w:Width) : WrappedWidth = new WrappedWidth(w) override def equals (o:Any) : Boolean = { o match { @@ -132,7 +139,7 @@ class WrappedWidth (val w:Width) { else { for (a1 <- w1.args) { var found = false - for (a2 <- w2.args) { if (eq(a1,a2)) found = true } + for (a2 <- w2.args) { if (eqw(a1,a2)) found = true } if (found == false) ret = false } } @@ -144,7 +151,7 @@ class WrappedWidth (val w:Width) { else { for (a1 <- w1.args) { var found = false - for (a2 <- w2.args) { if (eq(a1,a2)) found = true } + for (a2 <- w2.args) { if (eqw(a1,a2)) found = true } if (found == false) ret = false } } @@ -177,4 +184,12 @@ object WGeq { def apply (loc:Width,exp:Width) = new WGeq(loc,exp) } +trait MPortDir +case object MInfer extends MPortDir +case object MRead extends MPortDir +case object MWrite extends MPortDir +case object MReadWrite extends MPortDir + +case class CDefMemory (val info: FileInfo, val name: String, val tpe: Type, val size: Int, val seq: Boolean) extends Stmt +case class CDefMPort (val info: FileInfo, val name: String, val tpe: Type, val mem: String, val exps: Seq[Expression], val direction: MPortDir) extends Stmt diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 6c77d35d..a6b53e86 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -9,12 +9,13 @@ import scala.sys.process._ import scala.io.Source // Datastructures -import scala.collection.mutable.HashMap +import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.ArrayBuffer import firrtl._ import firrtl.Utils._ import firrtl.PrimOps._ +import firrtl.WrappedExpression._ trait Pass extends LazyLogging { def name: String @@ -107,7 +108,7 @@ object ResolveKinds extends Pass { def name = "Resolve Kinds" def run (c:Circuit): Circuit = { def resolve_kinds (m:Module, c:Circuit):Module = { - val kinds = HashMap[String,Kind]() + val kinds = LinkedHashMap[String,Kind]() def resolve (body:Stmt) = { def resolve_expr (e:Expression):Expression = { e match { @@ -157,7 +158,7 @@ object ResolveKinds extends Pass { object InferTypes extends Pass { private var mname = "" def name = "Infer Types" - val width_name_hash = HashMap[String,Int]() + val width_name_hash = LinkedHashMap[String,Int]() def set_type (s:Stmt,t:Type) : Stmt = { s match { case s:DefWire => DefWire(s.info,s.name,t) @@ -175,9 +176,9 @@ object InferTypes extends Pass { } def remove_unknowns (t:Type): Type = mapr(remove_unknowns_w _,t) def run (c:Circuit): Circuit = { - val module_types = HashMap[String,Type]() + val module_types = LinkedHashMap[String,Type]() def infer_types (m:Module) : Module = { - val types = HashMap[String,Type]() + val types = LinkedHashMap[String,Type]() def infer_types_e (e:Expression) : Expression = { eMap(infer_types_e _,e) match { case e:ValidIf => ValidIf(e.cond,e.value,tpe(e.value)) @@ -328,10 +329,10 @@ object CheckGenders extends Pass with StanzaPass { object InferWidths extends Pass { def name = "Infer Widths" var mname = "" - def solve_constraints (l:Seq[WGeq]) : HashMap[String,Width] = { + def solve_constraints (l:Seq[WGeq]) : LinkedHashMap[String,Width] = { def unique (ls:Seq[Width]) : Seq[Width] = ls.map(w => new WrappedWidth(w)).distinct.map(_.w) - def make_unique (ls:Seq[WGeq]) : HashMap[String,Width] = { - val h = HashMap[String,Width]() + def make_unique (ls:Seq[WGeq]) : LinkedHashMap[String,Width] = { + val h = LinkedHashMap[String,Width]() for (g <- ls) { (g.loc) match { case (w:VarWidth) => { @@ -369,10 +370,10 @@ object InferWidths extends Pass { case (w1,w2) => w }} case (w:ExpWidth) => { (w.arg1) match { - case (w1:IntWidth) => IntWidth((2 ^ w1.width) - 1) + case (w1:IntWidth) => IntWidth(BigInt((scala.math.pow(2,w1.width.toDouble) - 1).toLong)) case (w1) => w }} case (w) => w } } - def substitute (h:HashMap[String,Width])(w:Width) : Width = { + def substitute (h:LinkedHashMap[String,Width])(w:Width) : Width = { //;println-all-debug(["Substituting for [" w "]"]) val wx = simplify(w) //;println-all-debug(["After Simplify: [" wx "]"]) @@ -394,7 +395,7 @@ object InferWidths extends Pass { //;println-all-debug(["not varwidth!" w]) } } - def b_sub (h:HashMap[String,Width])(w:Width) : Width = { + def b_sub (h:LinkedHashMap[String,Width])(w:Width) : Width = { (wMap(b_sub(h) _,w)) match { case (w:VarWidth) => if (h.contains(w.name)) h(w.name) else w case (w) => w @@ -433,54 +434,44 @@ object InferWidths extends Pass { //; 2) Remove Cycles //; 3) Move to solved if not self-recursive val u = make_unique(l) - /* - println("======== UNIQUE CONSTRAINTS ========") - for (x <- u) { println(x) } - println("====================================") - */ + + //println("======== UNIQUE CONSTRAINTS ========") + //for (x <- u) { println(x) } + //println("====================================") + - val f = HashMap[String,Width]() + val f = LinkedHashMap[String,Width]() val o = ArrayBuffer[String]() for (x <- u) { - /* - println("==== SOLUTIONS TABLE ====") - for (x <- f) println(x) - println("=========================") - */ + //println("==== SOLUTIONS TABLE ====") + //for (x <- f) println(x) + //println("=========================") val (n, e) = (x._1, x._2) val e_sub = substitute(f)(e) - /* - println("Solving " + n + " => " + e) - println("After Substitute: " + n + " => " + e_sub) - println("==== SOLUTIONS TABLE (Post Substitute) ====") - for (x <- f) println(x) - println("=========================") - */ + //println("Solving " + n + " => " + e) + //println("After Substitute: " + n + " => " + e_sub) + //println("==== SOLUTIONS TABLE (Post Substitute) ====") + //for (x <- f) println(x) + //println("=========================") val ex = remove_cycle(n)(e_sub) - /* - println("After Remove Cycle: " + n + " => " + ex) - */ + //println("After Remove Cycle: " + n + " => " + ex) if (!self_rec(n,ex)) { - /* - println("Not rec!: " + n + " => " + ex) - println("Adding [" + n + "=>" + ex + "] to Solutions Table") - */ + //println("Not rec!: " + n + " => " + ex) + //println("Adding [" + n + "=>" + ex + "] to Solutions Table") o += n f(n) = ex } } - /* - println("Forward Solved Constraints") - for (x <- f) println(x) - */ + //println("Forward Solved Constraints") + //for (x <- f) println(x) //; Backwards Solve - val b = HashMap[String,Width]() + val b = LinkedHashMap[String,Width]() for (i <- 0 until o.size) { val n = o(o.size - 1 - i) /* @@ -510,7 +501,7 @@ object InferWidths extends Pass { case (t:ClockType) => IntWidth(1) case (t) => error("No width!"); IntWidth(-1) } } def width_BANG (e:Expression) : Width = width_BANG(tpe(e)) - def reduce_var_widths (c:Circuit,h:HashMap[String,Width]) : Circuit = { + def reduce_var_widths (c:Circuit,h:LinkedHashMap[String,Width]) : Circuit = { def evaluate (w:Width) : Width = { def apply_2 (a:Option[BigInt],b:Option[BigInt], f: (BigInt,BigInt) => BigInt) : Option[BigInt] = { (a,b) match { @@ -525,6 +516,7 @@ object InferWidths extends Pass { } def max (a:BigInt,b:BigInt) : BigInt = if (a >= b) a else b def min (a:BigInt,b:BigInt) : BigInt = if (a >= b) b else a + def pow (a:BigInt,b:BigInt) : BigInt = BigInt((scala.math.pow(a.toDouble,b.toDouble) - 1).toLong) def solve (w:Width) : Option[BigInt] = { (w) match { case (w:VarWidth) => { @@ -539,7 +531,7 @@ object InferWidths extends Pass { case (w:MinWidth) => apply_l(w.args.map(solve _),min) case (w:PlusWidth) => apply_2(solve(w.arg1),solve(w.arg2),{_ + _}) case (w:MinusWidth) => apply_2(solve(w.arg1),solve(w.arg2),{_ - _}) - case (w:ExpWidth) => apply_2(Some(BigInt(2)),solve(w.arg1),{(x,y) => (x ^ y) - BigInt(1)}) + case (w:ExpWidth) => apply_2(Some(BigInt(2)),solve(w.arg1),pow) case (w:IntWidth) => Some(w.width) case (w) => println(w); error("Shouldn't be here"); None; } @@ -691,7 +683,7 @@ object ExpandConnects extends Pass { def run (c:Circuit): Circuit = { def expand_connects (m:InModule) : InModule = { mname = m.name - val genders = HashMap[String,Gender]() + val genders = LinkedHashMap[String,Gender]() def expand_s (s:Stmt) : Stmt = { def set_gender (e:Expression) : Expression = { eMap(set_gender _,e) match { @@ -854,7 +846,7 @@ object RemoveAccesses extends Pass { def remove_s (s:Stmt) : Stmt = { val stmts = ArrayBuffer[Stmt]() def create_temp (e:Expression) : Expression = { - val n = firrtl_gensym("GEN",sh) + val n = firrtl_gensym_module(mname) stmts += DefWire(info(s),n,tpe(e)) WRef(n,tpe(e),kind(e),gender(e)) } @@ -897,7 +889,7 @@ object RemoveAccesses extends Pass { if (has_access(s.loc)) { val ls = get_locations(s.loc) val locx = - if (ls.size == 1 & ls(0).guard == one) s.loc + if (ls.size == 1 & weq(ls(0).guard,one)) s.loc else { val temp = create_temp(s.loc) for (x <- ls) { stmts += Conditionally(s.info,x.guard,Connect(s.info,x.base,temp),Empty()) } @@ -930,12 +922,12 @@ object ExpandWhens extends Pass { def name = "Expand Whens" var mname = "" // ; ========== Expand When Utilz ========== - def add (hash:HashMap[WrappedExpression,Expression],key:WrappedExpression,value:Expression) = { + def add (hash:LinkedHashMap[WrappedExpression,Expression],key:WrappedExpression,value:Expression) = { hash += (key -> value) } - def get_entries (hash:HashMap[WrappedExpression,Expression],exps:Seq[Expression]) : HashMap[WrappedExpression,Expression] = { - val hashx = HashMap[WrappedExpression,Expression]() + def get_entries (hash:LinkedHashMap[WrappedExpression,Expression],exps:Seq[Expression]) : LinkedHashMap[WrappedExpression,Expression] = { + val hashx = LinkedHashMap[WrappedExpression,Expression]() exps.foreach { e => { val value = hash.get(e) value match { @@ -987,10 +979,10 @@ object ExpandWhens extends Pass { val bodyx = void_all_s(m.body) InModule(m.info,m.name,m.ports,Begin(Seq(Begin(voids),bodyx))) } - def expand_whens (m:InModule) : Tuple2[HashMap[WrappedExpression,Expression],ArrayBuffer[Stmt]] = { + def expand_whens (m:InModule) : Tuple2[LinkedHashMap[WrappedExpression,Expression],ArrayBuffer[Stmt]] = { val simlist = ArrayBuffer[Stmt]() mname = m.name - def expand_whens (netlist:HashMap[WrappedExpression,Expression],p:Expression)(s:Stmt) : Stmt = { + def expand_whens (netlist:LinkedHashMap[WrappedExpression,Expression],p:Expression)(s:Stmt) : Stmt = { (s) match { case (s:Connect) => netlist(s.loc) = s.exp case (s:IsInvalid) => netlist(s.exp) = WInvalid() @@ -1025,14 +1017,14 @@ object ExpandWhens extends Pass { } } case (s:Print) => { - if (p == one) { + if (weq(p,one)) { simlist += s } else { simlist += Print(s.info,s.string,s.args,s.clk,AND(p,s.en)) } } case (s:Stop) => { - if (p == one) { + if (weq(p,one)) { simlist += s } else { simlist += Stop(s.info,s.ret,s.clk,AND(p,s.en)) @@ -1042,7 +1034,7 @@ object ExpandWhens extends Pass { } s } - val netlist = HashMap[WrappedExpression,Expression]() + val netlist = LinkedHashMap[WrappedExpression,Expression]() expand_whens(netlist,one)(m.body) //println("Netlist:") @@ -1052,7 +1044,7 @@ object ExpandWhens extends Pass { ( netlist, simlist ) } - def create_module (netlist:HashMap[WrappedExpression,Expression],simlist:ArrayBuffer[Stmt],m:InModule) : InModule = { + def create_module (netlist:LinkedHashMap[WrappedExpression,Expression],simlist:ArrayBuffer[Stmt],m:InModule) : InModule = { mname = m.name val stmts = ArrayBuffer[Stmt]() val connections = ArrayBuffer[Stmt]() @@ -1242,10 +1234,9 @@ object SplitExp extends Pass { def split_exp (m:InModule) : InModule = { mname = m.name val v = ArrayBuffer[Stmt]() - val sh = sym_hash def split_exp_s (s:Stmt) : Stmt = { def split (e:Expression) : Expression = { - val n = firrtl_gensym("GEN",sh) + val n = firrtl_gensym_module(mname) v += DefNode(info(s),n,e) WRef(n,tpe(e),kind(e),gender(e)) } @@ -1385,7 +1376,7 @@ object LowerTypes extends Pass { //;------------- Pass ------------------ def lower_types (m:Module) : Module = { - val mdt = HashMap[String,Type]() + val mdt = LinkedHashMap[String,Type]() mname = m.name def lower_types (s:Stmt) : Stmt = { def lower_mem (e:Expression) : Seq[Expression] = { @@ -1522,3 +1513,408 @@ object LowerTypes extends Pass { } } +object CInferTypes extends Pass { + def name = "CInfer Types" + var mname = "" + def set_type (s:Stmt,t:Type) : Stmt = { + (s) match { + case (s:DefWire) => DefWire(s.info,s.name,t) + case (s:DefRegister) => DefRegister(s.info,s.name,t,s.clock,s.reset,s.init) + case (s:CDefMemory) => CDefMemory(s.info,s.name,t,s.size,s.seq) + case (s:CDefMPort) => CDefMPort(s.info,s.name,t,s.mem,s.exps,s.direction) + case (s:DefNode) => s + case (s:DefPoison) => DefPoison(s.info,s.name,t) + } + } + + def to_field (p:Port) : Field = { + if (p.direction == OUTPUT) Field(p.name,DEFAULT,p.tpe) + else if (p.direction == INPUT) Field(p.name,REVERSE,p.tpe) + else error("Shouldn't be here"); Field(p.name,REVERSE,p.tpe) + } + def module_type (m:Module) : Type = BundleType(m.ports.map(p => to_field(p))) + def field_type (v:Type,s:String) : Type = { + (v) match { + case (v:BundleType) => { + val ft = v.fields.find(p => p.name == s) + if (ft != None) ft.get.tpe + else UnknownType() + } + case (v) => UnknownType() + } + } + def sub_type (v:Type) : Type = + (v) match { + case (v:VectorType) => v.tpe + case (v) => UnknownType() + } + def run (c:Circuit) : Circuit = { + val module_types = LinkedHashMap[String,Type]() + def infer_types (m:Module) : Module = { + val types = LinkedHashMap[String,Type]() + def infer_types_e (e:Expression) : Expression = { + (eMap(infer_types_e _,e)) match { + case (e:Ref) => Ref(e.name, types.getOrElse(e.name,UnknownType())) + case (e:SubField) => SubField(e.exp,e.name,field_type(tpe(e.exp),e.name)) + case (e:SubIndex) => SubIndex(e.exp,e.value,sub_type(tpe(e.exp))) + case (e:SubAccess) => SubAccess(e.exp,e.index,sub_type(tpe(e.exp))) + case (e:DoPrim) => set_primop_type(e) + case (e:Mux) => Mux(e.cond,e.tval,e.fval,mux_type(e.tval,e.tval)) + case (e:ValidIf) => ValidIf(e.cond,e.value,tpe(e.value)) + case (_:UIntValue|_:SIntValue) => e + } + } + def infer_types_s (s:Stmt) : Stmt = { + (s) match { + case (s:DefRegister) => { + types(s.name) = s.tpe + eMap(infer_types_e _,s) + s + } + case (s:DefWire) => { + types(s.name) = s.tpe + s + } + case (s:DefPoison) => { + types(s.name) = s.tpe + s + } + case (s:DefNode) => { + val sx = eMap(infer_types_e _,s) + val t = get_type(sx) + types(s.name) = t + sx + } + case (s:DefMemory) => { + types(s.name) = get_type(s) + s + } + case (s:CDefMPort) => { + val t = types.getOrElse(s.mem,UnknownType()) + types(s.name) = t + CDefMPort(s.info,s.name,t,s.mem,s.exps,s.direction) + } + case (s:CDefMemory) => { + types(s.name) = s.tpe + s + } + case (s:DefInstance) => { + types(s.name) = module_types.getOrElse(s.module,UnknownType()) + s + } + case (s) => eMap(infer_types_e _,sMap(infer_types_s _,s)) + } + } + for (p <- m.ports) { + types(p.name) = p.tpe + } + (m) match { + case (m:InModule) => InModule(m.info,m.name,m.ports,infer_types_s(m.body)) + case (m:ExModule) => m + } + } + + //; MAIN + for (m <- c.modules) { + module_types(m.name) = module_type(m) + } + val modulesx = c.modules.map(m => infer_types(m)) + Circuit(c.info, modulesx, c.main) + } +} + +object CInferMDir extends Pass { + def name = "CInfer MDir" + var mname = "" + def run (c:Circuit) : Circuit = { + def infer_mdir (m:Module) : Module = { + val mports = LinkedHashMap[String,MPortDir]() + def infer_mdir_e (dir:MPortDir)(e:Expression) : Expression = { + (eMap(infer_mdir_e(dir) _,e)) match { + case (e:Ref) => { + if (mports.contains(e.name)) { + val new_mport_dir = { + (mports(e.name),dir) match { + case (MInfer,MInfer) => error("Shouldn't be here") + case (MInfer,MWrite) => MWrite + case (MInfer,MRead) => MRead + case (MInfer,MReadWrite) => MReadWrite + case (MWrite,MInfer) => error("Shouldn't be here") + case (MWrite,MWrite) => MWrite + case (MWrite,MRead) => MReadWrite + case (MWrite,MReadWrite) => MReadWrite + case (MRead,MInfer) => error("Shouldn't be here") + case (MRead,MWrite) => MReadWrite + case (MRead,MRead) => MRead + case (MRead,MReadWrite) => MReadWrite + case (MReadWrite,MInfer) => error("Shouldn't be here") + case (MReadWrite,MWrite) => MReadWrite + case (MReadWrite,MRead) => MReadWrite + case (MReadWrite,MReadWrite) => MReadWrite + } + } + mports(e.name) = new_mport_dir + } + e + } + case (e) => e + } + } + def infer_mdir_s (s:Stmt) : Stmt = { + (s) match { + case (s:CDefMPort) => { + mports(s.name) = s.direction + eMap(infer_mdir_e(MRead) _,s) + } + case (s:Connect) => { + infer_mdir_e(MRead)(s.exp) + infer_mdir_e(MWrite)(s.loc) + s + } + case (s:BulkConnect) => { + infer_mdir_e(MRead)(s.exp) + infer_mdir_e(MWrite)(s.loc) + s + } + case (s) => eMap(infer_mdir_e(MRead) _, sMap(infer_mdir_s,s)) + } + } + def set_mdir_s (s:Stmt) : Stmt = { + (s) match { + case (s:CDefMPort) => + CDefMPort(s.info,s.name,s.tpe,s.mem,s.exps,mports(s.name)) + case (s) => sMap(set_mdir_s _,s) + } + } + (m) match { + case (m:InModule) => { + infer_mdir_s(m.body) + InModule(m.info,m.name,m.ports,set_mdir_s(m.body)) + } + case (m:ExModule) => m + } + } + + //; MAIN + Circuit(c.info, c.modules.map(m => infer_mdir(m)), c.main) + } +} + +case class MPort( val name : String, val clk : Expression) +case class MPorts( val readers : ArrayBuffer[MPort], val writers : ArrayBuffer[MPort], val readwriters : ArrayBuffer[MPort]) +case class DataRef( val exp : Expression, val male : String, val female : String, val mask : String, val rdwrite : Boolean) + +object RemoveCHIRRTL extends Pass { + def name = "Remove CHIRRTL" + var mname = "" + def create_exps (e:Expression) : Seq[Expression] = { + (e) match { + case (e:Mux)=> + (create_exps(e.tval),create_exps(e.fval)).zipped.map((e1,e2) => { + Mux(e.cond,e1,e2,mux_type(e1,e2)) + }) + case (e:ValidIf) => + create_exps(e.value).map(e1 => { + ValidIf(e.cond,e1,tpe(e1)) + }) + case (e) => (tpe(e)) match { + case (_:UIntType|_:SIntType|_:ClockType) => Seq(e) + case (t:BundleType) => + t.fields.flatMap(f => create_exps(SubField(e,f.name,f.tpe))) + case (t:VectorType)=> + (0 until t.size).flatMap(i => create_exps(SubIndex(e,i,t.tpe))) + case (t:UnknownType) => Seq(e) + } + } + } + def run (c:Circuit) : Circuit = { + def remove_chirrtl_m (m:InModule) : InModule = { + val hash = LinkedHashMap[String,MPorts]() + val repl = LinkedHashMap[String,DataRef]() + val ut = UnknownType() + val mport_types = LinkedHashMap[String,Type]() + def EMPs () : MPorts = MPorts(ArrayBuffer[MPort](),ArrayBuffer[MPort](),ArrayBuffer[MPort]()) + def collect_mports (s:Stmt) : Stmt = { + (s) match { + case (s:CDefMPort) => { + val mports = hash.getOrElse(s.mem,EMPs()) + s.direction match { + case MRead => mports.readers += MPort(s.name,s.exps(1)) + case MWrite => mports.writers += MPort(s.name,s.exps(1)) + case MReadWrite => mports.readwriters += MPort(s.name,s.exps(1)) + } + hash(s.mem) = mports + s + } + case (s) => sMap(collect_mports _,s) + } + } + def collect_refs (s:Stmt) : Stmt = { + (s) match { + case (s:CDefMemory) => { + mport_types(s.name) = s.tpe + val stmts = ArrayBuffer[Stmt]() + val taddr = UIntType(IntWidth(scala.math.max(1,ceil_log2(s.size)))) + val tdata = s.tpe + def set_poison (vec:Seq[MPort],addr:String) : Unit = { + for (r <- vec ) { + stmts += IsInvalid(s.info,SubField(SubField(Ref(s.name,ut),r.name,ut),addr,taddr)) + stmts += Connect(s.info,SubField(SubField(Ref(s.name,ut),r.name,ut),"clk",taddr),r.clk) + } + } + def set_enable (vec:Seq[MPort],en:String) : Unit = { + for (r <- vec ) { + stmts += Connect(s.info,SubField(SubField(Ref(s.name,ut),r.name,ut),en,taddr),zero) + }} + def set_wmode (vec:Seq[MPort],wmode:String) : Unit = { + for (r <- vec) { + stmts += Connect(s.info,SubField(SubField(Ref(s.name,ut),r.name,ut),wmode,taddr),zero) + }} + def set_write (vec:Seq[MPort],data:String,mask:String) : Unit = { + val tmask = create_mask(s.tpe) + for (r <- vec ) { + stmts += IsInvalid(s.info,SubField(SubField(Ref(s.name,ut),r.name,ut),data,tdata)) + for (x <- create_exps(SubField(SubField(Ref(s.name,ut),r.name,ut),mask,tmask)) ) { + stmts += Connect(s.info,x,zero) + }}} + val rds = (hash.getOrElse(s.name,EMPs())).readers + set_poison(rds,"addr") + set_enable(rds,"en") + val wrs = (hash.getOrElse(s.name,EMPs())).writers + set_poison(wrs,"addr") + set_enable(wrs,"en") + set_write(wrs,"data","mask") + val rws = (hash.getOrElse(s.name,EMPs())).readwriters + set_poison(rws,"addr") + set_wmode(rws,"wmode") + set_enable(rws,"en") + set_write(rws,"data","mask") + val read_l = if (s.seq) 1 else 0 + val mem = DefMemory(s.info,s.name,s.tpe,s.size,1,read_l,rds.map(_.name),wrs.map(_.name),rws.map(_.name)) + Begin(Seq(mem,Begin(stmts))) + } + case (s:CDefMPort) => { + mport_types(s.name) = mport_types(s.mem) + val addrs = ArrayBuffer[String]() + val ens = ArrayBuffer[String]() + val masks = ArrayBuffer[String]() + s.direction match { + case MReadWrite => { + repl(s.name) = DataRef(SubField(Ref(s.mem,ut),s.name,ut),"rdata","data","mask",true) + addrs += "addr" + ens += "en" + masks += "mask" + } + case MWrite => { + repl(s.name) = DataRef(SubField(Ref(s.mem,ut),s.name,ut),"data","data","mask",false) + addrs += "addr" + ens += "en" + masks += "mask" + } + case _ => { + repl(s.name) = DataRef(SubField(Ref(s.mem,ut),s.name,ut),"data","data","blah",false) + addrs += "addr" + ens += "en" + } + } + val stmts = ArrayBuffer[Stmt]() + for (x <- addrs ) { + stmts += Connect(s.info,SubField(SubField(Ref(s.mem,ut),s.name,ut),x,ut),s.exps(0)) + } + for (x <- ens ) { + stmts += Connect(s.info,SubField(SubField(Ref(s.mem,ut),s.name,ut),x,ut),one) + } + Begin(stmts) + } + case (s) => sMap(collect_refs _,s) + } + } + def remove_chirrtl_s (s:Stmt) : Stmt = { + var has_write_mport = false + var has_readwrite_mport:Option[Expression] = None + def remove_chirrtl_e (g:Gender)(e:Expression) : Expression = { + (e) match { + case (e:Ref) => { + if (repl.contains(e.name)) { + val vt = repl(e.name) + g match { + case MALE => SubField(vt.exp,vt.male,e.tpe) + case FEMALE => { + has_write_mport = true + if (vt.rdwrite == true) + has_readwrite_mport = Some(SubField(vt.exp,"wmode",UIntType(IntWidth(1)))) + SubField(vt.exp,vt.female,e.tpe) + } + } + } else e + } + case (e:SubAccess) => SubAccess(remove_chirrtl_e(g)(e.exp),remove_chirrtl_e(MALE)(e.index),e.tpe) + case (e) => eMap(remove_chirrtl_e(g) _,e) + } + } + def get_mask (e:Expression) : Expression = { + (eMap(get_mask _,e)) match { + case (e:Ref) => { + if (repl.contains(e.name)) { + val vt = repl(e.name) + val t = create_mask(e.tpe) + SubField(vt.exp,vt.mask,t) + } else e + } + case (e) => e + } + } + (s) match { + case (s:Connect) => { + val stmts = ArrayBuffer[Stmt]() + val rocx = remove_chirrtl_e(MALE)(s.exp) + val locx = remove_chirrtl_e(FEMALE)(s.loc) + stmts += Connect(s.info,locx,rocx) + if (has_write_mport) { + val e = get_mask(s.loc) + for (x <- create_exps(e) ) { + stmts += Connect(s.info,x,one) + } + if (has_readwrite_mport != None) { + val wmode = has_readwrite_mport.get + stmts += Connect(s.info,wmode,one) + } + } + if (stmts.size > 1) Begin(stmts) + else stmts(0) + } + case (s:BulkConnect) => { + val stmts = ArrayBuffer[Stmt]() + val locx = remove_chirrtl_e(FEMALE)(s.loc) + val rocx = remove_chirrtl_e(MALE)(s.exp) + stmts += BulkConnect(s.info,locx,rocx) + if (has_write_mport != false) { + val ls = get_valid_points(tpe(s.loc),tpe(s.exp),DEFAULT,DEFAULT) + val locs = create_exps(get_mask(s.loc)) + for (x <- ls ) { + val locx = locs(x._1) + stmts += Connect(s.info,locx,one) + } + if (has_readwrite_mport != None) { + val wmode = has_readwrite_mport.get + stmts += Connect(s.info,wmode,one) + } + } + if (stmts.size > 1) Begin(stmts) + else stmts(0) + } + case (s) => eMap(remove_chirrtl_e(MALE) _, sMap(remove_chirrtl_s,s)) + } + } + collect_mports(m.body) + val sx = collect_refs(m.body) + InModule(m.info,m.name, m.ports, remove_chirrtl_s(sx)) + } + val modulesx = c.modules.map{ m => { + (m) match { + case (m:InModule) => remove_chirrtl_m(m) + case (m:ExModule) => m + }}} + Circuit(c.info,modulesx, c.main) + } +} diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza index 9d7a1b97..97302711 100644 --- a/src/main/stanza/passes.stanza +++ b/src/main/stanza/passes.stanza @@ -114,10 +114,10 @@ defmethod get-type (s:WDefInstance) -> Type : type(s) defmethod equal? (e1:Expression,e2:Expression) -> True|False : match(e1,e2) : (e1:UIntValue,e2:UIntValue) : - if value(e1) == value(e2) : width(e1) == width(e2) + if to-int(value(e1)) == to-int(value(e2)) : width(e1) == width(e2) else : false (e1:SIntValue,e2:SIntValue) : - if value(e1) == value(e2) : width(e1) == width(e2) + if to-int(value(e1)) == to-int(value(e2)) : width(e1) == width(e2) else : false (e1:WRef,e2:WRef) : name(e1) == name(e2) (e1:WSubField,e2:WSubField) : |
