defpackage firrtl/chirrtl : import core import verse import firrtl/ir2 import firrtl/ir-utils import firrtl/primops ; CHIRRTL Additional IR Nodes public definterface MPortDir public val MInfer = new MPortDir public val MRead = new MPortDir public val MWrite = new MPortDir public val MReadWrite = new MPortDir defmethod print (o:OutputStream, m:MPortDir) : switch { m == _ } : MInfer : print(o,"infer") MRead : print(o,"read") MWrite : print(o,"write") MReadWrite : print(o,"rdwr") public defstruct CDefMemory <: Stmt : ;LOW info: FileInfo with: (as-method => true) name: Symbol type: Type size: Int seq?: True|False public defstruct CDefMPort <: Stmt : info: FileInfo with: (as-method => true) name: Symbol type: Type mem: Symbol exps: List direction: MPortDir defmethod print (o:OutputStream,c:CDefMemory) : if seq?(c) : print-all(o, ["smem " name(c) " : " type(c) "[" size(c) "]"]) else : print-all(o, ["cmem " name(c) " : " type(c) "[" size(c) "]"]) defmethod map (f: Type -> Type, c:CDefMemory) -> CDefMemory : CDefMemory(info(c),name(c),f(type(c)),size(c),seq?(c)) defmethod map (f: Symbol -> Symbol, c:CDefMemory) -> CDefMemory : CDefMemory(info(c),f(name(c)),type(c),size(c),seq?(c)) defmethod print (o:OutputStream,c:CDefMPort) : print-all(o, [direction(c) " mport " name(c) " = " mem(c) "[" exps(c)[0] "], " exps(c)[1]]) defmethod map (f: Expression -> Expression, c:CDefMPort) -> CDefMPort : CDefMPort(info(c),name(c),type(c),mem(c),map(f,exps(c)),direction(c)) defmethod map (f: Type -> Type, c:CDefMPort) -> CDefMPort : CDefMPort(info(c),name(c),f(type(c)),mem(c),exps(c),direction(c)) defmethod map (f: Symbol -> Symbol, c:CDefMPort) -> CDefMPort : CDefMPort(info(c),f(name(c)),type(c),mem(c),exps(c),direction(c)) ;======================= Infer Chirrtl Types ====================== public defstruct CInferTypes <: Pass public defmethod pass (b:CInferTypes) -> (Circuit -> Circuit) : infer-types public defmethod name (b:CInferTypes) -> String : "CInfer Types" public defmethod short-name (b:CInferTypes) -> String : "c-infer-types" ;--------------- Utils ----------------- defn set-type (s:Stmt,t:Type) -> Stmt : match(s) : (s:DefWire) : DefWire(info(s),name(s),t) (s:DefRegister) : DefRegister(info(s),name(s),t,clock(s),reset(s),init(s)) (s:CDefMemory) : CDefMemory(info(s),name(s),t,size(s),seq?(s)) (s:CDefMPort) : CDefMPort(info(s),name(s),t,mem(s),exps(s),direction(s)) (s:DefNode) : s (s:DefPoison) : DefPoison(info(s),name(s),t) defn to-field (p:Port) -> Field : if direction(p) == OUTPUT : Field(name(p),DEFAULT,type(p)) else if direction(p) == INPUT : Field(name(p),REVERSE,type(p)) else : error("Shouldn't be here") defn module-type (m:Module) -> Type : BundleType(for p in ports(m) map : to-field(p)) defn field-type (v:Type,s:Symbol) -> Type : match(v) : (v:BundleType) : val ft = for p in fields(v) find : name(p) == s if ft != false : type(ft as Field) else : UnknownType() (v) : UnknownType() defn sub-type (v:Type) -> Type : match(v) : (v:VectorType) : type(v) (v) : UnknownType() ;--------------- Pass ----------------- defn infer-types (c:Circuit) -> Circuit : val module-types = HashTable(symbol-hash) defn infer-types (m:Module) -> Module : val types = HashTable(symbol-hash) defn infer-types-e (e:Expression) -> Expression : match(map(infer-types-e,e)) : (e:Ref) : Ref(name(e), get?(types,name(e),UnknownType())) (e:SubField) : SubField(exp(e),name(e),field-type(type(exp(e)),name(e))) (e:SubIndex) : SubIndex(exp(e),value(e),sub-type(type(exp(e)))) (e:SubAccess) : SubAccess(exp(e),index(e),sub-type(type(exp(e)))) (e:DoPrim) : set-primop-type(e) (e:Mux) : Mux(cond(e),tval(e),fval(e),mux-type(tval(e),tval(e))) (e:ValidIf) : ValidIf(cond(e),value(e),type(value(e))) (e:UIntValue|SIntValue) : e defn infer-types-s (s:Stmt) -> Stmt : match(s) : (s:DefRegister) : types[name(s)] = type(s) map(infer-types-e,s) s (s:DefWire|DefPoison) : types[name(s)] = type(s) s (s:DefNode) : val s* = map(infer-types-e,s) val t = type(value(s*)) types[name(s*)] = t s* (s:DefMemory) : types[name(s)] = get-type(s) s (s:CDefMPort) : val t = get?(types,mem(s),UnknownType()) types[name(s)] = t CDefMPort(info(s),name(s),t,mem(s),exps(s),direction(s)) (s:CDefMemory) : types[name(s)] = type(s) s (s:DefInstance) : types[name(s)] = get?(module-types,module(s),UnknownType()) s (s) : map{infer-types-e,_} $ map(infer-types-s,s) for p in ports(m) do : types[name(p)] = type(p) match(m) : (m:InModule) : InModule(info(m),name(m),ports(m),infer-types-s(body(m))) (m:ExModule) : m ; MAIN for m in modules(c) do : module-types[name(m)] = module-type(m) Circuit{info(c), _, main(c) } $ for m in modules(c) map : infer-types(m) ;========================================== public defstruct CInferMDir <: Pass public defmethod pass (b:CInferMDir) -> (Circuit -> Circuit) : infer-mdir public defmethod name (b:CInferMDir) -> String : "CInfer MDir" public defmethod short-name (b:CInferMDir) -> String : "cinfer-mdir" defn infer-mdir (c:Circuit) -> Circuit : defn infer-mdir (m:Module) -> Module : val mports = HashTable(symbol-hash) defn infer-mdir-e (e:Expression,dir:MPortDir) -> Expression : match(map(infer-mdir-e{_,dir},e)) : (e:Ref) : if key?(mports,name(e)) : val new_mport_dir = switch fn ([x,y]) : mports[name(e)] == x and dir == y : [MInfer,MInfer] : error("Shouldn't be here") [MInfer,MWrite] : MWrite [MInfer,MRead] : MRead [MInfer,MReadWrite] : MReadWrite [MWrite,MInfer] : error("Shouldn't be here") [MWrite,MWrite] : MWrite [MWrite,MRead] : MReadWrite [MWrite,MReadWrite] : MReadWrite [MRead,MInfer] : error("Shouldn't be here") [MRead,MWrite] : MReadWrite [MRead,MRead] : MRead [MRead,MReadWrite] : MReadWrite [MReadWrite,MInfer] : error("Shouldn't be here") [MReadWrite,MWrite] : MReadWrite [MReadWrite,MRead] : MReadWrite [MReadWrite,MReadWrite] : MReadWrite mports[name(e)] = new_mport_dir e (e) : e defn infer-mdir-s (s:Stmt) -> Stmt : match(s) : (s:CDefMPort) : mports[name(s)] = direction(s) map(infer-mdir-e{_,MRead},s) (s:Connect|BulkConnect) : infer-mdir-e(exp(s),MRead) infer-mdir-e(loc(s),MWrite) s (s) : map{infer-mdir-e{_,MRead},_} $ map(infer-mdir-s,s) defn set-mdir-s (s:Stmt) -> Stmt : match(s) : (s:CDefMPort) : CDefMPort(info(s),name(s),type(s),mem(s),exps(s),mports[name(s)]) (s) : map(set-mdir-s,s) match(m) : (m:InModule) : infer-mdir-s(body(m)) InModule(info(m),name(m),ports(m),set-mdir-s(body(m))) (m:ExModule) : m ; MAIN Circuit{info(c), _, main(c) } $ for m in modules(c) map : infer-mdir(m) ;========================================== public defstruct RemoveCHIRRTL <: Pass public defmethod pass (b:RemoveCHIRRTL) -> (Circuit -> Circuit) : remove-chirrtl public defmethod name (b:RemoveCHIRRTL) -> String : "Remove CHIRRTL" public defmethod short-name (b:RemoveCHIRRTL) -> String : "remove-chirrtl" defstruct MPort : name : Symbol clk : Expression defstruct MPorts : readers : Vector writers : Vector readwriters : Vector defstruct DataRef : exp : Expression male : Symbol female : Symbol mask : Symbol public definterface Gender public val MALE = new Gender public val FEMALE = new Gender defn create-exps (e:Expression) -> List : match(e) : (e:Mux) : for (e1 in create-exps(tval(e)), e2 in create-exps(fval(e))) map : Mux(cond(e),e1,e2,mux-type(e1,e2)) (e:ValidIf) : for e1 in create-exps(value(e)) map : ValidIf(cond(e),e1,type(e1)) (e) : match(type(e)) : (t:UIntType|SIntType|ClockType) : list(e) (t:BundleType) : for f in fields(t) map-append : create-exps(SubField(e,name(f),type(f))) (t:VectorType) : for i in 0 to size(t) map-append : create-exps(SubIndex(e,i,type(t))) (t:UnknownType) : list(e) defn remove-chirrtl (c:Circuit) : defn remove-chirrtl-m (m:InModule) -> InModule : val hash = HashTable(symbol-hash) val sh = get-sym-hash(m,keys(v-keywords)) val repl = HashTable(symbol-hash) val ut = UnknownType() val mport-types = HashTable(symbol-hash) defn EMPs () -> MPorts : MPorts(Vector(),Vector(),Vector()) defn collect-mports (s:Stmt) -> Stmt : match(s) : (s:CDefMPort) : val mports = get?(hash,mem(s),EMPs()) switch { _ == direction(s) } : MRead : add(readers(mports),MPort(name(s),exps(s)[1])) MWrite : add(writers(mports),MPort(name(s),exps(s)[1])) MReadWrite : add(readwriters(mports),MPort(name(s),exps(s)[1])) hash[mem(s)] = mports s (s) : map(collect-mports,s) defn collect-refs (s:Stmt) -> Stmt : match(s) : (s:CDefMemory) : mport-types[name(s)] = type(s) val stmts = Vector() val naddr = firrtl-gensym(`GEN,sh) val taddr = UIntType(IntWidth(max(1,ceil-log2(size(s))))) add(stmts,DefPoison(info(s),naddr,taddr)) val ndata = firrtl-gensym(`GEN,sh) val tdata = type(s) add(stmts,DefPoison(info(s),ndata,tdata)) defn set-poison (vec:List,addr:Symbol) -> False : for r in vec do : add(stmts,Connect(info(s),SubField(SubField(Ref(name(s),ut),name(r),ut),addr,taddr),Ref(naddr,taddr))) add(stmts,Connect(info(s),SubField(SubField(Ref(name(s),ut),name(r),ut),`clk,taddr),clk(r))) defn set-enable (vec:List,en:Symbol) -> False: for r in vec do : add(stmts,Connect(info(s),SubField(SubField(Ref(name(s),ut),name(r),ut),en,taddr),zero)) defn set-rmode (vec:List,rmode:Symbol) -> False: for r in vec do : add(stmts,Connect(info(s),SubField(SubField(Ref(name(s),ut),name(r),ut),rmode,taddr),one)) defn set-write (vec:List,data:Symbol,mask:Symbol) -> False : val tmask = create-mask(type(s)) for r in vec do : add(stmts,Connect(info(s),SubField(SubField(Ref(name(s),ut),name(r),ut),data,tdata),Ref(ndata,tdata))) for x in create-exps(SubField(SubField(Ref(name(s),ut),name(r),ut),mask,tmask)) do : add(stmts,Connect(info(s),x,zero)) val rds = to-list $ readers $ get?(hash,name(s),EMPs()) set-poison(rds,`addr) set-enable(rds,`en) val wrs = to-list $ writers $ get?(hash,name(s),EMPs()) set-poison(wrs,`addr) set-enable(wrs,`en) set-write(wrs,`data,`mask) val rws = to-list $ readwriters $ get?(hash,name(s),EMPs()) set-poison(rws,`addr) set-rmode(rws,`rmode) set-enable(rws,`en) set-write(rws,`data,`mask) val read-l = if seq?(s) : 1 else : 0 val mem = DefMemory(info(s),name(s),type(s),size(s),1,read-l,map(name,rds),map(name,wrs),map(name,rws)) Begin $ List(mem,to-list(stmts)) (s:CDefMPort) : mport-types[name(s)] = mport-types[mem(s)] val addrs = Vector() val ens = Vector() val masks = Vector() val rmodes = Vector() switch { _ == direction(s) } : MReadWrite : repl[name(s)] = DataRef(SubField(Ref(mem(s),ut),name(s),ut),`rdata,`data,`mask) add(addrs,`addr) add(ens,`en) add(rmodes,`rmode) add(masks,`mask) MWrite : repl[name(s)] = DataRef(SubField(Ref(mem(s),ut),name(s),ut),`data,`data,`mask) add(addrs,`addr) add(ens,`en) add(masks,`mask) else : repl[name(s)] = DataRef(SubField(Ref(mem(s),ut),name(s),ut),`data,`data,`blah) add(addrs,`addr) add(ens,`en) val stmts = Vector() for x in addrs do : add(stmts,Connect(info(s),SubField(SubField(Ref(mem(s),ut),name(s),ut),x,ut),exps(s)[0])) for x in ens do : add(stmts,Connect(info(s),SubField(SubField(Ref(mem(s),ut),name(s),ut),x,ut),one)) for x in rmodes do : add(stmts,Connect(info(s),SubField(SubField(Ref(mem(s),ut),name(s),ut),x,ut),zero)) Begin $ to-list $ stmts (s) : map(collect-refs,s) defn remove-chirrtl-s (s:Stmt) -> Stmt : var has-write-mport? = false defn remove-chirrtl-e (e:Expression,g:Gender) -> Expression : match(map(remove-chirrtl-e{_,g},e)) : (e:Ref) : if key?(repl,name(e)) : val vt = repl[name(e)] switch {g == _ }: MALE : SubField(exp(vt),male(vt),type(e)) FEMALE : has-write-mport? = true SubField(exp(vt),female(vt),type(e)) else : e (e) : e defn get-mask (e:Expression) -> Expression : match(map(get-mask,e)) : (e:Ref) : if key?(repl,name(e)) : val vt = repl[name(e)] val t = create-mask(type(e)) SubField(exp(vt),mask(vt),t) else : e (e) : e match(s) : (s:Connect) : val stmts = Vector() val roc* = remove-chirrtl-e(exp(s),MALE) val loc* = remove-chirrtl-e(loc(s),FEMALE) add(stmts,Connect(info(s),loc*,roc*)) if has-write-mport? : val e = get-mask(loc(s)) for x in create-exps(e) do : add(stmts,Connect(info(s),x,one)) if length(stmts) > 1 : Begin(to-list(stmts)) else : stmts[0] (s:BulkConnect) : val stmts = Vector() val loc* = remove-chirrtl-e(loc(s),FEMALE) val roc* = remove-chirrtl-e(exp(s),MALE) add(stmts,BulkConnect(info(s),loc*,roc*)) if has-write-mport? : val ls = get-valid-points(type(loc(s)),type(exp(s)),DEFAULT,DEFAULT) val locs = create-exps(get-mask(loc(s))) for x in ls do : val loc* = locs[x[0]] add(stmts,Connect(info(s),loc*,one)) if length(stmts) > 1 : Begin(to-list(stmts)) else : stmts[0] (s) : map(remove-chirrtl-e{_,MALE}, map(remove-chirrtl-s,s)) collect-mports(body(m)) val s* = collect-refs(body(m)) InModule(info(m),name(m), ports(m), remove-chirrtl-s(s*)) Circuit(info(c),modules*, main(c)) where : val modules* = for m in modules(c) map : match(m) : (m:InModule) : remove-chirrtl-m(m) (m:ExModule) : m