diff options
| author | azidar | 2015-03-27 17:37:04 -0700 |
|---|---|---|
| committer | azidar | 2015-03-27 17:37:04 -0700 |
| commit | d4fdab6950b47379137fce750e4a3a6b262e750d (patch) | |
| tree | 60b2f6b6b89358f5311ba7409a6b7ccdb8ac4fed /src | |
| parent | a1a1156df859eb815f8b345d24198dbfe3857832 (diff) | |
Corrected register init by adding initialization of registers pass after lowering. Finished expand-whens. Needs more thorough testing of instances
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/stanza/ir-utils.stanza | 73 | ||||
| -rw-r--r-- | src/main/stanza/passes.stanza | 316 |
2 files changed, 271 insertions, 118 deletions
diff --git a/src/main/stanza/ir-utils.stanza b/src/main/stanza/ir-utils.stanza index 7c3537fc..0fe5ef5b 100644 --- a/src/main/stanza/ir-utils.stanza +++ b/src/main/stanza/ir-utils.stanza @@ -254,12 +254,73 @@ defmethod map (f: Stmt -> Stmt, c:Stmt) -> Stmt : (c:Begin) : Begin(map(f, body(c))) (c) : c -public defmulti children (c:Stmt) -> List<Stmt> -defmethod children (c:Stmt) : - match(c) : - (c:Conditionally) : list(conseq(c), alt(c)) - (c:Begin) : body(c) - (c) : List() +;================= HELPER FUNCTIONS USING MAP =================== +public defmulti do (f:Expression -> ?, e:Expression) -> False +defmethod do (f:Expression -> ?, e:Expression) -> False : + for x in e map : + f(x) + x + false + +public defmulti do (f:Expression -> ?, s:Stmt) -> False +defmethod do (f:Expression -> ?, s:Stmt) -> False : + defn f* (x:Expression) : + f(x) + x + map(f*,s) + false + +public defmulti do (f:Stmt -> ?, s:Stmt) -> False +defmethod do (f:Stmt -> ?, s:Stmt) -> False : + defn f* (x:Stmt) : + f(x) + x + map(f*,s) + false + +public defmulti dor (f:Expression -> ?, e:Expression) -> False +defmethod dor (f:Expression -> ?, e:Expression) -> False : + f(e) + for x in e map : + dor(f,x) + x + false + +public defmulti dor (f:Expression -> ?, s:Stmt) -> False +defmethod dor (f:Expression -> ?, s:Stmt) -> False : + defn f* (x:Expression) : + dor(f,x) + x + map(f*,s) + false + +public defmulti dor (f:Stmt -> ?, s:Stmt) -> False +defmethod dor (f:Stmt -> ?, s:Stmt) -> False : + f(s) + defn f* (x:Stmt) : + dor(f,x) + x + map(f*,s) + false + +public defmulti sub-exps (s:Expression|Stmt) -> List<Expression> +defmethod sub-exps (e:Expression) -> List<Expression> : + val l = Vector<Expression>() + defn f (x:Expression) : add(l,x) + do(f,e) + to-list(l) +defmethod sub-exps (e:Stmt) -> List<Expression> : + val l = Vector<Expression>() + defn f (x:Expression) : add(l,x) + do(f,e) + to-list(l) + +public defmulti sub-stmts (s:Stmt) -> List<Stmt> +defmethod sub-stmts (s:Stmt) : + val l = Vector<Stmt>() + defn f (x:Stmt) : add(l,x) + do(f,s) + to-list(l) ;=================== ADAM OPS =============================== public defn split (s:String,c:Char) -> List<String> : diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza index 4c45021d..2c0ad0ec 100644 --- a/src/main/stanza/passes.stanza +++ b/src/main/stanza/passes.stanza @@ -62,6 +62,17 @@ defstruct WDefAccessor <: Stmt : index: Expression gender: Gender +defstruct ConnectToIndexed <: Stmt : + index: Expression + locs: List<Expression> + exp: Expression + +defstruct ConnectFromIndexed <: Stmt : + index: Expression + loc: Expression + exps: List<Expression> + + ;================ WORKING IR UTILS ========================= defn plus (g1:Gender,g2:Gender) -> Gender : @@ -209,17 +220,26 @@ defmethod print (o:OutputStream, s:WDefAccessor) : print-all(o,["accessor " name(s) " = " source(s) "[" index(s) "]"]) print-debug(o,s) +defmethod print (o:OutputStream, c:ConnectToIndexed) : + print-all(o, [locs(c) "[" index(c) "] := " exp(c)]) + print-debug(o,c as ?) + +defmethod print (o:OutputStream, c:ConnectFromIndexed) : + print-all(o, [loc(c) " := " exps(c) "[" index(c) "]"]) + print-debug(o,c as ?) + defmethod map (f: Expression -> Expression, e: WRegInit) : WRegInit(f(reg(e)), name(e), type(e), gender(e)) - defmethod map (f: Expression -> Expression, e: WSubfield) : WSubfield(f(exp(e)), name(e), type(e), gender(e)) - defmethod map (f: Expression -> Expression, e: WIndex) : WIndex(f(exp(e)), value(e), type(e), gender(e)) - defmethod map (f: Expression -> Expression, c:WDefAccessor) : WDefAccessor(name(c), f(source(c)), f(index(c)), gender(c)) +defmethod map (f: Expression -> Expression, c:ConnectToIndexed) : + ConnectToIndexed(f(index(c)), map(f, locs(c)), f(exp(c))) +defmethod map (f: Expression -> Expression, c:ConnectFromIndexed) : + ConnectFromIndexed(f(index(c)), f(loc(c)), map(f, exps(c))) ;================= Bring to Working IR ======================== ; Returns a new Circuit with Refs, Subfields, Indexes and DefAccessors @@ -337,6 +357,7 @@ defn make-explicit-reset (c:Circuit) : for m in modules(c) map : make-explicit-reset(m,c) + ;============== INFER TYPES ================================ ; This pass infers the type field in all IR nodes by updating ; and passing an environment to all statements in pre-order @@ -660,26 +681,6 @@ defn resolve-genders (c:Circuit) : ; ConnectFromIndexed (female) ; Eg: -defstruct ConnectToIndexed <: Stmt : - index: Expression - locs: List<Expression> - exp: Expression - -defstruct ConnectFromIndexed <: Stmt : - index: Expression - loc: Expression - exps: List<Expression> - -defmethod print (o:OutputStream, c:ConnectToIndexed) : - print-all(o, [locs(c) "[" index(c) "] := " exp(c)]) -defmethod print (o:OutputStream, c:ConnectFromIndexed) : - print-all(o, [loc(c) " := " exps(c) "[" index(c) "]"]) - -defmethod map (f: Expression -> Expression, c:ConnectToIndexed) : - ConnectToIndexed(f(index(c)), map(f, locs(c)), f(exp(c))) -defmethod map (f: Expression -> Expression, c:ConnectFromIndexed) : - ConnectFromIndexed(f(index(c)), f(loc(c)), map(f, exps(c))) - defn expand-vector (e:Expression) -> List<Expression> : val t = type(e) as VectorType for i in 0 to size(t) map-append : @@ -939,6 +940,83 @@ defn expand-connect-indexed (c: Circuit) -> Circuit : for m in modules(c) map : expand-connect-indexed(m) +;======= MAKE EXPLICIT REGISTER INITIALIZATION ============= +; This pass replaces the reg.init construct by creating a new +; wire that holds the value at initialization. This wire +; is then connected to the register conditionally on reset, +; at the end of the scope containing the register +; declaration +; If a register has no inital value, the wire is connected to +; a NULL node. Later passes will remove these with the base +; case Mux(reset,NULL,a) -> a, and Mux(reset,a,NULL) -> a. +; This ensures proper behavior if this pass is run multiple +; times. + +defn initialize-registers (c:Circuit) : + defn to-wire-name (y:Symbol) : to-symbol("~#init" % [y]) + defn add-when (s:Stmt,h:HashTable<Symbol,Type>) -> Stmt : + var inits = List<Stmt>() + for kv in h do : + val refreg = WRef(key(kv),value(kv),RegKind(),FEMALE) + val refwire = WRef(to-wire-name(key(kv)),value(kv),NodeKind(),MALE) + val connect = Connect(refreg,refwire) + inits = append(inits,list(connect)) + if empty?(inits) : s + else : + val pred = WRef(`reset, UIntType(IntWidth(1)), PortKind(), MALE) + val when-reset = Conditionally(pred,Begin(inits),Begin(List<Stmt>())) + Begin(list(s,when-reset)) + + defn rename (s:Stmt,h:HashTable<Symbol,True|False>) -> [Stmt HashTable<Symbol,Type>] : + val t = HashTable<Symbol,Type>(symbol-hash) + defn rename-expr (e:Expression) -> Expression : + match(map(rename-expr,e)) : + (e:WRegInit) : + val new-name = to-wire-name(name(reg(e) as WRef)) + WRef(new-name,type(reg(e)),RegKind(),gender(e)) + (e) : e + defn rename-stmt (s:Stmt) -> Stmt : + match(map(rename-stmt,s)) : + (s:DefRegister) : + if h[name(s)] : + t[name(s)] = type(s) + Begin(list(s,DefWire(to-wire-name(name(s)),type(s)))) + else : s + (s) : map(rename-expr,s) + [rename-stmt(s) t] + + defn init? (y:Symbol,s:Stmt) -> True|False : + var used? = false + defn has? (e:Expression) -> Expression : + match(map(has?,e)) : + (e:WRegInit) : + if name(reg(e) as WRef) == y : used? = true + (e) : map(has?,e) + e + map(has?,s) + used? + + defn using-init (s:Stmt,h:HashTable<Symbol,True|False>) -> Stmt : + match(s) : + (s:DefRegister) : h[name(s)] = false + (s) : + for x in h do : + h[key(x)] = value(x) or init?(key(x),s) + map(using-init{_,h},s) + + defn explicit-init-scope (s:Stmt) -> Stmt : + val h = HashTable<Symbol,True|False>(symbol-hash) + using-init(s,h) + println(h) + val [s* t] = rename(s,h) + add-when(s*,t) + + Circuit(modules*, main(c)) where : + val modules* = + for m in modules(c) map : + Module(name(m), ports(m), body*) where : + val body* = explicit-init-scope(body(m)) + ;;================ EXPAND WHENS ============================= ; This pass does three things: remove last connect semantics, ; remove conditional blocks, and eliminate concept of scoping. @@ -1000,6 +1078,14 @@ defn NOT (e1:Expression) -> Expression : else if e1 == zero : one else : DoPrim(EQUAL-UU-OP,list(e1,zero),list(),UIntType(IntWidth(1))) +defn children (e:Expression) -> List<Expression> : + val es = Vector<Expression>() + do(add{es,_},e) + to-list(es) + + + + ; ======= Symbolic Value Library ========== public definterface SymbolicValue public defstruct SVExp <: SymbolicValue : @@ -1022,6 +1108,18 @@ defmethod map (f: SymbolicValue -> SymbolicValue, sv:SymbolicValue) -> SymbolicV (sv: SVMux) : SVMux(pred(sv),f(conseq(sv)),f(alt(sv))) (sv) : sv +defn do (f:SymbolicValue -> ?, s:SymbolicValue) -> False : + for x in s map : + f(x) + x + false +defn dor (f:SymbolicValue -> ?, e:SymbolicValue) -> False : + do(f,e) + for x in e map : + dor(f,x) + x + false + defmethod equal? (a:SymbolicValue,b:SymbolicValue) -> True|False : match(a,b) : (a:SVNul,b:SVNul) : true @@ -1029,6 +1127,7 @@ defmethod equal? (a:SymbolicValue,b:SymbolicValue) -> True|False : (a:SVMux,b:SVMux) : pred(a) == pred(b) and conseq(a) == conseq(b) and alt(a) == alt(b) (a,b) : false +;TODO add invert to primop defn optimize (sv:SymbolicValue) -> SymbolicValue : match(map(optimize,sv)) : (sv:SVMux) : @@ -1055,16 +1154,21 @@ defn new-table () -> HashTable<Symbol,SSV> : HashTable<Symbol,SSV>(symbol-hash) ; ========= Expand When Pass =========== defn expand-whens-stmt (table:HashTable<Symbol,SSV>,enables:HashTable<Symbol,SymbolicValue>) -> Stmt : + defn has-nul? (sv:SymbolicValue) -> True|False : + var has? = false + if sv typeof SVNul : has? = true + for x in sv dor : + if x typeof SVNul : has? = true + has? + defn remove-nul (sv:SymbolicValue) -> SymbolicValue : + match(map(remove-nul,sv)) : + (sv:SVMux) : + match(conseq(sv),alt(sv)) : + (c,a:SVNul) : c + (c:SVNul,a) : a + (c,a) : sv + (sv) : sv defn to-exp (sv:SymbolicValue) -> Expression : - defn remove-nul (sv:SymbolicValue) -> SymbolicValue : - match(map(remove-nul,sv)) : - (sv:SVMux) : - match(conseq(sv),alt(sv)) : - (c,a:SVNul) : c - (c:SVNul,a) : a - (c,a) : sv - (sv) : sv - match(remove-nul(sv)) : (sv:SVMux) : DoPrim(MUX-UU-OP, @@ -1087,20 +1191,29 @@ defn expand-whens-stmt (table:HashTable<Symbol,SSV>,enables:HashTable<Symbol,Sym add{connections,_} $ Connect{_,WRef(sym,ty,NodeKind(),MALE)} $ WritePort(source(s),index(s),ty,to-exp(enables[sym])) - add{connections,_} $ - Connect{_,to-exp(sv(ssv))} $ - WRef(name(stmt(ssv) as ?), UnknownType(), NodeKind(), FEMALE) + val sv* = remove-nul(sv(ssv)) + if sv* == SVNul : println("Uninitialized: ~" % [to-string(s)]) + else : + add{connections,_} $ + Connect{_,to-exp(sv*)} $ + WRef(name(stmt(ssv) as ?), UnknownType(), NodeKind(), FEMALE) MALE : add{connections,_} $ Connect{WRef(sym,ty,NodeKind(),FEMALE),_} $ ReadPort(source(s),index(s),ty,to-exp(enables[sym])) (s:DefRegister) : + val sv* = remove-nul(sv(ssv)) + val exp* = + if sv* typeof SVNul : WRef(sym,type(s),NodeKind(),MALE) + else : to-exp(sv*) add(declarations,s) - add{connections,_} $ Connect{WRef(sym,type(s),NodeKind(),FEMALE),_} $ - Register(type(s),to-exp(sv(ssv)),to-exp(enables[sym])) + add{connections,_} $ + Connect{WRef(sym,type(s),NodeKind(),FEMALE),_} $ + Register(type(s),exp*,to-exp(enables[sym])) (s) : add(declarations,stmt(value(x))) - if not sv(ssv) typeof SVNul : + if has-nul?(sv(ssv)) : println("Uninitialized: ~" % [to-string(s)]);TODO actually collect error + else : add{connections,_} $ Connect{_,to-exp(sv(ssv))} $ WRef(name(stmt(ssv) as ?), UnknownType(), NodeKind(), FEMALE) @@ -1116,15 +1229,10 @@ defn get-enables (table:HashTable<Symbol,SSV>) -> HashTable<Symbol,SymbolicValue else : OR(head(l) reduce-or(tail(l))) defn get-read-enable (sym:Symbol,sv:SymbolicValue) -> Expression : - defn active (e:Expression) -> True|False : + defn active (e:Expression) -> True|False : match(e) : (e:WRef) : name(e) == sym - (e:WSubfield) : active(exp(e)) - (e:WRegInit) : active(reg(e)) - (es:DoPrim) : reduce-or{_} $ for e in args(es) map : active(e) - (e:ReadPort) : reduce-or{_} $ map(active,list(mem(e),index(e),enable(e))) - (e:WritePort) : reduce-or{_} $ map(active,list(mem(e),index(e),enable(e))) - (e:Register) : reduce-or{_} $ map(active,list(value(e),enable(e))) + (e) : reduce-or{_} $ map(active,children(e)) (e) : false match(sv) : (sv: SVNul) : zero @@ -1135,7 +1243,8 @@ defn get-enables (table:HashTable<Symbol,SSV>) -> HashTable<Symbol,SymbolicValue val e0 = get-read-enable(sym,SVExp(pred(sv))) val e1 = get-read-enable(sym,conseq(sv)) val e2 = get-read-enable(sym,alt(sv)) - OR(e0,OR(AND(pred(sv),e1),AND(NOT(pred(sv)),e2))) + if e1 == e2 : OR(e0,e1) + else : OR(e0,OR(AND(pred(sv),e1),AND(NOT(pred(sv)),e2))) defn get-write-enable (sv:SymbolicValue) -> SymbolicValue : match(map(get-write-enable,sv)) : @@ -1149,39 +1258,23 @@ defn get-enables (table:HashTable<Symbol,SSV>) -> HashTable<Symbol,SymbolicValue match(stmt(value(x))) : (s:WDefAccessor) : switch {_ == gender(s)} : FEMALE : enables[sym] = get-write-enable(sv(value(x))) - MALE : enables[sym] = SVExp{_} $ reduce-or{_} $ - for y in table map-append : - list(get-read-enable(sym,sv(value(y)))) + MALE : enables[sym] = SVExp{_} $ reduce-or{_} $ to-list{_} $ + for y in table stream : + get-read-enable(sym,sv(value(y))) (s:DefRegister) : enables[sym] = get-write-enable(sv(value(x))) (s) : s enables -defn merge-reg-init (table:HashTable<Symbol,SSV>) -> HashTable<Symbol,SSV> : - val merge = HashTable<Symbol,SSV>(symbol-hash) - for x in table do : - val sym = key(x) - match(stmt(value(x))) : - (s:DefRegister) : - val write = sv(value(x)) - val x = for x in table find : key(x) == symbol-join([name(s),`.init]) - val init = match(x) : - (x:False) : SVExp(zero) - (x:KeyValue<Symbol,SSV>) : sv(value(x)) - merge[sym] = SSV(s,SVMux(WRef(`reset,UIntType(IntWidth(1)),PortKind(),MALE),init,write)) - (s:EmptyStmt) : "Nothing" - (s) : merge[sym] = table[sym] - merge - - -defn build-table (s:Stmt, table-arg:HashTable<Symbol,SSV>) -> HashTable<Symbol,SSV> : - var table = table-arg - ;println("=====================") +defn build-table (s:Stmt, table:HashTable<Symbol,SSV>) -> False : + println("=====================") match(s) : (s:DefWire) : table[name(s)] = SSV(s SVNul()) (s:DefNode) : table[name(s)] = SSV(s SVNul()) (s:DefRegister) : table[name(s)] = SSV(s SVNul()) (s:WDefAccessor) : table[name(s)] = SSV(s SVNul()) - (s:DefInstance) : table[name(s)] = SSV(s SVNul()) + (s:DefInstance) : + for f in fields(type(module(s)) as BundleType) do : + table[to-symbol("~.~" % [name(s),name(f)])] = SSV(s SVNul()) (s:DefMemory) : table[name(s)] = SSV(s SVNul()) (s:Conditionally) : defn deepcopy (t:HashTable<Symbol,SSV>) -> HashTable<Symbol,SSV> : @@ -1201,57 +1294,56 @@ defn build-table (s:Stmt, table-arg:HashTable<Symbol,SSV>) -> HashTable<Symbol,S for t in v do : val duplicate? = for x in t0 any? : x == key(t) if not duplicate? : add(t0,key(t)) - val table1 = deepcopy(table) - val table2 = deepcopy(table) - val table-c = build-table(conseq(s),table1) - val table-a = build-table(alt(s),table2) - val table-m = new-table() + val table-c = deepcopy(table) + val table-a = deepcopy(table) + build-table(conseq(s),table-c) + build-table(alt(s),table-a) for i in get-unique-keys(list(table-c,table-a)) do : - table-m[i] = match(get(table-c,i),get(table-a,i)) : + table[i] = match(get(table-c,i),get(table-a,i)) : (c:SSV,a:SSV) : SSV(stmt(c),SVMux(pred(s),sv(c),sv(a))) - (c:SSV,a:False) : SSV(stmt(c),SVMux(pred(s),sv(c),SVNul())) - (c:False,a:SSV) : SSV(stmt(a),SVMux(pred(s),SVNul(),sv(a))) + (c:SSV,a:False) : + if stmt(c) typeof DefWire|DefInstance : + SSV(stmt(c),sv(c)) + else : SSV(stmt(c),SVMux(pred(s),sv(c),SVNul())) + (c:False,a:SSV) : + if stmt(a) typeof DefWire|DefInstance : + SSV(stmt(a),sv(a)) + else : SSV(stmt(a),SVMux(pred(s),SVNul(),sv(a))) (c:False,a:False) : error("Shouldn't be here") - table = table-m - ;println("TABLE") - ;for x in table do : println(x) - ;println("TABLE-C") - ;for x in table-c do : println(x) - ;println("TABLE-A") - ;for x in table-a do : println(x) - ;println("TABLE-M") - ;for x in table-m do : println(x) - (s:Connect) : - val i = for kv in table find : - name(loc(s) as ?) == key(kv) - match(i) : - (i:False) : table[name(loc(s) as ?)] = SSV(EmptyStmt() SVExp(exp(s))) - (i:KeyValue<Symbol,SSV>) : table[key(i)] = SSV(stmt(value(i)) SVExp(exp(s))) - (s:Begin) : for s* in body(s) do: table = build-table(s*,table) - (s) : s - table + println("TABLE-C") + for x in table-c do : println(x) + println("TABLE-A") + for x in table-a do : println(x) + println("TABLE") + for x in table do : println(x) + (s:Connect) : + val key* = match(loc(s)) : + (e:WRef) : name(e) + (e:WSubfield) : to-symbol("~.~" % [name(exp(e) as WRef),name(e)]) + (e) : error("Shouldn't be here with ~" % [e]) + table[key*] = SSV(stmt(table[key*]) SVExp(exp(s))); TODO, need to check all references are declared before this point + (s:Begin) : for s* in body(s) do: build-table(s*,table) + (s) : false defn expand-whens (m:Module) -> Module : - val table = build-table(body(m),new-table()) + val table = new-table() + build-table(body(m),table) val table* = HashTable<Symbol,SSV>(symbol-hash) for x in table do : table*[key(x)] = SSV(stmt(value(x)) optimize(sv(value(x)))) - val table** = merge-reg-init(table*) - val enables = get-enables(table**) + val enables = get-enables(table*) val enables* = HashTable<Symbol,SymbolicValue>(symbol-hash) for x in enables do : enables*[key(x)] = optimize(value(x)) - ;println("Original Table") - ;for x in table do : println(x) - ;println("Optimized Table") - ;for x in table* do : println(x) - ;println("Merged Inits Table") - ;for x in table** do : println(x) - ;println("Enable Table") - ;for x in enables do : println(x) - ;println("Optimized Enable Table") - ;for x in enables* do : println(x) - - Module(name(m),ports(m),expand-whens-stmt(table**,enables*)) + println("Original Table") + for x in table do : println(x) + println("Optimized Table") + for x in table* do : println(x) + println("Enable Table") + for x in enables do : println(x) + println("Optimized Enable Table") + for x in enables* do : println(x) + + Module(name(m),ports(m),expand-whens-stmt(table*,enables*)) defn expand-whens (c:Circuit) -> Circuit : Circuit(modules*, main(c)) where : @@ -2301,12 +2393,12 @@ public defn run-passes (c: Circuit, p: List<Char>) : if contains(p,'a') : do-stage("Working IR", to-working-ir) if contains(p,'b') : do-stage("Resolve Kinds", resolve-kinds) if contains(p,'c') : do-stage("Make Explicit Reset", make-explicit-reset) - ;if contains(p,'d') : do-stage("Initialize Registers", initialize-registers) if contains(p,'e') : do-stage("Infer Types", infer-types) if contains(p,'f') : do-stage("Resolve Genders", resolve-genders) if contains(p,'g') : do-stage("Expand Accessors", expand-accessors) if contains(p,'h') : do-stage("Lower To Ground", lower-to-ground) if contains(p,'i') : do-stage("Expand Indexed Connects", expand-connect-indexed) + if contains(p,'p') : do-stage("Initialize Registers", initialize-registers) if contains(p,'j') : do-stage("Expand Whens", expand-whens) ;if contains(p,'l') : do-stage("Structural Form", structural-form) ;if contains(p,'m') : do-stage("Infer Widths", infer-widths) |
