diff options
Diffstat (limited to 'src/main/scala/firrtl/Passes.scala')
| -rw-r--r-- | src/main/scala/firrtl/Passes.scala | 246 |
1 files changed, 222 insertions, 24 deletions
diff --git a/src/main/scala/firrtl/Passes.scala b/src/main/scala/firrtl/Passes.scala index e28205f9..29b42d54 100644 --- a/src/main/scala/firrtl/Passes.scala +++ b/src/main/scala/firrtl/Passes.scala @@ -2,6 +2,7 @@ package firrtl import com.typesafe.scalalogging.LazyLogging +import scala.collection.mutable.HashMap import Utils._ import DebugUtils._ @@ -9,29 +10,50 @@ import PrimOps._ object Passes extends LazyLogging { - // TODO Perhaps we should get rid of Logger since this map would be nice - ////private val defaultLogger = Logger() - //private def mapNameToPass = Map[String, Circuit => Circuit] ( - // "infer-types" -> inferTypes - //) - def nameToPass(name: String): Circuit => Circuit = { - //mapNameToPass.getOrElse(name, throw new Exception("No Standard FIRRTL Pass of name " + name)) - name match { - case "to-working-ir" => toWorkingIr - //case "infer-types" => inferTypes - // errrrrrrrrrr... - //case "renameall" => renameall(Map()) - } - } - - private def toField(p: Port): Field = { - logger.debug(s"toField called on port ${p.serialize}") - p.dir match { - case Input => Field(p.name, Reverse, p.tpe) - case Output => Field(p.name, Default, p.tpe) - } - } - + // TODO Perhaps we should get rid of Logger since this map would be nice + ////private val defaultLogger = Logger() + //private def mapNameToPass = Map[String, Circuit => Circuit] ( + // "infer-types" -> inferTypes + //) + def nameToPass(name: String): Circuit => Circuit = { + //mapNameToPass.getOrElse(name, throw new Exception("No Standard FIRRTL Pass of name " + name)) + name match { + case "to-working-ir" => toWorkingIr + //case "infer-types" => inferTypes + // errrrrrrrrrr... + //case "renameall" => renameall(Map()) + } + } + + private def toField(p: Port): Field = { + logger.debug(s"toField called on port ${p.serialize}") + p.direction match { + case Input => Field(p.name, Reverse, p.tpe) + case Output => Field(p.name, Default, p.tpe) + } + } + // ============== RESOLVE ALL =================== + def resolve (c:Circuit) = { + val passes = Seq( + toWorkingIr _, + resolveKinds _, + inferTypes _) + val names = Seq( + "To Working IR", + "Resolve Kinds", + "Infer Types") + var c_BANG = c + (names, passes).zipped.foreach { + (n,p) => { + println("Starting " + n) + c_BANG = p(c_BANG) + println("Finished " + n) + } + } + c_BANG + } + + // ============== TO WORKING IR ================== def toWorkingIr (c:Circuit) = { def toExp (e:Expression) : Expression = { @@ -55,10 +77,186 @@ object Passes extends LazyLogging { case m:ExModule => m } } - Circuit(c.info,modulesx,c.main) + println("Before To Working IR") + println(c.serialize()) + val x = Circuit(c.info,modulesx,c.main) + println("After To Working IR") + println(x.serialize()) + x + } + + // =============================================== + + // ============== RESOLVE KINDS ================== + def resolveKinds (c:Circuit) = { + def resolve_kinds (m:Module, c:Circuit):Module = { + val kinds = HashMap[String,Kind]() + def resolve (body:Stmt) = { + def resolve_expr (e:Expression):Expression = { + e match { + case e:WRef => WRef(e.name,tpe(e),kinds(e.name),e.gender) + case e => eMap(resolve_expr,e) + } + } + def resolve_stmt (s:Stmt):Stmt = eMap(resolve_expr,sMap(resolve_stmt,s)) + resolve_stmt(body) + } + + def find (m:Module) = { + def find_stmt (s:Stmt):Stmt = { + s match { + case s:DefWire => kinds += (s.name -> WireKind()) + case s:DefPoison => kinds += (s.name -> PoisonKind()) + case s:DefNode => kinds += (s.name -> NodeKind()) + case s:DefRegister => kinds += (s.name -> RegKind()) + case s:WDefInstance => kinds += (s.name -> InstanceKind()) + case s:DefMemory => kinds += (s.name -> MemKind(s.readers ++ s.writers ++ s.readwriters)) + case s => false + } + sMap(find_stmt,s) + } + m.ports.foreach { p => kinds += (p.name -> PortKind()) } + println(kinds) + m match { + case m:InModule => find_stmt(m.body) + case m:ExModule => false + } + } + + find(m) + m match { + case m:InModule => { + val bodyx = resolve(m.body) + InModule(m.info,m.name,m.ports,bodyx) + } + case m:ExModule => ExModule(m.info,m.name,m.ports) + } + } + val modulesx = c.modules.map(m => resolve_kinds(m,c)) + println("Before Resolve Kinds") + println(c.serialize()) + val x = Circuit(c.info,modulesx,c.main) + println("After Resolve Kinds") + println(x.serialize()) + x } // =============================================== + // ============== INFER TYPES ================== + + // ------------------ Utils ------------------------- + + val width_name_hash = Map[String,Int]() + def set_type (s:Stmt,t:Type) : Stmt = { + s match { + case s:DefWire => DefWire(s.info,s.name,t) + case s:DefRegister => DefRegister(s.info,s.name,t,s.clock,s.reset,s.init) + case s:DefMemory => DefMemory(s.info,s.name,t,s.depth,s.write_latency,s.read_latency,s.readers,s.writers,s.readwriters) + case s:DefNode => s + case s:DefPoison => DefPoison(s.info,s.name,t) + } + } + def remove_unknowns_w (w:Width):Width = { + w match { + case w:UnknownWidth => VarWidth(firrtl_gensym("w",width_name_hash)) + case w => w + } + } + def remove_unknowns (t:Type): Type = mapr(remove_unknowns_w _,t) + def mapr (f: Width => Width, t:Type) : Type = { + def apply_t (t:Type) : Type = { + wMap(f,tMap(apply_t _,t)) + } + apply_t(t) + } + + + + // ------------------ Pass ------------------------- + + def inferTypes (c:Circuit) : Circuit = { + val module_types = HashMap[String,Type]() + def infer_types (m:Module) : Module = { + val types = HashMap[String,Type]() + def infer_types_e (e:Expression) : Expression = { + eMap(infer_types_e _,e) match { + case e:ValidIf => ValidIf(e.cond,e.value,tpe(e.value)) + case e:WRef => WRef(e.name, types(e.name),e.kind,e.gender) + case e:WSubField => WSubField(e.exp,e.name,field_type(tpe(e.exp),e.name),e.gender) + case e:WSubIndex => WSubIndex(e.exp,e.value,sub_type(tpe(e.exp)),e.gender) + case e:WSubAccess => WSubAccess(e.exp,e.index,sub_type(tpe(e.exp)),e.gender) + case e:DoPrim => set_primop_type(e) + case e:Mux => Mux(e.cond,e.tval,e.fval,mux_type_and_widths(e.tval,e.fval)) + case e:UIntValue => e + case e:SIntValue => e + } + } + def infer_types_s (s:Stmt) : Stmt = { + s match { + case s:DefRegister => { + val t = remove_unknowns(get_type(s)) + types += (s.name -> t) + eMap(infer_types_e _,set_type(s,t)) + } + case s:DefWire => { + val sx = eMap(infer_types_e _,s) + val t = remove_unknowns(get_type(sx)) + types += (s.name -> t) + set_type(sx,t) + } + case s:DefPoison => { + val sx = eMap(infer_types_e _,s) + val t = remove_unknowns(get_type(sx)) + types += (s.name -> t) + set_type(sx,t) + } + case s:DefNode => { + val sx = eMap(infer_types_e _,s) + val t = remove_unknowns(get_type(sx)) + types += (s.name -> t) + set_type(sx,t) + } + case s:DefMemory => { + val t = remove_unknowns(get_type(s)) + types += (s.name -> t) + val dt = remove_unknowns(s.data_type) + set_type(s,dt) + } + case s:WDefInstance => { + types += (s.name -> module_types(s.module)) + WDefInstance(s.info,s.name,s.module,module_types(s.module)) + } + case s => eMap(infer_types_e _,sMap(infer_types_s,s)) + } + } + + m.ports.foreach(p => types += (p.name -> p.tpe)) + m match { + case m:InModule => InModule(m.info,m.name,m.ports,infer_types_s(m.body)) + case m:ExModule => m + } + } + + + // MAIN + val modulesx = c.modules.map { + m => { + val portsx = m.ports.map(p => Port(p.info,p.name,p.direction,remove_unknowns(p.tpe))) + m match { + case m:InModule => InModule(m.info,m.name,portsx,m.body) + case m:ExModule => ExModule(m.info,m.name,portsx) + } + } + } + + modulesx.foreach(m => module_types += (m.name -> module_type(m))) + println("Before Infer Types") + println(c.serialize()) + val x = Circuit(c.info,modulesx.map(m => infer_types(m)) , c.main ) + println("After Infer Types") + println(x.serialize()) + x + } /** INFER TYPES * |
