diff options
| author | azidar | 2015-03-04 16:25:25 -0800 |
|---|---|---|
| committer | azidar | 2015-03-04 16:25:25 -0800 |
| commit | 6ad6267d26b52258f6e0d4d004aeb5f36856cf95 (patch) | |
| tree | 16aad9875b1f58dc0cc2a5cd59091e89d57a0861 /src | |
| parent | 355749c83d2066f1a149333ed762a7945d405076 (diff) | |
Finished infer-types pass
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/stanza/ir-parser.stanza | 68 | ||||
| -rw-r--r-- | src/main/stanza/ir-utils.stanza | 56 | ||||
| -rw-r--r-- | src/main/stanza/passes.stanza | 2341 |
3 files changed, 1240 insertions, 1225 deletions
diff --git a/src/main/stanza/ir-parser.stanza b/src/main/stanza/ir-parser.stanza index 43383f9a..cbd57f9b 100644 --- a/src/main/stanza/ir-parser.stanza +++ b/src/main/stanza/ir-parser.stanza @@ -54,11 +54,15 @@ defn unwrap-prefix-form (form) : ;======= Split Dots ============ defn split-dots (forms:List) : + defn to-form (x:String) : + val num? = for c in x all? : + c >= '0' and c <= '9' + to-int(x) when num? else to-symbol(x) defn split (form) : match(ut(form)) : (f:Symbol) : val fstr = to-string(f) - if contains?(fstr, '.') : map(to-symbol, split-string(fstr, ".")) + if contains?(fstr, '.') : map(to-form, split-string(fstr, ".")) else : list(form) (f:List) : list(map-append(split, f)) @@ -148,10 +152,10 @@ rd.defsyntax firrtl : ut(name) => Instance(UnknownType(), module, ports) defrule exp : - (?x:#exp . ?f:#symbol) : - Field(x, ut(f), UnknownType()) (?x:#exp . ?f:#int) : Index(x, ut(f), UnknownType()) + (?x:#exp . ?f:#symbol) : + Field(x, ut(f), UnknownType()) (?x:#exp-form) : x @@ -201,26 +205,26 @@ rd.defsyntax firrtl : operators[`sub-wrap-us] = SUB-WRAP-US-OP operators[`sub-wrap-su] = SUB-WRAP-SU-OP operators[`sub-wrap-ss] = SUB-WRAP-SS-OP - operators[`less] = LESS-OP - operators[`less-uu] = LESS-UU-OP - operators[`less-us] = LESS-US-OP - operators[`less-su] = LESS-SU-OP - operators[`less-ss] = LESS-SS-OP - operators[`less-eq] = LESS-EQ-OP - operators[`less-eq-uu] = LESS-EQ-UU-OP - operators[`less-eq-us] = LESS-EQ-US-OP - operators[`less-eq-su] = LESS-EQ-SU-OP - operators[`less-eq-ss] = LESS-EQ-SS-OP - operators[`greater] = GREATER-OP - operators[`greater-uu] = GREATER-UU-OP - operators[`greater-us] = GREATER-US-OP - operators[`greater-su] = GREATER-SU-OP - operators[`greater-ss] = GREATER-SS-OP - operators[`greater-eq] = GREATER-EQ-OP - operators[`greater-eq-uu] = GREATER-EQ-UU-OP - operators[`greater-eq-us] = GREATER-EQ-US-OP - operators[`greater-eq-su] = GREATER-EQ-SU-OP - operators[`greater-eq-ss] = GREATER-EQ-SS-OP + operators[`lt] = LESS-OP + operators[`lt-uu] = LESS-UU-OP + operators[`lt-us] = LESS-US-OP + operators[`lt-su] = LESS-SU-OP + operators[`lt-ss] = LESS-SS-OP + operators[`leq] = LESS-EQ-OP + operators[`leq-uu] = LESS-EQ-UU-OP + operators[`leq-us] = LESS-EQ-US-OP + operators[`leq-su] = LESS-EQ-SU-OP + operators[`leq-ss] = LESS-EQ-SS-OP + operators[`gt] = GREATER-OP + operators[`gt-uu] = GREATER-UU-OP + operators[`gt-us] = GREATER-US-OP + operators[`gt-su] = GREATER-SU-OP + operators[`gt-ss] = GREATER-SS-OP + operators[`geq] = GREATER-EQ-OP + operators[`geq-uu] = GREATER-EQ-UU-OP + operators[`geq-us] = GREATER-EQ-US-OP + operators[`geq-su] = GREATER-EQ-SU-OP + operators[`geq-ss] = GREATER-EQ-SS-OP operators[`equal] = EQUAL-OP operators[`equal-uu] = EQUAL-UU-OP operators[`equal-ss] = EQUAL-SS-OP @@ -236,15 +240,15 @@ rd.defsyntax firrtl : operators[`as-SInt] = AS-SINT-OP operators[`as-SInt-u] = AS-SINT-U-OP operators[`as-SInt-s] = AS-SINT-S-OP - operators[`shift-left] = SHIFT-LEFT-OP - operators[`shift-left-u] = SHIFT-LEFT-U-OP - operators[`shift-left-s] = SHIFT-LEFT-S-OP - operators[`shift-right] = SHIFT-RIGHT-OP - operators[`shift-right-u] = SHIFT-RIGHT-U-OP - operators[`shift-right-s] = SHIFT-RIGHT-S-OP - operators[`convert] = SHIFT-RIGHT-OP - operators[`convert-u] = SHIFT-RIGHT-U-OP - operators[`convert-s] = SHIFT-RIGHT-S-OP + operators[`shl] = SHIFT-LEFT-OP + operators[`shl-u] = SHIFT-LEFT-U-OP + operators[`shl-s] = SHIFT-LEFT-S-OP + operators[`shr] = SHIFT-RIGHT-OP + operators[`shr-u] = SHIFT-RIGHT-U-OP + operators[`shr-s] = SHIFT-RIGHT-S-OP + operators[`convert] = CONVERT-OP + operators[`convert-u] = CONVERT-U-OP + operators[`convert-s] = CONVERT-S-OP operators[`bit-and] = BIT-AND-OP operators[`bit-or] = BIT-OR-OP operators[`bit-xor] = BIT-XOR-OP diff --git a/src/main/stanza/ir-utils.stanza b/src/main/stanza/ir-utils.stanza index 4da64981..9e8c63c5 100644 --- a/src/main/stanza/ir-utils.stanza +++ b/src/main/stanza/ir-utils.stanza @@ -69,26 +69,26 @@ defmethod print (o:OutputStream, op:PrimOp) : SUB-WRAP-US-OP : "sub-wrap-us" SUB-WRAP-SU-OP : "sub-wrap-su" SUB-WRAP-SS-OP : "sub-wrap-ss" - LESS-OP : "less" - LESS-UU-OP : "less-uu" - LESS-US-OP : "less-us" - LESS-SU-OP : "less-su" - LESS-SS-OP : "less-ss" - LESS-EQ-OP : "less-eq" - LESS-EQ-UU-OP : "less-eq-uu" - LESS-EQ-US-OP : "less-eq-us" - LESS-EQ-SU-OP : "less-eq-su" - LESS-EQ-SS-OP : "less-eq-ss" - GREATER-OP : "greater" - GREATER-UU-OP : "greater-uu" - GREATER-US-OP : "greater-us" - GREATER-SU-OP : "greater-su" - GREATER-SS-OP : "greater-ss" - GREATER-EQ-OP : "greater-eq" - GREATER-EQ-UU-OP : "greater-eq-uu" - GREATER-EQ-US-OP : "greater-eq-us" - GREATER-EQ-SU-OP : "greater-eq-su" - GREATER-EQ-SS-OP : "greater-eq-ss" + LESS-OP : "lt" + LESS-UU-OP : "lt-uu" + LESS-US-OP : "lt-us" + LESS-SU-OP : "lt-su" + LESS-SS-OP : "lt-ss" + LESS-EQ-OP : "leq" + LESS-EQ-UU-OP : "leq-uu" + LESS-EQ-US-OP : "leq-us" + LESS-EQ-SU-OP : "leq-su" + LESS-EQ-SS-OP : "leq-ss" + GREATER-OP : "gt" + GREATER-UU-OP : "gt-uu" + GREATER-US-OP : "gt-us" + GREATER-SU-OP : "gt-su" + GREATER-SS-OP : "gt-ss" + GREATER-EQ-OP : "geq" + GREATER-EQ-UU-OP : "geq-uu" + GREATER-EQ-US-OP : "geq-us" + GREATER-EQ-SU-OP : "geq-su" + GREATER-EQ-SS-OP : "geq-ss" EQUAL-OP : "equal" EQUAL-UU-OP : "equal-uu" EQUAL-SS-OP : "equal-ss" @@ -104,12 +104,12 @@ defmethod print (o:OutputStream, op:PrimOp) : AS-SINT-OP : "as-SInt" AS-SINT-U-OP : "as-SInt-u" AS-SINT-S-OP : "as-SInt-s" - SHIFT-LEFT-OP : "shift-left" - SHIFT-LEFT-U-OP : "shift-left-u" - SHIFT-LEFT-S-OP : "shift-left-s" - SHIFT-RIGHT-OP : "shift-right" - SHIFT-RIGHT-U-OP : "shift-right-u" - SHIFT-RIGHT-S-OP : "shift-right-s" + SHIFT-LEFT-OP : "shl" + SHIFT-LEFT-U-OP : "shl-u" + SHIFT-LEFT-S-OP : "shl-s" + SHIFT-RIGHT-OP : "shr" + SHIFT-RIGHT-U-OP : "shr-u" + SHIFT-RIGHT-S-OP : "shr-s" CONVERT-OP : "convert" CONVERT-U-OP : "convert-u" CONVERT-S-OP : "convert-s" @@ -198,7 +198,9 @@ defmethod print (o:OutputStream, t:Type) : (w:UnknownWidth) : print-all(o, ["UInt"]) (w) : print-all(o, ["UInt(" width(t) ")"]) (t:SIntType) : - print-all(o, ["SInt(" width(t) ")"]) + match(width(t)) : + (w:UnknownWidth) : print-all(o, ["SInt"]) + (w) : print-all(o, ["SInt(" width(t) ")"]) (t:BundleType) : print(o, "{") print-all(o, join(ports(t), ", ")) diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza index 18593499..3cdb553b 100644 --- a/src/main/stanza/passes.stanza +++ b/src/main/stanza/passes.stanza @@ -84,9 +84,15 @@ defn any-debug? (e:Expression|Stmt|Type|Element|Port) : (hasKind(e) and PRINT-KINDS) defmethod print-debug (o:OutputStream, e:Expression|Stmt|Type|Element|Port) : + defn wipe-width (t:Type) -> Type : + match(t) : + (t:UIntType) : UIntType(UnknownWidth()) + (t:SIntType) : SIntType(UnknownWidth()) + (t) : t + if any-debug?(e) : print(o,"@") if PRINT-KINDS and hasKind(e) : print-all(o,["<k:" kind(e as ?) ">"]) - if PRINT-TYPES and hasType(e) : print-all(o,["<t:" type(e as ?) ">"]) + if PRINT-TYPES and hasType(e) : print-all(o,["<t:" wipe-width(type(e as ?)) ">"]) if PRINT-WIDTHS and hasWidth(e): print-all(o,["<w:" width(e as ?) ">"]) defmethod print (o:OutputStream, e:WRef) : @@ -346,18 +352,20 @@ defn get-primop-rettype (e:DoPrim) -> Type : defn u () : UIntType(UnknownWidth()) defn s () : SIntType(UnknownWidth()) defn u-and (op1:Expression,op2:Expression) : - if type(op1) typeof UIntType and type(op2) typeof UIntType : - UIntType(UnknownWidth()) - else : - SIntType(UnknownWidth()) + match(type(op1), type(op2)) : + (t1:UIntType, t2:UIntType) : u() + (t1:SIntType, t2) : s() + (t1, t2:SIntType) : s() + (t1, t2) : UnknownType() + defn of-type (op:Expression) : - if type(op) typeof UIntType : - UIntType(UnknownWidth()) - if type(op) typeof SIntType : - SIntType(UnknownWidth()) - else : UnknownType() + match(type(op)) : + (t:UIntType) : u() + (t:SIntType) : s() + (t) : UnknownType() - switch {e == _} : + ;println-all(["Inferencing primop type: " e]) + switch {op(e) == _} : ADD-OP : u-and(args(e)[0],args(e)[1]) ADD-UU-OP : u() ADD-US-OP : s() @@ -378,11 +386,11 @@ defn get-primop-rettype (e:DoPrim) -> Type : DIV-US-OP : s() DIV-SU-OP : s() DIV-SS-OP : s() - MOD-OP : u() + MOD-OP : of-type(args(e)[0]) MOD-UU-OP : u() MOD-US-OP : u() - MOD-SU-OP : u() - MOD-SS-OP : u() + MOD-SU-OP : s() + MOD-SS-OP : s() QUO-OP : u-and(args(e)[0],args(e)[1]) QUO-UU-OP : u() QUO-US-OP : s() @@ -458,22 +466,20 @@ defn type (m:Module) -> Type : BundleType(ports(m)) defn get-type (b:Symbol,l:List<KeyValue<Symbol,Type>>) -> Type : - val contains? = for kv in l any? : b == key(kv) - if contains? : - label<Type> myret : - for kv in l do : - if b == key(kv) : myret(value(kv)) - myret(UnknownType()) - else : UnknownType() + val ma = for kv in l find : b == key(kv) + if ma != false : + val ret = value(ma as KeyValue<Symbol,Type>) + ;println-all(["Found! Returning " ret " for " b]) + ret + else : + ;println-all(["Not found! Returning " UnknownType() " for " b]) + UnknownType() defn bundle-field-type (v:Type,s:Symbol) -> Type : match(v) : (v:BundleType) : - val contains? = for p in ports(v) any? : name(p) == s - if contains? : - label<Type> myret : - for p in ports(v) do : - if b == name(p) : myret(type(p)) + val ft = for p in ports(v) find : name(p) == s + if ft != false : type(ft as Port) else : UnknownType() (v) : UnknownType() @@ -482,8 +488,9 @@ defn get-vector-subtype (v:Type) -> Type : (v:VectorType) : type(v) (v) : UnknownType() -defn infer-types (e:Expression, l:List<KeyValue<Symbol,Type>>) -> Expression : - match(map(infer-types{_,l},e)) : +defn infer-exp-types (e:Expression, l:List<KeyValue<Symbol,Type>>) -> Expression : + val r = map(infer-exp-types{_,l},e) + match(r) : (e:WRef) : WRef(name(e), get-type(name(e),l),kind(e),dir(e)) (e:WField) : WField(exp(e),name(e), bundle-field-type(type(exp(e)),name(e)),dir(e)) (e:WIndex) : WIndex(exp(e),value(e), get-vector-subtype(type(exp(e))),dir(e)) @@ -492,7 +499,7 @@ defn infer-types (e:Expression, l:List<KeyValue<Symbol,Type>>) -> Expression : (e:UIntValue|SIntValue|Null) : e defn infer-types (s:Stmt, l:List<KeyValue<Symbol,Type>>) -> [Stmt, List<KeyValue<Symbol,Type>>] : - match(s) : + match(map(infer-exp-types{_,l},s)) : (s:LetRec) : [s,l] ;TODO, this is wrong but we might be getting rid of letrecs? (s:Begin) : var env = l @@ -507,7 +514,7 @@ defn infer-types (s:Stmt, l:List<KeyValue<Symbol,Type>>) -> [Stmt, List<KeyValue (s:DefMemory) : [s,List(name(s) => type(s),l)] (s:DefInstance) : [s, List(name(s) => type(module(s)),l)] (s:DefNode) : [s, List(name(s) => type(value(s)),l)] - (s:WDefAccessor) : [s, List(name(s) => type(source(s)),l)] + (s:WDefAccessor) : [s, List(name(s) => get-vector-subtype(type(source(s))),l)] (s:Conditionally) : val [s*,l*] = infer-types(conseq(s),l) val [s**,l**] = infer-types(alt(s),l) @@ -518,6 +525,7 @@ defn infer-types (m:Module, l:List<KeyValue<Symbol,Type>>) -> Module : val ptypes = for p in ports(m) map : name(p) => type(p) + ;println-all(append(ptypes,l)) val [s,l*] = infer-types(body(m),append(ptypes, l)) Module(name(m),ports(m),s) @@ -525,1138 +533,1139 @@ defn infer-types (c:Circuit) -> Circuit : val l = for m in modules(c) map : name(m) => BundleType(ports(m)) + ;println-all(l) Circuit{ _, main(c) } $ for m in modules(c) map : infer-types(m,l) ;============= INFER DIRECTIONS ============================ -defn flip (d:Direction) : - switch {d == _} : - INPUT : OUTPUT - OUTPUT : INPUT - else : d - -defn times (d1:Direction, d2:Direction) : - if d1 == INPUT : flip(d2) - else : d2 - -defn lookup-port (ports: Streamable<Port>, port-name: Symbol) : - for port in ports find : - name(port) == port-name - -defn bundle-field-dir (t:Type, n:Symbol) -> Direction : - match(t) : - (t:BundleType) : - match(lookup-port(ports(t), n)) : - (p:Port) : direction(p) - (p) : UNKNOWN-DIR - (t) : UNKNOWN-DIR - -defn infer-dirs (m:Module) : - ;=== Direction of all Binders === - val BI-DIR = new Direction - val directions = HashTable<Symbol,Direction>(symbol-hash) - defn find-dirs (c:Stmt) : - match(c) : - (c:LetRec) : - for entry in entries(c) do : - directions[key(entry)] = OUTPUT - find-dirs(body(c)) - (c:DefWire) : - directions[name(c)] = BI-DIR - (c:DefRegister) : - directions[name(c)] = BI-DIR - (c:DefInstance) : - directions[name(c)] = OUTPUT - (c:DefMemory) : - directions[name(c)] = BI-DIR - (c:WDefAccessor) : - directions[name(c)] = dir(c) - (c) : - do(find-dirs, children(c)) - for p in ports(m) do : - directions[name(p)] = flip(direction(p)) - find-dirs(body(m)) - - ;=== Fix Point Status === - var changed? = false - - ;=== Infer directions of Expression === - defn infer-exp (e:Expression, desired:Direction) : - match(e) : - (e:WRef) : - val dir* = let : - if kind(e) typeof ModuleKind : - OUTPUT - else : - val old-dir = directions[name(e)] - switch {old-dir == _} : - BI-DIR : - desired - UNKNOWN-DIR : - if directions[name(e)] != desired : - directions[name(e)] = desired - changed? = true - desired - else : - old-dir - WRef(name(e), type(e), kind(e), dir*) - (e:WField) : - val port-dir = bundle-field-dir(type(exp(e)), name(e)) - val exp* = infer-exp(exp(e), port-dir * desired) - WField(exp*, name(e), type(e), port-dir * dir(exp*)) - (e:WIndex) : - val exp* = infer-exp(exp(e), desired) - WIndex(exp*, value(e), type(e), dir(exp*)) - (e) : - map(infer-exp{_, OUTPUT}, e) - - ;=== Infer directions of Stmts === - defn infer-comm (c:Stmt) : - match(c) : - (c:LetRec) : - val c* = map(infer-exp{_, OUTPUT}, c) - LetRec(entries(c*), infer-comm(body(c))) - (c:DefInstance) : - DefInstance(name(c), - infer-exp(module(c), OUTPUT)) - (c:WDefAccessor) : - val d = directions[name(c)] - WDefAccessor(name(c), - infer-exp(source(c), d), - infer-exp(index(c), OUTPUT), - d) - (c:Conditionally) : - Conditionally(infer-exp(pred(c), OUTPUT), - infer-comm(conseq(c)), - infer-comm(alt(c))) - (c:Connect) : - Connect(infer-exp(loc(c), INPUT), infer-exp(exp(c), OUTPUT)) - (c) : - map(infer-comm, c) - - ;=== Iterate until fix point === - defn* fixpoint (c:Stmt) : - changed? = false - val c* = infer-comm(c) - if changed? : fixpoint(c*) - else : c* - - Module(name(m), ports(m), body*) where : - val body* = fixpoint(body(m)) - -defn infer-directions (c:Circuit) : - Circuit(modules*, main(c)) where : - val modules* = map(infer-dirs, modules(c)) - - -;============== EXPAND VECS ================================ -defstruct ManyConnect <: Stmt : - index: Expression - locs: List<Expression> - exp: Expression - -defstruct ConnectMany <: Stmt : - index: Expression - loc: Expression - exps: List<Expression> - -defmethod print (o:OutputStream, c:ManyConnect) : - print-all(o, [locs(c) "[" index(c) "] := " exp(c)]) -defmethod print (o:OutputStream, c:ConnectMany) : - print-all(o, [loc(c) " := " exps(c) "[" index(c) "]"]) - -defmethod map (f: Expression -> Expression, c:ManyConnect) : - ManyConnect(f(index(c)), map(f, locs(c)), f(exp(c))) -defmethod map (f: Expression -> Expression, c:ConnectMany) : - ConnectMany(f(index(c)), f(loc(c)), map(f, exps(c))) - -defn expand-accessors (m: Module) : - defn expand (c:Stmt) : - match(c) : - (c:WDefAccessor) : - ;Is the source a memory? - val mem? = - match(source(c)) : - (r:WRef) : kind(r) typeof MemKind - (r) : false - - if mem? : - c - else : - switch {dir(c) == _} : - INPUT : - Begin(list( - DefWire(name(c), type(src-type)) - ManyConnect(index(c), elems, wire-ref))) - where : - val src-type = type(source(c)) as VectorType - val wire-ref = WRef(name(c), type(src-type), NodeKind(), OUTPUT) - val elems = to-list $ - for i in 0 to size(src-type) stream : - WIndex(source(c), i, type(src-type), INPUT) - OUTPUT : - Begin(list( - DefWire(name(c), type(src-type)) - ConnectMany(index(c), wire-ref, elems))) - where : - val src-type = type(source(c)) as VectorType - val wire-ref = WRef(name(c), type(src-type), NodeKind(), INPUT) - val elems = to-list $ - for i in 0 to size(src-type) stream : - WIndex(source(c), i, type(src-type), OUTPUT) - (c) : - map(expand, c) - Module(name(m), ports(m), expand(body(m))) - -defn expand-accessors (c:Circuit) : - Circuit(modules*, main(c)) where : - val modules* = map(expand-accessors, modules(c)) - - - -;=============== BUNDLE FLATTENING ========================= -defn prefix (prefix, suffix) : - symbol-join([prefix "/" suffix]) - -defn prefix-ports (pre:Symbol, ports:List<Port>) : - for p in ports map : - Port(prefix(pre, name(p)), direction(p), type(p)) - -defn flatten-ports (port:Port) -> List<Port> : - match(type(port)) : - (t:BundleType) : - val ports = map-append(flatten-ports, ports(t)) - for p in ports map : - Port(prefix(name(port), name(p)), - direction(port) * direction(p), - type(p)) - (t:VectorType) : - val type* = flatten-type(t) - flatten-ports(Port(name(port), direction(port), type*)) - (t:Type) : - list(port) - -defn flatten-type (t:Type) -> Type : - match(t) : - (t:BundleType) : - BundleType $ - map-append(flatten-ports, ports(t)) - (t:VectorType) : - flatten-type $ BundleType $ to-list $ - for i in 0 to size(t) stream : - Port(to-symbol(i), OUTPUT, type(t)) - (t:Type) : - t - -defn flatten-bundles (c:Circuit) : - defn flatten-exp (e:Expression) : - match(map(flatten-exp, e)) : - (e:UIntValue|SIntValue) : - e - (e:WRef) : - match(kind(e)) : - (k:MemKind|StructuralMemKind) : - val type* = map(flatten-type, type(e)) - put-type(e, type*) - (k) : - val type* = flatten-type(type(e)) - put-type(e, type*) - (e) : - val type* = flatten-type(type(e)) - put-type(e, type*) - - defn flatten-element (e:Element) : - val t* = flatten-type(type(e)) - match(map(flatten-exp, e)) : - (e:Register) : Register(t*, value(e), enable(e)) - (e:Memory) : Memory(t*, writers(e)) - (e:Node) : Node(t*, value(e)) - (e:Instance) : Instance(t*, module(e), ports(e)) - - defn flatten-comm (c:Stmt) : - match(c) : - (c:LetRec) : - val entries* = - for entry in entries(c) map : - key(entry) => flatten-element(value(entry)) - LetRec(entries*, flatten-comm(body(c))) - (c:DefWire) : - DefWire(name(c), flatten-type(type(c))) - (c:DefRegister) : - DefRegister(name(c), flatten-type(type(c))) - (c:DefMemory) : - val type* = map(flatten-type, type(c)) - DefMemory(name(c), type*) - (c) : - map{flatten-comm, _} $ - map(flatten-exp, c) - - defn flatten-module (m:Module) : - val ports* = map-append(flatten-ports, ports(m)) - val body* = flatten-comm(body(m)) - Module(name(m), ports*, body*) - - Circuit(modules*, main(c)) where : - val modules* = map(flatten-module, modules(c)) - - -;================== BUNDLE EXPANSION ======================= -defn expand-bundles (m:Module) : - - ;Collapse all field/index expressions - defn collapse-exp (e:Expression) -> Expression : - match(e) : - (e:WField) : - match(collapse-exp(exp(e))) : - (ei:WRef) : - if kind(ei) typeof InstanceKind : - e - else : - WRef(name*, type(e), kind(ei), dir(e)) where : - val name* = prefix(name(ei), name(e)) - (ei:WField) : - WField(exp(ei), name*, type(e), dir(e)) where : - val name* = prefix(name(ei), name(e)) - (e:WIndex) : - collapse-exp(WField(exp(e), name, type(e), dir(e))) where : - val name = to-symbol(value(e)) - (e) : - map(collapse-exp, e) - - ;Expand expressions - defn expand-exp (e:Expression) -> List<Expression> : - match(type(e)) : - (t:BundleType) : - for p in ports(t) map : - val dir* = direction(p) * dir(e) - collapse-exp(WField(e, name(p), type(p), dir*)) - (t) : - list(collapse-exp(e)) - - ;Expand commands - defn expand-comm (c:Stmt) : - match(c) : - (c:DefWire) : - match(type(c)) : - (t:BundleType) : - Begin $ - for p in ports(t) map : - DefWire(prefix(name(c), name(p)), type(p)) - (t) : - c - (c:DefRegister) : - match(type(c)) : - (t:BundleType) : - Begin $ - for p in ports(t) map : - DefRegister(prefix(name(c), name(p)), type(p)) - (t) : - c - (c:DefMemory) : - match(type(type(c))) : - (t:BundleType) : - Begin $ - for p in ports(t) map : - DefMemory(prefix(name(c), name(p)), type*) where : - val s = size(type(c)) - val type* = VectorType(type(p), s) - (t) : - c - (c:WDefAccessor) : - match(type(source(c))) : - (t:BundleType) : - val srcs = expand-exp(source(c)) - Begin $ - for (p in ports(t), src in srcs) map : - WDefAccessor(name*, src, index(c), dir*) where : - val name* = prefix(name(c), name(p)) - val dir* = direction(p) * dir(c) - (t) : - c - (c:Connect) : - val locs = expand-exp(loc(c)) - val exps = expand-exp(exp(c)) - Begin $ - for (l in locs, e in exps) map : - switch {dir(l) == _} : - INPUT : Connect(l, e) - OUTPUT : Connect(e, l) - (c:ManyConnect) : - val locs-list = transpose(map(expand-exp, locs(c))) - val exps = expand-exp(exp(c)) - Begin $ - for (locs in locs-list, e in exps) map : - switch {dir(e) == _} : - OUTPUT : ManyConnect(index(c), locs, e) - INPUT : ConnectMany(index(c), e, locs) - (c:ConnectMany) : - val locs = expand-exp(loc(c)) - val exps-list = transpose(map(expand-exp, exps(c))) - Begin $ - for (l in locs, exps in exps-list) map : - switch {dir(l) == _} : - INPUT : ConnectMany(index(c), l, exps) - OUTPUT : ManyConnect(index(c), exps, l) - (c) : - map{expand-comm, _} $ - map(collapse-exp, c) - - Module(name(m), ports(m), expand-comm(body(m))) - -defn expand-bundles (c:Circuit) : - Circuit(modules*, main(c)) where : - val modules* = map(expand-bundles, modules(c)) - - -;=========== CONVERT MULTI CONNECTS to WHEN ================ -defn expand-multi-connects (c:Circuit) : - defn equal-exp (e1:Expression, e2:Expression) : - DoPrim(EQUAL-OP, list(e1, e2), List(), UIntType(UnknownWidth())) - defn uint (i:Int) : - UIntValue(i, UnknownWidth()) - - defn expand-comm (c:Stmt) : - match(c) : - (c:ConnectMany) : - Begin $ to-list $ - for (i in 0 to false, e in exps(c)) stream : - Conditionally(equal-exp(index(c), uint(i)), - Connect(loc(c), e) - EmptyStmt()) - (c:ManyConnect) : - Begin $ to-list $ - for (i in 0 to false, l in locs(c)) stream : - Conditionally(equal-exp(index(c), uint(i)), - Connect(l, exp(c)) - EmptyStmt()) - (c) : - map(expand-comm, c) - - defn expand (m:Module) : - Module(name(m), ports(m), expand-comm(body(m))) - - Circuit(modules*, main(c)) where : - val modules* = map(expand, modules(c)) - - -;================ EXPAND WHENS ============================= -definterface SymbolicValue -defstruct ExpValue <: SymbolicValue : - exp: Expression -defstruct WhenValue <: SymbolicValue : - pred: Expression - conseq: SymbolicValue - alt: SymbolicValue -defstruct VoidValue <: SymbolicValue - -defmethod print (o:OutputStream, sv:SymbolicValue) : - match(sv) : - (sv:VoidValue) : print(o, "VOID") - (sv:WhenValue) : print-all(o, ["(" pred(sv) "? " conseq(sv) " : " alt(sv) ")"]) - (sv:ExpValue) : print(o, exp(sv)) - -defn key-eqv? (e1:Expression, e2:Expression) : - match(e1, e2) : - (e1:WRef, e2:WRef) : - name(e1) == name(e2) - (e1:WField, e2:WField) : - name(e1) == name(e2) and - key-eqv?(exp(e1), exp(e2)) - (e1, e2) : - false - -defn merge-env (pred: Expression, - con-env: List<KeyValue<Expression,SymbolicValue>>, - alt-env: List<KeyValue<Expression,SymbolicValue>>) : - val merged = Vector<KeyValue<Expression, SymbolicValue>>() - defn new-key? (k:Expression) : - for entry in merged none? : - key-eqv?(key(entry), k) - - defn sv (env:List<KeyValue<Expression,SymbolicValue>>, k:Expression) : - for entry in env search : - if key-eqv?(key(entry), k) : - value(entry) - - for k in stream(key, concat(con-env, alt-env)) do : - if new-key?(k) : - match(sv(con-env, k), sv(alt-env, k)) : - (a:SymbolicValue, b:SymbolicValue) : - if a == b : - add(merged, k => a) - else : - add(merged, k => WhenValue(pred, a, b)) - (a:SymbolicValue, b:False) : - add(merged, k => a) - (a:False, b:SymbolicValue) : - add(merged, k => b) - (a:False, b:False) : - false - - to-list(merged) - -defn simplify-env (env: List<KeyValue<Expression,SymbolicValue>>) : - val merged = Vector<KeyValue<Expression, SymbolicValue>>() - defn new-key? (k:Expression) : - for entry in merged none? : - key-eqv?(key(entry), k) - for entry in env do : - if new-key?(key(entry)) : - add(merged, entry) - to-list(merged) - -defn expand-whens (m:Module) : - val commands = Vector<Stmt>() - val elements = Vector<KeyValue<Symbol,Element>>() - defn eval (c:Stmt, env:List<KeyValue<Expression,SymbolicValue>>) -> - List<KeyValue<Expression,SymbolicValue>> : - match(c) : - (c:LetRec) : - do(add{elements, _}, entries(c)) - eval(body(c), env) - (c:DefWire) : - add(commands, c) - val wire-ref = WRef(name(c), type(c), NodeKind(), INPUT) - List(wire-ref => VoidValue(), env) - (c:DefRegister) : - add(commands, c) - val reg-ref = WRef(name(c), type(c), RegKind(), INPUT) - List(reg-ref => VoidValue(), env) - (c:DefInstance) : - add(commands, c) - val entries = let : - val module-type = type(module(c)) as BundleType - val input-ports = to-list $ - for p in ports(module-type) filter : - direction(p) == INPUT - val inst-ref = WRef(name(c), module-type, InstanceKind(), OUTPUT) - for p in input-ports map : - WField(inst-ref, name(p), type(p), INPUT) => VoidValue() - append(entries, env) - (c:DefMemory) : - add(commands, c) - env - (c:WDefAccessor) : - add(commands, c) - if dir(c) == INPUT : - val access-ref = WRef(name(c), type(source(c)), AccessorKind(), dir(c)) - List(access-ref => VoidValue(), env) - else : - env - (c:Conditionally) : - val con-env = eval(conseq(c), env) - val alt-env = eval(alt(c), env) - merge-env(pred(c), con-env, alt-env) - (c:Begin) : - var env:List<KeyValue<Expression,SymbolicValue>> = env - for c in body(c) do : - env = eval(c, env) - env - (c:Connect) : - List(loc(c) => ExpValue(exp(c)), env) - (c:EmptyStmt) : - env - - defn convert-symbolic (key:Expression, sv:SymbolicValue) : - match(sv) : - (sv:VoidValue) : - throw $ PassException $ string-join $ [ - "No default value for " key "."] - (sv:ExpValue) : - exp(sv) - (sv:WhenValue) : - defn multiplex-exp (pred:Expression, conseq:Expression, alt:Expression) : - DoPrim(MUX-OP, list(pred, conseq, alt), List(), type(conseq)) - multiplex-exp(pred(sv), - convert-symbolic(key, conseq(sv)) - convert-symbolic(key, alt(sv))) - - ;Compute final environment - val env0 = let : - val output-ports = to-list $ - for p in ports(m) filter : - direction(p) == OUTPUT - for p in output-ports map : - val port-ref = WRef(name(p), type(p), PortKind(), INPUT) - port-ref => VoidValue() - val env* = simplify-env(eval(body(m), env0)) - - ;Make new body - val body* = Begin(list(defs, LetRec(elems, connections))) where : - val defs = Begin(to-list(commands)) - val elems = to-list(elements) - val connections = Begin $ - for entry in env* map : - val sv = convert-symbolic(key(entry), value(entry)) - Connect(key(entry), sv) - - ;Final module - Module(name(m), ports(m), body*) - -defn expand-whens (c:Circuit) : - val modules* = map(expand-whens, modules(c)) - Circuit(modules*, main(c)) - - -;================ STRUCTURAL FORM ========================== - -defn structural-form (m:Module) : - val elements = Vector<(() -> KeyValue<Symbol,Element>)>() - val connected = HashTable<Symbol, Expression>(symbol-hash) - val write-accessors = HashTable<Symbol, List<WDefAccessor>>(symbol-hash) - val read-accessors = HashTable<Symbol, WDefAccessor>(symbol-hash) - val inst-ports = HashTable<Symbol, List<KeyValue<Symbol, Expression>>>(symbol-hash) - val port-connects = Vector<Connect>() - - defn scan (c:Stmt) : - match(c) : - (c:Connect) : - match(loc(c)) : - (loc:WRef) : - match(kind(loc)) : - (k:PortKind) : add(port-connects, c) - (k) : connected[name(loc)] = exp(c) - (loc:WField) : - val inst = exp(loc) as WRef - val entry = name(loc) => exp(c) - inst-ports[name(inst)] = List(entry, get?(inst-ports, name(inst), List())) - (c:LetRec) : - for e in entries(c) do : - add(elements, {e}) - scan(body(c)) - (c:DefWire) : - add{elements, _} $ fn () : - name(c) => Node(type(c), connected[name(c)]) - (c:DefRegister) : - add{elements, _} $ fn () : - val one = UIntValue(1, UnknownWidth()) - name(c) => Register(type(c), connected[name(c)], one) - (c:DefInstance) : - add{elements, _} $ fn () : - name(c) => Instance(UnknownType(), module(c), inst-ports[name(c)]) - (c:DefMemory) : - add{elements, _} $ fn () : - val ports = for a in get?(write-accessors, name(c), List()) map : - val one = UIntValue(1, UnknownWidth()) - WritePort(index(a), connected[name(a)], one) - name(c) => Memory(type(c), ports) - (c:WDefAccessor) : - val mem = source(c) as WRef - switch {dir(c) == _} : - INPUT : - write-accessors[name(mem)] = List(c, - get?(write-accessors, name(mem), List())) - OUTPUT : - read-accessors[name(c)] = c - (c) : - do(scan, children(c)) - - defn make-read-ports (e:Expression) : - match(e) : - (e:WRef) : - match(kind(e)) : - (k:AccessorKind) : - val accessor = read-accessors[name(e)] - ReadPort(source(accessor), index(accessor), type(e)) - (k) : e - (e) : map(make-read-ports, e) - - Module(name(m), ports(m), body*) where : - scan(body(m)) - val elems = to-list $ - for e in elements stream : - val entry = e() - key(entry) => map(make-read-ports, value(entry)) - val connect-ports = Begin $ to-list $ - for c in port-connects stream : - Connect(loc(c), make-read-ports(exp(c))) - val body* = - if empty?(elems) : connect-ports - else : LetRec(elems, connect-ports) - -defn structural-form (c:Circuit) : - val modules* = map(structural-form, modules(c)) - Circuit(modules*, main(c)) - - -;==================== WIDTH INFERENCE ====================== -defstruct WidthVar <: Width : - name: Symbol - -defmethod print (o:OutputStream, w:WidthVar) : - print(o, name(w)) - -defn width! (t:Type) : - match(t) : - (t:UIntType) : width(t) - (t:SIntType) : width(t) - (t) : error("No width field.") - -defn put-width (t:Type, w:Width) : - match(t) : - (t:UIntType) : UIntType(w) - (t:SIntType) : SIntType(w) - (t) : t - -defn put-width (e:Expression, w:Width) : - val type* = put-width(type(e), w) - put-type(e, type*) - -defn add-width-vars (t:Type) : - defn width? (w:Width) : - match(w) : - (w:UnknownWidth) : WidthVar(gensym()) - (w) : w - match(t) : - (t:UIntType) : UIntType(width?(width(t))) - (t:SIntType) : SIntType(width?(width(t))) - (t) : map(add-width-vars, t) - -defn uint-width (i:Int) : - var v:Int = i - var n:Int = 0 - while v != 0 : - v = v >> 1 - n = n + 1 - IntWidth(n) - -defn sint-width (i:Int) : - if i > 0 : - val w = uint-width(i) - IntWidth(width(w) + 1) - else : - val w = uint-width(neg(i) - 1) - IntWidth(width(w) + 1) - -defn to-exp (w:Width) : - match(w) : - (w:IntWidth) : ELit(width(w)) - (w:WidthVar) : EVar(name(w)) - (w) : error $ string-join $ [ - "Cannot convert " w " to exp."] - -defn primop-width (op:PrimOp, ws:List<Width>, ints:List<Int>) -> Exp : - defn wmax (w1:Width, w2:Width) : - EMax(to-exp(w1), to-exp(w2)) - defn wplus (w1:Width, w2:Width) : - EPlus(to-exp(w1), to-exp(w2)) - defn wplus (w1:Width, w2:Int) : - EPlus(to-exp(w1), ELit(w2)) - defn wminus (w1:Width, w2:Width) : - EMinus(to-exp(w1), to-exp(w2)) - defn wminus (w1:Width, w2:Int) : - EMinus(to-exp(w1), ELit(w2)) - defn wmax-inc (w1:Width, w2:Width) : - EPlus(wmax(w1, w2), ELit(1)) - - switch {op == _} : - ADD-OP : wmax-inc(ws[0], ws[1]) - ADD-WRAP-OP : wmax(ws[0], ws[1]) - SUB-OP : wmax-inc(ws[0], ws[1]) - SUB-WRAP-OP : wmax(ws[0], ws[1]) - MUL-OP : wplus(ws[0], ws[1]) - DIV-OP : wminus(ws[0], ws[1]) - MOD-OP : to-exp(ws[1]) - SHIFT-LEFT-OP : wplus(ws[0], ints[0]) - SHIFT-RIGHT-OP : wminus(ws[0], ints[0]) - PAD-OP : ELit(ints[0]) - BIT-AND-OP : wmax(ws[0], ws[1]) - BIT-OR-OP : wmax(ws[0], ws[1]) - BIT-XOR-OP : wmax(ws[0], ws[1]) - CONCAT-OP : wplus(ws[0], ints[0]) - BIT-SELECT-OP : ELit(1) - BITS-SELECT-OP : ELit(ints[0]) - MUX-OP : wmax(ws[1], ws[2]) - LESS-OP : ELit(1) - LESS-EQ-OP : ELit(1) - GREATER-OP : ELit(1) - GREATER-EQ-OP : ELit(1) - EQUAL-OP : ELit(1) - -defn put-type (el:Element, t:Type) : - match(el) : - (el:Register) : Register(t, value(el), enable(el)) - (el:Memory) : Memory(t, writers(el)) - (el:Node) : Node(t, value(el)) - (el:Instance) : Instance(t, module(el), ports(el)) - -defn generate-constraints (c:Circuit) -> [Circuit, Vector<WConstraint>] : - ;Constraints - val cs = Vector<WConstraint>() - defn new-constraint (Constraint: (Symbol, Exp) -> WConstraint, wvar:Width, width:Width) : - match(wvar) : - (wvar:WidthVar) : - add(cs, Constraint(name(wvar), to-exp(width))) - (wvar) : - false - - defn to-width (e:Exp) : - match(e) : - (e:ELit) : - IntWidth(width(e)) - (e:EVar) : - WidthVar(name(e)) - (e) : - val x = gensym() - add(cs, WidthEqual(x, e)) - WidthVar(x) - - ;Module types - val mod-types = HashTable<Symbol,Type>(symbol-hash) - - defn add-port-vars (m:Module) -> Module : - val ports* = - for p in ports(m) map : - val type* = add-width-vars(type(p)) - Port(name(p), direction(p), type*) - mod-types[name(m)] = BundleType(ports*) - Module(name(m), ports*, body(m)) - - ;Add Width Variables - defn add-module-vars (m:Module) -> Module : - val types = HashTable<Symbol,Type>(symbol-hash) - for p in ports(m) do : - types[name(p)] = type(p) - - defn infer-exp-width (e:Expression) : - match(map(infer-exp-width, e)) : - (e:WRef) : - match(kind(e)) : - (k:ModuleKind) : put-type(e, mod-types[name(e)]) - (k) : put-type(e, types[name(e)]) - (e:WField) : - val t = bundle-field-type(type(exp(e)), name(e)) - put-width(e, width!(t)) - (e:UIntValue) : - match(width(e)) : - (w:UnknownWidth) : UIntValue(value(e), uint-width(value(e))) - (w) : e - (e:SIntValue) : - match(width(e)) : - (w:UnknownWidth) : SIntValue(value(e), sint-width(value(e))) - (w) : e - (e:DoPrim) : - val widths = map(width!{type(_)}, args(e)) - val w = to-width(primop-width(op(e), widths, consts(e))) - put-width(e, w) - (e:ReadPort) : - val elem-type = type(type(mem(e)) as VectorType) - put-width(e, width!(elem-type)) - - defn infer-comm-width (c:Stmt) : - match(c) : - (c:LetRec) : - ;Add width vars to elements - var entries*: List<KeyValue<Symbol,Element>> = - for entry in entries(c) map : - val el-name = key(entry) - key(entry) => - match(value(entry)) : - (el:Register|Node) : - put-type(el, add-width-vars(type(el))) - (el:Memory) : - el - (el:Instance) : - val mod-type = type(infer-exp-width(module(el))) as BundleType - val type = BundleType $ to-list $ - for p in ports(mod-type) filter : - direction(p) == OUTPUT - put-type(el, type) - - ;Add vars to environment - for entry in entries* do : - types[key(entry)] = type(value(entry)) - - ;Infer types for elements - entries* = - for entry in entries* map : - key(entry) => map(infer-exp-width, value(entry)) - - ;Generate constraints - for entry in entries* do : - val el-name = key(entry) - match(value(entry)) : - (el:Register) : - new-constraint(WidthEqual, reg-width, val-width) where : - val reg-width = width!(types[el-name]) - val val-width = width!(type(value(el))) - (el:Node) : - new-constraint(WidthEqual, node-width, val-width) where : - val node-width = width!(types[el-name]) - val val-width = width!(type(value(el))) - (el:Instance) : - val mod-type = type(module(el)) as BundleType - for entry in ports(el) do : - new-constraint(WidthGreater, port-width, val-width) where : - val port-name = key(entry) - val port-width = width!(bundle-field-type(mod-type, port-name)) - val val-width = width!(type(value(entry))) - (el) : false - - ;Analyze body - LetRec(entries*, infer-comm-width(body(c))) - - (c:Connect) : - val loc* = infer-exp-width(loc(c)) - val exp* = infer-exp-width(exp(c)) - new-constraint(WidthGreater, loc-width, exp-width) where : - val loc-width = width!(type(loc*)) - val exp-width = width!(type(exp*)) - Connect(loc*, exp*) - - (c:Begin) : - map(infer-comm-width, c) - - Module(name(m), ports(m), body*) where : - val body* = infer-comm-width(body(m)) - - val c* = - Circuit(modules*, main(c)) where : - val ms = map(add-port-vars, modules(c)) - val modules* = map(add-module-vars, ms) - [c*, cs] - - -;================== FILL WIDTHS ============================ -defn fill-widths (c:Circuit, solved:Streamable<WidthEqual>) : - ;Populate table - val table = HashTable<Symbol, Width>(symbol-hash) - for eq in solved do : - table[name(eq)] = IntWidth(width(value(eq) as ELit)) - - defn width? (w:Width) : - match(w) : - (w:WidthVar) : get?(table, name(w), UnknownWidth()) - (w) : w - - defn fill-type (t:Type) : - match(t) : - (t:UIntType) : UIntType(width?(width(t))) - (t:SIntType) : SIntType(width?(width(t))) - (t) : map(fill-type, t) - - defn fill-exp (e:Expression) : - val e* = map(fill-exp, e) - val type* = fill-type(type(e)) - put-type(e*, type*) - - defn fill-element (e:Element) : - val e* = map(fill-exp, e) - val type* = fill-type(type(e)) - put-type(e*, type*) - - defn fill-comm (c:Stmt) : - match(c) : - (c:LetRec) : - val entries* = - for e in entries(c) map : - key(e) => fill-element(value(e)) - LetRec(entries*, fill-comm(body(c))) - (c) : - map{fill-comm, _} $ - map(fill-exp, c) - - defn fill-port (p:Port) : - Port(name(p), direction(p), fill-type(type(p))) - - defn fill-mod (m:Module) : - Module(name(m), ports*, body*) where : - val ports* = map(fill-port, ports(m)) - val body* = fill-comm(body(m)) - - Circuit(modules*, main(c)) where : - val modules* = map(fill-mod, modules(c)) - - -;=============== TYPE INFERENCE DRIVER ===================== -defn infer-widths (c:Circuit) : - val [c*, cs] = generate-constraints(c) - val solved = solve-widths(cs) - fill-widths(c*, solved) - - -;================ PAD WIDTHS =============================== -defn pad-widths (c:Circuit) : - ;Pad an expression to the given width - defn pad-exp (e:Expression, w:Int) : - match(type(e)) : - (t:UIntType|SIntType) : - val prev-w = width!(t) as IntWidth - if width(prev-w) < w : - val type* = put-width(t, IntWidth(w)) - DoPrim(PAD-OP, list(e), list(w), type*) - else : - e - (t) : - e - - defn pad-exp (e:Expression, w:Width) : - val w-value = width(w as IntWidth) - pad-exp(e, w-value) - - ;Convenience - defn max-width (es:Streamable<Expression>) : - defn int-width (e:Expression) : - width(width!(type(e)) as IntWidth) - maximum(stream(int-width, es)) - - defn match-widths (es:List<Expression>) : - val w = max-width(es) - map(pad-exp{_, w}, es) - - ;Match widths for an expression - defn match-exp-width (e:Expression) : - match(map(match-exp-width, e)) : - (e:DoPrim) : - if contains?([BIT-AND-OP, BIT-OR-OP, BIT-XOR-OP, EQUAL-OP], op(e)) : - val args* = match-widths(args(e)) - DoPrim(op(e), args*, consts(e), type(e)) - else if op(e) == MUX-OP : - val args* = List(head(args(e)), match-widths(tail(args(e)))) - DoPrim(op(e), args*, consts(e), type(e)) - else : - e - (e) : e - - defn match-element-width (e:Element) : - match(map(match-exp-width, e)) : - (e:Register) : - val w = width!(type(e)) - val value* = pad-exp(value(e), w) - Register(type(e), value*, enable(e)) - (e:Memory) : - val width = width!(type(type(e) as VectorType)) - val writers* = - for w in writers(e) map : - WritePort(index(w), pad-exp(value(w), width), enable(w)) - Memory(type(e), writers*) - (e:Node) : - val w = width!(type(e)) - val value* = pad-exp(value(e), w) - Node(type(e), value*) - (e:Instance) : - val mod-type = type(module(e)) as BundleType - val ports* = - for p in ports(e) map : - val port-type = bundle-field-type(mod-type, key(p)) - val port-val = pad-exp(value(p), width!(port-type)) - key(p) => port-val - Instance(type(e), module(e), ports*) - - ;Match widths for a command - defn match-comm-width (c:Stmt) : - match(map(match-exp-width, c)) : - (c:LetRec) : - val entries* = - for e in entries(c) map : - key(e) => match-element-width(value(e)) - LetRec(entries*, match-comm-width(body(c))) - (c:Connect) : - val w = width!(type(loc(c))) - val exp* = pad-exp(exp(c), w) - Connect(loc(c), exp*) - (c) : - map(match-comm-width, c) - - defn match-mod-width (m:Module) : - Module(name(m), ports(m), body*) where : - val body* = match-comm-width(body(m)) - - Circuit(modules*, main(c)) where : - val modules* = map(match-mod-width, modules(c)) - - -;================== INLINING =============================== -defn inline-instances (c:Circuit) : - val module-table = HashTable<Symbol,Module>(symbol-hash) - val inlined? = HashTable<Symbol,True|False>(symbol-hash) - for m in modules(c) do : - module-table[name(m)] = m - inlined?[name(m)] = false - - ;Convert a module into a sequence of elements - defn to-elements (m:Module, - inst:Symbol, - port-exps:List<KeyValue<Symbol,Expression>>) -> - List<KeyValue<Symbol, Element>> : - defn rename-exp (e:Expression) : - match(e) : - (e:WRef) : WRef(prefix(inst, name(e)), type(e), kind(e), dir(e)) - (e) : map(rename-exp, e) - - defn to-elements (c:Stmt) -> List<KeyValue<Symbol,Element>> : - match(c) : - (c:LetRec) : - val entries* = - for entry in entries(c) map : - val name* = prefix(inst, key(entry)) - val element* = map(rename-exp, value(entry)) - name* => element* - val body* = to-elements(body(c)) - append(entries*, body*) - (c:Connect) : - val ref = loc(c) as WRef - val name* = prefix(inst, name(ref)) - list(name* => Node(type(exp(c)), rename-exp(exp(c)))) - (c:Begin) : - map-append(to-elements, body(c)) - - val inputs = - for p in ports(m) map-append : - if direction(p) == INPUT : - val port-exp = lookup!(port-exps, name(p)) - val name* = prefix(inst, name(p)) - list(name* => Node(type(port-exp), port-exp)) - else : - List() - append(inputs, to-elements(body(m))) - - ;Inline all instances in the module - defn inline-instances (m:Module) : - defn rename-exp (e:Expression) : - match(e) : - (e:WField) : - val inst-exp = exp(e) as WRef - val name* = prefix(name(inst-exp), name(e)) - WRef(name*, type(e), NodeKind(), dir(e)) - (e) : - map(rename-exp, e) - - defn inline-elems (es:List<KeyValue<Symbol,Element>>) : - for entry in es map-append : - match(value(entry)) : - (el:Instance) : - val mod-name = name(module(el) as WRef) - val module = inlined-module(mod-name) - to-elements(module, key(entry), ports(el)) - (el) : - list(entry) - - defn inline-comm (c:Stmt) : - match(map(rename-exp, c)) : - (c:LetRec) : - val entries* = inline-elems(entries(c)) - LetRec(entries*, inline-comm(body(c))) - (c) : - map(inline-comm, c) - - Module(name(m), ports(m), inline-comm(body(m))) - - ;Retrieve the inlined instance of a module - defn inlined-module (name:Symbol) : - if inlined?[name] : - module-table[name] - else : - val module* = inline-instances(module-table[name]) - module-table[name] = module* - inlined?[name] = true - module* - - ;Return the fully inlined circuit - val main-module = inlined-module(main(c)) - Circuit(list(main-module), main(c)) +;defn flip (d:Direction) : +; switch {d == _} : +; INPUT : OUTPUT +; OUTPUT : INPUT +; else : d +; +;defn times (d1:Direction, d2:Direction) : +; if d1 == INPUT : flip(d2) +; else : d2 +; +;defn lookup-port (ports: Streamable<Port>, port-name: Symbol) : +; for port in ports find : +; name(port) == port-name +; +;defn bundle-field-dir (t:Type, n:Symbol) -> Direction : +; match(t) : +; (t:BundleType) : +; match(lookup-port(ports(t), n)) : +; (p:Port) : direction(p) +; (p) : UNKNOWN-DIR +; (t) : UNKNOWN-DIR +; +;defn infer-dirs (m:Module) : +; ;=== Direction of all Binders === +; val BI-DIR = new Direction +; val directions = HashTable<Symbol,Direction>(symbol-hash) +; defn find-dirs (c:Stmt) : +; match(c) : +; (c:LetRec) : +; for entry in entries(c) do : +; directions[key(entry)] = OUTPUT +; find-dirs(body(c)) +; (c:DefWire) : +; directions[name(c)] = BI-DIR +; (c:DefRegister) : +; directions[name(c)] = BI-DIR +; (c:DefInstance) : +; directions[name(c)] = OUTPUT +; (c:DefMemory) : +; directions[name(c)] = BI-DIR +; (c:WDefAccessor) : +; directions[name(c)] = dir(c) +; (c) : +; do(find-dirs, children(c)) +; for p in ports(m) do : +; directions[name(p)] = flip(direction(p)) +; find-dirs(body(m)) +; +; ;=== Fix Point Status === +; var changed? = false +; +; ;=== Infer directions of Expression === +; defn infer-exp (e:Expression, desired:Direction) : +; match(e) : +; (e:WRef) : +; val dir* = let : +; if kind(e) typeof ModuleKind : +; OUTPUT +; else : +; val old-dir = directions[name(e)] +; switch {old-dir == _} : +; BI-DIR : +; desired +; UNKNOWN-DIR : +; if directions[name(e)] != desired : +; directions[name(e)] = desired +; changed? = true +; desired +; else : +; old-dir +; WRef(name(e), type(e), kind(e), dir*) +; (e:WField) : +; val port-dir = bundle-field-dir(type(exp(e)), name(e)) +; val exp* = infer-exp(exp(e), port-dir * desired) +; WField(exp*, name(e), type(e), port-dir * dir(exp*)) +; (e:WIndex) : +; val exp* = infer-exp(exp(e), desired) +; WIndex(exp*, value(e), type(e), dir(exp*)) +; (e) : +; map(infer-exp{_, OUTPUT}, e) +; +; ;=== Infer directions of Stmts === +; defn infer-comm (c:Stmt) : +; match(c) : +; (c:LetRec) : +; val c* = map(infer-exp{_, OUTPUT}, c) +; LetRec(entries(c*), infer-comm(body(c))) +; (c:DefInstance) : +; DefInstance(name(c), +; infer-exp(module(c), OUTPUT)) +; (c:WDefAccessor) : +; val d = directions[name(c)] +; WDefAccessor(name(c), +; infer-exp(source(c), d), +; infer-exp(index(c), OUTPUT), +; d) +; (c:Conditionally) : +; Conditionally(infer-exp(pred(c), OUTPUT), +; infer-comm(conseq(c)), +; infer-comm(alt(c))) +; (c:Connect) : +; Connect(infer-exp(loc(c), INPUT), infer-exp(exp(c), OUTPUT)) +; (c) : +; map(infer-comm, c) +; +; ;=== Iterate until fix point === +; defn* fixpoint (c:Stmt) : +; changed? = false +; val c* = infer-comm(c) +; if changed? : fixpoint(c*) +; else : c* +; +; Module(name(m), ports(m), body*) where : +; val body* = fixpoint(body(m)) +; +;defn infer-directions (c:Circuit) : +; Circuit(modules*, main(c)) where : +; val modules* = map(infer-dirs, modules(c)) +; +; +;;============== EXPAND VECS ================================ +;defstruct ManyConnect <: Stmt : +; index: Expression +; locs: List<Expression> +; exp: Expression +; +;defstruct ConnectMany <: Stmt : +; index: Expression +; loc: Expression +; exps: List<Expression> +; +;defmethod print (o:OutputStream, c:ManyConnect) : +; print-all(o, [locs(c) "[" index(c) "] := " exp(c)]) +;defmethod print (o:OutputStream, c:ConnectMany) : +; print-all(o, [loc(c) " := " exps(c) "[" index(c) "]"]) +; +;defmethod map (f: Expression -> Expression, c:ManyConnect) : +; ManyConnect(f(index(c)), map(f, locs(c)), f(exp(c))) +;defmethod map (f: Expression -> Expression, c:ConnectMany) : +; ConnectMany(f(index(c)), f(loc(c)), map(f, exps(c))) +; +;defn expand-accessors (m: Module) : +; defn expand (c:Stmt) : +; match(c) : +; (c:WDefAccessor) : +; ;Is the source a memory? +; val mem? = +; match(source(c)) : +; (r:WRef) : kind(r) typeof MemKind +; (r) : false +; +; if mem? : +; c +; else : +; switch {dir(c) == _} : +; INPUT : +; Begin(list( +; DefWire(name(c), type(src-type)) +; ManyConnect(index(c), elems, wire-ref))) +; where : +; val src-type = type(source(c)) as VectorType +; val wire-ref = WRef(name(c), type(src-type), NodeKind(), OUTPUT) +; val elems = to-list $ +; for i in 0 to size(src-type) stream : +; WIndex(source(c), i, type(src-type), INPUT) +; OUTPUT : +; Begin(list( +; DefWire(name(c), type(src-type)) +; ConnectMany(index(c), wire-ref, elems))) +; where : +; val src-type = type(source(c)) as VectorType +; val wire-ref = WRef(name(c), type(src-type), NodeKind(), INPUT) +; val elems = to-list $ +; for i in 0 to size(src-type) stream : +; WIndex(source(c), i, type(src-type), OUTPUT) +; (c) : +; map(expand, c) +; Module(name(m), ports(m), expand(body(m))) +; +;defn expand-accessors (c:Circuit) : +; Circuit(modules*, main(c)) where : +; val modules* = map(expand-accessors, modules(c)) +; +; +; +;;=============== BUNDLE FLATTENING ========================= +;defn prefix (prefix, suffix) : +; symbol-join([prefix "/" suffix]) +; +;defn prefix-ports (pre:Symbol, ports:List<Port>) : +; for p in ports map : +; Port(prefix(pre, name(p)), direction(p), type(p)) +; +;defn flatten-ports (port:Port) -> List<Port> : +; match(type(port)) : +; (t:BundleType) : +; val ports = map-append(flatten-ports, ports(t)) +; for p in ports map : +; Port(prefix(name(port), name(p)), +; direction(port) * direction(p), +; type(p)) +; (t:VectorType) : +; val type* = flatten-type(t) +; flatten-ports(Port(name(port), direction(port), type*)) +; (t:Type) : +; list(port) +; +;defn flatten-type (t:Type) -> Type : +; match(t) : +; (t:BundleType) : +; BundleType $ +; map-append(flatten-ports, ports(t)) +; (t:VectorType) : +; flatten-type $ BundleType $ to-list $ +; for i in 0 to size(t) stream : +; Port(to-symbol(i), OUTPUT, type(t)) +; (t:Type) : +; t +; +;defn flatten-bundles (c:Circuit) : +; defn flatten-exp (e:Expression) : +; match(map(flatten-exp, e)) : +; (e:UIntValue|SIntValue) : +; e +; (e:WRef) : +; match(kind(e)) : +; (k:MemKind|StructuralMemKind) : +; val type* = map(flatten-type, type(e)) +; put-type(e, type*) +; (k) : +; val type* = flatten-type(type(e)) +; put-type(e, type*) +; (e) : +; val type* = flatten-type(type(e)) +; put-type(e, type*) +; +; defn flatten-element (e:Element) : +; val t* = flatten-type(type(e)) +; match(map(flatten-exp, e)) : +; (e:Register) : Register(t*, value(e), enable(e)) +; (e:Memory) : Memory(t*, writers(e)) +; (e:Node) : Node(t*, value(e)) +; (e:Instance) : Instance(t*, module(e), ports(e)) +; +; defn flatten-comm (c:Stmt) : +; match(c) : +; (c:LetRec) : +; val entries* = +; for entry in entries(c) map : +; key(entry) => flatten-element(value(entry)) +; LetRec(entries*, flatten-comm(body(c))) +; (c:DefWire) : +; DefWire(name(c), flatten-type(type(c))) +; (c:DefRegister) : +; DefRegister(name(c), flatten-type(type(c))) +; (c:DefMemory) : +; val type* = map(flatten-type, type(c)) +; DefMemory(name(c), type*) +; (c) : +; map{flatten-comm, _} $ +; map(flatten-exp, c) +; +; defn flatten-module (m:Module) : +; val ports* = map-append(flatten-ports, ports(m)) +; val body* = flatten-comm(body(m)) +; Module(name(m), ports*, body*) +; +; Circuit(modules*, main(c)) where : +; val modules* = map(flatten-module, modules(c)) +; +; +;;================== BUNDLE EXPANSION ======================= +;defn expand-bundles (m:Module) : +; +; ;Collapse all field/index expressions +; defn collapse-exp (e:Expression) -> Expression : +; match(e) : +; (e:WField) : +; match(collapse-exp(exp(e))) : +; (ei:WRef) : +; if kind(ei) typeof InstanceKind : +; e +; else : +; WRef(name*, type(e), kind(ei), dir(e)) where : +; val name* = prefix(name(ei), name(e)) +; (ei:WField) : +; WField(exp(ei), name*, type(e), dir(e)) where : +; val name* = prefix(name(ei), name(e)) +; (e:WIndex) : +; collapse-exp(WField(exp(e), name, type(e), dir(e))) where : +; val name = to-symbol(value(e)) +; (e) : +; map(collapse-exp, e) +; +; ;Expand expressions +; defn expand-exp (e:Expression) -> List<Expression> : +; match(type(e)) : +; (t:BundleType) : +; for p in ports(t) map : +; val dir* = direction(p) * dir(e) +; collapse-exp(WField(e, name(p), type(p), dir*)) +; (t) : +; list(collapse-exp(e)) +; +; ;Expand commands +; defn expand-comm (c:Stmt) : +; match(c) : +; (c:DefWire) : +; match(type(c)) : +; (t:BundleType) : +; Begin $ +; for p in ports(t) map : +; DefWire(prefix(name(c), name(p)), type(p)) +; (t) : +; c +; (c:DefRegister) : +; match(type(c)) : +; (t:BundleType) : +; Begin $ +; for p in ports(t) map : +; DefRegister(prefix(name(c), name(p)), type(p)) +; (t) : +; c +; (c:DefMemory) : +; match(type(type(c))) : +; (t:BundleType) : +; Begin $ +; for p in ports(t) map : +; DefMemory(prefix(name(c), name(p)), type*) where : +; val s = size(type(c)) +; val type* = VectorType(type(p), s) +; (t) : +; c +; (c:WDefAccessor) : +; match(type(source(c))) : +; (t:BundleType) : +; val srcs = expand-exp(source(c)) +; Begin $ +; for (p in ports(t), src in srcs) map : +; WDefAccessor(name*, src, index(c), dir*) where : +; val name* = prefix(name(c), name(p)) +; val dir* = direction(p) * dir(c) +; (t) : +; c +; (c:Connect) : +; val locs = expand-exp(loc(c)) +; val exps = expand-exp(exp(c)) +; Begin $ +; for (l in locs, e in exps) map : +; switch {dir(l) == _} : +; INPUT : Connect(l, e) +; OUTPUT : Connect(e, l) +; (c:ManyConnect) : +; val locs-list = transpose(map(expand-exp, locs(c))) +; val exps = expand-exp(exp(c)) +; Begin $ +; for (locs in locs-list, e in exps) map : +; switch {dir(e) == _} : +; OUTPUT : ManyConnect(index(c), locs, e) +; INPUT : ConnectMany(index(c), e, locs) +; (c:ConnectMany) : +; val locs = expand-exp(loc(c)) +; val exps-list = transpose(map(expand-exp, exps(c))) +; Begin $ +; for (l in locs, exps in exps-list) map : +; switch {dir(l) == _} : +; INPUT : ConnectMany(index(c), l, exps) +; OUTPUT : ManyConnect(index(c), exps, l) +; (c) : +; map{expand-comm, _} $ +; map(collapse-exp, c) +; +; Module(name(m), ports(m), expand-comm(body(m))) +; +;defn expand-bundles (c:Circuit) : +; Circuit(modules*, main(c)) where : +; val modules* = map(expand-bundles, modules(c)) +; +; +;;=========== CONVERT MULTI CONNECTS to WHEN ================ +;defn expand-multi-connects (c:Circuit) : +; defn equal-exp (e1:Expression, e2:Expression) : +; DoPrim(EQUAL-OP, list(e1, e2), List(), UIntType(UnknownWidth())) +; defn uint (i:Int) : +; UIntValue(i, UnknownWidth()) +; +; defn expand-comm (c:Stmt) : +; match(c) : +; (c:ConnectMany) : +; Begin $ to-list $ +; for (i in 0 to false, e in exps(c)) stream : +; Conditionally(equal-exp(index(c), uint(i)), +; Connect(loc(c), e) +; EmptyStmt()) +; (c:ManyConnect) : +; Begin $ to-list $ +; for (i in 0 to false, l in locs(c)) stream : +; Conditionally(equal-exp(index(c), uint(i)), +; Connect(l, exp(c)) +; EmptyStmt()) +; (c) : +; map(expand-comm, c) +; +; defn expand (m:Module) : +; Module(name(m), ports(m), expand-comm(body(m))) +; +; Circuit(modules*, main(c)) where : +; val modules* = map(expand, modules(c)) +; +; +;;================ EXPAND WHENS ============================= +;definterface SymbolicValue +;defstruct ExpValue <: SymbolicValue : +; exp: Expression +;defstruct WhenValue <: SymbolicValue : +; pred: Expression +; conseq: SymbolicValue +; alt: SymbolicValue +;defstruct VoidValue <: SymbolicValue +; +;defmethod print (o:OutputStream, sv:SymbolicValue) : +; match(sv) : +; (sv:VoidValue) : print(o, "VOID") +; (sv:WhenValue) : print-all(o, ["(" pred(sv) "? " conseq(sv) " : " alt(sv) ")"]) +; (sv:ExpValue) : print(o, exp(sv)) +; +;defn key-eqv? (e1:Expression, e2:Expression) : +; match(e1, e2) : +; (e1:WRef, e2:WRef) : +; name(e1) == name(e2) +; (e1:WField, e2:WField) : +; name(e1) == name(e2) and +; key-eqv?(exp(e1), exp(e2)) +; (e1, e2) : +; false +; +;defn merge-env (pred: Expression, +; con-env: List<KeyValue<Expression,SymbolicValue>>, +; alt-env: List<KeyValue<Expression,SymbolicValue>>) : +; val merged = Vector<KeyValue<Expression, SymbolicValue>>() +; defn new-key? (k:Expression) : +; for entry in merged none? : +; key-eqv?(key(entry), k) +; +; defn sv (env:List<KeyValue<Expression,SymbolicValue>>, k:Expression) : +; for entry in env search : +; if key-eqv?(key(entry), k) : +; value(entry) +; +; for k in stream(key, concat(con-env, alt-env)) do : +; if new-key?(k) : +; match(sv(con-env, k), sv(alt-env, k)) : +; (a:SymbolicValue, b:SymbolicValue) : +; if a == b : +; add(merged, k => a) +; else : +; add(merged, k => WhenValue(pred, a, b)) +; (a:SymbolicValue, b:False) : +; add(merged, k => a) +; (a:False, b:SymbolicValue) : +; add(merged, k => b) +; (a:False, b:False) : +; false +; +; to-list(merged) +; +;defn simplify-env (env: List<KeyValue<Expression,SymbolicValue>>) : +; val merged = Vector<KeyValue<Expression, SymbolicValue>>() +; defn new-key? (k:Expression) : +; for entry in merged none? : +; key-eqv?(key(entry), k) +; for entry in env do : +; if new-key?(key(entry)) : +; add(merged, entry) +; to-list(merged) +; +;defn expand-whens (m:Module) : +; val commands = Vector<Stmt>() +; val elements = Vector<KeyValue<Symbol,Element>>() +; defn eval (c:Stmt, env:List<KeyValue<Expression,SymbolicValue>>) -> +; List<KeyValue<Expression,SymbolicValue>> : +; match(c) : +; (c:LetRec) : +; do(add{elements, _}, entries(c)) +; eval(body(c), env) +; (c:DefWire) : +; add(commands, c) +; val wire-ref = WRef(name(c), type(c), NodeKind(), INPUT) +; List(wire-ref => VoidValue(), env) +; (c:DefRegister) : +; add(commands, c) +; val reg-ref = WRef(name(c), type(c), RegKind(), INPUT) +; List(reg-ref => VoidValue(), env) +; (c:DefInstance) : +; add(commands, c) +; val entries = let : +; val module-type = type(module(c)) as BundleType +; val input-ports = to-list $ +; for p in ports(module-type) filter : +; direction(p) == INPUT +; val inst-ref = WRef(name(c), module-type, InstanceKind(), OUTPUT) +; for p in input-ports map : +; WField(inst-ref, name(p), type(p), INPUT) => VoidValue() +; append(entries, env) +; (c:DefMemory) : +; add(commands, c) +; env +; (c:WDefAccessor) : +; add(commands, c) +; if dir(c) == INPUT : +; val access-ref = WRef(name(c), type(source(c)), AccessorKind(), dir(c)) +; List(access-ref => VoidValue(), env) +; else : +; env +; (c:Conditionally) : +; val con-env = eval(conseq(c), env) +; val alt-env = eval(alt(c), env) +; merge-env(pred(c), con-env, alt-env) +; (c:Begin) : +; var env:List<KeyValue<Expression,SymbolicValue>> = env +; for c in body(c) do : +; env = eval(c, env) +; env +; (c:Connect) : +; List(loc(c) => ExpValue(exp(c)), env) +; (c:EmptyStmt) : +; env +; +; defn convert-symbolic (key:Expression, sv:SymbolicValue) : +; match(sv) : +; (sv:VoidValue) : +; throw $ PassException $ string-join $ [ +; "No default value for " key "."] +; (sv:ExpValue) : +; exp(sv) +; (sv:WhenValue) : +; defn multiplex-exp (pred:Expression, conseq:Expression, alt:Expression) : +; DoPrim(MUX-OP, list(pred, conseq, alt), List(), type(conseq)) +; multiplex-exp(pred(sv), +; convert-symbolic(key, conseq(sv)) +; convert-symbolic(key, alt(sv))) +; +; ;Compute final environment +; val env0 = let : +; val output-ports = to-list $ +; for p in ports(m) filter : +; direction(p) == OUTPUT +; for p in output-ports map : +; val port-ref = WRef(name(p), type(p), PortKind(), INPUT) +; port-ref => VoidValue() +; val env* = simplify-env(eval(body(m), env0)) +; +; ;Make new body +; val body* = Begin(list(defs, LetRec(elems, connections))) where : +; val defs = Begin(to-list(commands)) +; val elems = to-list(elements) +; val connections = Begin $ +; for entry in env* map : +; val sv = convert-symbolic(key(entry), value(entry)) +; Connect(key(entry), sv) +; +; ;Final module +; Module(name(m), ports(m), body*) +; +;defn expand-whens (c:Circuit) : +; val modules* = map(expand-whens, modules(c)) +; Circuit(modules*, main(c)) +; +; +;;================ STRUCTURAL FORM ========================== +; +;defn structural-form (m:Module) : +; val elements = Vector<(() -> KeyValue<Symbol,Element>)>() +; val connected = HashTable<Symbol, Expression>(symbol-hash) +; val write-accessors = HashTable<Symbol, List<WDefAccessor>>(symbol-hash) +; val read-accessors = HashTable<Symbol, WDefAccessor>(symbol-hash) +; val inst-ports = HashTable<Symbol, List<KeyValue<Symbol, Expression>>>(symbol-hash) +; val port-connects = Vector<Connect>() +; +; defn scan (c:Stmt) : +; match(c) : +; (c:Connect) : +; match(loc(c)) : +; (loc:WRef) : +; match(kind(loc)) : +; (k:PortKind) : add(port-connects, c) +; (k) : connected[name(loc)] = exp(c) +; (loc:WField) : +; val inst = exp(loc) as WRef +; val entry = name(loc) => exp(c) +; inst-ports[name(inst)] = List(entry, get?(inst-ports, name(inst), List())) +; (c:LetRec) : +; for e in entries(c) do : +; add(elements, {e}) +; scan(body(c)) +; (c:DefWire) : +; add{elements, _} $ fn () : +; name(c) => Node(type(c), connected[name(c)]) +; (c:DefRegister) : +; add{elements, _} $ fn () : +; val one = UIntValue(1, UnknownWidth()) +; name(c) => Register(type(c), connected[name(c)], one) +; (c:DefInstance) : +; add{elements, _} $ fn () : +; name(c) => Instance(UnknownType(), module(c), inst-ports[name(c)]) +; (c:DefMemory) : +; add{elements, _} $ fn () : +; val ports = for a in get?(write-accessors, name(c), List()) map : +; val one = UIntValue(1, UnknownWidth()) +; WritePort(index(a), connected[name(a)], one) +; name(c) => Memory(type(c), ports) +; (c:WDefAccessor) : +; val mem = source(c) as WRef +; switch {dir(c) == _} : +; INPUT : +; write-accessors[name(mem)] = List(c, +; get?(write-accessors, name(mem), List())) +; OUTPUT : +; read-accessors[name(c)] = c +; (c) : +; do(scan, children(c)) +; +; defn make-read-ports (e:Expression) : +; match(e) : +; (e:WRef) : +; match(kind(e)) : +; (k:AccessorKind) : +; val accessor = read-accessors[name(e)] +; ReadPort(source(accessor), index(accessor), type(e)) +; (k) : e +; (e) : map(make-read-ports, e) +; +; Module(name(m), ports(m), body*) where : +; scan(body(m)) +; val elems = to-list $ +; for e in elements stream : +; val entry = e() +; key(entry) => map(make-read-ports, value(entry)) +; val connect-ports = Begin $ to-list $ +; for c in port-connects stream : +; Connect(loc(c), make-read-ports(exp(c))) +; val body* = +; if empty?(elems) : connect-ports +; else : LetRec(elems, connect-ports) +; +;defn structural-form (c:Circuit) : +; val modules* = map(structural-form, modules(c)) +; Circuit(modules*, main(c)) +; +; +;;==================== WIDTH INFERENCE ====================== +;defstruct WidthVar <: Width : +; name: Symbol +; +;defmethod print (o:OutputStream, w:WidthVar) : +; print(o, name(w)) +; +;defn width! (t:Type) : +; match(t) : +; (t:UIntType) : width(t) +; (t:SIntType) : width(t) +; (t) : error("No width field.") +; +;defn put-width (t:Type, w:Width) : +; match(t) : +; (t:UIntType) : UIntType(w) +; (t:SIntType) : SIntType(w) +; (t) : t +; +;defn put-width (e:Expression, w:Width) : +; val type* = put-width(type(e), w) +; put-type(e, type*) +; +;defn add-width-vars (t:Type) : +; defn width? (w:Width) : +; match(w) : +; (w:UnknownWidth) : WidthVar(gensym()) +; (w) : w +; match(t) : +; (t:UIntType) : UIntType(width?(width(t))) +; (t:SIntType) : SIntType(width?(width(t))) +; (t) : map(add-width-vars, t) +; +;defn uint-width (i:Int) : +; var v:Int = i +; var n:Int = 0 +; while v != 0 : +; v = v >> 1 +; n = n + 1 +; IntWidth(n) +; +;defn sint-width (i:Int) : +; if i > 0 : +; val w = uint-width(i) +; IntWidth(width(w) + 1) +; else : +; val w = uint-width(neg(i) - 1) +; IntWidth(width(w) + 1) +; +;defn to-exp (w:Width) : +; match(w) : +; (w:IntWidth) : ELit(width(w)) +; (w:WidthVar) : EVar(name(w)) +; (w) : error $ string-join $ [ +; "Cannot convert " w " to exp."] +; +;defn primop-width (op:PrimOp, ws:List<Width>, ints:List<Int>) -> Exp : +; defn wmax (w1:Width, w2:Width) : +; EMax(to-exp(w1), to-exp(w2)) +; defn wplus (w1:Width, w2:Width) : +; EPlus(to-exp(w1), to-exp(w2)) +; defn wplus (w1:Width, w2:Int) : +; EPlus(to-exp(w1), ELit(w2)) +; defn wminus (w1:Width, w2:Width) : +; EMinus(to-exp(w1), to-exp(w2)) +; defn wminus (w1:Width, w2:Int) : +; EMinus(to-exp(w1), ELit(w2)) +; defn wmax-inc (w1:Width, w2:Width) : +; EPlus(wmax(w1, w2), ELit(1)) +; +; switch {op == _} : +; ADD-OP : wmax-inc(ws[0], ws[1]) +; ADD-WRAP-OP : wmax(ws[0], ws[1]) +; SUB-OP : wmax-inc(ws[0], ws[1]) +; SUB-WRAP-OP : wmax(ws[0], ws[1]) +; MUL-OP : wplus(ws[0], ws[1]) +; DIV-OP : wminus(ws[0], ws[1]) +; MOD-OP : to-exp(ws[1]) +; SHIFT-LEFT-OP : wplus(ws[0], ints[0]) +; SHIFT-RIGHT-OP : wminus(ws[0], ints[0]) +; PAD-OP : ELit(ints[0]) +; BIT-AND-OP : wmax(ws[0], ws[1]) +; BIT-OR-OP : wmax(ws[0], ws[1]) +; BIT-XOR-OP : wmax(ws[0], ws[1]) +; CONCAT-OP : wplus(ws[0], ints[0]) +; BIT-SELECT-OP : ELit(1) +; BITS-SELECT-OP : ELit(ints[0]) +; MUX-OP : wmax(ws[1], ws[2]) +; LESS-OP : ELit(1) +; LESS-EQ-OP : ELit(1) +; GREATER-OP : ELit(1) +; GREATER-EQ-OP : ELit(1) +; EQUAL-OP : ELit(1) +; +;defn put-type (el:Element, t:Type) -> Element : +; match(el) : +; (el:Register) : Register(t, value(el), enable(el)) +; (el:Memory) : Memory(t, writers(el)) +; (el:Node) : Node(t, value(el)) +; (el:Instance) : Instance(t, module(el), ports(el)) +; +;defn generate-constraints (c:Circuit) -> [Circuit, Vector<WConstraint>] : +; ;Constraints +; val cs = Vector<WConstraint>() +; defn new-constraint (Constraint: (Symbol, Exp) -> WConstraint, wvar:Width, width:Width) : +; match(wvar) : +; (wvar:WidthVar) : +; add(cs, Constraint(name(wvar), to-exp(width))) +; (wvar) : +; false +; +; defn to-width (e:Exp) : +; match(e) : +; (e:ELit) : +; IntWidth(width(e)) +; (e:EVar) : +; WidthVar(name(e)) +; (e) : +; val x = gensym() +; add(cs, WidthEqual(x, e)) +; WidthVar(x) +; +; ;Module types +; val mod-types = HashTable<Symbol,Type>(symbol-hash) +; +; defn add-port-vars (m:Module) -> Module : +; val ports* = +; for p in ports(m) map : +; val type* = add-width-vars(type(p)) +; Port(name(p), direction(p), type*) +; mod-types[name(m)] = BundleType(ports*) +; Module(name(m), ports*, body(m)) +; +; ;Add Width Variables +; defn add-module-vars (m:Module) -> Module : +; val types = HashTable<Symbol,Type>(symbol-hash) +; for p in ports(m) do : +; types[name(p)] = type(p) +; +; defn infer-exp-width (e:Expression) -> Expression : +; match(map(infer-exp-width, e)) : +; (e:WRef) : +; match(kind(e)) : +; (k:ModuleKind) : put-type(e, mod-types[name(e)]) +; (k) : put-type(e, types[name(e)]) +; (e:WField) : +; val t = bundle-field-type(type(exp(e)), name(e)) +; put-width(e, width!(t)) +; (e:UIntValue) : +; match(width(e)) : +; (w:UnknownWidth) : UIntValue(value(e), uint-width(value(e))) +; (w) : e +; (e:SIntValue) : +; match(width(e)) : +; (w:UnknownWidth) : SIntValue(value(e), sint-width(value(e))) +; (w) : e +; (e:DoPrim) : +; val widths = map(width!{type(_)}, args(e)) +; val w = to-width(primop-width(op(e), widths, consts(e))) +; put-width(e, w) +; (e:ReadPort) : +; val elem-type = type(type(mem(e)) as VectorType) +; put-width(e, width!(elem-type)) +; +; defn infer-comm-width (c:Stmt) : +; match(c) : +; (c:LetRec) : +; ;Add width vars to elements +; var entries*: List<KeyValue<Symbol,Element>> = +; for entry in entries(c) map : +; val el-name = key(entry) +; key(entry) => +; match(value(entry)) : +; (el:Register|Node) : +; put-type(el, add-width-vars(type(el))) +; (el:Memory) : +; el +; (el:Instance) : +; val mod-type = type(infer-exp-width(module(el))) as BundleType +; val type = BundleType $ to-list $ +; for p in ports(mod-type) filter : +; direction(p) == OUTPUT +; put-type(el, type) +; +; ;Add vars to environment +; for entry in entries* do : +; types[key(entry)] = type(value(entry)) +; +; ;Infer types for elements +; entries* = +; for entry in entries* map : +; key(entry) => map(infer-exp-width, value(entry)) +; +; ;Generate constraints +; for entry in entries* do : +; val el-name = key(entry) +; match(value(entry)) : +; (el:Register) : +; new-constraint(WidthEqual, reg-width, val-width) where : +; val reg-width = width!(types[el-name]) +; val val-width = width!(type(value(el))) +; (el:Node) : +; new-constraint(WidthEqual, node-width, val-width) where : +; val node-width = width!(types[el-name]) +; val val-width = width!(type(value(el))) +; (el:Instance) : +; val mod-type = type(module(el)) as BundleType +; for entry in ports(el) do : +; new-constraint(WidthGreater, port-width, val-width) where : +; val port-name = key(entry) +; val port-width = width!(bundle-field-type(mod-type, port-name)) +; val val-width = width!(type(value(entry))) +; (el) : false +; +; ;Analyze body +; LetRec(entries*, infer-comm-width(body(c))) +; +; (c:Connect) : +; val loc* = infer-exp-width(loc(c)) +; val exp* = infer-exp-width(exp(c)) +; new-constraint(WidthGreater, loc-width, exp-width) where : +; val loc-width = width!(type(loc*)) +; val exp-width = width!(type(exp*)) +; Connect(loc*, exp*) +; +; (c:Begin) : +; map(infer-comm-width, c) +; +; Module(name(m), ports(m), body*) where : +; val body* = infer-comm-width(body(m)) +; +; val c* = +; Circuit(modules*, main(c)) where : +; val ms = map(add-port-vars, modules(c)) +; val modules* = map(add-module-vars, ms) +; [c*, cs] +; +; +;;================== FILL WIDTHS ============================ +;defn fill-widths (c:Circuit, solved:Streamable<WidthEqual>) : +; ;Populate table +; val table = HashTable<Symbol, Width>(symbol-hash) +; for eq in solved do : +; table[name(eq)] = IntWidth(width(value(eq) as ELit)) +; +; defn width? (w:Width) : +; match(w) : +; (w:WidthVar) : get?(table, name(w), UnknownWidth()) +; (w) : w +; +; defn fill-type (t:Type) : +; match(t) : +; (t:UIntType) : UIntType(width?(width(t))) +; (t:SIntType) : SIntType(width?(width(t))) +; (t) : map(fill-type, t) +; +; defn fill-exp (e:Expression) -> Expression : +; val e* = map(fill-exp, e) +; val type* = fill-type(type(e)) +; put-type(e*, type*) +; +; defn fill-element (e:Element) : +; val e* = map(fill-exp, e) +; val type* = fill-type(type(e)) +; put-type(e*, type*) +; +; defn fill-comm (c:Stmt) : +; match(c) : +; (c:LetRec) : +; val entries* = +; for e in entries(c) map : +; key(e) => fill-element(value(e)) +; LetRec(entries*, fill-comm(body(c))) +; (c) : +; map{fill-comm, _} $ +; map(fill-exp, c) +; +; defn fill-port (p:Port) : +; Port(name(p), direction(p), fill-type(type(p))) +; +; defn fill-mod (m:Module) : +; Module(name(m), ports*, body*) where : +; val ports* = map(fill-port, ports(m)) +; val body* = fill-comm(body(m)) +; +; Circuit(modules*, main(c)) where : +; val modules* = map(fill-mod, modules(c)) +; +; +;;=============== TYPE INFERENCE DRIVER ===================== +;defn infer-widths (c:Circuit) : +; val [c*, cs] = generate-constraints(c) +; val solved = solve-widths(cs) +; fill-widths(c*, solved) +; +; +;;================ PAD WIDTHS =============================== +;defn pad-widths (c:Circuit) : +; ;Pad an expression to the given width +; defn pad-exp (e:Expression, w:Int) : +; match(type(e)) : +; (t:UIntType|SIntType) : +; val prev-w = width!(t) as IntWidth +; if width(prev-w) < w : +; val type* = put-width(t, IntWidth(w)) +; DoPrim(PAD-OP, list(e), list(w), type*) +; else : +; e +; (t) : +; e +; +; defn pad-exp (e:Expression, w:Width) : +; val w-value = width(w as IntWidth) +; pad-exp(e, w-value) +; +; ;Convenience +; defn max-width (es:Streamable<Expression>) : +; defn int-width (e:Expression) : +; width(width!(type(e)) as IntWidth) +; maximum(stream(int-width, es)) +; +; defn match-widths (es:List<Expression>) : +; val w = max-width(es) +; map(pad-exp{_, w}, es) +; +; ;Match widths for an expression +; defn match-exp-width (e:Expression) : +; match(map(match-exp-width, e)) : +; (e:DoPrim) : +; if contains?([BIT-AND-OP, BIT-OR-OP, BIT-XOR-OP, EQUAL-OP], op(e)) : +; val args* = match-widths(args(e)) +; DoPrim(op(e), args*, consts(e), type(e)) +; else if op(e) == MUX-OP : +; val args* = List(head(args(e)), match-widths(tail(args(e)))) +; DoPrim(op(e), args*, consts(e), type(e)) +; else : +; e +; (e) : e +; +; defn match-element-width (e:Element) : +; match(map(match-exp-width, e)) : +; (e:Register) : +; val w = width!(type(e)) +; val value* = pad-exp(value(e), w) +; Register(type(e), value*, enable(e)) +; (e:Memory) : +; val width = width!(type(type(e) as VectorType)) +; val writers* = +; for w in writers(e) map : +; WritePort(index(w), pad-exp(value(w), width), enable(w)) +; Memory(type(e), writers*) +; (e:Node) : +; val w = width!(type(e)) +; val value* = pad-exp(value(e), w) +; Node(type(e), value*) +; (e:Instance) : +; val mod-type = type(module(e)) as BundleType +; val ports* = +; for p in ports(e) map : +; val port-type = bundle-field-type(mod-type, key(p)) +; val port-val = pad-exp(value(p), width!(port-type)) +; key(p) => port-val +; Instance(type(e), module(e), ports*) +; +; ;Match widths for a command +; defn match-comm-width (c:Stmt) : +; match(map(match-exp-width, c)) : +; (c:LetRec) : +; val entries* = +; for e in entries(c) map : +; key(e) => match-element-width(value(e)) +; LetRec(entries*, match-comm-width(body(c))) +; (c:Connect) : +; val w = width!(type(loc(c))) +; val exp* = pad-exp(exp(c), w) +; Connect(loc(c), exp*) +; (c) : +; map(match-comm-width, c) +; +; defn match-mod-width (m:Module) : +; Module(name(m), ports(m), body*) where : +; val body* = match-comm-width(body(m)) +; +; Circuit(modules*, main(c)) where : +; val modules* = map(match-mod-width, modules(c)) +; +; +;;================== INLINING =============================== +;defn inline-instances (c:Circuit) : +; val module-table = HashTable<Symbol,Module>(symbol-hash) +; val inlined? = HashTable<Symbol,True|False>(symbol-hash) +; for m in modules(c) do : +; module-table[name(m)] = m +; inlined?[name(m)] = false +; +; ;Convert a module into a sequence of elements +; defn to-elements (m:Module, +; inst:Symbol, +; port-exps:List<KeyValue<Symbol,Expression>>) -> +; List<KeyValue<Symbol, Element>> : +; defn rename-exp (e:Expression) : +; match(e) : +; (e:WRef) : WRef(prefix(inst, name(e)), type(e), kind(e), dir(e)) +; (e) : map(rename-exp, e) +; +; defn to-elements (c:Stmt) -> List<KeyValue<Symbol,Element>> : +; match(c) : +; (c:LetRec) : +; val entries* = +; for entry in entries(c) map : +; val name* = prefix(inst, key(entry)) +; val element* = map(rename-exp, value(entry)) +; name* => element* +; val body* = to-elements(body(c)) +; append(entries*, body*) +; (c:Connect) : +; val ref = loc(c) as WRef +; val name* = prefix(inst, name(ref)) +; list(name* => Node(type(exp(c)), rename-exp(exp(c)))) +; (c:Begin) : +; map-append(to-elements, body(c)) +; +; val inputs = +; for p in ports(m) map-append : +; if direction(p) == INPUT : +; val port-exp = lookup!(port-exps, name(p)) +; val name* = prefix(inst, name(p)) +; list(name* => Node(type(port-exp), port-exp)) +; else : +; List() +; append(inputs, to-elements(body(m))) +; +; ;Inline all instances in the module +; defn inline-instances (m:Module) : +; defn rename-exp (e:Expression) : +; match(e) : +; (e:WField) : +; val inst-exp = exp(e) as WRef +; val name* = prefix(name(inst-exp), name(e)) +; WRef(name*, type(e), NodeKind(), dir(e)) +; (e) : +; map(rename-exp, e) +; +; defn inline-elems (es:List<KeyValue<Symbol,Element>>) : +; for entry in es map-append : +; match(value(entry)) : +; (el:Instance) : +; val mod-name = name(module(el) as WRef) +; val module = inlined-module(mod-name) +; to-elements(module, key(entry), ports(el)) +; (el) : +; list(entry) +; +; defn inline-comm (c:Stmt) : +; match(map(rename-exp, c)) : +; (c:LetRec) : +; val entries* = inline-elems(entries(c)) +; LetRec(entries*, inline-comm(body(c))) +; (c) : +; map(inline-comm, c) +; +; Module(name(m), ports(m), inline-comm(body(m))) +; +; ;Retrieve the inlined instance of a module +; defn inlined-module (name:Symbol) : +; if inlined?[name] : +; module-table[name] +; else : +; val module* = inline-instances(module-table[name]) +; module-table[name] = module* +; inlined?[name] = true +; module* +; +; ;Return the fully inlined circuit +; val main-module = inlined-module(main(c)) +; Circuit(list(main-module), main(c)) ;;;================ UTILITIES ================================ @@ -1982,16 +1991,16 @@ public defn run-passes (c: Circuit, p: List<Char>) : if contains(p,'c') : do-stage("Make Explicit Reset", make-explicit-reset) if contains(p,'d') : do-stage("Initialize Registers", initialize-registers) if contains(p,'e') : do-stage("Infer Types", infer-types) - if contains(p,'f') : do-stage("Infer Directions", infer-directions) - if contains(p,'g') : do-stage("Expand Accessors", expand-accessors) - if contains(p,'h') : do-stage("Flatten Bundles", flatten-bundles) - if contains(p,'i') : do-stage("Expand Bundles", expand-bundles) - if contains(p,'j') : do-stage("Expand Multi Connects", expand-multi-connects) - if contains(p,'k') : do-stage("Expand Whens", expand-whens) - if contains(p,'l') : do-stage("Structural Form", structural-form) - if contains(p,'m') : do-stage("Infer Widths", infer-widths) - if contains(p,'n') : do-stage("Pad Widths", pad-widths) - if contains(p,'o') : do-stage("Inline Instances", inline-instances) + ;if contains(p,'f') : do-stage("Infer Directions", infer-directions) + ;if contains(p,'g') : do-stage("Expand Accessors", expand-accessors) + ;if contains(p,'h') : do-stage("Flatten Bundles", flatten-bundles) + ;if contains(p,'i') : do-stage("Expand Bundles", expand-bundles) + ;if contains(p,'j') : do-stage("Expand Multi Connects", expand-multi-connects) + ;if contains(p,'k') : do-stage("Expand Whens", expand-whens) + ;if contains(p,'l') : do-stage("Structural Form", structural-form) + ;if contains(p,'m') : do-stage("Infer Widths", infer-widths) + ;if contains(p,'n') : do-stage("Pad Widths", pad-widths) + ;if contains(p,'o') : do-stage("Inline Instances", inline-instances) println("Done!") |
