defpackage firrtl.passes : import core import verse import firrtl.ir2 import firrtl.ir-utils import widthsolver import firrtl-main ;============== EXCEPTIONS ================================= defclass PassException <: Exception defn PassException (msg:String) : new PassException : defmethod print (o:OutputStream, this) : print(o, msg) ;=============== WORKING IR ================================ definterface Kind defstruct RegKind <: Kind defstruct AccessorKind <: Kind defstruct PortKind <: Kind defstruct MemKind <: Kind defstruct NodeKind <: Kind ; All elems except structural memory, wires defstruct ModuleKind <: Kind defstruct InstanceKind <: Kind defstruct StructuralMemKind <: Kind ; Separate kind because need special treatment defstruct WRef <: Expression : name: Symbol type: Type [multi => false] kind: Kind dir: Direction [multi => false] defstruct WField <: Expression : exp: Expression name: Symbol type: Type [multi => false] dir: Direction [multi => false] defstruct WIndex <: Expression : exp: Expression value: Int type: Type [multi => false] dir: Direction [multi => false] defstruct WDefAccessor <: Stmt : name: Symbol source: Expression index: Expression dir: Direction ;================ WORKING IR UTILS ========================= ;============== DEBUG STUFF ============================= public var PRINT-TYPES : True|False = false public var PRINT-KINDS : True|False = false public var PRINT-WIDTHS : True|False = false public var PRINT-CIRCUITS : True|False = false ;=== Printers === defmethod print (o:OutputStream, k:Kind) : print{o, _} $ match(k) : (k:RegKind) : "reg" (k:AccessorKind) : "accessor" (k:PortKind) : "port" (k:MemKind) : "mem" (k:NodeKind) : "n" (k:ModuleKind) : "module" (k:InstanceKind) : "inst" (k:StructuralMemKind) : "smem" defn hasWidth (e:Expression|Stmt|Type|Element|Port) : e typeof UIntType|SIntType|UIntValue|SIntValue defn hasType (e:Expression|Stmt|Type|Element|Port) : e typeof Ref|Field|Index|DoPrim|ReadPort|WRef|WField |WIndex|DefWire|DefRegister|DefMemory|Register |Memory|Node|Instance|VectorType|Port defn hasKind (e:Expression|Stmt|Type|Element|Port) : e typeof WRef defn any-debug? (e:Expression|Stmt|Type|Element|Port) : (hasType(e) and PRINT-TYPES) or (hasWidth(e) and PRINT-WIDTHS) or (hasKind(e) and PRINT-KINDS) defmethod print-debug (o:OutputStream, e:Expression|Stmt|Type|Element|Port) : if any-debug?(e) : print(o,"@") if PRINT-KINDS and hasKind(e) : print-all(o,[""]) if PRINT-TYPES and hasType(e) : print-all(o,[""]) if PRINT-WIDTHS and hasWidth(e): print-all(o,[""]) defmethod print (o:OutputStream, e:WRef) : print(o,name(e)) print-debug(o,e as ?) defmethod print (o:OutputStream, e:WField) : print-all(o,[exp(e) "." name(e)]) print-debug(o,e as ?) defmethod print (o:OutputStream, e:WIndex) : print-all(o,[exp(e) "." value(e)]) print-debug(o,e as ?) defmethod print (o:OutputStream, s:WDefAccessor) : print-all(o,[dir(s) " accessor " name(s) " = " source(s) "[" index(s) "]"]) print-debug(o,s) defmethod map (f: Expression -> Expression, e: WField) : WField(f(exp(e)), name(e), type(e), dir(e)) defmethod map (f: Expression -> Expression, e: WIndex) : WIndex(f(exp(e)), value(e), type(e), dir(e)) defmethod map (f: Expression -> Expression, c:WDefAccessor) : WDefAccessor(name(c), f(source(c)), f(index(c)), dir(c)) ;================= DIRECTION =============================== defmulti dir (e:Expression) -> Direction defmethod dir (e:Expression) : OUTPUT ;================= Bring to Working IR ======================== ; Returns a new Circuit with Refs, Fields, Indexes and DefAccessors ; replaced with IR-internal nodes that contain additional ; information (kind, direction) defn to-working-ir (c:Circuit) : defn to-exp (e:Expression) : match(map(to-exp,e)) : (e:Ref) : WRef(name(e), type(e), NodeKind(), UNKNOWN-DIR) (e:Field) : WField(exp(e), name(e), type(e), UNKNOWN-DIR) (e:Index) : WIndex(exp(e), value(e), type(e), UNKNOWN-DIR) (e) : e defn to-stmt (s:Stmt) : match(map(to-exp,s)) : (s:DefAccessor) : WDefAccessor(name(s),source(s),index(s), UNKNOWN-DIR) (s) : map(to-stmt,s) Circuit(modules*, main(c)) where : val modules* = for m in modules(c) map : Module(name(m), ports(m), to-stmt(body(m))) ;=============== Resolve Kinds ============================= ; It is useful for the compiler to know information about ; objects referenced. This information is stored in the kind ; field in WRef. This pass walks the graph and returns a new ; Circuit where all WRef kinds are resolved defn resolve-kinds (c:Circuit) : defn resolve (body:Stmt, kinds:HashTable) : defn resolve-stmt (s:Stmt) -> Stmt : map{resolve-expr,_} $ map(resolve-stmt,s) defn resolve-expr (e:Expression) -> Expression : match(e) : (e:WRef) : WRef(name(e),type(e),kinds[name(e)],dir(e)) (e) : map(resolve-expr,e) resolve-stmt(body) defn find (m:Module, kinds:HashTable) : defn find-stmt (s:Stmt) -> Stmt : match(s) : (s:LetRec) : for e in entries(s) do : kinds[key(e)] = get-elem-kind(value(e)) (s:DefWire) : kinds[name(s)] = NodeKind() (s:DefRegister) : kinds[name(s)] = RegKind() (s:DefInstance) : kinds[name(s)] = InstanceKind() (s:DefMemory) : kinds[name(s)] = MemKind() (s:WDefAccessor) : kinds[name(s)] = AccessorKind() (s) : false map(find-stmt,s) defn get-elem-kind (e:Element) : match(e) : (e: Memory) : StructuralMemKind() (e) : NodeKind() kinds[name(m)] = ModuleKind() for p in ports(m) do : kinds[name(p)] = PortKind() find-stmt(body(m)) defn resolve-kinds (m:Module, c:Circuit) -> Module : val kinds = HashTable(symbol-hash) for m in modules(c) do : kinds[name(m)] = ModuleKind() find(m,kinds) val body! = resolve(body(m),kinds) Module(name(m),ports(m),body!) Circuit(modules*, main(c)) where : val modules* = for m in modules(c) map : resolve-kinds(m,c) ;=============== MAKE EXPLICIT RESET ======================= ; All modules have an implicit reset signal - however, the ; programmer can explicitly reference this signal if desired. ; This pass makes all implicit resets explicit while ; preserving any previously explicit resets ; If reset is not explicitly passed to instantiations, then this ; pass autmatically connects the parent module's reset to the ; instantiation's reset defn make-explicit-reset (c:Circuit) : defn find-explicit (c:Circuit) -> List : defn explicit? (m:Module) -> True|False : for p in ports(m) any? : name(p) == `reset val explicit-reset = Vector() for m in modules(c) do: if explicit?(m) : add(explicit-reset,name(m)) to-list(explicit-reset) defn make-explicit (m:Module, explicit-reset:List) -> Module : defn route-reset (s:Stmt) -> Stmt : match(s) : (s:DefInstance) : val iref = WField(WRef(name(s), UnknownType(), InstanceKind(), UNKNOWN-DIR),`reset,UnknownType(),UNKNOWN-DIR) val pref = WRef(`reset, UnknownType(), PortKind(), INPUT) Begin(to-list([s,Connect(iref,pref)])) (s) : map(route-reset,s) var ports! = ports(m) if not contains?(explicit-reset,name(m)) : ports! = append(ports(m),list(Port(`reset,INPUT,UIntType(IntWidth(1))))) val body! = route-reset(body(m)) Module(name(m),ports!,body!) defn make-explicit-reset (m:Module, c:Circuit) -> Module : val explicit-reset = find-explicit(c) make-explicit(m,explicit-reset) Circuit(modules*, main(c)) where : val modules* = for m in modules(c) map : make-explicit-reset(m,c) ;======= MAKE EXPLICIT REGISTER INITIALIZATION ============= ; This pass replaces the reg.init construct by creating a new ; wire that holds the value at initialization. This wire ; is then connected to the register conditionally on reset, ; at the end of the scope containing the register ; declaration ; If a register has no inital value, the wire is connected to ; a NULL node. Later passes will remove these with the base ; case Mux(reset,NULL,a) -> a, and Mux(reset,a,NULL) -> a. ; This ensures proper behavior if this pass is run multiple ; times. defn initialize-registers (c:Circuit) : defn add-when (s:Stmt,renames:HashTable) -> Stmt : var inits = List() for kv in renames do : val refreg = WRef(key(kv),UnknownType(),RegKind(),UNKNOWN-DIR) val refwire = WRef(value(kv),UnknownType(),NodeKind(),UNKNOWN-DIR) val connect = Connect(refreg,refwire) inits = append(inits,list(connect)) if empty?(inits) : s else : val pred = WRef(`reset, UnknownType(), PortKind(), UNKNOWN-DIR) val when-reset = Conditionally(pred,Begin(inits),Begin(List())) Begin(list(s,when-reset)) defn rename (s:Stmt,l:HashTable) -> Stmt : defn rename-stmt (s:Stmt) -> Stmt : map{rename-expr,_} $ map(rename-stmt,s) defn rename-expr (e:Expression) -> Expression : match(e) : (e:WField) : if name(e) == `init and register?(exp(e)) : ;TODO Error if l does not contain register val new-name = l[name(exp(e) as WRef)] WRef(new-name,UnknownType(),NodeKind(),UNKNOWN-DIR) else : e (e) : map(rename-expr,e) defn register? (e:Expression) -> True|False : match(e) : (e:WRef) : kind(e) typeof RegKind (e) : false rename-stmt(s) defn initialize-registers (s:Stmt) -> [Stmt,HashTable] : val empty-hash = HashTable(symbol-hash) match(s) : (s:Begin) : var body! = List() var renames = HashTable(symbol-hash) for s in body(s) do : val [s!,renames!] = initialize-registers(s) body! = append(body!,list(s!)) merge!(renames,renames!) [Begin(body!),renames] (s:DefRegister) : val wire-name = gensym() val renames = HashTable(symbol-hash) renames[name(s)] = wire-name [Begin(body!),renames] where : val defreg = s val defwire = DefWire(wire-name,type(s)) val conwire = Connect(WRef(wire-name,UnknownType(),NodeKind(),UNKNOWN-DIR),Null()) val body! = list(defreg,defwire,conwire) (s:Conditionally) : [Conditionally(pred(s),initialize-scope(conseq(s)),initialize-scope(alt(s))),empty-hash] (s:LetRec) : [LetRec(entries(s),initialize-scope(body(s))),empty-hash] ;TODO Add Letrec (s) : [s,empty-hash] defn initialize-scope (s:Stmt) -> Stmt : val [s!,renames] = initialize-registers(s) val s!! = rename(s!,renames) val s!!! = add-when(s!!,renames) s!!! defn initialize-module (m:Module) -> Module : Module(name(m), ports(m), body!) where : val body! = initialize-scope(body(m)) Circuit(modules*, main(c)) where : val modules* = for m in modules(c) map : initialize-module(m) ;============== INFER TYPES ================================ ; This pass infers the type field in all IR nodes by updating ; and passing an environment to all statements in pre-order ; traversal, and resolving types in expressions in post- ; order traversal. ; Type propagation for primary ops are defined here. ; Notable cases: LetRec requires updating environment before ; resolving the subexpressions in its elements. ; Type errors are not checked in this pass, as this is ; postponed for a later/earlier pass. defmethod type (v:UIntValue) : UIntType(width(v)) defmethod type (v:SIntValue) : SIntType(width(v)) defn put-type (e:Expression, t:Type) -> Expression : match(e) : (e:WRef) : WRef(name(e), t, kind(e), dir(e)) (e:WField) : WField(exp(e), name(e), t, dir(e)) (e:WIndex) : WIndex(exp(e), value(e), t, dir(e)) (e:DoPrim) : DoPrim(op(e), args(e), consts(e), t) (e:ReadPort) : ReadPort(mem(e), index(e), t) (e) : e defn lookup-port (ports: Streamable, port-name: Symbol) : for port in ports find : name(port) == port-name defn infer (op:PrimOp, arg-types: List) -> Type : defn wipe-width (t:Type) : match(t) : (t:UIntType) : UIntType(UnknownWidth()) (t:SIntType) : SIntType(UnknownWidth()) defn arg0 () : wipe-width(arg-types[0]) defn arg1 () : wipe-width(arg-types[1]) ; TODO subtle, not entirely figured out switch {op == _} : ADD-OP : arg0() ADD-MOD-OP : arg0() SUB-OP : arg0() SUB-MOD-OP : arg0() TIMES-OP : arg0() DIVIDE-OP : arg0() MOD-OP : arg0() SHIFT-LEFT-OP : arg0() SHIFT-RIGHT-OP : arg0() PAD-OP : arg0() BIT-AND-OP : arg0() BIT-OR-OP : arg0() BIT-XOR-OP : arg0() CONCAT-OP : arg0() BIT-SELECT-OP : UIntType(UnknownWidth()) BITS-SELECT-OP : arg0() MULTIPLEX-OP : arg0() LESS-OP : UIntType(UnknownWidth()) LESS-EQ-OP : UIntType(UnknownWidth()) GREATER-OP : UIntType(UnknownWidth()) GREATER-EQ-OP : UIntType(UnknownWidth()) EQUAL-OP : UIntType(UnknownWidth()) defn bundle-field-type (t:Type, n:Symbol) -> Type : match(t) : (t:BundleType) : match(lookup-port(ports(t), n)) : (p:Port) : type(p) (p) : UnknownType() (t) : UnknownType() defn vector-elem-type (t:Type) -> Type : match(t) : (t:VectorType) : type(t) (t) : UnknownType() ;e is the environment that contains all definitions seen so far. defn infer (c:Stmt, e:List>) -> [Stmt, List>] : defn infer-exp (e:Expression, env:List>) : match(map(infer-exp{_, env}, e)) : (e:WRef) : put-type(e, lookup!(env, name(e))) (e:WField) : put-type(e, bundle-field-type(type(exp(e)), name(e))) (e:WIndex) : put-type(e, vector-elem-type(type(exp(e)))) (e:UIntValue) : e (e:SIntValue) : e (e:DoPrim) : put-type(e, infer(op(e), map(type, args(e)))) (e:ReadPort) : put-type(e, vector-elem-type(type(mem(e)))) (e:Null) : e defn element-type (e:Element, env:List>) : match(e) : (e:Instance) : val t = type(infer-exp(module(e), env)) match(t) : (t:BundleType) : BundleType $ to-list $ for p in ports(t) filter : direction(p) == OUTPUT (t) : UnknownType() (e) : type(e) match(c) : (c:LetRec) : val e* = append(elem-types, e) where : val elem-types = for entry in entries(c) map : key(entry) => element-type(value(entry), e) val c* = map(infer-exp{_, e*}, c) val [body*, be] = infer(body(c*), e*) [LetRec(entries(c*), body*), e] (c) : match(map(infer-exp{_, e}, c)) : (c:DefWire) : [c, List(entry, e)] where : val entry = name(c) => type(c) (c:DefRegister) : [c, List(entry, e)] where : val entry = name(c) => type(c) (c:DefInstance) : [c, List(entry, e)] where : val entry = name(c) => type(module(c)) (c:DefMemory) : [c, List(entry, e)] where : val entry = name(c) => type(c) (c:WDefAccessor) : [c, List(entry, e)] where : val src-type = type(source(c)) val entry = name(c) => vector-elem-type(src-type) (c:Begin) : var current-e: List> = e val body* = for c in body(c) map : val [c*, e*] = infer(c, current-e) current-e = e* c* [Begin(body*), current-e] (c) : defn infer-comm (c:Stmt) : val [c* e*] = infer(c, e) c* val c* = map(infer-comm, c) [c*, e] defn infer (m:Module, e:List>) -> Module : val env = append{_, e} $ for p in ports(m) map : name(p) => type(p) val [body*, e*] = infer(body(m), env) Module(name(m), ports(m), body*) defn infer-types (c:Circuit) -> Circuit : val env = for m in modules(c) map : name(m) => BundleType(ports(m)) Circuit(map(infer{_, env}, modules(c)), main(c)) ;============= INFER DIRECTIONS ============================ defn flip (d:Direction) : switch {d == _} : INPUT : OUTPUT OUTPUT : INPUT else : d defn times (d1:Direction, d2:Direction) : if d1 == INPUT : flip(d2) else : d2 defn bundle-field-dir (t:Type, n:Symbol) -> Direction : match(t) : (t:BundleType) : match(lookup-port(ports(t), n)) : (p:Port) : direction(p) (p) : UNKNOWN-DIR (t) : UNKNOWN-DIR defn infer-dirs (m:Module) : ;=== Direction of all Binders === val BI-DIR = new Direction val directions = HashTable(symbol-hash) defn find-dirs (c:Stmt) : match(c) : (c:LetRec) : for entry in entries(c) do : directions[key(entry)] = OUTPUT find-dirs(body(c)) (c:DefWire) : directions[name(c)] = BI-DIR (c:DefRegister) : directions[name(c)] = BI-DIR (c:DefInstance) : directions[name(c)] = OUTPUT (c:DefMemory) : directions[name(c)] = BI-DIR (c:WDefAccessor) : directions[name(c)] = dir(c) (c) : do(find-dirs, children(c)) for p in ports(m) do : directions[name(p)] = flip(direction(p)) find-dirs(body(m)) ;=== Fix Point Status === var changed? = false ;=== Infer directions of Expression === defn infer-exp (e:Expression, desired:Direction) : match(e) : (e:WRef) : val dir* = let : if kind(e) typeof ModuleKind : OUTPUT else : val old-dir = directions[name(e)] switch {old-dir == _} : BI-DIR : desired UNKNOWN-DIR : if directions[name(e)] != desired : directions[name(e)] = desired changed? = true desired else : old-dir WRef(name(e), type(e), kind(e), dir*) (e:WField) : val port-dir = bundle-field-dir(type(exp(e)), name(e)) val exp* = infer-exp(exp(e), port-dir * desired) WField(exp*, name(e), type(e), port-dir * dir(exp*)) (e:WIndex) : val exp* = infer-exp(exp(e), desired) WIndex(exp*, value(e), type(e), dir(exp*)) (e) : map(infer-exp{_, OUTPUT}, e) ;=== Infer directions of Stmts === defn infer-comm (c:Stmt) : match(c) : (c:LetRec) : val c* = map(infer-exp{_, OUTPUT}, c) LetRec(entries(c*), infer-comm(body(c))) (c:DefInstance) : DefInstance(name(c), infer-exp(module(c), OUTPUT)) (c:WDefAccessor) : val d = directions[name(c)] WDefAccessor(name(c), infer-exp(source(c), d), infer-exp(index(c), OUTPUT), d) (c:Conditionally) : Conditionally(infer-exp(pred(c), OUTPUT), infer-comm(conseq(c)), infer-comm(alt(c))) (c:Connect) : Connect(infer-exp(loc(c), INPUT), infer-exp(exp(c), OUTPUT)) (c) : map(infer-comm, c) ;=== Iterate until fix point === defn* fixpoint (c:Stmt) : changed? = false val c* = infer-comm(c) if changed? : fixpoint(c*) else : c* Module(name(m), ports(m), body*) where : val body* = fixpoint(body(m)) defn infer-directions (c:Circuit) : Circuit(modules*, main(c)) where : val modules* = map(infer-dirs, modules(c)) ;============== EXPAND VECS ================================ defstruct ManyConnect <: Stmt : index: Expression locs: List exp: Expression defstruct ConnectMany <: Stmt : index: Expression loc: Expression exps: List defmethod print (o:OutputStream, c:ManyConnect) : print-all(o, [locs(c) "[" index(c) "] := " exp(c)]) defmethod print (o:OutputStream, c:ConnectMany) : print-all(o, [loc(c) " := " exps(c) "[" index(c) "]"]) defmethod map (f: Expression -> Expression, c:ManyConnect) : ManyConnect(f(index(c)), map(f, locs(c)), f(exp(c))) defmethod map (f: Expression -> Expression, c:ConnectMany) : ConnectMany(f(index(c)), f(loc(c)), map(f, exps(c))) defn expand-accessors (m: Module) : defn expand (c:Stmt) : match(c) : (c:WDefAccessor) : ;Is the source a memory? val mem? = match(source(c)) : (r:WRef) : kind(r) typeof MemKind (r) : false if mem? : c else : switch {dir(c) == _} : INPUT : Begin(list( DefWire(name(c), type(src-type)) ManyConnect(index(c), elems, wire-ref))) where : val src-type = type(source(c)) as VectorType val wire-ref = WRef(name(c), type(src-type), NodeKind(), OUTPUT) val elems = to-list $ for i in 0 to size(src-type) stream : WIndex(source(c), i, type(src-type), INPUT) OUTPUT : Begin(list( DefWire(name(c), type(src-type)) ConnectMany(index(c), wire-ref, elems))) where : val src-type = type(source(c)) as VectorType val wire-ref = WRef(name(c), type(src-type), NodeKind(), INPUT) val elems = to-list $ for i in 0 to size(src-type) stream : WIndex(source(c), i, type(src-type), OUTPUT) (c) : map(expand, c) Module(name(m), ports(m), expand(body(m))) defn expand-accessors (c:Circuit) : Circuit(modules*, main(c)) where : val modules* = map(expand-accessors, modules(c)) ;=============== BUNDLE FLATTENING ========================= defn prefix (prefix, suffix) : symbol-join([prefix "/" suffix]) defn prefix-ports (pre:Symbol, ports:List) : for p in ports map : Port(prefix(pre, name(p)), direction(p), type(p)) defn flatten-ports (port:Port) -> List : match(type(port)) : (t:BundleType) : val ports = map-append(flatten-ports, ports(t)) for p in ports map : Port(prefix(name(port), name(p)), direction(port) * direction(p), type(p)) (t:VectorType) : val type* = flatten-type(t) flatten-ports(Port(name(port), direction(port), type*)) (t:Type) : list(port) defn flatten-type (t:Type) -> Type : match(t) : (t:BundleType) : BundleType $ map-append(flatten-ports, ports(t)) (t:VectorType) : flatten-type $ BundleType $ to-list $ for i in 0 to size(t) stream : Port(to-symbol(i), OUTPUT, type(t)) (t:Type) : t defn flatten-bundles (c:Circuit) : defn flatten-exp (e:Expression) : match(map(flatten-exp, e)) : (e:UIntValue|SIntValue) : e (e:WRef) : match(kind(e)) : (k:MemKind|StructuralMemKind) : val type* = map(flatten-type, type(e)) put-type(e, type*) (k) : val type* = flatten-type(type(e)) put-type(e, type*) (e) : val type* = flatten-type(type(e)) put-type(e, type*) defn flatten-element (e:Element) : val t* = flatten-type(type(e)) match(map(flatten-exp, e)) : (e:Register) : Register(t*, value(e), enable(e)) (e:Memory) : Memory(t*, writers(e)) (e:Node) : Node(t*, value(e)) (e:Instance) : Instance(t*, module(e), ports(e)) defn flatten-comm (c:Stmt) : match(c) : (c:LetRec) : val entries* = for entry in entries(c) map : key(entry) => flatten-element(value(entry)) LetRec(entries*, flatten-comm(body(c))) (c:DefWire) : DefWire(name(c), flatten-type(type(c))) (c:DefRegister) : DefRegister(name(c), flatten-type(type(c))) (c:DefMemory) : val type* = map(flatten-type, type(c)) DefMemory(name(c), type*) (c) : map{flatten-comm, _} $ map(flatten-exp, c) defn flatten-module (m:Module) : val ports* = map-append(flatten-ports, ports(m)) val body* = flatten-comm(body(m)) Module(name(m), ports*, body*) Circuit(modules*, main(c)) where : val modules* = map(flatten-module, modules(c)) ;================== BUNDLE EXPANSION ======================= defn expand-bundles (m:Module) : ;Collapse all field/index expressions defn collapse-exp (e:Expression) -> Expression : match(e) : (e:WField) : match(collapse-exp(exp(e))) : (ei:WRef) : if kind(ei) typeof InstanceKind : e else : WRef(name*, type(e), kind(ei), dir(e)) where : val name* = prefix(name(ei), name(e)) (ei:WField) : WField(exp(ei), name*, type(e), dir(e)) where : val name* = prefix(name(ei), name(e)) (e:WIndex) : collapse-exp(WField(exp(e), name, type(e), dir(e))) where : val name = to-symbol(value(e)) (e) : map(collapse-exp, e) ;Expand expressions defn expand-exp (e:Expression) -> List : match(type(e)) : (t:BundleType) : for p in ports(t) map : val dir* = direction(p) * dir(e) collapse-exp(WField(e, name(p), type(p), dir*)) (t) : list(collapse-exp(e)) ;Expand commands defn expand-comm (c:Stmt) : match(c) : (c:DefWire) : match(type(c)) : (t:BundleType) : Begin $ for p in ports(t) map : DefWire(prefix(name(c), name(p)), type(p)) (t) : c (c:DefRegister) : match(type(c)) : (t:BundleType) : Begin $ for p in ports(t) map : DefRegister(prefix(name(c), name(p)), type(p)) (t) : c (c:DefMemory) : match(type(type(c))) : (t:BundleType) : Begin $ for p in ports(t) map : DefMemory(prefix(name(c), name(p)), type*) where : val s = size(type(c)) val type* = VectorType(type(p), s) (t) : c (c:WDefAccessor) : match(type(source(c))) : (t:BundleType) : val srcs = expand-exp(source(c)) Begin $ for (p in ports(t), src in srcs) map : WDefAccessor(name*, src, index(c), dir*) where : val name* = prefix(name(c), name(p)) val dir* = direction(p) * dir(c) (t) : c (c:Connect) : val locs = expand-exp(loc(c)) val exps = expand-exp(exp(c)) Begin $ for (l in locs, e in exps) map : switch {dir(l) == _} : INPUT : Connect(l, e) OUTPUT : Connect(e, l) (c:ManyConnect) : val locs-list = transpose(map(expand-exp, locs(c))) val exps = expand-exp(exp(c)) Begin $ for (locs in locs-list, e in exps) map : switch {dir(e) == _} : OUTPUT : ManyConnect(index(c), locs, e) INPUT : ConnectMany(index(c), e, locs) (c:ConnectMany) : val locs = expand-exp(loc(c)) val exps-list = transpose(map(expand-exp, exps(c))) Begin $ for (l in locs, exps in exps-list) map : switch {dir(l) == _} : INPUT : ConnectMany(index(c), l, exps) OUTPUT : ManyConnect(index(c), exps, l) (c) : map{expand-comm, _} $ map(collapse-exp, c) Module(name(m), ports(m), expand-comm(body(m))) defn expand-bundles (c:Circuit) : Circuit(modules*, main(c)) where : val modules* = map(expand-bundles, modules(c)) ;=========== CONVERT MULTI CONNECTS to WHEN ================ defn expand-multi-connects (c:Circuit) : defn equal-exp (e1:Expression, e2:Expression) : DoPrim(EQUAL-OP, list(e1, e2), List(), UIntType(UnknownWidth())) defn uint (i:Int) : UIntValue(i, UnknownWidth()) defn expand-comm (c:Stmt) : match(c) : (c:ConnectMany) : Begin $ to-list $ for (i in 0 to false, e in exps(c)) stream : Conditionally(equal-exp(index(c), uint(i)), Connect(loc(c), e) EmptyStmt()) (c:ManyConnect) : Begin $ to-list $ for (i in 0 to false, l in locs(c)) stream : Conditionally(equal-exp(index(c), uint(i)), Connect(l, exp(c)) EmptyStmt()) (c) : map(expand-comm, c) defn expand (m:Module) : Module(name(m), ports(m), expand-comm(body(m))) Circuit(modules*, main(c)) where : val modules* = map(expand, modules(c)) ;================ EXPAND WHENS ============================= definterface SymbolicValue defstruct ExpValue <: SymbolicValue : exp: Expression defstruct WhenValue <: SymbolicValue : pred: Expression conseq: SymbolicValue alt: SymbolicValue defstruct VoidValue <: SymbolicValue defmethod print (o:OutputStream, sv:SymbolicValue) : match(sv) : (sv:VoidValue) : print(o, "VOID") (sv:WhenValue) : print-all(o, ["(" pred(sv) "? " conseq(sv) " : " alt(sv) ")"]) (sv:ExpValue) : print(o, exp(sv)) defn key-eqv? (e1:Expression, e2:Expression) : match(e1, e2) : (e1:WRef, e2:WRef) : name(e1) == name(e2) (e1:WField, e2:WField) : name(e1) == name(e2) and key-eqv?(exp(e1), exp(e2)) (e1, e2) : false defn merge-env (pred: Expression, con-env: List>, alt-env: List>) : val merged = Vector>() defn new-key? (k:Expression) : for entry in merged none? : key-eqv?(key(entry), k) defn sv (env:List>, k:Expression) : for entry in env search : if key-eqv?(key(entry), k) : value(entry) for k in stream(key, concat(con-env, alt-env)) do : if new-key?(k) : match(sv(con-env, k), sv(alt-env, k)) : (a:SymbolicValue, b:SymbolicValue) : if a == b : add(merged, k => a) else : add(merged, k => WhenValue(pred, a, b)) (a:SymbolicValue, b:False) : add(merged, k => a) (a:False, b:SymbolicValue) : add(merged, k => b) (a:False, b:False) : false to-list(merged) defn simplify-env (env: List>) : val merged = Vector>() defn new-key? (k:Expression) : for entry in merged none? : key-eqv?(key(entry), k) for entry in env do : if new-key?(key(entry)) : add(merged, entry) to-list(merged) defn expand-whens (m:Module) : val commands = Vector() val elements = Vector>() defn eval (c:Stmt, env:List>) -> List> : match(c) : (c:LetRec) : do(add{elements, _}, entries(c)) eval(body(c), env) (c:DefWire) : add(commands, c) val wire-ref = WRef(name(c), type(c), NodeKind(), INPUT) List(wire-ref => VoidValue(), env) (c:DefRegister) : add(commands, c) val reg-ref = WRef(name(c), type(c), RegKind(), INPUT) List(reg-ref => VoidValue(), env) (c:DefInstance) : add(commands, c) val entries = let : val module-type = type(module(c)) as BundleType val input-ports = to-list $ for p in ports(module-type) filter : direction(p) == INPUT val inst-ref = WRef(name(c), module-type, InstanceKind(), OUTPUT) for p in input-ports map : WField(inst-ref, name(p), type(p), INPUT) => VoidValue() append(entries, env) (c:DefMemory) : add(commands, c) env (c:WDefAccessor) : add(commands, c) if dir(c) == INPUT : val access-ref = WRef(name(c), type(source(c)), AccessorKind(), dir(c)) List(access-ref => VoidValue(), env) else : env (c:Conditionally) : val con-env = eval(conseq(c), env) val alt-env = eval(alt(c), env) merge-env(pred(c), con-env, alt-env) (c:Begin) : var env:List> = env for c in body(c) do : env = eval(c, env) env (c:Connect) : List(loc(c) => ExpValue(exp(c)), env) (c:EmptyStmt) : env defn convert-symbolic (key:Expression, sv:SymbolicValue) : match(sv) : (sv:VoidValue) : throw $ PassException $ string-join $ [ "No default value for " key "."] (sv:ExpValue) : exp(sv) (sv:WhenValue) : defn multiplex-exp (pred:Expression, conseq:Expression, alt:Expression) : DoPrim(MULTIPLEX-OP, list(pred, conseq, alt), List(), type(conseq)) multiplex-exp(pred(sv), convert-symbolic(key, conseq(sv)) convert-symbolic(key, alt(sv))) ;Compute final environment val env0 = let : val output-ports = to-list $ for p in ports(m) filter : direction(p) == OUTPUT for p in output-ports map : val port-ref = WRef(name(p), type(p), PortKind(), INPUT) port-ref => VoidValue() val env* = simplify-env(eval(body(m), env0)) ;Make new body val body* = Begin(list(defs, LetRec(elems, connections))) where : val defs = Begin(to-list(commands)) val elems = to-list(elements) val connections = Begin $ for entry in env* map : val sv = convert-symbolic(key(entry), value(entry)) Connect(key(entry), sv) ;Final module Module(name(m), ports(m), body*) defn expand-whens (c:Circuit) : val modules* = map(expand-whens, modules(c)) Circuit(modules*, main(c)) ;================ STRUCTURAL FORM ========================== defn structural-form (m:Module) : val elements = Vector<(() -> KeyValue)>() val connected = HashTable(symbol-hash) val write-accessors = HashTable>(symbol-hash) val read-accessors = HashTable(symbol-hash) val inst-ports = HashTable>>(symbol-hash) val port-connects = Vector() defn scan (c:Stmt) : match(c) : (c:Connect) : match(loc(c)) : (loc:WRef) : match(kind(loc)) : (k:PortKind) : add(port-connects, c) (k) : connected[name(loc)] = exp(c) (loc:WField) : val inst = exp(loc) as WRef val entry = name(loc) => exp(c) inst-ports[name(inst)] = List(entry, get?(inst-ports, name(inst), List())) (c:LetRec) : for e in entries(c) do : add(elements, {e}) scan(body(c)) (c:DefWire) : add{elements, _} $ fn () : name(c) => Node(type(c), connected[name(c)]) (c:DefRegister) : add{elements, _} $ fn () : val one = UIntValue(1, UnknownWidth()) name(c) => Register(type(c), connected[name(c)], one) (c:DefInstance) : add{elements, _} $ fn () : name(c) => Instance(UnknownType(), module(c), inst-ports[name(c)]) (c:DefMemory) : add{elements, _} $ fn () : val ports = for a in get?(write-accessors, name(c), List()) map : val one = UIntValue(1, UnknownWidth()) WritePort(index(a), connected[name(a)], one) name(c) => Memory(type(c), ports) (c:WDefAccessor) : val mem = source(c) as WRef switch {dir(c) == _} : INPUT : write-accessors[name(mem)] = List(c, get?(write-accessors, name(mem), List())) OUTPUT : read-accessors[name(c)] = c (c) : do(scan, children(c)) defn make-read-ports (e:Expression) : match(e) : (e:WRef) : match(kind(e)) : (k:AccessorKind) : val accessor = read-accessors[name(e)] ReadPort(source(accessor), index(accessor), type(e)) (k) : e (e) : map(make-read-ports, e) Module(name(m), ports(m), body*) where : scan(body(m)) val elems = to-list $ for e in elements stream : val entry = e() key(entry) => map(make-read-ports, value(entry)) val connect-ports = Begin $ to-list $ for c in port-connects stream : Connect(loc(c), make-read-ports(exp(c))) val body* = if empty?(elems) : connect-ports else : LetRec(elems, connect-ports) defn structural-form (c:Circuit) : val modules* = map(structural-form, modules(c)) Circuit(modules*, main(c)) ;==================== WIDTH INFERENCE ====================== defstruct WidthVar <: Width : name: Symbol defmethod print (o:OutputStream, w:WidthVar) : print(o, name(w)) defn width! (t:Type) : match(t) : (t:UIntType) : width(t) (t:SIntType) : width(t) (t) : error("No width field.") defn put-width (t:Type, w:Width) : match(t) : (t:UIntType) : UIntType(w) (t:SIntType) : SIntType(w) (t) : t defn put-width (e:Expression, w:Width) : val type* = put-width(type(e), w) put-type(e, type*) defn add-width-vars (t:Type) : defn width? (w:Width) : match(w) : (w:UnknownWidth) : WidthVar(gensym()) (w) : w match(t) : (t:UIntType) : UIntType(width?(width(t))) (t:SIntType) : SIntType(width?(width(t))) (t) : map(add-width-vars, t) defn uint-width (i:Int) : var v:Int = i var n:Int = 0 while v != 0 : v = v >> 1 n = n + 1 IntWidth(n) defn sint-width (i:Int) : if i > 0 : val w = uint-width(i) IntWidth(width(w) + 1) else : val w = uint-width(neg(i) - 1) IntWidth(width(w) + 1) defn to-exp (w:Width) : match(w) : (w:IntWidth) : ELit(width(w)) (w:WidthVar) : EVar(name(w)) (w) : error $ string-join $ [ "Cannot convert " w " to exp."] defn primop-width (op:PrimOp, ws:List, ints:List) -> Exp : defn wmax (w1:Width, w2:Width) : EMax(to-exp(w1), to-exp(w2)) defn wplus (w1:Width, w2:Width) : EPlus(to-exp(w1), to-exp(w2)) defn wplus (w1:Width, w2:Int) : EPlus(to-exp(w1), ELit(w2)) defn wminus (w1:Width, w2:Width) : EMinus(to-exp(w1), to-exp(w2)) defn wminus (w1:Width, w2:Int) : EMinus(to-exp(w1), ELit(w2)) defn wmax-inc (w1:Width, w2:Width) : EPlus(wmax(w1, w2), ELit(1)) switch {op == _} : ADD-OP : wmax-inc(ws[0], ws[1]) ADD-MOD-OP : wmax(ws[0], ws[1]) SUB-OP : wmax-inc(ws[0], ws[1]) SUB-MOD-OP : wmax(ws[0], ws[1]) TIMES-OP : wplus(ws[0], ws[1]) DIVIDE-OP : wminus(ws[0], ws[1]) MOD-OP : to-exp(ws[1]) SHIFT-LEFT-OP : wplus(ws[0], ints[0]) SHIFT-RIGHT-OP : wminus(ws[0], ints[0]) PAD-OP : ELit(ints[0]) BIT-AND-OP : wmax(ws[0], ws[1]) BIT-OR-OP : wmax(ws[0], ws[1]) BIT-XOR-OP : wmax(ws[0], ws[1]) CONCAT-OP : wplus(ws[0], ints[0]) BIT-SELECT-OP : ELit(1) BITS-SELECT-OP : ELit(ints[0]) MULTIPLEX-OP : wmax(ws[1], ws[2]) LESS-OP : ELit(1) LESS-EQ-OP : ELit(1) GREATER-OP : ELit(1) GREATER-EQ-OP : ELit(1) EQUAL-OP : ELit(1) defn put-type (el:Element, t:Type) : match(el) : (el:Register) : Register(t, value(el), enable(el)) (el:Memory) : Memory(t, writers(el)) (el:Node) : Node(t, value(el)) (el:Instance) : Instance(t, module(el), ports(el)) defn generate-constraints (c:Circuit) -> [Circuit, Vector] : ;Constraints val cs = Vector() defn new-constraint (Constraint: (Symbol, Exp) -> WConstraint, wvar:Width, width:Width) : match(wvar) : (wvar:WidthVar) : add(cs, Constraint(name(wvar), to-exp(width))) (wvar) : false defn to-width (e:Exp) : match(e) : (e:ELit) : IntWidth(width(e)) (e:EVar) : WidthVar(name(e)) (e) : val x = gensym() add(cs, WidthEqual(x, e)) WidthVar(x) ;Module types val mod-types = HashTable(symbol-hash) defn add-port-vars (m:Module) -> Module : val ports* = for p in ports(m) map : val type* = add-width-vars(type(p)) Port(name(p), direction(p), type*) mod-types[name(m)] = BundleType(ports*) Module(name(m), ports*, body(m)) ;Add Width Variables defn add-module-vars (m:Module) -> Module : val types = HashTable(symbol-hash) for p in ports(m) do : types[name(p)] = type(p) defn infer-exp-width (e:Expression) : match(map(infer-exp-width, e)) : (e:WRef) : match(kind(e)) : (k:ModuleKind) : put-type(e, mod-types[name(e)]) (k) : put-type(e, types[name(e)]) (e:WField) : val t = bundle-field-type(type(exp(e)), name(e)) put-width(e, width!(t)) (e:UIntValue) : match(width(e)) : (w:UnknownWidth) : UIntValue(value(e), uint-width(value(e))) (w) : e (e:SIntValue) : match(width(e)) : (w:UnknownWidth) : SIntValue(value(e), sint-width(value(e))) (w) : e (e:DoPrim) : val widths = map(width!{type(_)}, args(e)) val w = to-width(primop-width(op(e), widths, consts(e))) put-width(e, w) (e:ReadPort) : val elem-type = type(type(mem(e)) as VectorType) put-width(e, width!(elem-type)) defn infer-comm-width (c:Stmt) : match(c) : (c:LetRec) : ;Add width vars to elements var entries*: List> = for entry in entries(c) map : val el-name = key(entry) key(entry) => match(value(entry)) : (el:Register|Node) : put-type(el, add-width-vars(type(el))) (el:Memory) : el (el:Instance) : val mod-type = type(infer-exp-width(module(el))) as BundleType val type = BundleType $ to-list $ for p in ports(mod-type) filter : direction(p) == OUTPUT put-type(el, type) ;Add vars to environment for entry in entries* do : types[key(entry)] = type(value(entry)) ;Infer types for elements entries* = for entry in entries* map : key(entry) => map(infer-exp-width, value(entry)) ;Generate constraints for entry in entries* do : val el-name = key(entry) match(value(entry)) : (el:Register) : new-constraint(WidthEqual, reg-width, val-width) where : val reg-width = width!(types[el-name]) val val-width = width!(type(value(el))) (el:Node) : new-constraint(WidthEqual, node-width, val-width) where : val node-width = width!(types[el-name]) val val-width = width!(type(value(el))) (el:Instance) : val mod-type = type(module(el)) as BundleType for entry in ports(el) do : new-constraint(WidthGreater, port-width, val-width) where : val port-name = key(entry) val port-width = width!(bundle-field-type(mod-type, port-name)) val val-width = width!(type(value(entry))) (el) : false ;Analyze body LetRec(entries*, infer-comm-width(body(c))) (c:Connect) : val loc* = infer-exp-width(loc(c)) val exp* = infer-exp-width(exp(c)) new-constraint(WidthGreater, loc-width, exp-width) where : val loc-width = width!(type(loc*)) val exp-width = width!(type(exp*)) Connect(loc*, exp*) (c:Begin) : map(infer-comm-width, c) Module(name(m), ports(m), body*) where : val body* = infer-comm-width(body(m)) val c* = Circuit(modules*, main(c)) where : val ms = map(add-port-vars, modules(c)) val modules* = map(add-module-vars, ms) [c*, cs] ;================== FILL WIDTHS ============================ defn fill-widths (c:Circuit, solved:Streamable) : ;Populate table val table = HashTable(symbol-hash) for eq in solved do : table[name(eq)] = IntWidth(width(value(eq) as ELit)) defn width? (w:Width) : match(w) : (w:WidthVar) : get?(table, name(w), UnknownWidth()) (w) : w defn fill-type (t:Type) : match(t) : (t:UIntType) : UIntType(width?(width(t))) (t:SIntType) : SIntType(width?(width(t))) (t) : map(fill-type, t) defn fill-exp (e:Expression) : val e* = map(fill-exp, e) val type* = fill-type(type(e)) put-type(e*, type*) defn fill-element (e:Element) : val e* = map(fill-exp, e) val type* = fill-type(type(e)) put-type(e*, type*) defn fill-comm (c:Stmt) : match(c) : (c:LetRec) : val entries* = for e in entries(c) map : key(e) => fill-element(value(e)) LetRec(entries*, fill-comm(body(c))) (c) : map{fill-comm, _} $ map(fill-exp, c) defn fill-port (p:Port) : Port(name(p), direction(p), fill-type(type(p))) defn fill-mod (m:Module) : Module(name(m), ports*, body*) where : val ports* = map(fill-port, ports(m)) val body* = fill-comm(body(m)) Circuit(modules*, main(c)) where : val modules* = map(fill-mod, modules(c)) ;=============== TYPE INFERENCE DRIVER ===================== defn infer-widths (c:Circuit) : val [c*, cs] = generate-constraints(c) val solved = solve-widths(cs) fill-widths(c*, solved) ;================ PAD WIDTHS =============================== defn pad-widths (c:Circuit) : ;Pad an expression to the given width defn pad-exp (e:Expression, w:Int) : match(type(e)) : (t:UIntType|SIntType) : val prev-w = width!(t) as IntWidth if width(prev-w) < w : val type* = put-width(t, IntWidth(w)) DoPrim(PAD-OP, list(e), list(w), type*) else : e (t) : e defn pad-exp (e:Expression, w:Width) : val w-value = width(w as IntWidth) pad-exp(e, w-value) ;Convenience defn max-width (es:Streamable) : defn int-width (e:Expression) : width(width!(type(e)) as IntWidth) maximum(stream(int-width, es)) defn match-widths (es:List) : val w = max-width(es) map(pad-exp{_, w}, es) ;Match widths for an expression defn match-exp-width (e:Expression) : match(map(match-exp-width, e)) : (e:DoPrim) : if contains?([BIT-AND-OP, BIT-OR-OP, BIT-XOR-OP, EQUAL-OP], op(e)) : val args* = match-widths(args(e)) DoPrim(op(e), args*, consts(e), type(e)) else if op(e) == MULTIPLEX-OP : val args* = List(head(args(e)), match-widths(tail(args(e)))) DoPrim(op(e), args*, consts(e), type(e)) else : e (e) : e defn match-element-width (e:Element) : match(map(match-exp-width, e)) : (e:Register) : val w = width!(type(e)) val value* = pad-exp(value(e), w) Register(type(e), value*, enable(e)) (e:Memory) : val width = width!(type(type(e) as VectorType)) val writers* = for w in writers(e) map : WritePort(index(w), pad-exp(value(w), width), enable(w)) Memory(type(e), writers*) (e:Node) : val w = width!(type(e)) val value* = pad-exp(value(e), w) Node(type(e), value*) (e:Instance) : val mod-type = type(module(e)) as BundleType val ports* = for p in ports(e) map : val port-type = bundle-field-type(mod-type, key(p)) val port-val = pad-exp(value(p), width!(port-type)) key(p) => port-val Instance(type(e), module(e), ports*) ;Match widths for a command defn match-comm-width (c:Stmt) : match(map(match-exp-width, c)) : (c:LetRec) : val entries* = for e in entries(c) map : key(e) => match-element-width(value(e)) LetRec(entries*, match-comm-width(body(c))) (c:Connect) : val w = width!(type(loc(c))) val exp* = pad-exp(exp(c), w) Connect(loc(c), exp*) (c) : map(match-comm-width, c) defn match-mod-width (m:Module) : Module(name(m), ports(m), body*) where : val body* = match-comm-width(body(m)) Circuit(modules*, main(c)) where : val modules* = map(match-mod-width, modules(c)) ;================== INLINING =============================== defn inline-instances (c:Circuit) : val module-table = HashTable(symbol-hash) val inlined? = HashTable(symbol-hash) for m in modules(c) do : module-table[name(m)] = m inlined?[name(m)] = false ;Convert a module into a sequence of elements defn to-elements (m:Module, inst:Symbol, port-exps:List>) -> List> : defn rename-exp (e:Expression) : match(e) : (e:WRef) : WRef(prefix(inst, name(e)), type(e), kind(e), dir(e)) (e) : map(rename-exp, e) defn to-elements (c:Stmt) -> List> : match(c) : (c:LetRec) : val entries* = for entry in entries(c) map : val name* = prefix(inst, key(entry)) val element* = map(rename-exp, value(entry)) name* => element* val body* = to-elements(body(c)) append(entries*, body*) (c:Connect) : val ref = loc(c) as WRef val name* = prefix(inst, name(ref)) list(name* => Node(type(exp(c)), rename-exp(exp(c)))) (c:Begin) : map-append(to-elements, body(c)) val inputs = for p in ports(m) map-append : if direction(p) == INPUT : val port-exp = lookup!(port-exps, name(p)) val name* = prefix(inst, name(p)) list(name* => Node(type(port-exp), port-exp)) else : List() append(inputs, to-elements(body(m))) ;Inline all instances in the module defn inline-instances (m:Module) : defn rename-exp (e:Expression) : match(e) : (e:WField) : val inst-exp = exp(e) as WRef val name* = prefix(name(inst-exp), name(e)) WRef(name*, type(e), NodeKind(), dir(e)) (e) : map(rename-exp, e) defn inline-elems (es:List>) : for entry in es map-append : match(value(entry)) : (el:Instance) : val mod-name = name(module(el) as WRef) val module = inlined-module(mod-name) to-elements(module, key(entry), ports(el)) (el) : list(entry) defn inline-comm (c:Stmt) : match(map(rename-exp, c)) : (c:LetRec) : val entries* = inline-elems(entries(c)) LetRec(entries*, inline-comm(body(c))) (c) : map(inline-comm, c) Module(name(m), ports(m), inline-comm(body(m))) ;Retrieve the inlined instance of a module defn inlined-module (name:Symbol) : if inlined?[name] : module-table[name] else : val module* = inline-instances(module-table[name]) module-table[name] = module* inlined?[name] = true module* ;Return the fully inlined circuit val main-module = inlined-module(main(c)) Circuit(list(main-module), main(c)) ;;;================ UTILITIES ================================ ; ; ; ;defn* root-ref (i:Immediate) : ; match(i) : ; (f:Field) : root-ref(imm(f)) ; (ind:Index) : root-ref(imm(ind)) ; (r) : r ; ;;defn lookup (e: Streamable>, i:Immediate) : ;; for entry in e search : ;; if eqv?(key(entry), i) : ;; value(entry) ;; ;;defn lookup! (e: Streamable>, i:Immediate) : ;; lookup(e, i) as T ;; ;;============ CHECK IF NAMES ARE UNIQUE ==================== ;defn check-duplicate-symbols (names: Streamable, msg: String) : ; val dict = HashTable(symbol-hash) ; for name in names do: ; if key?(dict, name): ; throw $ PassException $ string-join $ ; [msg ": " name] ; else: ; dict[name] = true ; ;defn check-duplicates (t: Type) : ; match(t) : ; (t:BundleType) : ; val names = map(name, ports(t)) ; check-duplicate-symbols{names, string-join(_)} $ ; ["Duplicate port name in bundle "] ; do(check-duplicates{type(_)}, ports(t)) ; (t:VectorType) : ; check-duplicates(type(t)) ; (t) : false ; ;defn check-duplicates (c: Stmt) : ; match(c) : ; (c:DefWire) : check-duplicates(type(c)) ; (c:DefRegister) : check-duplicates(type(c)) ; (c:DefMemory) : check-duplicates(type(c)) ; (c) : do(check-duplicates, children(c)) ; ;defn defined-names (c: Stmt) : ; generate : ; loop(c) where : ; defn loop (c:Stmt) : ; match(c) : ; (c:Stmt&HasName) : yield(name(c)) ; (c) : do(loop, children(c)) ; ;defn check-duplicates (m: Module): ; ;Check all duplicate names in all types in all ports and body ; do(check-duplicates{type(_)}, ports(m)) ; check-duplicates(body(m)) ; ; ;Check all names defined in module ; val names = concat(stream(name, ports(m)), ; defined-names(body(m))) ; check-duplicate-symbols{names, string-join(_)} $ ; ["Duplicate definition name in module " name(m)] ; ;defn check-duplicates (c: Circuit) : ; ;Check all duplicate names in all modules ; do(check-duplicates, modules(c)) ; ; ;Check all defined modules ; val names = stream(name, modules(c)) ; check-duplicate-symbols(names, "Duplicate module name") ; ; ;;================ CLEANUP COMMANDS ========================= ;defn cleanup (c:Stmt) : ; match(c) : ; (c:Begin) : ; to-command $ generate : ; loop(c) where : ; defn loop (c:Stmt) : ; match(c) : ; (c:Begin) : do(loop, body(c)) ; (c:EmptyStmt) : false ; (c) : yield(cleanup(c)) ; (c) : map(cleanup{_ as Stmt}, c) ; ;defn cleanup (c:Circuit) : ; val modules* = ; for m in modules(c) map : ; map(cleanup, m) ; Circuit(modules*, main(c)) ; ;;;============= SHIM ======================================== ;;defn shim (i:Immediate) -> Immediate : ;; match(i) : ;; (i:RegData) : ;; Ref(name(i), direction(i), type(i)) ;; (i:InstPort) : ;; val inst = Ref(name(i), UNKNOWN-DIR, UnknownType()) ;; Field(inst, port(i), direction(i), type(i)) ;; (i:Field) : ;; val imm* = shim(imm(i)) ;; put-imm(i, imm*) ;; (i) : i ;; ;;defn shim (c:Stmt) -> Stmt : ;; val c* = map(shim{_ as Immediate}, c) ;; map(shim{_ as Stmt}, c*) ;; ;;defn shim (c:Circuit) -> Circuit : ;; val modules* = ;; for m in modules(c) map : ;; Module(name(m), ports(m), shim(body(m))) ;; Circuit(modules*, main(c)) ;; ;;;================== INLINE MODULES ========================= ;;defn cat-name (p: String|Symbol, s: String|Symbol) -> Symbol : ;; if p == "" or p == `this : ;; TODO: REMOVE THIS WHEN `THIS GETS REMOVED ;; to-symbol(s) ;; else if s == `this : ;; TODO: DITTO ;; to-symbol(p) ;; else : ;; symbol-join([p, "/", s]) ;; ;;defn inline-command (c: Stmt, mods: HashTable, prefix: String, cmds: Vector) : ;; defn rename (n: Symbol) -> Symbol : ;; cat-name(prefix, n) ;; defn inline-name (i:Immediate) -> Symbol : ;; match(i) : ;; (r:Ref) : rename(name(r)) ;; (f:Field) : cat-name(inline-name(imm(f)), name(f)) ;; (f:Index) : cat-name(inline-name(imm(f)), to-string(value(f))) ;; defn inline-imm (i:Immediate) -> Ref : ;; Ref(inline-name(i), direction(i), type(i)) ;; match(c) : ;; (c:DefUInt) : add(cmds, DefUInt(rename(name(c)), value(c), width(c))) ;; (c:DefSInt) : add(cmds, DefSInt(rename(name(c)), value(c), width(c))) ;; (c:DefWire) : add(cmds, DefWire(rename(name(c)), type(c))) ;; (c:DefRegister) : add(cmds, DefRegister(rename(name(c)), type(c))) ;; (c:DefMemory) : add(cmds, DefMemory(rename(name(c)), type(c), size(c))) ;; (c:DefInstance) : inline-module(mods, mods[name(module(c))], to-string(rename(name(c))), cmds) ;; (c:DoPrim) : add(cmds, DoPrim(rename(name(c)), op(c), map(inline-imm, args(c)), consts(c))) ;; (c:DefAccessor) : add(cmds, DefAccessor(rename(name(c)), inline-imm(source(c)), direction(c), inline-imm(index(c)))) ;; (c:Connect) : add(cmds, Connect(inline-imm(loc(c)), inline-imm(exp(c)))) ;; (c:Begin) : do(inline-command{_, mods, prefix, cmds}, body(c)) ;; (c:EmptyStmt) : c ;; (c) : error("Unsupported command") ;; ;;defn inline-port (p: Port, prefix: String) -> Stmt : ;; DefWire(cat-name(prefix, name(p)), type(p)) ;; ;;defn inline-module (mods: HashTable, mod: Module, prefix: String, cmds: Vector) : ;; do(add{cmds, _}, map(inline-port{_, prefix}, ports(mod))) ;; inline-command(body(mod), mods, prefix, cmds) ;; ;;defn inline-modules (c: Circuit) -> Circuit : ;; val cmds = Vector() ;; val mods = HashTable(symbol-hash) ;; for mod in modules(c) do : ;; mods[name(mod)] = mod ;; val top = mods[main(c)] ;; inline-command(body(top), mods, "", cmds) ;; val main* = Module(name(top), ports(top), Begin(to-list(cmds))) ;; Circuit(list(main*), name(top)) ;; ;; ;;;============= FLO PRINTER ====================================== ;;;;; TODO: ;;;;; not supported gt, lte ;; ;;defn flo-op-name (op:PrimOp) -> String : ;; switch {op == _} : ;; ADD-OP : "add" ;; ADD-MOD-OP : "add" ;; MINUS-OP : "sub" ;; SUB-MOD-OP : "sub" ;; TIMES-OP : "mul" ;; todo: signed version ;; DIVIDE-OP : "div" ;; todo: signed version ;; MOD-OP : "mod" ;; todo: signed version ;; SHIFT-LEFT-OP : "lsh" ;; todo: signed version ;; SHIFT-RIGHT-OP : "rsh" ;; PAD-OP : "pad" ;; todo: signed version ;; BIT-AND-OP : "and" ;; BIT-OR-OP : "or" ;; BIT-XOR-OP : "xor" ;; CONCAT-OP : "cat" ;; BIT-SELECT-OP : "rsh" ;; BITS-SELECT-OP : "rsh" ;; LESS-OP : "lt" ;; todo: signed version ;; LESS-EQ-OP : "lte" ;; todo: swap args ;; GREATER-OP : "gt" ;; todo: swap args ;; GREATER-EQ-OP : "gte" ;; todo: signed version ;; EQUAL-OP : "eq" ;; MULTIPLEX-OP : "mux" ;; else : error $ string-join $ ;; ["Unable to print Primop: " op] ;; ;;defn emit (o:OutputStream, top:Symbol, ports:HashTable, lits:HashTable, elt) : ;; match(elt) : ;; (e:String|Symbol|Int) : ;; print(o, e) ;; (e:Ref) : ;; if key?(lits, name(e)) : ;; val lit = lits[name(e)] ;; print-all(o, [value(lit) "'" width(lit)]) ;; else : ;; if key?(ports, name(e)) : ;; print-all(o, [top "::"]) ;; print(o, name(e)) ;; (e:IntWidth) : ;; print(o, value(e)) ;; (e:PrimOp) : ;; print(o, flo-op-name(e)) ;; (e) : ;; println-all(["EMIT " e]) ;; error("Unable to emit") ;; ;;defn emit-all (o:OutputStream, top:Symbol, ports:HashTable, lits:HashTable, elts: Streamable) : ;; for e in elts do : emit(o, top, ports, lits, e) ;; ;;defn prim-width (type:Type) -> Width : ;; match(type) : ;; (t:UIntType) : width(t) ;; (t:SIntType) : width(t) ;; (t) : error("Bad prim width type") ;; ;;defn emit-command (o:OutputStream, cmd:Stmt, top:Symbol, lits:HashTable, regs:HashTable, accs:HashTable, ports:HashTable, outs:HashTable) : ;; match(cmd) : ;; (c:DefUInt) : ;; lits[name(c)] = c ;; (c:DefSInt) : ;; emit-all(o, top, ports, lits, [name(c) " = " value(c) "'" width(c) "\n"]) ;; (c:DoPrim) : ;; NEED TO FIGURE OUT WHEN WIDTHS ARE NECESSARY AND EXTRACT ;; emit-all(o, top, ports, lits, [name(c) " = " op(c)]) ;; for arg in args(c) do : ;; print(o, " ") ;; emit(o, top, ports, lits, arg) ;; for const in consts(c) do : ;; print(o, " ") ;; emit(o, top, ports, lits, const) ;; print("\n") ;; (c:DefRegister) : ;; regs[name(c)] = c ;; (c:DefMemory) : ;; emit-all(o, top, ports, lits, [name(c) " : mem'" prim-width(type(c)) " " size(c) "\n"]) ;; (c:DefAccessor) : ;; accs[name(c)] = c ;; (c:Connect) : ;; val dst = name(loc(c) as Ref) ;; val src = name(exp(c) as Ref) ;; if key?(regs, dst) : ;; val reg = regs[dst] ;; emit-all(o, top, ports, lits, [dst " = reg'" prim-width(type(reg)) " 0'" prim-width(type(reg)) " " exp(c) "\n"]) ;; else if key?(accs, dst) : ;; val acc = accs[dst] ;; ;; assert(direction(acc) == OUTPUT) ;; emit-all(o, top, ports, lits, [dst " = wr " source(acc) " " index(acc) " " exp(c) "\n"]) ;; else if key?(outs, dst) : ;; val out = outs[dst] ;; emit-all(o, top, ports, lits, [top "::" dst " = out'" prim-width(type(out)) " " exp(c) "\n"]) ;; else if key?(accs, src) : ;; val acc = accs[src] ;; ;; assert(direction(acc) == INPUT) ;; emit-all(o, top, ports, lits, [dst " = rd " source(acc) " " index(acc) "\n"]) ;; else : ;; emit-all(o, top, ports, lits, [dst " = mov " exp(c) "\n"]) ;; (c:Begin) : ;; do(emit-command{o, _, top, lits, regs, accs, ports, outs}, body(c)) ;; (c:DefWire|EmptyStmt) : ;; print("") ;; (c) : ;; error("Unable to print command") ;; ;;defn emit-module (o:OutputStream, m:Module) : ;; val regs = HashTable(symbol-hash) ;; val accs = HashTable(symbol-hash) ;; val lits = HashTable(symbol-hash) ;; val outs = HashTable(symbol-hash) ;; val portz = HashTable(symbol-hash) ;; for port in ports(m) do : ;; portz[name(port)] = port ;; if direction(port) == OUTPUT : ;; outs[name(port)] = port ;; else if name(port) == `reset : ;; print-all(o, [name(m) "::reset = rst\n"]) ;; else : ;; print-all(o, [name(m) "::" name(port) " = " "in'" prim-width(type(port)) "\n"]) ;; emit-command(o, body(m), name(m), lits, regs, accs, portz, outs) ;; ;;public defn emit-circuit (o:OutputStream, c:Circuit) : ;; emit-module(o, modules(c)[0]) ;============= DRIVER ====================================== public defn run-passes (c: Circuit, p: List) : var c*:Circuit = c println("Compiling!") if PRINT-CIRCUITS : println("Original Circuit") if PRINT-CIRCUITS : print(c) defn do-stage (name:String, f: Circuit -> Circuit) : if PRINT-CIRCUITS : println(name) c* = f(c*) if PRINT-CIRCUITS : print(c*) if PRINT-CIRCUITS : println-all(["Finished " name "\n"]) ; Early passes: ; If modules have a reset defined, must be an INPUT and UInt(1) if contains(p,'a') : do-stage("Working IR", to-working-ir) if contains(p,'b') : do-stage("Resolve Kinds", resolve-kinds) if contains(p,'c') : do-stage("Make Explicit Reset", make-explicit-reset) if contains(p,'d') : do-stage("Initialize Registers", initialize-registers) if contains(p,'e') : do-stage("Infer Types", infer-types) if contains(p,'f') : do-stage("Infer Directions", infer-directions) if contains(p,'g') : do-stage("Expand Accessors", expand-accessors) if contains(p,'h') : do-stage("Flatten Bundles", flatten-bundles) if contains(p,'i') : do-stage("Expand Bundles", expand-bundles) if contains(p,'j') : do-stage("Expand Multi Connects", expand-multi-connects) if contains(p,'k') : do-stage("Expand Whens", expand-whens) if contains(p,'l') : do-stage("Structural Form", structural-form) if contains(p,'m') : do-stage("Infer Widths", infer-widths) if contains(p,'n') : do-stage("Pad Widths", pad-widths) if contains(p,'o') : do-stage("Inline Instances", inline-instances) println("Done!") ;; println("Shim for Jonathan's Passes") ;; c* = shim(c*) ;; println("Inline Modules") ;; c* = inline-modules(c*) ; c*