From d4fdab6950b47379137fce750e4a3a6b262e750d Mon Sep 17 00:00:00 2001 From: azidar Date: Fri, 27 Mar 2015 17:37:04 -0700 Subject: Corrected register init by adding initialization of registers pass after lowering. Finished expand-whens. Needs more thorough testing of instances --- src/main/stanza/ir-utils.stanza | 73 +++++++++- src/main/stanza/passes.stanza | 316 ++++++++++++++++++++++++++-------------- 2 files changed, 271 insertions(+), 118 deletions(-) (limited to 'src/main') diff --git a/src/main/stanza/ir-utils.stanza b/src/main/stanza/ir-utils.stanza index 7c3537fc..0fe5ef5b 100644 --- a/src/main/stanza/ir-utils.stanza +++ b/src/main/stanza/ir-utils.stanza @@ -254,12 +254,73 @@ defmethod map (f: Stmt -> Stmt, c:Stmt) -> Stmt : (c:Begin) : Begin(map(f, body(c))) (c) : c -public defmulti children (c:Stmt) -> List -defmethod children (c:Stmt) : - match(c) : - (c:Conditionally) : list(conseq(c), alt(c)) - (c:Begin) : body(c) - (c) : List() +;================= HELPER FUNCTIONS USING MAP =================== +public defmulti do (f:Expression -> ?, e:Expression) -> False +defmethod do (f:Expression -> ?, e:Expression) -> False : + for x in e map : + f(x) + x + false + +public defmulti do (f:Expression -> ?, s:Stmt) -> False +defmethod do (f:Expression -> ?, s:Stmt) -> False : + defn f* (x:Expression) : + f(x) + x + map(f*,s) + false + +public defmulti do (f:Stmt -> ?, s:Stmt) -> False +defmethod do (f:Stmt -> ?, s:Stmt) -> False : + defn f* (x:Stmt) : + f(x) + x + map(f*,s) + false + +public defmulti dor (f:Expression -> ?, e:Expression) -> False +defmethod dor (f:Expression -> ?, e:Expression) -> False : + f(e) + for x in e map : + dor(f,x) + x + false + +public defmulti dor (f:Expression -> ?, s:Stmt) -> False +defmethod dor (f:Expression -> ?, s:Stmt) -> False : + defn f* (x:Expression) : + dor(f,x) + x + map(f*,s) + false + +public defmulti dor (f:Stmt -> ?, s:Stmt) -> False +defmethod dor (f:Stmt -> ?, s:Stmt) -> False : + f(s) + defn f* (x:Stmt) : + dor(f,x) + x + map(f*,s) + false + +public defmulti sub-exps (s:Expression|Stmt) -> List +defmethod sub-exps (e:Expression) -> List : + val l = Vector() + defn f (x:Expression) : add(l,x) + do(f,e) + to-list(l) +defmethod sub-exps (e:Stmt) -> List : + val l = Vector() + defn f (x:Expression) : add(l,x) + do(f,e) + to-list(l) + +public defmulti sub-stmts (s:Stmt) -> List +defmethod sub-stmts (s:Stmt) : + val l = Vector() + defn f (x:Stmt) : add(l,x) + do(f,s) + to-list(l) ;=================== ADAM OPS =============================== public defn split (s:String,c:Char) -> List : diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza index 4c45021d..2c0ad0ec 100644 --- a/src/main/stanza/passes.stanza +++ b/src/main/stanza/passes.stanza @@ -62,6 +62,17 @@ defstruct WDefAccessor <: Stmt : index: Expression gender: Gender +defstruct ConnectToIndexed <: Stmt : + index: Expression + locs: List + exp: Expression + +defstruct ConnectFromIndexed <: Stmt : + index: Expression + loc: Expression + exps: List + + ;================ WORKING IR UTILS ========================= defn plus (g1:Gender,g2:Gender) -> Gender : @@ -209,17 +220,26 @@ defmethod print (o:OutputStream, s:WDefAccessor) : print-all(o,["accessor " name(s) " = " source(s) "[" index(s) "]"]) print-debug(o,s) +defmethod print (o:OutputStream, c:ConnectToIndexed) : + print-all(o, [locs(c) "[" index(c) "] := " exp(c)]) + print-debug(o,c as ?) + +defmethod print (o:OutputStream, c:ConnectFromIndexed) : + print-all(o, [loc(c) " := " exps(c) "[" index(c) "]"]) + print-debug(o,c as ?) + defmethod map (f: Expression -> Expression, e: WRegInit) : WRegInit(f(reg(e)), name(e), type(e), gender(e)) - defmethod map (f: Expression -> Expression, e: WSubfield) : WSubfield(f(exp(e)), name(e), type(e), gender(e)) - defmethod map (f: Expression -> Expression, e: WIndex) : WIndex(f(exp(e)), value(e), type(e), gender(e)) - defmethod map (f: Expression -> Expression, c:WDefAccessor) : WDefAccessor(name(c), f(source(c)), f(index(c)), gender(c)) +defmethod map (f: Expression -> Expression, c:ConnectToIndexed) : + ConnectToIndexed(f(index(c)), map(f, locs(c)), f(exp(c))) +defmethod map (f: Expression -> Expression, c:ConnectFromIndexed) : + ConnectFromIndexed(f(index(c)), f(loc(c)), map(f, exps(c))) ;================= Bring to Working IR ======================== ; Returns a new Circuit with Refs, Subfields, Indexes and DefAccessors @@ -337,6 +357,7 @@ defn make-explicit-reset (c:Circuit) : for m in modules(c) map : make-explicit-reset(m,c) + ;============== INFER TYPES ================================ ; This pass infers the type field in all IR nodes by updating ; and passing an environment to all statements in pre-order @@ -660,26 +681,6 @@ defn resolve-genders (c:Circuit) : ; ConnectFromIndexed (female) ; Eg: -defstruct ConnectToIndexed <: Stmt : - index: Expression - locs: List - exp: Expression - -defstruct ConnectFromIndexed <: Stmt : - index: Expression - loc: Expression - exps: List - -defmethod print (o:OutputStream, c:ConnectToIndexed) : - print-all(o, [locs(c) "[" index(c) "] := " exp(c)]) -defmethod print (o:OutputStream, c:ConnectFromIndexed) : - print-all(o, [loc(c) " := " exps(c) "[" index(c) "]"]) - -defmethod map (f: Expression -> Expression, c:ConnectToIndexed) : - ConnectToIndexed(f(index(c)), map(f, locs(c)), f(exp(c))) -defmethod map (f: Expression -> Expression, c:ConnectFromIndexed) : - ConnectFromIndexed(f(index(c)), f(loc(c)), map(f, exps(c))) - defn expand-vector (e:Expression) -> List : val t = type(e) as VectorType for i in 0 to size(t) map-append : @@ -939,6 +940,83 @@ defn expand-connect-indexed (c: Circuit) -> Circuit : for m in modules(c) map : expand-connect-indexed(m) +;======= MAKE EXPLICIT REGISTER INITIALIZATION ============= +; This pass replaces the reg.init construct by creating a new +; wire that holds the value at initialization. This wire +; is then connected to the register conditionally on reset, +; at the end of the scope containing the register +; declaration +; If a register has no inital value, the wire is connected to +; a NULL node. Later passes will remove these with the base +; case Mux(reset,NULL,a) -> a, and Mux(reset,a,NULL) -> a. +; This ensures proper behavior if this pass is run multiple +; times. + +defn initialize-registers (c:Circuit) : + defn to-wire-name (y:Symbol) : to-symbol("~#init" % [y]) + defn add-when (s:Stmt,h:HashTable) -> Stmt : + var inits = List() + for kv in h do : + val refreg = WRef(key(kv),value(kv),RegKind(),FEMALE) + val refwire = WRef(to-wire-name(key(kv)),value(kv),NodeKind(),MALE) + val connect = Connect(refreg,refwire) + inits = append(inits,list(connect)) + if empty?(inits) : s + else : + val pred = WRef(`reset, UIntType(IntWidth(1)), PortKind(), MALE) + val when-reset = Conditionally(pred,Begin(inits),Begin(List())) + Begin(list(s,when-reset)) + + defn rename (s:Stmt,h:HashTable) -> [Stmt HashTable] : + val t = HashTable(symbol-hash) + defn rename-expr (e:Expression) -> Expression : + match(map(rename-expr,e)) : + (e:WRegInit) : + val new-name = to-wire-name(name(reg(e) as WRef)) + WRef(new-name,type(reg(e)),RegKind(),gender(e)) + (e) : e + defn rename-stmt (s:Stmt) -> Stmt : + match(map(rename-stmt,s)) : + (s:DefRegister) : + if h[name(s)] : + t[name(s)] = type(s) + Begin(list(s,DefWire(to-wire-name(name(s)),type(s)))) + else : s + (s) : map(rename-expr,s) + [rename-stmt(s) t] + + defn init? (y:Symbol,s:Stmt) -> True|False : + var used? = false + defn has? (e:Expression) -> Expression : + match(map(has?,e)) : + (e:WRegInit) : + if name(reg(e) as WRef) == y : used? = true + (e) : map(has?,e) + e + map(has?,s) + used? + + defn using-init (s:Stmt,h:HashTable) -> Stmt : + match(s) : + (s:DefRegister) : h[name(s)] = false + (s) : + for x in h do : + h[key(x)] = value(x) or init?(key(x),s) + map(using-init{_,h},s) + + defn explicit-init-scope (s:Stmt) -> Stmt : + val h = HashTable(symbol-hash) + using-init(s,h) + println(h) + val [s* t] = rename(s,h) + add-when(s*,t) + + Circuit(modules*, main(c)) where : + val modules* = + for m in modules(c) map : + Module(name(m), ports(m), body*) where : + val body* = explicit-init-scope(body(m)) + ;;================ EXPAND WHENS ============================= ; This pass does three things: remove last connect semantics, ; remove conditional blocks, and eliminate concept of scoping. @@ -1000,6 +1078,14 @@ defn NOT (e1:Expression) -> Expression : else if e1 == zero : one else : DoPrim(EQUAL-UU-OP,list(e1,zero),list(),UIntType(IntWidth(1))) +defn children (e:Expression) -> List : + val es = Vector() + do(add{es,_},e) + to-list(es) + + + + ; ======= Symbolic Value Library ========== public definterface SymbolicValue public defstruct SVExp <: SymbolicValue : @@ -1022,6 +1108,18 @@ defmethod map (f: SymbolicValue -> SymbolicValue, sv:SymbolicValue) -> SymbolicV (sv: SVMux) : SVMux(pred(sv),f(conseq(sv)),f(alt(sv))) (sv) : sv +defn do (f:SymbolicValue -> ?, s:SymbolicValue) -> False : + for x in s map : + f(x) + x + false +defn dor (f:SymbolicValue -> ?, e:SymbolicValue) -> False : + do(f,e) + for x in e map : + dor(f,x) + x + false + defmethod equal? (a:SymbolicValue,b:SymbolicValue) -> True|False : match(a,b) : (a:SVNul,b:SVNul) : true @@ -1029,6 +1127,7 @@ defmethod equal? (a:SymbolicValue,b:SymbolicValue) -> True|False : (a:SVMux,b:SVMux) : pred(a) == pred(b) and conseq(a) == conseq(b) and alt(a) == alt(b) (a,b) : false +;TODO add invert to primop defn optimize (sv:SymbolicValue) -> SymbolicValue : match(map(optimize,sv)) : (sv:SVMux) : @@ -1055,16 +1154,21 @@ defn new-table () -> HashTable : HashTable(symbol-hash) ; ========= Expand When Pass =========== defn expand-whens-stmt (table:HashTable,enables:HashTable) -> Stmt : + defn has-nul? (sv:SymbolicValue) -> True|False : + var has? = false + if sv typeof SVNul : has? = true + for x in sv dor : + if x typeof SVNul : has? = true + has? + defn remove-nul (sv:SymbolicValue) -> SymbolicValue : + match(map(remove-nul,sv)) : + (sv:SVMux) : + match(conseq(sv),alt(sv)) : + (c,a:SVNul) : c + (c:SVNul,a) : a + (c,a) : sv + (sv) : sv defn to-exp (sv:SymbolicValue) -> Expression : - defn remove-nul (sv:SymbolicValue) -> SymbolicValue : - match(map(remove-nul,sv)) : - (sv:SVMux) : - match(conseq(sv),alt(sv)) : - (c,a:SVNul) : c - (c:SVNul,a) : a - (c,a) : sv - (sv) : sv - match(remove-nul(sv)) : (sv:SVMux) : DoPrim(MUX-UU-OP, @@ -1087,20 +1191,29 @@ defn expand-whens-stmt (table:HashTable,enables:HashTable) -> HashTable Expression : - defn active (e:Expression) -> True|False : + defn active (e:Expression) -> True|False : match(e) : (e:WRef) : name(e) == sym - (e:WSubfield) : active(exp(e)) - (e:WRegInit) : active(reg(e)) - (es:DoPrim) : reduce-or{_} $ for e in args(es) map : active(e) - (e:ReadPort) : reduce-or{_} $ map(active,list(mem(e),index(e),enable(e))) - (e:WritePort) : reduce-or{_} $ map(active,list(mem(e),index(e),enable(e))) - (e:Register) : reduce-or{_} $ map(active,list(value(e),enable(e))) + (e) : reduce-or{_} $ map(active,children(e)) (e) : false match(sv) : (sv: SVNul) : zero @@ -1135,7 +1243,8 @@ defn get-enables (table:HashTable) -> HashTable SymbolicValue : match(map(get-write-enable,sv)) : @@ -1149,39 +1258,23 @@ defn get-enables (table:HashTable) -> HashTable) -> HashTable : - val merge = HashTable(symbol-hash) - for x in table do : - val sym = key(x) - match(stmt(value(x))) : - (s:DefRegister) : - val write = sv(value(x)) - val x = for x in table find : key(x) == symbol-join([name(s),`.init]) - val init = match(x) : - (x:False) : SVExp(zero) - (x:KeyValue) : sv(value(x)) - merge[sym] = SSV(s,SVMux(WRef(`reset,UIntType(IntWidth(1)),PortKind(),MALE),init,write)) - (s:EmptyStmt) : "Nothing" - (s) : merge[sym] = table[sym] - merge - - -defn build-table (s:Stmt, table-arg:HashTable) -> HashTable : - var table = table-arg - ;println("=====================") +defn build-table (s:Stmt, table:HashTable) -> False : + println("=====================") match(s) : (s:DefWire) : table[name(s)] = SSV(s SVNul()) (s:DefNode) : table[name(s)] = SSV(s SVNul()) (s:DefRegister) : table[name(s)] = SSV(s SVNul()) (s:WDefAccessor) : table[name(s)] = SSV(s SVNul()) - (s:DefInstance) : table[name(s)] = SSV(s SVNul()) + (s:DefInstance) : + for f in fields(type(module(s)) as BundleType) do : + table[to-symbol("~.~" % [name(s),name(f)])] = SSV(s SVNul()) (s:DefMemory) : table[name(s)] = SSV(s SVNul()) (s:Conditionally) : defn deepcopy (t:HashTable) -> HashTable : @@ -1201,57 +1294,56 @@ defn build-table (s:Stmt, table-arg:HashTable) -> HashTable) : table[key(i)] = SSV(stmt(value(i)) SVExp(exp(s))) - (s:Begin) : for s* in body(s) do: table = build-table(s*,table) - (s) : s - table + println("TABLE-C") + for x in table-c do : println(x) + println("TABLE-A") + for x in table-a do : println(x) + println("TABLE") + for x in table do : println(x) + (s:Connect) : + val key* = match(loc(s)) : + (e:WRef) : name(e) + (e:WSubfield) : to-symbol("~.~" % [name(exp(e) as WRef),name(e)]) + (e) : error("Shouldn't be here with ~" % [e]) + table[key*] = SSV(stmt(table[key*]) SVExp(exp(s))); TODO, need to check all references are declared before this point + (s:Begin) : for s* in body(s) do: build-table(s*,table) + (s) : false defn expand-whens (m:Module) -> Module : - val table = build-table(body(m),new-table()) + val table = new-table() + build-table(body(m),table) val table* = HashTable(symbol-hash) for x in table do : table*[key(x)] = SSV(stmt(value(x)) optimize(sv(value(x)))) - val table** = merge-reg-init(table*) - val enables = get-enables(table**) + val enables = get-enables(table*) val enables* = HashTable(symbol-hash) for x in enables do : enables*[key(x)] = optimize(value(x)) - ;println("Original Table") - ;for x in table do : println(x) - ;println("Optimized Table") - ;for x in table* do : println(x) - ;println("Merged Inits Table") - ;for x in table** do : println(x) - ;println("Enable Table") - ;for x in enables do : println(x) - ;println("Optimized Enable Table") - ;for x in enables* do : println(x) - - Module(name(m),ports(m),expand-whens-stmt(table**,enables*)) + println("Original Table") + for x in table do : println(x) + println("Optimized Table") + for x in table* do : println(x) + println("Enable Table") + for x in enables do : println(x) + println("Optimized Enable Table") + for x in enables* do : println(x) + + Module(name(m),ports(m),expand-whens-stmt(table*,enables*)) defn expand-whens (c:Circuit) -> Circuit : Circuit(modules*, main(c)) where : @@ -2301,12 +2393,12 @@ public defn run-passes (c: Circuit, p: List) : if contains(p,'a') : do-stage("Working IR", to-working-ir) if contains(p,'b') : do-stage("Resolve Kinds", resolve-kinds) if contains(p,'c') : do-stage("Make Explicit Reset", make-explicit-reset) - ;if contains(p,'d') : do-stage("Initialize Registers", initialize-registers) if contains(p,'e') : do-stage("Infer Types", infer-types) if contains(p,'f') : do-stage("Resolve Genders", resolve-genders) if contains(p,'g') : do-stage("Expand Accessors", expand-accessors) if contains(p,'h') : do-stage("Lower To Ground", lower-to-ground) if contains(p,'i') : do-stage("Expand Indexed Connects", expand-connect-indexed) + if contains(p,'p') : do-stage("Initialize Registers", initialize-registers) if contains(p,'j') : do-stage("Expand Whens", expand-whens) ;if contains(p,'l') : do-stage("Structural Form", structural-form) ;if contains(p,'m') : do-stage("Infer Widths", infer-widths) -- cgit v1.2.3