diff options
| author | azidar | 2016-01-23 06:58:59 -0800 |
|---|---|---|
| committer | azidar | 2016-01-23 06:58:59 -0800 |
| commit | 99062792e5006dbf4c6b1f97da9121bbd6217c7a (patch) | |
| tree | b8cf5153edecc108a65894da5b00c03c05e9fffb /src | |
| parent | cce5603ac7f5765434ec8239053b1fde74a2c67f (diff) | |
Changed chirrtl to not require known mask values
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/stanza/chirrtl.stanza | 177 | ||||
| -rw-r--r-- | src/main/stanza/compilers.stanza | 6 | ||||
| -rw-r--r-- | src/main/stanza/errors.stanza | 3 | ||||
| -rw-r--r-- | src/main/stanza/firrtl-ir.stanza | 19 | ||||
| -rw-r--r-- | src/main/stanza/ir-parser.stanza | 7 | ||||
| -rw-r--r-- | src/main/stanza/ir-utils.stanza | 77 | ||||
| -rw-r--r-- | src/main/stanza/passes.stanza | 67 |
7 files changed, 227 insertions, 129 deletions
diff --git a/src/main/stanza/chirrtl.stanza b/src/main/stanza/chirrtl.stanza index 049d0b0d..fd39c001 100644 --- a/src/main/stanza/chirrtl.stanza +++ b/src/main/stanza/chirrtl.stanza @@ -3,11 +3,146 @@ defpackage firrtl/chirrtl : import verse import firrtl/ir2 import firrtl/ir-utils + import firrtl/primops -public defstruct ToIR <: Pass -public defmethod pass (b:ToIR) -> (Circuit -> Circuit) : to-ir -public defmethod name (b:ToIR) -> String : "To FIRRTL" -public defmethod short-name (b:ToIR) -> String : "to-firrtl" +; CHIRRTL Additional IR Nodes +public definterface MPortDir +public val MRead = new MPortDir +public val MWrite = new MPortDir +public val MReadWrite = new MPortDir + +defmethod print (o:OutputStream, m:MPortDir) : + switch { m == _ } : + MRead : print(o,"read") + MWrite : print(o,"write") + MReadWrite : print(o,"rdwr") + +public defstruct CDefMemory <: Stmt : ;LOW + info: FileInfo with: (as-method => true) + name: Symbol + type: Type + size: Int + seq?: True|False +public defstruct CDefMPort <: Stmt : + info: FileInfo with: (as-method => true) + name: Symbol + type: Type + mem: Symbol + exps: List<Expression> + direction: MPortDir + +defmethod print (o:OutputStream,c:CDefMemory) : + if seq?(c) : print-all(o, ["smem " name(c) " : " type(c) "[" size(c) "]"]) + else : print-all(o, ["cmem " name(c) " : " type(c) "[" size(c) "]"]) +defmethod map (f: Type -> Type, c:CDefMemory) -> CDefMemory : + CDefMemory(info(c),name(c),f(type(c)),size(c),seq?(c)) +defmethod map (f: Symbol -> Symbol, c:CDefMemory) -> CDefMemory : + CDefMemory(info(c),f(name(c)),type(c),size(c),seq?(c)) + +defmethod print (o:OutputStream,c:CDefMPort) : + print-all(o, [direction(c) " mport " name(c) " = " mem(c) "[" exps(c)[0] "], " exps(c)[1]]) +defmethod map (f: Expression -> Expression, c:CDefMPort) -> CDefMPort : + CDefMPort(info(c),name(c),type(c),mem(c),map(f,exps(c)),direction(c)) +defmethod map (f: Type -> Type, c:CDefMPort) -> CDefMPort : + CDefMPort(info(c),name(c),f(type(c)),mem(c),exps(c),direction(c)) +defmethod map (f: Symbol -> Symbol, c:CDefMPort) -> CDefMPort : + CDefMPort(info(c),f(name(c)),type(c),mem(c),exps(c),direction(c)) + + +;======================= Infer Chirrtl Types ====================== +public defstruct CInferTypes <: Pass +public defmethod pass (b:CInferTypes) -> (Circuit -> Circuit) : infer-types +public defmethod name (b:CInferTypes) -> String : "Infer CTypes" +public defmethod short-name (b:CInferTypes) -> String : "infer-ctypes" + +;--------------- Utils ----------------- + +defn set-type (s:Stmt,t:Type) -> Stmt : + match(s) : + (s:DefWire) : DefWire(info(s),name(s),t) + (s:DefRegister) : DefRegister(info(s),name(s),t,clock(s),reset(s),init(s)) + (s:CDefMemory) : CDefMemory(info(s),name(s),t,size(s),seq?(s)) + (s:CDefMPort) : CDefMPort(info(s),name(s),t,mem(s),exps(s),direction(s)) + (s:DefNode) : s + (s:DefPoison) : DefPoison(info(s),name(s),t) + +defn to-field (p:Port) -> Field : + if direction(p) == OUTPUT : Field(name(p),DEFAULT,type(p)) + else if direction(p) == INPUT : Field(name(p),REVERSE,type(p)) + else : error("Shouldn't be here") +defn module-type (m:Module) -> Type : + BundleType(for p in ports(m) map : to-field(p)) +defn field-type (v:Type,s:Symbol) -> Type : + match(v) : + (v:BundleType) : + val ft = for p in fields(v) find : name(p) == s + if ft != false : type(ft as Field) + else : UnknownType() + (v) : UnknownType() +defn sub-type (v:Type) -> Type : + match(v) : + (v:VectorType) : type(v) + (v) : UnknownType() + +;--------------- Pass ----------------- + +defn infer-types (c:Circuit) -> Circuit : + val module-types = HashTable<Symbol,Type>(symbol-hash) + defn infer-types (m:Module) -> Module : + val types = HashTable<Symbol,Type>(symbol-hash) + defn infer-types-e (e:Expression) -> Expression : + match(map(infer-types-e,e)) : + (e:Ref) : Ref(name(e), types[name(e)]) + (e:SubField) : SubField(exp(e),name(e),field-type(type(exp(e)),name(e))) + (e:SubIndex) : SubIndex(exp(e),value(e),sub-type(type(exp(e)))) + (e:SubAccess) : SubAccess(exp(e),index(e),sub-type(type(exp(e)))) + (e:DoPrim) : set-primop-type(e) + (e:UIntValue|SIntValue) : e + defn infer-types-s (s:Stmt) -> Stmt : + match(s) : + (s:DefRegister) : + types[name(s)] = type(s) + map(infer-types-e,s) + s + (s:DefWire|DefPoison) : + types[name(s)] = type(s) + s + (s:DefNode) : + val s* = map(infer-types-e,s) + val t = type(value(s*)) + types[name(s*)] = t + s* + (s:CDefMPort) : + val t = types[mem(s)] + types[name(s)] = t + CDefMPort(info(s),name(s),t,mem(s),exps(s),direction(s)) + (s:CDefMemory) : + types[name(s)] = type(s) + s + (s:DefInstance) : + types[name(s)] = module-types[module(s)] + s + (s) : map{infer-types-e,_} $ map(infer-types-s,s) + for p in ports(m) do : + types[name(p)] = type(p) + match(m) : + (m:InModule) : + InModule(info(m),name(m),ports(m),infer-types-s(body(m))) + (m:ExModule) : m + + ; MAIN + for m in modules(c) do : + module-types[name(m)] = module-type(m) + Circuit{info(c), _, main(c) } $ + for m in modules(c) map : + infer-types(m) + +;========================================== + +public defstruct RemoveCHIRRTL <: Pass +public defmethod pass (b:RemoveCHIRRTL) -> (Circuit -> Circuit) : remove-chirrtl +public defmethod name (b:RemoveCHIRRTL) -> String : "Remove CHIRRTL" +public defmethod short-name (b:RemoveCHIRRTL) -> String : "remove-chirrtl" defstruct MPort : name : Symbol @@ -36,8 +171,8 @@ defn create-exps (e:Expression) -> List<Expression> : for i in 0 to size(t) map-append : create-exps(SubIndex(e,i,type(t))) -defn to-ir (c:Circuit) : - defn to-ir-m (m:InModule) -> InModule : +defn remove-chirrtl (c:Circuit) : + defn remove-chirrtl-m (m:InModule) -> InModule : val hash = HashTable<Symbol,MPorts>(symbol-hash) val sh = get-sym-hash(m,keys(v-keywords)) val repl = HashTable<Symbol,DataRef>(symbol-hash) @@ -126,14 +261,12 @@ defn to-ir (c:Circuit) : add(stmts,Connect(info(s),SubField(SubField(Ref(mem(s),ut),name(s),ut),x,ut),exps(s)[0])) for x in ens do : add(stmts,Connect(info(s),SubField(SubField(Ref(mem(s),ut),name(s),ut),x,ut),one)) - for x in masks do : - add(stmts,Connect(info(s),SubField(SubField(Ref(mem(s),ut),name(s),ut),x,ut),exps(s)[2])) Begin $ to-list $ stmts (s) : map(collect-refs,s) - defn to-ir-s (s:Stmt) -> Stmt : + defn remove-chirrtl-s (s:Stmt) -> Stmt : var has-write-mport? = false - defn to-ir-e (e:Expression,g:Gender) -> Expression : - match(map(to-ir-e{_,g},e)) : + defn remove-chirrtl-e (e:Expression,g:Gender) -> Expression : + match(map(remove-chirrtl-e{_,g},e)) : (e:Ref) : if key?(repl,name(e)) : val vt = repl[name(e)] @@ -156,37 +289,37 @@ defn to-ir (c:Circuit) : match(s) : (s:Connect) : val stmts = Vector<Stmt>() - val roc* = to-ir-e(exp(s),MALE) - val loc* = to-ir-e(loc(s),FEMALE) + val roc* = remove-chirrtl-e(exp(s),MALE) + val loc* = remove-chirrtl-e(loc(s),FEMALE) add(stmts,Connect(info(s),loc*,roc*)) if has-write-mport? : val e = get-mask(loc(s)) for x in create-exps(e) do : - add(stmts,Connect(info(s),x,UInt(1))) - if length(stmts > 1) : Begin(to-list(stmts)) + add(stmts,Connect(info(s),x,one)) + if length(stmts) > 1 : Begin(to-list(stmts)) else : stmts[0] (s:BulkConnect) : val stmts = Vector<Stmt>() - val loc* = to-ir-e(loc(s),FEMALE) - val roc* = to-ir-e(exp(s),MALE) + val loc* = remove-chirrtl-e(loc(s),FEMALE) + val roc* = remove-chirrtl-e(exp(s),MALE) add(stmts,BulkConnect(info(s),loc*,roc*)) if has-write-mport? : val ls = get-valid-points(type(loc(s)),type(exp(s)),DEFAULT,DEFAULT) val locs = create-exps(get-mask(loc(s))) for x in ls do : val loc* = locs[x[0]] - add(stmts,Connect(info(s),loc*,UInt(1))) - if length(stmts > 1) : Begin(to-list(stmts)) + add(stmts,Connect(info(s),loc*,one)) + if length(stmts) > 1 : Begin(to-list(stmts)) else : stmts[0] - (s) : map(to-ir-e{_,MALE}, map(to-ir-s,s)) + (s) : map(remove-chirrtl-e{_,MALE}, map(remove-chirrtl-s,s)) collect-mports(body(m)) val s* = collect-refs(body(m)) - InModule(info(m),name(m), ports(m), to-ir-s(s*)) + InModule(info(m),name(m), ports(m), remove-chirrtl-s(s*)) Circuit(info(c),modules*, main(c)) where : val modules* = for m in modules(c) map : match(m) : - (m:InModule) : to-ir-m(m) + (m:InModule) : remove-chirrtl-m(m) (m:ExModule) : m diff --git a/src/main/stanza/compilers.stanza b/src/main/stanza/compilers.stanza index 4878dce5..167efc26 100644 --- a/src/main/stanza/compilers.stanza +++ b/src/main/stanza/compilers.stanza @@ -50,8 +50,8 @@ public defmethod passes (c:StandardVerilog) -> List<Pass> : ;RemoveSpecialChars() ;TempElimination() ; Needs to check number of uses ;=============== - InferTypes() - ToIR() + CInferTypes() + RemoveCHIRRTL() ;=============== CheckHighForm() ;=============== @@ -122,7 +122,7 @@ public defmethod backend (c:StandardLoFIRRTL) -> List<Pass> : public defmethod passes (c:StandardLoFIRRTL) -> List<Pass> : to-list $ [ ;=============== - ToIR() + RemoveCHIRRTL() ;=============== CheckHighForm() ;=============== diff --git a/src/main/stanza/errors.stanza b/src/main/stanza/errors.stanza index 67dfef8d..7bac85bb 100644 --- a/src/main/stanza/errors.stanza +++ b/src/main/stanza/errors.stanza @@ -665,8 +665,9 @@ public defn check-genders (c:Circuit) -> Circuit : (t) : if f == REVERSE : f? = true t flip-rec(t,DEFAULT) - val has-flip? = flip?(t) + f? + val has-flip? = flip?(type(e)) ;println(e) ;println(gender) ;println(desired) diff --git a/src/main/stanza/firrtl-ir.stanza b/src/main/stanza/firrtl-ir.stanza index d9586df9..ace1e76f 100644 --- a/src/main/stanza/firrtl-ir.stanza +++ b/src/main/stanza/firrtl-ir.stanza @@ -161,25 +161,6 @@ public defstruct Print <: Stmt : ;LOW public defstruct Empty <: Stmt ;LOW -; CHIRRTL Features -public definterface MPortDir -public val MRead = new MPortDir -public val MWrite = new MPortDir -public val MReadWrite = new MPortDir - -public defstruct CDefMemory <: Stmt : ;LOW - info: FileInfo with: (as-method => true) - name: Symbol - type: Type - size: Int - seq?: True|False -public defstruct CDefMPort <: Stmt : - info: FileInfo with: (as-method => true) - name: Symbol - mem: Symbol - exps: List<Expression> - direction: MPortDir - public definterface Type public defstruct UIntType <: Type : width: Width diff --git a/src/main/stanza/ir-parser.stanza b/src/main/stanza/ir-parser.stanza index d891227f..951d341c 100644 --- a/src/main/stanza/ir-parser.stanza +++ b/src/main/stanza/ir-parser.stanza @@ -6,6 +6,7 @@ defpackage firrtl/parser : import firrtl/lexer import bigint2 import firrtl/ir-utils + import firrtl/chirrtl ;======= Convenience Types =========== definterface MStat @@ -264,11 +265,11 @@ defsyntax firrtl : stmt = (smem ?name:#id! #:! ?t:#vectype! ) : CDefMemory(first-info(form),name,type(t),size(t),true) stmt = (read mport ?name:#id! #=! ?mem:#id! (@get ?index:#exp!) ?clk:#exp!) : - CDefMPort(first-info(form),name,mem,list(index,clk),MRead) + CDefMPort(first-info(form),name,UnknownType(),mem,list(index,clk),MRead) stmt = (write mport ?name:#id! #=! ?mem:#id! (@get ?index:#exp!) ?clk:#exp!) : - CDefMPort(first-info(form),name,mem,list(index,clk),MWrite) + CDefMPort(first-info(form),name,UnknownType(),mem,list(index,clk),MWrite) stmt = (rdwr mport ?name:#id! #=! ?mem:#id! (@get ?index:#exp!) ?clk:#exp!) : - CDefMPort(first-info(form),name,mem,list(index,clk),MReadWrite) + CDefMPort(first-info(form),name,UnknownType(),mem,list(index,clk),MReadWrite) stmt = (mem ?name:#id! #:! (?ms:#mstat ...)) : defn grab (f:MStat -> True|False) : diff --git a/src/main/stanza/ir-utils.stanza b/src/main/stanza/ir-utils.stanza index 2f4bf973..7af75b57 100644 --- a/src/main/stanza/ir-utils.stanza +++ b/src/main/stanza/ir-utils.stanza @@ -148,6 +148,69 @@ public defn list-hash (l:List) -> Int : turn-on-debug(false) i +;===== Type Expansion Algorithms ========= +public defn times (f1:Flip,f2:Flip) -> Flip : + switch {_ == f2} : + DEFAULT : f1 + REVERSE : swap(f1) +public defn swap (f:Flip) -> Flip : + switch {_ == f} : + DEFAULT : REVERSE + REVERSE : DEFAULT + +public defn get-size (t:Type) -> Int : + val x = match(t) : + (t:BundleType) : + var sum = 0 + for f in fields(t) do : + sum = sum + get-size(type(f)) + sum + (t:VectorType) : size(t) * get-size(type(t)) + (t) : 1 + x +public defn get-valid-points (t1:Type,t2:Type,flip1:Flip,flip2:Flip) -> List<[Int,Int]> : + ;println-all(["Inside with t1:" t1 ",t2:" t2 ",f1:" flip1 ",f2:" flip2]) + match(t1,t2) : + (t1:UIntType,t2:UIntType) : + if flip1 == flip2 : list([0, 0]) + else: list() + (t1:SIntType,t2:SIntType) : + if flip1 == flip2 : list([0, 0]) + else: list() + (t1:BundleType,t2:BundleType) : + val points = Vector<[Int,Int]>() + var ilen = 0 + var jlen = 0 + for i in 0 to length(fields(t1)) do : + for j in 0 to length(fields(t2)) do : + ;println(i) + ;println(j) + ;println(ilen) + ;println(jlen) + val f1 = fields(t1)[i] + val f2 = fields(t2)[j] + if name(f1) == name(f2) : + val ls = get-valid-points(type(f1),type(f2),flip1 * flip(f1), + flip2 * flip(f2)) + for x in ls do : + add(points,[x[0] + ilen, x[1] + jlen]) + println(points) + jlen = jlen + get-size(type(fields(t2)[j])) + ilen = ilen + get-size(type(fields(t1)[i])) + jlen = 0 + to-list(points) + (t1:VectorType,t2:VectorType) : + val points = Vector<[Int,Int]>() + var ilen = 0 + var jlen = 0 + for i in 0 to min(size(t1),size(t2)) do : + val ls = get-valid-points(type(t1),type(t2),flip1,flip2) + for x in ls do : + add(points,[x[0] + ilen, x[1] + jlen]) + ilen = ilen + get-size(type(t1)) + jlen = jlen + get-size(type(t2)) + to-list(points) + ;============= Useful functions ============== public defn create-mask (n:Symbol,dt:Type) -> Field : Field{n,DEFAULT,_} $ match(dt) : @@ -350,22 +413,9 @@ defmethod print (o:OutputStream, c:Stmt) : print-all(o, ["printf(" clk(c) ", " en(c) ", "]) ;" print-all(o, join(List(escape(string(c)),args(c)), ", ")) print(o, ")") - (c:CDefMemory) : - if seq?(c) : - print-all(o, ["smem " name(c) " : " type(c) "[" size(c) "]"]) - else : - print-all(o, ["cmem " name(c) " : " type(c) "[" size(c) "]"]) - (c:CDefMPort) : - print-all(o, [direction(c) " mport " name(c) " = " mem(c) "[" exps(c)[0] "], " exps(c)[1]]) if not c typeof Conditionally|Begin|Empty: print-debug(o,c) -defmethod print (o:OutputStream, m:MPortDir) : - switch { m == _ } : - MRead : print(o,"read") - MWrite : print(o,"write") - MReadWrite : print(o,"rdwr") - defmethod print (o:OutputStream, t:Type) : match(t) : (t:UnknownType) : @@ -466,7 +516,6 @@ defmethod map (f: Expression -> Expression, c:Stmt) -> Stmt : (c:BulkConnect) : BulkConnect(info(c),f(loc(c)), f(exp(c))) (c:Stop) : Stop(info(c),ret(c),f(clk(c)),f(en(c))) (c:Print) : Print(info(c),string(c),map(f,args(c)),f(clk(c)),f(en(c))) - (c:CDefMPort) : CDefMPort(info(c),name(c), mem(c), f(clock(c)), f(reset(c)), f(init(c))) (c) : c public defmulti map<?T> (f: Stmt -> Stmt, c:?T&Stmt) -> T diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza index 4f40c8b7..ee650bde 100644 --- a/src/main/stanza/passes.stanza +++ b/src/main/stanza/passes.stanza @@ -444,10 +444,6 @@ defn swap (g:Gender) -> Gender : MALE : FEMALE FEMALE : MALE BI-GENDER : BI-GENDER -defn swap (f:Flip) -> Flip : - switch {_ == f} : - DEFAULT : REVERSE - REVERSE : DEFAULT defn swap (d:Direction) -> Direction : switch {_ == d} : OUTPUT : INPUT @@ -464,11 +460,6 @@ public defn times (flip:Flip,g:Gender) -> Gender : switch {_ == flip} : DEFAULT : g REVERSE : swap(g) -public defn times (f1:Flip,f2:Flip) -> Flip : - switch {_ == f2} : - DEFAULT : f1 - REVERSE : swap(f1) - defn to-field (p:Port) -> Field : if direction(p) == OUTPUT : Field(name(p),DEFAULT,type(p)) else if direction(p) == INPUT : Field(name(p),REVERSE,type(p)) @@ -713,10 +704,6 @@ defn infer-types (c:Circuit) -> Circuit : val types = HashTable<Symbol,Type>(symbol-hash) defn infer-types-e (e:Expression) -> Expression : match(map(infer-types-e,e)) : - (e:Ref) : Ref(name(e), types[name(e)]) - (e:SubField) : SubField(exp(e),name(e),field-type(type(exp(e)),name(e))) - (e:SubIndex) : SubIndex(exp(e),value(e),sub-type(type(exp(e)))) - (e:SubAccess) : SubAccess(exp(e),index(e),sub-type(type(exp(e)))) (e:WRef) : WRef(name(e), types[name(e)],kind(e),gender(e)) (e:WSubField) : WSubField(exp(e),name(e),field-type(type(exp(e)),name(e)),gender(e)) (e:WSubIndex) : WSubIndex(exp(e),value(e),sub-type(type(exp(e))),gender(e)) @@ -853,17 +840,6 @@ public defmethod short-name (b:ExpandConnects) -> String : "expand-connects" ;---------------- UTILS ------------------ defn get-size (e:Expression) -> Int : get-size(type(e)) -defn get-size (t:Type) -> Int : - val x = match(t) : - (t:BundleType) : - var sum = 0 - for f in fields(t) do : - sum = sum + get-size(type(f)) - sum - (t:VectorType) : size(t) * get-size(type(t)) - (t) : 1 - x - val hashed-get-flip = HashTable<List,Flip>(list-hash) defn get-flip (t:Type, i:Int, f:Flip) -> Flip : if key?(hashed-get-flip,list(t,i,f)) : hashed-get-flip[list(t,i,f)] @@ -904,49 +880,6 @@ defn get-point (e:Expression) -> Int : value(e) * get-size(e) (e:WSubAccess) : get-point(exp(e)) -defn get-valid-points (t1:Type,t2:Type,flip1:Flip,flip2:Flip) -> List<[Int,Int]> : - ;println-all(["Inside with t1:" t1 ",t2:" t2 ",f1:" flip1 ",f2:" flip2]) - match(t1,t2) : - (t1:UIntType,t2:UIntType) : - if flip1 == flip2 : list([0, 0]) - else: list() - (t1:SIntType,t2:SIntType) : - if flip1 == flip2 : list([0, 0]) - else: list() - (t1:BundleType,t2:BundleType) : - val points = Vector<[Int,Int]>() - var ilen = 0 - var jlen = 0 - for i in 0 to length(fields(t1)) do : - for j in 0 to length(fields(t2)) do : - ;println(i) - ;println(j) - ;println(ilen) - ;println(jlen) - val f1 = fields(t1)[i] - val f2 = fields(t2)[j] - if name(f1) == name(f2) : - val ls = get-valid-points(type(f1),type(f2),flip1 * flip(f1), - flip2 * flip(f2)) - for x in ls do : - add(points,[x[0] + ilen, x[1] + jlen]) - println(points) - jlen = jlen + get-size(type(fields(t2)[j])) - ilen = ilen + get-size(type(fields(t1)[i])) - jlen = 0 - to-list(points) - (t1:VectorType,t2:VectorType) : - val points = Vector<[Int,Int]>() - var ilen = 0 - var jlen = 0 - for i in 0 to min(size(t1),size(t2)) do : - val ls = get-valid-points(type(t1),type(t2),flip1,flip2) - for x in ls do : - add(points,[x[0] + ilen, x[1] + jlen]) - ilen = ilen + get-size(type(t1)) - jlen = jlen + get-size(type(t2)) - to-list(points) - val hashed-create-exps = HashTable<Expression,List<Expression>>(exp-hash) defn create-exps (n:Symbol, t:Type) -> List<Expression> : |
