diff options
Diffstat (limited to 'src/main/stanza/chirrtl.stanza')
| -rw-r--r-- | src/main/stanza/chirrtl.stanza | 177 |
1 files changed, 155 insertions, 22 deletions
diff --git a/src/main/stanza/chirrtl.stanza b/src/main/stanza/chirrtl.stanza index 049d0b0d..fd39c001 100644 --- a/src/main/stanza/chirrtl.stanza +++ b/src/main/stanza/chirrtl.stanza @@ -3,11 +3,146 @@ defpackage firrtl/chirrtl : import verse import firrtl/ir2 import firrtl/ir-utils + import firrtl/primops -public defstruct ToIR <: Pass -public defmethod pass (b:ToIR) -> (Circuit -> Circuit) : to-ir -public defmethod name (b:ToIR) -> String : "To FIRRTL" -public defmethod short-name (b:ToIR) -> String : "to-firrtl" +; CHIRRTL Additional IR Nodes +public definterface MPortDir +public val MRead = new MPortDir +public val MWrite = new MPortDir +public val MReadWrite = new MPortDir + +defmethod print (o:OutputStream, m:MPortDir) : + switch { m == _ } : + MRead : print(o,"read") + MWrite : print(o,"write") + MReadWrite : print(o,"rdwr") + +public defstruct CDefMemory <: Stmt : ;LOW + info: FileInfo with: (as-method => true) + name: Symbol + type: Type + size: Int + seq?: True|False +public defstruct CDefMPort <: Stmt : + info: FileInfo with: (as-method => true) + name: Symbol + type: Type + mem: Symbol + exps: List<Expression> + direction: MPortDir + +defmethod print (o:OutputStream,c:CDefMemory) : + if seq?(c) : print-all(o, ["smem " name(c) " : " type(c) "[" size(c) "]"]) + else : print-all(o, ["cmem " name(c) " : " type(c) "[" size(c) "]"]) +defmethod map (f: Type -> Type, c:CDefMemory) -> CDefMemory : + CDefMemory(info(c),name(c),f(type(c)),size(c),seq?(c)) +defmethod map (f: Symbol -> Symbol, c:CDefMemory) -> CDefMemory : + CDefMemory(info(c),f(name(c)),type(c),size(c),seq?(c)) + +defmethod print (o:OutputStream,c:CDefMPort) : + print-all(o, [direction(c) " mport " name(c) " = " mem(c) "[" exps(c)[0] "], " exps(c)[1]]) +defmethod map (f: Expression -> Expression, c:CDefMPort) -> CDefMPort : + CDefMPort(info(c),name(c),type(c),mem(c),map(f,exps(c)),direction(c)) +defmethod map (f: Type -> Type, c:CDefMPort) -> CDefMPort : + CDefMPort(info(c),name(c),f(type(c)),mem(c),exps(c),direction(c)) +defmethod map (f: Symbol -> Symbol, c:CDefMPort) -> CDefMPort : + CDefMPort(info(c),f(name(c)),type(c),mem(c),exps(c),direction(c)) + + +;======================= Infer Chirrtl Types ====================== +public defstruct CInferTypes <: Pass +public defmethod pass (b:CInferTypes) -> (Circuit -> Circuit) : infer-types +public defmethod name (b:CInferTypes) -> String : "Infer CTypes" +public defmethod short-name (b:CInferTypes) -> String : "infer-ctypes" + +;--------------- Utils ----------------- + +defn set-type (s:Stmt,t:Type) -> Stmt : + match(s) : + (s:DefWire) : DefWire(info(s),name(s),t) + (s:DefRegister) : DefRegister(info(s),name(s),t,clock(s),reset(s),init(s)) + (s:CDefMemory) : CDefMemory(info(s),name(s),t,size(s),seq?(s)) + (s:CDefMPort) : CDefMPort(info(s),name(s),t,mem(s),exps(s),direction(s)) + (s:DefNode) : s + (s:DefPoison) : DefPoison(info(s),name(s),t) + +defn to-field (p:Port) -> Field : + if direction(p) == OUTPUT : Field(name(p),DEFAULT,type(p)) + else if direction(p) == INPUT : Field(name(p),REVERSE,type(p)) + else : error("Shouldn't be here") +defn module-type (m:Module) -> Type : + BundleType(for p in ports(m) map : to-field(p)) +defn field-type (v:Type,s:Symbol) -> Type : + match(v) : + (v:BundleType) : + val ft = for p in fields(v) find : name(p) == s + if ft != false : type(ft as Field) + else : UnknownType() + (v) : UnknownType() +defn sub-type (v:Type) -> Type : + match(v) : + (v:VectorType) : type(v) + (v) : UnknownType() + +;--------------- Pass ----------------- + +defn infer-types (c:Circuit) -> Circuit : + val module-types = HashTable<Symbol,Type>(symbol-hash) + defn infer-types (m:Module) -> Module : + val types = HashTable<Symbol,Type>(symbol-hash) + defn infer-types-e (e:Expression) -> Expression : + match(map(infer-types-e,e)) : + (e:Ref) : Ref(name(e), types[name(e)]) + (e:SubField) : SubField(exp(e),name(e),field-type(type(exp(e)),name(e))) + (e:SubIndex) : SubIndex(exp(e),value(e),sub-type(type(exp(e)))) + (e:SubAccess) : SubAccess(exp(e),index(e),sub-type(type(exp(e)))) + (e:DoPrim) : set-primop-type(e) + (e:UIntValue|SIntValue) : e + defn infer-types-s (s:Stmt) -> Stmt : + match(s) : + (s:DefRegister) : + types[name(s)] = type(s) + map(infer-types-e,s) + s + (s:DefWire|DefPoison) : + types[name(s)] = type(s) + s + (s:DefNode) : + val s* = map(infer-types-e,s) + val t = type(value(s*)) + types[name(s*)] = t + s* + (s:CDefMPort) : + val t = types[mem(s)] + types[name(s)] = t + CDefMPort(info(s),name(s),t,mem(s),exps(s),direction(s)) + (s:CDefMemory) : + types[name(s)] = type(s) + s + (s:DefInstance) : + types[name(s)] = module-types[module(s)] + s + (s) : map{infer-types-e,_} $ map(infer-types-s,s) + for p in ports(m) do : + types[name(p)] = type(p) + match(m) : + (m:InModule) : + InModule(info(m),name(m),ports(m),infer-types-s(body(m))) + (m:ExModule) : m + + ; MAIN + for m in modules(c) do : + module-types[name(m)] = module-type(m) + Circuit{info(c), _, main(c) } $ + for m in modules(c) map : + infer-types(m) + +;========================================== + +public defstruct RemoveCHIRRTL <: Pass +public defmethod pass (b:RemoveCHIRRTL) -> (Circuit -> Circuit) : remove-chirrtl +public defmethod name (b:RemoveCHIRRTL) -> String : "Remove CHIRRTL" +public defmethod short-name (b:RemoveCHIRRTL) -> String : "remove-chirrtl" defstruct MPort : name : Symbol @@ -36,8 +171,8 @@ defn create-exps (e:Expression) -> List<Expression> : for i in 0 to size(t) map-append : create-exps(SubIndex(e,i,type(t))) -defn to-ir (c:Circuit) : - defn to-ir-m (m:InModule) -> InModule : +defn remove-chirrtl (c:Circuit) : + defn remove-chirrtl-m (m:InModule) -> InModule : val hash = HashTable<Symbol,MPorts>(symbol-hash) val sh = get-sym-hash(m,keys(v-keywords)) val repl = HashTable<Symbol,DataRef>(symbol-hash) @@ -126,14 +261,12 @@ defn to-ir (c:Circuit) : add(stmts,Connect(info(s),SubField(SubField(Ref(mem(s),ut),name(s),ut),x,ut),exps(s)[0])) for x in ens do : add(stmts,Connect(info(s),SubField(SubField(Ref(mem(s),ut),name(s),ut),x,ut),one)) - for x in masks do : - add(stmts,Connect(info(s),SubField(SubField(Ref(mem(s),ut),name(s),ut),x,ut),exps(s)[2])) Begin $ to-list $ stmts (s) : map(collect-refs,s) - defn to-ir-s (s:Stmt) -> Stmt : + defn remove-chirrtl-s (s:Stmt) -> Stmt : var has-write-mport? = false - defn to-ir-e (e:Expression,g:Gender) -> Expression : - match(map(to-ir-e{_,g},e)) : + defn remove-chirrtl-e (e:Expression,g:Gender) -> Expression : + match(map(remove-chirrtl-e{_,g},e)) : (e:Ref) : if key?(repl,name(e)) : val vt = repl[name(e)] @@ -156,37 +289,37 @@ defn to-ir (c:Circuit) : match(s) : (s:Connect) : val stmts = Vector<Stmt>() - val roc* = to-ir-e(exp(s),MALE) - val loc* = to-ir-e(loc(s),FEMALE) + val roc* = remove-chirrtl-e(exp(s),MALE) + val loc* = remove-chirrtl-e(loc(s),FEMALE) add(stmts,Connect(info(s),loc*,roc*)) if has-write-mport? : val e = get-mask(loc(s)) for x in create-exps(e) do : - add(stmts,Connect(info(s),x,UInt(1))) - if length(stmts > 1) : Begin(to-list(stmts)) + add(stmts,Connect(info(s),x,one)) + if length(stmts) > 1 : Begin(to-list(stmts)) else : stmts[0] (s:BulkConnect) : val stmts = Vector<Stmt>() - val loc* = to-ir-e(loc(s),FEMALE) - val roc* = to-ir-e(exp(s),MALE) + val loc* = remove-chirrtl-e(loc(s),FEMALE) + val roc* = remove-chirrtl-e(exp(s),MALE) add(stmts,BulkConnect(info(s),loc*,roc*)) if has-write-mport? : val ls = get-valid-points(type(loc(s)),type(exp(s)),DEFAULT,DEFAULT) val locs = create-exps(get-mask(loc(s))) for x in ls do : val loc* = locs[x[0]] - add(stmts,Connect(info(s),loc*,UInt(1))) - if length(stmts > 1) : Begin(to-list(stmts)) + add(stmts,Connect(info(s),loc*,one)) + if length(stmts) > 1 : Begin(to-list(stmts)) else : stmts[0] - (s) : map(to-ir-e{_,MALE}, map(to-ir-s,s)) + (s) : map(remove-chirrtl-e{_,MALE}, map(remove-chirrtl-s,s)) collect-mports(body(m)) val s* = collect-refs(body(m)) - InModule(info(m),name(m), ports(m), to-ir-s(s*)) + InModule(info(m),name(m), ports(m), remove-chirrtl-s(s*)) Circuit(info(c),modules*, main(c)) where : val modules* = for m in modules(c) map : match(m) : - (m:InModule) : to-ir-m(m) + (m:InModule) : remove-chirrtl-m(m) (m:ExModule) : m |
