defpackage firrtl.passes : import core import verse import firrtl.ir2 import firrtl.ir-utils import widthsolver import firrtl-main ;============== EXCEPTIONS ================================= defclass PassException <: Exception defn PassException (msg:String) : new PassException : defmethod print (o:OutputStream, this) : print(o, msg) ;=============== WORKING IR ================================ definterface Kind defstruct RegKind <: Kind defstruct AccessorKind <: Kind defstruct PortKind <: Kind defstruct MemKind <: Kind defstruct NodeKind <: Kind ; All elems except structural memory, wires defstruct ModuleKind <: Kind defstruct InstanceKind <: Kind defstruct StructuralMemKind <: Kind ; Separate kind because need special treatment ;val UNKNOWN-DIR = new Direction val UNKNOWN-GENDER = new Gender val BI-GENDER = new Gender defstruct WRef <: Expression : name: Symbol type: Type [multi => false] kind: Kind gender: Gender [multi => false] defstruct WSubfield <: Expression : exp: Expression name: Symbol type: Type [multi => false] gender: Gender [multi => false] defstruct WIndex <: Expression : exp: Expression value: Int type: Type [multi => false] gender: Gender [multi => false] defstruct WDefAccessor <: Stmt : name: Symbol source: Expression index: Expression gender: Gender ;================ WORKING IR UTILS ========================= ;============== DEBUG STUFF ============================= public var PRINT-TYPES : True|False = false public var PRINT-KINDS : True|False = false public var PRINT-WIDTHS : True|False = false public var PRINT-GENDERS : True|False = false public var PRINT-CIRCUITS : True|False = false ;=== Printers === defmethod print (o:OutputStream, k:Kind) : print{o, _} $ match(k) : (k:RegKind) : "reg" (k:AccessorKind) : "accessor" (k:PortKind) : "port" (k:MemKind) : "mem" (k:NodeKind) : "n" (k:ModuleKind) : "module" (k:InstanceKind) : "inst" (k:StructuralMemKind) : "smem" defn hasWidth (e:Expression|Stmt|Type|Element|Port) : e typeof UIntType|SIntType|UIntValue|SIntValue defn hasWidth (e:Expression|Stmt|Type|Element|Port) : e typeof UIntType|SIntType|UIntValue|SIntValue defn hasType (e:Expression|Stmt|Type|Element|Port) : e typeof Ref|Subfield|Index|DoPrim|ReadPort|WRef|WSubfield |WIndex|DefWire|DefRegister|DefMemory|Register |Memory|Node|Instance|VectorType|Port defn hasKind (e:Expression|Stmt|Type|Element|Port) : e typeof WRef defn any-debug? (e:Expression|Stmt|Type|Element|Port) : (hasType(e) and PRINT-TYPES) or (hasWidth(e) and PRINT-WIDTHS) or (hasKind(e) and PRINT-KINDS) defmethod print-debug (o:OutputStream, e:Expression|Stmt|Type|Element|Port) : defn wipe-width (t:Type) -> Type : match(t) : (t:UIntType) : UIntType(UnknownWidth()) (t:SIntType) : SIntType(UnknownWidth()) (t) : t if any-debug?(e) : print(o,"@") if PRINT-KINDS and hasKind(e) : print-all(o,[""]) if PRINT-TYPES and hasType(e) : print-all(o,[""]) if PRINT-WIDTHS and hasWidth(e): print-all(o,[""]) defmethod print (o:OutputStream, e:WRef) : print(o,name(e)) print-debug(o,e as ?) defmethod print (o:OutputStream, e:WSubfield) : print-all(o,[exp(e) "." name(e)]) print-debug(o,e as ?) defmethod print (o:OutputStream, e:WIndex) : print-all(o,[exp(e) "." value(e)]) print-debug(o,e as ?) defmethod print (o:OutputStream, s:WDefAccessor) : print-all(o,[dir(s) " accessor " name(s) " = " source(s) "[" index(s) "]"]) print-debug(o,s) defmethod map (f: Expression -> Expression, e: WSubfield) : WSubfield(f(exp(e)), name(e), type(e), dir(e)) defmethod map (f: Expression -> Expression, e: WIndex) : WIndex(f(exp(e)), value(e), type(e), dir(e)) defmethod map (f: Expression -> Expression, c:WDefAccessor) : WDefAccessor(name(c), f(source(c)), f(index(c)), dir(c)) ;================= DIRECTION =============================== defmulti dir (e:Expression) -> Gender defmethod dir (e:Expression) : OUTPUT ;================= Bring to Working IR ======================== ; Returns a new Circuit with Refs, Subfields, Indexes and DefAccessors ; replaced with IR-internal nodes that contain additional ; information (kind, gender) defn to-working-ir (c:Circuit) : defn to-exp (e:Expression) : match(map(to-exp,e)) : (e:Ref) : WRef(name(e), type(e), NodeKind(), UNKNOWN-GENDER) (e:Subfield) : WSubfield(exp(e), name(e), type(e), UNKNOWN-GENDER) (e:Index) : WIndex(exp(e), value(e), type(e), UNKNOWN-GENDER) (e) : e defn to-stmt (s:Stmt) : match(map(to-exp,s)) : (s:DefAccessor) : WDefAccessor(name(s),source(s),index(s), UNKNOWN-GENDER) (s) : map(to-stmt,s) Circuit(modules*, main(c)) where : val modules* = for m in modules(c) map : Module(name(m), ports(m), to-stmt(body(m))) ;=============== Resolve Kinds ============================= ; It is useful for the compiler to know information about ; objects referenced. This information is stored in the kind ; field in WRef. This pass walks the graph and returns a new ; Circuit where all WRef kinds are resolved defn resolve-kinds (c:Circuit) : defn resolve (body:Stmt, kinds:HashTable) : defn resolve-stmt (s:Stmt) -> Stmt : map{resolve-expr,_} $ map(resolve-stmt,s) defn resolve-expr (e:Expression) -> Expression : match(e) : (e:WRef) : WRef(name(e),type(e),kinds[name(e)],dir(e)) (e) : map(resolve-expr,e) resolve-stmt(body) defn find (m:Module, kinds:HashTable) : defn find-stmt (s:Stmt) -> Stmt : match(s) : (s:LetRec) : for e in entries(s) do : kinds[key(e)] = get-elem-kind(value(e)) (s:DefWire) : kinds[name(s)] = NodeKind() (s:DefRegister) : kinds[name(s)] = RegKind() (s:DefInstance) : kinds[name(s)] = InstanceKind() (s:DefMemory) : kinds[name(s)] = MemKind() (s:WDefAccessor) : kinds[name(s)] = AccessorKind() (s) : false map(find-stmt,s) defn get-elem-kind (e:Element) : match(e) : (e: Memory) : StructuralMemKind() (e) : NodeKind() kinds[name(m)] = ModuleKind() for p in ports(m) do : kinds[name(p)] = PortKind() find-stmt(body(m)) defn resolve-kinds (m:Module, c:Circuit) -> Module : val kinds = HashTable(symbol-hash) for m in modules(c) do : kinds[name(m)] = ModuleKind() find(m,kinds) val body! = resolve(body(m),kinds) Module(name(m),ports(m),body!) Circuit(modules*, main(c)) where : val modules* = for m in modules(c) map : resolve-kinds(m,c) ;=============== MAKE EXPLICIT RESET ======================= ; All modules have an implicit reset signal - however, the ; programmer can explicitly reference this signal if desired. ; This pass makes all implicit resets explicit while ; preserving any previously explicit resets ; If reset is not explicitly passed to instantiations, then this ; pass autmatically connects the parent module's reset to the ; instantiation's reset defn make-explicit-reset (c:Circuit) : defn find-explicit (c:Circuit) -> List : defn explicit? (m:Module) -> True|False : for p in ports(m) any? : name(p) == `reset val explicit-reset = Vector() for m in modules(c) do: if explicit?(m) : add(explicit-reset,name(m)) to-list(explicit-reset) defn make-explicit (m:Module, explicit-reset:List) -> Module : defn route-reset (s:Stmt) -> Stmt : match(s) : (s:DefInstance) : val iref = WSubfield(WRef(name(s), UnknownType(), InstanceKind(), UNKNOWN-GENDER),`reset,UnknownType(),UNKNOWN-GENDER) val pref = WRef(`reset, UnknownType(), PortKind(), MALE) Begin(to-list([s,Connect(iref,pref)])) (s) : map(route-reset,s) var ports! = ports(m) if not contains?(explicit-reset,name(m)) : ports! = append(ports(m),list(Port(`reset,MALE,UIntType(IntWidth(1))))) val body! = route-reset(body(m)) Module(name(m),ports!,body!) defn make-explicit-reset (m:Module, c:Circuit) -> Module : val explicit-reset = find-explicit(c) make-explicit(m,explicit-reset) Circuit(modules*, main(c)) where : val modules* = for m in modules(c) map : make-explicit-reset(m,c) ;======= MAKE EXPLICIT REGISTER INITIALIZATION ============= ; This pass replaces the reg.init construct by creating a new ; wire that holds the value at initialization. This wire ; is then connected to the register conditionally on reset, ; at the end of the scope containing the register ; declaration ; If a register has no inital value, the wire is connected to ; a NULL node. Later passes will remove these with the base ; case Mux(reset,NULL,a) -> a, and Mux(reset,a,NULL) -> a. ; This ensures proper behavior if this pass is run multiple ; times. defn initialize-registers (c:Circuit) : defn add-when (s:Stmt,renames:HashTable) -> Stmt : var inits = List() for kv in renames do : val refreg = WRef(key(kv),UnknownType(),RegKind(),UNKNOWN-GENDER) val refwire = WRef(value(kv),UnknownType(),NodeKind(),UNKNOWN-GENDER) val connect = Connect(refreg,refwire) inits = append(inits,list(connect)) if empty?(inits) : s else : val pred = WRef(`reset, UnknownType(), PortKind(), UNKNOWN-GENDER) val when-reset = Conditionally(pred,Begin(inits),Begin(List())) Begin(list(s,when-reset)) defn rename (s:Stmt,l:HashTable) -> Stmt : defn rename-stmt (s:Stmt) -> Stmt : map{rename-expr,_} $ map(rename-stmt,s) defn rename-expr (e:Expression) -> Expression : match(e) : (e:WSubfield) : if name(e) == `init and register?(exp(e)) : ;TODO Error if l does not contain register val new-name = l[name(exp(e) as WRef)] WRef(new-name,UnknownType(),NodeKind(),UNKNOWN-GENDER) else : e (e) : map(rename-expr,e) defn register? (e:Expression) -> True|False : match(e) : (e:WRef) : kind(e) typeof RegKind (e) : false rename-stmt(s) defn initialize-registers (s:Stmt) -> [Stmt,HashTable] : val empty-hash = HashTable(symbol-hash) match(s) : (s:Begin) : var body! = List() var renames = HashTable(symbol-hash) for s in body(s) do : val [s!,renames!] = initialize-registers(s) body! = append(body!,list(s!)) merge!(renames,renames!) [Begin(body!),renames] (s:DefRegister) : val wire-name = gensym() val renames = HashTable(symbol-hash) renames[name(s)] = wire-name [Begin(body!),renames] where : val defreg = s val defwire = DefWire(wire-name,type(s)) val conwire = Connect(WRef(wire-name,UnknownType(),NodeKind(),UNKNOWN-GENDER),Null()) val body! = list(defreg,defwire,conwire) (s:Conditionally) : [Conditionally(pred(s),initialize-scope(conseq(s)),initialize-scope(alt(s))),empty-hash] (s:LetRec) : [LetRec(entries(s),initialize-scope(body(s))),empty-hash] ;TODO Add Letrec (s) : [s,empty-hash] defn initialize-scope (s:Stmt) -> Stmt : val [s!,renames] = initialize-registers(s) val s!! = rename(s!,renames) val s!!! = add-when(s!!,renames) s!!! defn initialize-module (m:Module) -> Module : Module(name(m), ports(m), body!) where : val body! = initialize-scope(body(m)) Circuit(modules*, main(c)) where : val modules* = for m in modules(c) map : initialize-module(m) ;============== INFER TYPES ================================ ; This pass infers the type field in all IR nodes by updating ; and passing an environment to all statements in pre-order ; traversal, and resolving types in expressions in post- ; order traversal. ; Type propagation for primary ops are defined here. ; Notable cases: LetRec requires updating environment before ; resolving the subexpressions in its elements. ; Type errors are not checked in this pass, as this is ; postponed for a later/earlier pass. defn get-primop-rettype (e:DoPrim) -> Type : defn u () : UIntType(UnknownWidth()) defn s () : SIntType(UnknownWidth()) defn u-and (op1:Expression,op2:Expression) : match(type(op1), type(op2)) : (t1:UIntType, t2:UIntType) : u() (t1:SIntType, t2) : s() (t1, t2:SIntType) : s() (t1, t2) : UnknownType() defn of-type (op:Expression) : match(type(op)) : (t:UIntType) : u() (t:SIntType) : s() (t) : UnknownType() ;println-all(["Inferencing primop type: " e]) switch {op(e) == _} : ADD-OP : u-and(args(e)[0],args(e)[1]) ADD-UU-OP : u() ADD-US-OP : s() ADD-SU-OP : s() ADD-SS-OP : s() SUB-OP : s() SUB-UU-OP : s() SUB-US-OP : s() SUB-SU-OP : s() SUB-SS-OP : s() MUL-OP : u-and(args(e)[0],args(e)[1]) MUL-UU-OP : u() MUL-US-OP : s() MUL-SU-OP : s() MUL-SS-OP : s() DIV-OP : u-and(args(e)[0],args(e)[1]) DIV-UU-OP : u() DIV-US-OP : s() DIV-SU-OP : s() DIV-SS-OP : s() MOD-OP : of-type(args(e)[0]) MOD-UU-OP : u() MOD-US-OP : u() MOD-SU-OP : s() MOD-SS-OP : s() QUO-OP : u-and(args(e)[0],args(e)[1]) QUO-UU-OP : u() QUO-US-OP : s() QUO-SU-OP : s() QUO-SS-OP : s() REM-OP : of-type(args(e)[1]) REM-UU-OP : u() REM-US-OP : s() REM-SU-OP : u() REM-SS-OP : s() ADD-WRAP-OP : u-and(args(e)[0],args(e)[1]) ADD-WRAP-UU-OP : u() ADD-WRAP-US-OP : s() ADD-WRAP-SU-OP : s() ADD-WRAP-SS-OP : s() SUB-WRAP-OP : u-and(args(e)[0],args(e)[1]) SUB-WRAP-UU-OP : u() SUB-WRAP-US-OP : s() SUB-WRAP-SU-OP : s() SUB-WRAP-SS-OP : s() LESS-OP : u() LESS-UU-OP : u() LESS-US-OP : u() LESS-SU-OP : u() LESS-SS-OP : u() LESS-EQ-OP : u() LESS-EQ-UU-OP : u() LESS-EQ-US-OP : u() LESS-EQ-SU-OP : u() LESS-EQ-SS-OP : u() GREATER-OP : u() GREATER-UU-OP : u() GREATER-US-OP : u() GREATER-SU-OP : u() GREATER-SS-OP : u() GREATER-EQ-OP : u() GREATER-EQ-UU-OP : u() GREATER-EQ-US-OP : u() GREATER-EQ-SU-OP : u() GREATER-EQ-SS-OP : u() EQUAL-OP : u() EQUAL-UU-OP : u() EQUAL-SS-OP : u() MUX-OP : of-type(args(e)[0]) MUX-UU-OP : u() MUX-SS-OP : s() PAD-OP : of-type(args(e)[0]) PAD-U-OP : u() PAD-S-OP : s() AS-UINT-OP : u() AS-UINT-U-OP : u() AS-UINT-S-OP : u() AS-SINT-OP : s() AS-SINT-U-OP : s() AS-SINT-S-OP : s() SHIFT-LEFT-OP : of-type(args(e)[0]) SHIFT-LEFT-U-OP : u() SHIFT-LEFT-S-OP : s() SHIFT-RIGHT-OP : of-type(args(e)[0]) SHIFT-RIGHT-U-OP : u() SHIFT-RIGHT-S-OP : s() CONVERT-OP : s() CONVERT-U-OP : s() CONVERT-S-OP : s() BIT-AND-OP : u() BIT-OR-OP : u() BIT-XOR-OP : u() CONCAT-OP : u() BIT-SELECT-OP : u() BITS-SELECT-OP : u() defn type (m:Module) -> Type : BundleType(ports(m)) defn get-type (b:Symbol,l:List>) -> Type : val ma = for kv in l find : b == key(kv) if ma != false : val ret = value(ma as KeyValue) ret else : UnknownType() defn bundle-field-type (v:Type,s:Symbol) -> Type : match(v) : (v:BundleType) : val ft = for p in ports(v) find : name(p) == s if ft != false : type(ft as Port) else : UnknownType() (v) : UnknownType() defn get-vector-subtype (v:Type) -> Type : match(v) : (v:VectorType) : type(v) (v) : UnknownType() defn infer-exp-types (e:Expression, l:List>) -> Expression : val r = map(infer-exp-types{_,l},e) match(r) : (e:WRef) : WRef(name(e), get-type(name(e),l),kind(e),dir(e)) (e:WSubfield) : WSubfield(exp(e),name(e), bundle-field-type(type(exp(e)),name(e)),dir(e)) (e:WIndex) : WIndex(exp(e),value(e), get-vector-subtype(type(exp(e))),dir(e)) (e:DoPrim) : DoPrim(op(e),args(e),consts(e),get-primop-rettype(e)) (e:ReadPort) : ReadPort(mem(e),index(e),get-vector-subtype(type(mem(e)))) (e:UIntValue|SIntValue|Null) : e defn infer-types (s:Stmt, l:List>) -> [Stmt, List>] : match(map(infer-exp-types{_,l},s)) : (s:LetRec) : [s,l] ;TODO, this is wrong but we might be getting rid of letrecs? (s:Begin) : var env = l val body* = for s in body(s) map : val [s*,l*] = infer-types(s,env) env = l* s* [Begin(body*),env] (s:DefWire) : [s,List(name(s) => type(s),l)] (s:DefRegister) : [s,List(name(s) => type(s),l)] (s:DefMemory) : [s,List(name(s) => type(s),l)] (s:DefInstance) : [s, List(name(s) => type(module(s)),l)] (s:DefNode) : [s, List(name(s) => type(value(s)),l)] (s:WDefAccessor) : [s, List(name(s) => get-vector-subtype(type(source(s))),l)] (s:Conditionally) : val [s*,l*] = infer-types(conseq(s),l) val [s**,l**] = infer-types(alt(s),l) [Conditionally(pred(s),s*,s**),l] (s:Connect|EmptyStmt) : [s,l] defn infer-types (m:Module, l:List>) -> Module : val ptypes = for p in ports(m) map : name(p) => type(p) ;println-all(append(ptypes,l)) val [s,l*] = infer-types(body(m),append(ptypes, l)) Module(name(m),ports(m),s) defn infer-types (c:Circuit) -> Circuit : val l = for m in modules(c) map : name(m) => BundleType(ports(m)) ;println-all(l) Circuit{ _, main(c) } $ for m in modules(c) map : infer-types(m,l) ;============= INFER DIRECTIONS ============================ ; To ensure a proper circuit, we must ensure that assignments ; only work on expressions that can be assigned to. Similarly, ; we must ensure that only expressions that can be read from ; are used to assign from. This invariant requires each ; expression's gender to be inferred. ; Various elements can be bi-gender (e.g. wires) and can ; thus be treated as either female or male. Conversely, some ; elements are single-gender (e.g. accessors, ports). ; Because accessor gender is not known during declaration, ; this pass requires iterating until a fixed point is reached. defn infer-genders (c:Circuit) : defn resolve (body:Stmt, kinds:HashTable) : defn resolve-stmt (s:Stmt) -> Stmt : match(s) ; (s:LetRec) : s ; TODO get rid of this ; (s:DefInstance) : genders[name(s)] = MALE ; (s:DefMemory) : genders[name(s)] = BI-GENDER ;TODO WHY?? ; (s:WDefAccessor) : genders[name(s)] = gender(s) ; (s) : false ; map(find-stmt,s) ;defn resolve-expr (e:Expression) -> Expression : ; match(e) : ; (e:WRef) : WRef(name(e),type(e),kinds[name(e)],dir(e)) ; (e) : map(resolve-expr,e) ;resolve-stmt(body) defn find (m:Module, genders:HashTable) : defn find-stmt (s:Stmt) -> Stmt : match(s) (s:LetRec) : s ; TODO get rid of this (s:DefWire) : genders[name(s)] = BI-GENDER (s:DefRegister) : genders[name(s)] = BI-GENDER (s:DefInstance) : genders[name(s)] = MALE (s:DefMemory) : genders[name(s)] = BI-GENDER ;TODO WHY?? (s:WDefAccessor) : genders[name(s)] = gender(s) (s) : false map(find-stmt,s) kinds[name(m)] = ModuleKind() for p in ports(m) do : kinds[name(p)] = PortKind() find-stmt(body(m)) defn resolve-genders (m:Module, c:Circuit) -> Module : val genders = HashTable(symbol-hash) for m in modules(c) do : genders[name(m)] = flip(to-field(ports(m))) find(m,genders) val [body*,done?] = resolve(body(m),kinds) val module* = Module(name(m),ports(m),body*) if done? : module* else : resolve-genders(module,c) Circuit(modules*, main(c)) where : val modules* = for m in modules(c) map : resolve-genders(m,c) ;first pass, probably will delete defn iter-infer-genders (m:Module,l:HashTable) -> [Circuit, True|False] : var all-done? = false val body* = for m in modules(c) map : val [m*,done?] = iter-infer-genders(m) all-done? = all-done? and done? m* [Circuit(modules*,name(c)),all-done?] defn iter-infer-genders (c:Circuit) -> [Circuit, True|False] : var all-done? = false val modules* = for m in modules(c) map : val [m*,done?] = iter-infer-genders(m) all-done? = all-done? and done? m* [Circuit(modules*,name(c)),all-done?] defn infer-genders (c:Circuit) -> Circuit : val [c*,done?] = iter-infer-genders(c) if done? : c* else : infer-genders(c*) ;;============== EXPAND VECS ================================ ;defstruct ManyConnect <: Stmt : ; index: Expression ; locs: List ; exp: Expression ; ;defstruct ConnectMany <: Stmt : ; index: Expression ; loc: Expression ; exps: List ; ;defmethod print (o:OutputStream, c:ManyConnect) : ; print-all(o, [locs(c) "[" index(c) "] := " exp(c)]) ;defmethod print (o:OutputStream, c:ConnectMany) : ; print-all(o, [loc(c) " := " exps(c) "[" index(c) "]"]) ; ;defmethod map (f: Expression -> Expression, c:ManyConnect) : ; ManyConnect(f(index(c)), map(f, locs(c)), f(exp(c))) ;defmethod map (f: Expression -> Expression, c:ConnectMany) : ; ConnectMany(f(index(c)), f(loc(c)), map(f, exps(c))) ; ;defn expand-accessors (m: Module) : ; defn expand (c:Stmt) : ; match(c) : ; (c:WDefAccessor) : ; ;Is the source a memory? ; val mem? = ; match(source(c)) : ; (r:WRef) : kind(r) typeof MemKind ; (r) : false ; ; if mem? : ; c ; else : ; switch {dir(c) == _} : ; INPUT : ; Begin(list( ; DefWire(name(c), type(src-type)) ; ManyConnect(index(c), elems, wire-ref))) ; where : ; val src-type = type(source(c)) as VectorType ; val wire-ref = WRef(name(c), type(src-type), NodeKind(), OUTPUT) ; val elems = to-list $ ; for i in 0 to size(src-type) stream : ; WIndex(source(c), i, type(src-type), INPUT) ; OUTPUT : ; Begin(list( ; DefWire(name(c), type(src-type)) ; ConnectMany(index(c), wire-ref, elems))) ; where : ; val src-type = type(source(c)) as VectorType ; val wire-ref = WRef(name(c), type(src-type), NodeKind(), INPUT) ; val elems = to-list $ ; for i in 0 to size(src-type) stream : ; WIndex(source(c), i, type(src-type), OUTPUT) ; (c) : ; map(expand, c) ; Module(name(m), ports(m), expand(body(m))) ; ;defn expand-accessors (c:Circuit) : ; Circuit(modules*, main(c)) where : ; val modules* = map(expand-accessors, modules(c)) ; ; ; ;;=============== BUNDLE FLATTENING ========================= ;defn prefix (prefix, suffix) : ; symbol-join([prefix "/" suffix]) ; ;defn prefix-ports (pre:Symbol, ports:List) : ; for p in ports map : ; Port(prefix(pre, name(p)), gender(p), type(p)) ; ;defn flatten-ports (port:Port) -> List : ; match(type(port)) : ; (t:BundleType) : ; val ports = map-append(flatten-ports, ports(t)) ; for p in ports map : ; Port(prefix(name(port), name(p)), ; gender(port) * gender(p), ; type(p)) ; (t:VectorType) : ; val type* = flatten-type(t) ; flatten-ports(Port(name(port), gender(port), type*)) ; (t:Type) : ; list(port) ; ;defn flatten-type (t:Type) -> Type : ; match(t) : ; (t:BundleType) : ; BundleType $ ; map-append(flatten-ports, ports(t)) ; (t:VectorType) : ; flatten-type $ BundleType $ to-list $ ; for i in 0 to size(t) stream : ; Port(to-symbol(i), OUTPUT, type(t)) ; (t:Type) : ; t ; ;defn flatten-bundles (c:Circuit) : ; defn flatten-exp (e:Expression) : ; match(map(flatten-exp, e)) : ; (e:UIntValue|SIntValue) : ; e ; (e:WRef) : ; match(kind(e)) : ; (k:MemKind|StructuralMemKind) : ; val type* = map(flatten-type, type(e)) ; put-type(e, type*) ; (k) : ; val type* = flatten-type(type(e)) ; put-type(e, type*) ; (e) : ; val type* = flatten-type(type(e)) ; put-type(e, type*) ; ; defn flatten-element (e:Element) : ; val t* = flatten-type(type(e)) ; match(map(flatten-exp, e)) : ; (e:Register) : Register(t*, value(e), enable(e)) ; (e:Memory) : Memory(t*, writers(e)) ; (e:Node) : Node(t*, value(e)) ; (e:Instance) : Instance(t*, module(e), ports(e)) ; ; defn flatten-comm (c:Stmt) : ; match(c) : ; (c:LetRec) : ; val entries* = ; for entry in entries(c) map : ; key(entry) => flatten-element(value(entry)) ; LetRec(entries*, flatten-comm(body(c))) ; (c:DefWire) : ; DefWire(name(c), flatten-type(type(c))) ; (c:DefRegister) : ; DefRegister(name(c), flatten-type(type(c))) ; (c:DefMemory) : ; val type* = map(flatten-type, type(c)) ; DefMemory(name(c), type*) ; (c) : ; map{flatten-comm, _} $ ; map(flatten-exp, c) ; ; defn flatten-module (m:Module) : ; val ports* = map-append(flatten-ports, ports(m)) ; val body* = flatten-comm(body(m)) ; Module(name(m), ports*, body*) ; ; Circuit(modules*, main(c)) where : ; val modules* = map(flatten-module, modules(c)) ; ; ;;================== BUNDLE EXPANSION ======================= ;defn expand-bundles (m:Module) : ; ; ;Collapse all field/index expressions ; defn collapse-exp (e:Expression) -> Expression : ; match(e) : ; (e:WSubfield) : ; match(collapse-exp(exp(e))) : ; (ei:WRef) : ; if kind(ei) typeof InstanceKind : ; e ; else : ; WRef(name*, type(e), kind(ei), dir(e)) where : ; val name* = prefix(name(ei), name(e)) ; (ei:WSubfield) : ; WSubfield(exp(ei), name*, type(e), dir(e)) where : ; val name* = prefix(name(ei), name(e)) ; (e:WIndex) : ; collapse-exp(WSubfield(exp(e), name, type(e), dir(e))) where : ; val name = to-symbol(value(e)) ; (e) : ; map(collapse-exp, e) ; ; ;Expand expressions ; defn expand-exp (e:Expression) -> List : ; match(type(e)) : ; (t:BundleType) : ; for p in ports(t) map : ; val dir* = gender(p) * dir(e) ; collapse-exp(WSubfield(e, name(p), type(p), dir*)) ; (t) : ; list(collapse-exp(e)) ; ; ;Expand commands ; defn expand-comm (c:Stmt) : ; match(c) : ; (c:DefWire) : ; match(type(c)) : ; (t:BundleType) : ; Begin $ ; for p in ports(t) map : ; DefWire(prefix(name(c), name(p)), type(p)) ; (t) : ; c ; (c:DefRegister) : ; match(type(c)) : ; (t:BundleType) : ; Begin $ ; for p in ports(t) map : ; DefRegister(prefix(name(c), name(p)), type(p)) ; (t) : ; c ; (c:DefMemory) : ; match(type(type(c))) : ; (t:BundleType) : ; Begin $ ; for p in ports(t) map : ; DefMemory(prefix(name(c), name(p)), type*) where : ; val s = size(type(c)) ; val type* = VectorType(type(p), s) ; (t) : ; c ; (c:WDefAccessor) : ; match(type(source(c))) : ; (t:BundleType) : ; val srcs = expand-exp(source(c)) ; Begin $ ; for (p in ports(t), src in srcs) map : ; WDefAccessor(name*, src, index(c), dir*) where : ; val name* = prefix(name(c), name(p)) ; val dir* = gender(p) * dir(c) ; (t) : ; c ; (c:Connect) : ; val locs = expand-exp(loc(c)) ; val exps = expand-exp(exp(c)) ; Begin $ ; for (l in locs, e in exps) map : ; switch {dir(l) == _} : ; INPUT : Connect(l, e) ; OUTPUT : Connect(e, l) ; (c:ManyConnect) : ; val locs-list = transpose(map(expand-exp, locs(c))) ; val exps = expand-exp(exp(c)) ; Begin $ ; for (locs in locs-list, e in exps) map : ; switch {dir(e) == _} : ; OUTPUT : ManyConnect(index(c), locs, e) ; INPUT : ConnectMany(index(c), e, locs) ; (c:ConnectMany) : ; val locs = expand-exp(loc(c)) ; val exps-list = transpose(map(expand-exp, exps(c))) ; Begin $ ; for (l in locs, exps in exps-list) map : ; switch {dir(l) == _} : ; INPUT : ConnectMany(index(c), l, exps) ; OUTPUT : ManyConnect(index(c), exps, l) ; (c) : ; map{expand-comm, _} $ ; map(collapse-exp, c) ; ; Module(name(m), ports(m), expand-comm(body(m))) ; ;defn expand-bundles (c:Circuit) : ; Circuit(modules*, main(c)) where : ; val modules* = map(expand-bundles, modules(c)) ; ; ;;=========== CONVERT MULTI CONNECTS to WHEN ================ ;defn expand-multi-connects (c:Circuit) : ; defn equal-exp (e1:Expression, e2:Expression) : ; DoPrim(EQUAL-OP, list(e1, e2), List(), UIntType(UnknownWidth())) ; defn uint (i:Int) : ; UIntValue(i, UnknownWidth()) ; ; defn expand-comm (c:Stmt) : ; match(c) : ; (c:ConnectMany) : ; Begin $ to-list $ ; for (i in 0 to false, e in exps(c)) stream : ; Conditionally(equal-exp(index(c), uint(i)), ; Connect(loc(c), e) ; EmptyStmt()) ; (c:ManyConnect) : ; Begin $ to-list $ ; for (i in 0 to false, l in locs(c)) stream : ; Conditionally(equal-exp(index(c), uint(i)), ; Connect(l, exp(c)) ; EmptyStmt()) ; (c) : ; map(expand-comm, c) ; ; defn expand (m:Module) : ; Module(name(m), ports(m), expand-comm(body(m))) ; ; Circuit(modules*, main(c)) where : ; val modules* = map(expand, modules(c)) ; ; ;;================ EXPAND WHENS ============================= ;definterface SymbolicValue ;defstruct ExpValue <: SymbolicValue : ; exp: Expression ;defstruct WhenValue <: SymbolicValue : ; pred: Expression ; conseq: SymbolicValue ; alt: SymbolicValue ;defstruct VoidValue <: SymbolicValue ; ;defmethod print (o:OutputStream, sv:SymbolicValue) : ; match(sv) : ; (sv:VoidValue) : print(o, "VOID") ; (sv:WhenValue) : print-all(o, ["(" pred(sv) "? " conseq(sv) " : " alt(sv) ")"]) ; (sv:ExpValue) : print(o, exp(sv)) ; ;defn key-eqv? (e1:Expression, e2:Expression) : ; match(e1, e2) : ; (e1:WRef, e2:WRef) : ; name(e1) == name(e2) ; (e1:WSubfield, e2:WSubfield) : ; name(e1) == name(e2) and ; key-eqv?(exp(e1), exp(e2)) ; (e1, e2) : ; false ; ;defn merge-env (pred: Expression, ; con-env: List>, ; alt-env: List>) : ; val merged = Vector>() ; defn new-key? (k:Expression) : ; for entry in merged none? : ; key-eqv?(key(entry), k) ; ; defn sv (env:List>, k:Expression) : ; for entry in env search : ; if key-eqv?(key(entry), k) : ; value(entry) ; ; for k in stream(key, concat(con-env, alt-env)) do : ; if new-key?(k) : ; match(sv(con-env, k), sv(alt-env, k)) : ; (a:SymbolicValue, b:SymbolicValue) : ; if a == b : ; add(merged, k => a) ; else : ; add(merged, k => WhenValue(pred, a, b)) ; (a:SymbolicValue, b:False) : ; add(merged, k => a) ; (a:False, b:SymbolicValue) : ; add(merged, k => b) ; (a:False, b:False) : ; false ; ; to-list(merged) ; ;defn simplify-env (env: List>) : ; val merged = Vector>() ; defn new-key? (k:Expression) : ; for entry in merged none? : ; key-eqv?(key(entry), k) ; for entry in env do : ; if new-key?(key(entry)) : ; add(merged, entry) ; to-list(merged) ; ;defn expand-whens (m:Module) : ; val commands = Vector() ; val elements = Vector>() ; defn eval (c:Stmt, env:List>) -> ; List> : ; match(c) : ; (c:LetRec) : ; do(add{elements, _}, entries(c)) ; eval(body(c), env) ; (c:DefWire) : ; add(commands, c) ; val wire-ref = WRef(name(c), type(c), NodeKind(), INPUT) ; List(wire-ref => VoidValue(), env) ; (c:DefRegister) : ; add(commands, c) ; val reg-ref = WRef(name(c), type(c), RegKind(), INPUT) ; List(reg-ref => VoidValue(), env) ; (c:DefInstance) : ; add(commands, c) ; val entries = let : ; val module-type = type(module(c)) as BundleType ; val input-ports = to-list $ ; for p in ports(module-type) filter : ; gender(p) == INPUT ; val inst-ref = WRef(name(c), module-type, InstanceKind(), OUTPUT) ; for p in input-ports map : ; WSubfield(inst-ref, name(p), type(p), INPUT) => VoidValue() ; append(entries, env) ; (c:DefMemory) : ; add(commands, c) ; env ; (c:WDefAccessor) : ; add(commands, c) ; if dir(c) == INPUT : ; val access-ref = WRef(name(c), type(source(c)), AccessorKind(), dir(c)) ; List(access-ref => VoidValue(), env) ; else : ; env ; (c:Conditionally) : ; val con-env = eval(conseq(c), env) ; val alt-env = eval(alt(c), env) ; merge-env(pred(c), con-env, alt-env) ; (c:Begin) : ; var env:List> = env ; for c in body(c) do : ; env = eval(c, env) ; env ; (c:Connect) : ; List(loc(c) => ExpValue(exp(c)), env) ; (c:EmptyStmt) : ; env ; ; defn convert-symbolic (key:Expression, sv:SymbolicValue) : ; match(sv) : ; (sv:VoidValue) : ; throw $ PassException $ string-join $ [ ; "No default value for " key "."] ; (sv:ExpValue) : ; exp(sv) ; (sv:WhenValue) : ; defn multiplex-exp (pred:Expression, conseq:Expression, alt:Expression) : ; DoPrim(MUX-OP, list(pred, conseq, alt), List(), type(conseq)) ; multiplex-exp(pred(sv), ; convert-symbolic(key, conseq(sv)) ; convert-symbolic(key, alt(sv))) ; ; ;Compute final environment ; val env0 = let : ; val output-ports = to-list $ ; for p in ports(m) filter : ; gender(p) == OUTPUT ; for p in output-ports map : ; val port-ref = WRef(name(p), type(p), PortKind(), INPUT) ; port-ref => VoidValue() ; val env* = simplify-env(eval(body(m), env0)) ; ; ;Make new body ; val body* = Begin(list(defs, LetRec(elems, connections))) where : ; val defs = Begin(to-list(commands)) ; val elems = to-list(elements) ; val connections = Begin $ ; for entry in env* map : ; val sv = convert-symbolic(key(entry), value(entry)) ; Connect(key(entry), sv) ; ; ;Final module ; Module(name(m), ports(m), body*) ; ;defn expand-whens (c:Circuit) : ; val modules* = map(expand-whens, modules(c)) ; Circuit(modules*, main(c)) ; ; ;;================ STRUCTURAL FORM ========================== ; ;defn structural-form (m:Module) : ; val elements = Vector<(() -> KeyValue)>() ; val connected = HashTable(symbol-hash) ; val write-accessors = HashTable>(symbol-hash) ; val read-accessors = HashTable(symbol-hash) ; val inst-ports = HashTable>>(symbol-hash) ; val port-connects = Vector() ; ; defn scan (c:Stmt) : ; match(c) : ; (c:Connect) : ; match(loc(c)) : ; (loc:WRef) : ; match(kind(loc)) : ; (k:PortKind) : add(port-connects, c) ; (k) : connected[name(loc)] = exp(c) ; (loc:WSubfield) : ; val inst = exp(loc) as WRef ; val entry = name(loc) => exp(c) ; inst-ports[name(inst)] = List(entry, get?(inst-ports, name(inst), List())) ; (c:LetRec) : ; for e in entries(c) do : ; add(elements, {e}) ; scan(body(c)) ; (c:DefWire) : ; add{elements, _} $ fn () : ; name(c) => Node(type(c), connected[name(c)]) ; (c:DefRegister) : ; add{elements, _} $ fn () : ; val one = UIntValue(1, UnknownWidth()) ; name(c) => Register(type(c), connected[name(c)], one) ; (c:DefInstance) : ; add{elements, _} $ fn () : ; name(c) => Instance(UnknownType(), module(c), inst-ports[name(c)]) ; (c:DefMemory) : ; add{elements, _} $ fn () : ; val ports = for a in get?(write-accessors, name(c), List()) map : ; val one = UIntValue(1, UnknownWidth()) ; WritePort(index(a), connected[name(a)], one) ; name(c) => Memory(type(c), ports) ; (c:WDefAccessor) : ; val mem = source(c) as WRef ; switch {dir(c) == _} : ; INPUT : ; write-accessors[name(mem)] = List(c, ; get?(write-accessors, name(mem), List())) ; OUTPUT : ; read-accessors[name(c)] = c ; (c) : ; do(scan, children(c)) ; ; defn make-read-ports (e:Expression) : ; match(e) : ; (e:WRef) : ; match(kind(e)) : ; (k:AccessorKind) : ; val accessor = read-accessors[name(e)] ; ReadPort(source(accessor), index(accessor), type(e)) ; (k) : e ; (e) : map(make-read-ports, e) ; ; Module(name(m), ports(m), body*) where : ; scan(body(m)) ; val elems = to-list $ ; for e in elements stream : ; val entry = e() ; key(entry) => map(make-read-ports, value(entry)) ; val connect-ports = Begin $ to-list $ ; for c in port-connects stream : ; Connect(loc(c), make-read-ports(exp(c))) ; val body* = ; if empty?(elems) : connect-ports ; else : LetRec(elems, connect-ports) ; ;defn structural-form (c:Circuit) : ; val modules* = map(structural-form, modules(c)) ; Circuit(modules*, main(c)) ; ; ;;==================== WIDTH INFERENCE ====================== ;defstruct WidthVar <: Width : ; name: Symbol ; ;defmethod print (o:OutputStream, w:WidthVar) : ; print(o, name(w)) ; ;defn width! (t:Type) : ; match(t) : ; (t:UIntType) : width(t) ; (t:SIntType) : width(t) ; (t) : error("No width field.") ; ;defn put-width (t:Type, w:Width) : ; match(t) : ; (t:UIntType) : UIntType(w) ; (t:SIntType) : SIntType(w) ; (t) : t ; ;defn put-width (e:Expression, w:Width) : ; val type* = put-width(type(e), w) ; put-type(e, type*) ; ;defn add-width-vars (t:Type) : ; defn width? (w:Width) : ; match(w) : ; (w:UnknownWidth) : WidthVar(gensym()) ; (w) : w ; match(t) : ; (t:UIntType) : UIntType(width?(width(t))) ; (t:SIntType) : SIntType(width?(width(t))) ; (t) : map(add-width-vars, t) ; ;defn uint-width (i:Int) : ; var v:Int = i ; var n:Int = 0 ; while v != 0 : ; v = v >> 1 ; n = n + 1 ; IntWidth(n) ; ;defn sint-width (i:Int) : ; if i > 0 : ; val w = uint-width(i) ; IntWidth(width(w) + 1) ; else : ; val w = uint-width(neg(i) - 1) ; IntWidth(width(w) + 1) ; ;defn to-exp (w:Width) : ; match(w) : ; (w:IntWidth) : ELit(width(w)) ; (w:WidthVar) : EVar(name(w)) ; (w) : error $ string-join $ [ ; "Cannot convert " w " to exp."] ; ;defn primop-width (op:PrimOp, ws:List, ints:List) -> Exp : ; defn wmax (w1:Width, w2:Width) : ; EMax(to-exp(w1), to-exp(w2)) ; defn wplus (w1:Width, w2:Width) : ; EPlus(to-exp(w1), to-exp(w2)) ; defn wplus (w1:Width, w2:Int) : ; EPlus(to-exp(w1), ELit(w2)) ; defn wminus (w1:Width, w2:Width) : ; EMinus(to-exp(w1), to-exp(w2)) ; defn wminus (w1:Width, w2:Int) : ; EMinus(to-exp(w1), ELit(w2)) ; defn wmax-inc (w1:Width, w2:Width) : ; EPlus(wmax(w1, w2), ELit(1)) ; ; switch {op == _} : ; ADD-OP : wmax-inc(ws[0], ws[1]) ; ADD-WRAP-OP : wmax(ws[0], ws[1]) ; SUB-OP : wmax-inc(ws[0], ws[1]) ; SUB-WRAP-OP : wmax(ws[0], ws[1]) ; MUL-OP : wplus(ws[0], ws[1]) ; DIV-OP : wminus(ws[0], ws[1]) ; MOD-OP : to-exp(ws[1]) ; SHIFT-LEFT-OP : wplus(ws[0], ints[0]) ; SHIFT-RIGHT-OP : wminus(ws[0], ints[0]) ; PAD-OP : ELit(ints[0]) ; BIT-AND-OP : wmax(ws[0], ws[1]) ; BIT-OR-OP : wmax(ws[0], ws[1]) ; BIT-XOR-OP : wmax(ws[0], ws[1]) ; CONCAT-OP : wplus(ws[0], ints[0]) ; BIT-SELECT-OP : ELit(1) ; BITS-SELECT-OP : ELit(ints[0]) ; MUX-OP : wmax(ws[1], ws[2]) ; LESS-OP : ELit(1) ; LESS-EQ-OP : ELit(1) ; GREATER-OP : ELit(1) ; GREATER-EQ-OP : ELit(1) ; EQUAL-OP : ELit(1) ; ;defn put-type (el:Element, t:Type) -> Element : ; match(el) : ; (el:Register) : Register(t, value(el), enable(el)) ; (el:Memory) : Memory(t, writers(el)) ; (el:Node) : Node(t, value(el)) ; (el:Instance) : Instance(t, module(el), ports(el)) ; ;defn generate-constraints (c:Circuit) -> [Circuit, Vector] : ; ;Constraints ; val cs = Vector() ; defn new-constraint (Constraint: (Symbol, Exp) -> WConstraint, wvar:Width, width:Width) : ; match(wvar) : ; (wvar:WidthVar) : ; add(cs, Constraint(name(wvar), to-exp(width))) ; (wvar) : ; false ; ; defn to-width (e:Exp) : ; match(e) : ; (e:ELit) : ; IntWidth(width(e)) ; (e:EVar) : ; WidthVar(name(e)) ; (e) : ; val x = gensym() ; add(cs, WidthEqual(x, e)) ; WidthVar(x) ; ; ;Module types ; val mod-types = HashTable(symbol-hash) ; ; defn add-port-vars (m:Module) -> Module : ; val ports* = ; for p in ports(m) map : ; val type* = add-width-vars(type(p)) ; Port(name(p), gender(p), type*) ; mod-types[name(m)] = BundleType(ports*) ; Module(name(m), ports*, body(m)) ; ; ;Add Width Variables ; defn add-module-vars (m:Module) -> Module : ; val types = HashTable(symbol-hash) ; for p in ports(m) do : ; types[name(p)] = type(p) ; ; defn infer-exp-width (e:Expression) -> Expression : ; match(map(infer-exp-width, e)) : ; (e:WRef) : ; match(kind(e)) : ; (k:ModuleKind) : put-type(e, mod-types[name(e)]) ; (k) : put-type(e, types[name(e)]) ; (e:WSubfield) : ; val t = bundle-field-type(type(exp(e)), name(e)) ; put-width(e, width!(t)) ; (e:UIntValue) : ; match(width(e)) : ; (w:UnknownWidth) : UIntValue(value(e), uint-width(value(e))) ; (w) : e ; (e:SIntValue) : ; match(width(e)) : ; (w:UnknownWidth) : SIntValue(value(e), sint-width(value(e))) ; (w) : e ; (e:DoPrim) : ; val widths = map(width!{type(_)}, args(e)) ; val w = to-width(primop-width(op(e), widths, consts(e))) ; put-width(e, w) ; (e:ReadPort) : ; val elem-type = type(type(mem(e)) as VectorType) ; put-width(e, width!(elem-type)) ; ; defn infer-comm-width (c:Stmt) : ; match(c) : ; (c:LetRec) : ; ;Add width vars to elements ; var entries*: List> = ; for entry in entries(c) map : ; val el-name = key(entry) ; key(entry) => ; match(value(entry)) : ; (el:Register|Node) : ; put-type(el, add-width-vars(type(el))) ; (el:Memory) : ; el ; (el:Instance) : ; val mod-type = type(infer-exp-width(module(el))) as BundleType ; val type = BundleType $ to-list $ ; for p in ports(mod-type) filter : ; gender(p) == OUTPUT ; put-type(el, type) ; ; ;Add vars to environment ; for entry in entries* do : ; types[key(entry)] = type(value(entry)) ; ; ;Infer types for elements ; entries* = ; for entry in entries* map : ; key(entry) => map(infer-exp-width, value(entry)) ; ; ;Generate constraints ; for entry in entries* do : ; val el-name = key(entry) ; match(value(entry)) : ; (el:Register) : ; new-constraint(WidthEqual, reg-width, val-width) where : ; val reg-width = width!(types[el-name]) ; val val-width = width!(type(value(el))) ; (el:Node) : ; new-constraint(WidthEqual, node-width, val-width) where : ; val node-width = width!(types[el-name]) ; val val-width = width!(type(value(el))) ; (el:Instance) : ; val mod-type = type(module(el)) as BundleType ; for entry in ports(el) do : ; new-constraint(WidthGreater, port-width, val-width) where : ; val port-name = key(entry) ; val port-width = width!(bundle-field-type(mod-type, port-name)) ; val val-width = width!(type(value(entry))) ; (el) : false ; ; ;Analyze body ; LetRec(entries*, infer-comm-width(body(c))) ; ; (c:Connect) : ; val loc* = infer-exp-width(loc(c)) ; val exp* = infer-exp-width(exp(c)) ; new-constraint(WidthGreater, loc-width, exp-width) where : ; val loc-width = width!(type(loc*)) ; val exp-width = width!(type(exp*)) ; Connect(loc*, exp*) ; ; (c:Begin) : ; map(infer-comm-width, c) ; ; Module(name(m), ports(m), body*) where : ; val body* = infer-comm-width(body(m)) ; ; val c* = ; Circuit(modules*, main(c)) where : ; val ms = map(add-port-vars, modules(c)) ; val modules* = map(add-module-vars, ms) ; [c*, cs] ; ; ;;================== FILL WIDTHS ============================ ;defn fill-widths (c:Circuit, solved:Streamable) : ; ;Populate table ; val table = HashTable(symbol-hash) ; for eq in solved do : ; table[name(eq)] = IntWidth(width(value(eq) as ELit)) ; ; defn width? (w:Width) : ; match(w) : ; (w:WidthVar) : get?(table, name(w), UnknownWidth()) ; (w) : w ; ; defn fill-type (t:Type) : ; match(t) : ; (t:UIntType) : UIntType(width?(width(t))) ; (t:SIntType) : SIntType(width?(width(t))) ; (t) : map(fill-type, t) ; ; defn fill-exp (e:Expression) -> Expression : ; val e* = map(fill-exp, e) ; val type* = fill-type(type(e)) ; put-type(e*, type*) ; ; defn fill-element (e:Element) : ; val e* = map(fill-exp, e) ; val type* = fill-type(type(e)) ; put-type(e*, type*) ; ; defn fill-comm (c:Stmt) : ; match(c) : ; (c:LetRec) : ; val entries* = ; for e in entries(c) map : ; key(e) => fill-element(value(e)) ; LetRec(entries*, fill-comm(body(c))) ; (c) : ; map{fill-comm, _} $ ; map(fill-exp, c) ; ; defn fill-port (p:Port) : ; Port(name(p), gender(p), fill-type(type(p))) ; ; defn fill-mod (m:Module) : ; Module(name(m), ports*, body*) where : ; val ports* = map(fill-port, ports(m)) ; val body* = fill-comm(body(m)) ; ; Circuit(modules*, main(c)) where : ; val modules* = map(fill-mod, modules(c)) ; ; ;;=============== TYPE INFERENCE DRIVER ===================== ;defn infer-widths (c:Circuit) : ; val [c*, cs] = generate-constraints(c) ; val solved = solve-widths(cs) ; fill-widths(c*, solved) ; ; ;;================ PAD WIDTHS =============================== ;defn pad-widths (c:Circuit) : ; ;Pad an expression to the given width ; defn pad-exp (e:Expression, w:Int) : ; match(type(e)) : ; (t:UIntType|SIntType) : ; val prev-w = width!(t) as IntWidth ; if width(prev-w) < w : ; val type* = put-width(t, IntWidth(w)) ; DoPrim(PAD-OP, list(e), list(w), type*) ; else : ; e ; (t) : ; e ; ; defn pad-exp (e:Expression, w:Width) : ; val w-value = width(w as IntWidth) ; pad-exp(e, w-value) ; ; ;Convenience ; defn max-width (es:Streamable) : ; defn int-width (e:Expression) : ; width(width!(type(e)) as IntWidth) ; maximum(stream(int-width, es)) ; ; defn match-widths (es:List) : ; val w = max-width(es) ; map(pad-exp{_, w}, es) ; ; ;Match widths for an expression ; defn match-exp-width (e:Expression) : ; match(map(match-exp-width, e)) : ; (e:DoPrim) : ; if contains?([BIT-AND-OP, BIT-OR-OP, BIT-XOR-OP, EQUAL-OP], op(e)) : ; val args* = match-widths(args(e)) ; DoPrim(op(e), args*, consts(e), type(e)) ; else if op(e) == MUX-OP : ; val args* = List(head(args(e)), match-widths(tail(args(e)))) ; DoPrim(op(e), args*, consts(e), type(e)) ; else : ; e ; (e) : e ; ; defn match-element-width (e:Element) : ; match(map(match-exp-width, e)) : ; (e:Register) : ; val w = width!(type(e)) ; val value* = pad-exp(value(e), w) ; Register(type(e), value*, enable(e)) ; (e:Memory) : ; val width = width!(type(type(e) as VectorType)) ; val writers* = ; for w in writers(e) map : ; WritePort(index(w), pad-exp(value(w), width), enable(w)) ; Memory(type(e), writers*) ; (e:Node) : ; val w = width!(type(e)) ; val value* = pad-exp(value(e), w) ; Node(type(e), value*) ; (e:Instance) : ; val mod-type = type(module(e)) as BundleType ; val ports* = ; for p in ports(e) map : ; val port-type = bundle-field-type(mod-type, key(p)) ; val port-val = pad-exp(value(p), width!(port-type)) ; key(p) => port-val ; Instance(type(e), module(e), ports*) ; ; ;Match widths for a command ; defn match-comm-width (c:Stmt) : ; match(map(match-exp-width, c)) : ; (c:LetRec) : ; val entries* = ; for e in entries(c) map : ; key(e) => match-element-width(value(e)) ; LetRec(entries*, match-comm-width(body(c))) ; (c:Connect) : ; val w = width!(type(loc(c))) ; val exp* = pad-exp(exp(c), w) ; Connect(loc(c), exp*) ; (c) : ; map(match-comm-width, c) ; ; defn match-mod-width (m:Module) : ; Module(name(m), ports(m), body*) where : ; val body* = match-comm-width(body(m)) ; ; Circuit(modules*, main(c)) where : ; val modules* = map(match-mod-width, modules(c)) ; ; ;;================== INLINING =============================== ;defn inline-instances (c:Circuit) : ; val module-table = HashTable(symbol-hash) ; val inlined? = HashTable(symbol-hash) ; for m in modules(c) do : ; module-table[name(m)] = m ; inlined?[name(m)] = false ; ; ;Convert a module into a sequence of elements ; defn to-elements (m:Module, ; inst:Symbol, ; port-exps:List>) -> ; List> : ; defn rename-exp (e:Expression) : ; match(e) : ; (e:WRef) : WRef(prefix(inst, name(e)), type(e), kind(e), dir(e)) ; (e) : map(rename-exp, e) ; ; defn to-elements (c:Stmt) -> List> : ; match(c) : ; (c:LetRec) : ; val entries* = ; for entry in entries(c) map : ; val name* = prefix(inst, key(entry)) ; val element* = map(rename-exp, value(entry)) ; name* => element* ; val body* = to-elements(body(c)) ; append(entries*, body*) ; (c:Connect) : ; val ref = loc(c) as WRef ; val name* = prefix(inst, name(ref)) ; list(name* => Node(type(exp(c)), rename-exp(exp(c)))) ; (c:Begin) : ; map-append(to-elements, body(c)) ; ; val inputs = ; for p in ports(m) map-append : ; if gender(p) == INPUT : ; val port-exp = lookup!(port-exps, name(p)) ; val name* = prefix(inst, name(p)) ; list(name* => Node(type(port-exp), port-exp)) ; else : ; List() ; append(inputs, to-elements(body(m))) ; ; ;Inline all instances in the module ; defn inline-instances (m:Module) : ; defn rename-exp (e:Expression) : ; match(e) : ; (e:WSubfield) : ; val inst-exp = exp(e) as WRef ; val name* = prefix(name(inst-exp), name(e)) ; WRef(name*, type(e), NodeKind(), dir(e)) ; (e) : ; map(rename-exp, e) ; ; defn inline-elems (es:List>) : ; for entry in es map-append : ; match(value(entry)) : ; (el:Instance) : ; val mod-name = name(module(el) as WRef) ; val module = inlined-module(mod-name) ; to-elements(module, key(entry), ports(el)) ; (el) : ; list(entry) ; ; defn inline-comm (c:Stmt) : ; match(map(rename-exp, c)) : ; (c:LetRec) : ; val entries* = inline-elems(entries(c)) ; LetRec(entries*, inline-comm(body(c))) ; (c) : ; map(inline-comm, c) ; ; Module(name(m), ports(m), inline-comm(body(m))) ; ; ;Retrieve the inlined instance of a module ; defn inlined-module (name:Symbol) : ; if inlined?[name] : ; module-table[name] ; else : ; val module* = inline-instances(module-table[name]) ; module-table[name] = module* ; inlined?[name] = true ; module* ; ; ;Return the fully inlined circuit ; val main-module = inlined-module(main(c)) ; Circuit(list(main-module), main(c)) ;;;================ UTILITIES ================================ ; ; ; ;defn* root-ref (i:Immediate) : ; match(i) : ; (f:Subfield) : root-ref(imm(f)) ; (ind:Index) : root-ref(imm(ind)) ; (r) : r ; ;;defn lookup (e: Streamable>, i:Immediate) : ;; for entry in e search : ;; if eqv?(key(entry), i) : ;; value(entry) ;; ;;defn lookup! (e: Streamable>, i:Immediate) : ;; lookup(e, i) as T ;; ;;============ CHECK IF NAMES ARE UNIQUE ==================== ;defn check-duplicate-symbols (names: Streamable, msg: String) : ; val dict = HashTable(symbol-hash) ; for name in names do: ; if key?(dict, name): ; throw $ PassException $ string-join $ ; [msg ": " name] ; else: ; dict[name] = true ; ;defn check-duplicates (t: Type) : ; match(t) : ; (t:BundleType) : ; val names = map(name, ports(t)) ; check-duplicate-symbols{names, string-join(_)} $ ; ["Duplicate port name in bundle "] ; do(check-duplicates{type(_)}, ports(t)) ; (t:VectorType) : ; check-duplicates(type(t)) ; (t) : false ; ;defn check-duplicates (c: Stmt) : ; match(c) : ; (c:DefWire) : check-duplicates(type(c)) ; (c:DefRegister) : check-duplicates(type(c)) ; (c:DefMemory) : check-duplicates(type(c)) ; (c) : do(check-duplicates, children(c)) ; ;defn defined-names (c: Stmt) : ; generate : ; loop(c) where : ; defn loop (c:Stmt) : ; match(c) : ; (c:Stmt&HasName) : yield(name(c)) ; (c) : do(loop, children(c)) ; ;defn check-duplicates (m: Module): ; ;Check all duplicate names in all types in all ports and body ; do(check-duplicates{type(_)}, ports(m)) ; check-duplicates(body(m)) ; ; ;Check all names defined in module ; val names = concat(stream(name, ports(m)), ; defined-names(body(m))) ; check-duplicate-symbols{names, string-join(_)} $ ; ["Duplicate definition name in module " name(m)] ; ;defn check-duplicates (c: Circuit) : ; ;Check all duplicate names in all modules ; do(check-duplicates, modules(c)) ; ; ;Check all defined modules ; val names = stream(name, modules(c)) ; check-duplicate-symbols(names, "Duplicate module name") ; ; ;;================ CLEANUP COMMANDS ========================= ;defn cleanup (c:Stmt) : ; match(c) : ; (c:Begin) : ; to-command $ generate : ; loop(c) where : ; defn loop (c:Stmt) : ; match(c) : ; (c:Begin) : do(loop, body(c)) ; (c:EmptyStmt) : false ; (c) : yield(cleanup(c)) ; (c) : map(cleanup{_ as Stmt}, c) ; ;defn cleanup (c:Circuit) : ; val modules* = ; for m in modules(c) map : ; map(cleanup, m) ; Circuit(modules*, main(c)) ; ;;;============= SHIM ======================================== ;;defn shim (i:Immediate) -> Immediate : ;; match(i) : ;; (i:RegData) : ;; Ref(name(i), gender(i), type(i)) ;; (i:InstPort) : ;; val inst = Ref(name(i), UNKNOWN-GENDER, UnknownType()) ;; Subfield(inst, port(i), gender(i), type(i)) ;; (i:Subfield) : ;; val imm* = shim(imm(i)) ;; put-imm(i, imm*) ;; (i) : i ;; ;;defn shim (c:Stmt) -> Stmt : ;; val c* = map(shim{_ as Immediate}, c) ;; map(shim{_ as Stmt}, c*) ;; ;;defn shim (c:Circuit) -> Circuit : ;; val modules* = ;; for m in modules(c) map : ;; Module(name(m), ports(m), shim(body(m))) ;; Circuit(modules*, main(c)) ;; ;;;================== INLINE MODULES ========================= ;;defn cat-name (p: String|Symbol, s: String|Symbol) -> Symbol : ;; if p == "" or p == `this : ;; TODO: REMOVE THIS WHEN `THIS GETS REMOVED ;; to-symbol(s) ;; else if s == `this : ;; TODO: DITTO ;; to-symbol(p) ;; else : ;; symbol-join([p, "/", s]) ;; ;;defn inline-command (c: Stmt, mods: HashTable, prefix: String, cmds: Vector) : ;; defn rename (n: Symbol) -> Symbol : ;; cat-name(prefix, n) ;; defn inline-name (i:Immediate) -> Symbol : ;; match(i) : ;; (r:Ref) : rename(name(r)) ;; (f:Subfield) : cat-name(inline-name(imm(f)), name(f)) ;; (f:Index) : cat-name(inline-name(imm(f)), to-string(value(f))) ;; defn inline-imm (i:Immediate) -> Ref : ;; Ref(inline-name(i), gender(i), type(i)) ;; match(c) : ;; (c:DefUInt) : add(cmds, DefUInt(rename(name(c)), value(c), width(c))) ;; (c:DefSInt) : add(cmds, DefSInt(rename(name(c)), value(c), width(c))) ;; (c:DefWire) : add(cmds, DefWire(rename(name(c)), type(c))) ;; (c:DefRegister) : add(cmds, DefRegister(rename(name(c)), type(c))) ;; (c:DefMemory) : add(cmds, DefMemory(rename(name(c)), type(c), size(c))) ;; (c:DefInstance) : inline-module(mods, mods[name(module(c))], to-string(rename(name(c))), cmds) ;; (c:DoPrim) : add(cmds, DoPrim(rename(name(c)), op(c), map(inline-imm, args(c)), consts(c))) ;; (c:DefAccessor) : add(cmds, DefAccessor(rename(name(c)), inline-imm(source(c)), gender(c), inline-imm(index(c)))) ;; (c:Connect) : add(cmds, Connect(inline-imm(loc(c)), inline-imm(exp(c)))) ;; (c:Begin) : do(inline-command{_, mods, prefix, cmds}, body(c)) ;; (c:EmptyStmt) : c ;; (c) : error("Unsupported command") ;; ;;defn inline-port (p: Port, prefix: String) -> Stmt : ;; DefWire(cat-name(prefix, name(p)), type(p)) ;; ;;defn inline-module (mods: HashTable, mod: Module, prefix: String, cmds: Vector) : ;; do(add{cmds, _}, map(inline-port{_, prefix}, ports(mod))) ;; inline-command(body(mod), mods, prefix, cmds) ;; ;;defn inline-modules (c: Circuit) -> Circuit : ;; val cmds = Vector() ;; val mods = HashTable(symbol-hash) ;; for mod in modules(c) do : ;; mods[name(mod)] = mod ;; val top = mods[main(c)] ;; inline-command(body(top), mods, "", cmds) ;; val main* = Module(name(top), ports(top), Begin(to-list(cmds))) ;; Circuit(list(main*), name(top)) ;; ;; ;;;============= FLO PRINTER ====================================== ;;;;; TODO: ;;;;; not supported gt, lte ;; ;;defn flo-op-name (op:PrimOp) -> String : ;; switch {op == _} : ;; ADD-OP : "add" ;; ADD-MOD-OP : "add" ;; MINUS-OP : "sub" ;; SUB-MOD-OP : "sub" ;; MUL-OP : "mul" ;; todo: signed version ;; DIV-OP : "div" ;; todo: signed version ;; MOD-OP : "mod" ;; todo: signed version ;; SHIFT-LEFT-OP : "lsh" ;; todo: signed version ;; SHIFT-RIGHT-OP : "rsh" ;; PAD-OP : "pad" ;; todo: signed version ;; BIT-AND-OP : "and" ;; BIT-OR-OP : "or" ;; BIT-XOR-OP : "xor" ;; CONCAT-OP : "cat" ;; BIT-SELECT-OP : "rsh" ;; BITS-SELECT-OP : "rsh" ;; LESS-OP : "lt" ;; todo: signed version ;; LESS-EQ-OP : "lte" ;; todo: swap args ;; GREATER-OP : "gt" ;; todo: swap args ;; GREATER-EQ-OP : "gte" ;; todo: signed version ;; EQUAL-OP : "eq" ;; MUX-OP : "mux" ;; else : error $ string-join $ ;; ["Unable to print Primop: " op] ;; ;;defn emit (o:OutputStream, top:Symbol, ports:HashTable, lits:HashTable, elt) : ;; match(elt) : ;; (e:String|Symbol|Int) : ;; print(o, e) ;; (e:Ref) : ;; if key?(lits, name(e)) : ;; val lit = lits[name(e)] ;; print-all(o, [value(lit) "'" width(lit)]) ;; else : ;; if key?(ports, name(e)) : ;; print-all(o, [top "::"]) ;; print(o, name(e)) ;; (e:IntWidth) : ;; print(o, value(e)) ;; (e:PrimOp) : ;; print(o, flo-op-name(e)) ;; (e) : ;; println-all(["EMIT " e]) ;; error("Unable to emit") ;; ;;defn emit-all (o:OutputStream, top:Symbol, ports:HashTable, lits:HashTable, elts: Streamable) : ;; for e in elts do : emit(o, top, ports, lits, e) ;; ;;defn prim-width (type:Type) -> Width : ;; match(type) : ;; (t:UIntType) : width(t) ;; (t:SIntType) : width(t) ;; (t) : error("Bad prim width type") ;; ;;defn emit-command (o:OutputStream, cmd:Stmt, top:Symbol, lits:HashTable, regs:HashTable, accs:HashTable, ports:HashTable, outs:HashTable) : ;; match(cmd) : ;; (c:DefUInt) : ;; lits[name(c)] = c ;; (c:DefSInt) : ;; emit-all(o, top, ports, lits, [name(c) " = " value(c) "'" width(c) "\n"]) ;; (c:DoPrim) : ;; NEED TO FIGURE OUT WHEN WIDTHS ARE NECESSARY AND EXTRACT ;; emit-all(o, top, ports, lits, [name(c) " = " op(c)]) ;; for arg in args(c) do : ;; print(o, " ") ;; emit(o, top, ports, lits, arg) ;; for const in consts(c) do : ;; print(o, " ") ;; emit(o, top, ports, lits, const) ;; print("\n") ;; (c:DefRegister) : ;; regs[name(c)] = c ;; (c:DefMemory) : ;; emit-all(o, top, ports, lits, [name(c) " : mem'" prim-width(type(c)) " " size(c) "\n"]) ;; (c:DefAccessor) : ;; accs[name(c)] = c ;; (c:Connect) : ;; val dst = name(loc(c) as Ref) ;; val src = name(exp(c) as Ref) ;; if key?(regs, dst) : ;; val reg = regs[dst] ;; emit-all(o, top, ports, lits, [dst " = reg'" prim-width(type(reg)) " 0'" prim-width(type(reg)) " " exp(c) "\n"]) ;; else if key?(accs, dst) : ;; val acc = accs[dst] ;; ;; assert(gender(acc) == OUTPUT) ;; emit-all(o, top, ports, lits, [dst " = wr " source(acc) " " index(acc) " " exp(c) "\n"]) ;; else if key?(outs, dst) : ;; val out = outs[dst] ;; emit-all(o, top, ports, lits, [top "::" dst " = out'" prim-width(type(out)) " " exp(c) "\n"]) ;; else if key?(accs, src) : ;; val acc = accs[src] ;; ;; assert(gender(acc) == INPUT) ;; emit-all(o, top, ports, lits, [dst " = rd " source(acc) " " index(acc) "\n"]) ;; else : ;; emit-all(o, top, ports, lits, [dst " = mov " exp(c) "\n"]) ;; (c:Begin) : ;; do(emit-command{o, _, top, lits, regs, accs, ports, outs}, body(c)) ;; (c:DefWire|EmptyStmt) : ;; print("") ;; (c) : ;; error("Unable to print command") ;; ;;defn emit-module (o:OutputStream, m:Module) : ;; val regs = HashTable(symbol-hash) ;; val accs = HashTable(symbol-hash) ;; val lits = HashTable(symbol-hash) ;; val outs = HashTable(symbol-hash) ;; val portz = HashTable(symbol-hash) ;; for port in ports(m) do : ;; portz[name(port)] = port ;; if gender(port) == OUTPUT : ;; outs[name(port)] = port ;; else if name(port) == `reset : ;; print-all(o, [name(m) "::reset = rst\n"]) ;; else : ;; print-all(o, [name(m) "::" name(port) " = " "in'" prim-width(type(port)) "\n"]) ;; emit-command(o, body(m), name(m), lits, regs, accs, portz, outs) ;; ;;public defn emit-circuit (o:OutputStream, c:Circuit) : ;; emit-module(o, modules(c)[0]) ;============= DRIVER ====================================== public defn run-passes (c: Circuit, p: List) : var c*:Circuit = c println("Compiling!") if PRINT-CIRCUITS : println("Original Circuit") if PRINT-CIRCUITS : print(c) defn do-stage (name:String, f: Circuit -> Circuit) : if PRINT-CIRCUITS : println(name) c* = f(c*) if PRINT-CIRCUITS : print(c*) if PRINT-CIRCUITS : println-all(["Finished " name "\n"]) ; Early passes: ; If modules have a reset defined, must be an INPUT and UInt(1) if contains(p,'a') : do-stage("Working IR", to-working-ir) if contains(p,'b') : do-stage("Resolve Kinds", resolve-kinds) if contains(p,'c') : do-stage("Make Explicit Reset", make-explicit-reset) if contains(p,'d') : do-stage("Initialize Registers", initialize-registers) if contains(p,'e') : do-stage("Infer Types", infer-types) ;if contains(p,'f') : do-stage("Infer Genders", infer-genders) ;if contains(p,'g') : do-stage("Expand Accessors", expand-accessors) ;if contains(p,'h') : do-stage("Flatten Bundles", flatten-bundles) ;if contains(p,'i') : do-stage("Expand Bundles", expand-bundles) ;if contains(p,'j') : do-stage("Expand Multi Connects", expand-multi-connects) ;if contains(p,'k') : do-stage("Expand Whens", expand-whens) ;if contains(p,'l') : do-stage("Structural Form", structural-form) ;if contains(p,'m') : do-stage("Infer Widths", infer-widths) ;if contains(p,'n') : do-stage("Pad Widths", pad-widths) ;if contains(p,'o') : do-stage("Inline Instances", inline-instances) println("Done!") ;; println("Shim for Jonathan's Passes") ;; c* = shim(c*) ;; println("Inline Modules") ;; c* = inline-modules(c*) ; c*