aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/stanza/passes.stanza470
1 files changed, 261 insertions, 209 deletions
diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza
index 25ce27bd..fe31f08c 100644
--- a/src/main/stanza/passes.stanza
+++ b/src/main/stanza/passes.stanza
@@ -1692,12 +1692,6 @@ defn inline-instances (c:Circuit) :
;
;
-
-
-
-
-
-
;;================ CLEANUP COMMANDS =========================
;defn cleanup (c:Stmt) :
; match(c) :
@@ -1719,206 +1713,268 @@ defn inline-instances (c:Circuit) :
;
+;================== WORKING-TO-REAL =============================
+defn working-to-real (c:Circuit) :
+ ;convert working to real expressions in letrec of module
+ defn to-real-module (m:Module) :
+ defn to-real (e:Expression) -> Expression :
+ match(e) :
+ (ex:WRef) : Ref(name(ex), type(ex))
+ (ex:WField) : Field(exp(ex), name(ex), type(ex))
+ (ex:WIndex) : Index(exp(ex), value(ex), type(ex))
+ (ex:Field) : Field(to-real(exp(ex)), name(ex), type(ex))
+ (ex:Index) : Index(to-real(exp(ex)), value(ex), type(ex))
+ (ex:DoPrim) : DoPrim(op(ex), map(to-real, args(ex)), consts(ex), type(ex))
+ (ex:ReadPort) : ReadPort(to-real(mem(ex)), to-real(index(ex)), type(ex))
+ (ex) : ex
+
+ defn to-real-elements (es:List<KeyValue<Symbol,Element>>) :
+ for entry in es map :
+ defn to-real-writeport (e:WritePort) :
+ WritePort(to-real(index(e)), to-real(value(e)), to-real(enable(e)))
+ val k = key(entry)
+ match(value(entry)) :
+ (el:Register) :
+ KeyValue(k, Register(type(el), to-real(value(el)), to-real(enable(el))))
+ (el:Memory) :
+ KeyValue(k, Memory(type(el), map(to-real-writeport, writers(el))))
+ (el:Node) :
+ KeyValue(k, Node(type(el), to-real(value(el))))
+ (el) :
+ entry
+
+ defn to-real-command (c:Stmt) :
+ match(c) :
+ (c:Connect) :
+ Connect(to-real(loc(c)), to-real(exp(c)))
+ (c:Begin) :
+ Begin(map(to-real-command, body(c)))
+ (c) :
+ ;; error
+ map(to-real-letrec, c)
+
+ defn to-real-letrec (c:Stmt) :
+ match(c) :
+ (c:LetRec) :
+ LetRec(to-real-elements(entries(c)), to-real-command(body(c)))
+ (c) :
+ ;; error
+ map(to-real-letrec, c)
+
+ Module(name(m), ports(m), to-real-letrec(body(m)))
+
+ ;Return the real version of expressions of circuit
+ ;; assert only one module
+ val main-module = to-real-module(modules(c)[0])
+ Circuit(list(main-module), main(c))
+
+;================== SPLIT-EXPRESSIONS =============================
+defn split-expressions (c:Circuit) :
+ ;split all expression in letrec of module
+ defn split-expressions (m:Module) :
+ defn name-elt (e:Expression, elts:Vector<KeyValue<Symbol,Element>>) -> Expression :
+ val name = gensym("T")
+ val elt = Node(type(e), e)
+ add(elts, KeyValue(name, elt))
+ Ref(name, type(e))
+
+ defn split-expression (e:Expression, elts:Vector<KeyValue<Symbol,Element>>) -> Expression :
+ match(e) :
+ (ex:Ref) : ex
+ (ex:UIntValue) : ex
+ (ex:SIntValue) : ex
+ (ex:Field) :
+ name-elt(Field(split-expression(exp(ex), elts), name(ex), type(ex)), elts)
+ (ex:Index) :
+ name-elt(Index(split-expression(exp(ex), elts), value(ex), type(ex)), elts)
+ (ex:DoPrim) :
+ name-elt(DoPrim(op(ex), map(split-expression{_, elts}, args(ex)), consts(ex), type(ex)), elts)
+ (ex:ReadPort) :
+ name-elt(ReadPort(split-expression(mem(ex), elts), split-expression(index(ex), elts), type(ex)), elts)
+ (ex) :
+ println-all(["SPLIT ", ex])
+ ex
+
+ defn split-command (c:Stmt, es:Vector<KeyValue<Symbol,Element>>) :
+ match(c) :
+ (cx:Begin) : Begin(map(split-command{_, es}, body(cx)))
+ (cx:EmptyStmt) : cx
+ (cx:Connect) : Connect(loc(cx), split-expression(exp(cx), es))
+
+ defn split-element-expressions (es:List<KeyValue<Symbol,Element>>) :
+ for entry in es map-append :
+ val elts = Vector<KeyValue<Symbol,Element>>()
+ defn add-ret (key:Symbol, elt:Element) -> List<KeyValue<Symbol,Element>> :
+ add(elts, KeyValue(key, elt))
+ to-list(elts)
+ defn split-writeport-expressions (e:WritePort) :
+ WritePort(split-expression(index(e), elts),
+ split-expression(value(e), elts),
+ split-expression(enable(e), elts))
+ val k = key(entry)
+ match(value(entry)) :
+ (el:Register) :
+ add-ret(k, Register(type(el), split-expression(value(el), elts),
+ split-expression(enable(el), elts)))
+ (el:Memory) :
+ add-ret(k, Memory(type(el), map(split-writeport-expressions, writers(el))))
+ (el:Node) :
+ add-ret(k, Node(type(el), split-expression(value(el), elts)))
+ (el) :
+ list(entry)
+
+ defn split-letrec-expressions (c:Stmt) :
+ match(c) :
+ (c:LetRec) :
+ val entries* = split-element-expressions(entries(c))
+ val elts = Vector<KeyValue<Symbol,Element>>()
+ val body* = split-command(body(c), elts)
+ LetRec(append(entries*, to-list(elts)), body*)
+ (c) :
+ ;; error
+ map(split-letrec-expressions, c)
+
+ Module(name(m), ports(m), split-letrec-expressions(body(m)))
+
+ ;Return the fully split out expression version of circuit
+ ;; assert only one module
+ val main-module = split-expressions(modules(c)[0])
+ Circuit(list(main-module), main(c))
+
+
-;;;============= SHIM ========================================
-;;defn shim (i:Immediate) -> Immediate :
-;; match(i) :
-;; (i:RegData) :
-;; Ref(name(i), direction(i), type(i))
-;; (i:InstPort) :
-;; val inst = Ref(name(i), UNKNOWN-DIR, UnknownType())
-;; Field(inst, port(i), direction(i), type(i))
-;; (i:Field) :
-;; val imm* = shim(imm(i))
-;; put-imm(i, imm*)
-;; (i) : i
-;;
-;;defn shim (c:Stmt) -> Stmt :
-;; val c* = map(shim{_ as Immediate}, c)
-;; map(shim{_ as Stmt}, c*)
-;;
-;;defn shim (c:Circuit) -> Circuit :
-;; val modules* =
-;; for m in modules(c) map :
-;; Module(name(m), ports(m), shim(body(m)))
-;; Circuit(modules*, main(c))
-;;
-;;;================== INLINE MODULES =========================
-;;defn cat-name (p: String|Symbol, s: String|Symbol) -> Symbol :
-;; if p == "" or p == `this : ;; TODO: REMOVE THIS WHEN `THIS GETS REMOVED
-;; to-symbol(s)
-;; else if s == `this : ;; TODO: DITTO
-;; to-symbol(p)
-;; else :
-;; symbol-join([p, "/", s])
-;;
-;;defn inline-command (c: Stmt, mods: HashTable<Symbol, Module>, prefix: String, cmds: Vector<Stmt>) :
-;; defn rename (n: Symbol) -> Symbol :
-;; cat-name(prefix, n)
-;; defn inline-name (i:Immediate) -> Symbol :
-;; match(i) :
-;; (r:Ref) : rename(name(r))
-;; (f:Field) : cat-name(inline-name(imm(f)), name(f))
-;; (f:Index) : cat-name(inline-name(imm(f)), to-string(value(f)))
-;; defn inline-imm (i:Immediate) -> Ref :
-;; Ref(inline-name(i), direction(i), type(i))
-;; match(c) :
-;; (c:DefUInt) : add(cmds, DefUInt(rename(name(c)), value(c), width(c)))
-;; (c:DefSInt) : add(cmds, DefSInt(rename(name(c)), value(c), width(c)))
-;; (c:DefWire) : add(cmds, DefWire(rename(name(c)), type(c)))
-;; (c:DefRegister) : add(cmds, DefRegister(rename(name(c)), type(c)))
-;; (c:DefMemory) : add(cmds, DefMemory(rename(name(c)), type(c), size(c)))
-;; (c:DefInstance) : inline-module(mods, mods[name(module(c))], to-string(rename(name(c))), cmds)
-;; (c:DoPrim) : add(cmds, DoPrim(rename(name(c)), op(c), map(inline-imm, args(c)), consts(c)))
-;; (c:DefAccessor) : add(cmds, DefAccessor(rename(name(c)), inline-imm(source(c)), direction(c), inline-imm(index(c))))
-;; (c:Connect) : add(cmds, Connect(inline-imm(loc(c)), inline-imm(exp(c))))
-;; (c:Begin) : do(inline-command{_, mods, prefix, cmds}, body(c))
-;; (c:EmptyStmt) : c
-;; (c) : error("Unsupported command")
-;;
-;;defn inline-port (p: Port, prefix: String) -> Stmt :
-;; DefWire(cat-name(prefix, name(p)), type(p))
-;;
-;;defn inline-module (mods: HashTable<Symbol, Module>, mod: Module, prefix: String, cmds: Vector<Stmt>) :
-;; do(add{cmds, _}, map(inline-port{_, prefix}, ports(mod)))
-;; inline-command(body(mod), mods, prefix, cmds)
-;;
-;;defn inline-modules (c: Circuit) -> Circuit :
-;; val cmds = Vector<Stmt>()
-;; val mods = HashTable<Symbol, Module>(symbol-hash)
-;; for mod in modules(c) do :
-;; mods[name(mod)] = mod
-;; val top = mods[main(c)]
-;; inline-command(body(top), mods, "", cmds)
-;; val main* = Module(name(top), ports(top), Begin(to-list(cmds)))
-;; Circuit(list(main*), name(top))
-;;
-;;
;;;============= FLO PRINTER ======================================
;;;;; TODO:
;;;;; not supported gt, lte
-;;
-;;defn flo-op-name (op:PrimOp) -> String :
-;; switch {op == _} :
-;; ADD-OP : "add"
-;; ADD-MOD-OP : "add"
-;; MINUS-OP : "sub"
-;; SUB-MOD-OP : "sub"
-;; TIMES-OP : "mul" ;; todo: signed version
-;; DIVIDE-OP : "div" ;; todo: signed version
-;; MOD-OP : "mod" ;; todo: signed version
-;; SHIFT-LEFT-OP : "lsh" ;; todo: signed version
-;; SHIFT-RIGHT-OP : "rsh"
-;; PAD-OP : "pad" ;; todo: signed version
-;; BIT-AND-OP : "and"
-;; BIT-OR-OP : "or"
-;; BIT-XOR-OP : "xor"
-;; CONCAT-OP : "cat"
-;; BIT-SELECT-OP : "rsh"
-;; BITS-SELECT-OP : "rsh"
-;; LESS-OP : "lt" ;; todo: signed version
-;; LESS-EQ-OP : "lte" ;; todo: swap args
-;; GREATER-OP : "gt" ;; todo: swap args
-;; GREATER-EQ-OP : "gte" ;; todo: signed version
-;; EQUAL-OP : "eq"
-;; MULTIPLEX-OP : "mux"
-;; else : error $ string-join $
-;; ["Unable to print Primop: " op]
-;;
-;;defn emit (o:OutputStream, top:Symbol, ports:HashTable<Symbol, Port>, lits:HashTable<Symbol, DefUInt>, elt) :
-;; match(elt) :
-;; (e:String|Symbol|Int) :
-;; print(o, e)
-;; (e:Ref) :
-;; if key?(lits, name(e)) :
-;; val lit = lits[name(e)]
-;; print-all(o, [value(lit) "'" width(lit)])
-;; else :
-;; if key?(ports, name(e)) :
-;; print-all(o, [top "::"])
-;; print(o, name(e))
-;; (e:IntWidth) :
-;; print(o, value(e))
-;; (e:PrimOp) :
-;; print(o, flo-op-name(e))
-;; (e) :
-;; println-all(["EMIT " e])
-;; error("Unable to emit")
-;;
-;;defn emit-all (o:OutputStream, top:Symbol, ports:HashTable<Symbol, Port>, lits:HashTable<Symbol, DefUInt>, elts: Streamable) :
-;; for e in elts do : emit(o, top, ports, lits, e)
-;;
-;;defn prim-width (type:Type) -> Width :
-;; match(type) :
-;; (t:UIntType) : width(t)
-;; (t:SIntType) : width(t)
-;; (t) : error("Bad prim width type")
-;;
-;;defn emit-command (o:OutputStream, cmd:Stmt, top:Symbol, lits:HashTable<Symbol, DefUInt>, regs:HashTable<Symbol, DefRegister>, accs:HashTable<Symbol, DefAccessor>, ports:HashTable<Symbol, Port>, outs:HashTable<Symbol, Port>) :
-;; match(cmd) :
-;; (c:DefUInt) :
-;; lits[name(c)] = c
-;; (c:DefSInt) :
-;; emit-all(o, top, ports, lits, [name(c) " = " value(c) "'" width(c) "\n"])
-;; (c:DoPrim) : ;; NEED TO FIGURE OUT WHEN WIDTHS ARE NECESSARY AND EXTRACT
-;; emit-all(o, top, ports, lits, [name(c) " = " op(c)])
-;; for arg in args(c) do :
-;; print(o, " ")
-;; emit(o, top, ports, lits, arg)
-;; for const in consts(c) do :
-;; print(o, " ")
-;; emit(o, top, ports, lits, const)
-;; print("\n")
-;; (c:DefRegister) :
-;; regs[name(c)] = c
-;; (c:DefMemory) :
-;; emit-all(o, top, ports, lits, [name(c) " : mem'" prim-width(type(c)) " " size(c) "\n"])
-;; (c:DefAccessor) :
-;; accs[name(c)] = c
-;; (c:Connect) :
-;; val dst = name(loc(c) as Ref)
-;; val src = name(exp(c) as Ref)
-;; if key?(regs, dst) :
-;; val reg = regs[dst]
-;; emit-all(o, top, ports, lits, [dst " = reg'" prim-width(type(reg)) " 0'" prim-width(type(reg)) " " exp(c) "\n"])
-;; else if key?(accs, dst) :
-;; val acc = accs[dst]
-;; ;; assert(direction(acc) == OUTPUT)
-;; emit-all(o, top, ports, lits, [dst " = wr " source(acc) " " index(acc) " " exp(c) "\n"])
-;; else if key?(outs, dst) :
-;; val out = outs[dst]
-;; emit-all(o, top, ports, lits, [top "::" dst " = out'" prim-width(type(out)) " " exp(c) "\n"])
-;; else if key?(accs, src) :
-;; val acc = accs[src]
-;; ;; assert(direction(acc) == INPUT)
-;; emit-all(o, top, ports, lits, [dst " = rd " source(acc) " " index(acc) "\n"])
-;; else :
-;; emit-all(o, top, ports, lits, [dst " = mov " exp(c) "\n"])
-;; (c:Begin) :
-;; do(emit-command{o, _, top, lits, regs, accs, ports, outs}, body(c))
-;; (c:DefWire|EmptyStmt) :
-;; print("")
-;; (c) :
-;; error("Unable to print command")
-;;
-;;defn emit-module (o:OutputStream, m:Module) :
-;; val regs = HashTable<Symbol, DefRegister>(symbol-hash)
-;; val accs = HashTable<Symbol, DefAccessor>(symbol-hash)
-;; val lits = HashTable<Symbol, DefUInt>(symbol-hash)
-;; val outs = HashTable<Symbol, Port>(symbol-hash)
-;; val portz = HashTable<Symbol, Port>(symbol-hash)
-;; for port in ports(m) do :
-;; portz[name(port)] = port
-;; if direction(port) == OUTPUT :
-;; outs[name(port)] = port
-;; else if name(port) == `reset :
-;; print-all(o, [name(m) "::reset = rst\n"])
-;; else :
-;; print-all(o, [name(m) "::" name(port) " = " "in'" prim-width(type(port)) "\n"])
-;; emit-command(o, body(m), name(m), lits, regs, accs, portz, outs)
-;;
-;;public defn emit-circuit (o:OutputStream, c:Circuit) :
-;; emit-module(o, modules(c)[0])
+
+defn flo-op-name (op:PrimOp) -> String :
+ switch {op == _} :
+ ADD-OP : "add"
+ ADD-MOD-OP : "add"
+ SUB-OP : "sub"
+ SUB-MOD-OP : "sub"
+ TIMES-OP : "mul" ;; todo: signed version
+ DIVIDE-OP : "div" ;; todo: signed version
+ MOD-OP : "mod" ;; todo: signed version
+ SHIFT-LEFT-OP : "lsh" ;; todo: signed version
+ SHIFT-RIGHT-OP : "rsh"
+ PAD-OP : "rsh" ;; todo: signed version
+ BIT-AND-OP : "and"
+ BIT-OR-OP : "or"
+ BIT-XOR-OP : "xor"
+ CONCAT-OP : "cat"
+ BIT-SELECT-OP : "rsh"
+ BITS-SELECT-OP : "rsh"
+ LESS-OP : "lt" ;; todo: signed version
+ LESS-EQ-OP : "lte" ;; todo: swap args
+ GREATER-OP : "lte" ;; todo: swap args
+ GREATER-EQ-OP : "lt" ;; todo: signed version
+ EQUAL-OP : "eq"
+ MULTIPLEX-OP : "mux"
+ else : error $ string-join $ ["Unable to print Primop: " op]
+
+defn sane-width (wd:Width) -> Int :
+ match(wd) :
+ (w:IntWidth) : max(1, width(w))
+ (w) : error("Unknown width")
+
+defn prim-width (type:Type) -> Int :
+ match(type) :
+ (t:UIntType) : sane-width(width(t))
+ (t:SIntType) : sane-width(width(t))
+ (t) : error("Bad prim width type")
+
+defn emit-all (o:OutputStream, es:Streamable, top:Symbol) :
+ for e in es do :
+ match(e) :
+ (ex:Expression) : emit(o, ex, top)
+ (ex:String) : print(o, ex)
+ (ex:Symbol) : print(o, ex)
+ (ex:Int) : print(o, ex)
+ (ex) : print(o, ex)
+
+defn emit (o:OutputStream, e:Expression, top:Symbol) :
+ match(e) :
+ (ex:Ref) : print-all(o, [top "::" name(ex)])
+ (ex:UIntValue) : emit-all(o, [value(ex) "'" sane-width(width(ex))], top)
+ (ex:SIntValue) : emit-all(o, [value(ex) "'" sane-width(width(ex))], top)
+ (ex:Field) : emit-all(o, [exp(ex) "/" name(ex)], top)
+ (ex:Index) : emit-all(o, [exp(ex) "/" value(ex)], top)
+ (ex:DoPrim) :
+ if op(ex) == EQUAL-OP or op(ex) == GREATER-OP or op(ex) == GREATER-EQ-OP :
+ emit-all(o, [flo-op-name(op(ex)) "'" prim-width(type(args(ex)[0]))], top)
+ if op(ex) == GREATER-OP or op(ex) == GREATER-EQ-OP :
+ emit-all(o, [" " args(ex)[1] " " args(ex)[0]], top)
+ else :
+ emit-all(o, [" " args(ex)[0] " " args(ex)[1]], top)
+ else :
+ emit-all(o, [flo-op-name(op(ex)) "'" prim-width(type(ex))], top)
+ if op(ex) == PAD-OP :
+ emit-all(o, [" " args(ex)[0] " 0"], top)
+ else :
+ for arg in args(ex) do :
+ print(o, " ")
+ emit(o, arg, top)
+ for const in consts(ex) do :
+ print(o, " ")
+ print(o, const)
+ (ex) : print-all(o, ["EMIT(" ex ")"])
+
+defn emit-elt-exp (o:OutputStream, e:Expression, top:Symbol) :
+ match(e) :
+ (ex:DoPrim) : emit(o, e, top)
+ (ex:ReadPort) :
+ val m = mem(ex)
+ val vtype = type(m) as VectorType
+ emit-all(o, ["rd'" prim-width(type(vtype)) " " "1" " " mem(ex) " " index(ex)], top)
+ (ex) : emit-all(o, ["mov'" prim-width(type(e)) " " e], top)
+
+defn emit-elements (o:OutputStream, es:List<KeyValue<Symbol,Element>>, top:Symbol) :
+ for entry in es do :
+ val k = key(entry)
+ match(value(entry)) :
+ (e:Register) :
+ emit-all(o, [top "::" k " = reg'" prim-width(type(e)) " " "1" " " value(e) "\n"], top)
+ ;; enable(e)
+ (e:Memory) :
+ val vtype = type(e) as VectorType
+ emit-all(o, [top "::" k " = mem'" prim-width(type(vtype)) " " size(vtype) "\n"], top)
+ for wp in writers(e) do :
+ val name = gensym("T")
+ emit-all(o, [top "::" name " = wr'" prim-width(type(vtype)) " " enable(wp) " " top "::" k " " index(wp) " " value(wp) "\n"], top)
+ (e:Node) :
+ emit-all(o, [top "::" k " = "], top)
+ emit-elt-exp(o, value(e), top)
+ print(o, "\n")
+
+defn emit-outs (o:OutputStream, c:Stmt, outs:HashTable<Symbol,Port>, top:Symbol) :
+ match(c) :
+ (cmd:Connect) : emit-all(o, [loc(cmd) " = out'" prim-width(type(loc(cmd))) " " exp(cmd) "\n"], top)
+ (cmd:Begin) : do(emit-outs{o, _, outs, top}, body(cmd))
+ (cmd) : error("Unknown")
+
+defn emit-letrec (o:OutputStream, c:Stmt, outs:HashTable<Symbol,Port>, top:Symbol) :
+ match(c) :
+ (c:LetRec) :
+ emit-elements(o, entries(c), top)
+ emit-outs(o, body(c), outs, top)
+ (c) : c ;; error
+
+defn emit-module (o:OutputStream, m:Module) :
+ val outs = HashTable<Symbol,Port>(symbol-hash)
+ for port in ports(m) do :
+ if direction(port) == OUTPUT :
+ outs[name(port)] = port
+ else if name(port) == `reset :
+ print-all(o, [name(m) "::" name(port) " = rst'1\n"])
+ else :
+ print-all(o, [name(m) "::" name(port) " = " "in'" prim-width(type(port)) "\n"])
+ emit-letrec(o, body(m), outs, name(m))
+
+public defn emit-circuit (o:OutputStream, c:Circuit) :
+ emit-module(o, modules(c)[0])
+
+
;============= DRIVER ======================================
@@ -1950,12 +2006,8 @@ public defn run-passes (c: Circuit, p: List<Char>) :
if contains(p,'m') : do-stage("Infer Widths", infer-widths)
if contains(p,'n') : do-stage("Pad Widths", pad-widths)
if contains(p,'o') : do-stage("Inline Instances", inline-instances)
+ if contains(p,'p') : do-stage("Working To Real", working-to-real)
+ if contains(p,'q') : do-stage("Split Expressions", split-expressions)
+ if contains(p,'r') : emit-circuit(STANDARD-OUTPUT, c*)
println("Done!")
-
-
- ;; println("Shim for Jonathan's Passes")
- ;; c* = shim(c*)
- ;; println("Inline Modules")
- ;; c* = inline-modules(c*)
- ; c*