;Define the STANDALONE flag to run STANDALONE #if-defined(STANDALONE) : #include<"core/stringeater.stanza"> #include<"compiler/lexer.stanza"> defpackage widthsolver : import core import verse import stz/lexer ;============= Language of Constraints ====================== public definterface WConstraint public defstruct WidthEqual <: WConstraint : name: Symbol value: Exp public defstruct WidthGreater <: WConstraint : name: Symbol value: Exp defmethod print (o:OutputStream, c:WConstraint) : print-all{o, _} $ match(c) : (c:WidthEqual) : [name(c) " = " value(c)] (c:WidthGreater) : [name(c) " >= " value(c)] defn construct-eqns (cs: Streamable) : val eqns = HashTable(symbol-hash) val lower-bounds = HashTable>(symbol-hash) for c in cs do : match(c) : (c:WidthEqual) : eqns[name(c)] = value(c) (c:WidthGreater) : lower-bounds[name(c)] = List(value(c), get?(lower-bounds, name(c), List())) ;Create minimum expressions for lower-bounds for entry in lower-bounds do : val v = key(entry) val exps = value(entry) if not key?(eqns, v) : eqns[v] = reduce(EMax, ELit(0), exps) ;Return equations eqns ;============================================================ ;============= Language of Expressions ====================== public definterface Exp public defstruct EVar <: Exp : name: Symbol public defstruct EMax <: Exp : a: Exp b: Exp public defstruct EPlus <: Exp : a: Exp b: Exp public defstruct EMinus <: Exp : a: Exp b: Exp public defstruct ELit <: Exp : width: Int defmethod print (o:OutputStream, e:Exp) : match(e) : (e:EVar) : print(o, name(e)) (e:EMax) : print-all(o, ["max(" a(e) ", " b(e) ")"]) (e:EPlus) : print-all(o, [a(e) " + " b(e)]) (e:EMinus) : print-all(o, [a(e) " - " b(e)]) (e:ELit) : print(o, width(e)) defn map (f: (Exp) -> Exp, e: Exp) -> Exp : match(e) : (e:EMax) : EMax(f(a(e)), f(b(e))) (e:EPlus) : EPlus(f(a(e)), f(b(e))) (e:EMinus) : EMinus(f(a(e)), f(b(e))) (e:Exp) : e defn children (e: Exp) -> List : match(e) : (e:EMax) : list(a(e), b(e)) (e:EPlus) : list(a(e), b(e)) (e:EMinus) : list(a(e), b(e)) (e:Exp) : list() ;============================================================ ;================== Reading from File ======================= defn read-exp (x) : match(unwrap-token(x)) : (x:Symbol) : EVar(x) (x:Int) : ELit(x) (x:List) : val tag = unwrap-token(x[1]) switch {tag == _} : `plus : EPlus(read-exp(x[2]), read-exp(x[3])) `minus : EMinus(read-exp(x[2]), read-exp(x[3])) `max : EMax(read-exp(x[2]), read-exp(x[3])) else : error $ string-join $ ["Improper expression: " x] defn read (filename: String) : var form:List = lex-file(filename) val cs = Vector() while not empty?(form) : val x = unwrap-token(form[0]) val op = form[1] val e = read-exp(form[2]) form = tailn(form, 3) add{cs, _} $ switch {unwrap-token(op) == _} : `= : WidthEqual(x, e) `>= : WidthGreater(x, e) else : error $ string-join $ ["Unsupported Operator: " op] cs ;============================================================ ;============ Operations on Expressions ===================== defn occurs? (v: Symbol, exp: Exp) : match(exp) : (exp: EVar) : name(exp) == v (exp: Exp) : any?(occurs?{v, _}, children(exp)) defn freevars (exp: Exp) : to-list $ generate : defn loop (exp: Exp) : match(exp) : (exp: EVar) : yield(name(exp)) (exp: Exp) : do(loop, children(exp)) loop(exp) defn contains-only-max? (exp: Exp) : match(exp) : (exp:EVar|EMax|ELit) : all?(contains-only-max?, children(exp)) (exp) : false defn simplify (exp: Exp) : match(map(simplify,exp)) : (exp: EPlus) : match(a(exp), b(exp)) : (a: ELit, b: ELit) : ELit(width(a) + width(b)) (a: ELit, b) : if width(a) == 0 : b else : exp (a, b: ELit) : if width(b) == 0 : a else : exp (a, b) : exp (exp: EMinus) : match(a(exp), b(exp)) : (a: ELit, b: ELit) : ELit(width(a) - width(b)) (a, b: ELit) : if width(b) == 0 : a else : exp (a, b) : exp (exp: EMax) : match(a(exp), b(exp)) : (a: ELit, b: ELit) : ELit(max(width(a), width(b))) (a: ELit, b) : if width(a) == 0 : b else : exp (a, b: ELit) : if width(b) == 0 : a else : exp (a, b) : exp (exp: Exp) : exp defn eval (exp: Exp, state: HashTable) -> Int : defn loop (e: Exp) -> Int : match(e) : (e: EVar) : state[name(e)] (e: EMax) : max(loop(a(e)), loop(b(e))) (e: EPlus) : loop(a(e)) + loop(b(e)) (e: EMinus) : loop(a(e)) - loop(b(e)) (e: ELit) : width(e) loop(exp) ;============================================================ ;================ Constraint Solver ========================= defn substitute (solns: HashTable, exp: Exp) : match(exp) : (exp: EVar) : match(get?(solns, name(exp), false)) : (s:Exp) : substitute(solns, s) (f:False) : exp (exp) : map(substitute{solns, _}, exp) defn dataflow (eqns: HashTable, solns: HashTable) : var progress?:True|False = false for entry in eqns do : if value(entry) != false : val v = key(entry) val exp = simplify(substitute(solns, value(entry) as Exp)) if occurs?(v, exp) : eqns[v] = exp else : eqns[v] = false solns[v] = exp progress? = true progress? defn fixpoint (eqns: HashTable, solns: HashTable) : label break : for v in keys(eqns) do : if eqns[v] != false : val fix-eqns = fixpoint-eqns(v, eqns) val has-fixpoint? = all?(contains-only-max?{value(_)}, fix-eqns) if has-fixpoint? : val soln = solve-fixpoint(fix-eqns) for s in soln do : solns[key(s)] = ELit(value(s)) eqns[key(s)] = false break(true) false defn fixpoint-eqns (v: Symbol, eqns: HashTable) : val vs = HashTable(symbol-hash) defn loop (v: Symbol) : if not key?(vs, v) : val eqn = eqns[v] as Exp vs[v] = eqn do(loop, freevars(eqn)) loop(v) to-list(vs) defn solve-fixpoint (eqns: List>) : ;Solve for fixpoint val sol = HashTable(symbol-hash) do({sol[key(_)] = 0}, eqns) defn loop () : var progress?:True|False = false for eqn in eqns do : val v = key(eqn) val x = eval(value(eqn), sol) if x != sol[v] : sol[v] = x progress? = true progress? while loop() : false ;Return solutions to-list(sol) defn backsubstitute (vs:Streamable, solns: HashTable) : val widths = HashTable(symbol-hash) defn get-width (v:Symbol) : if key?(solns, v) : val vs = freevars(solns[v]) ;Calculate dependencies for v in vs do : if not key?(widths, v) : widths[v] = get-width(v) ;Compute value if none?({widths[_] == false}, vs) : eval(solns[v], widths as HashTable) ;Compute all widths for v in vs do : widths[v] = get-width(v) ;Return widths to-list $ generate : for entry in widths do : if value(entry) != false : yield $ WidthEqual(key(entry), ELit(value(entry) as Int)) public defn solve-widths (cs: Streamable) : ;Copy to new hashtable val eqns = construct-eqns(cs) val solns = HashTable(symbol-hash) defn loop () : dataflow(eqns, solns) or fixpoint(eqns, solns) while loop() : false backsubstitute(keys(eqns), solns) ;================= Main ===================================== #if-defined(STANDALONE) : defn main () : val input = lex(commandline-arguments()) error("No input file!") when length(input) < 2 val cs = read(to-string(input[1])) do(println, solve-widths(cs)) main() ;============================================================