diff options
| author | azidar | 2015-04-10 16:56:36 -0700 |
|---|---|---|
| committer | azidar | 2015-04-10 16:56:36 -0700 |
| commit | 75be996a7d26778aa6ed3c02db617b4f0516537c (patch) | |
| tree | 4ad612a5907a9ea98deea42588f05bfe26d3878c | |
| parent | a604e0789a85d8b3c5d6def2f9860047f479b68a (diff) | |
Almost finished width inference, takes too long/infinite loop for gcd
| -rw-r--r-- | TODO | 2 | ||||
| -rw-r--r-- | src/main/stanza/ir-utils.stanza | 16 | ||||
| -rw-r--r-- | src/main/stanza/passes.stanza | 182 | ||||
| -rw-r--r-- | test/passes/infer-widths/simple.fir | 5 |
4 files changed, 165 insertions, 40 deletions
@@ -1,7 +1,7 @@ TODO Change parser to use <> syntax (and update all tests) (patrick) Think about on-reset, add it - Think about expanding mems, more complicated than first glance! + Think about expanding mems, more complicated than first glance! Change DefMem to be size, and element type Think about max(op1,op2) for muxes. Make instances always male, flip the bundles on declaration Talk to palmer/patrick about how writing passes is going to be supported diff --git a/src/main/stanza/ir-utils.stanza b/src/main/stanza/ir-utils.stanza index 1f33547e..0bce0a90 100644 --- a/src/main/stanza/ir-utils.stanza +++ b/src/main/stanza/ir-utils.stanza @@ -289,6 +289,22 @@ defmethod map (f: Type -> Type, c:Stmt) -> Stmt : (c:DefMemory) : DefRegister(name(c),f(type(c))) (c) : c +public defmulti mapr<?T> (f: Width -> Width, t:?T&Type) -> T +defmethod mapr (f: Width -> Width, t:Type) -> Type : + defn apply-t (t:Type) -> Type : + map{f,_} $ map(apply-t,t) + apply-t(t) + +public defmulti mapr<?T> (f: Width -> Width, s:?T&Stmt) -> T +defmethod mapr (f: Width -> Width, s:Stmt) -> Stmt : + defn apply-t (t:Type) -> Type : mapr(f,t) + defn apply-e (e:Expression) -> Expression : + map{f,_} $ map{apply-t,_} $ map(apply-e,e) + defn apply-s (s:Stmt) -> Stmt : + map{apply-t,_} $ map{apply-e,_} $ map(apply-s,s) + apply-s(s) + + ;================= HELPER FUNCTIONS USING MAP =================== ; These don't work properly.. ;public defmulti do (f:Expression -> ?, e:Expression) -> False diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza index 313191c8..e9e5bfc6 100644 --- a/src/main/stanza/passes.stanza +++ b/src/main/stanza/passes.stanza @@ -1436,6 +1436,14 @@ defstruct MaxWidth <: Width : arg1 : Width arg2 : Width +public defmulti map<?T> (f: Width -> Width, w:?T&Width) -> T +defmethod map (f: Width -> Width, w:Width) -> Width : + match(w) : + (w:MaxWidth) : MaxWidth(f(arg1(w)),f(arg2(w))) + (w:PlusWidth) : PlusWidth(f(arg1(w)),f(arg2(w))) + (w:MinusWidth) : MinusWidth(f(arg1(w)),f(arg2(w))) + (w) : w + defmethod print (o:OutputStream, w:VarWidth) : print(o,name(w)) defmethod print (o:OutputStream, w:MaxWidth) : @@ -1449,18 +1457,115 @@ definterface Constraint defstruct WGeq <: Constraint : loc : Width exp : Width -defstruct WEq <: Constraint : - loc : Width - exp : Width defmethod print (o:OutputStream, c:WGeq) : print-all(o,[ loc(c) " >= " exp(c)]) -defmethod print (o:OutputStream, c:WEq) : - print-all(o,[ loc(c) " = " exp(c)]) -;defn remove-var-widths (c:Circuit) -> Circuit : - -;defn solve-constraints (l:List<WGeq>) -> HashTable<Symbol,Int> : +defn solve-constraints (l:List<WGeq>) -> HashTable<Symbol,Int> : + defn contains? (n:Symbol,h:HashTable<Symbol,?>) -> True|False : + for x in h any? : key(x) == n + defn unique (ls:List<WGeq>) -> HashTable<Symbol,Width> : + val h = HashTable<Symbol,Width>(symbol-hash) + for g in ls do : + match(loc(g)) : + (w:VarWidth) : + val n = name(w) + if contains?(n,h) : h[n] = MaxWidth(exp(g),h[n]) + else : h[n] = exp(g) + (w) : w + h + defn substitute (w:Width,h:HashTable<Symbol,Width>) -> Width : + match(map(substitute{_,h},w)) : + (w:VarWidth) : + if contains?(name(w),h) : + val t = substitute(h[name(w)],h) + h[name(w)] = t + t + else : w + (w) : w + defn remove-cycle (n:Symbol,w:Width) -> Width : + match(map(remove-cycle{n,_},w)) : + (w:MaxWidth) : + match(arg1(w),arg2(w)) : + (v1:VarWidth,v2:VarWidth) : + if name(v1) == n : arg2(w) + else if name(v2) == n : arg1(w) + else : w + (v:VarWidth,_) : + if name(v) == n : arg2(w) + else : w + (_,v:VarWidth) : + if name(v) == n : arg1(w) + else : w + (v1,v2) : w + (w) : w + defn self-rec? (n:Symbol,w:Width) -> True|False : + var has? = false + defn look (w:Width) -> Width : + match(map(look,w)) : + (w:VarWidth) : if name(w) == n : has? = true + (w) : w + w + look(w) + has? + defn evaluate (h:HashTable<Symbol,Width>) -> HashTable<Symbol,Int> : + defn apply (a:Int|False,b:Int|False, f: (Int,Int) -> Int) -> Int|False : + if a typeof Int and b typeof Int : f(a as Int, b as Int) + else : false + defn max (a:Int,b:Int) -> Int : + if a > b : a + else : b + defn solve (w:Width) -> Int|False : + match(w) : + (w:VarWidth) : false + (w:MaxWidth) : apply(solve(arg1(w)),solve(arg2(w)),max) + (w:PlusWidth) : apply(solve(arg1(w)),solve(arg2(w)),{_ + _}) + (w:MinusWidth) : apply(solve(arg1(w)),solve(arg2(w)),{_ - _}) + (w:IntWidth) : width(w) + (w) : error("Shouldn't be here") + + val i = HashTable<Symbol,Int>(symbol-hash) + for x in h do : + val s = solve(value(x)) + if s typeof Int : i[key(x)] = s as Int + i + + ; Forward solve + ; Returns a solved list where each constraint undergoes: + ; 1) Continuous Solving (using triangular solving) + ; 2) Remove Cycles + ; 3) Move to solved if not self-recursive + val u = unique(l) + println("Unique Constraints") + for x in u do : println(x) + + val f = HashTable<Symbol,Width>(symbol-hash) + val o = Vector<Symbol>() + for x in u do : + val [n e] = [key(x) value(x)] + val e* = substitute(e,f) + ;val e* = remove-cycle{n,_} $ substitute(e,f) + ;if not self-rec?(n,e*) : + add(o,n) + f[n] = e* + + println("Forward Solved Constraints") + for x in f do : println(x) + + ; Backwards Solve + ;val b = HashTable<Symbol,Width>(symbol-hash) + ;for i in (length(o) - 1) through 0 by -1 do : + ; val n = o[i] + ; b[n] = substitute(f[n],b) + + ;println("Backwards Solved Constraints") + ;for x in b do : println(x) + ;; Evaluate + ;val e = evaluate(b) + ;println("Evaluated Constraints") + ;for x in e do : println(x) + ;e + HashTable<Symbol,Int>(symbol-hash) defn width! (t:Type) -> Width : match(t) : @@ -1550,8 +1655,7 @@ defn prim-width (e:DoPrim,v:Vector<WGeq>) -> Width : BIT-SELECT-OP : IntWidth(1) BITS-SELECT-OP : IntWidth(consts(e)[0] - consts(e)[1]) -defn gen-constraints (c:Circuit, m:Module) -> List<WGeq> : - val v = Vector<WGeq>() +defn gen-constraints (c:Circuit, m:Module, v:Vector<WGeq>) -> Vector<WGeq> : val h = HashTable<Symbol,Type>(symbol-hash) defn get-width (e:Expression) -> Width : @@ -1565,28 +1669,25 @@ defn gen-constraints (c:Circuit, m:Module) -> List<WGeq> : (k) : [width! $ h[name(e)], width! $ type(e)] - add(v,WGeq(wdec,wref)) add(v,WGeq(wref,wdec)) - wdec + add(v,WGeq(wdec,wref)) + wref (e:WSubfield) : ;assumes only subfields are instances val wdec = width! $ bundle-field-type(h[name(exp(e) as WRef)],name(e)) val wref = width! $ bundle-field-type(type(exp(e)),name(e)) - add(v,WGeq(wdec,wref)) add(v,WGeq(wref,wdec)) - wdec + add(v,WGeq(wdec,wref)) + wref (e:WIndex) : error("Shouldn't be here") (e:UIntValue) : width(e) (e:SIntValue) : width(e) (e:DoPrim) : prim-width(e,v) (e:ReadPort|WritePort) : add(v,WGeq(get-width(enable(e)),IntWidth(1))) - add(v,WGeq(IntWidth(1),get-width(enable(e)))) get-width(mem(e)) (e:Register) : add(v,WGeq(get-width(enable(e)),IntWidth(1))) - add(v,WGeq(IntWidth(1),get-width(enable(e)))) val w = width!(type(e)) - add(v,WGeq(get-width(value(e)),w)) add(v,WGeq(w,get-width(value(e)))) w defn gen-constraints (s:Stmt) -> Stmt : @@ -1606,42 +1707,51 @@ defn gen-constraints (c:Circuit, m:Module) -> List<WGeq> : for p in ports(m) do : h[name(p)] = type(p) gen-constraints(body(m)) - to-list(v) + v +defn replace-var-widths (c:Circuit,h:HashTable<Symbol,Int>) -> Circuit : + defn replace-var-widths-w (w:Width) -> Width : + defn contains? (n:Symbol,h:HashTable<Symbol,?>) -> True|False : + for x in h any? : key(x) == n + match(w) : + (w:VarWidth) : + if contains?(name(w),h) : IntWidth(h[name(w)]) + else: w + (w) : w + val modules* = for m in modules(c) map : + Module{name(m),_,body(m)} $ + for p in ports(m) map : + Port(name(p),direction(p),mapr(replace-var-widths-w,type(p))) + + val modules** = for m in modules* map : + Module(name(m),ports(m),mapr(replace-var-widths-w,body(m))) + Circuit(modules**,main(c)) defn remove-unknown-widths (c:Circuit) -> Circuit : - defn remove-unknown-widths-wid (w:Width) -> Width : + defn remove-unknown-widths-w (w:Width) -> Width : match(w) : (w:UnknownWidth) : VarWidth(gensym(`w)) (w) : w - defn remove-unknown-widths-type (t:Type) -> Type : - map{remove-unknown-widths-wid,_} $ - map(remove-unknown-widths-type,t) - defn remove-unknown-widths-exp (e:Expression) -> Expression : - map{remove-unknown-widths-wid,_} $ - map{remove-unknown-widths-type,_} $ - map(remove-unknown-widths-exp,e) - defn remove-unknown-widths-stmt (s:Stmt) -> Stmt : - map{remove-unknown-widths-type,_} $ - map{remove-unknown-widths-exp,_} $ - map(remove-unknown-widths-stmt,s) val modules* = for m in modules(c) map : Module{name(m),_,body(m)} $ for p in ports(m) map : - Port(name(p),direction(p),remove-unknown-widths-type(type(p))) + Port(name(p),direction(p),mapr(remove-unknown-widths-w,type(p))) val modules** = for m in modules* map : - Module(name(m),ports(m),remove-unknown-widths-stmt(body(m))) + Module(name(m),ports(m),mapr(remove-unknown-widths-w,body(m))) Circuit(modules**,main(c)) defn infer-widths (c:Circuit) -> Circuit : val c* = remove-unknown-widths(c) + val v = Vector<WGeq>() for m in modules(c*) do : - val l = gen-constraints(c*,m) - for x in l do : println(x) - c* - ;val h = solve-constraints(l) + gen-constraints(c*,m,v) + for x in v do : println(x) + val h = solve-constraints(to-list(v)) + println("Solved Constraints") + ;for x in h do : println(x) ;replace-var-widths(c*,h) + c* diff --git a/test/passes/infer-widths/simple.fir b/test/passes/infer-widths/simple.fir index f98d98da..fcd08ac6 100644 --- a/test/passes/infer-widths/simple.fir +++ b/test/passes/infer-widths/simple.fir @@ -3,10 +3,9 @@ ;CHECK: Infer Widths circuit top : module top : - wire e : UInt - wire x : UInt + wire e : UInt(30) reg y : UInt - y := mux-uu(e, UInt(1), equal-uu(gt-uu(x, x), UInt(0))) + y := e ; CHECK: Finished Infer Widths |
