aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorazidar2015-04-10 16:56:36 -0700
committerazidar2015-04-10 16:56:36 -0700
commit75be996a7d26778aa6ed3c02db617b4f0516537c (patch)
tree4ad612a5907a9ea98deea42588f05bfe26d3878c
parenta604e0789a85d8b3c5d6def2f9860047f479b68a (diff)
Almost finished width inference, takes too long/infinite loop for gcd
-rw-r--r--TODO2
-rw-r--r--src/main/stanza/ir-utils.stanza16
-rw-r--r--src/main/stanza/passes.stanza182
-rw-r--r--test/passes/infer-widths/simple.fir5
4 files changed, 165 insertions, 40 deletions
diff --git a/TODO b/TODO
index dacac142..3d9beac9 100644
--- a/TODO
+++ b/TODO
@@ -1,7 +1,7 @@
TODO
Change parser to use <> syntax (and update all tests) (patrick)
Think about on-reset, add it
- Think about expanding mems, more complicated than first glance!
+ Think about expanding mems, more complicated than first glance! Change DefMem to be size, and element type
Think about max(op1,op2) for muxes.
Make instances always male, flip the bundles on declaration
Talk to palmer/patrick about how writing passes is going to be supported
diff --git a/src/main/stanza/ir-utils.stanza b/src/main/stanza/ir-utils.stanza
index 1f33547e..0bce0a90 100644
--- a/src/main/stanza/ir-utils.stanza
+++ b/src/main/stanza/ir-utils.stanza
@@ -289,6 +289,22 @@ defmethod map (f: Type -> Type, c:Stmt) -> Stmt :
(c:DefMemory) : DefRegister(name(c),f(type(c)))
(c) : c
+public defmulti mapr<?T> (f: Width -> Width, t:?T&Type) -> T
+defmethod mapr (f: Width -> Width, t:Type) -> Type :
+ defn apply-t (t:Type) -> Type :
+ map{f,_} $ map(apply-t,t)
+ apply-t(t)
+
+public defmulti mapr<?T> (f: Width -> Width, s:?T&Stmt) -> T
+defmethod mapr (f: Width -> Width, s:Stmt) -> Stmt :
+ defn apply-t (t:Type) -> Type : mapr(f,t)
+ defn apply-e (e:Expression) -> Expression :
+ map{f,_} $ map{apply-t,_} $ map(apply-e,e)
+ defn apply-s (s:Stmt) -> Stmt :
+ map{apply-t,_} $ map{apply-e,_} $ map(apply-s,s)
+ apply-s(s)
+
+
;================= HELPER FUNCTIONS USING MAP ===================
; These don't work properly..
;public defmulti do (f:Expression -> ?, e:Expression) -> False
diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza
index 313191c8..e9e5bfc6 100644
--- a/src/main/stanza/passes.stanza
+++ b/src/main/stanza/passes.stanza
@@ -1436,6 +1436,14 @@ defstruct MaxWidth <: Width :
arg1 : Width
arg2 : Width
+public defmulti map<?T> (f: Width -> Width, w:?T&Width) -> T
+defmethod map (f: Width -> Width, w:Width) -> Width :
+ match(w) :
+ (w:MaxWidth) : MaxWidth(f(arg1(w)),f(arg2(w)))
+ (w:PlusWidth) : PlusWidth(f(arg1(w)),f(arg2(w)))
+ (w:MinusWidth) : MinusWidth(f(arg1(w)),f(arg2(w)))
+ (w) : w
+
defmethod print (o:OutputStream, w:VarWidth) :
print(o,name(w))
defmethod print (o:OutputStream, w:MaxWidth) :
@@ -1449,18 +1457,115 @@ definterface Constraint
defstruct WGeq <: Constraint :
loc : Width
exp : Width
-defstruct WEq <: Constraint :
- loc : Width
- exp : Width
defmethod print (o:OutputStream, c:WGeq) :
print-all(o,[ loc(c) " >= " exp(c)])
-defmethod print (o:OutputStream, c:WEq) :
- print-all(o,[ loc(c) " = " exp(c)])
-;defn remove-var-widths (c:Circuit) -> Circuit :
-
-;defn solve-constraints (l:List<WGeq>) -> HashTable<Symbol,Int> :
+defn solve-constraints (l:List<WGeq>) -> HashTable<Symbol,Int> :
+ defn contains? (n:Symbol,h:HashTable<Symbol,?>) -> True|False :
+ for x in h any? : key(x) == n
+ defn unique (ls:List<WGeq>) -> HashTable<Symbol,Width> :
+ val h = HashTable<Symbol,Width>(symbol-hash)
+ for g in ls do :
+ match(loc(g)) :
+ (w:VarWidth) :
+ val n = name(w)
+ if contains?(n,h) : h[n] = MaxWidth(exp(g),h[n])
+ else : h[n] = exp(g)
+ (w) : w
+ h
+ defn substitute (w:Width,h:HashTable<Symbol,Width>) -> Width :
+ match(map(substitute{_,h},w)) :
+ (w:VarWidth) :
+ if contains?(name(w),h) :
+ val t = substitute(h[name(w)],h)
+ h[name(w)] = t
+ t
+ else : w
+ (w) : w
+ defn remove-cycle (n:Symbol,w:Width) -> Width :
+ match(map(remove-cycle{n,_},w)) :
+ (w:MaxWidth) :
+ match(arg1(w),arg2(w)) :
+ (v1:VarWidth,v2:VarWidth) :
+ if name(v1) == n : arg2(w)
+ else if name(v2) == n : arg1(w)
+ else : w
+ (v:VarWidth,_) :
+ if name(v) == n : arg2(w)
+ else : w
+ (_,v:VarWidth) :
+ if name(v) == n : arg1(w)
+ else : w
+ (v1,v2) : w
+ (w) : w
+ defn self-rec? (n:Symbol,w:Width) -> True|False :
+ var has? = false
+ defn look (w:Width) -> Width :
+ match(map(look,w)) :
+ (w:VarWidth) : if name(w) == n : has? = true
+ (w) : w
+ w
+ look(w)
+ has?
+ defn evaluate (h:HashTable<Symbol,Width>) -> HashTable<Symbol,Int> :
+ defn apply (a:Int|False,b:Int|False, f: (Int,Int) -> Int) -> Int|False :
+ if a typeof Int and b typeof Int : f(a as Int, b as Int)
+ else : false
+ defn max (a:Int,b:Int) -> Int :
+ if a > b : a
+ else : b
+ defn solve (w:Width) -> Int|False :
+ match(w) :
+ (w:VarWidth) : false
+ (w:MaxWidth) : apply(solve(arg1(w)),solve(arg2(w)),max)
+ (w:PlusWidth) : apply(solve(arg1(w)),solve(arg2(w)),{_ + _})
+ (w:MinusWidth) : apply(solve(arg1(w)),solve(arg2(w)),{_ - _})
+ (w:IntWidth) : width(w)
+ (w) : error("Shouldn't be here")
+
+ val i = HashTable<Symbol,Int>(symbol-hash)
+ for x in h do :
+ val s = solve(value(x))
+ if s typeof Int : i[key(x)] = s as Int
+ i
+
+ ; Forward solve
+ ; Returns a solved list where each constraint undergoes:
+ ; 1) Continuous Solving (using triangular solving)
+ ; 2) Remove Cycles
+ ; 3) Move to solved if not self-recursive
+ val u = unique(l)
+ println("Unique Constraints")
+ for x in u do : println(x)
+
+ val f = HashTable<Symbol,Width>(symbol-hash)
+ val o = Vector<Symbol>()
+ for x in u do :
+ val [n e] = [key(x) value(x)]
+ val e* = substitute(e,f)
+ ;val e* = remove-cycle{n,_} $ substitute(e,f)
+ ;if not self-rec?(n,e*) :
+ add(o,n)
+ f[n] = e*
+
+ println("Forward Solved Constraints")
+ for x in f do : println(x)
+
+ ; Backwards Solve
+ ;val b = HashTable<Symbol,Width>(symbol-hash)
+ ;for i in (length(o) - 1) through 0 by -1 do :
+ ; val n = o[i]
+ ; b[n] = substitute(f[n],b)
+
+ ;println("Backwards Solved Constraints")
+ ;for x in b do : println(x)
+ ;; Evaluate
+ ;val e = evaluate(b)
+ ;println("Evaluated Constraints")
+ ;for x in e do : println(x)
+ ;e
+ HashTable<Symbol,Int>(symbol-hash)
defn width! (t:Type) -> Width :
match(t) :
@@ -1550,8 +1655,7 @@ defn prim-width (e:DoPrim,v:Vector<WGeq>) -> Width :
BIT-SELECT-OP : IntWidth(1)
BITS-SELECT-OP : IntWidth(consts(e)[0] - consts(e)[1])
-defn gen-constraints (c:Circuit, m:Module) -> List<WGeq> :
- val v = Vector<WGeq>()
+defn gen-constraints (c:Circuit, m:Module, v:Vector<WGeq>) -> Vector<WGeq> :
val h = HashTable<Symbol,Type>(symbol-hash)
defn get-width (e:Expression) -> Width :
@@ -1565,28 +1669,25 @@ defn gen-constraints (c:Circuit, m:Module) -> List<WGeq> :
(k) :
[width! $ h[name(e)],
width! $ type(e)]
- add(v,WGeq(wdec,wref))
add(v,WGeq(wref,wdec))
- wdec
+ add(v,WGeq(wdec,wref))
+ wref
(e:WSubfield) : ;assumes only subfields are instances
val wdec = width! $ bundle-field-type(h[name(exp(e) as WRef)],name(e))
val wref = width! $ bundle-field-type(type(exp(e)),name(e))
- add(v,WGeq(wdec,wref))
add(v,WGeq(wref,wdec))
- wdec
+ add(v,WGeq(wdec,wref))
+ wref
(e:WIndex) : error("Shouldn't be here")
(e:UIntValue) : width(e)
(e:SIntValue) : width(e)
(e:DoPrim) : prim-width(e,v)
(e:ReadPort|WritePort) :
add(v,WGeq(get-width(enable(e)),IntWidth(1)))
- add(v,WGeq(IntWidth(1),get-width(enable(e))))
get-width(mem(e))
(e:Register) :
add(v,WGeq(get-width(enable(e)),IntWidth(1)))
- add(v,WGeq(IntWidth(1),get-width(enable(e))))
val w = width!(type(e))
- add(v,WGeq(get-width(value(e)),w))
add(v,WGeq(w,get-width(value(e))))
w
defn gen-constraints (s:Stmt) -> Stmt :
@@ -1606,42 +1707,51 @@ defn gen-constraints (c:Circuit, m:Module) -> List<WGeq> :
for p in ports(m) do :
h[name(p)] = type(p)
gen-constraints(body(m))
- to-list(v)
+ v
+defn replace-var-widths (c:Circuit,h:HashTable<Symbol,Int>) -> Circuit :
+ defn replace-var-widths-w (w:Width) -> Width :
+ defn contains? (n:Symbol,h:HashTable<Symbol,?>) -> True|False :
+ for x in h any? : key(x) == n
+ match(w) :
+ (w:VarWidth) :
+ if contains?(name(w),h) : IntWidth(h[name(w)])
+ else: w
+ (w) : w
+ val modules* = for m in modules(c) map :
+ Module{name(m),_,body(m)} $
+ for p in ports(m) map :
+ Port(name(p),direction(p),mapr(replace-var-widths-w,type(p)))
+
+ val modules** = for m in modules* map :
+ Module(name(m),ports(m),mapr(replace-var-widths-w,body(m)))
+ Circuit(modules**,main(c))
defn remove-unknown-widths (c:Circuit) -> Circuit :
- defn remove-unknown-widths-wid (w:Width) -> Width :
+ defn remove-unknown-widths-w (w:Width) -> Width :
match(w) :
(w:UnknownWidth) : VarWidth(gensym(`w))
(w) : w
- defn remove-unknown-widths-type (t:Type) -> Type :
- map{remove-unknown-widths-wid,_} $
- map(remove-unknown-widths-type,t)
- defn remove-unknown-widths-exp (e:Expression) -> Expression :
- map{remove-unknown-widths-wid,_} $
- map{remove-unknown-widths-type,_} $
- map(remove-unknown-widths-exp,e)
- defn remove-unknown-widths-stmt (s:Stmt) -> Stmt :
- map{remove-unknown-widths-type,_} $
- map{remove-unknown-widths-exp,_} $
- map(remove-unknown-widths-stmt,s)
val modules* = for m in modules(c) map :
Module{name(m),_,body(m)} $
for p in ports(m) map :
- Port(name(p),direction(p),remove-unknown-widths-type(type(p)))
+ Port(name(p),direction(p),mapr(remove-unknown-widths-w,type(p)))
val modules** = for m in modules* map :
- Module(name(m),ports(m),remove-unknown-widths-stmt(body(m)))
+ Module(name(m),ports(m),mapr(remove-unknown-widths-w,body(m)))
Circuit(modules**,main(c))
defn infer-widths (c:Circuit) -> Circuit :
val c* = remove-unknown-widths(c)
+ val v = Vector<WGeq>()
for m in modules(c*) do :
- val l = gen-constraints(c*,m)
- for x in l do : println(x)
- c*
- ;val h = solve-constraints(l)
+ gen-constraints(c*,m,v)
+ for x in v do : println(x)
+ val h = solve-constraints(to-list(v))
+ println("Solved Constraints")
+ ;for x in h do : println(x)
;replace-var-widths(c*,h)
+ c*
diff --git a/test/passes/infer-widths/simple.fir b/test/passes/infer-widths/simple.fir
index f98d98da..fcd08ac6 100644
--- a/test/passes/infer-widths/simple.fir
+++ b/test/passes/infer-widths/simple.fir
@@ -3,10 +3,9 @@
;CHECK: Infer Widths
circuit top :
module top :
- wire e : UInt
- wire x : UInt
+ wire e : UInt(30)
reg y : UInt
- y := mux-uu(e, UInt(1), equal-uu(gt-uu(x, x), UInt(0)))
+ y := e
; CHECK: Finished Infer Widths