diff options
Diffstat (limited to 'src/main/scala/firrtl/passes/Passes.scala')
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 640 |
1 files changed, 0 insertions, 640 deletions
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index a4a7290e..b9808485 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -94,212 +94,6 @@ object ToWorkingIR extends Pass { } } -object ResolveKinds extends Pass { - private var mname = "" - def name = "Resolve Kinds" - def run (c:Circuit): Circuit = { - def resolve_kinds (m:DefModule, c:Circuit):DefModule = { - val kinds = LinkedHashMap[String,Kind]() - def resolve (body:Statement) = { - def resolve_expr (e:Expression):Expression = { - e match { - case e:WRef => WRef(e.name,e.tpe,kinds(e.name),e.gender) - case e => e map (resolve_expr) - } - } - def resolve_stmt (s:Statement):Statement = s map (resolve_stmt) map (resolve_expr) - resolve_stmt(body) - } - - def find (m:DefModule) = { - def find_stmt (s:Statement):Statement = { - s match { - case s:DefWire => kinds(s.name) = WireKind() - case s:DefNode => kinds(s.name) = NodeKind() - case s:DefRegister => kinds(s.name) = RegKind() - case s:WDefInstance => kinds(s.name) = InstanceKind() - case s:DefMemory => kinds(s.name) = MemKind(s.readers ++ s.writers ++ s.readwriters) - case s => false - } - s map (find_stmt) - } - m.ports.foreach { p => kinds(p.name) = PortKind() } - m match { - case m:Module => find_stmt(m.body) - case m:ExtModule => false - } - } - - mname = m.name - find(m) - m match { - case m:Module => { - val bodyx = resolve(m.body) - Module(m.info,m.name,m.ports,bodyx) - } - case m:ExtModule => ExtModule(m.info,m.name,m.ports) - } - } - val modulesx = c.modules.map(m => resolve_kinds(m,c)) - Circuit(c.info,modulesx,c.main) - } -} - -object InferTypes extends Pass { - private var mname = "" - def name = "Infer Types" - def set_type (s:Statement, t:Type) : Statement = { - s match { - case s:DefWire => DefWire(s.info,s.name,t) - case s:DefRegister => DefRegister(s.info,s.name,t,s.clock,s.reset,s.init) - case s:DefMemory => DefMemory(s.info,s.name,t,s.depth,s.writeLatency,s.readLatency,s.readers,s.writers,s.readwriters) - case s:DefNode => s - } - } - def remove_unknowns_w (w:Width)(implicit namespace: Namespace):Width = { - w match { - case UnknownWidth => VarWidth(namespace.newName("w")) - case w => w - } - } - def remove_unknowns (t:Type)(implicit n: Namespace): Type = mapr(remove_unknowns_w _,t) - def run (c:Circuit): Circuit = { - val module_types = LinkedHashMap[String,Type]() - implicit val wnamespace = Namespace() - def infer_types (m:DefModule) : DefModule = { - val types = LinkedHashMap[String,Type]() - def infer_types_e (e:Expression) : Expression = { - e map (infer_types_e) match { - case e:ValidIf => ValidIf(e.cond,e.value,e.value.tpe) - case e:WRef => WRef(e.name, types(e.name),e.kind,e.gender) - case e:WSubField => WSubField(e.exp,e.name,field_type(e.exp.tpe,e.name),e.gender) - case e:WSubIndex => WSubIndex(e.exp,e.value,sub_type(e.exp.tpe),e.gender) - case e:WSubAccess => WSubAccess(e.exp,e.index,sub_type(e.exp.tpe),e.gender) - case e:DoPrim => set_primop_type(e) - case e:Mux => Mux(e.cond,e.tval,e.fval,mux_type_and_widths(e.tval,e.fval)) - case e:UIntLiteral => e - case e:SIntLiteral => e - } - } - def infer_types_s (s:Statement) : Statement = { - s match { - case s:DefRegister => { - val t = remove_unknowns(get_type(s)) - types(s.name) = t - set_type(s,t) map (infer_types_e) - } - case s:DefWire => { - val sx = s map(infer_types_e) - val t = remove_unknowns(get_type(sx)) - types(s.name) = t - set_type(sx,t) - } - case s:DefNode => { - val sx = s map (infer_types_e) - val t = remove_unknowns(get_type(sx)) - types(s.name) = t - set_type(sx,t) - } - case s:DefMemory => { - val t = remove_unknowns(get_type(s)) - types(s.name) = t - val dt = remove_unknowns(s.dataType) - set_type(s,dt) - } - case s:WDefInstance => { - types(s.name) = module_types(s.module) - WDefInstance(s.info,s.name,s.module,module_types(s.module)) - } - case s => s map (infer_types_s) map (infer_types_e) - } - } - - mname = m.name - m.ports.foreach(p => types(p.name) = p.tpe) - m match { - case m:Module => Module(m.info,m.name,m.ports,infer_types_s(m.body)) - case m:ExtModule => m - } - } - - val modulesx = c.modules.map { - m => { - mname = m.name - val portsx = m.ports.map(p => Port(p.info,p.name,p.direction,remove_unknowns(p.tpe))) - m match { - case m:Module => Module(m.info,m.name,portsx,m.body) - case m:ExtModule => ExtModule(m.info,m.name,portsx) - } - } - } - modulesx.foreach(m => module_types(m.name) = module_type(m)) - Circuit(c.info,modulesx.map({m => mname = m.name; infer_types(m)}) , c.main ) - } -} - -object ResolveGenders extends Pass { - private var mname = "" - def name = "Resolve Genders" - def run (c:Circuit): Circuit = { - def resolve_e (g:Gender)(e:Expression) : Expression = { - e match { - case e:WRef => WRef(e.name,e.tpe,e.kind,g) - case e:WSubField => { - val expx = - field_flip(e.exp.tpe,e.name) match { - case Default => resolve_e(g)(e.exp) - case Flip => resolve_e(swap(g))(e.exp) - } - WSubField(expx,e.name,e.tpe,g) - } - case e:WSubIndex => { - val expx = resolve_e(g)(e.exp) - WSubIndex(expx,e.value,e.tpe,g) - } - case e:WSubAccess => { - val expx = resolve_e(g)(e.exp) - val indexx = resolve_e(MALE)(e.index) - WSubAccess(expx,indexx,e.tpe,g) - } - case e => e map (resolve_e(g)) - } - } - - def resolve_s (s:Statement) : Statement = { - s match { - case s:IsInvalid => { - val expx = resolve_e(FEMALE)(s.expr) - IsInvalid(s.info,expx) - } - case s:Connect => { - val locx = resolve_e(FEMALE)(s.loc) - val expx = resolve_e(MALE)(s.expr) - Connect(s.info,locx,expx) - } - case s:PartialConnect => { - val locx = resolve_e(FEMALE)(s.loc) - val expx = resolve_e(MALE)(s.expr) - PartialConnect(s.info,locx,expx) - } - case s => s map (resolve_e(MALE)) map (resolve_s) - } - } - val modulesx = c.modules.map { - m => { - mname = m.name - m match { - case m:Module => { - val bodyx = resolve_s(m.body) - Module(m.info,m.name,m.ports,bodyx) - } - case m:ExtModule => m - } - } - } - Circuit(c.info,modulesx,c.main) - } -} - object PullMuxes extends Pass { def name = "Pull Muxes" def run(c: Circuit): Circuit = { @@ -536,437 +330,3 @@ object VerilogRename extends Pass { Circuit(c.info,modulesx,c.main) } } - -object CInferTypes extends Pass { - def name = "CInfer Types" - var mname = "" - def set_type (s:Statement, t:Type) : Statement = { - (s) match { - case (s:DefWire) => DefWire(s.info,s.name,t) - case (s:DefRegister) => DefRegister(s.info,s.name,t,s.clock,s.reset,s.init) - case (s:CDefMemory) => CDefMemory(s.info,s.name,t,s.size,s.seq) - case (s:CDefMPort) => CDefMPort(s.info,s.name,t,s.mem,s.exps,s.direction) - case (s:DefNode) => s - } - } - - def to_field (p:Port) : Field = { - if (p.direction == Output) Field(p.name,Default,p.tpe) - else if (p.direction == Input) Field(p.name,Flip,p.tpe) - else error("Shouldn't be here"); Field(p.name,Flip,p.tpe) - } - def module_type (m:DefModule) : Type = BundleType(m.ports.map(p => to_field(p))) - def field_type (v:Type,s:String) : Type = { - (v) match { - case (v:BundleType) => { - val ft = v.fields.find(p => p.name == s) - if (ft != None) ft.get.tpe - else UnknownType - } - case (v) => UnknownType - } - } - def sub_type (v:Type) : Type = - (v) match { - case (v:VectorType) => v.tpe - case (v) => UnknownType - } - def run (c:Circuit) : Circuit = { - val module_types = LinkedHashMap[String,Type]() - def infer_types (m:DefModule) : DefModule = { - val types = LinkedHashMap[String,Type]() - def infer_types_e (e:Expression) : Expression = { - e map infer_types_e match { - case (e:Reference) => Reference(e.name, types.getOrElse(e.name,UnknownType)) - case (e:SubField) => SubField(e.expr,e.name,field_type(e.expr.tpe,e.name)) - case (e:SubIndex) => SubIndex(e.expr,e.value,sub_type(e.expr.tpe)) - case (e:SubAccess) => SubAccess(e.expr,e.index,sub_type(e.expr.tpe)) - case (e:DoPrim) => set_primop_type(e) - case (e:Mux) => Mux(e.cond,e.tval,e.fval,mux_type(e.tval,e.tval)) - case (e:ValidIf) => ValidIf(e.cond,e.value,e.value.tpe) - case (_:UIntLiteral | _:SIntLiteral) => e - } - } - def infer_types_s (s:Statement) : Statement = { - s match { - case (s:DefRegister) => { - types(s.name) = s.tpe - s map infer_types_e - s - } - case (s:DefWire) => { - types(s.name) = s.tpe - s - } - case (s:DefNode) => { - val sx = s map infer_types_e - val t = get_type(sx) - types(s.name) = t - sx - } - case (s:DefMemory) => { - types(s.name) = get_type(s) - s - } - case (s:CDefMPort) => { - val t = types.getOrElse(s.mem,UnknownType) - types(s.name) = t - CDefMPort(s.info,s.name,t,s.mem,s.exps,s.direction) - } - case (s:CDefMemory) => { - types(s.name) = s.tpe - s - } - case (s:DefInstance) => { - types(s.name) = module_types.getOrElse(s.module,UnknownType) - s - } - case (s) => s map infer_types_s map infer_types_e - } - } - for (p <- m.ports) { - types(p.name) = p.tpe - } - m match { - case (m:Module) => Module(m.info,m.name,m.ports,infer_types_s(m.body)) - case (m:ExtModule) => m - } - } - - //; MAIN - for (m <- c.modules) { - module_types(m.name) = module_type(m) - } - val modulesx = c.modules.map(m => infer_types(m)) - Circuit(c.info, modulesx, c.main) - } -} - -object CInferMDir extends Pass { - def name = "CInfer MDir" - var mname = "" - def run (c:Circuit) : Circuit = { - def infer_mdir (m:DefModule) : DefModule = { - val mports = LinkedHashMap[String,MPortDir]() - def infer_mdir_e (dir:MPortDir)(e:Expression) : Expression = { - (e map (infer_mdir_e(dir))) match { - case (e:Reference) => { - if (mports.contains(e.name)) { - val new_mport_dir = { - (mports(e.name),dir) match { - case (MInfer,MInfer) => error("Shouldn't be here") - case (MInfer,MWrite) => MWrite - case (MInfer,MRead) => MRead - case (MInfer,MReadWrite) => MReadWrite - case (MWrite,MInfer) => error("Shouldn't be here") - case (MWrite,MWrite) => MWrite - case (MWrite,MRead) => MReadWrite - case (MWrite,MReadWrite) => MReadWrite - case (MRead,MInfer) => error("Shouldn't be here") - case (MRead,MWrite) => MReadWrite - case (MRead,MRead) => MRead - case (MRead,MReadWrite) => MReadWrite - case (MReadWrite,MInfer) => error("Shouldn't be here") - case (MReadWrite,MWrite) => MReadWrite - case (MReadWrite,MRead) => MReadWrite - case (MReadWrite,MReadWrite) => MReadWrite - } - } - mports(e.name) = new_mport_dir - } - e - } - case (e) => e - } - } - def infer_mdir_s (s:Statement) : Statement = { - (s) match { - case (s:CDefMPort) => { - mports(s.name) = s.direction - s map (infer_mdir_e(MRead)) - } - case (s:Connect) => { - infer_mdir_e(MRead)(s.expr) - infer_mdir_e(MWrite)(s.loc) - s - } - case (s:PartialConnect) => { - infer_mdir_e(MRead)(s.expr) - infer_mdir_e(MWrite)(s.loc) - s - } - case (s) => s map (infer_mdir_s) map (infer_mdir_e(MRead)) - } - } - def set_mdir_s (s:Statement) : Statement = { - (s) match { - case (s:CDefMPort) => - CDefMPort(s.info,s.name,s.tpe,s.mem,s.exps,mports(s.name)) - case (s) => s map (set_mdir_s) - } - } - (m) match { - case (m:Module) => { - infer_mdir_s(m.body) - Module(m.info,m.name,m.ports,set_mdir_s(m.body)) - } - case (m:ExtModule) => m - } - } - - //; MAIN - Circuit(c.info, c.modules.map(m => infer_mdir(m)), c.main) - } -} - -case class MPort( val name : String, val clk : Expression) -case class MPorts( val readers : ArrayBuffer[MPort], val writers : ArrayBuffer[MPort], val readwriters : ArrayBuffer[MPort]) -case class DataRef( val exp : Expression, val male : String, val female : String, val mask : String, val rdwrite : Boolean) - -object RemoveCHIRRTL extends Pass { - def name = "Remove CHIRRTL" - var mname = "" - def create_exps (e:Expression) : Seq[Expression] = e match { - case (e:Mux) => - val e1s = create_exps(e.tval) - val e2s = create_exps(e.fval) - (e1s,e2s).zipped map ((e1,e2) => Mux(e.cond,e1,e2,mux_type(e1,e2))) - case (e:ValidIf) => - create_exps(e.value) map (e1 => ValidIf(e.cond,e1,e1.tpe)) - case (e) => (e.tpe) match { - case (_:GroundType) => Seq(e) - case (t:BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) => - exps ++ create_exps(SubField(e,f.name,f.tpe))) - case (t:VectorType) => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) => - exps ++ create_exps(SubIndex(e,i,t.tpe))) - case UnknownType => Seq(e) - } - } - def run (c:Circuit) : Circuit = { - def remove_chirrtl_m (m:Module) : Module = { - val hash = LinkedHashMap[String,MPorts]() - val repl = LinkedHashMap[String,DataRef]() - val raddrs = HashMap[String, Expression]() - val ut = UnknownType - val mport_types = LinkedHashMap[String,Type]() - val smems = HashSet[String]() - def EMPs () : MPorts = MPorts(ArrayBuffer[MPort](),ArrayBuffer[MPort](),ArrayBuffer[MPort]()) - def collect_smems_and_mports (s:Statement) : Statement = { - (s) match { - case (s:CDefMemory) if s.seq => - smems += s.name - s - case (s:CDefMPort) => { - val mports = hash.getOrElse(s.mem,EMPs()) - s.direction match { - case MRead => mports.readers += MPort(s.name,s.exps(1)) - case MWrite => mports.writers += MPort(s.name,s.exps(1)) - case MReadWrite => mports.readwriters += MPort(s.name,s.exps(1)) - } - hash(s.mem) = mports - s - } - case (s) => s map (collect_smems_and_mports) - } - } - def collect_refs (s:Statement) : Statement = { - (s) match { - case (s:CDefMemory) => { - mport_types(s.name) = s.tpe - val stmts = ArrayBuffer[Statement]() - val taddr = UIntType(IntWidth(scala.math.max(1,ceil_log2(s.size)))) - val tdata = s.tpe - def set_poison (vec:Seq[MPort],addr:String) : Unit = { - for (r <- vec ) { - stmts += IsInvalid(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),addr,taddr)) - stmts += IsInvalid(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),"clk",taddr)) - } - } - def set_enable (vec:Seq[MPort],en:String) : Unit = { - for (r <- vec ) { - stmts += Connect(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),en,taddr),zero) - }} - def set_wmode (vec:Seq[MPort],wmode:String) : Unit = { - for (r <- vec) { - stmts += Connect(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),wmode,taddr),zero) - }} - def set_write (vec:Seq[MPort],data:String,mask:String) : Unit = { - val tmask = create_mask(s.tpe) - for (r <- vec ) { - stmts += IsInvalid(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),data,tdata)) - for (x <- create_exps(SubField(SubField(Reference(s.name,ut),r.name,ut),mask,tmask)) ) { - stmts += Connect(s.info,x,zero) - }}} - val rds = (hash.getOrElse(s.name,EMPs())).readers - set_poison(rds,"addr") - set_enable(rds,"en") - val wrs = (hash.getOrElse(s.name,EMPs())).writers - set_poison(wrs,"addr") - set_enable(wrs,"en") - set_write(wrs,"data","mask") - val rws = (hash.getOrElse(s.name,EMPs())).readwriters - set_poison(rws,"addr") - set_wmode(rws,"wmode") - set_enable(rws,"en") - set_write(rws,"wdata","wmask") - val read_l = if (s.seq) 1 else 0 - val mem = DefMemory(s.info,s.name,s.tpe,s.size,1,read_l,rds.map(_.name),wrs.map(_.name),rws.map(_.name)) - Block(Seq(mem,Block(stmts))) - } - case (s:CDefMPort) => { - mport_types(s.name) = mport_types(s.mem) - val addrs = ArrayBuffer[String]() - val clks = ArrayBuffer[String]() - val ens = ArrayBuffer[String]() - val masks = ArrayBuffer[String]() - s.direction match { - case MReadWrite => { - repl(s.name) = DataRef(SubField(Reference(s.mem,ut),s.name,ut),"rdata","wdata","wmask",true) - addrs += "addr" - clks += "clk" - ens += "en" - masks += "wmask" - } - case MWrite => { - repl(s.name) = DataRef(SubField(Reference(s.mem,ut),s.name,ut),"data","data","mask",false) - addrs += "addr" - clks += "clk" - ens += "en" - masks += "mask" - } - case MRead => { - repl(s.name) = DataRef(SubField(Reference(s.mem,ut),s.name,ut),"data","data","blah",false) - addrs += "addr" - clks += "clk" - s.exps(0) match { - case e: Reference if smems(s.mem) => - raddrs(e.name) = SubField(SubField(Reference(s.mem,ut),s.name,ut),"en",ut) - case _ => ens += "en" - } - } - } - val stmts = ArrayBuffer[Statement]() - for (x <- addrs ) { - stmts += Connect(s.info,SubField(SubField(Reference(s.mem,ut),s.name,ut),x,ut),s.exps(0)) - } - for (x <- clks ) { - stmts += Connect(s.info,SubField(SubField(Reference(s.mem,ut),s.name,ut),x,ut),s.exps(1)) - } - for (x <- ens ) { - stmts += Connect(s.info,SubField(SubField(Reference(s.mem,ut),s.name,ut),x,ut),one) - } - Block(stmts) - } - case (s) => s map (collect_refs) - } - } - def remove_chirrtl_s (s:Statement) : Statement = { - var has_write_mport = false - var has_read_mport: Option[Expression] = None - var has_readwrite_mport: Option[Expression] = None - def remove_chirrtl_e (g:Gender)(e:Expression) : Expression = { - (e) match { - case (e:Reference) if repl contains e.name => - val vt = repl(e.name) - g match { - case MALE => SubField(vt.exp,vt.male,e.tpe) - case FEMALE => { - has_write_mport = true - if (vt.rdwrite) - has_readwrite_mport = Some(SubField(vt.exp,"wmode",UIntType(IntWidth(1)))) - SubField(vt.exp,vt.female,e.tpe) - } - } - case (e:Reference) if g == FEMALE && (raddrs contains e.name) => - has_read_mport = Some(raddrs(e.name)) - e - case (e:Reference) => e - case (e:SubAccess) => SubAccess(remove_chirrtl_e(g)(e.expr),remove_chirrtl_e(MALE)(e.index),e.tpe) - case (e) => e map (remove_chirrtl_e(g)) - } - } - def get_mask (e:Expression) : Expression = { - (e map (get_mask)) match { - case (e:Reference) => { - if (repl.contains(e.name)) { - val vt = repl(e.name) - val t = create_mask(e.tpe) - SubField(vt.exp,vt.mask,t) - } else e - } - case (e) => e - } - } - (s) match { - case (s:DefNode) => { - val stmts = ArrayBuffer[Statement]() - val valuex = remove_chirrtl_e(MALE)(s.value) - stmts += DefNode(s.info,s.name,valuex) - has_read_mport match { - case None => - case Some(en) => stmts += Connect(s.info,en,one) - } - if (stmts.size > 1) Block(stmts) - else stmts(0) - } - case (s:Connect) => { - val stmts = ArrayBuffer[Statement]() - val rocx = remove_chirrtl_e(MALE)(s.expr) - val locx = remove_chirrtl_e(FEMALE)(s.loc) - stmts += Connect(s.info,locx,rocx) - has_read_mport match { - case None => - case Some(en) => stmts += Connect(s.info,en,one) - } - if (has_write_mport) { - val e = get_mask(s.loc) - for (x <- create_exps(e) ) { - stmts += Connect(s.info,x,one) - } - has_readwrite_mport match { - case None => - case Some(wmode) => stmts += Connect(s.info,wmode,one) - } - } - if (stmts.size > 1) Block(stmts) - else stmts(0) - } - case (s:PartialConnect) => { - val stmts = ArrayBuffer[Statement]() - val locx = remove_chirrtl_e(FEMALE)(s.loc) - val rocx = remove_chirrtl_e(MALE)(s.expr) - stmts += PartialConnect(s.info,locx,rocx) - has_read_mport match { - case None => - case Some(en) => stmts += Connect(s.info,en,one) - } - if (has_write_mport) { - val ls = get_valid_points(s.loc.tpe,s.expr.tpe,Default,Default) - val locs = create_exps(get_mask(s.loc)) - for (x <- ls ) { - val locx = locs(x._1) - stmts += Connect(s.info,locx,one) - } - has_readwrite_mport match { - case None => - case Some(wmode) => stmts += Connect(s.info,wmode,one) - } - } - if (stmts.size > 1) Block(stmts) - else stmts(0) - } - case (s) => s map (remove_chirrtl_s) map (remove_chirrtl_e(MALE)) - } - } - collect_smems_and_mports(m.body) - val sx = collect_refs(m.body) - Module(m.info,m.name, m.ports, remove_chirrtl_s(sx)) - } - val modulesx = c.modules.map{ m => { - (m) match { - case (m:Module) => remove_chirrtl_m(m) - case (m:ExtModule) => m - }}} - Circuit(c.info,modulesx, c.main) - } -} |
