diff options
| author | jackbackrack | 2015-03-03 13:28:58 -0800 |
|---|---|---|
| committer | jackbackrack | 2015-03-03 13:28:58 -0800 |
| commit | cce4364f342beb9dcc1f0e68637656dde6faac41 (patch) | |
| tree | fff04cdc82851c71179e8110aa550c7d625ad538 | |
| parent | 4bb3ec977ea29763af6f4a35f4cb5b236d7a10a5 (diff) | |
working to real + split expressions passes and flo backend
| -rw-r--r-- | src/main/stanza/passes.stanza | 470 |
1 files changed, 261 insertions, 209 deletions
diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza index 25ce27bd..fe31f08c 100644 --- a/src/main/stanza/passes.stanza +++ b/src/main/stanza/passes.stanza @@ -1692,12 +1692,6 @@ defn inline-instances (c:Circuit) : ; ; - - - - - - ;;================ CLEANUP COMMANDS ========================= ;defn cleanup (c:Stmt) : ; match(c) : @@ -1719,206 +1713,268 @@ defn inline-instances (c:Circuit) : ; +;================== WORKING-TO-REAL ============================= +defn working-to-real (c:Circuit) : + ;convert working to real expressions in letrec of module + defn to-real-module (m:Module) : + defn to-real (e:Expression) -> Expression : + match(e) : + (ex:WRef) : Ref(name(ex), type(ex)) + (ex:WField) : Field(exp(ex), name(ex), type(ex)) + (ex:WIndex) : Index(exp(ex), value(ex), type(ex)) + (ex:Field) : Field(to-real(exp(ex)), name(ex), type(ex)) + (ex:Index) : Index(to-real(exp(ex)), value(ex), type(ex)) + (ex:DoPrim) : DoPrim(op(ex), map(to-real, args(ex)), consts(ex), type(ex)) + (ex:ReadPort) : ReadPort(to-real(mem(ex)), to-real(index(ex)), type(ex)) + (ex) : ex + + defn to-real-elements (es:List<KeyValue<Symbol,Element>>) : + for entry in es map : + defn to-real-writeport (e:WritePort) : + WritePort(to-real(index(e)), to-real(value(e)), to-real(enable(e))) + val k = key(entry) + match(value(entry)) : + (el:Register) : + KeyValue(k, Register(type(el), to-real(value(el)), to-real(enable(el)))) + (el:Memory) : + KeyValue(k, Memory(type(el), map(to-real-writeport, writers(el)))) + (el:Node) : + KeyValue(k, Node(type(el), to-real(value(el)))) + (el) : + entry + + defn to-real-command (c:Stmt) : + match(c) : + (c:Connect) : + Connect(to-real(loc(c)), to-real(exp(c))) + (c:Begin) : + Begin(map(to-real-command, body(c))) + (c) : + ;; error + map(to-real-letrec, c) + + defn to-real-letrec (c:Stmt) : + match(c) : + (c:LetRec) : + LetRec(to-real-elements(entries(c)), to-real-command(body(c))) + (c) : + ;; error + map(to-real-letrec, c) + + Module(name(m), ports(m), to-real-letrec(body(m))) + + ;Return the real version of expressions of circuit + ;; assert only one module + val main-module = to-real-module(modules(c)[0]) + Circuit(list(main-module), main(c)) + +;================== SPLIT-EXPRESSIONS ============================= +defn split-expressions (c:Circuit) : + ;split all expression in letrec of module + defn split-expressions (m:Module) : + defn name-elt (e:Expression, elts:Vector<KeyValue<Symbol,Element>>) -> Expression : + val name = gensym("T") + val elt = Node(type(e), e) + add(elts, KeyValue(name, elt)) + Ref(name, type(e)) + + defn split-expression (e:Expression, elts:Vector<KeyValue<Symbol,Element>>) -> Expression : + match(e) : + (ex:Ref) : ex + (ex:UIntValue) : ex + (ex:SIntValue) : ex + (ex:Field) : + name-elt(Field(split-expression(exp(ex), elts), name(ex), type(ex)), elts) + (ex:Index) : + name-elt(Index(split-expression(exp(ex), elts), value(ex), type(ex)), elts) + (ex:DoPrim) : + name-elt(DoPrim(op(ex), map(split-expression{_, elts}, args(ex)), consts(ex), type(ex)), elts) + (ex:ReadPort) : + name-elt(ReadPort(split-expression(mem(ex), elts), split-expression(index(ex), elts), type(ex)), elts) + (ex) : + println-all(["SPLIT ", ex]) + ex + + defn split-command (c:Stmt, es:Vector<KeyValue<Symbol,Element>>) : + match(c) : + (cx:Begin) : Begin(map(split-command{_, es}, body(cx))) + (cx:EmptyStmt) : cx + (cx:Connect) : Connect(loc(cx), split-expression(exp(cx), es)) + + defn split-element-expressions (es:List<KeyValue<Symbol,Element>>) : + for entry in es map-append : + val elts = Vector<KeyValue<Symbol,Element>>() + defn add-ret (key:Symbol, elt:Element) -> List<KeyValue<Symbol,Element>> : + add(elts, KeyValue(key, elt)) + to-list(elts) + defn split-writeport-expressions (e:WritePort) : + WritePort(split-expression(index(e), elts), + split-expression(value(e), elts), + split-expression(enable(e), elts)) + val k = key(entry) + match(value(entry)) : + (el:Register) : + add-ret(k, Register(type(el), split-expression(value(el), elts), + split-expression(enable(el), elts))) + (el:Memory) : + add-ret(k, Memory(type(el), map(split-writeport-expressions, writers(el)))) + (el:Node) : + add-ret(k, Node(type(el), split-expression(value(el), elts))) + (el) : + list(entry) + + defn split-letrec-expressions (c:Stmt) : + match(c) : + (c:LetRec) : + val entries* = split-element-expressions(entries(c)) + val elts = Vector<KeyValue<Symbol,Element>>() + val body* = split-command(body(c), elts) + LetRec(append(entries*, to-list(elts)), body*) + (c) : + ;; error + map(split-letrec-expressions, c) + + Module(name(m), ports(m), split-letrec-expressions(body(m))) + + ;Return the fully split out expression version of circuit + ;; assert only one module + val main-module = split-expressions(modules(c)[0]) + Circuit(list(main-module), main(c)) + + -;;;============= SHIM ======================================== -;;defn shim (i:Immediate) -> Immediate : -;; match(i) : -;; (i:RegData) : -;; Ref(name(i), direction(i), type(i)) -;; (i:InstPort) : -;; val inst = Ref(name(i), UNKNOWN-DIR, UnknownType()) -;; Field(inst, port(i), direction(i), type(i)) -;; (i:Field) : -;; val imm* = shim(imm(i)) -;; put-imm(i, imm*) -;; (i) : i -;; -;;defn shim (c:Stmt) -> Stmt : -;; val c* = map(shim{_ as Immediate}, c) -;; map(shim{_ as Stmt}, c*) -;; -;;defn shim (c:Circuit) -> Circuit : -;; val modules* = -;; for m in modules(c) map : -;; Module(name(m), ports(m), shim(body(m))) -;; Circuit(modules*, main(c)) -;; -;;;================== INLINE MODULES ========================= -;;defn cat-name (p: String|Symbol, s: String|Symbol) -> Symbol : -;; if p == "" or p == `this : ;; TODO: REMOVE THIS WHEN `THIS GETS REMOVED -;; to-symbol(s) -;; else if s == `this : ;; TODO: DITTO -;; to-symbol(p) -;; else : -;; symbol-join([p, "/", s]) -;; -;;defn inline-command (c: Stmt, mods: HashTable<Symbol, Module>, prefix: String, cmds: Vector<Stmt>) : -;; defn rename (n: Symbol) -> Symbol : -;; cat-name(prefix, n) -;; defn inline-name (i:Immediate) -> Symbol : -;; match(i) : -;; (r:Ref) : rename(name(r)) -;; (f:Field) : cat-name(inline-name(imm(f)), name(f)) -;; (f:Index) : cat-name(inline-name(imm(f)), to-string(value(f))) -;; defn inline-imm (i:Immediate) -> Ref : -;; Ref(inline-name(i), direction(i), type(i)) -;; match(c) : -;; (c:DefUInt) : add(cmds, DefUInt(rename(name(c)), value(c), width(c))) -;; (c:DefSInt) : add(cmds, DefSInt(rename(name(c)), value(c), width(c))) -;; (c:DefWire) : add(cmds, DefWire(rename(name(c)), type(c))) -;; (c:DefRegister) : add(cmds, DefRegister(rename(name(c)), type(c))) -;; (c:DefMemory) : add(cmds, DefMemory(rename(name(c)), type(c), size(c))) -;; (c:DefInstance) : inline-module(mods, mods[name(module(c))], to-string(rename(name(c))), cmds) -;; (c:DoPrim) : add(cmds, DoPrim(rename(name(c)), op(c), map(inline-imm, args(c)), consts(c))) -;; (c:DefAccessor) : add(cmds, DefAccessor(rename(name(c)), inline-imm(source(c)), direction(c), inline-imm(index(c)))) -;; (c:Connect) : add(cmds, Connect(inline-imm(loc(c)), inline-imm(exp(c)))) -;; (c:Begin) : do(inline-command{_, mods, prefix, cmds}, body(c)) -;; (c:EmptyStmt) : c -;; (c) : error("Unsupported command") -;; -;;defn inline-port (p: Port, prefix: String) -> Stmt : -;; DefWire(cat-name(prefix, name(p)), type(p)) -;; -;;defn inline-module (mods: HashTable<Symbol, Module>, mod: Module, prefix: String, cmds: Vector<Stmt>) : -;; do(add{cmds, _}, map(inline-port{_, prefix}, ports(mod))) -;; inline-command(body(mod), mods, prefix, cmds) -;; -;;defn inline-modules (c: Circuit) -> Circuit : -;; val cmds = Vector<Stmt>() -;; val mods = HashTable<Symbol, Module>(symbol-hash) -;; for mod in modules(c) do : -;; mods[name(mod)] = mod -;; val top = mods[main(c)] -;; inline-command(body(top), mods, "", cmds) -;; val main* = Module(name(top), ports(top), Begin(to-list(cmds))) -;; Circuit(list(main*), name(top)) -;; -;; ;;;============= FLO PRINTER ====================================== ;;;;; TODO: ;;;;; not supported gt, lte -;; -;;defn flo-op-name (op:PrimOp) -> String : -;; switch {op == _} : -;; ADD-OP : "add" -;; ADD-MOD-OP : "add" -;; MINUS-OP : "sub" -;; SUB-MOD-OP : "sub" -;; TIMES-OP : "mul" ;; todo: signed version -;; DIVIDE-OP : "div" ;; todo: signed version -;; MOD-OP : "mod" ;; todo: signed version -;; SHIFT-LEFT-OP : "lsh" ;; todo: signed version -;; SHIFT-RIGHT-OP : "rsh" -;; PAD-OP : "pad" ;; todo: signed version -;; BIT-AND-OP : "and" -;; BIT-OR-OP : "or" -;; BIT-XOR-OP : "xor" -;; CONCAT-OP : "cat" -;; BIT-SELECT-OP : "rsh" -;; BITS-SELECT-OP : "rsh" -;; LESS-OP : "lt" ;; todo: signed version -;; LESS-EQ-OP : "lte" ;; todo: swap args -;; GREATER-OP : "gt" ;; todo: swap args -;; GREATER-EQ-OP : "gte" ;; todo: signed version -;; EQUAL-OP : "eq" -;; MULTIPLEX-OP : "mux" -;; else : error $ string-join $ -;; ["Unable to print Primop: " op] -;; -;;defn emit (o:OutputStream, top:Symbol, ports:HashTable<Symbol, Port>, lits:HashTable<Symbol, DefUInt>, elt) : -;; match(elt) : -;; (e:String|Symbol|Int) : -;; print(o, e) -;; (e:Ref) : -;; if key?(lits, name(e)) : -;; val lit = lits[name(e)] -;; print-all(o, [value(lit) "'" width(lit)]) -;; else : -;; if key?(ports, name(e)) : -;; print-all(o, [top "::"]) -;; print(o, name(e)) -;; (e:IntWidth) : -;; print(o, value(e)) -;; (e:PrimOp) : -;; print(o, flo-op-name(e)) -;; (e) : -;; println-all(["EMIT " e]) -;; error("Unable to emit") -;; -;;defn emit-all (o:OutputStream, top:Symbol, ports:HashTable<Symbol, Port>, lits:HashTable<Symbol, DefUInt>, elts: Streamable) : -;; for e in elts do : emit(o, top, ports, lits, e) -;; -;;defn prim-width (type:Type) -> Width : -;; match(type) : -;; (t:UIntType) : width(t) -;; (t:SIntType) : width(t) -;; (t) : error("Bad prim width type") -;; -;;defn emit-command (o:OutputStream, cmd:Stmt, top:Symbol, lits:HashTable<Symbol, DefUInt>, regs:HashTable<Symbol, DefRegister>, accs:HashTable<Symbol, DefAccessor>, ports:HashTable<Symbol, Port>, outs:HashTable<Symbol, Port>) : -;; match(cmd) : -;; (c:DefUInt) : -;; lits[name(c)] = c -;; (c:DefSInt) : -;; emit-all(o, top, ports, lits, [name(c) " = " value(c) "'" width(c) "\n"]) -;; (c:DoPrim) : ;; NEED TO FIGURE OUT WHEN WIDTHS ARE NECESSARY AND EXTRACT -;; emit-all(o, top, ports, lits, [name(c) " = " op(c)]) -;; for arg in args(c) do : -;; print(o, " ") -;; emit(o, top, ports, lits, arg) -;; for const in consts(c) do : -;; print(o, " ") -;; emit(o, top, ports, lits, const) -;; print("\n") -;; (c:DefRegister) : -;; regs[name(c)] = c -;; (c:DefMemory) : -;; emit-all(o, top, ports, lits, [name(c) " : mem'" prim-width(type(c)) " " size(c) "\n"]) -;; (c:DefAccessor) : -;; accs[name(c)] = c -;; (c:Connect) : -;; val dst = name(loc(c) as Ref) -;; val src = name(exp(c) as Ref) -;; if key?(regs, dst) : -;; val reg = regs[dst] -;; emit-all(o, top, ports, lits, [dst " = reg'" prim-width(type(reg)) " 0'" prim-width(type(reg)) " " exp(c) "\n"]) -;; else if key?(accs, dst) : -;; val acc = accs[dst] -;; ;; assert(direction(acc) == OUTPUT) -;; emit-all(o, top, ports, lits, [dst " = wr " source(acc) " " index(acc) " " exp(c) "\n"]) -;; else if key?(outs, dst) : -;; val out = outs[dst] -;; emit-all(o, top, ports, lits, [top "::" dst " = out'" prim-width(type(out)) " " exp(c) "\n"]) -;; else if key?(accs, src) : -;; val acc = accs[src] -;; ;; assert(direction(acc) == INPUT) -;; emit-all(o, top, ports, lits, [dst " = rd " source(acc) " " index(acc) "\n"]) -;; else : -;; emit-all(o, top, ports, lits, [dst " = mov " exp(c) "\n"]) -;; (c:Begin) : -;; do(emit-command{o, _, top, lits, regs, accs, ports, outs}, body(c)) -;; (c:DefWire|EmptyStmt) : -;; print("") -;; (c) : -;; error("Unable to print command") -;; -;;defn emit-module (o:OutputStream, m:Module) : -;; val regs = HashTable<Symbol, DefRegister>(symbol-hash) -;; val accs = HashTable<Symbol, DefAccessor>(symbol-hash) -;; val lits = HashTable<Symbol, DefUInt>(symbol-hash) -;; val outs = HashTable<Symbol, Port>(symbol-hash) -;; val portz = HashTable<Symbol, Port>(symbol-hash) -;; for port in ports(m) do : -;; portz[name(port)] = port -;; if direction(port) == OUTPUT : -;; outs[name(port)] = port -;; else if name(port) == `reset : -;; print-all(o, [name(m) "::reset = rst\n"]) -;; else : -;; print-all(o, [name(m) "::" name(port) " = " "in'" prim-width(type(port)) "\n"]) -;; emit-command(o, body(m), name(m), lits, regs, accs, portz, outs) -;; -;;public defn emit-circuit (o:OutputStream, c:Circuit) : -;; emit-module(o, modules(c)[0]) + +defn flo-op-name (op:PrimOp) -> String : + switch {op == _} : + ADD-OP : "add" + ADD-MOD-OP : "add" + SUB-OP : "sub" + SUB-MOD-OP : "sub" + TIMES-OP : "mul" ;; todo: signed version + DIVIDE-OP : "div" ;; todo: signed version + MOD-OP : "mod" ;; todo: signed version + SHIFT-LEFT-OP : "lsh" ;; todo: signed version + SHIFT-RIGHT-OP : "rsh" + PAD-OP : "rsh" ;; todo: signed version + BIT-AND-OP : "and" + BIT-OR-OP : "or" + BIT-XOR-OP : "xor" + CONCAT-OP : "cat" + BIT-SELECT-OP : "rsh" + BITS-SELECT-OP : "rsh" + LESS-OP : "lt" ;; todo: signed version + LESS-EQ-OP : "lte" ;; todo: swap args + GREATER-OP : "lte" ;; todo: swap args + GREATER-EQ-OP : "lt" ;; todo: signed version + EQUAL-OP : "eq" + MULTIPLEX-OP : "mux" + else : error $ string-join $ ["Unable to print Primop: " op] + +defn sane-width (wd:Width) -> Int : + match(wd) : + (w:IntWidth) : max(1, width(w)) + (w) : error("Unknown width") + +defn prim-width (type:Type) -> Int : + match(type) : + (t:UIntType) : sane-width(width(t)) + (t:SIntType) : sane-width(width(t)) + (t) : error("Bad prim width type") + +defn emit-all (o:OutputStream, es:Streamable, top:Symbol) : + for e in es do : + match(e) : + (ex:Expression) : emit(o, ex, top) + (ex:String) : print(o, ex) + (ex:Symbol) : print(o, ex) + (ex:Int) : print(o, ex) + (ex) : print(o, ex) + +defn emit (o:OutputStream, e:Expression, top:Symbol) : + match(e) : + (ex:Ref) : print-all(o, [top "::" name(ex)]) + (ex:UIntValue) : emit-all(o, [value(ex) "'" sane-width(width(ex))], top) + (ex:SIntValue) : emit-all(o, [value(ex) "'" sane-width(width(ex))], top) + (ex:Field) : emit-all(o, [exp(ex) "/" name(ex)], top) + (ex:Index) : emit-all(o, [exp(ex) "/" value(ex)], top) + (ex:DoPrim) : + if op(ex) == EQUAL-OP or op(ex) == GREATER-OP or op(ex) == GREATER-EQ-OP : + emit-all(o, [flo-op-name(op(ex)) "'" prim-width(type(args(ex)[0]))], top) + if op(ex) == GREATER-OP or op(ex) == GREATER-EQ-OP : + emit-all(o, [" " args(ex)[1] " " args(ex)[0]], top) + else : + emit-all(o, [" " args(ex)[0] " " args(ex)[1]], top) + else : + emit-all(o, [flo-op-name(op(ex)) "'" prim-width(type(ex))], top) + if op(ex) == PAD-OP : + emit-all(o, [" " args(ex)[0] " 0"], top) + else : + for arg in args(ex) do : + print(o, " ") + emit(o, arg, top) + for const in consts(ex) do : + print(o, " ") + print(o, const) + (ex) : print-all(o, ["EMIT(" ex ")"]) + +defn emit-elt-exp (o:OutputStream, e:Expression, top:Symbol) : + match(e) : + (ex:DoPrim) : emit(o, e, top) + (ex:ReadPort) : + val m = mem(ex) + val vtype = type(m) as VectorType + emit-all(o, ["rd'" prim-width(type(vtype)) " " "1" " " mem(ex) " " index(ex)], top) + (ex) : emit-all(o, ["mov'" prim-width(type(e)) " " e], top) + +defn emit-elements (o:OutputStream, es:List<KeyValue<Symbol,Element>>, top:Symbol) : + for entry in es do : + val k = key(entry) + match(value(entry)) : + (e:Register) : + emit-all(o, [top "::" k " = reg'" prim-width(type(e)) " " "1" " " value(e) "\n"], top) + ;; enable(e) + (e:Memory) : + val vtype = type(e) as VectorType + emit-all(o, [top "::" k " = mem'" prim-width(type(vtype)) " " size(vtype) "\n"], top) + for wp in writers(e) do : + val name = gensym("T") + emit-all(o, [top "::" name " = wr'" prim-width(type(vtype)) " " enable(wp) " " top "::" k " " index(wp) " " value(wp) "\n"], top) + (e:Node) : + emit-all(o, [top "::" k " = "], top) + emit-elt-exp(o, value(e), top) + print(o, "\n") + +defn emit-outs (o:OutputStream, c:Stmt, outs:HashTable<Symbol,Port>, top:Symbol) : + match(c) : + (cmd:Connect) : emit-all(o, [loc(cmd) " = out'" prim-width(type(loc(cmd))) " " exp(cmd) "\n"], top) + (cmd:Begin) : do(emit-outs{o, _, outs, top}, body(cmd)) + (cmd) : error("Unknown") + +defn emit-letrec (o:OutputStream, c:Stmt, outs:HashTable<Symbol,Port>, top:Symbol) : + match(c) : + (c:LetRec) : + emit-elements(o, entries(c), top) + emit-outs(o, body(c), outs, top) + (c) : c ;; error + +defn emit-module (o:OutputStream, m:Module) : + val outs = HashTable<Symbol,Port>(symbol-hash) + for port in ports(m) do : + if direction(port) == OUTPUT : + outs[name(port)] = port + else if name(port) == `reset : + print-all(o, [name(m) "::" name(port) " = rst'1\n"]) + else : + print-all(o, [name(m) "::" name(port) " = " "in'" prim-width(type(port)) "\n"]) + emit-letrec(o, body(m), outs, name(m)) + +public defn emit-circuit (o:OutputStream, c:Circuit) : + emit-module(o, modules(c)[0]) + + ;============= DRIVER ====================================== @@ -1950,12 +2006,8 @@ public defn run-passes (c: Circuit, p: List<Char>) : if contains(p,'m') : do-stage("Infer Widths", infer-widths) if contains(p,'n') : do-stage("Pad Widths", pad-widths) if contains(p,'o') : do-stage("Inline Instances", inline-instances) + if contains(p,'p') : do-stage("Working To Real", working-to-real) + if contains(p,'q') : do-stage("Split Expressions", split-expressions) + if contains(p,'r') : emit-circuit(STANDARD-OUTPUT, c*) println("Done!") - - - ;; println("Shim for Jonathan's Passes") - ;; c* = shim(c*) - ;; println("Inline Modules") - ;; c* = inline-modules(c*) - ; c* |
