defpackage firrtl.passes : import core import verse import firrtl.ir2 import firrtl.ir-utils import widthsolver import firrtl-main ;============== EXCEPTIONS ================================= defclass PassException <: Exception defn PassException (msg:String) : new PassException : defmethod print (o:OutputStream, this) : print(o, msg) ;=============== WORKING IR ================================ definterface Kind defstruct WireKind <: Kind defstruct RegKind <: Kind defstruct InstanceKind <: Kind defstruct ReadAccessorKind <: Kind defstruct WriteAccessorKind <: Kind defstruct PortKind <: Kind defstruct NodeKind <: Kind ; All elems except structural memory, wires defstruct MemKind <: Kind defstruct ModuleKind <: Kind defstruct StructuralMemKind <: Kind ; Separate kind because need special treatment defstruct AccessorKind <: Kind public definterface Gender public val MALE = new Gender public val FEMALE = new Gender public val UNKNOWN-GENDER = new Gender public val BI-GENDER = new Gender defstruct WRef <: Expression : name: Symbol type: Type [multi => false] kind: Kind gender: Gender [multi => false] defstruct WRegInit <: Expression : reg: Expression name: Symbol type: Type [multi => false] gender: Gender [multi => false] defstruct WSubfield <: Expression : exp: Expression name: Symbol type: Type [multi => false] gender: Gender [multi => false] defstruct WIndex <: Expression : exp: Expression value: Int type: Type [multi => false] gender: Gender [multi => false] defstruct WDefAccessor <: Stmt : name: Symbol source: Expression index: Expression gender: Gender defstruct ConnectToIndexed <: Stmt : index: Expression locs: List exp: Expression defstruct ConnectFromIndexed <: Stmt : index: Expression loc: Expression exps: List ;================ WORKING IR UTILS ========================= defn plus (g1:Gender,g2:Gender) -> Gender : switch fn ([x,y]) : g1 == x and g2 == y : [FEMALE,MALE] : UNKNOWN-GENDER [MALE,FEMALE] : UNKNOWN-GENDER [MALE,MALE] : MALE [FEMALE,FEMALE] : FEMALE [BI-GENDER,MALE] : MALE [BI-GENDER,FEMALE] : FEMALE [MALE,BI-GENDER] : MALE [FEMALE,BI-GENDER] : FEMALE defn swap (g:Gender) -> Gender : switch {_ == g} : UNKNOWN-GENDER : UNKNOWN-GENDER MALE : FEMALE FEMALE : MALE BI-GENDER : BI-GENDER defn swap (f:Flip) -> Flip : switch {_ == f} : DEFAULT : REVERSE REVERSE : DEFAULT defn swap (d:Direction) -> Direction : switch {_ == d} : OUTPUT : INPUT INPUT : OUTPUT defn times (flip:Flip,d:Direction) -> Direction : flip * d defn times (d:Direction,flip:Flip) -> Direction : switch {_ == flip} : DEFAULT : d REVERSE : swap(d) defn times (g:Gender,flip:Flip) -> Gender : flip * g defn times (flip:Flip,g:Gender) -> Gender : switch {_ == flip} : DEFAULT : g REVERSE : swap(g) defn times (f1:Flip,f2:Flip) -> Flip : switch {_ == f2} : DEFAULT : f1 REVERSE : swap(f1) defn to-field (p:Port) -> Field : Field(name(p),REVERSE,type(p)) if direction(p) == OUTPUT : Field(name(p),REVERSE,type(p)) else if direction(p) == INPUT : Field(name(p),DEFAULT,type(p)) else : error("Shouldn't be here") defn to-dir (g:Gender) -> Direction : switch {_ == g} : MALE : INPUT FEMALE : OUTPUT defmulti gender (e:Expression) -> Gender defmethod gender (e:Expression) : MALE ; TODO, why was this OUTPUT before? It makes sense as male, not female defmethod print (o:OutputStream, g:Gender) : print{o, _} $ switch {g == _} : MALE : "male" FEMALE: "female" BI-GENDER : "bi" UNKNOWN-GENDER: "unknown" defmethod type (exp:UIntValue) -> Type : UIntType(width(exp)) defmethod type (exp:SIntValue) -> Type : SIntType(width(exp)) ;============== DEBUG STUFF ============================= public var PRINT-TYPES : True|False = false public var PRINT-KINDS : True|False = false public var PRINT-WIDTHS : True|False = false public var PRINT-TWIDTHS : True|False = false public var PRINT-GENDERS : True|False = false public var PRINT-CIRCUITS : True|False = false ;=== Printers === defmethod print (o:OutputStream, k:Kind) : print{o, _} $ match(k) : (k:WireKind) : "wire" (k:RegKind) : "reg" (k:AccessorKind) : "accessor" (k:PortKind) : "port" (k:MemKind) : "mem" (k:NodeKind) : "n" (k:ModuleKind) : "module" (k:InstanceKind) : "inst" (k:StructuralMemKind) : "smem" (k:ReadAccessorKind) : "racc" (k:WriteAccessorKind) : "wacc" defn hasGender (e:Expression|Stmt|Type|Port|Field) : e typeof WRef|WSubfield|WIndex|WDefAccessor|WRegInit defn hasWidth (e:Expression|Stmt|Type|Port|Field) : e typeof UIntType|SIntType|UIntValue|SIntValue defn hasType (e:Expression|Stmt|Type|Port|Field) : e typeof Ref|Subfield|Index|DoPrim|WritePort|ReadPort|WRef|WSubfield |WIndex|DefWire|DefRegister|DefMemory|Register |VectorType|Port|Field|WRegInit defn hasKind (e:Expression|Stmt|Type|Port|Field) : e typeof WRef defn any-debug? (e:Expression|Stmt|Type|Port|Field) : (hasGender(e) and PRINT-GENDERS) or (hasType(e) and PRINT-TYPES) or (hasWidth(e) and PRINT-WIDTHS) or (hasType(e) and PRINT-WIDTHS) or (hasKind(e) and PRINT-KINDS) defmethod print-debug (o:OutputStream, e:Expression|Stmt|Type|Port|Field) : defn wipe-width (t:Type) -> Type : match(t) : (t:UIntType) : UIntType(UnknownWidth()) (t:SIntType) : SIntType(UnknownWidth()) (t) : t if any-debug?(e) : print(o,"@") if PRINT-KINDS and hasKind(e) : print-all(o,[""]) if PRINT-TYPES and hasType(e) : print-all(o,[""]) if PRINT-TWIDTHS and hasType(e): print-all(o,[""]) if PRINT-WIDTHS and hasWidth(e): print-all(o,[""]) if PRINT-GENDERS and hasGender(e): print-all(o,[""]) defmethod print (o:OutputStream, e:WRef) : print(o,name(e)) print-debug(o,e as ?) defmethod print (o:OutputStream, e:WRegInit) : print-all(o,[name(e)]) print-debug(o,e as ?) defmethod print (o:OutputStream, e:WSubfield) : print-all(o,[exp(e) "." name(e)]) print-debug(o,e as ?) defmethod print (o:OutputStream, e:WIndex) : print-all(o,[exp(e) "." value(e)]) print-debug(o,e as ?) defmethod print (o:OutputStream, s:WDefAccessor) : print-all(o,["accessor " name(s) " = " source(s) "[" index(s) "]"]) print-debug(o,s) defmethod print (o:OutputStream, c:ConnectToIndexed) : print-all(o, [locs(c) "[" index(c) "] := " exp(c)]) print-debug(o,c as ?) defmethod print (o:OutputStream, c:ConnectFromIndexed) : print-all(o, [loc(c) " := " exps(c) "[" index(c) "]"]) print-debug(o,c as ?) defmethod map (f: Expression -> Expression, e: WRegInit) : WRegInit(f(reg(e)), name(e), type(e), gender(e)) defmethod map (f: Expression -> Expression, e: WSubfield) : WSubfield(f(exp(e)), name(e), type(e), gender(e)) defmethod map (f: Expression -> Expression, e: WIndex) : WIndex(f(exp(e)), value(e), type(e), gender(e)) defmethod map (f: Expression -> Expression, c:WDefAccessor) : WDefAccessor(name(c), f(source(c)), f(index(c)), gender(c)) defmethod map (f: Expression -> Expression, c:ConnectToIndexed) : ConnectToIndexed(f(index(c)), map(f, locs(c)), f(exp(c))) defmethod map (f: Expression -> Expression, c:ConnectFromIndexed) : ConnectFromIndexed(f(index(c)), f(loc(c)), map(f, exps(c))) defmethod map (f: Type -> Type, e: WRef) : WRef(name(e), f(type(e)), kind(e), gender(e)) defmethod map (f: Type -> Type, e: WRegInit) : WRegInit(reg(e), name(e), f(type(e)), gender(e)) defmethod map (f: Type -> Type, e: WSubfield) : WSubfield(exp(e), name(e), f(type(e)), gender(e)) defmethod map (f: Type -> Type, e: WIndex) : WIndex(exp(e), value(e), f(type(e)), gender(e)) ;================= Bring to Working IR ======================== ; Returns a new Circuit with Refs, Subfields, Indexes and DefAccessors ; replaced with IR-internal nodes that contain additional ; information (kind, gender) defn to-working-ir (c:Circuit) : defn to-exp (e:Expression) : match(map(to-exp,e)) : (e:Ref) : WRef(name(e), type(e), NodeKind(), UNKNOWN-GENDER) (e:Subfield) : if name(e) == `init : WRegInit(exp(e), to-symbol("~.init" % [name(exp(e) as WRef)]), type(e), UNKNOWN-GENDER) else : WSubfield(exp(e), name(e), type(e), UNKNOWN-GENDER) (e:Index) : WIndex(exp(e), value(e), type(e), UNKNOWN-GENDER) (e) : e defn to-stmt (s:Stmt) : match(map(to-exp,s)) : (s:DefAccessor) : WDefAccessor(name(s),source(s),index(s), UNKNOWN-GENDER) (s) : map(to-stmt,s) Circuit(modules*, main(c)) where : val modules* = for m in modules(c) map : Module(name(m), ports(m), to-stmt(body(m))) ;=============== Resolve Kinds ============================= ; It is useful for the compiler to know information about ; objects referenced. This information is stored in the kind ; field in WRef. This pass walks the graph and returns a new ; Circuit where all WRef kinds are resolved defn resolve-kinds (c:Circuit) : defn resolve (body:Stmt, kinds:HashTable) : defn resolve-stmt (s:Stmt) -> Stmt : map{resolve-expr,_} $ map(resolve-stmt,s) defn resolve-expr (e:Expression) -> Expression : match(e) : (e:WRef) : WRef(name(e),type(e),kinds[name(e)],gender(e)) (e) : map(resolve-expr,e) resolve-stmt(body) defn find (m:Module, kinds:HashTable) : defn find-stmt (s:Stmt) -> Stmt : match(s) : (s:DefWire) : kinds[name(s)] = NodeKind() (s:DefNode) : kinds[name(s)] = NodeKind() (s:DefRegister) : kinds[name(s)] = RegKind() (s:DefInstance) : kinds[name(s)] = InstanceKind() (s:DefMemory) : kinds[name(s)] = MemKind() (s:WDefAccessor) : kinds[name(s)] = AccessorKind() (s) : false map(find-stmt,s) kinds[name(m)] = ModuleKind() for p in ports(m) do : kinds[name(p)] = PortKind() find-stmt(body(m)) defn resolve-kinds (m:Module, c:Circuit) -> Module : val kinds = HashTable(symbol-hash) for m in modules(c) do : kinds[name(m)] = ModuleKind() find(m,kinds) val body! = resolve(body(m),kinds) Module(name(m),ports(m),body!) Circuit(modules*, main(c)) where : val modules* = for m in modules(c) map : resolve-kinds(m,c) ;=============== MAKE EXPLICIT RESET ======================= ; All modules have an implicit reset signal - however, the ; programmer can explicitly reference this signal if desired. ; This pass makes all implicit resets explicit while ; preserving any previously explicit resets ; If reset is not explicitly passed to instantiations, then this ; pass autmatically connects the parent module's reset to the ; instantiation's reset defn make-explicit-reset (c:Circuit) : defn find-explicit (c:Circuit) -> List : defn explicit? (m:Module) -> True|False : for p in ports(m) any? : name(p) == `reset val explicit-reset = Vector() for m in modules(c) do: if explicit?(m) : add(explicit-reset,name(m)) to-list(explicit-reset) defn make-explicit (m:Module, explicit-reset:List) -> Module : defn route-reset (s:Stmt) -> Stmt : match(s) : (s:DefInstance) : val iref = WSubfield(WRef(name(s), UnknownType(), InstanceKind(), UNKNOWN-GENDER),`reset,UnknownType(),UNKNOWN-GENDER) val pref = WRef(`reset, UnknownType(), PortKind(), MALE) Begin(to-list([s,Connect(iref,pref)])) (s) : map(route-reset,s) var ports! = ports(m) if not contains?(explicit-reset,name(m)) : ports! = append(ports(m),list(Port(`reset,INPUT,UIntType(IntWidth(1))))) val body! = route-reset(body(m)) Module(name(m),ports!,body!) defn make-explicit-reset (m:Module, c:Circuit) -> Module : val explicit-reset = find-explicit(c) make-explicit(m,explicit-reset) Circuit(modules*, main(c)) where : val modules* = for m in modules(c) map : make-explicit-reset(m,c) ;============== INFER TYPES ================================ ; This pass infers the type field in all IR nodes by updating ; and passing an environment to all statements in pre-order ; traversal, and resolving types in expressions in post- ; order traversal. ; Type propagation for primary ops are defined here. ; Notable cases: LetRec requires updating environment before ; resolving the subexpressions in its elements. ; Type errors are not checked in this pass, as this is ; postponed for a later/earlier pass. defn get-primop-rettype (e:DoPrim) -> Type : defn u () : UIntType(UnknownWidth()) defn s () : SIntType(UnknownWidth()) defn u-and (op1:Expression,op2:Expression) : match(type(op1), type(op2)) : (t1:UIntType, t2:UIntType) : u() (t1:SIntType, t2) : s() (t1, t2:SIntType) : s() (t1, t2) : UnknownType() defn of-type (op:Expression) : match(type(op)) : (t:UIntType) : u() (t:SIntType) : s() (t) : UnknownType() ;println-all(["Inferencing primop type: " e]) switch {op(e) == _} : ADD-OP : u-and(args(e)[0],args(e)[1]) ADD-UU-OP : u() ADD-US-OP : s() ADD-SU-OP : s() ADD-SS-OP : s() SUB-OP : s() SUB-UU-OP : s() SUB-US-OP : s() SUB-SU-OP : s() SUB-SS-OP : s() MUL-OP : u-and(args(e)[0],args(e)[1]) MUL-UU-OP : u() MUL-US-OP : s() MUL-SU-OP : s() MUL-SS-OP : s() DIV-OP : u-and(args(e)[0],args(e)[1]) DIV-UU-OP : u() DIV-US-OP : s() DIV-SU-OP : s() DIV-SS-OP : s() MOD-OP : of-type(args(e)[0]) MOD-UU-OP : u() MOD-US-OP : u() MOD-SU-OP : s() MOD-SS-OP : s() QUO-OP : u-and(args(e)[0],args(e)[1]) QUO-UU-OP : u() QUO-US-OP : s() QUO-SU-OP : s() QUO-SS-OP : s() REM-OP : of-type(args(e)[1]) REM-UU-OP : u() REM-US-OP : s() REM-SU-OP : u() REM-SS-OP : s() ADD-WRAP-OP : u-and(args(e)[0],args(e)[1]) ADD-WRAP-UU-OP : u() ADD-WRAP-US-OP : s() ADD-WRAP-SU-OP : s() ADD-WRAP-SS-OP : s() SUB-WRAP-OP : u-and(args(e)[0],args(e)[1]) SUB-WRAP-UU-OP : u() SUB-WRAP-US-OP : s() SUB-WRAP-SU-OP : s() SUB-WRAP-SS-OP : s() LESS-OP : u() LESS-UU-OP : u() LESS-US-OP : u() LESS-SU-OP : u() LESS-SS-OP : u() LESS-EQ-OP : u() LESS-EQ-UU-OP : u() LESS-EQ-US-OP : u() LESS-EQ-SU-OP : u() LESS-EQ-SS-OP : u() GREATER-OP : u() GREATER-UU-OP : u() GREATER-US-OP : u() GREATER-SU-OP : u() GREATER-SS-OP : u() GREATER-EQ-OP : u() GREATER-EQ-UU-OP : u() GREATER-EQ-US-OP : u() GREATER-EQ-SU-OP : u() GREATER-EQ-SS-OP : u() EQUAL-OP : u() EQUAL-UU-OP : u() EQUAL-SS-OP : u() MUX-OP : of-type(args(e)[0]) MUX-UU-OP : u() MUX-SS-OP : s() PAD-OP : of-type(args(e)[0]) PAD-U-OP : u() PAD-S-OP : s() AS-UINT-OP : u() AS-UINT-U-OP : u() AS-UINT-S-OP : u() AS-SINT-OP : s() AS-SINT-U-OP : s() AS-SINT-S-OP : s() SHIFT-LEFT-OP : of-type(args(e)[0]) SHIFT-LEFT-U-OP : u() SHIFT-LEFT-S-OP : s() SHIFT-RIGHT-OP : of-type(args(e)[0]) SHIFT-RIGHT-U-OP : u() SHIFT-RIGHT-S-OP : s() CONVERT-OP : s() CONVERT-U-OP : s() CONVERT-S-OP : s() BIT-AND-OP : u() BIT-OR-OP : u() BIT-XOR-OP : u() CONCAT-OP : u() BIT-SELECT-OP : u() BITS-SELECT-OP : u() defn type (m:Module) -> Type : BundleType(for p in ports(m) map : to-field(p)) defn get-type (b:Symbol,l:List>) -> Type : val ma = for kv in l find : b == key(kv) if ma != false : val ret = value(ma as KeyValue) ret else : UnknownType() defn bundle-field-type (v:Type,s:Symbol) -> Type : match(v) : (v:BundleType) : val ft = for p in fields(v) find : name(p) == s if ft != false : type(ft as Field) else : UnknownType() (v) : error(string-join(["Accessing subfield " s " on a non-Bundle type."])) defn get-vector-subtype (v:Type) -> Type : match(v) : (v:VectorType) : type(v) (v) : UnknownType() defn infer-exp-types (e:Expression, l:List>) -> Expression : val r = map(infer-exp-types{_,l},e) match(r) : (e:WRef) : WRef(name(e), get-type(name(e),l),kind(e),gender(e)) (e:WSubfield) : WSubfield(exp(e),name(e), bundle-field-type(type(exp(e)),name(e)),gender(e)) (e:WRegInit) : WRegInit(reg(e),name(e),get-type(name(reg(e) as WRef),l),gender(e)) (e:WIndex) : WIndex(exp(e),value(e), get-vector-subtype(type(exp(e))),gender(e)) (e:DoPrim) : DoPrim(op(e),args(e),consts(e),get-primop-rettype(e)) (e:ReadPort) : ReadPort(mem(e),index(e),get-vector-subtype(type(mem(e))),enable(e)) (e:WritePort) : WritePort(mem(e),index(e),get-vector-subtype(type(mem(e))),enable(e)) (e:UIntValue|SIntValue) : e defn infer-types (s:Stmt, l:List>) -> [Stmt List>] : match(map(infer-exp-types{_,l},s)) : (s:Begin) : var env = l val body* = for s in body(s) map : val [s*,l*] = infer-types(s,env) env = l* s* [Begin(body*),env] (s:DefWire) : [s,List(name(s) => type(s),l)] (s:DefRegister) : [s,List(name(s) => type(s),l)] (s:DefMemory) : [s,List(name(s) => type(s),l)] (s:DefInstance) : [s, List(name(s) => type(module(s)),l)] (s:DefNode) : [s, List(name(s) => type(value(s)),l)] (s:WDefAccessor) : [s, List(name(s) => get-vector-subtype(type(source(s))),l)] (s:Conditionally) : val [s*,l*] = infer-types(conseq(s),l) val [s**,l**] = infer-types(alt(s),l) [Conditionally(pred(s),s*,s**),l] (s:Connect|EmptyStmt) : [s,l] defn infer-types (m:Module, l:List>) -> Module : val ptypes = for p in ports(m) map : name(p) => type(p) ;println-all(append(ptypes,l)) val [s,l*] = infer-types(body(m),append(ptypes, l)) Module(name(m),ports(m),s) defn infer-types (c:Circuit) -> Circuit : val l = for m in modules(c) map : name(m) => BundleType(map(to-field,ports(m))) ;println-all(l) Circuit{ _, main(c) } $ for m in modules(c) map : infer-types(m,l) ;============= RESOLVE ACCESSOR GENDER ============================ ; To ensure a proper circuit, we must ensure that assignments ; only work on expressions that can be assigned to. Similarly, ; we must ensure that only expressions that can be read from ; are used to assign from. This invariant requires each ; expression's gender to be inferred. ; Various elements can be bi-gender (e.g. wires) and can ; thus be treated as either female or male. Conversely, some ; elements are single-gender (e.g. accessors, ports). ; Because accessor gender is not known during declaration, ; this pass requires iterating until a fixed point is reached. defn bundle-field-flip (n:Symbol,t:Type) -> Flip : match(t) : (b:BundleType) : val field = for f in fields(b) find : name(f) == n match(field): (f:Field) : flip(f) (f) : error(string-join(["Could not find " n " in bundle "])) (b) : error(string-join(["Accessing subfield " n " on a non-Bundle type."])) defn resolve-genders (c:Circuit) : defn resolve-module (m:Module, genders:HashTable) -> Module : var done? = true defn resolve-iter (m:Module) -> Module : val body* = resolve-stmt(body(m)) Module(name(m),ports(m),body*) defn get-gender (n:Symbol,g:Gender) -> Gender : defn force-gender (n:Symbol,g:Gender) -> Gender : genders[n] = g done? = false g val entry = for kv in genders find : key(kv) == n match(entry) : (e:KeyValue) : val value = value(e) if value == UNKNOWN-GENDER and g == UNKNOWN-GENDER : g else if value != UNKNOWN-GENDER and g == UNKNOWN-GENDER : value else if value == UNKNOWN-GENDER and g != UNKNOWN-GENDER : force-gender(n,g) else : value (e:False) : force-gender(n,g) defn resolve-stmt (s:Stmt) -> Stmt : match(s) : (s:DefWire) : get-gender(name(s),BI-GENDER) s (s:DefRegister) : get-gender(name(s),BI-GENDER) s (s:DefMemory) : get-gender(name(s),BI-GENDER) s (s:DefNode) : DefNode(name(s),resolve-expr(value(s),get-gender(name(s),MALE))) (s:DefInstance) : get-gender(name(s),FEMALE) DefInstance(name(s),resolve-expr(module(s),FEMALE)) (s:WDefAccessor) : val gender* = get-gender(name(s),UNKNOWN-GENDER) val index* = resolve-expr(index(s),MALE) val source* = resolve-expr(source(s),gender*) WDefAccessor(name(s),source*,index*,gender*) (s:Connect) : Connect(resolve-expr(loc(s),FEMALE),resolve-expr(exp(s),MALE)) (s:Conditionally) : val pred* = resolve-expr(pred(s),MALE) val conseq* = resolve-stmt(conseq(s)) val alt* = resolve-stmt(alt(s)) Conditionally(pred*,conseq*,alt*) (s) : map(resolve-stmt,s) defn resolve-expr (e:Expression,desired:Gender) -> Expression : match(e) : (e:WRef) : val gender = get-gender(name(e),desired) WRef{name(e),type(e),kind(e),_} $ if gender == BI-GENDER : desired else : gender (e:WRegInit) : val gender = get-gender(name(reg(e) as WRef),desired) WRegInit{reg(e),name(e),type(e),_} $ if gender == BI-GENDER : desired else : gender (e:WSubfield) : val field-flip = bundle-field-flip(name(e),type(exp(e))) val exp* = resolve-expr(exp(e),field-flip * desired) val gender* = field-flip * gender(exp*) WSubfield(exp*,name(e),type(e),gender*) (e:WIndex) : val exp* = resolve-expr(exp(e),desired) val gender* = gender(exp*) WIndex(exp*,value(e),type(e),gender*) (e) : map(resolve-expr{_,MALE},e) var module* = resolve-iter(m) println(genders) while not done? : done? = true module* = resolve-iter(m) println(genders) module* defn resolve-genders (m:Module, c:Circuit) -> Module : val genders = HashTable(symbol-hash) resolve-module(m,genders) Circuit(modules*, main(c)) where : val modules* = for m in modules(c) map : resolve-genders(m,c) ;;============== EXPAND ACCESSORS ================================ ; This pass expands non-memory accessors into ConnectToIndexed or ; ConnectFromIndexed. All elements of the vector are ; explicitly written out, then indexed. Depending on the gender ; of the accessor, it is transformed into ConnectToIndexed (male) or ; ConnectFromIndexed (female) ; Eg: defn expand-vector (e:Expression) -> List : val t = type(e) as VectorType for i in 0 to size(t) map-append : list(WIndex(e,i,type(t),gender(e as ?))) ;always be WRef|WSubfield|WIndex defn expand-stmt (s:Stmt) -> Stmt : match(s) : (s:WDefAccessor) : println-all(["Matched WDefAcc with " name(s)]) val mem? = match(source(s)) : (e:WRef) : kind(e) typeof MemKind (e) : false if mem? : s else : val vtype = type(type(source(s)) as VectorType) val wire = DefWire(name(s),vtype) switch {gender(s) == _} : MALE : Begin{list(wire,_)} $ ConnectFromIndexed( index(s), WRef(name(wire),vtype,NodeKind(),FEMALE), expand-vector(source(s))) FEMALE: Begin{list(wire,_)} $ ConnectToIndexed( index(s), expand-vector(source(s)), WRef(name(wire),vtype,NodeKind(),MALE)) (s) : map(expand-stmt,s) defn expand-accessors (c:Circuit) : Circuit(modules*, main(c)) where : val modules* = for m in modules(c) map : Module(name(m),ports(m),expand-stmt(body(m))) ;;=============== LOWERING TO GROUND TYPES ============================= ; All non-ground (elevated) types (Vectors, Bundles) are expanded out to ; individual ground types. ; This pass involves filling a table mapping the name of elevated types ; to the lowered ground expression names and genders. This allows ; references to be resolved. defn num-elems (t:Type) -> Int : match(t) : (t:BundleType) : var sum = 0 for f in fields(t) do : sum = sum + num-elems(type(f)) sum (t:VectorType) : size(t) * num-elems(type(t)) (t) : 1 defn index-of-elem (t:BundleType, s:Symbol) -> Int : var sum = 0 label ret : for f in fields(t) do : if s == name(f) : ret(sum) else : sum = sum + num-elems(type(f)) error("Shouldn't be here") defn lower-ports (m:Module, table:HashTable>>) -> List : val entries = table[name(m)] val directions = for p in ports(m) map-append : to-list(for i in 0 to num-elems(type(p)) stream : direction(p)) for (kv in entries, d in directions) map : val exp = key(kv) as WRef val dir* = d * value(kv) Port(name(exp),dir*,type(exp)) defn lower (body:Stmt, table:HashTable>>) -> Stmt : defn lower-stmt (s:Stmt) -> Stmt : defn add-to-table (y:Symbol,k:KeyValue,ctable:HashTable>>) : val contains? = for x in ctable any? : key(x) == y if contains? : ctable[y] = append(ctable[y],list(k)) else : ctable[y] = list(k) defn is-instance (e:Expression) -> True|False : match(e) : (e:WRef) : kind(e) == InstanceKind() (e) : false defn calc-gender (g:Gender, e:Expression) -> Gender : match(e) : (e:WRef) : gender(e) (e:WRegInit) : gender(e) (e:WSubfield) : if is-instance(exp(e)) : gender(e) else : calc-gender(bundle-field-flip(name(e),type(exp(e))) * g,exp(e)) (e:WIndex) : gender(e) (e) : g match(s) : (s:DefWire) : Begin{_} $ for t in table[name(s)] map : DefWire(name(key(t) as WRef),type(key(t))) (s:DefRegister) : Begin{_} $ for t in table[name(s)] map : DefRegister(name(key(t) as WRef),type(key(t))) (s:DefInstance) : s (s:DefNode) : val s* = Begin $ list( DefWire(name(s),type(value(s))), Connect(WRef(name(s),type(value(s)),NodeKind(),FEMALE),value(s))) lower-stmt(s*) (s:Connect) : Begin{_} $ for (l in expand-expr(loc(s)), r in expand-expr(exp(s))) map : val lgender = calc-gender(FEMALE,loc(s)) * value(l) val rgender = calc-gender(MALE,exp(s)) * value(r) switch fn ([x,y]) : lgender == x and rgender == y : [FEMALE,MALE] : Connect(key(l),key(r)) [MALE,FEMALE] : Connect(key(r),key(l)) (s:WDefAccessor) : Begin{_} $ for (l in table[name(s)], r in expand-expr(source(s))) map: WDefAccessor(name(key(l) as WRef),key(r),index(s),value(r) * gender(s)) (s:ConnectFromIndexed) : Begin(ls) where : val ctable = HashTable>>(symbol-hash) for e in exps(s) do : for (r in expand-expr(e),l in expand-expr(loc(s))) do : add-to-table(name(key(l) as WRef),r,ctable) val ls = for l in expand-expr(loc(s)) map : val cg = calc-gender(FEMALE,loc(s)) val lgender = cg * value(l) var rgender = BI-GENDER val exps = for e in ctable[name(key(l) as WRef)] map : rgender = rgender + (swap(cg) * value(e)) key(e) switch fn ([x,y]) : lgender == x and rgender == y : [FEMALE,MALE] : ConnectFromIndexed(index(s),key(l),exps) [MALE,FEMALE] : ConnectToIndexed(index(s),exps,key(l)) (s:ConnectToIndexed) : Begin(ls) where : val ctable = HashTable>>(symbol-hash) for ls in locs(s) do : for (l in expand-expr(ls),r in expand-expr(exp(s))) do : add-to-table(name(key(r) as WRef),l,ctable) val ls = for r in expand-expr(exp(s)) map : val n = name(key(r) as WRef) val cg = calc-gender(MALE,exp(s)) val rgender = cg * value(r) var lgender = BI-GENDER val locs = for l in ctable[n] map : lgender = lgender + (swap(cg) * value(l)) key(l) switch fn ([x,y]) : lgender == x and rgender == y : [FEMALE,MALE] : ConnectToIndexed(index(s),locs,key(r)) [MALE,FEMALE] : ConnectFromIndexed(index(s),key(r),locs) (s:DefMemory) : Begin{_} $ for t in table[name(s)] map : DefMemory(name(key(t) as WRef),type(key(t)) as VectorType) (s) : map(lower-stmt,s) defn expand-expr (e:Expression) -> List> : match(e) : (e:WRef) : table[name(e)] (e:WRegInit) : table[name(e)] (e:WSubfield) : val exps = expand-expr(exp(e)) val begin = index-of-elem(type(exp(e)) as BundleType,name(e)) val len = num-elems(type(e)) headn(tailn(exps,begin),len) (e:WIndex) : val exps = expand-expr(exp(e)) val len = num-elems(type(e)) headn(tailn(exps,len * value(e)),len) (e) : list(KeyValue(e, DEFAULT)) println(table) lower-stmt(body) defn get-entries (n:Symbol,t:Type) -> List> : defn uniquify (w:WRef) -> WRef : val name* = symbol-join([n "$" name(w)]) WRef(name*,type(w),kind(w),gender(w)) match(t) : (t:BundleType) : for f in fields(t) map-append : val es = get-entries(name(f),type(f)) for e in es map : uniquify(key(e)) => value(e) * flip(f) (t:VectorType) : for i in 0 to size(t) map-append : val es = get-entries(to-symbol(i),type(t)) for e in es map : uniquify(key(e)) => value(e) (t) : list(KeyValue(WRef(n,t,NodeKind(),UNKNOWN-GENDER),DEFAULT)) defn lower-module (m:Module,table:HashTable>>) -> Module : defn build-table-ports (ports:List) : for p in ports do : table[name(p)] = get-entries(name(p),type(p)) defn build-table-stmt (stmt:Stmt) -> Stmt: match(stmt) : (s:DefWire) : table[name(s)] = get-entries(name(s),type(s)) (s:DefRegister) : val regs = get-entries(name(s),type(s)) val init-sym = symbol-join([name(s),`.init]) val init-regs = for r in regs map : val [e f] = [key(r) value(r)] WRegInit(e,symbol-join([name(e),`.init]),type(e),gender(e)) => f table[name(s)] = regs table[init-sym] = init-regs (s:DefInstance) : val r = WRef(name(s),type(module(s)),InstanceKind(),FEMALE) val ports = table[name(module(s) as WRef)] table[name(s)] = for w in ports map-append : list(KeyValue(WSubfield(r,name(key(w) as WRef),type(key(w) as WRef),UNKNOWN-GENDER), value(w))) (s:DefMemory) : table[name(s)] = for x in get-entries(name(s),type(type(s) as VectorType)) map : val [w f] = [key(x) value(x)] WRef(name(w),VectorType(type(w),size(type(s) as VectorType)),kind(w),gender(w)) => f (s:DefNode) : table[name(s)] = get-entries(name(s),type(value(s))) (s:WDefAccessor) : table[name(s)] = get-entries(name(s),type(type(source(s)) as VectorType)) (s) : map(build-table-stmt,s) stmt build-table-ports(ports(m)) build-table-stmt(body(m)) Module(name(m),ports*,body*) where : val body* = lower(body(m),table) val ports* = lower-ports(m,table) defn lower-to-ground (c:Circuit) -> Circuit : val table = HashTable>>(symbol-hash) defn build-table-module (m:Module) -> ? : table[name(m)] = for p in ports(m) map-append : get-entries(name(p),type(p)) for m in modules(c) map : build-table-module(m) Circuit(modules*, main(c)) where : val modules* = for m in modules(c) map : lower-module(m,table) ;;=========== CONVERT MULTI CONNECTS to WHEN ================ ; This pass converts ConnectToIndexed and ConnectFromIndexed ; into a series of when statements. TODO what about initial ; values? defn expand-connect-indexed-stmt (s: Stmt) -> Stmt : defn equality (e1:Expression,e2:Expression) -> Expression : DoPrim(EQUAL-UU-OP,list(e1,e2),List(),UIntType(UnknownWidth())) match(s) : (s:ConnectToIndexed) : Begin $ if length(locs(s)) == 0 : list(EmptyStmt()) else : List(Connect(head(locs(s)),exp(s)), to-list $ for (i in 1 to false, l in tail(locs(s))) stream : Conditionally( equality(index(s),UIntValue(i,UnknownWidth())), Connect(l,exp(s)), EmptyStmt()) ) (s:ConnectFromIndexed) : Begin $ if length(exps(s)) == 0 : list(EmptyStmt()) else : List(Connect(loc(s),head(exps(s))), to-list $ for (i in 1 to false, e in tail(exps(s))) stream : Conditionally( equality(index(s),UIntValue(i,UnknownWidth())), Connect(loc(s),e), EmptyStmt()) ) (s) : map(expand-connect-indexed-stmt,s) defn expand-connect-indexed (m: Module) -> Module : Module(name(m),ports(m),expand-connect-indexed-stmt(body(m))) defn expand-connect-indexed (c: Circuit) -> Circuit : Circuit(modules*, main(c)) where : val modules* = for m in modules(c) map : expand-connect-indexed(m) ;======= MAKE EXPLICIT REGISTER INITIALIZATION ============= ; This pass replaces the reg.init construct by creating a new ; wire that holds the value at initialization. This wire ; is then connected to the register conditionally on reset, ; at the end of the scope containing the register ; declaration ; If a register has no inital value, the wire is connected to ; a NULL node. Later passes will remove these with the base ; case Mux(reset,NULL,a) -> a, and Mux(reset,a,NULL) -> a. ; This ensures proper behavior if this pass is run multiple ; times. defn initialize-registers (c:Circuit) : defn to-wire-name (y:Symbol) : to-symbol("~$init" % [y]) defn add-when (s:Stmt,h:HashTable) -> Stmt : var inits = List() for kv in h do : val refreg = WRef(key(kv),value(kv),RegKind(),FEMALE) val refwire = WRef(to-wire-name(key(kv)),value(kv),NodeKind(),MALE) val connect = Connect(refreg,refwire) inits = append(inits,list(connect)) if empty?(inits) : s else : val pred = WRef(`reset, UIntType(IntWidth(1)), PortKind(), MALE) val when-reset = Conditionally(pred,Begin(inits),Begin(List())) Begin(list(s,when-reset)) defn rename (s:Stmt,h:HashTable) -> [Stmt HashTable] : val t = HashTable(symbol-hash) defn rename-expr (e:Expression) -> Expression : match(map(rename-expr,e)) : (e:WRegInit) : val new-name = to-wire-name(name(reg(e) as WRef)) WRef(new-name,type(reg(e)),RegKind(),gender(e)) (e) : e defn rename-stmt (s:Stmt) -> Stmt : match(map(rename-stmt,s)) : (s:DefRegister) : if h[name(s)] : t[name(s)] = type(s) Begin(list(s,DefWire(to-wire-name(name(s)),type(s)))) else : s (s) : map(rename-expr,s) [rename-stmt(s) t] defn init? (y:Symbol,s:Stmt) -> True|False : var used? = false defn has? (e:Expression) -> Expression : match(map(has?,e)) : (e:WRegInit) : if name(reg(e) as WRef) == y : used? = true (e) : map(has?,e) e map(has?,s) used? defn using-init (s:Stmt,h:HashTable) -> Stmt : match(s) : (s:DefRegister) : h[name(s)] = false (s) : for x in h do : h[key(x)] = value(x) or init?(key(x),s) map(using-init{_,h},s) defn explicit-init-scope (s:Stmt) -> Stmt : val h = HashTable(symbol-hash) using-init(s,h) ;println(h) val [s* t] = rename(s,h) add-when(s*,t) Circuit(modules*, main(c)) where : val modules* = for m in modules(c) map : Module(name(m), ports(m), body*) where : val body* = explicit-init-scope(body(m)) ;;================ EXPAND WHENS ============================= ; This pass does three things: remove last connect semantics, ; remove conditional blocks, and eliminate concept of scoping. ; First, we scan the circuit to build a table mapping references ; to the final assigned value, represented with SymbolicValues. ; Within a scope, we remove the last connect symantics to get ; the final value. When leaving a scope, the resulting table ; is merged with the parent scope by using the SVMux. ; We also collect the kind of reference to know how to declare ; it in a following stage. ; Second, we use the table to declare each reference, then ; assign to each once. This is relatively straightforward ; except calculating the WritePort/ReadPort enables. ; Finally, we scan the table to remove redundant values ; The WritePort enable is calculated by returning 1 for all conditions ; for which the corresponding symbolic value is not SVNul. ; The ReadPort enable is calcuated by scanning all entries in ; the table for when this is referenced (a read). All conditions ; are accumulated and OR'ed together. ; ======== Expression Computation Library =========== val zero = UIntValue(0,IntWidth(1)) val one = UIntValue(1,IntWidth(1)) defmethod equal? (e1:Expression,e2:Expression) -> True|False : match(e1,e2) : (e1:UIntValue,e2:UIntValue) : if value(e1) == value(e2) : match(width(e1), width(e2)) : (w1:IntWidth,w2:IntWidth) : width(w1) == width(w2) (w1,w2) : false else : false (e1:SIntValue,e2:SIntValue) : if value(e1) == value(e2) : match(width(e1), width(e2)) : (w1:IntWidth,w2:IntWidth) : width(w1) == width(w2) else : false (e1:WRef,e2:WRef) : name(e1) == name(e2) ;(e1:DoPrim,e2:DoPrim) : TODO (e1:WRegInit,e2:WRegInit) : reg(e1) == reg(e2) and name(e1) == name(e2) (e1:WSubfield,e2:WSubfield) : name(e1) == name(e2) (e1,e2) : false defn AND (e1:Expression,e2:Expression) -> Expression : if e1 == e2 : e1 else if e1 == zero or e2 == zero : zero else if e1 == one : e2 else if e2 == one : e1 else : DoPrim(BIT-AND-OP,list(e1,e2),list(),UIntType(IntWidth(1))) defn OR (e1:Expression,e2:Expression) -> Expression : if e1 == e2 : e1 else if e1 == one or e2 == one : one else if e1 == zero : e2 else if e2 == zero : e1 else : DoPrim(BIT-OR-OP,list(e1,e2),list(),UIntType(IntWidth(1))) defn NOT (e1:Expression) -> Expression : if e1 == one : zero else if e1 == zero : one else : DoPrim(EQUAL-UU-OP,list(e1,zero),list(),UIntType(IntWidth(1))) defn children (e:Expression) -> List : val es = Vector() defn f (e:Expression) : add(es,e) e map(f,e) to-list(es) ; ======= Symbolic Value Library ========== public definterface SymbolicValue public defstruct SVExp <: SymbolicValue : exp : Expression public defstruct SVMux <: SymbolicValue : pred : Expression conseq : SymbolicValue alt : SymbolicValue public defstruct SVNul <: SymbolicValue defmethod print (o:OutputStream, sv:SymbolicValue) : match(sv) : (sv: SVExp) : print(o, exp(sv)) (sv: SVMux) : print-all(o, ["(" pred(sv) " ? " conseq(sv) " : " alt(sv) ")"]) (sv: SVNul) : print(o, "SVNUL") defmulti map (f: SymbolicValue -> SymbolicValue, sv:?T&SymbolicValue) -> T defmethod map (f: SymbolicValue -> SymbolicValue, sv:SymbolicValue) -> SymbolicValue : match(sv) : (sv: SVMux) : SVMux(pred(sv),f(conseq(sv)),f(alt(sv))) (sv) : sv defn do (f:SymbolicValue -> ?, s:SymbolicValue) -> False : for x in s map : f(x) x false defn dor (f:SymbolicValue -> ?, e:SymbolicValue) -> False : do(f,e) for x in e map : dor(f,x) x false defmethod equal? (a:SymbolicValue,b:SymbolicValue) -> True|False : match(a,b) : (a:SVNul,b:SVNul) : true (a:SVExp,b:SVExp) : exp(a) == exp(b) (a:SVMux,b:SVMux) : pred(a) == pred(b) and conseq(a) == conseq(b) and alt(a) == alt(b) (a,b) : false ;TODO add invert to primop defn optimize (sv:SymbolicValue) -> SymbolicValue : match(map(optimize,sv)) : (sv:SVMux) : if conseq(sv) == alt(sv) : conseq(sv) else : match(conseq(sv),alt(sv)) : (c:SVExp,a:SVExp) : if exp(c) == one and exp(a) == zero : SVExp(pred(sv)) else if exp(c) == zero and exp(a) == one : SVExp(NOT(pred(sv))) else if exp(c) == exp(a) : c else : sv (c,a) : sv (sv) : sv ; ========== Expand When Utilz ========== defn deepcopy (t:HashTable) -> HashTable : t0 where : val t0 = HashTable(symbol-hash) for x in t do : t0[key(x)] = value(x) defn get-unique-keys (ts:List>) -> Vector : t0 where : val t0 = Vector() for v in ts do : for t in v do : val duplicate? = for x in t0 any? : x == key(t) if not duplicate? : add(t0,key(t)) defn has-nul? (sv:SymbolicValue) -> True|False : var has? = false if sv typeof SVNul : has? = true for x in sv dor : if x typeof SVNul : has? = true has? defn remove-nul (sv:SymbolicValue) -> SymbolicValue : match(map(remove-nul,sv)) : (sv:SVMux) : match(conseq(sv),alt(sv)) : (c,a:SVNul) : c (c:SVNul,a) : a (c,a) : sv (sv) : sv defn to-exp (sv:SymbolicValue) -> Expression : match(remove-nul(sv)) : (sv:SVMux) : DoPrim(MUX-UU-OP, list(pred(sv),to-exp(conseq(sv)),to-exp(alt(sv))), list(), UIntType(IntWidth(1))) (sv:SVExp) : exp(sv) (sv) : error("Shouldn't be here") defn reduce-or (l:List) -> True|False : if length(l) == 0 : false else : head(l) or reduce-or(tail(l)) defn reduce-or (l:List) -> Expression : if length(l) == 0 : zero else : OR(head(l) reduce-or(tail(l))) ; ========= Expand When Pass =========== ; TODO: replace stmt with wr (WRefs). The KIND of wref will help figure out what to emit as far as ; declarations, especially with not declaring anything for ports. We need WRefs, and not just Kinds, ; because we need the name of the symbolic expression. I think? Or maybe we can use the key? ; 1) Build Table, Build Declaration List defn expand-whens (assign:HashTable, kinds:HashTable, stmts:HashTable, decs:Vector, enables:HashTable) -> Stmt : for x in assign do : val [n sv] = [key(x) value(x)] match(kinds[n]) : (k:WriteAccessorKind) : ;First create WritePort and assign from accessor-turned-wire val s = stmts[n] as WDefAccessor val t = type(type(source(s)) as VectorType) val ref = WRef(n,t,k,MALE) val wp = WritePort(source(s),index(s),t,to-exp(enables[n])) add(decs,Connect(wp,ref)) ;If initialized, assign input to accessor-turned-wire val sv = remove-nul(assign[n]) if sv == SVNul : println("Uninitialized: ~" % [to-string(n)]) ;TODO actually collect error else : add(decs,Connect(ref,to-exp(sv))) (k:ReadAccessorKind) : val s = stmts[n] as WDefAccessor val t = type(type(source(s)) as VectorType) val ref = WRef(n,t,k,FEMALE) val rp = ReadPort(source(s),index(s),t,to-exp(enables[n])) add(decs,Connect(ref,rp)) (k:RegKind) : val s = stmts[n] as DefRegister val ref = WRef(n,type(s),k,FEMALE) val sv = remove-nul(assign[n]) val reg = if sv typeof SVNul : Register(type(s),UIntValue(0,width(type(s) as ?)),zero) else : Register(type(s),to-exp(sv),to-exp(enables[n])) add(decs,Connect(ref,reg)) (k:InstanceKind) : val s = stmts[n] as DefInstance val x = split(to-string(n),'.') val f = to-symbol(split(to-string(n),'.')[1]) val ref = WSubfield(module(s),f,bundle-field-type(type(module(s)),f),FEMALE) if has-nul?(assign[n]) : println("Uninitialized: ~" % [to-string(n)]);TODO actually collect error else : add(decs,Connect(ref,to-exp(assign[n]))) (k) : val s = stmts[n] as DefWire val ref = WRef(n,type(s),k,FEMALE) if has-nul?(assign[n]) : println("Uninitialized: ~" % [to-string(n)]);TODO actually collect error else : add(decs,Connect(ref,to-exp(assign[n]))) Begin(to-list(decs)) defn get-enables (assign:HashTable, kinds:HashTable) -> HashTable : defn get-read-enable (sym:Symbol,sv:SymbolicValue) -> Expression : defn active (e:Expression) -> True|False : match(e) : (e:WRef) : name(e) == sym (e) : reduce-or{_} $ map(active,children(e)) (e) : false match(sv) : (sv: SVNul) : zero (sv: SVExp) : if active(exp(sv)) : one else : zero (sv: SVMux) : val e0 = get-read-enable(sym,SVExp(pred(sv))) val e1 = get-read-enable(sym,conseq(sv)) val e2 = get-read-enable(sym,alt(sv)) if e1 == e2 : OR(e0,e1) else : OR(e0,OR(AND(pred(sv),e1),AND(NOT(pred(sv)),e2))) defn get-write-enable (sv:SymbolicValue) -> SymbolicValue : match(map(get-write-enable,sv)) : (sv: SVExp) : SVExp(one) (sv: SVNul) : SVExp(zero) (sv) : sv val enables = HashTable(symbol-hash) for x in assign do : val sym = key(x) match(kinds[sym]) : (k:ReadAccessorKind) : enables[sym] = SVExp{_} $ reduce-or{_} $ to-list{_} $ for y in assign stream : get-read-enable(sym,value(y)) (k:WriteAccessorKind) : enables[sym] = get-write-enable(value(x)) (k:RegKind) : enables[sym] = get-write-enable(value(x)) (k) : k enables defn build-tables (s:Stmt, assign:HashTable, kinds:HashTable, decs:Vector, stmts:HashTable) -> False : match(s) : (s:DefWire) : add(decs,s) kinds[name(s)] = WireKind() assign[name(s)] = SVNul() stmts[name(s)] = s (s:DefNode) : add(decs,s) (s:DefRegister) : add(decs,DefWire(name(s),type(s))) kinds[name(s)] = RegKind() assign[name(s)] = SVNul() stmts[name(s)] = s (s:WDefAccessor) : add(decs,DefWire(name(s),type(type(source(s)) as VectorType))) assign[name(s)] = SVNul() kinds[name(s)] = switch {_ == gender(s)} : MALE : ReadAccessorKind() FEMALE : WriteAccessorKind() stmts[name(s)] = s (s:DefInstance) : add(decs,s) for f in fields(type(module(s)) as BundleType) do : val n = to-symbol("~.~" % [name(s),name(f)]) ; only on inputs ;println-all(["In DefInst adding: " n]) kinds[n] = InstanceKind() assign[n] = SVNul() stmts[n] = s (s:DefMemory) : add(decs,s) (s:Conditionally) : val assign-c = deepcopy(assign) val assign-a = deepcopy(assign) build-tables(conseq(s),assign-c,kinds,decs,stmts) build-tables(alt(s),assign-a,kinds,decs,stmts) for i in get-unique-keys(list(assign-c,assign-a)) do : assign[i] = match(get?(assign-c,i,false),get?(assign-a,i,false)) : ;TODO add to syntax highlighting (c:SymbolicValue,a:SymbolicValue) : SVMux(pred(s),c,a) (c:SymbolicValue,a:False) : if kinds[i] typeof WireKind|InstanceKind|NodeKind : c else : SVMux(pred(s),c,SVNul()) (c:False,a:SymbolicValue) : if kinds[i] typeof WireKind|InstanceKind|NodeKind : a else : SVMux(pred(s),SVNul(),a) (c:False,a:False) : error("Shouldn't be here") ;println("TABLE-C") ;for x in assign-c do : println(x) ;println("TABLE-A") ;for x in assign-a do : println(x) ;println("TABLE") ;for x in assign do : println(x) (s:Connect) : val key* = match(loc(s)) : (e:WRef) : name(e) (e:WSubfield) : symbol-join([name(exp(e) as ?) `. name(e)]) (e) : error("Shouldn't be here with ~" % [e]) assign[key*] = SVExp(exp(s)); TODO, need to check all references are declared before this point (s:Begin) : for s* in body(s) do: build-tables(s*,assign,kinds,decs,stmts) (s) : false defn expand-whens (m:Module) -> Module : val assign = HashTable(symbol-hash) val decs = Vector() val kinds = HashTable(symbol-hash) val stmts = HashTable(symbol-hash) for p in ports(m) do : if direction(p) == OUTPUT : assign[name(p)] = SVNul() kinds[name(p)] = PortKind() stmts[name(p)] = DefWire(name(p),type(p)) build-tables(body(m),assign,kinds,decs,stmts) for x in assign do : assign[key(x)] = optimize(value(x)) val enables = get-enables(assign,kinds) for x in enables do : enables[key(x)] = optimize(value(x)) ;println("Assigns") ;for x in assign do : println(x) ;println("Kinds") ;for x in kinds do : println(x) ;println("Decs") ;for x in decs do : println(x) ;println("Enables") ;for x in enables do : println(x) Module(name(m),ports(m),expand-whens(assign,kinds,stmts,decs,enables)) defn expand-whens (c:Circuit) -> Circuit : Circuit(modules*, main(c)) where : val modules* = for m in modules(c) map : expand-whens(m) ;;================ INFER WIDTHS ============================= ; First, you replace all unknown widths with a unique width ; variable. ; Then, you collect all width constraints. ; Then, you solve width constraints. ; Finally, you replace all width variables with the solved ; widths. ; Low FIRRTL Pass. defstruct VarWidth <: Width : name: Symbol defstruct PlusWidth <: Width : arg1 : Width arg2 : Width defstruct MinusWidth <: Width : arg1 : Width arg2 : Width defstruct MaxWidth <: Width : arg1 : Width arg2 : Width public defmulti map (f: Width -> Width, w:?T&Width) -> T defmethod map (f: Width -> Width, w:Width) -> Width : match(w) : (w:MaxWidth) : MaxWidth(f(arg1(w)),f(arg2(w))) (w:PlusWidth) : PlusWidth(f(arg1(w)),f(arg2(w))) (w:MinusWidth) : MinusWidth(f(arg1(w)),f(arg2(w))) (w) : w defmethod print (o:OutputStream, w:VarWidth) : print(o,name(w)) defmethod print (o:OutputStream, w:MaxWidth) : print-all(o,["max(" arg1(w) "," arg2(w) ")"]) defmethod print (o:OutputStream, w:PlusWidth) : print-all(o,[ arg1(w) " + " arg2(w)]) defmethod print (o:OutputStream, w:MinusWidth) : print-all(o,[ arg1(w) " - " arg2(w)]) definterface Constraint defstruct WGeq <: Constraint : loc : Width exp : Width defmethod print (o:OutputStream, c:WGeq) : print-all(o,[ loc(c) " >= " exp(c)]) defn solve-constraints (l:List) -> HashTable : defn contains? (n:Symbol,h:HashTable) -> True|False : for x in h any? : key(x) == n defn unique (ls:List) -> HashTable : val h = HashTable(symbol-hash) for g in ls do : match(loc(g)) : (w:VarWidth) : val n = name(w) if contains?(n,h) : h[n] = MaxWidth(exp(g),h[n]) else : h[n] = exp(g) (w) : w h defn substitute (w:Width,h:HashTable) -> Width : match(map(substitute{_,h},w)) : (w:VarWidth) : if contains?(name(w),h) : val t = substitute(h[name(w)],h) h[name(w)] = t t else : w (w) : w defn remove-cycle (n:Symbol,w:Width) -> Width : match(map(remove-cycle{n,_},w)) : (w:MaxWidth) : match(arg1(w),arg2(w)) : (v1:VarWidth,v2:VarWidth) : if name(v1) == n : arg2(w) else if name(v2) == n : arg1(w) else : w (v:VarWidth,_) : if name(v) == n : arg2(w) else : w (_,v:VarWidth) : if name(v) == n : arg1(w) else : w (v1,v2) : w (w) : w defn self-rec? (n:Symbol,w:Width) -> True|False : var has? = false defn look (w:Width) -> Width : match(map(look,w)) : (w:VarWidth) : if name(w) == n : has? = true (w) : w w look(w) has? defn evaluate (h:HashTable) -> HashTable : defn apply (a:Int|False,b:Int|False, f: (Int,Int) -> Int) -> Int|False : if a typeof Int and b typeof Int : f(a as Int, b as Int) else : false defn max (a:Int,b:Int) -> Int : if a > b : a else : b defn solve (w:Width) -> Int|False : match(w) : (w:VarWidth) : false (w:MaxWidth) : apply(solve(arg1(w)),solve(arg2(w)),max) (w:PlusWidth) : apply(solve(arg1(w)),solve(arg2(w)),{_ + _}) (w:MinusWidth) : apply(solve(arg1(w)),solve(arg2(w)),{_ - _}) (w:IntWidth) : width(w) (w) : error("Shouldn't be here") val i = HashTable(symbol-hash) for x in h do : val s = solve(value(x)) if s typeof Int : i[key(x)] = s as Int i ; Forward solve ; Returns a solved list where each constraint undergoes: ; 1) Continuous Solving (using triangular solving) ; 2) Remove Cycles ; 3) Move to solved if not self-recursive val u = unique(l) println("Unique Constraints") for x in u do : println(x) val f = HashTable(symbol-hash) val o = Vector() for x in u do : val [n e] = [key(x) value(x)] val e* = substitute(e,f) ;val e* = remove-cycle{n,_} $ substitute(e,f) ;if not self-rec?(n,e*) : add(o,n) f[n] = e* println("Forward Solved Constraints") for x in f do : println(x) ; Backwards Solve ;val b = HashTable(symbol-hash) ;for i in (length(o) - 1) through 0 by -1 do : ; val n = o[i] ; b[n] = substitute(f[n],b) ;println("Backwards Solved Constraints") ;for x in b do : println(x) ;; Evaluate ;val e = evaluate(b) ;println("Evaluated Constraints") ;for x in e do : println(x) ;e HashTable(symbol-hash) defn width! (t:Type) -> Width : match(t) : (t:UIntType) : width(t) (t:SIntType) : width(t) (t) : error("No width!") defn prim-width (e:DoPrim,v:Vector) -> Width : defn e-width (e:Expression) -> Width : width!(type(e)) defn wpc (l:List,c:List) : PlusWidth(e-width(l[0]),IntWidth(c[0])) defn wmc (l:List,c:List) : MinusWidth(e-width(l[0]),IntWidth(c[0])) defn maxw (l:List) : MaxWidth(e-width(l[0]),e-width(l[1])) defn mp1 (l:List) : PlusWidth(MaxWidth(e-width(l[0]),e-width(l[1])),IntWidth(1)) defn sum (l:List) : PlusWidth(e-width(l[0]),e-width(l[1])) switch {op(e) == _} : ADD-UU-OP : mp1(args(e)) ADD-US-OP : mp1(args(e)) ADD-SU-OP : mp1(args(e)) ADD-SS-OP : mp1(args(e)) SUB-UU-OP : mp1(args(e)) SUB-US-OP : mp1(args(e)) SUB-SU-OP : mp1(args(e)) SUB-SS-OP : mp1(args(e)) MUL-UU-OP : sum(args(e)) MUL-US-OP : sum(args(e)) MUL-SU-OP : sum(args(e)) MUL-SS-OP : sum(args(e)) ;(p:DIV-UU-OP) : ;(p:DIV-US-OP) : ;(p:DIV-SU-OP) : ;(p:DIV-SS-OP) : ;(p:MOD-UU-OP) : ;(p:MOD-US-OP) : ;(p:MOD-SU-OP) : ;(p:MOD-SS-OP) : ;(p:QUO-UU-OP) : ;(p:QUO-US-OP) : ;(p:QUO-SU-OP) : ;(p:QUO-SS-OP) : ;(p:REM-UU-OP) : ;(p:REM-US-OP) : ;(p:REM-SU-OP) : ;(p:REM-SS-OP) : ADD-WRAP-UU-OP : maxw(args(e)) ADD-WRAP-US-OP : maxw(args(e)) ADD-WRAP-SU-OP : maxw(args(e)) ADD-WRAP-SS-OP : maxw(args(e)) SUB-WRAP-UU-OP : maxw(args(e)) SUB-WRAP-US-OP : maxw(args(e)) SUB-WRAP-SU-OP : maxw(args(e)) SUB-WRAP-SS-OP : maxw(args(e)) LESS-UU-OP : IntWidth(1) LESS-US-OP : IntWidth(1) LESS-SU-OP : IntWidth(1) LESS-SS-OP : IntWidth(1) LESS-EQ-UU-OP : IntWidth(1) LESS-EQ-US-OP : IntWidth(1) LESS-EQ-SU-OP : IntWidth(1) LESS-EQ-SS-OP : IntWidth(1) GREATER-UU-OP : IntWidth(1) GREATER-US-OP : IntWidth(1) GREATER-SU-OP : IntWidth(1) GREATER-SS-OP : IntWidth(1) GREATER-EQ-UU-OP : IntWidth(1) GREATER-EQ-US-OP : IntWidth(1) GREATER-EQ-SU-OP : IntWidth(1) GREATER-EQ-SS-OP : IntWidth(1) EQUAL-UU-OP : IntWidth(1) EQUAL-SS-OP : IntWidth(1) MUX-UU-OP : maxw(args(e)) MUX-SS-OP : maxw(args(e)) PAD-U-OP : IntWidth(consts(e)[0]) PAD-S-OP : IntWidth(consts(e)[0]) AS-UINT-U-OP : e-width(args(e)[0]) AS-UINT-S-OP : e-width(args(e)[0]) AS-SINT-U-OP : e-width(args(e)[0]) AS-SINT-S-OP : e-width(args(e)[0]) SHIFT-LEFT-U-OP : wpc(args(e),consts(e)) SHIFT-LEFT-S-OP : wpc(args(e),consts(e)) SHIFT-RIGHT-U-OP : wmc(args(e),consts(e)) SHIFT-RIGHT-S-OP : wmc(args(e),consts(e)) CONVERT-U-OP : PlusWidth(e-width(args(e)[0]),IntWidth(1)) CONVERT-S-OP : e-width(args(e)[0]) BIT-AND-OP : maxw(args(e)) BIT-OR-OP : maxw(args(e)) BIT-XOR-OP : maxw(args(e)) CONCAT-OP : sum(args(e)) BIT-SELECT-OP : IntWidth(1) BITS-SELECT-OP : IntWidth(consts(e)[0] - consts(e)[1]) defn gen-constraints (c:Circuit, m:Module, v:Vector) -> Vector : val h = HashTable(symbol-hash) defn get-width (e:Expression) -> Width : match(e) : (e:WRef) : val [wdec wref] = match(kind(e)) : (k:InstanceKind) : error("Shouldn't be here") (k:MemKind) : [width! $ type $ (h[name(e)] as VectorType), width! $ type $ (type(e) as VectorType)] (k) : [width! $ h[name(e)], width! $ type(e)] add(v,WGeq(wref,wdec)) add(v,WGeq(wdec,wref)) wref (e:WSubfield) : ;assumes only subfields are instances val wdec = width! $ bundle-field-type(h[name(exp(e) as WRef)],name(e)) val wref = width! $ bundle-field-type(type(exp(e)),name(e)) add(v,WGeq(wref,wdec)) add(v,WGeq(wdec,wref)) wref (e:WIndex) : error("Shouldn't be here") (e:UIntValue) : width(e) (e:SIntValue) : width(e) (e:DoPrim) : prim-width(e,v) (e:ReadPort|WritePort) : add(v,WGeq(get-width(enable(e)),IntWidth(1))) get-width(mem(e)) (e:Register) : add(v,WGeq(get-width(enable(e)),IntWidth(1))) val w = width!(type(e)) add(v,WGeq(w,get-width(value(e)))) w defn gen-constraints (s:Stmt) -> Stmt : match(map(gen-constraints,s)) : (s:DefWire) : h[name(s)] = type(s) (s:DefInstance) : h[name(s)] = h[name(module(s) as WRef)] (s:DefMemory) : h[name(s)] = type(s) (s:DefNode) : h[name(s)] = type(value(s)) (s:Connect) : add(v,WGeq(get-width(loc(s)),get-width(exp(s)))) add(v,WGeq(get-width(exp(s)),get-width(loc(s)))) (s) : "" s for m in modules(c) do : h[name(m)] = BundleType(map(to-field,ports(m))) for p in ports(m) do : h[name(p)] = type(p) gen-constraints(body(m)) v defn replace-var-widths (c:Circuit,h:HashTable) -> Circuit : defn replace-var-widths-w (w:Width) -> Width : defn contains? (n:Symbol,h:HashTable) -> True|False : for x in h any? : key(x) == n match(w) : (w:VarWidth) : if contains?(name(w),h) : IntWidth(h[name(w)]) else: w (w) : w val modules* = for m in modules(c) map : Module{name(m),_,body(m)} $ for p in ports(m) map : Port(name(p),direction(p),mapr(replace-var-widths-w,type(p))) val modules** = for m in modules* map : Module(name(m),ports(m),mapr(replace-var-widths-w,body(m))) Circuit(modules**,main(c)) defn remove-unknown-widths (c:Circuit) -> Circuit : defn remove-unknown-widths-w (w:Width) -> Width : match(w) : (w:UnknownWidth) : VarWidth(gensym(`w)) (w) : w val modules* = for m in modules(c) map : Module{name(m),_,body(m)} $ for p in ports(m) map : Port(name(p),direction(p),mapr(remove-unknown-widths-w,type(p))) val modules** = for m in modules* map : Module(name(m),ports(m),mapr(remove-unknown-widths-w,body(m))) Circuit(modules**,main(c)) defn infer-widths (c:Circuit) -> Circuit : val c* = remove-unknown-widths(c) val v = Vector() for m in modules(c*) do : gen-constraints(c*,m,v) for x in v do : println(x) val h = solve-constraints(to-list(v)) println("Solved Constraints") ;for x in h do : println(x) ;replace-var-widths(c*,h) c* ;;==================== WIDTH INFERENCE ====================== ;defstruct WidthVar <: Width : ; name: Symbol ; ;defmethod print (o:OutputStream, w:WidthVar) : ; print(o, name(w)) ; ;defn width! (t:Type) : ; match(t) : ; (t:UIntType) : width(t) ; (t:SIntType) : width(t) ; (t) : error("No width field.") ; ;defn put-width (t:Type, w:Width) : ; match(t) : ; (t:UIntType) : UIntType(w) ; (t:SIntType) : SIntType(w) ; (t) : t ; ;defn put-width (e:Expression, w:Width) : ; val type* = put-width(type(e), w) ; put-type(e, type*) ; ;defn add-width-vars (t:Type) : ; defn width? (w:Width) : ; match(w) : ; (w:UnknownWidth) : WidthVar(gensym()) ; (w) : w ; match(t) : ; (t:UIntType) : UIntType(width?(width(t))) ; (t:SIntType) : SIntType(width?(width(t))) ; (t) : map(add-width-vars, t) ; ;defn uint-width (i:Int) : ; var v:Int = i ; var n:Int = 0 ; while v != 0 : ; v = v >> 1 ; n = n + 1 ; IntWidth(n) ; ;defn sint-width (i:Int) : ; if i > 0 : ; val w = uint-width(i) ; IntWidth(width(w) + 1) ; else : ; val w = uint-width(neg(i) - 1) ; IntWidth(width(w) + 1) ; ;defn to-exp (w:Width) : ; match(w) : ; (w:IntWidth) : ELit(width(w)) ; (w:WidthVar) : EVar(name(w)) ; (w) : error $ string-join $ [ ; "Cannot convert " w " to exp."] ; ;defn primop-width (op:PrimOp, ws:List, ints:List) -> Exp : ; defn wmax (w1:Width, w2:Width) : ; EMax(to-exp(w1), to-exp(w2)) ; defn wplus (w1:Width, w2:Width) : ; EPlus(to-exp(w1), to-exp(w2)) ; defn wplus (w1:Width, w2:Int) : ; EPlus(to-exp(w1), ELit(w2)) ; defn wminus (w1:Width, w2:Width) : ; EMinus(to-exp(w1), to-exp(w2)) ; defn wminus (w1:Width, w2:Int) : ; EMinus(to-exp(w1), ELit(w2)) ; defn wmax-inc (w1:Width, w2:Width) : ; EPlus(wmax(w1, w2), ELit(1)) ; ; switch {op == _} : ; ADD-OP : wmax-inc(ws[0], ws[1]) ; ADD-WRAP-OP : wmax(ws[0], ws[1]) ; SUB-OP : wmax-inc(ws[0], ws[1]) ; SUB-WRAP-OP : wmax(ws[0], ws[1]) ; MUL-OP : wplus(ws[0], ws[1]) ; DIV-OP : wminus(ws[0], ws[1]) ; MOD-OP : to-exp(ws[1]) ; SHIFT-LEFT-OP : wplus(ws[0], ints[0]) ; SHIFT-RIGHT-OP : wminus(ws[0], ints[0]) ; PAD-OP : ELit(ints[0]) ; BIT-AND-OP : wmax(ws[0], ws[1]) ; BIT-OR-OP : wmax(ws[0], ws[1]) ; BIT-XOR-OP : wmax(ws[0], ws[1]) ; CONCAT-OP : wplus(ws[0], ints[0]) ; BIT-SELECT-OP : ELit(1) ; BITS-SELECT-OP : ELit(ints[0]) ; MUX-OP : wmax(ws[1], ws[2]) ; LESS-OP : ELit(1) ; LESS-EQ-OP : ELit(1) ; GREATER-OP : ELit(1) ; GREATER-EQ-OP : ELit(1) ; EQUAL-OP : ELit(1) ; ;defn put-type (el:Element, t:Type) -> Element : ; match(el) : ; (el:Register) : Register(t, value(el), enable(el)) ; (el:Memory) : Memory(t, writers(el)) ; (el:Node) : Node(t, value(el)) ; (el:Instance) : Instance(t, module(el), ports(el)) ; ;defn generate-constraints (c:Circuit) -> [Circuit, Vector] : ; ;Constraints ; val cs = Vector() ; defn new-constraint (Constraint: (Symbol, Exp) -> WConstraint, wvar:Width, width:Width) : ; match(wvar) : ; (wvar:WidthVar) : ; add(cs, Constraint(name(wvar), to-exp(width))) ; (wvar) : ; false ; ; defn to-width (e:Exp) : ; match(e) : ; (e:ELit) : ; IntWidth(width(e)) ; (e:EVar) : ; WidthVar(name(e)) ; (e) : ; val x = gensym() ; add(cs, WidthEqual(x, e)) ; WidthVar(x) ; ; ;Module types ; val mod-types = HashTable(symbol-hash) ; ; defn add-port-vars (m:Module) -> Module : ; val ports* = ; for p in ports(m) map : ; val type* = add-width-vars(type(p)) ; Port(name(p), gender(p), type*) ; mod-types[name(m)] = BundleType(ports*) ; Module(name(m), ports*, body(m)) ; ; ;Add Width Variables ; defn add-module-vars (m:Module) -> Module : ; val types = HashTable(symbol-hash) ; for p in ports(m) do : ; types[name(p)] = type(p) ; ; defn infer-exp-width (e:Expression) -> Expression : ; match(map(infer-exp-width, e)) : ; (e:WRef) : ; match(kind(e)) : ; (k:ModuleKind) : put-type(e, mod-types[name(e)]) ; (k) : put-type(e, types[name(e)]) ; (e:WSubfield) : ; val t = bundle-field-type(type(exp(e)), name(e)) ; put-width(e, width!(t)) ; (e:UIntValue) : ; match(width(e)) : ; (w:UnknownWidth) : UIntValue(value(e), uint-width(value(e))) ; (w) : e ; (e:SIntValue) : ; match(width(e)) : ; (w:UnknownWidth) : SIntValue(value(e), sint-width(value(e))) ; (w) : e ; (e:DoPrim) : ; val widths = map(width!{type(_)}, args(e)) ; val w = to-width(primop-width(op(e), widths, consts(e))) ; put-width(e, w) ; (e:ReadPort) : ; val elem-type = type(type(mem(e)) as VectorType) ; put-width(e, width!(elem-type)) ; ; defn infer-comm-width (c:Stmt) : ; match(c) : ; (c:LetRec) : ; ;Add width vars to elements ; var entries*: List> = ; for entry in entries(c) map : ; val el-name = key(entry) ; key(entry) => ; match(value(entry)) : ; (el:Register|Node) : ; put-type(el, add-width-vars(type(el))) ; (el:Memory) : ; el ; (el:Instance) : ; val mod-type = type(infer-exp-width(module(el))) as BundleType ; val type = BundleType $ to-list $ ; for p in ports(mod-type) filter : ; gender(p) == OUTPUT ; put-type(el, type) ; ; ;Add vars to environment ; for entry in entries* do : ; types[key(entry)] = type(value(entry)) ; ; ;Infer types for elements ; entries* = ; for entry in entries* map : ; key(entry) => map(infer-exp-width, value(entry)) ; ; ;Generate constraints ; for entry in entries* do : ; val el-name = key(entry) ; match(value(entry)) : ; (el:Register) : ; new-constraint(WidthEqual, reg-width, val-width) where : ; val reg-width = width!(types[el-name]) ; val val-width = width!(type(value(el))) ; (el:Node) : ; new-constraint(WidthEqual, node-width, val-width) where : ; val node-width = width!(types[el-name]) ; val val-width = width!(type(value(el))) ; (el:Instance) : ; val mod-type = type(module(el)) as BundleType ; for entry in ports(el) do : ; new-constraint(WidthGreater, port-width, val-width) where : ; val port-name = key(entry) ; val port-width = width!(bundle-field-type(mod-type, port-name)) ; val val-width = width!(type(value(entry))) ; (el) : false ; ; ;Analyze body ; LetRec(entries*, infer-comm-width(body(c))) ; ; (c:Connect) : ; val loc* = infer-exp-width(loc(c)) ; val exp* = infer-exp-width(exp(c)) ; new-constraint(WidthGreater, loc-width, exp-width) where : ; val loc-width = width!(type(loc*)) ; val exp-width = width!(type(exp*)) ; Connect(loc*, exp*) ; ; (c:Begin) : ; map(infer-comm-width, c) ; ; Module(name(m), ports(m), body*) where : ; val body* = infer-comm-width(body(m)) ; ; val c* = ; Circuit(modules*, main(c)) where : ; val ms = map(add-port-vars, modules(c)) ; val modules* = map(add-module-vars, ms) ; [c*, cs] ; ; ;;================== FILL WIDTHS ============================ ;defn fill-widths (c:Circuit, solved:Streamable) : ; ;Populate table ; val table = HashTable(symbol-hash) ; for eq in solved do : ; table[name(eq)] = IntWidth(width(value(eq) as ELit)) ; ; defn width? (w:Width) : ; match(w) : ; (w:WidthVar) : get?(table, name(w), UnknownWidth()) ; (w) : w ; ; defn fill-type (t:Type) : ; match(t) : ; (t:UIntType) : UIntType(width?(width(t))) ; (t:SIntType) : SIntType(width?(width(t))) ; (t) : map(fill-type, t) ; ; defn fill-exp (e:Expression) -> Expression : ; val e* = map(fill-exp, e) ; val type* = fill-type(type(e)) ; put-type(e*, type*) ; ; defn fill-element (e:Element) : ; val e* = map(fill-exp, e) ; val type* = fill-type(type(e)) ; put-type(e*, type*) ; ; defn fill-comm (c:Stmt) : ; match(c) : ; (c:LetRec) : ; val entries* = ; for e in entries(c) map : ; key(e) => fill-element(value(e)) ; LetRec(entries*, fill-comm(body(c))) ; (c) : ; map{fill-comm, _} $ ; map(fill-exp, c) ; ; defn fill-port (p:Port) : ; Port(name(p), gender(p), fill-type(type(p))) ; ; defn fill-mod (m:Module) : ; Module(name(m), ports*, body*) where : ; val ports* = map(fill-port, ports(m)) ; val body* = fill-comm(body(m)) ; ; Circuit(modules*, main(c)) where : ; val modules* = map(fill-mod, modules(c)) ; ; ;;=============== TYPE INFERENCE DRIVER ===================== ;defn infer-widths (c:Circuit) : ; val [c*, cs] = generate-constraints(c) ; val solved = solve-widths(cs) ; fill-widths(c*, solved) ; ; ;;================ PAD WIDTHS =============================== ;defn pad-widths (c:Circuit) : ; ;Pad an expression to the given width ; defn pad-exp (e:Expression, w:Int) : ; match(type(e)) : ; (t:UIntType|SIntType) : ; val prev-w = width!(t) as IntWidth ; if width(prev-w) < w : ; val type* = put-width(t, IntWidth(w)) ; DoPrim(PAD-OP, list(e), list(w), type*) ; else : ; e ; (t) : ; e ; ; defn pad-exp (e:Expression, w:Width) : ; val w-value = width(w as IntWidth) ; pad-exp(e, w-value) ; ; ;Convenience ; defn max-width (es:Streamable) : ; defn int-width (e:Expression) : ; width(width!(type(e)) as IntWidth) ; maximum(stream(int-width, es)) ; ; defn match-widths (es:List) : ; val w = max-width(es) ; map(pad-exp{_, w}, es) ; ; ;Match widths for an expression ; defn match-exp-width (e:Expression) : ; match(map(match-exp-width, e)) : ; (e:DoPrim) : ; if contains?([BIT-AND-OP, BIT-OR-OP, BIT-XOR-OP, EQUAL-OP], op(e)) : ; val args* = match-widths(args(e)) ; DoPrim(op(e), args*, consts(e), type(e)) ; else if op(e) == MUX-OP : ; val args* = List(head(args(e)), match-widths(tail(args(e)))) ; DoPrim(op(e), args*, consts(e), type(e)) ; else : ; e ; (e) : e ; ; defn match-element-width (e:Element) : ; match(map(match-exp-width, e)) : ; (e:Register) : ; val w = width!(type(e)) ; val value* = pad-exp(value(e), w) ; Register(type(e), value*, enable(e)) ; (e:Memory) : ; val width = width!(type(type(e) as VectorType)) ; val writers* = ; for w in writers(e) map : ; WritePort(index(w), pad-exp(value(w), width), enable(w)) ; Memory(type(e), writers*) ; (e:Node) : ; val w = width!(type(e)) ; val value* = pad-exp(value(e), w) ; Node(type(e), value*) ; (e:Instance) : ; val mod-type = type(module(e)) as BundleType ; val ports* = ; for p in ports(e) map : ; val port-type = bundle-field-type(mod-type, key(p)) ; val port-val = pad-exp(value(p), width!(port-type)) ; key(p) => port-val ; Instance(type(e), module(e), ports*) ; ; ;Match widths for a command ; defn match-comm-width (c:Stmt) : ; match(map(match-exp-width, c)) : ; (c:LetRec) : ; val entries* = ; for e in entries(c) map : ; key(e) => match-element-width(value(e)) ; LetRec(entries*, match-comm-width(body(c))) ; (c:Connect) : ; val w = width!(type(loc(c))) ; val exp* = pad-exp(exp(c), w) ; Connect(loc(c), exp*) ; (c) : ; map(match-comm-width, c) ; ; defn match-mod-width (m:Module) : ; Module(name(m), ports(m), body*) where : ; val body* = match-comm-width(body(m)) ; ; Circuit(modules*, main(c)) where : ; val modules* = map(match-mod-width, modules(c)) ; ; ;;================== INLINING =============================== ;defn inline-instances (c:Circuit) : ; val module-table = HashTable(symbol-hash) ; val inlined? = HashTable(symbol-hash) ; for m in modules(c) do : ; module-table[name(m)] = m ; inlined?[name(m)] = false ; ; ;Convert a module into a sequence of elements ; defn to-elements (m:Module, ; inst:Symbol, ; port-exps:List>) -> ; List> : ; defn rename-exp (e:Expression) : ; match(e) : ; (e:WRef) : WRef(prefix(inst, name(e)), type(e), kind(e), dir(e)) ; (e) : map(rename-exp, e) ; ; defn to-elements (c:Stmt) -> List> : ; match(c) : ; (c:LetRec) : ; val entries* = ; for entry in entries(c) map : ; val name* = prefix(inst, key(entry)) ; val element* = map(rename-exp, value(entry)) ; name* => element* ; val body* = to-elements(body(c)) ; append(entries*, body*) ; (c:Connect) : ; val ref = loc(c) as WRef ; val name* = prefix(inst, name(ref)) ; list(name* => Node(type(exp(c)), rename-exp(exp(c)))) ; (c:Begin) : ; map-append(to-elements, body(c)) ; ; val inputs = ; for p in ports(m) map-append : ; if gender(p) == INPUT : ; val port-exp = lookup!(port-exps, name(p)) ; val name* = prefix(inst, name(p)) ; list(name* => Node(type(port-exp), port-exp)) ; else : ; List() ; append(inputs, to-elements(body(m))) ; ; ;Inline all instances in the module ; defn inline-instances (m:Module) : ; defn rename-exp (e:Expression) : ; match(e) : ; (e:WSubfield) : ; val inst-exp = exp(e) as WRef ; val name* = prefix(name(inst-exp), name(e)) ; WRef(name*, type(e), NodeKind(), dir(e)) ; (e) : ; map(rename-exp, e) ; ; defn inline-elems (es:List>) : ; for entry in es map-append : ; match(value(entry)) : ; (el:Instance) : ; val mod-name = name(module(el) as WRef) ; val module = inlined-module(mod-name) ; to-elements(module, key(entry), ports(el)) ; (el) : ; list(entry) ; ; defn inline-comm (c:Stmt) : ; match(map(rename-exp, c)) : ; (c:LetRec) : ; val entries* = inline-elems(entries(c)) ; LetRec(entries*, inline-comm(body(c))) ; (c) : ; map(inline-comm, c) ; ; Module(name(m), ports(m), inline-comm(body(m))) ; ; ;Retrieve the inlined instance of a module ; defn inlined-module (name:Symbol) : ; if inlined?[name] : ; module-table[name] ; else : ; val module* = inline-instances(module-table[name]) ; module-table[name] = module* ; inlined?[name] = true ; module* ; ; ;Return the fully inlined circuit ; val main-module = inlined-module(main(c)) ; Circuit(list(main-module), main(c)) ;;;================ UTILITIES ================================ ; ; ; ;defn* root-ref (i:Immediate) : ; match(i) : ; (f:Subfield) : root-ref(imm(f)) ; (ind:Index) : root-ref(imm(ind)) ; (r) : r ; ;;defn lookup (e: Streamable>, i:Immediate) : ;; for entry in e search : ;; if eqv?(key(entry), i) : ;; value(entry) ;; ;;defn lookup! (e: Streamable>, i:Immediate) : ;; lookup(e, i) as T ;; ;;============ CHECK IF NAMES ARE UNIQUE ==================== ;defn check-duplicate-symbols (names: Streamable, msg: String) : ; val dict = HashTable(symbol-hash) ; for name in names do: ; if key?(dict, name): ; throw $ PassException $ string-join $ ; [msg ": " name] ; else: ; dict[name] = true ; ;defn check-duplicates (t: Type) : ; match(t) : ; (t:BundleType) : ; val names = map(name, ports(t)) ; check-duplicate-symbols{names, string-join(_)} $ ; ["Duplicate port name in bundle "] ; do(check-duplicates{type(_)}, ports(t)) ; (t:VectorType) : ; check-duplicates(type(t)) ; (t) : false ; ;defn check-duplicates (c: Stmt) : ; match(c) : ; (c:DefWire) : check-duplicates(type(c)) ; (c:DefRegister) : check-duplicates(type(c)) ; (c:DefMemory) : check-duplicates(type(c)) ; (c) : do(check-duplicates, children(c)) ; ;defn defined-names (c: Stmt) : ; generate : ; loop(c) where : ; defn loop (c:Stmt) : ; match(c) : ; (c:Stmt&HasName) : yield(name(c)) ; (c) : do(loop, children(c)) ; ;defn check-duplicates (m: Module): ; ;Check all duplicate names in all types in all ports and body ; do(check-duplicates{type(_)}, ports(m)) ; check-duplicates(body(m)) ; ; ;Check all names defined in module ; val names = concat(stream(name, ports(m)), ; defined-names(body(m))) ; check-duplicate-symbols{names, string-join(_)} $ ; ["Duplicate definition name in module " name(m)] ; ;defn check-duplicates (c: Circuit) : ; ;Check all duplicate names in all modules ; do(check-duplicates, modules(c)) ; ; ;Check all defined modules ; val names = stream(name, modules(c)) ; check-duplicate-symbols(names, "Duplicate module name") ; ; ;;================ CLEANUP COMMANDS ========================= ;defn cleanup (c:Stmt) : ; match(c) : ; (c:Begin) : ; to-command $ generate : ; loop(c) where : ; defn loop (c:Stmt) : ; match(c) : ; (c:Begin) : do(loop, body(c)) ; (c:EmptyStmt) : false ; (c) : yield(cleanup(c)) ; (c) : map(cleanup{_ as Stmt}, c) ; ;defn cleanup (c:Circuit) : ; val modules* = ; for m in modules(c) map : ; map(cleanup, m) ; Circuit(modules*, main(c)) ; ;;;============= SHIM ======================================== ;;defn shim (i:Immediate) -> Immediate : ;; match(i) : ;; (i:RegData) : ;; Ref(name(i), gender(i), type(i)) ;; (i:InstPort) : ;; val inst = Ref(name(i), UNKNOWN-GENDER, UnknownType()) ;; Subfield(inst, port(i), gender(i), type(i)) ;; (i:Subfield) : ;; val imm* = shim(imm(i)) ;; put-imm(i, imm*) ;; (i) : i ;; ;;defn shim (c:Stmt) -> Stmt : ;; val c* = map(shim{_ as Immediate}, c) ;; map(shim{_ as Stmt}, c*) ;; ;;defn shim (c:Circuit) -> Circuit : ;; val modules* = ;; for m in modules(c) map : ;; Module(name(m), ports(m), shim(body(m))) ;; Circuit(modules*, main(c)) ;; ;;;================== INLINE MODULES ========================= ;;defn cat-name (p: String|Symbol, s: String|Symbol) -> Symbol : ;; if p == "" or p == `this : ;; TODO: REMOVE THIS WHEN `THIS GETS REMOVED ;; to-symbol(s) ;; else if s == `this : ;; TODO: DITTO ;; to-symbol(p) ;; else : ;; symbol-join([p, "/", s]) ;; ;;defn inline-command (c: Stmt, mods: HashTable, prefix: String, cmds: Vector) : ;; defn rename (n: Symbol) -> Symbol : ;; cat-name(prefix, n) ;; defn inline-name (i:Immediate) -> Symbol : ;; match(i) : ;; (r:Ref) : rename(name(r)) ;; (f:Subfield) : cat-name(inline-name(imm(f)), name(f)) ;; (f:Index) : cat-name(inline-name(imm(f)), to-string(value(f))) ;; defn inline-imm (i:Immediate) -> Ref : ;; Ref(inline-name(i), gender(i), type(i)) ;; match(c) : ;; (c:DefUInt) : add(cmds, DefUInt(rename(name(c)), value(c), width(c))) ;; (c:DefSInt) : add(cmds, DefSInt(rename(name(c)), value(c), width(c))) ;; (c:DefWire) : add(cmds, DefWire(rename(name(c)), type(c))) ;; (c:DefRegister) : add(cmds, DefRegister(rename(name(c)), type(c))) ;; (c:DefMemory) : add(cmds, DefMemory(rename(name(c)), type(c), size(c))) ;; (c:DefInstance) : inline-module(mods, mods[name(module(c))], to-string(rename(name(c))), cmds) ;; (c:DoPrim) : add(cmds, DoPrim(rename(name(c)), op(c), map(inline-imm, args(c)), consts(c))) ;; (c:DefAccessor) : add(cmds, DefAccessor(rename(name(c)), inline-imm(source(c)), gender(c), inline-imm(index(c)))) ;; (c:Connect) : add(cmds, Connect(inline-imm(loc(c)), inline-imm(exp(c)))) ;; (c:Begin) : do(inline-command{_, mods, prefix, cmds}, body(c)) ;; (c:EmptyStmt) : c ;; (c) : error("Unsupported command") ;; ;;defn inline-port (p: Port, prefix: String) -> Stmt : ;; DefWire(cat-name(prefix, name(p)), type(p)) ;; ;;defn inline-module (mods: HashTable, mod: Module, prefix: String, cmds: Vector) : ;; do(add{cmds, _}, map(inline-port{_, prefix}, ports(mod))) ;; inline-command(body(mod), mods, prefix, cmds) ;; ;;defn inline-modules (c: Circuit) -> Circuit : ;; val cmds = Vector() ;; val mods = HashTable(symbol-hash) ;; for mod in modules(c) do : ;; mods[name(mod)] = mod ;; val top = mods[main(c)] ;; inline-command(body(top), mods, "", cmds) ;; val main* = Module(name(top), ports(top), Begin(to-list(cmds))) ;; Circuit(list(main*), name(top)) ;; ;; ;;;============= FLO PRINTER ====================================== ;;;;; TODO: ;;;;; not supported gt, lte ;; ;;defn flo-op-name (op:PrimOp) -> String : ;; switch {op == _} : ;; ADD-OP : "add" ;; ADD-MOD-OP : "add" ;; MINUS-OP : "sub" ;; SUB-MOD-OP : "sub" ;; MUL-OP : "mul" ;; todo: signed version ;; DIV-OP : "div" ;; todo: signed version ;; MOD-OP : "mod" ;; todo: signed version ;; SHIFT-LEFT-OP : "lsh" ;; todo: signed version ;; SHIFT-RIGHT-OP : "rsh" ;; PAD-OP : "pad" ;; todo: signed version ;; BIT-AND-OP : "and" ;; BIT-OR-OP : "or" ;; BIT-XOR-OP : "xor" ;; CONCAT-OP : "cat" ;; BIT-SELECT-OP : "rsh" ;; BITS-SELECT-OP : "rsh" ;; LESS-OP : "lt" ;; todo: signed version ;; LESS-EQ-OP : "lte" ;; todo: swap args ;; GREATER-OP : "gt" ;; todo: swap args ;; GREATER-EQ-OP : "gte" ;; todo: signed version ;; EQUAL-OP : "eq" ;; MUX-OP : "mux" ;; else : error $ string-join $ ;; ["Unable to print Primop: " op] ;; ;;defn emit (o:OutputStream, top:Symbol, ports:HashTable, lits:HashTable, elt) : ;; match(elt) : ;; (e:String|Symbol|Int) : ;; print(o, e) ;; (e:Ref) : ;; if key?(lits, name(e)) : ;; val lit = lits[name(e)] ;; print-all(o, [value(lit) "'" width(lit)]) ;; else : ;; if key?(ports, name(e)) : ;; print-all(o, [top "::"]) ;; print(o, name(e)) ;; (e:IntWidth) : ;; print(o, value(e)) ;; (e:PrimOp) : ;; print(o, flo-op-name(e)) ;; (e) : ;; println-all(["EMIT " e]) ;; error("Unable to emit") ;; ;;defn emit-all (o:OutputStream, top:Symbol, ports:HashTable, lits:HashTable, elts: Streamable) : ;; for e in elts do : emit(o, top, ports, lits, e) ;; ;;defn prim-width (type:Type) -> Width : ;; match(type) : ;; (t:UIntType) : width(t) ;; (t:SIntType) : width(t) ;; (t) : error("Bad prim width type") ;; ;;defn emit-command (o:OutputStream, cmd:Stmt, top:Symbol, lits:HashTable, regs:HashTable, accs:HashTable, ports:HashTable, outs:HashTable) : ;; match(cmd) : ;; (c:DefUInt) : ;; lits[name(c)] = c ;; (c:DefSInt) : ;; emit-all(o, top, ports, lits, [name(c) " = " value(c) "'" width(c) "\n"]) ;; (c:DoPrim) : ;; NEED TO FIGURE OUT WHEN WIDTHS ARE NECESSARY AND EXTRACT ;; emit-all(o, top, ports, lits, [name(c) " = " op(c)]) ;; for arg in args(c) do : ;; print(o, " ") ;; emit(o, top, ports, lits, arg) ;; for const in consts(c) do : ;; print(o, " ") ;; emit(o, top, ports, lits, const) ;; print("\n") ;; (c:DefRegister) : ;; regs[name(c)] = c ;; (c:DefMemory) : ;; emit-all(o, top, ports, lits, [name(c) " : mem'" prim-width(type(c)) " " size(c) "\n"]) ;; (c:DefAccessor) : ;; accs[name(c)] = c ;; (c:Connect) : ;; val dst = name(loc(c) as Ref) ;; val src = name(exp(c) as Ref) ;; if key?(regs, dst) : ;; val reg = regs[dst] ;; emit-all(o, top, ports, lits, [dst " = reg'" prim-width(type(reg)) " 0'" prim-width(type(reg)) " " exp(c) "\n"]) ;; else if key?(accs, dst) : ;; val acc = accs[dst] ;; ;; assert(gender(acc) == OUTPUT) ;; emit-all(o, top, ports, lits, [dst " = wr " source(acc) " " index(acc) " " exp(c) "\n"]) ;; else if key?(outs, dst) : ;; val out = outs[dst] ;; emit-all(o, top, ports, lits, [top "::" dst " = out'" prim-width(type(out)) " " exp(c) "\n"]) ;; else if key?(accs, src) : ;; val acc = accs[src] ;; ;; assert(gender(acc) == INPUT) ;; emit-all(o, top, ports, lits, [dst " = rd " source(acc) " " index(acc) "\n"]) ;; else : ;; emit-all(o, top, ports, lits, [dst " = mov " exp(c) "\n"]) ;; (c:Begin) : ;; do(emit-command{o, _, top, lits, regs, accs, ports, outs}, body(c)) ;; (c:DefWire|EmptyStmt) : ;; print("") ;; (c) : ;; error("Unable to print command") ;; ;;defn emit-module (o:OutputStream, m:Module) : ;; val regs = HashTable(symbol-hash) ;; val accs = HashTable(symbol-hash) ;; val lits = HashTable(symbol-hash) ;; val outs = HashTable(symbol-hash) ;; val portz = HashTable(symbol-hash) ;; for port in ports(m) do : ;; portz[name(port)] = port ;; if gender(port) == OUTPUT : ;; outs[name(port)] = port ;; else if name(port) == `reset : ;; print-all(o, [name(m) "::reset = rst\n"]) ;; else : ;; print-all(o, [name(m) "::" name(port) " = " "in'" prim-width(type(port)) "\n"]) ;; emit-command(o, body(m), name(m), lits, regs, accs, portz, outs) ;; ;;public defn emit-circuit (o:OutputStream, c:Circuit) : ;; emit-module(o, modules(c)[0]) ;============= DRIVER ====================================== public defn run-passes (c: Circuit, p: List) : var c*:Circuit = c println("Compiling!") if PRINT-CIRCUITS : println("Original Circuit") if PRINT-CIRCUITS : print(c) defn do-stage (name:String, f: Circuit -> Circuit) : if PRINT-CIRCUITS : println(name) c* = f(c*) if PRINT-CIRCUITS : print(c*) if PRINT-CIRCUITS : println-all(["Finished " name "\n"]) ; Early passes: ; If modules have a reset defined, must be an INPUT and UInt(1) if contains(p,'a') : do-stage("Working IR", to-working-ir) if contains(p,'b') : do-stage("Resolve Kinds", resolve-kinds) if contains(p,'c') : do-stage("Make Explicit Reset", make-explicit-reset) if contains(p,'e') : do-stage("Infer Types", infer-types) if contains(p,'f') : do-stage("Resolve Genders", resolve-genders) if contains(p,'g') : do-stage("Expand Accessors", expand-accessors) if contains(p,'h') : do-stage("Lower To Ground", lower-to-ground) if contains(p,'i') : do-stage("Expand Indexed Connects", expand-connect-indexed) if contains(p,'p') : do-stage("Initialize Registers", initialize-registers) if contains(p,'j') : do-stage("Expand Whens", expand-whens) if contains(p,'k') : do-stage("Infer Widths", infer-widths) ;if contains(p,'n') : do-stage("Pad Widths", pad-widths) ;if contains(p,'o') : do-stage("Inline Instances", inline-instances) println("Done!") ;; println("Shim for Jonathan's Passes") ;; c* = shim(c*) ;; println("Inline Modules") ;; c* = inline-modules(c*) ; c*