diff options
| author | azidar | 2015-02-13 15:42:47 -0800 |
|---|---|---|
| committer | azidar | 2015-02-13 15:42:47 -0800 |
| commit | 4f68f75415eb89427062eb86ff21b0e53bf4cadd (patch) | |
| tree | 1f6a552e18eed4874a563359e95e5aad87a8ef50 /src/main/stanza/passes.stanza | |
| parent | 4deb61cefa9c0ef7806e3986231865ce59673bc2 (diff) | |
First commit.
Added stanza as a .zip, changed names from ch to firrtl, and spec.tex is
included. need to add installation instructions. TODO's included in
README
Diffstat (limited to 'src/main/stanza/passes.stanza')
| -rw-r--r-- | src/main/stanza/passes.stanza | 1878 |
1 files changed, 1878 insertions, 0 deletions
diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza new file mode 100644 index 00000000..61eac73c --- /dev/null +++ b/src/main/stanza/passes.stanza @@ -0,0 +1,1878 @@ +defpackage chipper.passes : + import core + import verse + import chipper.ir2 + import chipper.ir-utils + import widthsolver + +;============== EXCEPTIONS ================================= +defclass PassException <: Exception +defn PassException (msg:String) : + new PassException : + defmethod print (o:OutputStream, this) : + print(o, msg) + +;=============== WORKING IR ================================ +definterface Kind +defstruct RegKind <: Kind +defstruct AccessorKind <: Kind +defstruct PortKind <: Kind +defstruct MemKind <: Kind +defstruct NodeKind <: Kind +defstruct ModuleKind <: Kind +defstruct InstanceKind <: Kind +defstruct StructuralMemKind <: Kind + +defstruct WRef <: Expression : + name: Symbol + type: Type [multi => false] + kind: Kind + dir: Direction [multi => false] + +defstruct WField <: Expression : + exp: Expression + name: Symbol + type: Type [multi => false] + dir: Direction [multi => false] + +defstruct WIndex <: Expression : + exp: Expression + value: Int + type: Type [multi => false] + dir: Direction [multi => false] + +defstruct WDefAccessor <: Command : + name: Symbol + source: Expression + index: Expression + dir: Direction + +;================ WORKING IR UTILS ========================= +;=== Printers === +defmethod print (o:OutputStream, k:Kind) : + print{o, _} $ + match(k) : + (k:RegKind) : "reg:" + (k:AccessorKind) : "accessor:" + (k:PortKind) : "port:" + (k:MemKind) : "mem:" + (k:NodeKind) : "n:" + (k:ModuleKind) : "module:" + (k:InstanceKind) : "inst:" + (k:StructuralMemKind) : "smem:" + +defmethod print (o:OutputStream, e:WRef) : + print-all(o, [name(e)]) +defmethod print (o:OutputStream, e:WField) : + print-all(o, [exp(e) "." name(e)]) +defmethod print (o:OutputStream, e:WIndex) : + print-all(o, [exp(e) "." value(e)]) + +defmethod print (o:OutputStream, c:WDefAccessor) : + print-all(o, [dir(c) " accessor " name(c) " = " source(c) "[" index(c) "]"]) + +defmethod map (f: Expression -> Expression, e: WField) : + WField(f(exp(e)), name(e), type(e), dir(e)) + +defmethod map (f: Expression -> Expression, e: WIndex) : + WIndex(f(exp(e)), value(e), type(e), dir(e)) + +defmethod map (f: Expression -> Expression, c:WDefAccessor) : + WDefAccessor(name(c), f(source(c)), f(index(c)), dir(c)) + +;================= DIRECTION =============================== +defmulti dir (e:Expression) -> Direction +defmethod dir (e:Expression) : + OUTPUT + +;============== Bring to Working IR ======================== +defn to-working-ir (c:Circuit) : + defn to-exp (e:Expression) : + match(map(to-exp, e)) : + (e:Ref) : WRef(name(e), type(e), NodeKind(), UNKNOWN-DIR) + (e:Field) : WField(exp(e), name(e), type(e), UNKNOWN-DIR) + (e:Index) : WIndex(exp(e), value(e), type(e), UNKNOWN-DIR) + (e) : e + defn to-command (c:Command) : + match(map(to-exp, c)) : + (c:DefAccessor) : + WDefAccessor(name(c), source(c), index(c), UNKNOWN-DIR) + (c) : + map(to-command, c) + + Circuit(modules*, main(c)) where : + val modules* = + for m in modules(c) map : + Module(name(m), ports(m), to-command(body(m))) + +;=============== Resolve Kinds ============================= +defn resolve-kinds (c:Circuit) : + defn resolve-exp (e:Expression, kinds:HashTable<Symbol,Kind>) : + match(e) : + (e:WRef) : WRef(name(e), type(e), kinds[name(e)], dir(e)) + (e) : map(resolve-exp{_, kinds}, e) + + defn resolve-comm (c:Command, kinds:HashTable<Symbol,Kind>) -> Command : + map{resolve-comm{_, kinds}, _} $ + map(resolve-exp{_, kinds}, c) + + defn find-kinds (c:Command, kinds:HashTable<Symbol,Kind>) : + match(c) : + (c:LetRec) : + for entry in entries(c) do : + kinds[key(entry)] = element-kind(value(entry)) + (c:DefWire) : kinds[name(c)] = NodeKind() + (c:DefRegister) : kinds[name(c)] = RegKind() + (c:DefInstance) : kinds[name(c)] = InstanceKind() + (c:DefMemory) : kinds[name(c)] = MemKind() + (c:WDefAccessor) : kinds[name(c)] = AccessorKind() + (c) : false + do(find-kinds{_, kinds}, children(c)) + + defn element-kind (e:Element) : + match(e) : + (e:Memory) : StructuralMemKind() + (e) : NodeKind() + + defn resolve-mod (m:Module, modules:List<Symbol>) : + val kinds = HashTable<Symbol,Kind>(symbol-hash) + for module in modules do : + kinds[module] = ModuleKind() + for port in ports(m) do : + kinds[name(port)] = PortKind() + find-kinds(body(m), kinds) + Module(name(m), ports(m), body*) where : + val body* = resolve-comm(body(m), kinds) + + Circuit(modules*, main(c)) where : + val mod-names = map(name, modules(c)) + val modules* = map(resolve-mod{_, mod-names}, modules(c)) + +;=============== MAKE RESET EXPLICIT ======================= +defn make-explicit-reset (c:Circuit) : + defn reset-instances (c:Command, reset?: List<Symbol>) -> Command : + match(c) : + (c:DefInstance) : + val module = module(c) as WRef + if contains?(reset?, name(module)) : + c + else : + Begin $ list(c, Connect(WField(inst, `reset, UnknownType(), UNKNOWN-DIR), reset)) where : + val inst = WRef(name(c), UnknownType(), InstanceKind(), UNKNOWN-DIR) + val reset = WRef(`reset, UnknownType(), PortKind(), UNKNOWN-DIR) + (c) : + map(reset-instances{_:Command, reset?}, c) + + defn make-explicit-reset (m:Module, reset-list: List<Symbol>) : + val reset? = contains?(reset-list, name(m)) + + ;Add reset port if necessary + val ports* = + if reset? : + ports(m) + else : + val reset = Port(`reset, INPUT, UIntType(IntWidth(1))) + List(reset, ports(m)) + + ;Reset Instances + val body* = reset-instances(body(m), reset-list) + val m* = Module(name(m), ports*, body*) + + ;Initialize registers if necessary + if reset? : m* + else : initialize-registers(m*) + + Circuit(modules*, main(c)) where : + defn reset? (m:Module) : + for p in ports(m) any? : + name(p) == `reset + val reset-list = to-list(stream(name, filter(reset?, modules(c)))) + val modules* = map(make-explicit-reset{_, reset-list}, modules(c)) + +;======= MAKE EXPLICIT REGISTER INITIALIZATION ============= +defn initialize-registers (m:Module) : + ;=== Initializing Expressions === + defn init-exps (inits: List<KeyValue<Symbol,Expression>>) : + if empty?(inits) : + EmptyCommand() + else : + Conditionally(reset, Begin(map(connect, inits)), EmptyCommand()) where : + val reset = WRef(`reset, UnknownType(), PortKind(), UNKNOWN-DIR) + defn connect (init: KeyValue<Symbol, Expression>) : + val reg-ref = WRef(key(init), UnknownType(), RegKind(), UNKNOWN-DIR) + Connect(reg-ref, value(init)) + + defn initialize-registers (c: Command + inits: List<KeyValue<Symbol,Expression>>) -> + [Command, List<KeyValue<Symbol,Expression>>] : + ;=== Rename Expressions === + defn rename (e:Expression) : + match(e) : + (e:WField) : + switch {name(e) == _} : + `init : + if reg?(exp(e)) : init-wire(exp(e)) + else : map(rename, e) + else : map(rename, e) + (e) : map(rename, e) + defn reg? (e:Expression) : + match(e) : + (e:WRef) : kind(e) typeof RegKind + (e) : false + defn init-wire (e:Expression) : + lookup!(inits, name(e as WRef)) + + ;=== Driver === + match(c) : + (c:DefRegister) : + [new-command, list(init-entry)] where : + val wire-name = gensym() + val wire-ref = WRef(wire-name, UnknownType(), NodeKind(), UNKNOWN-DIR) + val reg-ref = WRef(name(c), UnknownType(), RegKind(), UNKNOWN-DIR) + val def-init-wire = DefWire(wire-name, type(c)) + val init-wire = Connect(wire-ref, reg-ref) + val init-reg = Connect(reg-ref, wire-ref) + val new-command = Begin(to-list([c, def-init-wire, init-wire, init-reg])) + val init-entry = name(c) => wire-ref + (c:Conditionally) : + val pred* = rename(pred(c)) + val [conseq* con-inits] = initialize-registers(conseq(c), inits) + val [alt* alt-inits] = initialize-registers(alt(c), inits) + val c* = Conditionally(pred*, conseq+inits, alt+inits) where : + val conseq+inits = Begin(list(conseq*, init-exps(con-inits))) + val alt+inits = Begin(list(alt*, init-exps(alt-inits))) + [c*, List()] + (c:LetRec) : + val c* = map(rename, c) + val [body*, body-inits] = initialize-registers(body(c), inits) + val new-command = + LetRec(entries(c*), body+inits) where : + val body+inits = Begin(list(body*, init-exps(body-inits))) + [new-command, List()] + (c:Begin) : + var inits-in:List<KeyValue<Symbol,Expression>> = inits + var inits-out:List<KeyValue<Symbol,Expression>> = List() + val body* = + for c in body(c) map : + val [c* inits*] = initialize-registers(c, inits-in) + inits-in = append(inits*, inits-in) + inits-out = append(inits*, inits-out) + c* + [Begin(body*), inits-out] + (c) : + val c* = map(rename, c) + [c*, List()] + + Module(name(m), ports(m), body+inits) where : + val [body*, inits] = initialize-registers(body(m), List()) + val body+inits = Begin(list(body*, init-exps(inits))) + + +;============== INFER TYPES ================================ +defmethod type (v:UIntValue) : + UIntType(width(v)) + +defmethod type (v:SIntValue) : + SIntType(width(v)) + +defn put-type (e:Expression, t:Type) -> Expression : + match(e) : + (e:WRef) : WRef(name(e), t, kind(e), dir(e)) + (e:WField) : WField(exp(e), name(e), t, dir(e)) + (e:WIndex) : WIndex(exp(e), value(e), t, dir(e)) + (e:DoPrim) : DoPrim(op(e), args(e), consts(e), t) + (e:ReadPort) : ReadPort(mem(e), index(e), t) + (e) : e + +defn lookup-port (ports: Streamable<Port>, port-name: Symbol) : + for port in ports find : + name(port) == port-name + +defn infer (op:PrimOp, arg-types: List<Type>) -> Type : + defn wipe-width (t:Type) : + match(t) : + (t:UIntType) : UIntType(UnknownWidth()) + (t:SIntType) : SIntType(UnknownWidth()) + + defn arg0 () : wipe-width(arg-types[0]) + defn arg1 () : wipe-width(arg-types[1]) + switch {op == _} : + ADD-OP : arg0() + ADD-MOD-OP : arg0() + SUB-OP : arg0() + SUB-MOD-OP : arg0() + TIMES-OP : arg0() + DIVIDE-OP : arg0() + MOD-OP : arg0() + SHIFT-LEFT-OP : arg0() + SHIFT-RIGHT-OP : arg0() + PAD-OP : arg0() + BIT-AND-OP : arg0() + BIT-OR-OP : arg0() + BIT-XOR-OP : arg0() + CONCAT-OP : arg0() + BIT-SELECT-OP : UIntType(UnknownWidth()) + BITS-SELECT-OP : arg0() + MULTIPLEX-OP : arg0() + LESS-OP : UIntType(UnknownWidth()) + LESS-EQ-OP : UIntType(UnknownWidth()) + GREATER-OP : UIntType(UnknownWidth()) + GREATER-EQ-OP : UIntType(UnknownWidth()) + EQUAL-OP : UIntType(UnknownWidth()) + +defn bundle-field-type (t:Type, n:Symbol) -> Type : + match(t) : + (t:BundleType) : + match(lookup-port(ports(t), n)) : + (p:Port) : type(p) + (p) : UnknownType() + (t) : UnknownType() + +defn vector-elem-type (t:Type) -> Type : + match(t) : + (t:VectorType) : type(t) + (t) : UnknownType() + +;e is the environment that contains all definitions seen so far. +defn infer (c:Command, e:List<KeyValue<Symbol, Type>>) -> [Command, List<KeyValue<Symbol,Type>>] : + defn infer-exp (e:Expression, env:List<KeyValue<Symbol,Type>>) : + match(map(infer-exp{_, env}, e)) : + (e:WRef) : + put-type(e, lookup!(env, name(e))) + (e:WField) : + put-type(e, bundle-field-type(type(exp(e)), name(e))) + (e:WIndex) : + put-type(e, vector-elem-type(type(exp(e)))) + (e:UIntValue) : e + (e:SIntValue) : e + (e:DoPrim) : + put-type(e, infer(op(e), map(type, args(e)))) + (e:ReadPort) : + put-type(e, vector-elem-type(type(mem(e)))) + + defn element-type (e:Element, env:List<KeyValue<Symbol,Type>>) : + match(e) : + (e:Instance) : + val t = type(infer-exp(module(e), env)) + match(t) : + (t:BundleType) : + BundleType $ to-list $ + for p in ports(t) filter : + direction(p) == OUTPUT + (t) : UnknownType() + (e) : type(e) + + match(c) : + (c:LetRec) : + val e* = append(elem-types, e) where : + val elem-types = + for entry in entries(c) map : + key(entry) => element-type(value(entry), e) + val c* = map(infer-exp{_, e*}, c) + val [body*, be] = infer(body(c*), e*) + [LetRec(entries(c*), body*), e] + (c) : + match(map(infer-exp{_, e}, c)) : + (c:DefWire) : + [c, List(entry, e)] where : + val entry = name(c) => type(c) + (c:DefRegister) : + [c, List(entry, e)] where : + val entry = name(c) => type(c) + (c:DefInstance) : + [c, List(entry, e)] where : + val entry = name(c) => type(module(c)) + (c:DefMemory) : + [c, List(entry, e)] where : + val entry = name(c) => type(c) + (c:WDefAccessor) : + [c, List(entry, e)] where : + val src-type = type(source(c)) + val entry = name(c) => vector-elem-type(src-type) + (c:Begin) : + var current-e: List<KeyValue<Symbol,Type>> = e + val body* = for c in body(c) map : + val [c*, e*] = infer(c, current-e) + current-e = e* + c* + [Begin(body*), current-e] + (c) : + defn infer-comm (c:Command) : + val [c* e*] = infer(c, e) + c* + val c* = map(infer-comm, c) + [c*, e] + +defn infer (m:Module, e:List<KeyValue<Symbol, Type>>) -> Module : + val env = append{_, e} $ + for p in ports(m) map : + name(p) => type(p) + val [body*, e*] = infer(body(m), env) + Module(name(m), ports(m), body*) + +defn infer-types (c:Circuit) -> Circuit : + val env = + for m in modules(c) map : + name(m) => BundleType(ports(m)) + Circuit(map(infer{_, env}, modules(c)), + main(c)) + +;============= INFER DIRECTIONS ============================ +defn flip (d:Direction) : + switch {d == _} : + INPUT : OUTPUT + OUTPUT : INPUT + else : d + +defn times (d1:Direction, d2:Direction) : + if d1 == INPUT : flip(d2) + else : d2 + +defn bundle-field-dir (t:Type, n:Symbol) -> Direction : + match(t) : + (t:BundleType) : + match(lookup-port(ports(t), n)) : + (p:Port) : direction(p) + (p) : UNKNOWN-DIR + (t) : UNKNOWN-DIR + +defn infer-dirs (m:Module) : + ;=== Direction of all Binders === + val BI-DIR = new Direction + val directions = HashTable<Symbol,Direction>(symbol-hash) + defn find-dirs (c:Command) : + match(c) : + (c:LetRec) : + for entry in entries(c) do : + directions[key(entry)] = OUTPUT + find-dirs(body(c)) + (c:DefWire) : + directions[name(c)] = BI-DIR + (c:DefRegister) : + directions[name(c)] = BI-DIR + (c:DefInstance) : + directions[name(c)] = OUTPUT + (c:DefMemory) : + directions[name(c)] = BI-DIR + (c:WDefAccessor) : + directions[name(c)] = dir(c) + (c) : + do(find-dirs, children(c)) + for p in ports(m) do : + directions[name(p)] = flip(direction(p)) + find-dirs(body(m)) + + ;=== Fix Point Status === + var changed? = false + + ;=== Infer directions of Expression === + defn infer-exp (e:Expression, desired:Direction) : + match(e) : + (e:WRef) : + val dir* = let : + if kind(e) typeof ModuleKind : + OUTPUT + else : + val old-dir = directions[name(e)] + switch {old-dir == _} : + BI-DIR : + desired + UNKNOWN-DIR : + if directions[name(e)] != desired : + directions[name(e)] = desired + changed? = true + desired + else : + old-dir + WRef(name(e), type(e), kind(e), dir*) + (e:WField) : + val port-dir = bundle-field-dir(type(exp(e)), name(e)) + val exp* = infer-exp(exp(e), port-dir * desired) + WField(exp*, name(e), type(e), port-dir * dir(exp*)) + (e:WIndex) : + val exp* = infer-exp(exp(e), desired) + WIndex(exp*, value(e), type(e), dir(exp*)) + (e) : + map(infer-exp{_, OUTPUT}, e) + + ;=== Infer directions of Commands === + defn infer-comm (c:Command) : + match(c) : + (c:LetRec) : + val c* = map(infer-exp{_, OUTPUT}, c) + LetRec(entries(c*), infer-comm(body(c))) + (c:DefInstance) : + DefInstance(name(c), + infer-exp(module(c), OUTPUT)) + (c:WDefAccessor) : + val d = directions[name(c)] + WDefAccessor(name(c), + infer-exp(source(c), d), + infer-exp(index(c), OUTPUT), + d) + (c:Conditionally) : + Conditionally(infer-exp(pred(c), OUTPUT), + infer-comm(conseq(c)), + infer-comm(alt(c))) + (c:Connect) : + Connect(infer-exp(loc(c), INPUT), infer-exp(exp(c), OUTPUT)) + (c) : + map(infer-comm, c) + + ;=== Iterate until fix point === + defn* fixpoint (c:Command) : + changed? = false + val c* = infer-comm(c) + if changed? : fixpoint(c*) + else : c* + + Module(name(m), ports(m), body*) where : + val body* = fixpoint(body(m)) + +defn infer-directions (c:Circuit) : + Circuit(modules*, main(c)) where : + val modules* = map(infer-dirs, modules(c)) + + +;============== EXPAND VECS ================================ +defstruct ManyConnect <: Command : + index: Expression + locs: List<Expression> + exp: Expression + +defstruct ConnectMany <: Command : + index: Expression + loc: Expression + exps: List<Expression> + +defmethod print (o:OutputStream, c:ManyConnect) : + print-all(o, [locs(c) "[" index(c) "] := " exp(c)]) +defmethod print (o:OutputStream, c:ConnectMany) : + print-all(o, [loc(c) " := " exps(c) "[" index(c) "]"]) + +defmethod map (f: Expression -> Expression, c:ManyConnect) : + ManyConnect(f(index(c)), map(f, locs(c)), f(exp(c))) +defmethod map (f: Expression -> Expression, c:ConnectMany) : + ConnectMany(f(index(c)), f(loc(c)), map(f, exps(c))) + +defn expand-accessors (m: Module) : + defn expand (c:Command) : + match(c) : + (c:WDefAccessor) : + ;Is the source a memory? + val mem? = + match(source(c)) : + (r:WRef) : kind(r) typeof MemKind + (r) : false + + if mem? : + c + else : + switch {dir(c) == _} : + INPUT : + Begin(list( + DefWire(name(c), type(src-type)) + ManyConnect(index(c), elems, wire-ref))) + where : + val src-type = type(source(c)) as VectorType + val wire-ref = WRef(name(c), type(src-type), NodeKind(), OUTPUT) + val elems = to-list $ + for i in 0 to size(src-type) stream : + WIndex(source(c), i, type(src-type), INPUT) + OUTPUT : + Begin(list( + DefWire(name(c), type(src-type)) + ConnectMany(index(c), wire-ref, elems))) + where : + val src-type = type(source(c)) as VectorType + val wire-ref = WRef(name(c), type(src-type), NodeKind(), INPUT) + val elems = to-list $ + for i in 0 to size(src-type) stream : + WIndex(source(c), i, type(src-type), OUTPUT) + (c) : + map(expand, c) + Module(name(m), ports(m), expand(body(m))) + +defn expand-accessors (c:Circuit) : + Circuit(modules*, main(c)) where : + val modules* = map(expand-accessors, modules(c)) + + + +;=============== BUNDLE FLATTENING ========================= +defn prefix (prefix, suffix) : + symbol-join([prefix "/" suffix]) + +defn prefix-ports (pre:Symbol, ports:List<Port>) : + for p in ports map : + Port(prefix(pre, name(p)), direction(p), type(p)) + +defn flatten-ports (port:Port) -> List<Port> : + match(type(port)) : + (t:BundleType) : + val ports = map-append(flatten-ports, ports(t)) + for p in ports map : + Port(prefix(name(port), name(p)), + direction(port) * direction(p), + type(p)) + (t:VectorType) : + val type* = flatten-type(t) + flatten-ports(Port(name(port), direction(port), type*)) + (t:Type) : + list(port) + +defn flatten-type (t:Type) -> Type : + match(t) : + (t:BundleType) : + BundleType $ + map-append(flatten-ports, ports(t)) + (t:VectorType) : + flatten-type $ BundleType $ to-list $ + for i in 0 to size(t) stream : + Port(to-symbol(i), OUTPUT, type(t)) + (t:Type) : + t + +defn flatten-bundles (c:Circuit) : + defn flatten-exp (e:Expression) : + match(map(flatten-exp, e)) : + (e:UIntValue|SIntValue) : + e + (e:WRef) : + match(kind(e)) : + (k:MemKind|StructuralMemKind) : + val type* = map(flatten-type, type(e)) + put-type(e, type*) + (k) : + val type* = flatten-type(type(e)) + put-type(e, type*) + (e) : + val type* = flatten-type(type(e)) + put-type(e, type*) + + defn flatten-element (e:Element) : + val t* = flatten-type(type(e)) + match(map(flatten-exp, e)) : + (e:Register) : Register(t*, value(e), enable(e)) + (e:Memory) : Memory(t*, writers(e)) + (e:Node) : Node(t*, value(e)) + (e:Instance) : Instance(t*, module(e), ports(e)) + + defn flatten-comm (c:Command) : + match(c) : + (c:LetRec) : + val entries* = + for entry in entries(c) map : + key(entry) => flatten-element(value(entry)) + LetRec(entries*, flatten-comm(body(c))) + (c:DefWire) : + DefWire(name(c), flatten-type(type(c))) + (c:DefRegister) : + DefRegister(name(c), flatten-type(type(c))) + (c:DefMemory) : + val type* = map(flatten-type, type(c)) + DefMemory(name(c), type*) + (c) : + map{flatten-comm, _} $ + map(flatten-exp, c) + + defn flatten-module (m:Module) : + val ports* = map-append(flatten-ports, ports(m)) + val body* = flatten-comm(body(m)) + Module(name(m), ports*, body*) + + Circuit(modules*, main(c)) where : + val modules* = map(flatten-module, modules(c)) + + +;================== BUNDLE EXPANSION ======================= +defn expand-bundles (m:Module) : + + ;Collapse all field/index expressions + defn collapse-exp (e:Expression) -> Expression : + match(e) : + (e:WField) : + match(collapse-exp(exp(e))) : + (ei:WRef) : + if kind(ei) typeof InstanceKind : + e + else : + WRef(name*, type(e), kind(ei), dir(e)) where : + val name* = prefix(name(ei), name(e)) + (ei:WField) : + WField(exp(ei), name*, type(e), dir(e)) where : + val name* = prefix(name(ei), name(e)) + (e:WIndex) : + collapse-exp(WField(exp(e), name, type(e), dir(e))) where : + val name = to-symbol(value(e)) + (e) : + map(collapse-exp, e) + + ;Expand expressions + defn expand-exp (e:Expression) -> List<Expression> : + match(type(e)) : + (t:BundleType) : + for p in ports(t) map : + val dir* = direction(p) * dir(e) + collapse-exp(WField(e, name(p), type(p), dir*)) + (t) : + list(collapse-exp(e)) + + ;Expand commands + defn expand-comm (c:Command) : + match(c) : + (c:DefWire) : + match(type(c)) : + (t:BundleType) : + Begin $ + for p in ports(t) map : + DefWire(prefix(name(c), name(p)), type(p)) + (t) : + c + (c:DefRegister) : + match(type(c)) : + (t:BundleType) : + Begin $ + for p in ports(t) map : + DefRegister(prefix(name(c), name(p)), type(p)) + (t) : + c + (c:DefMemory) : + match(type(type(c))) : + (t:BundleType) : + Begin $ + for p in ports(t) map : + DefMemory(prefix(name(c), name(p)), type*) where : + val s = size(type(c)) + val type* = VectorType(type(p), s) + (t) : + c + (c:WDefAccessor) : + match(type(source(c))) : + (t:BundleType) : + val srcs = expand-exp(source(c)) + Begin $ + for (p in ports(t), src in srcs) map : + WDefAccessor(name*, src, index(c), dir*) where : + val name* = prefix(name(c), name(p)) + val dir* = direction(p) * dir(c) + (t) : + c + (c:Connect) : + val locs = expand-exp(loc(c)) + val exps = expand-exp(exp(c)) + Begin $ + for (l in locs, e in exps) map : + switch {dir(l) == _} : + INPUT : Connect(l, e) + OUTPUT : Connect(e, l) + (c:ManyConnect) : + val locs-list = transpose(map(expand-exp, locs(c))) + val exps = expand-exp(exp(c)) + Begin $ + for (locs in locs-list, e in exps) map : + switch {dir(e) == _} : + OUTPUT : ManyConnect(index(c), locs, e) + INPUT : ConnectMany(index(c), e, locs) + (c:ConnectMany) : + val locs = expand-exp(loc(c)) + val exps-list = transpose(map(expand-exp, exps(c))) + Begin $ + for (l in locs, exps in exps-list) map : + switch {dir(l) == _} : + INPUT : ConnectMany(index(c), l, exps) + OUTPUT : ManyConnect(index(c), exps, l) + (c) : + map{expand-comm, _} $ + map(collapse-exp, c) + + Module(name(m), ports(m), expand-comm(body(m))) + +defn expand-bundles (c:Circuit) : + Circuit(modules*, main(c)) where : + val modules* = map(expand-bundles, modules(c)) + + +;=========== CONVERT MULTI CONNECTS to WHEN ================ +defn expand-multi-connects (c:Circuit) : + defn equal-exp (e1:Expression, e2:Expression) : + DoPrim(EQUAL-OP, list(e1, e2), List(), UIntType(UnknownWidth())) + defn uint (i:Int) : + UIntValue(i, UnknownWidth()) + + defn expand-comm (c:Command) : + match(c) : + (c:ConnectMany) : + Begin $ to-list $ + for (i in 0 to false, e in exps(c)) stream : + Conditionally(equal-exp(index(c), uint(i)), + Connect(loc(c), e) + EmptyCommand()) + (c:ManyConnect) : + Begin $ to-list $ + for (i in 0 to false, l in locs(c)) stream : + Conditionally(equal-exp(index(c), uint(i)), + Connect(l, exp(c)) + EmptyCommand()) + (c) : + map(expand-comm, c) + + defn expand (m:Module) : + Module(name(m), ports(m), expand-comm(body(m))) + + Circuit(modules*, main(c)) where : + val modules* = map(expand, modules(c)) + + +;================ EXPAND WHENS ============================= +definterface SymbolicValue +defstruct ExpValue <: SymbolicValue : + exp: Expression +defstruct WhenValue <: SymbolicValue : + pred: Expression + conseq: SymbolicValue + alt: SymbolicValue +defstruct VoidValue <: SymbolicValue + +defmethod print (o:OutputStream, sv:SymbolicValue) : + match(sv) : + (sv:VoidValue) : print(o, "VOID") + (sv:WhenValue) : print-all(o, ["(" pred(sv) "? " conseq(sv) " : " alt(sv) ")"]) + (sv:ExpValue) : print(o, exp(sv)) + +defn key-eqv? (e1:Expression, e2:Expression) : + match(e1, e2) : + (e1:WRef, e2:WRef) : + name(e1) == name(e2) + (e1:WField, e2:WField) : + name(e1) == name(e2) and + key-eqv?(exp(e1), exp(e2)) + (e1, e2) : + false + +defn merge-env (pred: Expression, + con-env: List<KeyValue<Expression,SymbolicValue>>, + alt-env: List<KeyValue<Expression,SymbolicValue>>) : + val merged = Vector<KeyValue<Expression, SymbolicValue>>() + defn new-key? (k:Expression) : + for entry in merged none? : + key-eqv?(key(entry), k) + + defn sv (env:List<KeyValue<Expression,SymbolicValue>>, k:Expression) : + for entry in env search : + if key-eqv?(key(entry), k) : + value(entry) + + for k in stream(key, concat(con-env, alt-env)) do : + if new-key?(k) : + match(sv(con-env, k), sv(alt-env, k)) : + (a:SymbolicValue, b:SymbolicValue) : + if a == b : + add(merged, k => a) + else : + add(merged, k => WhenValue(pred, a, b)) + (a:SymbolicValue, b:False) : + add(merged, k => a) + (a:False, b:SymbolicValue) : + add(merged, k => b) + (a:False, b:False) : + false + + to-list(merged) + +defn simplify-env (env: List<KeyValue<Expression,SymbolicValue>>) : + val merged = Vector<KeyValue<Expression, SymbolicValue>>() + defn new-key? (k:Expression) : + for entry in merged none? : + key-eqv?(key(entry), k) + for entry in env do : + if new-key?(key(entry)) : + add(merged, entry) + to-list(merged) + +defn expand-whens (m:Module) : + val commands = Vector<Command>() + val elements = Vector<KeyValue<Symbol,Element>>() + defn eval (c:Command, env:List<KeyValue<Expression,SymbolicValue>>) -> + List<KeyValue<Expression,SymbolicValue>> : + match(c) : + (c:LetRec) : + do(add{elements, _}, entries(c)) + eval(body(c), env) + (c:DefWire) : + add(commands, c) + val wire-ref = WRef(name(c), type(c), NodeKind(), INPUT) + List(wire-ref => VoidValue(), env) + (c:DefRegister) : + add(commands, c) + val reg-ref = WRef(name(c), type(c), RegKind(), INPUT) + List(reg-ref => VoidValue(), env) + (c:DefInstance) : + add(commands, c) + val entries = let : + val module-type = type(module(c)) as BundleType + val input-ports = to-list $ + for p in ports(module-type) filter : + direction(p) == INPUT + val inst-ref = WRef(name(c), module-type, InstanceKind(), OUTPUT) + for p in input-ports map : + WField(inst-ref, name(p), type(p), INPUT) => VoidValue() + append(entries, env) + (c:DefMemory) : + add(commands, c) + env + (c:WDefAccessor) : + add(commands, c) + if dir(c) == INPUT : + val access-ref = WRef(name(c), type(source(c)), AccessorKind(), dir(c)) + List(access-ref => VoidValue(), env) + else : + env + (c:Conditionally) : + val con-env = eval(conseq(c), env) + val alt-env = eval(alt(c), env) + merge-env(pred(c), con-env, alt-env) + (c:Begin) : + var env:List<KeyValue<Expression,SymbolicValue>> = env + for c in body(c) do : + env = eval(c, env) + env + (c:Connect) : + List(loc(c) => ExpValue(exp(c)), env) + (c:EmptyCommand) : + env + + defn convert-symbolic (key:Expression, sv:SymbolicValue) : + match(sv) : + (sv:VoidValue) : + throw $ PassException $ string-join $ [ + "No default value for " key "."] + (sv:ExpValue) : + exp(sv) + (sv:WhenValue) : + defn multiplex-exp (pred:Expression, conseq:Expression, alt:Expression) : + DoPrim(MULTIPLEX-OP, list(pred, conseq, alt), List(), type(conseq)) + multiplex-exp(pred(sv), + convert-symbolic(key, conseq(sv)) + convert-symbolic(key, alt(sv))) + + ;Compute final environment + val env0 = let : + val output-ports = to-list $ + for p in ports(m) filter : + direction(p) == OUTPUT + for p in output-ports map : + val port-ref = WRef(name(p), type(p), PortKind(), INPUT) + port-ref => VoidValue() + val env* = simplify-env(eval(body(m), env0)) + + ;Make new body + val body* = Begin(list(defs, LetRec(elems, connections))) where : + val defs = Begin(to-list(commands)) + val elems = to-list(elements) + val connections = Begin $ + for entry in env* map : + val sv = convert-symbolic(key(entry), value(entry)) + Connect(key(entry), sv) + + ;Final module + Module(name(m), ports(m), body*) + +defn expand-whens (c:Circuit) : + val modules* = map(expand-whens, modules(c)) + Circuit(modules*, main(c)) + + +;================ STRUCTURAL FORM ========================== + +defn structural-form (m:Module) : + val elements = Vector<(() -> KeyValue<Symbol,Element>)>() + val connected = HashTable<Symbol, Expression>(symbol-hash) + val write-accessors = HashTable<Symbol, List<WDefAccessor>>(symbol-hash) + val read-accessors = HashTable<Symbol, WDefAccessor>(symbol-hash) + val inst-ports = HashTable<Symbol, List<KeyValue<Symbol, Expression>>>(symbol-hash) + val port-connects = Vector<Connect>() + + defn scan (c:Command) : + match(c) : + (c:Connect) : + match(loc(c)) : + (loc:WRef) : + match(kind(loc)) : + (k:PortKind) : add(port-connects, c) + (k) : connected[name(loc)] = exp(c) + (loc:WField) : + val inst = exp(loc) as WRef + val entry = name(loc) => exp(c) + inst-ports[name(inst)] = List(entry, get?(inst-ports, name(inst), List())) + (c:LetRec) : + for e in entries(c) do : + add(elements, {e}) + scan(body(c)) + (c:DefWire) : + add{elements, _} $ fn () : + name(c) => Node(type(c), connected[name(c)]) + (c:DefRegister) : + add{elements, _} $ fn () : + val one = UIntValue(1, UnknownWidth()) + name(c) => Register(type(c), connected[name(c)], one) + (c:DefInstance) : + add{elements, _} $ fn () : + name(c) => Instance(UnknownType(), module(c), inst-ports[name(c)]) + (c:DefMemory) : + add{elements, _} $ fn () : + val ports = for a in get?(write-accessors, name(c), List()) map : + val one = UIntValue(1, UnknownWidth()) + WritePort(index(a), connected[name(a)], one) + name(c) => Memory(type(c), ports) + (c:WDefAccessor) : + val mem = source(c) as WRef + switch {dir(c) == _} : + INPUT : + write-accessors[name(mem)] = List(c, + get?(write-accessors, name(mem), List())) + OUTPUT : + read-accessors[name(c)] = c + (c) : + do(scan, children(c)) + + defn make-read-ports (e:Expression) : + match(e) : + (e:WRef) : + match(kind(e)) : + (k:AccessorKind) : + val accessor = read-accessors[name(e)] + ReadPort(source(accessor), index(accessor), type(e)) + (k) : e + (e) : map(make-read-ports, e) + + Module(name(m), ports(m), body*) where : + scan(body(m)) + val elems = to-list $ + for e in elements stream : + val entry = e() + key(entry) => map(make-read-ports, value(entry)) + val connect-ports = Begin $ to-list $ + for c in port-connects stream : + Connect(loc(c), make-read-ports(exp(c))) + val body* = + if empty?(elems) : connect-ports + else : LetRec(elems, connect-ports) + +defn structural-form (c:Circuit) : + val modules* = map(structural-form, modules(c)) + Circuit(modules*, main(c)) + + +;==================== WIDTH INFERENCE ====================== +defstruct WidthVar <: Width : + name: Symbol + +defmethod print (o:OutputStream, w:WidthVar) : + print(o, name(w)) + +defn width! (t:Type) : + match(t) : + (t:UIntType) : width(t) + (t:SIntType) : width(t) + (t) : error("No width field.") + +defn put-width (t:Type, w:Width) : + match(t) : + (t:UIntType) : UIntType(w) + (t:SIntType) : SIntType(w) + (t) : t + +defn put-width (e:Expression, w:Width) : + val type* = put-width(type(e), w) + put-type(e, type*) + +defn add-width-vars (t:Type) : + defn width? (w:Width) : + match(w) : + (w:UnknownWidth) : WidthVar(gensym()) + (w) : w + match(t) : + (t:UIntType) : UIntType(width?(width(t))) + (t:SIntType) : SIntType(width?(width(t))) + (t) : map(add-width-vars, t) + +defn uint-width (i:Int) : + var v:Int = i + var n:Int = 0 + while v != 0 : + v = v >> 1 + n = n + 1 + IntWidth(n) + +defn sint-width (i:Int) : + if i > 0 : + val w = uint-width(i) + IntWidth(width(w) + 1) + else : + val w = uint-width(neg(i) - 1) + IntWidth(width(w) + 1) + +defn to-exp (w:Width) : + match(w) : + (w:IntWidth) : ELit(width(w)) + (w:WidthVar) : EVar(name(w)) + (w) : error $ string-join $ [ + "Cannot convert " w " to exp."] + +defn primop-width (op:PrimOp, ws:List<Width>, ints:List<Int>) -> Exp : + defn wmax (w1:Width, w2:Width) : + EMax(to-exp(w1), to-exp(w2)) + defn wplus (w1:Width, w2:Width) : + EPlus(to-exp(w1), to-exp(w2)) + defn wplus (w1:Width, w2:Int) : + EPlus(to-exp(w1), ELit(w2)) + defn wminus (w1:Width, w2:Width) : + EMinus(to-exp(w1), to-exp(w2)) + defn wminus (w1:Width, w2:Int) : + EMinus(to-exp(w1), ELit(w2)) + defn wmax-inc (w1:Width, w2:Width) : + EPlus(wmax(w1, w2), ELit(1)) + + switch {op == _} : + ADD-OP : wmax-inc(ws[0], ws[1]) + ADD-MOD-OP : wmax(ws[0], ws[1]) + SUB-OP : wmax-inc(ws[0], ws[1]) + SUB-MOD-OP : wmax(ws[0], ws[1]) + TIMES-OP : wplus(ws[0], ws[1]) + DIVIDE-OP : wminus(ws[0], ws[1]) + MOD-OP : to-exp(ws[1]) + SHIFT-LEFT-OP : wplus(ws[0], ints[0]) + SHIFT-RIGHT-OP : wminus(ws[0], ints[0]) + PAD-OP : ELit(ints[0]) + BIT-AND-OP : wmax(ws[0], ws[1]) + BIT-OR-OP : wmax(ws[0], ws[1]) + BIT-XOR-OP : wmax(ws[0], ws[1]) + CONCAT-OP : wplus(ws[0], ints[0]) + BIT-SELECT-OP : ELit(1) + BITS-SELECT-OP : ELit(ints[0]) + MULTIPLEX-OP : wmax(ws[1], ws[2]) + LESS-OP : ELit(1) + LESS-EQ-OP : ELit(1) + GREATER-OP : ELit(1) + GREATER-EQ-OP : ELit(1) + EQUAL-OP : ELit(1) + +defn put-type (el:Element, t:Type) : + match(el) : + (el:Register) : Register(t, value(el), enable(el)) + (el:Memory) : Memory(t, writers(el)) + (el:Node) : Node(t, value(el)) + (el:Instance) : Instance(t, module(el), ports(el)) + +defn generate-constraints (c:Circuit) -> [Circuit, Vector<WConstraint>] : + ;Constraints + val cs = Vector<WConstraint>() + defn new-constraint (Constraint: (Symbol, Exp) -> WConstraint, wvar:Width, width:Width) : + match(wvar) : + (wvar:WidthVar) : + add(cs, Constraint(name(wvar), to-exp(width))) + (wvar) : + false + + defn to-width (e:Exp) : + match(e) : + (e:ELit) : + IntWidth(width(e)) + (e:EVar) : + WidthVar(name(e)) + (e) : + val x = gensym() + add(cs, WidthEqual(x, e)) + WidthVar(x) + + ;Module types + val mod-types = HashTable<Symbol,Type>(symbol-hash) + + defn add-port-vars (m:Module) -> Module : + val ports* = + for p in ports(m) map : + val type* = add-width-vars(type(p)) + Port(name(p), direction(p), type*) + mod-types[name(m)] = BundleType(ports*) + Module(name(m), ports*, body(m)) + + ;Add Width Variables + defn add-module-vars (m:Module) -> Module : + val types = HashTable<Symbol,Type>(symbol-hash) + for p in ports(m) do : + types[name(p)] = type(p) + + defn infer-exp-width (e:Expression) : + match(map(infer-exp-width, e)) : + (e:WRef) : + match(kind(e)) : + (k:ModuleKind) : put-type(e, mod-types[name(e)]) + (k) : put-type(e, types[name(e)]) + (e:WField) : + val t = bundle-field-type(type(exp(e)), name(e)) + put-width(e, width!(t)) + (e:UIntValue) : + match(width(e)) : + (w:UnknownWidth) : UIntValue(value(e), uint-width(value(e))) + (w) : e + (e:SIntValue) : + match(width(e)) : + (w:UnknownWidth) : SIntValue(value(e), sint-width(value(e))) + (w) : e + (e:DoPrim) : + val widths = map(width!{type(_)}, args(e)) + val w = to-width(primop-width(op(e), widths, consts(e))) + put-width(e, w) + (e:ReadPort) : + val elem-type = type(type(mem(e)) as VectorType) + put-width(e, width!(elem-type)) + + defn infer-comm-width (c:Command) : + match(c) : + (c:LetRec) : + ;Add width vars to elements + var entries*: List<KeyValue<Symbol,Element>> = + for entry in entries(c) map : + val el-name = key(entry) + key(entry) => + match(value(entry)) : + (el:Register|Node) : + put-type(el, add-width-vars(type(el))) + (el:Memory) : + el + (el:Instance) : + val mod-type = type(infer-exp-width(module(el))) as BundleType + val type = BundleType $ to-list $ + for p in ports(mod-type) filter : + direction(p) == OUTPUT + put-type(el, type) + + ;Add vars to environment + for entry in entries* do : + types[key(entry)] = type(value(entry)) + + ;Infer types for elements + entries* = + for entry in entries* map : + key(entry) => map(infer-exp-width, value(entry)) + + ;Generate constraints + for entry in entries* do : + val el-name = key(entry) + match(value(entry)) : + (el:Register) : + new-constraint(WidthEqual, reg-width, val-width) where : + val reg-width = width!(types[el-name]) + val val-width = width!(type(value(el))) + (el:Node) : + new-constraint(WidthEqual, node-width, val-width) where : + val node-width = width!(types[el-name]) + val val-width = width!(type(value(el))) + (el:Instance) : + val mod-type = type(module(el)) as BundleType + for entry in ports(el) do : + new-constraint(WidthGreater, port-width, val-width) where : + val port-name = key(entry) + val port-width = width!(bundle-field-type(mod-type, port-name)) + val val-width = width!(type(value(entry))) + (el) : false + + ;Analyze body + LetRec(entries*, infer-comm-width(body(c))) + + (c:Connect) : + val loc* = infer-exp-width(loc(c)) + val exp* = infer-exp-width(exp(c)) + new-constraint(WidthGreater, loc-width, exp-width) where : + val loc-width = width!(type(loc*)) + val exp-width = width!(type(exp*)) + Connect(loc*, exp*) + + (c:Begin) : + map(infer-comm-width, c) + + Module(name(m), ports(m), body*) where : + val body* = infer-comm-width(body(m)) + + val c* = + Circuit(modules*, main(c)) where : + val ms = map(add-port-vars, modules(c)) + val modules* = map(add-module-vars, ms) + [c*, cs] + + +;================== FILL WIDTHS ============================ +defn fill-widths (c:Circuit, solved:Streamable<WidthEqual>) : + ;Populate table + val table = HashTable<Symbol, Width>(symbol-hash) + for eq in solved do : + table[name(eq)] = IntWidth(width(value(eq) as ELit)) + + defn width? (w:Width) : + match(w) : + (w:WidthVar) : get?(table, name(w), UnknownWidth()) + (w) : w + + defn fill-type (t:Type) : + match(t) : + (t:UIntType) : UIntType(width?(width(t))) + (t:SIntType) : SIntType(width?(width(t))) + (t) : map(fill-type, t) + + defn fill-exp (e:Expression) : + val e* = map(fill-exp, e) + val type* = fill-type(type(e)) + put-type(e*, type*) + + defn fill-element (e:Element) : + val e* = map(fill-exp, e) + val type* = fill-type(type(e)) + put-type(e*, type*) + + defn fill-comm (c:Command) : + match(c) : + (c:LetRec) : + val entries* = + for e in entries(c) map : + key(e) => fill-element(value(e)) + LetRec(entries*, fill-comm(body(c))) + (c) : + map{fill-comm, _} $ + map(fill-exp, c) + + defn fill-port (p:Port) : + Port(name(p), direction(p), fill-type(type(p))) + + defn fill-mod (m:Module) : + Module(name(m), ports*, body*) where : + val ports* = map(fill-port, ports(m)) + val body* = fill-comm(body(m)) + + Circuit(modules*, main(c)) where : + val modules* = map(fill-mod, modules(c)) + + +;=============== TYPE INFERENCE DRIVER ===================== +defn infer-widths (c:Circuit) : + val [c*, cs] = generate-constraints(c) + val solved = solve-widths(cs) + fill-widths(c*, solved) + + +;================ PAD WIDTHS =============================== +defn pad-widths (c:Circuit) : + ;Pad an expression to the given width + defn pad-exp (e:Expression, w:Int) : + match(type(e)) : + (t:UIntType|SIntType) : + val prev-w = width!(t) as IntWidth + if width(prev-w) < w : + val type* = put-width(t, IntWidth(w)) + DoPrim(PAD-OP, list(e), list(w), type*) + else : + e + (t) : + e + + defn pad-exp (e:Expression, w:Width) : + val w-value = width(w as IntWidth) + pad-exp(e, w-value) + + ;Convenience + defn max-width (es:Streamable<Expression>) : + defn int-width (e:Expression) : + width(width!(type(e)) as IntWidth) + maximum(stream(int-width, es)) + + defn match-widths (es:List<Expression>) : + val w = max-width(es) + map(pad-exp{_, w}, es) + + ;Match widths for an expression + defn match-exp-width (e:Expression) : + match(map(match-exp-width, e)) : + (e:DoPrim) : + if contains?([BIT-AND-OP, BIT-OR-OP, BIT-XOR-OP, EQUAL-OP], op(e)) : + val args* = match-widths(args(e)) + DoPrim(op(e), args*, consts(e), type(e)) + else if op(e) == MULTIPLEX-OP : + val args* = List(head(args(e)), match-widths(tail(args(e)))) + DoPrim(op(e), args*, consts(e), type(e)) + else : + e + (e) : e + + defn match-element-width (e:Element) : + match(map(match-exp-width, e)) : + (e:Register) : + val w = width!(type(e)) + val value* = pad-exp(value(e), w) + Register(type(e), value*, enable(e)) + (e:Memory) : + val width = width!(type(type(e) as VectorType)) + val writers* = + for w in writers(e) map : + WritePort(index(w), pad-exp(value(w), width), enable(w)) + Memory(type(e), writers*) + (e:Node) : + val w = width!(type(e)) + val value* = pad-exp(value(e), w) + Node(type(e), value*) + (e:Instance) : + val mod-type = type(module(e)) as BundleType + val ports* = + for p in ports(e) map : + val port-type = bundle-field-type(mod-type, key(p)) + val port-val = pad-exp(value(p), width!(port-type)) + key(p) => port-val + Instance(type(e), module(e), ports*) + + ;Match widths for a command + defn match-comm-width (c:Command) : + match(map(match-exp-width, c)) : + (c:LetRec) : + val entries* = + for e in entries(c) map : + key(e) => match-element-width(value(e)) + LetRec(entries*, match-comm-width(body(c))) + (c:Connect) : + val w = width!(type(loc(c))) + val exp* = pad-exp(exp(c), w) + Connect(loc(c), exp*) + (c) : + map(match-comm-width, c) + + defn match-mod-width (m:Module) : + Module(name(m), ports(m), body*) where : + val body* = match-comm-width(body(m)) + + Circuit(modules*, main(c)) where : + val modules* = map(match-mod-width, modules(c)) + + +;================== INLINING =============================== +defn inline-instances (c:Circuit) : + val module-table = HashTable<Symbol,Module>(symbol-hash) + val inlined? = HashTable<Symbol,True|False>(symbol-hash) + for m in modules(c) do : + module-table[name(m)] = m + inlined?[name(m)] = false + + ;Convert a module into a sequence of elements + defn to-elements (m:Module, + inst:Symbol, + port-exps:List<KeyValue<Symbol,Expression>>) -> + List<KeyValue<Symbol, Element>> : + defn rename-exp (e:Expression) : + match(e) : + (e:WRef) : WRef(prefix(inst, name(e)), type(e), kind(e), dir(e)) + (e) : map(rename-exp, e) + + defn to-elements (c:Command) -> List<KeyValue<Symbol,Element>> : + match(c) : + (c:LetRec) : + val entries* = + for entry in entries(c) map : + val name* = prefix(inst, key(entry)) + val element* = map(rename-exp, value(entry)) + name* => element* + val body* = to-elements(body(c)) + append(entries*, body*) + (c:Connect) : + val ref = loc(c) as WRef + val name* = prefix(inst, name(ref)) + list(name* => Node(type(exp(c)), rename-exp(exp(c)))) + (c:Begin) : + map-append(to-elements, body(c)) + + val inputs = + for p in ports(m) map-append : + if direction(p) == INPUT : + val port-exp = lookup!(port-exps, name(p)) + val name* = prefix(inst, name(p)) + list(name* => Node(type(port-exp), port-exp)) + else : + List() + append(inputs, to-elements(body(m))) + + ;Inline all instances in the module + defn inline-instances (m:Module) : + defn rename-exp (e:Expression) : + match(e) : + (e:WField) : + val inst-exp = exp(e) as WRef + val name* = prefix(name(inst-exp), name(e)) + WRef(name*, type(e), NodeKind(), dir(e)) + (e) : + map(rename-exp, e) + + defn inline-elems (es:List<KeyValue<Symbol,Element>>) : + for entry in es map-append : + match(value(entry)) : + (el:Instance) : + val mod-name = name(module(el) as WRef) + val module = inlined-module(mod-name) + to-elements(module, key(entry), ports(el)) + (el) : + list(entry) + + defn inline-comm (c:Command) : + match(map(rename-exp, c)) : + (c:LetRec) : + val entries* = inline-elems(entries(c)) + LetRec(entries*, inline-comm(body(c))) + (c) : + map(inline-comm, c) + + Module(name(m), ports(m), inline-comm(body(m))) + + ;Retrieve the inlined instance of a module + defn inlined-module (name:Symbol) : + if inlined?[name] : + module-table[name] + else : + val module* = inline-instances(module-table[name]) + module-table[name] = module* + inlined?[name] = true + module* + + ;Return the fully inlined circuit + val main-module = inlined-module(main(c)) + Circuit(list(main-module), main(c)) + + +;;;================ UTILITIES ================================ +; +; +; +;defn* root-ref (i:Immediate) : +; match(i) : +; (f:Field) : root-ref(imm(f)) +; (ind:Index) : root-ref(imm(ind)) +; (r) : r +; +;;defn lookup<?T> (e: Streamable<KeyValue<Immediate,?T>>, i:Immediate) : +;; for entry in e search : +;; if eqv?(key(entry), i) : +;; value(entry) +;; +;;defn lookup!<?T> (e: Streamable<KeyValue<Immediate,?T>>, i:Immediate) : +;; lookup(e, i) as T +;; +;;============ CHECK IF NAMES ARE UNIQUE ==================== +;defn check-duplicate-symbols (names: Streamable<Symbol>, msg: String) : +; val dict = HashTable<Symbol, True>(symbol-hash) +; for name in names do: +; if key?(dict, name): +; throw $ PassException $ string-join $ +; [msg ": " name] +; else: +; dict[name] = true +; +;defn check-duplicates (t: Type) : +; match(t) : +; (t:BundleType) : +; val names = map(name, ports(t)) +; check-duplicate-symbols{names, string-join(_)} $ +; ["Duplicate port name in bundle "] +; do(check-duplicates{type(_)}, ports(t)) +; (t:VectorType) : +; check-duplicates(type(t)) +; (t) : false +; +;defn check-duplicates (c: Command) : +; match(c) : +; (c:DefWire) : check-duplicates(type(c)) +; (c:DefRegister) : check-duplicates(type(c)) +; (c:DefMemory) : check-duplicates(type(c)) +; (c) : do(check-duplicates, children(c)) +; +;defn defined-names (c: Command) : +; generate<Symbol> : +; loop(c) where : +; defn loop (c:Command) : +; match(c) : +; (c:Command&HasName) : yield(name(c)) +; (c) : do(loop, children(c)) +; +;defn check-duplicates (m: Module): +; ;Check all duplicate names in all types in all ports and body +; do(check-duplicates{type(_)}, ports(m)) +; check-duplicates(body(m)) +; +; ;Check all names defined in module +; val names = concat(stream(name, ports(m)), +; defined-names(body(m))) +; check-duplicate-symbols{names, string-join(_)} $ +; ["Duplicate definition name in module " name(m)] +; +;defn check-duplicates (c: Circuit) : +; ;Check all duplicate names in all modules +; do(check-duplicates, modules(c)) +; +; ;Check all defined modules +; val names = stream(name, modules(c)) +; check-duplicate-symbols(names, "Duplicate module name") +; +; + + + + + + + +;;================ CLEANUP COMMANDS ========================= +;defn cleanup (c:Command) : +; match(c) : +; (c:Begin) : +; to-command $ generate<Command> : +; loop(c) where : +; defn loop (c:Command) : +; match(c) : +; (c:Begin) : do(loop, body(c)) +; (c:EmptyCommand) : false +; (c) : yield(cleanup(c)) +; (c) : map(cleanup{_ as Command}, c) +; +;defn cleanup (c:Circuit) : +; val modules* = +; for m in modules(c) map : +; map(cleanup, m) +; Circuit(modules*, main(c)) +; + + + +;;;============= SHIM ======================================== +;;defn shim (i:Immediate) -> Immediate : +;; match(i) : +;; (i:RegData) : +;; Ref(name(i), direction(i), type(i)) +;; (i:InstPort) : +;; val inst = Ref(name(i), UNKNOWN-DIR, UnknownType()) +;; Field(inst, port(i), direction(i), type(i)) +;; (i:Field) : +;; val imm* = shim(imm(i)) +;; put-imm(i, imm*) +;; (i) : i +;; +;;defn shim (c:Command) -> Command : +;; val c* = map(shim{_ as Immediate}, c) +;; map(shim{_ as Command}, c*) +;; +;;defn shim (c:Circuit) -> Circuit : +;; val modules* = +;; for m in modules(c) map : +;; Module(name(m), ports(m), shim(body(m))) +;; Circuit(modules*, main(c)) +;; +;;;================== INLINE MODULES ========================= +;;defn cat-name (p: String|Symbol, s: String|Symbol) -> Symbol : +;; if p == "" or p == `this : ;; TODO: REMOVE THIS WHEN `THIS GETS REMOVED +;; to-symbol(s) +;; else if s == `this : ;; TODO: DITTO +;; to-symbol(p) +;; else : +;; symbol-join([p, "/", s]) +;; +;;defn inline-command (c: Command, mods: HashTable<Symbol, Module>, prefix: String, cmds: Vector<Command>) : +;; defn rename (n: Symbol) -> Symbol : +;; cat-name(prefix, n) +;; defn inline-name (i:Immediate) -> Symbol : +;; match(i) : +;; (r:Ref) : rename(name(r)) +;; (f:Field) : cat-name(inline-name(imm(f)), name(f)) +;; (f:Index) : cat-name(inline-name(imm(f)), to-string(value(f))) +;; defn inline-imm (i:Immediate) -> Ref : +;; Ref(inline-name(i), direction(i), type(i)) +;; match(c) : +;; (c:DefUInt) : add(cmds, DefUInt(rename(name(c)), value(c), width(c))) +;; (c:DefSInt) : add(cmds, DefSInt(rename(name(c)), value(c), width(c))) +;; (c:DefWire) : add(cmds, DefWire(rename(name(c)), type(c))) +;; (c:DefRegister) : add(cmds, DefRegister(rename(name(c)), type(c))) +;; (c:DefMemory) : add(cmds, DefMemory(rename(name(c)), type(c), size(c))) +;; (c:DefInstance) : inline-module(mods, mods[name(module(c))], to-string(rename(name(c))), cmds) +;; (c:DoPrim) : add(cmds, DoPrim(rename(name(c)), op(c), map(inline-imm, args(c)), consts(c))) +;; (c:DefAccessor) : add(cmds, DefAccessor(rename(name(c)), inline-imm(source(c)), direction(c), inline-imm(index(c)))) +;; (c:Connect) : add(cmds, Connect(inline-imm(loc(c)), inline-imm(exp(c)))) +;; (c:Begin) : do(inline-command{_, mods, prefix, cmds}, body(c)) +;; (c:EmptyCommand) : c +;; (c) : error("Unsupported command") +;; +;;defn inline-port (p: Port, prefix: String) -> Command : +;; DefWire(cat-name(prefix, name(p)), type(p)) +;; +;;defn inline-module (mods: HashTable<Symbol, Module>, mod: Module, prefix: String, cmds: Vector<Command>) : +;; do(add{cmds, _}, map(inline-port{_, prefix}, ports(mod))) +;; inline-command(body(mod), mods, prefix, cmds) +;; +;;defn inline-modules (c: Circuit) -> Circuit : +;; val cmds = Vector<Command>() +;; val mods = HashTable<Symbol, Module>(symbol-hash) +;; for mod in modules(c) do : +;; mods[name(mod)] = mod +;; val top = mods[main(c)] +;; inline-command(body(top), mods, "", cmds) +;; val main* = Module(name(top), ports(top), Begin(to-list(cmds))) +;; Circuit(list(main*), name(top)) +;; +;; +;;;============= FLO PRINTER ====================================== +;;;;; TODO: +;;;;; not supported gt, lte +;; +;;defn flo-op-name (op:PrimOp) -> String : +;; switch {op == _} : +;; ADD-OP : "add" +;; ADD-MOD-OP : "add" +;; MINUS-OP : "sub" +;; SUB-MOD-OP : "sub" +;; TIMES-OP : "mul" ;; todo: signed version +;; DIVIDE-OP : "div" ;; todo: signed version +;; MOD-OP : "mod" ;; todo: signed version +;; SHIFT-LEFT-OP : "lsh" ;; todo: signed version +;; SHIFT-RIGHT-OP : "rsh" +;; PAD-OP : "pad" ;; todo: signed version +;; BIT-AND-OP : "and" +;; BIT-OR-OP : "or" +;; BIT-XOR-OP : "xor" +;; CONCAT-OP : "cat" +;; BIT-SELECT-OP : "rsh" +;; BITS-SELECT-OP : "rsh" +;; LESS-OP : "lt" ;; todo: signed version +;; LESS-EQ-OP : "lte" ;; todo: swap args +;; GREATER-OP : "gt" ;; todo: swap args +;; GREATER-EQ-OP : "gte" ;; todo: signed version +;; EQUAL-OP : "eq" +;; MULTIPLEX-OP : "mux" +;; else : error $ string-join $ +;; ["Unable to print Primop: " op] +;; +;;defn emit (o:OutputStream, top:Symbol, ports:HashTable<Symbol, Port>, lits:HashTable<Symbol, DefUInt>, elt) : +;; match(elt) : +;; (e:String|Symbol|Int) : +;; print(o, e) +;; (e:Ref) : +;; if key?(lits, name(e)) : +;; val lit = lits[name(e)] +;; print-all(o, [value(lit) "'" width(lit)]) +;; else : +;; if key?(ports, name(e)) : +;; print-all(o, [top "::"]) +;; print(o, name(e)) +;; (e:IntWidth) : +;; print(o, value(e)) +;; (e:PrimOp) : +;; print(o, flo-op-name(e)) +;; (e) : +;; println-all(["EMIT " e]) +;; error("Unable to emit") +;; +;;defn emit-all (o:OutputStream, top:Symbol, ports:HashTable<Symbol, Port>, lits:HashTable<Symbol, DefUInt>, elts: Streamable) : +;; for e in elts do : emit(o, top, ports, lits, e) +;; +;;defn prim-width (type:Type) -> Width : +;; match(type) : +;; (t:UIntType) : width(t) +;; (t:SIntType) : width(t) +;; (t) : error("Bad prim width type") +;; +;;defn emit-command (o:OutputStream, cmd:Command, top:Symbol, lits:HashTable<Symbol, DefUInt>, regs:HashTable<Symbol, DefRegister>, accs:HashTable<Symbol, DefAccessor>, ports:HashTable<Symbol, Port>, outs:HashTable<Symbol, Port>) : +;; match(cmd) : +;; (c:DefUInt) : +;; lits[name(c)] = c +;; (c:DefSInt) : +;; emit-all(o, top, ports, lits, [name(c) " = " value(c) "'" width(c) "\n"]) +;; (c:DoPrim) : ;; NEED TO FIGURE OUT WHEN WIDTHS ARE NECESSARY AND EXTRACT +;; emit-all(o, top, ports, lits, [name(c) " = " op(c)]) +;; for arg in args(c) do : +;; print(o, " ") +;; emit(o, top, ports, lits, arg) +;; for const in consts(c) do : +;; print(o, " ") +;; emit(o, top, ports, lits, const) +;; print("\n") +;; (c:DefRegister) : +;; regs[name(c)] = c +;; (c:DefMemory) : +;; emit-all(o, top, ports, lits, [name(c) " : mem'" prim-width(type(c)) " " size(c) "\n"]) +;; (c:DefAccessor) : +;; accs[name(c)] = c +;; (c:Connect) : +;; val dst = name(loc(c) as Ref) +;; val src = name(exp(c) as Ref) +;; if key?(regs, dst) : +;; val reg = regs[dst] +;; emit-all(o, top, ports, lits, [dst " = reg'" prim-width(type(reg)) " 0'" prim-width(type(reg)) " " exp(c) "\n"]) +;; else if key?(accs, dst) : +;; val acc = accs[dst] +;; ;; assert(direction(acc) == OUTPUT) +;; emit-all(o, top, ports, lits, [dst " = wr " source(acc) " " index(acc) " " exp(c) "\n"]) +;; else if key?(outs, dst) : +;; val out = outs[dst] +;; emit-all(o, top, ports, lits, [top "::" dst " = out'" prim-width(type(out)) " " exp(c) "\n"]) +;; else if key?(accs, src) : +;; val acc = accs[src] +;; ;; assert(direction(acc) == INPUT) +;; emit-all(o, top, ports, lits, [dst " = rd " source(acc) " " index(acc) "\n"]) +;; else : +;; emit-all(o, top, ports, lits, [dst " = mov " exp(c) "\n"]) +;; (c:Begin) : +;; do(emit-command{o, _, top, lits, regs, accs, ports, outs}, body(c)) +;; (c:DefWire|EmptyCommand) : +;; print("") +;; (c) : +;; error("Unable to print command") +;; +;;defn emit-module (o:OutputStream, m:Module) : +;; val regs = HashTable<Symbol, DefRegister>(symbol-hash) +;; val accs = HashTable<Symbol, DefAccessor>(symbol-hash) +;; val lits = HashTable<Symbol, DefUInt>(symbol-hash) +;; val outs = HashTable<Symbol, Port>(symbol-hash) +;; val portz = HashTable<Symbol, Port>(symbol-hash) +;; for port in ports(m) do : +;; portz[name(port)] = port +;; if direction(port) == OUTPUT : +;; outs[name(port)] = port +;; else if name(port) == `reset : +;; print-all(o, [name(m) "::reset = rst\n"]) +;; else : +;; print-all(o, [name(m) "::" name(port) " = " "in'" prim-width(type(port)) "\n"]) +;; emit-command(o, body(m), name(m), lits, regs, accs, portz, outs) +;; +;;public defn emit-circuit (o:OutputStream, c:Circuit) : +;; emit-module(o, modules(c)[0]) + + +;============= DRIVER ====================================== +public defn run-passes (c: Circuit) : + var c*:Circuit = c + defn do-stage (name:String, f: Circuit -> Circuit) : + println(name) + c* = f(c*) + println(c*) + println("\n\n\n\n") + + do-stage("Working IR", to-working-ir) + do-stage("Resolve Kinds", resolve-kinds) + do-stage("Make Explicit Reset", make-explicit-reset) + do-stage("Infer Types", infer-types) + do-stage("Infer Directions", infer-directions) + do-stage("Expand Accessors", expand-accessors) + do-stage("Flatten Bundles", flatten-bundles) + do-stage("Expand Bundles", expand-bundles) + do-stage("Expand Multi Connects", expand-multi-connects) + do-stage("Expand Whens", expand-whens) + do-stage("Structural Form", structural-form) + do-stage("Infer Widths", infer-widths) + do-stage("Pad Widths", pad-widths) + do-stage("Inline Instances", inline-instances) + + + ;; println("Shim for Jonathan's Passes") + ;; c* = shim(c*) + ;; println("Inline Modules") + ;; c* = inline-modules(c*) + ; c* |
