aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorazidar2015-04-10 16:56:36 -0700
committerazidar2015-04-10 16:56:36 -0700
commit75be996a7d26778aa6ed3c02db617b4f0516537c (patch)
tree4ad612a5907a9ea98deea42588f05bfe26d3878c /src
parenta604e0789a85d8b3c5d6def2f9860047f479b68a (diff)
Almost finished width inference, takes too long/infinite loop for gcd
Diffstat (limited to 'src')
-rw-r--r--src/main/stanza/ir-utils.stanza16
-rw-r--r--src/main/stanza/passes.stanza182
2 files changed, 162 insertions, 36 deletions
diff --git a/src/main/stanza/ir-utils.stanza b/src/main/stanza/ir-utils.stanza
index 1f33547e..0bce0a90 100644
--- a/src/main/stanza/ir-utils.stanza
+++ b/src/main/stanza/ir-utils.stanza
@@ -289,6 +289,22 @@ defmethod map (f: Type -> Type, c:Stmt) -> Stmt :
(c:DefMemory) : DefRegister(name(c),f(type(c)))
(c) : c
+public defmulti mapr<?T> (f: Width -> Width, t:?T&Type) -> T
+defmethod mapr (f: Width -> Width, t:Type) -> Type :
+ defn apply-t (t:Type) -> Type :
+ map{f,_} $ map(apply-t,t)
+ apply-t(t)
+
+public defmulti mapr<?T> (f: Width -> Width, s:?T&Stmt) -> T
+defmethod mapr (f: Width -> Width, s:Stmt) -> Stmt :
+ defn apply-t (t:Type) -> Type : mapr(f,t)
+ defn apply-e (e:Expression) -> Expression :
+ map{f,_} $ map{apply-t,_} $ map(apply-e,e)
+ defn apply-s (s:Stmt) -> Stmt :
+ map{apply-t,_} $ map{apply-e,_} $ map(apply-s,s)
+ apply-s(s)
+
+
;================= HELPER FUNCTIONS USING MAP ===================
; These don't work properly..
;public defmulti do (f:Expression -> ?, e:Expression) -> False
diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza
index 313191c8..e9e5bfc6 100644
--- a/src/main/stanza/passes.stanza
+++ b/src/main/stanza/passes.stanza
@@ -1436,6 +1436,14 @@ defstruct MaxWidth <: Width :
arg1 : Width
arg2 : Width
+public defmulti map<?T> (f: Width -> Width, w:?T&Width) -> T
+defmethod map (f: Width -> Width, w:Width) -> Width :
+ match(w) :
+ (w:MaxWidth) : MaxWidth(f(arg1(w)),f(arg2(w)))
+ (w:PlusWidth) : PlusWidth(f(arg1(w)),f(arg2(w)))
+ (w:MinusWidth) : MinusWidth(f(arg1(w)),f(arg2(w)))
+ (w) : w
+
defmethod print (o:OutputStream, w:VarWidth) :
print(o,name(w))
defmethod print (o:OutputStream, w:MaxWidth) :
@@ -1449,18 +1457,115 @@ definterface Constraint
defstruct WGeq <: Constraint :
loc : Width
exp : Width
-defstruct WEq <: Constraint :
- loc : Width
- exp : Width
defmethod print (o:OutputStream, c:WGeq) :
print-all(o,[ loc(c) " >= " exp(c)])
-defmethod print (o:OutputStream, c:WEq) :
- print-all(o,[ loc(c) " = " exp(c)])
-;defn remove-var-widths (c:Circuit) -> Circuit :
-
-;defn solve-constraints (l:List<WGeq>) -> HashTable<Symbol,Int> :
+defn solve-constraints (l:List<WGeq>) -> HashTable<Symbol,Int> :
+ defn contains? (n:Symbol,h:HashTable<Symbol,?>) -> True|False :
+ for x in h any? : key(x) == n
+ defn unique (ls:List<WGeq>) -> HashTable<Symbol,Width> :
+ val h = HashTable<Symbol,Width>(symbol-hash)
+ for g in ls do :
+ match(loc(g)) :
+ (w:VarWidth) :
+ val n = name(w)
+ if contains?(n,h) : h[n] = MaxWidth(exp(g),h[n])
+ else : h[n] = exp(g)
+ (w) : w
+ h
+ defn substitute (w:Width,h:HashTable<Symbol,Width>) -> Width :
+ match(map(substitute{_,h},w)) :
+ (w:VarWidth) :
+ if contains?(name(w),h) :
+ val t = substitute(h[name(w)],h)
+ h[name(w)] = t
+ t
+ else : w
+ (w) : w
+ defn remove-cycle (n:Symbol,w:Width) -> Width :
+ match(map(remove-cycle{n,_},w)) :
+ (w:MaxWidth) :
+ match(arg1(w),arg2(w)) :
+ (v1:VarWidth,v2:VarWidth) :
+ if name(v1) == n : arg2(w)
+ else if name(v2) == n : arg1(w)
+ else : w
+ (v:VarWidth,_) :
+ if name(v) == n : arg2(w)
+ else : w
+ (_,v:VarWidth) :
+ if name(v) == n : arg1(w)
+ else : w
+ (v1,v2) : w
+ (w) : w
+ defn self-rec? (n:Symbol,w:Width) -> True|False :
+ var has? = false
+ defn look (w:Width) -> Width :
+ match(map(look,w)) :
+ (w:VarWidth) : if name(w) == n : has? = true
+ (w) : w
+ w
+ look(w)
+ has?
+ defn evaluate (h:HashTable<Symbol,Width>) -> HashTable<Symbol,Int> :
+ defn apply (a:Int|False,b:Int|False, f: (Int,Int) -> Int) -> Int|False :
+ if a typeof Int and b typeof Int : f(a as Int, b as Int)
+ else : false
+ defn max (a:Int,b:Int) -> Int :
+ if a > b : a
+ else : b
+ defn solve (w:Width) -> Int|False :
+ match(w) :
+ (w:VarWidth) : false
+ (w:MaxWidth) : apply(solve(arg1(w)),solve(arg2(w)),max)
+ (w:PlusWidth) : apply(solve(arg1(w)),solve(arg2(w)),{_ + _})
+ (w:MinusWidth) : apply(solve(arg1(w)),solve(arg2(w)),{_ - _})
+ (w:IntWidth) : width(w)
+ (w) : error("Shouldn't be here")
+
+ val i = HashTable<Symbol,Int>(symbol-hash)
+ for x in h do :
+ val s = solve(value(x))
+ if s typeof Int : i[key(x)] = s as Int
+ i
+
+ ; Forward solve
+ ; Returns a solved list where each constraint undergoes:
+ ; 1) Continuous Solving (using triangular solving)
+ ; 2) Remove Cycles
+ ; 3) Move to solved if not self-recursive
+ val u = unique(l)
+ println("Unique Constraints")
+ for x in u do : println(x)
+
+ val f = HashTable<Symbol,Width>(symbol-hash)
+ val o = Vector<Symbol>()
+ for x in u do :
+ val [n e] = [key(x) value(x)]
+ val e* = substitute(e,f)
+ ;val e* = remove-cycle{n,_} $ substitute(e,f)
+ ;if not self-rec?(n,e*) :
+ add(o,n)
+ f[n] = e*
+
+ println("Forward Solved Constraints")
+ for x in f do : println(x)
+
+ ; Backwards Solve
+ ;val b = HashTable<Symbol,Width>(symbol-hash)
+ ;for i in (length(o) - 1) through 0 by -1 do :
+ ; val n = o[i]
+ ; b[n] = substitute(f[n],b)
+
+ ;println("Backwards Solved Constraints")
+ ;for x in b do : println(x)
+ ;; Evaluate
+ ;val e = evaluate(b)
+ ;println("Evaluated Constraints")
+ ;for x in e do : println(x)
+ ;e
+ HashTable<Symbol,Int>(symbol-hash)
defn width! (t:Type) -> Width :
match(t) :
@@ -1550,8 +1655,7 @@ defn prim-width (e:DoPrim,v:Vector<WGeq>) -> Width :
BIT-SELECT-OP : IntWidth(1)
BITS-SELECT-OP : IntWidth(consts(e)[0] - consts(e)[1])
-defn gen-constraints (c:Circuit, m:Module) -> List<WGeq> :
- val v = Vector<WGeq>()
+defn gen-constraints (c:Circuit, m:Module, v:Vector<WGeq>) -> Vector<WGeq> :
val h = HashTable<Symbol,Type>(symbol-hash)
defn get-width (e:Expression) -> Width :
@@ -1565,28 +1669,25 @@ defn gen-constraints (c:Circuit, m:Module) -> List<WGeq> :
(k) :
[width! $ h[name(e)],
width! $ type(e)]
- add(v,WGeq(wdec,wref))
add(v,WGeq(wref,wdec))
- wdec
+ add(v,WGeq(wdec,wref))
+ wref
(e:WSubfield) : ;assumes only subfields are instances
val wdec = width! $ bundle-field-type(h[name(exp(e) as WRef)],name(e))
val wref = width! $ bundle-field-type(type(exp(e)),name(e))
- add(v,WGeq(wdec,wref))
add(v,WGeq(wref,wdec))
- wdec
+ add(v,WGeq(wdec,wref))
+ wref
(e:WIndex) : error("Shouldn't be here")
(e:UIntValue) : width(e)
(e:SIntValue) : width(e)
(e:DoPrim) : prim-width(e,v)
(e:ReadPort|WritePort) :
add(v,WGeq(get-width(enable(e)),IntWidth(1)))
- add(v,WGeq(IntWidth(1),get-width(enable(e))))
get-width(mem(e))
(e:Register) :
add(v,WGeq(get-width(enable(e)),IntWidth(1)))
- add(v,WGeq(IntWidth(1),get-width(enable(e))))
val w = width!(type(e))
- add(v,WGeq(get-width(value(e)),w))
add(v,WGeq(w,get-width(value(e))))
w
defn gen-constraints (s:Stmt) -> Stmt :
@@ -1606,42 +1707,51 @@ defn gen-constraints (c:Circuit, m:Module) -> List<WGeq> :
for p in ports(m) do :
h[name(p)] = type(p)
gen-constraints(body(m))
- to-list(v)
+ v
+defn replace-var-widths (c:Circuit,h:HashTable<Symbol,Int>) -> Circuit :
+ defn replace-var-widths-w (w:Width) -> Width :
+ defn contains? (n:Symbol,h:HashTable<Symbol,?>) -> True|False :
+ for x in h any? : key(x) == n
+ match(w) :
+ (w:VarWidth) :
+ if contains?(name(w),h) : IntWidth(h[name(w)])
+ else: w
+ (w) : w
+ val modules* = for m in modules(c) map :
+ Module{name(m),_,body(m)} $
+ for p in ports(m) map :
+ Port(name(p),direction(p),mapr(replace-var-widths-w,type(p)))
+
+ val modules** = for m in modules* map :
+ Module(name(m),ports(m),mapr(replace-var-widths-w,body(m)))
+ Circuit(modules**,main(c))
defn remove-unknown-widths (c:Circuit) -> Circuit :
- defn remove-unknown-widths-wid (w:Width) -> Width :
+ defn remove-unknown-widths-w (w:Width) -> Width :
match(w) :
(w:UnknownWidth) : VarWidth(gensym(`w))
(w) : w
- defn remove-unknown-widths-type (t:Type) -> Type :
- map{remove-unknown-widths-wid,_} $
- map(remove-unknown-widths-type,t)
- defn remove-unknown-widths-exp (e:Expression) -> Expression :
- map{remove-unknown-widths-wid,_} $
- map{remove-unknown-widths-type,_} $
- map(remove-unknown-widths-exp,e)
- defn remove-unknown-widths-stmt (s:Stmt) -> Stmt :
- map{remove-unknown-widths-type,_} $
- map{remove-unknown-widths-exp,_} $
- map(remove-unknown-widths-stmt,s)
val modules* = for m in modules(c) map :
Module{name(m),_,body(m)} $
for p in ports(m) map :
- Port(name(p),direction(p),remove-unknown-widths-type(type(p)))
+ Port(name(p),direction(p),mapr(remove-unknown-widths-w,type(p)))
val modules** = for m in modules* map :
- Module(name(m),ports(m),remove-unknown-widths-stmt(body(m)))
+ Module(name(m),ports(m),mapr(remove-unknown-widths-w,body(m)))
Circuit(modules**,main(c))
defn infer-widths (c:Circuit) -> Circuit :
val c* = remove-unknown-widths(c)
+ val v = Vector<WGeq>()
for m in modules(c*) do :
- val l = gen-constraints(c*,m)
- for x in l do : println(x)
- c*
- ;val h = solve-constraints(l)
+ gen-constraints(c*,m,v)
+ for x in v do : println(x)
+ val h = solve-constraints(to-list(v))
+ println("Solved Constraints")
+ ;for x in h do : println(x)
;replace-var-widths(c*,h)
+ c*