aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--TODO2
-rw-r--r--src/main/stanza/ir-utils.stanza16
-rw-r--r--src/main/stanza/passes.stanza182
-rw-r--r--test/passes/infer-widths/simple.fir5
4 files changed, 165 insertions, 40 deletions
diff --git a/TODO b/TODO
index dacac142..3d9beac9 100644
--- a/TODO
+++ b/TODO
@@ -1,7 +1,7 @@
TODO
Change parser to use <> syntax (and update all tests) (patrick)
Think about on-reset, add it
- Think about expanding mems, more complicated than first glance!
+ Think about expanding mems, more complicated than first glance! Change DefMem to be size, and element type
Think about max(op1,op2) for muxes.
Make instances always male, flip the bundles on declaration
Talk to palmer/patrick about how writing passes is going to be supported
diff --git a/src/main/stanza/ir-utils.stanza b/src/main/stanza/ir-utils.stanza
index 1f33547e..0bce0a90 100644
--- a/src/main/stanza/ir-utils.stanza
+++ b/src/main/stanza/ir-utils.stanza
@@ -289,6 +289,22 @@ defmethod map (f: Type -> Type, c:Stmt) -> Stmt :
(c:DefMemory) : DefRegister(name(c),f(type(c)))
(c) : c
+public defmulti mapr<?T> (f: Width -> Width, t:?T&Type) -> T
+defmethod mapr (f: Width -> Width, t:Type) -> Type :
+ defn apply-t (t:Type) -> Type :
+ map{f,_} $ map(apply-t,t)
+ apply-t(t)
+
+public defmulti mapr<?T> (f: Width -> Width, s:?T&Stmt) -> T
+defmethod mapr (f: Width -> Width, s:Stmt) -> Stmt :
+ defn apply-t (t:Type) -> Type : mapr(f,t)
+ defn apply-e (e:Expression) -> Expression :
+ map{f,_} $ map{apply-t,_} $ map(apply-e,e)
+ defn apply-s (s:Stmt) -> Stmt :
+ map{apply-t,_} $ map{apply-e,_} $ map(apply-s,s)
+ apply-s(s)
+
+
;================= HELPER FUNCTIONS USING MAP ===================
; These don't work properly..
;public defmulti do (f:Expression -> ?, e:Expression) -> False
diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza
index 313191c8..e9e5bfc6 100644
--- a/src/main/stanza/passes.stanza
+++ b/src/main/stanza/passes.stanza
@@ -1436,6 +1436,14 @@ defstruct MaxWidth <: Width :
arg1 : Width
arg2 : Width
+public defmulti map<?T> (f: Width -> Width, w:?T&Width) -> T
+defmethod map (f: Width -> Width, w:Width) -> Width :
+ match(w) :
+ (w:MaxWidth) : MaxWidth(f(arg1(w)),f(arg2(w)))
+ (w:PlusWidth) : PlusWidth(f(arg1(w)),f(arg2(w)))
+ (w:MinusWidth) : MinusWidth(f(arg1(w)),f(arg2(w)))
+ (w) : w
+
defmethod print (o:OutputStream, w:VarWidth) :
print(o,name(w))
defmethod print (o:OutputStream, w:MaxWidth) :
@@ -1449,18 +1457,115 @@ definterface Constraint
defstruct WGeq <: Constraint :
loc : Width
exp : Width
-defstruct WEq <: Constraint :
- loc : Width
- exp : Width
defmethod print (o:OutputStream, c:WGeq) :
print-all(o,[ loc(c) " >= " exp(c)])
-defmethod print (o:OutputStream, c:WEq) :
- print-all(o,[ loc(c) " = " exp(c)])
-;defn remove-var-widths (c:Circuit) -> Circuit :
-
-;defn solve-constraints (l:List<WGeq>) -> HashTable<Symbol,Int> :
+defn solve-constraints (l:List<WGeq>) -> HashTable<Symbol,Int> :
+ defn contains? (n:Symbol,h:HashTable<Symbol,?>) -> True|False :
+ for x in h any? : key(x) == n
+ defn unique (ls:List<WGeq>) -> HashTable<Symbol,Width> :
+ val h = HashTable<Symbol,Width>(symbol-hash)
+ for g in ls do :
+ match(loc(g)) :
+ (w:VarWidth) :
+ val n = name(w)
+ if contains?(n,h) : h[n] = MaxWidth(exp(g),h[n])
+ else : h[n] = exp(g)
+ (w) : w
+ h
+ defn substitute (w:Width,h:HashTable<Symbol,Width>) -> Width :
+ match(map(substitute{_,h},w)) :
+ (w:VarWidth) :
+ if contains?(name(w),h) :
+ val t = substitute(h[name(w)],h)
+ h[name(w)] = t
+ t
+ else : w
+ (w) : w
+ defn remove-cycle (n:Symbol,w:Width) -> Width :
+ match(map(remove-cycle{n,_},w)) :
+ (w:MaxWidth) :
+ match(arg1(w),arg2(w)) :
+ (v1:VarWidth,v2:VarWidth) :
+ if name(v1) == n : arg2(w)
+ else if name(v2) == n : arg1(w)
+ else : w
+ (v:VarWidth,_) :
+ if name(v) == n : arg2(w)
+ else : w
+ (_,v:VarWidth) :
+ if name(v) == n : arg1(w)
+ else : w
+ (v1,v2) : w
+ (w) : w
+ defn self-rec? (n:Symbol,w:Width) -> True|False :
+ var has? = false
+ defn look (w:Width) -> Width :
+ match(map(look,w)) :
+ (w:VarWidth) : if name(w) == n : has? = true
+ (w) : w
+ w
+ look(w)
+ has?
+ defn evaluate (h:HashTable<Symbol,Width>) -> HashTable<Symbol,Int> :
+ defn apply (a:Int|False,b:Int|False, f: (Int,Int) -> Int) -> Int|False :
+ if a typeof Int and b typeof Int : f(a as Int, b as Int)
+ else : false
+ defn max (a:Int,b:Int) -> Int :
+ if a > b : a
+ else : b
+ defn solve (w:Width) -> Int|False :
+ match(w) :
+ (w:VarWidth) : false
+ (w:MaxWidth) : apply(solve(arg1(w)),solve(arg2(w)),max)
+ (w:PlusWidth) : apply(solve(arg1(w)),solve(arg2(w)),{_ + _})
+ (w:MinusWidth) : apply(solve(arg1(w)),solve(arg2(w)),{_ - _})
+ (w:IntWidth) : width(w)
+ (w) : error("Shouldn't be here")
+
+ val i = HashTable<Symbol,Int>(symbol-hash)
+ for x in h do :
+ val s = solve(value(x))
+ if s typeof Int : i[key(x)] = s as Int
+ i
+
+ ; Forward solve
+ ; Returns a solved list where each constraint undergoes:
+ ; 1) Continuous Solving (using triangular solving)
+ ; 2) Remove Cycles
+ ; 3) Move to solved if not self-recursive
+ val u = unique(l)
+ println("Unique Constraints")
+ for x in u do : println(x)
+
+ val f = HashTable<Symbol,Width>(symbol-hash)
+ val o = Vector<Symbol>()
+ for x in u do :
+ val [n e] = [key(x) value(x)]
+ val e* = substitute(e,f)
+ ;val e* = remove-cycle{n,_} $ substitute(e,f)
+ ;if not self-rec?(n,e*) :
+ add(o,n)
+ f[n] = e*
+
+ println("Forward Solved Constraints")
+ for x in f do : println(x)
+
+ ; Backwards Solve
+ ;val b = HashTable<Symbol,Width>(symbol-hash)
+ ;for i in (length(o) - 1) through 0 by -1 do :
+ ; val n = o[i]
+ ; b[n] = substitute(f[n],b)
+
+ ;println("Backwards Solved Constraints")
+ ;for x in b do : println(x)
+ ;; Evaluate
+ ;val e = evaluate(b)
+ ;println("Evaluated Constraints")
+ ;for x in e do : println(x)
+ ;e
+ HashTable<Symbol,Int>(symbol-hash)
defn width! (t:Type) -> Width :
match(t) :
@@ -1550,8 +1655,7 @@ defn prim-width (e:DoPrim,v:Vector<WGeq>) -> Width :
BIT-SELECT-OP : IntWidth(1)
BITS-SELECT-OP : IntWidth(consts(e)[0] - consts(e)[1])
-defn gen-constraints (c:Circuit, m:Module) -> List<WGeq> :
- val v = Vector<WGeq>()
+defn gen-constraints (c:Circuit, m:Module, v:Vector<WGeq>) -> Vector<WGeq> :
val h = HashTable<Symbol,Type>(symbol-hash)
defn get-width (e:Expression) -> Width :
@@ -1565,28 +1669,25 @@ defn gen-constraints (c:Circuit, m:Module) -> List<WGeq> :
(k) :
[width! $ h[name(e)],
width! $ type(e)]
- add(v,WGeq(wdec,wref))
add(v,WGeq(wref,wdec))
- wdec
+ add(v,WGeq(wdec,wref))
+ wref
(e:WSubfield) : ;assumes only subfields are instances
val wdec = width! $ bundle-field-type(h[name(exp(e) as WRef)],name(e))
val wref = width! $ bundle-field-type(type(exp(e)),name(e))
- add(v,WGeq(wdec,wref))
add(v,WGeq(wref,wdec))
- wdec
+ add(v,WGeq(wdec,wref))
+ wref
(e:WIndex) : error("Shouldn't be here")
(e:UIntValue) : width(e)
(e:SIntValue) : width(e)
(e:DoPrim) : prim-width(e,v)
(e:ReadPort|WritePort) :
add(v,WGeq(get-width(enable(e)),IntWidth(1)))
- add(v,WGeq(IntWidth(1),get-width(enable(e))))
get-width(mem(e))
(e:Register) :
add(v,WGeq(get-width(enable(e)),IntWidth(1)))
- add(v,WGeq(IntWidth(1),get-width(enable(e))))
val w = width!(type(e))
- add(v,WGeq(get-width(value(e)),w))
add(v,WGeq(w,get-width(value(e))))
w
defn gen-constraints (s:Stmt) -> Stmt :
@@ -1606,42 +1707,51 @@ defn gen-constraints (c:Circuit, m:Module) -> List<WGeq> :
for p in ports(m) do :
h[name(p)] = type(p)
gen-constraints(body(m))
- to-list(v)
+ v
+defn replace-var-widths (c:Circuit,h:HashTable<Symbol,Int>) -> Circuit :
+ defn replace-var-widths-w (w:Width) -> Width :
+ defn contains? (n:Symbol,h:HashTable<Symbol,?>) -> True|False :
+ for x in h any? : key(x) == n
+ match(w) :
+ (w:VarWidth) :
+ if contains?(name(w),h) : IntWidth(h[name(w)])
+ else: w
+ (w) : w
+ val modules* = for m in modules(c) map :
+ Module{name(m),_,body(m)} $
+ for p in ports(m) map :
+ Port(name(p),direction(p),mapr(replace-var-widths-w,type(p)))
+
+ val modules** = for m in modules* map :
+ Module(name(m),ports(m),mapr(replace-var-widths-w,body(m)))
+ Circuit(modules**,main(c))
defn remove-unknown-widths (c:Circuit) -> Circuit :
- defn remove-unknown-widths-wid (w:Width) -> Width :
+ defn remove-unknown-widths-w (w:Width) -> Width :
match(w) :
(w:UnknownWidth) : VarWidth(gensym(`w))
(w) : w
- defn remove-unknown-widths-type (t:Type) -> Type :
- map{remove-unknown-widths-wid,_} $
- map(remove-unknown-widths-type,t)
- defn remove-unknown-widths-exp (e:Expression) -> Expression :
- map{remove-unknown-widths-wid,_} $
- map{remove-unknown-widths-type,_} $
- map(remove-unknown-widths-exp,e)
- defn remove-unknown-widths-stmt (s:Stmt) -> Stmt :
- map{remove-unknown-widths-type,_} $
- map{remove-unknown-widths-exp,_} $
- map(remove-unknown-widths-stmt,s)
val modules* = for m in modules(c) map :
Module{name(m),_,body(m)} $
for p in ports(m) map :
- Port(name(p),direction(p),remove-unknown-widths-type(type(p)))
+ Port(name(p),direction(p),mapr(remove-unknown-widths-w,type(p)))
val modules** = for m in modules* map :
- Module(name(m),ports(m),remove-unknown-widths-stmt(body(m)))
+ Module(name(m),ports(m),mapr(remove-unknown-widths-w,body(m)))
Circuit(modules**,main(c))
defn infer-widths (c:Circuit) -> Circuit :
val c* = remove-unknown-widths(c)
+ val v = Vector<WGeq>()
for m in modules(c*) do :
- val l = gen-constraints(c*,m)
- for x in l do : println(x)
- c*
- ;val h = solve-constraints(l)
+ gen-constraints(c*,m,v)
+ for x in v do : println(x)
+ val h = solve-constraints(to-list(v))
+ println("Solved Constraints")
+ ;for x in h do : println(x)
;replace-var-widths(c*,h)
+ c*
diff --git a/test/passes/infer-widths/simple.fir b/test/passes/infer-widths/simple.fir
index f98d98da..fcd08ac6 100644
--- a/test/passes/infer-widths/simple.fir
+++ b/test/passes/infer-widths/simple.fir
@@ -3,10 +3,9 @@
;CHECK: Infer Widths
circuit top :
module top :
- wire e : UInt
- wire x : UInt
+ wire e : UInt(30)
reg y : UInt
- y := mux-uu(e, UInt(1), equal-uu(gt-uu(x, x), UInt(0)))
+ y := e
; CHECK: Finished Infer Widths