diff options
Diffstat (limited to 'src/main/stanza')
| -rw-r--r-- | src/main/stanza/chirrtl.stanza | 3 | ||||
| -rw-r--r-- | src/main/stanza/ir-utils.stanza | 37 | ||||
| -rw-r--r-- | src/main/stanza/passes.stanza | 195 |
3 files changed, 125 insertions, 110 deletions
diff --git a/src/main/stanza/chirrtl.stanza b/src/main/stanza/chirrtl.stanza index 7d6d9d3f..2ac76a05 100644 --- a/src/main/stanza/chirrtl.stanza +++ b/src/main/stanza/chirrtl.stanza @@ -114,6 +114,9 @@ defn infer-types (c:Circuit) -> Circuit : val t = type(value(s*)) types[name(s*)] = t s* + (s:DefMemory) : + types[name(s)] = get-type(s) + s (s:CDefMPort) : val t = types[mem(s)] types[name(s)] = t diff --git a/src/main/stanza/ir-utils.stanza b/src/main/stanza/ir-utils.stanza index fa4b296f..59a4b659 100644 --- a/src/main/stanza/ir-utils.stanza +++ b/src/main/stanza/ir-utils.stanza @@ -157,6 +157,43 @@ public defn swap (f:Flip) -> Flip : switch {_ == f} : DEFAULT : REVERSE REVERSE : DEFAULT + +public defmulti get-type (s:Stmt) -> Type +public defmethod get-type (s:Stmt) -> Type : + match(s) : + (s:DefWire|DefPoison|DefRegister) : type(s) + (s:DefNode) : type(value(s)) + (s:DefMemory) : + val depth = depth(s) + ; Fields + val addr = Field(`addr,DEFAULT,UIntType(IntWidth(ceil-log2(depth)))) + val en = Field(`en,DEFAULT,BoolType()) + val clk = Field(`clk,DEFAULT,ClockType()) + val def-data = Field(`data,DEFAULT,data-type(s)) + val rev-data = Field(`data,REVERSE,data-type(s)) + val rdata = Field(`rdata,REVERSE,data-type(s)) + val wdata = Field(`wdata,DEFAULT,data-type(s)) + val mask = Field(`mask,DEFAULT,create-mask(data-type(s))) + val wmask = Field(`wmask,DEFAULT,create-mask(data-type(s))) + val ren = Field(`ren,DEFAULT,UIntType(IntWidth(1))) + val wen = Field(`wen,DEFAULT,UIntType(IntWidth(1))) + val raddr = Field(`raddr,DEFAULT,UIntType(IntWidth(ceil-log2(depth)))) + val waddr = Field(`waddr,DEFAULT,UIntType(IntWidth(ceil-log2(depth)))) + + val read-type = BundleType(to-list([rev-data,addr,en,clk])) + val write-type = BundleType(to-list([def-data,mask,addr,en,clk])) + val readwrite-type = BundleType(to-list([wdata,wmask,waddr,wen,rdata,raddr,ren,clk])) + + val mem-fields = Vector<Field>() + for x in readers(s) do : + add(mem-fields,Field(x,DEFAULT,read-type)) + for x in writers(s) do : + add(mem-fields,Field(x,DEFAULT,write-type)) + for x in readwriters(s) do : + add(mem-fields,Field(x,DEFAULT,readwrite-type)) + BundleType(to-list(mem-fields)) + (s:DefInstance) : UnknownType() + (s:Begin|Connect|BulkConnect|Stop|Print|Empty) : UnknownType() public defn get-size (t:Type) -> Int : val x = match(t) : diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza index 570e271e..47e8711e 100644 --- a/src/main/stanza/passes.stanza +++ b/src/main/stanza/passes.stanza @@ -116,41 +116,7 @@ defmethod info (stmt:Empty) -> FileInfo : FileInfo() defmethod type (exp:UIntValue) -> Type : UIntType(width(exp)) defmethod type (exp:SIntValue) -> Type : SIntType(width(exp)) -defn get-type (s:Stmt) -> Type : - match(s) : - (s:DefWire|DefPoison|DefRegister|WDefInstance) : type(s) - (s:DefNode) : type(value(s)) - (s:DefMemory) : - val depth = depth(s) - ; Fields - val addr = Field(`addr,DEFAULT,UIntType(IntWidth(ceil-log2(depth)))) - val en = Field(`en,DEFAULT,BoolType()) - val clk = Field(`clk,DEFAULT,ClockType()) - val def-data = Field(`data,DEFAULT,data-type(s)) - val rev-data = Field(`data,REVERSE,data-type(s)) - val rdata = Field(`rdata,REVERSE,data-type(s)) - val wdata = Field(`wdata,DEFAULT,data-type(s)) - val mask = Field(`mask,DEFAULT,create-mask(data-type(s))) - val wmask = Field(`wmask,DEFAULT,create-mask(data-type(s))) - val ren = Field(`ren,DEFAULT,UIntType(IntWidth(1))) - val wen = Field(`wen,DEFAULT,UIntType(IntWidth(1))) - val raddr = Field(`raddr,DEFAULT,UIntType(IntWidth(ceil-log2(depth)))) - val waddr = Field(`waddr,DEFAULT,UIntType(IntWidth(ceil-log2(depth)))) - - val read-type = BundleType(to-list([rev-data,addr,en,clk])) - val write-type = BundleType(to-list([def-data,mask,addr,en,clk])) - val readwrite-type = BundleType(to-list([wdata,wmask,waddr,wen,rdata,raddr,ren,clk])) - - val mem-fields = Vector<Field>() - for x in readers(s) do : - add(mem-fields,Field(x,DEFAULT,read-type)) - for x in writers(s) do : - add(mem-fields,Field(x,DEFAULT,write-type)) - for x in readwriters(s) do : - add(mem-fields,Field(x,DEFAULT,readwrite-type)) - BundleType(to-list(mem-fields)) - (s:DefInstance) : UnknownType() - (s:Begin|Connect|BulkConnect|Stop|Print|Empty) : UnknownType() +defmethod get-type (s:WDefInstance) -> Type : type(s) defmethod equal? (e1:Expression,e2:Expression) -> True|False : match(e1,e2) : @@ -840,31 +806,27 @@ public defmethod short-name (b:ExpandConnects) -> String : "expand-connects" ;---------------- UTILS ------------------ defn get-size (e:Expression) -> Int : get-size(type(e)) -val hashed-get-flip = HashTable<List,Flip>(list-hash) defn get-flip (t:Type, i:Int, f:Flip) -> Flip : - if key?(hashed-get-flip,list(t,i,f)) : hashed-get-flip[list(t,i,f)] - else : - if i >= get-size(t) : error("Shouldn't be here") - val x = match(t) : - (t:UIntType|SIntType|ClockType) : f - (t:BundleType) : label<Flip> ret : - var n = i - for x in fields(t) do : - if n < get-size(type(x)) : - ret(get-flip(type(x),n,flip(x) * f)) - else : - n = n - get-size(type(x)) - error("Shouldn't be here") - (t:VectorType) : label<Flip> ret : - var n = i - for j in 0 to size(t) do : - if n < get-size(type(t)) : - ret(get-flip(type(t),n,f)) - else : - n = n - get-size(type(t)) - error("Shouldn't be here") - hashed-get-flip[list(t,i,f)] = x - x + if i >= get-size(t) : error("Shouldn't be here") + val x = match(t) : + (t:UIntType|SIntType|ClockType) : f + (t:BundleType) : label<Flip> ret : + var n = i + for x in fields(t) do : + if n < get-size(type(x)) : + ret(get-flip(type(x),n,flip(x) * f)) + else : + n = n - get-size(type(x)) + error("Shouldn't be here") + (t:VectorType) : label<Flip> ret : + var n = i + for j in 0 to size(t) do : + if n < get-size(type(t)) : + ret(get-flip(type(t),n,f)) + else : + n = n - get-size(type(t)) + error("Shouldn't be here") + x defn get-point (e:Expression) -> Int : match(e) : @@ -881,7 +843,6 @@ defn get-point (e:Expression) -> Int : (e:WSubAccess) : get-point(exp(e)) -val hashed-create-exps = HashTable<Expression,List<Expression>>(exp-hash) defn create-exps (n:Symbol, t:Type) -> List<Expression> : create-exps(WRef(n,t,ExpKind(),UNKNOWN-GENDER)) defn create-exps (e:Expression) -> List<Expression> : @@ -894,23 +855,31 @@ defn create-exps (e:Expression) -> List<Expression> : for i in 0 to size(t) map-append : create-exps(WSubIndex(e,i,type(t),gender(e))) +defn gexp-hash (e:Expression) -> Int : + turn-off-debug(false) + val ls = to-list([mname `.... e `.... gender(e) `.... type(e)]) + ;val ls = to-list([e `.... gender(e) `.... type(e)]) + val i = symbol-hash(symbol-join(ls)) + ;val i = symbol-hash(to-symbol(to-string(e))) + turn-on-debug(false) + i +val hashed-create-exps = HashTable<Expression,List<Expression>>(gexp-hash) defn fast-create-exps (n:Symbol, t:Type) -> List<Expression> : fast-create-exps(WRef(n,t,ExpKind(),UNKNOWN-GENDER)) defn fast-create-exps (e:Expression) -> List<Expression> : - if key?(hashed-create-exps,e) : hashed-create-exps[e] + if key?(hashed-create-exps,e) : + hashed-create-exps[e] else : - val es = Vector<Expression>() - defn create-exps (e:Expression) -> False : - match(type(e)) : - (t:UIntType|SIntType|ClockType) : add(es,e) - (t:BundleType) : - for f in fields(t) do : - create-exps(WSubField(e,name(f),type(f),gender(e) * flip(f))) - (t:VectorType) : - for i in 0 to size(t) do : - create-exps(WSubIndex(e,i,type(t),gender(e))) - create-exps(e) - val x = to-list(es) + val es = Vector<List<Expression>>() + match(type(e)) : + (t:UIntType|SIntType|ClockType) : add(es,list(e)) + (t:BundleType) : + for f in fields(t) do : + add(es,fast-create-exps(WSubField(e,name(f),type(f),gender(e) * flip(f)))) + (t:VectorType) : + for i in 0 to size(t) do : + add(es,fast-create-exps(WSubIndex(e,i,type(t),gender(e)))) + val x = append-all(es) hashed-create-exps[e] = x x @@ -974,39 +943,34 @@ defstruct Location : defmethod print (o:OutputStream,x:Location) : print-all(o,["[" base(x) " , " guard(x) "]"]) - -val hashed-locations = HashTable<Expression,List<Location>>(exp-hash) defn get-locations (e:Expression) -> List<Location> : - if key?(hashed-locations,e) : hashed-locations[e] - else : - val x = match(e) : - (e:WRef) : map(Location{_,one},create-exps(e)) - (e:WSubIndex|WSubField) : - val ls = get-locations(exp(e)) - val start = get-point(e) - val end = start + get-size(e) - val stride = get-size(exp(e)) - val ls* = Vector<Location>() - var c = 0 - for i in 0 to length(ls) do : - if (i % stride >= start and i % stride < end) : - add(ls*,ls[i]) - to-list(ls*) - (e:WSubAccess) : - val ls = get-locations(exp(e)) - val stride = get-size(e) - val wrap = size(type(exp(e)) as VectorType) - val ls* = Vector<Location>() - var c = 0 - for i in 0 to length(ls) do : - if c % wrap == 0 : c = 0 - val base* = base(ls[i]) - val guard* = AND(guard(ls[i]),EQV(uint(c),index(e))) - add(ls*,Location(base*,guard*)) - if (i + 1) % stride == 0 : c = c + 1 - to-list(ls*) - hashed-locations[e] = x - x + match(e) : + (e:WRef) : map(Location{_,one},create-exps(e)) + (e:WSubIndex|WSubField) : + val ls = get-locations(exp(e)) + val start = get-point(e) + val end = start + get-size(e) + val stride = get-size(exp(e)) + val ls* = Vector<Location>() + var c = 0 + for i in 0 to length(ls) do : + if (i % stride >= start and i % stride < end) : + add(ls*,ls[i]) + to-list(ls*) + (e:WSubAccess) : + val ls = get-locations(exp(e)) + val stride = get-size(e) + val wrap = size(type(exp(e)) as VectorType) + val ls* = Vector<Location>() + var c = 0 + for i in 0 to length(ls) do : + if c % wrap == 0 : c = 0 + val base* = base(ls[i]) + val guard* = AND(guard(ls[i]),EQV(uint(c),index(e))) + add(ls*,Location(base*,guard*)) + if (i + 1) % stride == 0 : c = c + 1 + to-list(ls*) + defn has-access? (e:Expression) -> True|False : var ret = false defn rec-has-access (e:Expression) -> Expression : @@ -1214,6 +1178,8 @@ defn expand-whens (c:Circuit) -> Circuit : (m:InModule) : val [netlist simlist] = expand-whens(m) create-module(netlist,simlist,m) + for m in modules* do : + if name(m) == `RRArbiter_38 : print(m) Circuit(info(c),modules*,main(c)) ;;================ Module Duplication ================== @@ -2099,6 +2065,16 @@ defn lower-data-mem (e:Expression) -> Expression : WSubField(p,to-symbol(names[2]),UnknownType(),UNKNOWN-GENDER) defn merge (a:Symbol,b:Symbol,x:Symbol) -> Symbol : symbol-join([a x b]) +val hashed-lowered-name = HashTable<Expression,Symbol>(gexp-hash) +defn fast-lowered-name (e:Expression) -> Symbol : + val x = get?(hashed-lowered-name,e,false) + match(x) : + (x:Symbol) : x + (x:False) : + match(e) : + (e:WRef) : name(e) + (e:WSubField) : merge(fast-lowered-name(exp(e)),name(e),`_) + (e:WSubIndex) : merge(fast-lowered-name(exp(e)),to-symbol(value(e)),`_) defn lowered-name (e:Expression) -> Symbol : match(e) : (e:WRef) : name(e) @@ -2178,11 +2154,10 @@ defn lower-types (m:Module) -> Module : else : s (s) : map(lower-types,s) - val ports* = - for p in ports(m) map-append : - val es = create-exps(WRef(name(p),type(p),PortKind(),to-gender(direction(p)))) - for e in es map : - Port(info(p),lowered-name(e),to-dir(gender(e)),type(e)) + val ports* = for p in ports(m) map-append : + val es = create-exps(WRef(name(p),type(p),PortKind(),to-gender(direction(p)))) + for e in es map : + Port(info(p),lowered-name(e),to-dir(gender(e)),type(e)) match(m) : (m:ExModule) : ExModule(info(m),name(m),ports*) (m:InModule) : InModule(info(m),name(m),ports*,lower-types(body(m))) |
