diff options
| author | azidar | 2015-04-10 16:56:36 -0700 |
|---|---|---|
| committer | azidar | 2015-04-10 16:56:36 -0700 |
| commit | 75be996a7d26778aa6ed3c02db617b4f0516537c (patch) | |
| tree | 4ad612a5907a9ea98deea42588f05bfe26d3878c /src | |
| parent | a604e0789a85d8b3c5d6def2f9860047f479b68a (diff) | |
Almost finished width inference, takes too long/infinite loop for gcd
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/stanza/ir-utils.stanza | 16 | ||||
| -rw-r--r-- | src/main/stanza/passes.stanza | 182 |
2 files changed, 162 insertions, 36 deletions
diff --git a/src/main/stanza/ir-utils.stanza b/src/main/stanza/ir-utils.stanza index 1f33547e..0bce0a90 100644 --- a/src/main/stanza/ir-utils.stanza +++ b/src/main/stanza/ir-utils.stanza @@ -289,6 +289,22 @@ defmethod map (f: Type -> Type, c:Stmt) -> Stmt : (c:DefMemory) : DefRegister(name(c),f(type(c))) (c) : c +public defmulti mapr<?T> (f: Width -> Width, t:?T&Type) -> T +defmethod mapr (f: Width -> Width, t:Type) -> Type : + defn apply-t (t:Type) -> Type : + map{f,_} $ map(apply-t,t) + apply-t(t) + +public defmulti mapr<?T> (f: Width -> Width, s:?T&Stmt) -> T +defmethod mapr (f: Width -> Width, s:Stmt) -> Stmt : + defn apply-t (t:Type) -> Type : mapr(f,t) + defn apply-e (e:Expression) -> Expression : + map{f,_} $ map{apply-t,_} $ map(apply-e,e) + defn apply-s (s:Stmt) -> Stmt : + map{apply-t,_} $ map{apply-e,_} $ map(apply-s,s) + apply-s(s) + + ;================= HELPER FUNCTIONS USING MAP =================== ; These don't work properly.. ;public defmulti do (f:Expression -> ?, e:Expression) -> False diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza index 313191c8..e9e5bfc6 100644 --- a/src/main/stanza/passes.stanza +++ b/src/main/stanza/passes.stanza @@ -1436,6 +1436,14 @@ defstruct MaxWidth <: Width : arg1 : Width arg2 : Width +public defmulti map<?T> (f: Width -> Width, w:?T&Width) -> T +defmethod map (f: Width -> Width, w:Width) -> Width : + match(w) : + (w:MaxWidth) : MaxWidth(f(arg1(w)),f(arg2(w))) + (w:PlusWidth) : PlusWidth(f(arg1(w)),f(arg2(w))) + (w:MinusWidth) : MinusWidth(f(arg1(w)),f(arg2(w))) + (w) : w + defmethod print (o:OutputStream, w:VarWidth) : print(o,name(w)) defmethod print (o:OutputStream, w:MaxWidth) : @@ -1449,18 +1457,115 @@ definterface Constraint defstruct WGeq <: Constraint : loc : Width exp : Width -defstruct WEq <: Constraint : - loc : Width - exp : Width defmethod print (o:OutputStream, c:WGeq) : print-all(o,[ loc(c) " >= " exp(c)]) -defmethod print (o:OutputStream, c:WEq) : - print-all(o,[ loc(c) " = " exp(c)]) -;defn remove-var-widths (c:Circuit) -> Circuit : - -;defn solve-constraints (l:List<WGeq>) -> HashTable<Symbol,Int> : +defn solve-constraints (l:List<WGeq>) -> HashTable<Symbol,Int> : + defn contains? (n:Symbol,h:HashTable<Symbol,?>) -> True|False : + for x in h any? : key(x) == n + defn unique (ls:List<WGeq>) -> HashTable<Symbol,Width> : + val h = HashTable<Symbol,Width>(symbol-hash) + for g in ls do : + match(loc(g)) : + (w:VarWidth) : + val n = name(w) + if contains?(n,h) : h[n] = MaxWidth(exp(g),h[n]) + else : h[n] = exp(g) + (w) : w + h + defn substitute (w:Width,h:HashTable<Symbol,Width>) -> Width : + match(map(substitute{_,h},w)) : + (w:VarWidth) : + if contains?(name(w),h) : + val t = substitute(h[name(w)],h) + h[name(w)] = t + t + else : w + (w) : w + defn remove-cycle (n:Symbol,w:Width) -> Width : + match(map(remove-cycle{n,_},w)) : + (w:MaxWidth) : + match(arg1(w),arg2(w)) : + (v1:VarWidth,v2:VarWidth) : + if name(v1) == n : arg2(w) + else if name(v2) == n : arg1(w) + else : w + (v:VarWidth,_) : + if name(v) == n : arg2(w) + else : w + (_,v:VarWidth) : + if name(v) == n : arg1(w) + else : w + (v1,v2) : w + (w) : w + defn self-rec? (n:Symbol,w:Width) -> True|False : + var has? = false + defn look (w:Width) -> Width : + match(map(look,w)) : + (w:VarWidth) : if name(w) == n : has? = true + (w) : w + w + look(w) + has? + defn evaluate (h:HashTable<Symbol,Width>) -> HashTable<Symbol,Int> : + defn apply (a:Int|False,b:Int|False, f: (Int,Int) -> Int) -> Int|False : + if a typeof Int and b typeof Int : f(a as Int, b as Int) + else : false + defn max (a:Int,b:Int) -> Int : + if a > b : a + else : b + defn solve (w:Width) -> Int|False : + match(w) : + (w:VarWidth) : false + (w:MaxWidth) : apply(solve(arg1(w)),solve(arg2(w)),max) + (w:PlusWidth) : apply(solve(arg1(w)),solve(arg2(w)),{_ + _}) + (w:MinusWidth) : apply(solve(arg1(w)),solve(arg2(w)),{_ - _}) + (w:IntWidth) : width(w) + (w) : error("Shouldn't be here") + + val i = HashTable<Symbol,Int>(symbol-hash) + for x in h do : + val s = solve(value(x)) + if s typeof Int : i[key(x)] = s as Int + i + + ; Forward solve + ; Returns a solved list where each constraint undergoes: + ; 1) Continuous Solving (using triangular solving) + ; 2) Remove Cycles + ; 3) Move to solved if not self-recursive + val u = unique(l) + println("Unique Constraints") + for x in u do : println(x) + + val f = HashTable<Symbol,Width>(symbol-hash) + val o = Vector<Symbol>() + for x in u do : + val [n e] = [key(x) value(x)] + val e* = substitute(e,f) + ;val e* = remove-cycle{n,_} $ substitute(e,f) + ;if not self-rec?(n,e*) : + add(o,n) + f[n] = e* + + println("Forward Solved Constraints") + for x in f do : println(x) + + ; Backwards Solve + ;val b = HashTable<Symbol,Width>(symbol-hash) + ;for i in (length(o) - 1) through 0 by -1 do : + ; val n = o[i] + ; b[n] = substitute(f[n],b) + + ;println("Backwards Solved Constraints") + ;for x in b do : println(x) + ;; Evaluate + ;val e = evaluate(b) + ;println("Evaluated Constraints") + ;for x in e do : println(x) + ;e + HashTable<Symbol,Int>(symbol-hash) defn width! (t:Type) -> Width : match(t) : @@ -1550,8 +1655,7 @@ defn prim-width (e:DoPrim,v:Vector<WGeq>) -> Width : BIT-SELECT-OP : IntWidth(1) BITS-SELECT-OP : IntWidth(consts(e)[0] - consts(e)[1]) -defn gen-constraints (c:Circuit, m:Module) -> List<WGeq> : - val v = Vector<WGeq>() +defn gen-constraints (c:Circuit, m:Module, v:Vector<WGeq>) -> Vector<WGeq> : val h = HashTable<Symbol,Type>(symbol-hash) defn get-width (e:Expression) -> Width : @@ -1565,28 +1669,25 @@ defn gen-constraints (c:Circuit, m:Module) -> List<WGeq> : (k) : [width! $ h[name(e)], width! $ type(e)] - add(v,WGeq(wdec,wref)) add(v,WGeq(wref,wdec)) - wdec + add(v,WGeq(wdec,wref)) + wref (e:WSubfield) : ;assumes only subfields are instances val wdec = width! $ bundle-field-type(h[name(exp(e) as WRef)],name(e)) val wref = width! $ bundle-field-type(type(exp(e)),name(e)) - add(v,WGeq(wdec,wref)) add(v,WGeq(wref,wdec)) - wdec + add(v,WGeq(wdec,wref)) + wref (e:WIndex) : error("Shouldn't be here") (e:UIntValue) : width(e) (e:SIntValue) : width(e) (e:DoPrim) : prim-width(e,v) (e:ReadPort|WritePort) : add(v,WGeq(get-width(enable(e)),IntWidth(1))) - add(v,WGeq(IntWidth(1),get-width(enable(e)))) get-width(mem(e)) (e:Register) : add(v,WGeq(get-width(enable(e)),IntWidth(1))) - add(v,WGeq(IntWidth(1),get-width(enable(e)))) val w = width!(type(e)) - add(v,WGeq(get-width(value(e)),w)) add(v,WGeq(w,get-width(value(e)))) w defn gen-constraints (s:Stmt) -> Stmt : @@ -1606,42 +1707,51 @@ defn gen-constraints (c:Circuit, m:Module) -> List<WGeq> : for p in ports(m) do : h[name(p)] = type(p) gen-constraints(body(m)) - to-list(v) + v +defn replace-var-widths (c:Circuit,h:HashTable<Symbol,Int>) -> Circuit : + defn replace-var-widths-w (w:Width) -> Width : + defn contains? (n:Symbol,h:HashTable<Symbol,?>) -> True|False : + for x in h any? : key(x) == n + match(w) : + (w:VarWidth) : + if contains?(name(w),h) : IntWidth(h[name(w)]) + else: w + (w) : w + val modules* = for m in modules(c) map : + Module{name(m),_,body(m)} $ + for p in ports(m) map : + Port(name(p),direction(p),mapr(replace-var-widths-w,type(p))) + + val modules** = for m in modules* map : + Module(name(m),ports(m),mapr(replace-var-widths-w,body(m))) + Circuit(modules**,main(c)) defn remove-unknown-widths (c:Circuit) -> Circuit : - defn remove-unknown-widths-wid (w:Width) -> Width : + defn remove-unknown-widths-w (w:Width) -> Width : match(w) : (w:UnknownWidth) : VarWidth(gensym(`w)) (w) : w - defn remove-unknown-widths-type (t:Type) -> Type : - map{remove-unknown-widths-wid,_} $ - map(remove-unknown-widths-type,t) - defn remove-unknown-widths-exp (e:Expression) -> Expression : - map{remove-unknown-widths-wid,_} $ - map{remove-unknown-widths-type,_} $ - map(remove-unknown-widths-exp,e) - defn remove-unknown-widths-stmt (s:Stmt) -> Stmt : - map{remove-unknown-widths-type,_} $ - map{remove-unknown-widths-exp,_} $ - map(remove-unknown-widths-stmt,s) val modules* = for m in modules(c) map : Module{name(m),_,body(m)} $ for p in ports(m) map : - Port(name(p),direction(p),remove-unknown-widths-type(type(p))) + Port(name(p),direction(p),mapr(remove-unknown-widths-w,type(p))) val modules** = for m in modules* map : - Module(name(m),ports(m),remove-unknown-widths-stmt(body(m))) + Module(name(m),ports(m),mapr(remove-unknown-widths-w,body(m))) Circuit(modules**,main(c)) defn infer-widths (c:Circuit) -> Circuit : val c* = remove-unknown-widths(c) + val v = Vector<WGeq>() for m in modules(c*) do : - val l = gen-constraints(c*,m) - for x in l do : println(x) - c* - ;val h = solve-constraints(l) + gen-constraints(c*,m,v) + for x in v do : println(x) + val h = solve-constraints(to-list(v)) + println("Solved Constraints") + ;for x in h do : println(x) ;replace-var-widths(c*,h) + c* |
