aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'src/main')
-rw-r--r--src/main/stanza/chirrtl.stanza3
-rw-r--r--src/main/stanza/ir-utils.stanza37
-rw-r--r--src/main/stanza/passes.stanza195
3 files changed, 125 insertions, 110 deletions
diff --git a/src/main/stanza/chirrtl.stanza b/src/main/stanza/chirrtl.stanza
index 7d6d9d3f..2ac76a05 100644
--- a/src/main/stanza/chirrtl.stanza
+++ b/src/main/stanza/chirrtl.stanza
@@ -114,6 +114,9 @@ defn infer-types (c:Circuit) -> Circuit :
val t = type(value(s*))
types[name(s*)] = t
s*
+ (s:DefMemory) :
+ types[name(s)] = get-type(s)
+ s
(s:CDefMPort) :
val t = types[mem(s)]
types[name(s)] = t
diff --git a/src/main/stanza/ir-utils.stanza b/src/main/stanza/ir-utils.stanza
index fa4b296f..59a4b659 100644
--- a/src/main/stanza/ir-utils.stanza
+++ b/src/main/stanza/ir-utils.stanza
@@ -157,6 +157,43 @@ public defn swap (f:Flip) -> Flip :
switch {_ == f} :
DEFAULT : REVERSE
REVERSE : DEFAULT
+
+public defmulti get-type (s:Stmt) -> Type
+public defmethod get-type (s:Stmt) -> Type :
+ match(s) :
+ (s:DefWire|DefPoison|DefRegister) : type(s)
+ (s:DefNode) : type(value(s))
+ (s:DefMemory) :
+ val depth = depth(s)
+ ; Fields
+ val addr = Field(`addr,DEFAULT,UIntType(IntWidth(ceil-log2(depth))))
+ val en = Field(`en,DEFAULT,BoolType())
+ val clk = Field(`clk,DEFAULT,ClockType())
+ val def-data = Field(`data,DEFAULT,data-type(s))
+ val rev-data = Field(`data,REVERSE,data-type(s))
+ val rdata = Field(`rdata,REVERSE,data-type(s))
+ val wdata = Field(`wdata,DEFAULT,data-type(s))
+ val mask = Field(`mask,DEFAULT,create-mask(data-type(s)))
+ val wmask = Field(`wmask,DEFAULT,create-mask(data-type(s)))
+ val ren = Field(`ren,DEFAULT,UIntType(IntWidth(1)))
+ val wen = Field(`wen,DEFAULT,UIntType(IntWidth(1)))
+ val raddr = Field(`raddr,DEFAULT,UIntType(IntWidth(ceil-log2(depth))))
+ val waddr = Field(`waddr,DEFAULT,UIntType(IntWidth(ceil-log2(depth))))
+
+ val read-type = BundleType(to-list([rev-data,addr,en,clk]))
+ val write-type = BundleType(to-list([def-data,mask,addr,en,clk]))
+ val readwrite-type = BundleType(to-list([wdata,wmask,waddr,wen,rdata,raddr,ren,clk]))
+
+ val mem-fields = Vector<Field>()
+ for x in readers(s) do :
+ add(mem-fields,Field(x,DEFAULT,read-type))
+ for x in writers(s) do :
+ add(mem-fields,Field(x,DEFAULT,write-type))
+ for x in readwriters(s) do :
+ add(mem-fields,Field(x,DEFAULT,readwrite-type))
+ BundleType(to-list(mem-fields))
+ (s:DefInstance) : UnknownType()
+ (s:Begin|Connect|BulkConnect|Stop|Print|Empty) : UnknownType()
public defn get-size (t:Type) -> Int :
val x = match(t) :
diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza
index 570e271e..47e8711e 100644
--- a/src/main/stanza/passes.stanza
+++ b/src/main/stanza/passes.stanza
@@ -116,41 +116,7 @@ defmethod info (stmt:Empty) -> FileInfo : FileInfo()
defmethod type (exp:UIntValue) -> Type : UIntType(width(exp))
defmethod type (exp:SIntValue) -> Type : SIntType(width(exp))
-defn get-type (s:Stmt) -> Type :
- match(s) :
- (s:DefWire|DefPoison|DefRegister|WDefInstance) : type(s)
- (s:DefNode) : type(value(s))
- (s:DefMemory) :
- val depth = depth(s)
- ; Fields
- val addr = Field(`addr,DEFAULT,UIntType(IntWidth(ceil-log2(depth))))
- val en = Field(`en,DEFAULT,BoolType())
- val clk = Field(`clk,DEFAULT,ClockType())
- val def-data = Field(`data,DEFAULT,data-type(s))
- val rev-data = Field(`data,REVERSE,data-type(s))
- val rdata = Field(`rdata,REVERSE,data-type(s))
- val wdata = Field(`wdata,DEFAULT,data-type(s))
- val mask = Field(`mask,DEFAULT,create-mask(data-type(s)))
- val wmask = Field(`wmask,DEFAULT,create-mask(data-type(s)))
- val ren = Field(`ren,DEFAULT,UIntType(IntWidth(1)))
- val wen = Field(`wen,DEFAULT,UIntType(IntWidth(1)))
- val raddr = Field(`raddr,DEFAULT,UIntType(IntWidth(ceil-log2(depth))))
- val waddr = Field(`waddr,DEFAULT,UIntType(IntWidth(ceil-log2(depth))))
-
- val read-type = BundleType(to-list([rev-data,addr,en,clk]))
- val write-type = BundleType(to-list([def-data,mask,addr,en,clk]))
- val readwrite-type = BundleType(to-list([wdata,wmask,waddr,wen,rdata,raddr,ren,clk]))
-
- val mem-fields = Vector<Field>()
- for x in readers(s) do :
- add(mem-fields,Field(x,DEFAULT,read-type))
- for x in writers(s) do :
- add(mem-fields,Field(x,DEFAULT,write-type))
- for x in readwriters(s) do :
- add(mem-fields,Field(x,DEFAULT,readwrite-type))
- BundleType(to-list(mem-fields))
- (s:DefInstance) : UnknownType()
- (s:Begin|Connect|BulkConnect|Stop|Print|Empty) : UnknownType()
+defmethod get-type (s:WDefInstance) -> Type : type(s)
defmethod equal? (e1:Expression,e2:Expression) -> True|False :
match(e1,e2) :
@@ -840,31 +806,27 @@ public defmethod short-name (b:ExpandConnects) -> String : "expand-connects"
;---------------- UTILS ------------------
defn get-size (e:Expression) -> Int : get-size(type(e))
-val hashed-get-flip = HashTable<List,Flip>(list-hash)
defn get-flip (t:Type, i:Int, f:Flip) -> Flip :
- if key?(hashed-get-flip,list(t,i,f)) : hashed-get-flip[list(t,i,f)]
- else :
- if i >= get-size(t) : error("Shouldn't be here")
- val x = match(t) :
- (t:UIntType|SIntType|ClockType) : f
- (t:BundleType) : label<Flip> ret :
- var n = i
- for x in fields(t) do :
- if n < get-size(type(x)) :
- ret(get-flip(type(x),n,flip(x) * f))
- else :
- n = n - get-size(type(x))
- error("Shouldn't be here")
- (t:VectorType) : label<Flip> ret :
- var n = i
- for j in 0 to size(t) do :
- if n < get-size(type(t)) :
- ret(get-flip(type(t),n,f))
- else :
- n = n - get-size(type(t))
- error("Shouldn't be here")
- hashed-get-flip[list(t,i,f)] = x
- x
+ if i >= get-size(t) : error("Shouldn't be here")
+ val x = match(t) :
+ (t:UIntType|SIntType|ClockType) : f
+ (t:BundleType) : label<Flip> ret :
+ var n = i
+ for x in fields(t) do :
+ if n < get-size(type(x)) :
+ ret(get-flip(type(x),n,flip(x) * f))
+ else :
+ n = n - get-size(type(x))
+ error("Shouldn't be here")
+ (t:VectorType) : label<Flip> ret :
+ var n = i
+ for j in 0 to size(t) do :
+ if n < get-size(type(t)) :
+ ret(get-flip(type(t),n,f))
+ else :
+ n = n - get-size(type(t))
+ error("Shouldn't be here")
+ x
defn get-point (e:Expression) -> Int :
match(e) :
@@ -881,7 +843,6 @@ defn get-point (e:Expression) -> Int :
(e:WSubAccess) :
get-point(exp(e))
-val hashed-create-exps = HashTable<Expression,List<Expression>>(exp-hash)
defn create-exps (n:Symbol, t:Type) -> List<Expression> :
create-exps(WRef(n,t,ExpKind(),UNKNOWN-GENDER))
defn create-exps (e:Expression) -> List<Expression> :
@@ -894,23 +855,31 @@ defn create-exps (e:Expression) -> List<Expression> :
for i in 0 to size(t) map-append :
create-exps(WSubIndex(e,i,type(t),gender(e)))
+defn gexp-hash (e:Expression) -> Int :
+ turn-off-debug(false)
+ val ls = to-list([mname `.... e `.... gender(e) `.... type(e)])
+ ;val ls = to-list([e `.... gender(e) `.... type(e)])
+ val i = symbol-hash(symbol-join(ls))
+ ;val i = symbol-hash(to-symbol(to-string(e)))
+ turn-on-debug(false)
+ i
+val hashed-create-exps = HashTable<Expression,List<Expression>>(gexp-hash)
defn fast-create-exps (n:Symbol, t:Type) -> List<Expression> :
fast-create-exps(WRef(n,t,ExpKind(),UNKNOWN-GENDER))
defn fast-create-exps (e:Expression) -> List<Expression> :
- if key?(hashed-create-exps,e) : hashed-create-exps[e]
+ if key?(hashed-create-exps,e) :
+ hashed-create-exps[e]
else :
- val es = Vector<Expression>()
- defn create-exps (e:Expression) -> False :
- match(type(e)) :
- (t:UIntType|SIntType|ClockType) : add(es,e)
- (t:BundleType) :
- for f in fields(t) do :
- create-exps(WSubField(e,name(f),type(f),gender(e) * flip(f)))
- (t:VectorType) :
- for i in 0 to size(t) do :
- create-exps(WSubIndex(e,i,type(t),gender(e)))
- create-exps(e)
- val x = to-list(es)
+ val es = Vector<List<Expression>>()
+ match(type(e)) :
+ (t:UIntType|SIntType|ClockType) : add(es,list(e))
+ (t:BundleType) :
+ for f in fields(t) do :
+ add(es,fast-create-exps(WSubField(e,name(f),type(f),gender(e) * flip(f))))
+ (t:VectorType) :
+ for i in 0 to size(t) do :
+ add(es,fast-create-exps(WSubIndex(e,i,type(t),gender(e))))
+ val x = append-all(es)
hashed-create-exps[e] = x
x
@@ -974,39 +943,34 @@ defstruct Location :
defmethod print (o:OutputStream,x:Location) :
print-all(o,["[" base(x) " , " guard(x) "]"])
-
-val hashed-locations = HashTable<Expression,List<Location>>(exp-hash)
defn get-locations (e:Expression) -> List<Location> :
- if key?(hashed-locations,e) : hashed-locations[e]
- else :
- val x = match(e) :
- (e:WRef) : map(Location{_,one},create-exps(e))
- (e:WSubIndex|WSubField) :
- val ls = get-locations(exp(e))
- val start = get-point(e)
- val end = start + get-size(e)
- val stride = get-size(exp(e))
- val ls* = Vector<Location>()
- var c = 0
- for i in 0 to length(ls) do :
- if (i % stride >= start and i % stride < end) :
- add(ls*,ls[i])
- to-list(ls*)
- (e:WSubAccess) :
- val ls = get-locations(exp(e))
- val stride = get-size(e)
- val wrap = size(type(exp(e)) as VectorType)
- val ls* = Vector<Location>()
- var c = 0
- for i in 0 to length(ls) do :
- if c % wrap == 0 : c = 0
- val base* = base(ls[i])
- val guard* = AND(guard(ls[i]),EQV(uint(c),index(e)))
- add(ls*,Location(base*,guard*))
- if (i + 1) % stride == 0 : c = c + 1
- to-list(ls*)
- hashed-locations[e] = x
- x
+ match(e) :
+ (e:WRef) : map(Location{_,one},create-exps(e))
+ (e:WSubIndex|WSubField) :
+ val ls = get-locations(exp(e))
+ val start = get-point(e)
+ val end = start + get-size(e)
+ val stride = get-size(exp(e))
+ val ls* = Vector<Location>()
+ var c = 0
+ for i in 0 to length(ls) do :
+ if (i % stride >= start and i % stride < end) :
+ add(ls*,ls[i])
+ to-list(ls*)
+ (e:WSubAccess) :
+ val ls = get-locations(exp(e))
+ val stride = get-size(e)
+ val wrap = size(type(exp(e)) as VectorType)
+ val ls* = Vector<Location>()
+ var c = 0
+ for i in 0 to length(ls) do :
+ if c % wrap == 0 : c = 0
+ val base* = base(ls[i])
+ val guard* = AND(guard(ls[i]),EQV(uint(c),index(e)))
+ add(ls*,Location(base*,guard*))
+ if (i + 1) % stride == 0 : c = c + 1
+ to-list(ls*)
+
defn has-access? (e:Expression) -> True|False :
var ret = false
defn rec-has-access (e:Expression) -> Expression :
@@ -1214,6 +1178,8 @@ defn expand-whens (c:Circuit) -> Circuit :
(m:InModule) :
val [netlist simlist] = expand-whens(m)
create-module(netlist,simlist,m)
+ for m in modules* do :
+ if name(m) == `RRArbiter_38 : print(m)
Circuit(info(c),modules*,main(c))
;;================ Module Duplication ==================
@@ -2099,6 +2065,16 @@ defn lower-data-mem (e:Expression) -> Expression :
WSubField(p,to-symbol(names[2]),UnknownType(),UNKNOWN-GENDER)
defn merge (a:Symbol,b:Symbol,x:Symbol) -> Symbol : symbol-join([a x b])
+val hashed-lowered-name = HashTable<Expression,Symbol>(gexp-hash)
+defn fast-lowered-name (e:Expression) -> Symbol :
+ val x = get?(hashed-lowered-name,e,false)
+ match(x) :
+ (x:Symbol) : x
+ (x:False) :
+ match(e) :
+ (e:WRef) : name(e)
+ (e:WSubField) : merge(fast-lowered-name(exp(e)),name(e),`_)
+ (e:WSubIndex) : merge(fast-lowered-name(exp(e)),to-symbol(value(e)),`_)
defn lowered-name (e:Expression) -> Symbol :
match(e) :
(e:WRef) : name(e)
@@ -2178,11 +2154,10 @@ defn lower-types (m:Module) -> Module :
else : s
(s) : map(lower-types,s)
- val ports* =
- for p in ports(m) map-append :
- val es = create-exps(WRef(name(p),type(p),PortKind(),to-gender(direction(p))))
- for e in es map :
- Port(info(p),lowered-name(e),to-dir(gender(e)),type(e))
+ val ports* = for p in ports(m) map-append :
+ val es = create-exps(WRef(name(p),type(p),PortKind(),to-gender(direction(p))))
+ for e in es map :
+ Port(info(p),lowered-name(e),to-dir(gender(e)),type(e))
match(m) :
(m:ExModule) : ExModule(info(m),name(m),ports*)
(m:InModule) : InModule(info(m),name(m),ports*,lower-types(body(m)))