diff options
| author | azidar | 2015-04-13 17:51:00 -0700 |
|---|---|---|
| committer | azidar | 2015-04-13 17:51:00 -0700 |
| commit | c140b1ffbcf7fb5b2bb05e93388b2c79f2ddf9f9 (patch) | |
| tree | ea9621cbf742772c4f7c7bcf7ee09025402cb8d2 /src | |
| parent | e5e51130ebb109f9e433139cab098454da676b8f (diff) | |
Finished Infer Widths
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/stanza/ir-utils.stanza | 49 | ||||
| -rw-r--r-- | src/main/stanza/passes.stanza | 506 |
2 files changed, 332 insertions, 223 deletions
diff --git a/src/main/stanza/ir-utils.stanza b/src/main/stanza/ir-utils.stanza index 0bce0a90..a0379206 100644 --- a/src/main/stanza/ir-utils.stanza +++ b/src/main/stanza/ir-utils.stanza @@ -306,30 +306,31 @@ defmethod mapr (f: Width -> Width, s:Stmt) -> Stmt : ;================= HELPER FUNCTIONS USING MAP =================== -; These don't work properly.. -;public defmulti do (f:Expression -> ?, e:Expression) -> False -;defmethod do (f:Expression -> ?, e:Expression) -> False : -; for x in e map : -; f(x) -; x -; false -; -;public defmulti do (f:Expression -> ?, s:Stmt) -> False -;defmethod do (f:Expression -> ?, s:Stmt) -> False : -; defn f* (x:Expression) : -; f(x) -; x -; map(f*,s) -; false -; -;public defmulti do (f:Stmt -> ?, s:Stmt) -> False -;defmethod do (f:Stmt -> ?, s:Stmt) -> False : -; defn f* (x:Stmt) : -; f(x) -; x -; map(f*,s) -; false -; +public defmulti do (f:Expression -> ?, e:Expression) -> False +defmethod do (f:Expression -> ?, e:Expression) -> False : + defn f* (x:Expression) : + f(x) + x + map(f*,e) + false + +public defmulti do (f:Expression -> ?, s:Stmt) -> False +defmethod do (f:Expression -> ?, s:Stmt) -> False : + defn f* (x:Expression) : + f(x) + x + map(f*,s) + false + +public defmulti do (f:Stmt -> ?, s:Stmt) -> False +defmethod do (f:Stmt -> ?, s:Stmt) -> False : + defn f* (x:Stmt) : + f(x) + x + map(f*,s) + false + +; Not well defined - usually use dor on fields of a recursive type ;public defmulti dor (f:Expression -> ?, e:Expression) -> False ;defmethod dor (f:Expression -> ?, e:Expression) -> False : ; f(e) diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza index 79374ef2..8b393c13 100644 --- a/src/main/stanza/passes.stanza +++ b/src/main/stanza/passes.stanza @@ -1266,11 +1266,12 @@ defn expand-whens (assign:HashTable<Symbol,SymbolicValue>, add(decs,Connect(ref,reg)) (k:InstanceKind) : val s = stmts[n] as DefInstance - val x = split(to-string(n),'.') + val x = to-symbol(split(to-string(n),'.')[0]) val f = to-symbol(split(to-string(n),'.')[1]) - val ref = WSubfield(module(s),f,bundle-field-type(type(module(s)),f),FEMALE) + val ref = WRef(x,type(module(s)),k,FEMALE) + val sref = WSubfield(ref,f,bundle-field-type(type(module(s)),f),FEMALE) if has-nul?(assign[n]) : println("Uninitialized: ~" % [to-string(n)]);TODO actually collect error - else : add(decs,Connect(ref,to-exp(assign[n]))) + else : add(decs,Connect(sref,to-exp(assign[n]))) (k) : val s = stmts[n] as DefWire val ref = WRef(n,type(s),k,FEMALE) @@ -1433,13 +1434,12 @@ defstruct MinusWidth <: Width : arg1 : Width arg2 : Width defstruct MaxWidth <: Width : - arg1 : Width - arg2 : Width + args : List<Width> public defmulti map<?T> (f: Width -> Width, w:?T&Width) -> T defmethod map (f: Width -> Width, w:Width) -> Width : match(w) : - (w:MaxWidth) : MaxWidth(f(arg1(w)),f(arg2(w))) + (w:MaxWidth) : MaxWidth(map(f,args(w))) (w:PlusWidth) : PlusWidth(f(arg1(w)),f(arg2(w))) (w:MinusWidth) : MinusWidth(f(arg1(w)),f(arg2(w))) (w) : w @@ -1447,7 +1447,7 @@ defmethod map (f: Width -> Width, w:Width) -> Width : defmethod print (o:OutputStream, w:VarWidth) : print(o,name(w)) defmethod print (o:OutputStream, w:MaxWidth) : - print-all(o,["max(" arg1(w) "," arg2(w) ")"]) + print-all(o,["max" args(w)]) defmethod print (o:OutputStream, w:PlusWidth) : print-all(o,[ arg1(w) " + " arg2(w)]) defmethod print (o:OutputStream, w:MinusWidth) : @@ -1459,46 +1459,81 @@ defstruct WGeq <: Constraint : exp : Width defmethod print (o:OutputStream, c:WGeq) : print-all(o,[ loc(c) " >= " exp(c)]) +defmethod equal? (w1:Width,w2:Width) -> True|False : + match(w1,w2) : + (w1:VarWidth,w2:VarWidth) : name(w1) == name(w2) + (w1:MaxWidth,w2:MaxWidth) : + label<True|False> ret : + if not length(args(w1)) == length(args(w2)) : ret(false) + else : + for w in args(w1) do : + if not contains?(args(w2),w) : ret(false) + ret(true) + (w1:IntWidth,w2:IntWidth) : width(w1) == width(w2) + (w1,w2) : false +defn apply (a:Int|False,b:Int|False, f: (Int,Int) -> Int) -> Int|False : + if a typeof Int and b typeof Int : f(a as Int, b as Int) + else : false + +; TODO: I should make MaxWidth take a variable list of arguments, which would make it easier to write the simplify function. It looks like there isn't a bug in the algorithm, but simplification reallllly speeds it up. defn solve-constraints (l:List<WGeq>) -> HashTable<Symbol,Int> : - defn contains? (n:Symbol,h:HashTable<Symbol,?>) -> True|False : - for x in h any? : key(x) == n - defn unique (ls:List<WGeq>) -> HashTable<Symbol,Width> : + defn contains? (n:Symbol,h:HashTable<Symbol,?>) -> True|False : key?(h,n) + defn make-unique (ls:List<WGeq>) -> HashTable<Symbol,Width> : val h = HashTable<Symbol,Width>(symbol-hash) for g in ls do : match(loc(g)) : (w:VarWidth) : val n = name(w) - if contains?(n,h) : h[n] = MaxWidth(exp(g),h[n]) + if contains?(n,h) : h[n] = MaxWidth(list(exp(g),h[n])) else : h[n] = exp(g) (w) : w h + defn simplify (w:Width) -> Width : + match(map(simplify,w)) : + (w:MaxWidth) : + val v = Vector<Width>() + for w* in args(w) do : + match(w*) : + (w:MaxWidth) : + for x in args(w) do : add(v,x) + (w) : add(v,w) + MaxWidth(unique(v)) + (w) : w defn substitute (w:Width,h:HashTable<Symbol,Width>) -> Width : - match(map(substitute{_,h},w)) : + ;println-all(["Substituting for [" w "]"]) + val w* = simplify(w) + ;println-all(["After Simplify: [" w "]"]) + match(map(substitute{_,h},simplify(w))) : (w:VarWidth) : + ;println("matched varwidth!") if contains?(name(w),h) : - val t = substitute(h[name(w)],h) + ;println("Contained!") + ;println-all(["Width: " w]) + ;println-all(["Accessed: " h[name(w)]]) + val t = simplify(substitute(h[name(w)],h)) + ;val t = h[name(w)] + ;println-all(["Width after sub: " t]) h[name(w)] = t t else : w + (w): + ;println-all(["not varwidth!" w]) + w + defn b-sub (w:Width,h:HashTable<Symbol,Width>) -> Width: + match(map(b-sub{_,h},w)) : + (w:VarWidth) : + if key?(h,name(w)) : h[name(w)] + else : w (w) : w defn remove-cycle (n:Symbol,w:Width) -> Width : - match(map(remove-cycle{n,_},w)) : - (w:MaxWidth) : - match(arg1(w),arg2(w)) : - (v1:VarWidth,v2:VarWidth) : - if name(v1) == n : arg2(w) - else if name(v2) == n : arg1(w) - else : w - (v:VarWidth,_) : - if name(v) == n : arg2(w) - else : w - (_,v:VarWidth) : - if name(v) == n : arg1(w) - else : w - (v1,v2) : w + ;println-all(["Removing cycle for " n " inside " w]) + val w* = match(map(remove-cycle{n,_},w)) : + (w:MaxWidth) : MaxWidth(to-list(filter({_ != VarWidth(n)},args(w)))) (w) : w + ;println-all(["After removing cycle for " n ", returning " w*]) + w* defn self-rec? (n:Symbol,w:Width) -> True|False : var has? = false defn look (w:Width) -> Width : @@ -1512,13 +1547,16 @@ defn solve-constraints (l:List<WGeq>) -> HashTable<Symbol,Int> : defn apply (a:Int|False,b:Int|False, f: (Int,Int) -> Int) -> Int|False : if a typeof Int and b typeof Int : f(a as Int, b as Int) else : false + defn apply-l (l:List<Int|False>,f:(Int,Int) -> Int) -> Int|False : + if length(l) == 0 : 0 + else : apply(head(l),apply-l(tail(l),f),f) defn max (a:Int,b:Int) -> Int : if a > b : a else : b defn solve (w:Width) -> Int|False : match(w) : (w:VarWidth) : false - (w:MaxWidth) : apply(solve(arg1(w)),solve(arg2(w)),max) + (w:MaxWidth) : apply-l(map(solve,args(w)),max) (w:PlusWidth) : apply(solve(arg1(w)),solve(arg2(w)),{_ + _}) (w:MinusWidth) : apply(solve(arg1(w)),solve(arg2(w)),{_ - _}) (w:IntWidth) : width(w) @@ -1535,34 +1573,56 @@ defn solve-constraints (l:List<WGeq>) -> HashTable<Symbol,Int> : ; 1) Continuous Solving (using triangular solving) ; 2) Remove Cycles ; 3) Move to solved if not self-recursive - val u = unique(l) - println("Unique Constraints") + val u = make-unique(l) + println("======== UNIQUE CONSTRAINTS ========") for x in u do : println(x) + println("====================================") val f = HashTable<Symbol,Width>(symbol-hash) val o = Vector<Symbol>() for x in u do : + ;println("==== SOLUTIONS TABLE ====") + ;for x in f do : println(x) + ;println("=========================") + val [n e] = [key(x) value(x)] - val e* = remove-cycle{n,_} $ substitute(e,f) + + val e-sub = substitute(e,f) + ;println(["Solving " n " => " e]) + ;println(["After Substitute: " n " => " e-sub]) + ;println("==== SOLUTIONS TABLE (Post Substitute) ====") + ;for x in f do : println(x) + ;println("=========================") + val e* = remove-cycle{n,_} $ e-sub + ;println(["After Remove Cycle: " n " => " e*]) if not self-rec?(n,e*) : + ;println-all(["Not rec!: " n " => " e*]) + ;println-all(["Adding [" n "=>" e* "] to Solutions Table"]) add(o,n) f[n] = e* - println("Forward Solved Constraints") - for x in f do : println(x) + ;println("Forward Solved Constraints") + ;for x in f do : println(x) ; Backwards Solve val b = HashTable<Symbol,Width>(symbol-hash) for i in (length(o) - 1) through 0 by -1 do : val n = o[i] - b[n] = substitute(f[n],b) - println("Backwards Solved Constraints") - for x in b do : println(x) + ;println-all(["SOLVE BACK: [" n " => " f[n] "]"]) + ;println("==== SOLUTIONS TABLE ====") + ;for x in b do : println(x) + ;println("=========================") + val e* = simplify(b-sub(f[n],b)) + ;println-all(["BACK RETURN: [" n " => " e* "]"]) + b[n] = e* + ;println("==== SOLUTIONS TABLE (Post backsolve) ====") + ;for x in b do : println(x) + ;println("=========================") ; Evaluate val e = evaluate(b) - println("Evaluated Constraints") - for x in e do : println(x) + ;println("Evaluated Constraints") + ;for x in e do : println(x) e defn width! (t:Type) -> Width : @@ -1570,186 +1630,234 @@ defn width! (t:Type) -> Width : (t:UIntType) : width(t) (t:SIntType) : width(t) (t) : error("No width!") -defn prim-width (e:DoPrim,v:Vector<WGeq>) -> Width : - defn e-width (e:Expression) -> Width : width!(type(e)) - defn wpc (l:List<Expression>,c:List<Int>) : PlusWidth(e-width(l[0]),IntWidth(c[0])) - defn wmc (l:List<Expression>,c:List<Int>) : MinusWidth(e-width(l[0]),IntWidth(c[0])) - defn maxw (l:List<Expression>) : MaxWidth(e-width(l[0]),e-width(l[1])) - defn mp1 (l:List<Expression>) : PlusWidth(MaxWidth(e-width(l[0]),e-width(l[1])),IntWidth(1)) - defn sum (l:List<Expression>) : PlusWidth(e-width(l[0]),e-width(l[1])) - switch {op(e) == _} : - ADD-UU-OP : mp1(args(e)) - ADD-US-OP : mp1(args(e)) - ADD-SU-OP : mp1(args(e)) - ADD-SS-OP : mp1(args(e)) - SUB-UU-OP : mp1(args(e)) - SUB-US-OP : mp1(args(e)) - SUB-SU-OP : mp1(args(e)) - SUB-SS-OP : mp1(args(e)) - MUL-UU-OP : sum(args(e)) - MUL-US-OP : sum(args(e)) - MUL-SU-OP : sum(args(e)) - MUL-SS-OP : sum(args(e)) - ;(p:DIV-UU-OP) : - ;(p:DIV-US-OP) : - ;(p:DIV-SU-OP) : - ;(p:DIV-SS-OP) : - ;(p:MOD-UU-OP) : - ;(p:MOD-US-OP) : - ;(p:MOD-SU-OP) : - ;(p:MOD-SS-OP) : - ;(p:QUO-UU-OP) : - ;(p:QUO-US-OP) : - ;(p:QUO-SU-OP) : - ;(p:QUO-SS-OP) : - ;(p:REM-UU-OP) : - ;(p:REM-US-OP) : - ;(p:REM-SU-OP) : - ;(p:REM-SS-OP) : - ADD-WRAP-UU-OP : maxw(args(e)) - ADD-WRAP-US-OP : maxw(args(e)) - ADD-WRAP-SU-OP : maxw(args(e)) - ADD-WRAP-SS-OP : maxw(args(e)) - SUB-WRAP-UU-OP : maxw(args(e)) - SUB-WRAP-US-OP : maxw(args(e)) - SUB-WRAP-SU-OP : maxw(args(e)) - SUB-WRAP-SS-OP : maxw(args(e)) - LESS-UU-OP : IntWidth(1) - LESS-US-OP : IntWidth(1) - LESS-SU-OP : IntWidth(1) - LESS-SS-OP : IntWidth(1) - LESS-EQ-UU-OP : IntWidth(1) - LESS-EQ-US-OP : IntWidth(1) - LESS-EQ-SU-OP : IntWidth(1) - LESS-EQ-SS-OP : IntWidth(1) - GREATER-UU-OP : IntWidth(1) - GREATER-US-OP : IntWidth(1) - GREATER-SU-OP : IntWidth(1) - GREATER-SS-OP : IntWidth(1) - GREATER-EQ-UU-OP : IntWidth(1) - GREATER-EQ-US-OP : IntWidth(1) - GREATER-EQ-SU-OP : IntWidth(1) - GREATER-EQ-SS-OP : IntWidth(1) - EQUAL-UU-OP : IntWidth(1) - EQUAL-SS-OP : IntWidth(1) - MUX-UU-OP : maxw(args(e)) - MUX-SS-OP : maxw(args(e)) - PAD-U-OP : IntWidth(consts(e)[0]) - PAD-S-OP : IntWidth(consts(e)[0]) - AS-UINT-U-OP : e-width(args(e)[0]) - AS-UINT-S-OP : e-width(args(e)[0]) - AS-SINT-U-OP : e-width(args(e)[0]) - AS-SINT-S-OP : e-width(args(e)[0]) - SHIFT-LEFT-U-OP : wpc(args(e),consts(e)) - SHIFT-LEFT-S-OP : wpc(args(e),consts(e)) - SHIFT-RIGHT-U-OP : wmc(args(e),consts(e)) - SHIFT-RIGHT-S-OP : wmc(args(e),consts(e)) - CONVERT-U-OP : PlusWidth(e-width(args(e)[0]),IntWidth(1)) - CONVERT-S-OP : e-width(args(e)[0]) - BIT-AND-OP : maxw(args(e)) - BIT-OR-OP : maxw(args(e)) - BIT-XOR-OP : maxw(args(e)) - CONCAT-OP : sum(args(e)) - BIT-SELECT-OP : IntWidth(1) - BITS-SELECT-OP : IntWidth(consts(e)[0] - consts(e)[1]) - -defn gen-constraints (c:Circuit, m:Module, v:Vector<WGeq>) -> Vector<WGeq> : - val h = HashTable<Symbol,Type>(symbol-hash) - - defn get-width (e:Expression) -> Width : - match(e) : - (e:WRef) : - val [wdec wref] = match(kind(e)) : - (k:InstanceKind) : error("Shouldn't be here") - (k:MemKind) : - [width! $ type $ (h[name(e)] as VectorType), - width! $ type $ (type(e) as VectorType)] - (k) : - [width! $ h[name(e)], - width! $ type(e)] - add(v,WGeq(wref,wdec)) - add(v,WGeq(wdec,wref)) - wref - (e:WSubfield) : ;assumes only subfields are instances - val wdec = width! $ bundle-field-type(h[name(exp(e) as WRef)],name(e)) - val wref = width! $ bundle-field-type(type(exp(e)),name(e)) - add(v,WGeq(wref,wdec)) - add(v,WGeq(wdec,wref)) - wref - (e:WIndex) : error("Shouldn't be here") - (e:UIntValue) : width(e) - (e:SIntValue) : width(e) - (e:DoPrim) : prim-width(e,v) - (e:ReadPort|WritePort) : - add(v,WGeq(get-width(enable(e)),IntWidth(1))) - get-width(mem(e)) - (e:Register) : - add(v,WGeq(get-width(enable(e)),IntWidth(1))) - val w = width!(type(e)) - add(v,WGeq(w,get-width(value(e)))) - w - defn gen-constraints (s:Stmt) -> Stmt : - match(map(gen-constraints,s)) : - (s:DefWire) : h[name(s)] = type(s) - (s:DefInstance) : h[name(s)] = h[name(module(s) as WRef)] - (s:DefMemory) : h[name(s)] = type(s) - (s:DefNode) : h[name(s)] = type(value(s)) +defn width! (e:Expression) -> Width : width!(type(e)) + +defn gen-constraints (m:Module, h:HashTable<Symbol,Type>, v:Vector<WGeq>) -> Module: + defn prim-type (e:DoPrim,v:Vector<WGeq>) -> Type : + defn add-c (w:Width) -> Type: + val w* = VarWidth(gensym(`w)) + add(v,WGeq(w*,w)) + add(v,WGeq(w,w*)) + match(type(e)) : + (t:UIntType) : UIntType(w*) + (t:SIntType) : SIntType(w*) + (t) : error("Shouldn't be here") + defn wpc (l:List<Expression>,c:List<Int>) : + add-c(PlusWidth(width!(l[0]),IntWidth(c[0]))) + defn wmc (l:List<Expression>,c:List<Int>) : + add-c(MinusWidth(width!(l[0]),IntWidth(c[0]))) + defn maxw (l:List<Expression>) : + add-c(MaxWidth(list(width!(l[0]),width!(l[1])))) + defn cons (ls:List<Expression>) : + val l = width!(ls[0]) + val r = width!(ls[1]) + add(v,WGeq(l,r)) + add(v,WGeq(r,l)) + add-c(l) + add-c(r) + defn mp1 (l:List<Expression>) : + add-c(PlusWidth(MaxWidth(list(width!(l[0]),width!(l[1]))),IntWidth(1))) + defn sum (l:List<Expression>) : + add-c(PlusWidth(width!(l[0]),width!(l[1]))) + + println-all(["Looking at " op(e) " with inputs " args(e)]) + switch {op(e) == _} : + ADD-UU-OP : mp1(args(e)) + ADD-US-OP : mp1(args(e)) + ADD-SU-OP : mp1(args(e)) + ADD-SS-OP : mp1(args(e)) + SUB-UU-OP : mp1(args(e)) + SUB-US-OP : mp1(args(e)) + SUB-SU-OP : mp1(args(e)) + SUB-SS-OP : mp1(args(e)) + MUL-UU-OP : sum(args(e)) + MUL-US-OP : sum(args(e)) + MUL-SU-OP : sum(args(e)) + MUL-SS-OP : sum(args(e)) + ;(p:DIV-UU-OP) : + ;(p:DIV-US-OP) : + ;(p:DIV-SU-OP) : + ;(p:DIV-SS-OP) : + ;(p:MOD-UU-OP) : + ;(p:MOD-US-OP) : + ;(p:MOD-SU-OP) : + ;(p:MOD-SS-OP) : + ;(p:QUO-UU-OP) : + ;(p:QUO-US-OP) : + ;(p:QUO-SU-OP) : + ;(p:QUO-SS-OP) : + ;(p:REM-UU-OP) : + ;(p:REM-US-OP) : + ;(p:REM-SU-OP) : + ;(p:REM-SS-OP) : + ADD-WRAP-UU-OP : maxw(args(e)) + ADD-WRAP-US-OP : maxw(args(e)) + ADD-WRAP-SU-OP : maxw(args(e)) + ADD-WRAP-SS-OP : maxw(args(e)) + SUB-WRAP-UU-OP : maxw(args(e)) + SUB-WRAP-US-OP : maxw(args(e)) + SUB-WRAP-SU-OP : maxw(args(e)) + SUB-WRAP-SS-OP : maxw(args(e)) + LESS-UU-OP : add-c(IntWidth(1)) + LESS-US-OP : add-c(IntWidth(1)) + LESS-SU-OP : add-c(IntWidth(1)) + LESS-SS-OP : add-c(IntWidth(1)) + LESS-EQ-UU-OP : add-c(IntWidth(1)) + LESS-EQ-US-OP : add-c(IntWidth(1)) + LESS-EQ-SU-OP : add-c(IntWidth(1)) + LESS-EQ-SS-OP : add-c(IntWidth(1)) + GREATER-UU-OP : add-c(IntWidth(1)) + GREATER-US-OP : add-c(IntWidth(1)) + GREATER-SU-OP : add-c(IntWidth(1)) + GREATER-SS-OP : add-c(IntWidth(1)) + GREATER-EQ-UU-OP : add-c(IntWidth(1)) + GREATER-EQ-US-OP : add-c(IntWidth(1)) + GREATER-EQ-SU-OP : add-c(IntWidth(1)) + GREATER-EQ-SS-OP : add-c(IntWidth(1)) + EQUAL-UU-OP : add-c(IntWidth(1)) + EQUAL-SS-OP : add-c(IntWidth(1)) + MUX-UU-OP : cons(args(e)) + MUX-SS-OP : cons(args(e)) + PAD-U-OP : add-c(IntWidth(consts(e)[0])) + PAD-S-OP : add-c(IntWidth(consts(e)[0])) + AS-UINT-U-OP : add-c(width!(args(e)[0])) + AS-UINT-S-OP : add-c(width!(args(e)[0])) + AS-SINT-U-OP : add-c(width!(args(e)[0])) + AS-SINT-S-OP : add-c(width!(args(e)[0])) + SHIFT-LEFT-U-OP : wpc(args(e),consts(e)) + SHIFT-LEFT-S-OP : wpc(args(e),consts(e)) + SHIFT-RIGHT-U-OP : wmc(args(e),consts(e)) + SHIFT-RIGHT-S-OP : wmc(args(e),consts(e)) + CONVERT-U-OP : add-c(PlusWidth(width!(args(e)[0]),IntWidth(1))) + CONVERT-S-OP : add-c(width!(args(e)[0])) + BIT-AND-OP : maxw(args(e)) + BIT-OR-OP : maxw(args(e)) + BIT-XOR-OP : maxw(args(e)) + CONCAT-OP : sum(args(e)) + BIT-SELECT-OP : add-c(IntWidth(1)) + BITS-SELECT-OP : add-c(IntWidth(consts(e)[0] - consts(e)[1])) + + defn gen-constraints-s (s:Stmt) -> Stmt : + match(map(gen-constraints-s,s)) : + (s:DefWire) : DefWire(name(s),h[name(s)]) + (s:DefInstance) : DefInstance(name(s),gen-constraints(module(s))) + (s:DefMemory) : DefMemory(name(s),h[name(s)] as VectorType) + (s:DefNode) : DefNode(name(s),gen-constraints(value(s))) (s:Connect) : - add(v,WGeq(get-width(loc(s)),get-width(exp(s)))) - add(v,WGeq(get-width(exp(s)),get-width(loc(s)))) - (s) : "" - s + val l = gen-constraints(loc(s)) + val e = gen-constraints(exp(s)) + add(v,WGeq(width!(type(l)),width!(type(e)))) + add(v,WGeq(width!(type(e)),width!(type(l)))) + Connect(l,e) + (s) : s + + defn gen-constraints (e:Expression) -> Expression : + match(map(gen-constraints,e)) : + (e:WRef) : WRef(name(e),h[name(e)],kind(e),gender(e)) + (e:WSubfield) : WSubfield(exp(e),name(e),bundle-field-type(type(exp(e)),name(e)),gender(e)) + (e:WIndex) : error("Shouldn't be here") + (e:DoPrim) : DoPrim(op(e),args(e),consts(e),prim-type(e,v)) + (e:ReadPort) : ReadPort(mem(e),index(e),type(type(mem(e)) as VectorType),enable(e)) + (e:WritePort) : WritePort(mem(e),index(e),type(type(mem(e)) as VectorType),enable(e)) + (e:Register) : Register(type(value(e)),value(e),enable(e)) + (e:UIntValue) : + match(width(e)) : + (w:UnknownWidth) : UIntValue(value(e),VarWidth(gensym(`w))) + (w) : e + (e:SIntValue) : + match(width(e)) : + (w:UnknownWidth) : SIntValue(value(e),VarWidth(gensym(`w))) + (w) : e + (e) : e - for m in modules(c) do : - h[name(m)] = BundleType(map(to-field,ports(m))) + val ports* = + for p in ports(m) map : Port(name(p),direction(p),h[name(p)]) + + Module(name(m),ports*,gen-constraints-s(body(m))) + +defn build-environment (c:Circuit,m:Module,h:HashTable<Symbol,Type>) -> HashTable<Symbol,Type> : + defn build-environment (s:Stmt) -> False : + match(s) : + (s:DefWire) : h[name(s)] = remove-unknowns(type(s)) + (s:DefInstance) : h[name(s)] = h[name(module(s) as WRef)] + (s:DefMemory) : h[name(s)] = remove-unknowns(type(s)) + (s:DefNode) : h[name(s)] = remove-unknowns(type(value(s))) + (s) : false + do(build-environment,s) for p in ports(m) do : - h[name(p)] = type(p) - gen-constraints(body(m)) - v + h[name(p)] = bundle-field-type(h[name(m)],name(p)) + build-environment(body(m)) + h defn replace-var-widths (c:Circuit,h:HashTable<Symbol,Int>) -> Circuit : defn replace-var-widths-w (w:Width) -> Width : - defn contains? (n:Symbol,h:HashTable<Symbol,?>) -> True|False : - for x in h any? : key(x) == n - match(w) : + println-all(["REPLACE: " w]) + val w* = match(w) : (w:VarWidth) : - if contains?(name(w),h) : IntWidth(h[name(w)]) + if key?(h,name(w)) : IntWidth(h[name(w)]) else: w (w) : w + println-all(["WITH: " w*]) + w* + val modules* = for m in modules(c) map : - Module{name(m),_,body(m)} $ + Module{name(m),_,mapr(replace-var-widths-w,body(m))} $ for p in ports(m) map : Port(name(p),direction(p),mapr(replace-var-widths-w,type(p))) - val modules** = for m in modules* map : - Module(name(m),ports(m),mapr(replace-var-widths-w,body(m))) - Circuit(modules**,main(c)) + Circuit(modules*,main(c)) -defn remove-unknown-widths (c:Circuit) -> Circuit : - defn remove-unknown-widths-w (w:Width) -> Width : - match(w) : - (w:UnknownWidth) : VarWidth(gensym(`w)) - (w) : w - val modules* = for m in modules(c) map : - Module{name(m),_,body(m)} $ - for p in ports(m) map : - Port(name(p),direction(p),mapr(remove-unknown-widths-w,type(p))) - - val modules** = for m in modules* map : - Module(name(m),ports(m),mapr(remove-unknown-widths-w,body(m))) - Circuit(modules**,main(c)) +;defn remove-unknown-widths (c:Circuit) -> Circuit : +; defn remove-unknown-widths-w (w:Width) -> Width : +; match(w) : +; (w:UnknownWidth) : VarWidth(gensym(`w)) +; (w) : w +; val modules* = for m in modules(c) map : +; Module{name(m),_,body(m)} $ +; for p in ports(m) map : +; Port(name(p),direction(p),mapr(remove-unknown-widths-w,type(p))) +; +; val modules** = for m in modules* map : +; Module(name(m),ports(m),mapr(remove-unknown-widths-w,body(m))) +; Circuit(modules**,main(c)) +defn remove-unknowns-w (w:Width) -> Width : + match(w) : + (w:UnknownWidth) : VarWidth(gensym(`w)) + (w) : w +defn remove-unknowns (t:Type) -> Type : mapr(remove-unknowns-w,t) + defn infer-widths (c:Circuit) -> Circuit : - val c* = remove-unknown-widths(c) + defn deepcopy (t:HashTable<Symbol,Type>) -> HashTable<Symbol,Type> : + t0 where : + val t0 = HashTable<Symbol,Type>(symbol-hash) + for x in t do : + t0[key(x)] = value(x) + + ;val c* = remove-unknown-widths(c) + ;println(c*) val v = Vector<WGeq>() - for m in modules(c*) do : - gen-constraints(c*,m,v) + val ports* = HashTable<Symbol,Type>(symbol-hash) + for m in modules(c) do : + ports*[name(m)] = remove-unknowns(BundleType(map(to-field,ports(m)))) + val modules* = for m in modules(c) map : + println-all(["====== MODULE(" name(m) ") ENV ======"]) + val h = build-environment(c,m,deepcopy(ports*)) + for x in h do: println(x) + println-all(["====================================="]) + val m* = gen-constraints(m,h,v) + println-all(["====== MODULE(" name(m) ") ======"]) + println(m*) + println-all(["====================================="]) + m* + println("======== ALL CONSTRAINTS ========") for x in v do : println(x) + println("=================================") val h = solve-constraints(to-list(v)) - println("Solved Constraints") - ;for x in h do : println(x) - ;replace-var-widths(c*,h) - c* + println("======== SOLVED CONSTRAINTS ========") + for x in h do : println(x) + println("====================================") + replace-var-widths(Circuit(modules*,main(c)),h) |
