aboutsummaryrefslogtreecommitdiff
path: root/src/main/stanza/passes.stanza
diff options
context:
space:
mode:
authorazidar2015-02-13 15:42:47 -0800
committerazidar2015-02-13 15:42:47 -0800
commit4f68f75415eb89427062eb86ff21b0e53bf4cadd (patch)
tree1f6a552e18eed4874a563359e95e5aad87a8ef50 /src/main/stanza/passes.stanza
parent4deb61cefa9c0ef7806e3986231865ce59673bc2 (diff)
First commit.
Added stanza as a .zip, changed names from ch to firrtl, and spec.tex is included. need to add installation instructions. TODO's included in README
Diffstat (limited to 'src/main/stanza/passes.stanza')
-rw-r--r--src/main/stanza/passes.stanza1878
1 files changed, 1878 insertions, 0 deletions
diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza
new file mode 100644
index 00000000..61eac73c
--- /dev/null
+++ b/src/main/stanza/passes.stanza
@@ -0,0 +1,1878 @@
+defpackage chipper.passes :
+ import core
+ import verse
+ import chipper.ir2
+ import chipper.ir-utils
+ import widthsolver
+
+;============== EXCEPTIONS =================================
+defclass PassException <: Exception
+defn PassException (msg:String) :
+ new PassException :
+ defmethod print (o:OutputStream, this) :
+ print(o, msg)
+
+;=============== WORKING IR ================================
+definterface Kind
+defstruct RegKind <: Kind
+defstruct AccessorKind <: Kind
+defstruct PortKind <: Kind
+defstruct MemKind <: Kind
+defstruct NodeKind <: Kind
+defstruct ModuleKind <: Kind
+defstruct InstanceKind <: Kind
+defstruct StructuralMemKind <: Kind
+
+defstruct WRef <: Expression :
+ name: Symbol
+ type: Type [multi => false]
+ kind: Kind
+ dir: Direction [multi => false]
+
+defstruct WField <: Expression :
+ exp: Expression
+ name: Symbol
+ type: Type [multi => false]
+ dir: Direction [multi => false]
+
+defstruct WIndex <: Expression :
+ exp: Expression
+ value: Int
+ type: Type [multi => false]
+ dir: Direction [multi => false]
+
+defstruct WDefAccessor <: Command :
+ name: Symbol
+ source: Expression
+ index: Expression
+ dir: Direction
+
+;================ WORKING IR UTILS =========================
+;=== Printers ===
+defmethod print (o:OutputStream, k:Kind) :
+ print{o, _} $
+ match(k) :
+ (k:RegKind) : "reg:"
+ (k:AccessorKind) : "accessor:"
+ (k:PortKind) : "port:"
+ (k:MemKind) : "mem:"
+ (k:NodeKind) : "n:"
+ (k:ModuleKind) : "module:"
+ (k:InstanceKind) : "inst:"
+ (k:StructuralMemKind) : "smem:"
+
+defmethod print (o:OutputStream, e:WRef) :
+ print-all(o, [name(e)])
+defmethod print (o:OutputStream, e:WField) :
+ print-all(o, [exp(e) "." name(e)])
+defmethod print (o:OutputStream, e:WIndex) :
+ print-all(o, [exp(e) "." value(e)])
+
+defmethod print (o:OutputStream, c:WDefAccessor) :
+ print-all(o, [dir(c) " accessor " name(c) " = " source(c) "[" index(c) "]"])
+
+defmethod map (f: Expression -> Expression, e: WField) :
+ WField(f(exp(e)), name(e), type(e), dir(e))
+
+defmethod map (f: Expression -> Expression, e: WIndex) :
+ WIndex(f(exp(e)), value(e), type(e), dir(e))
+
+defmethod map (f: Expression -> Expression, c:WDefAccessor) :
+ WDefAccessor(name(c), f(source(c)), f(index(c)), dir(c))
+
+;================= DIRECTION ===============================
+defmulti dir (e:Expression) -> Direction
+defmethod dir (e:Expression) :
+ OUTPUT
+
+;============== Bring to Working IR ========================
+defn to-working-ir (c:Circuit) :
+ defn to-exp (e:Expression) :
+ match(map(to-exp, e)) :
+ (e:Ref) : WRef(name(e), type(e), NodeKind(), UNKNOWN-DIR)
+ (e:Field) : WField(exp(e), name(e), type(e), UNKNOWN-DIR)
+ (e:Index) : WIndex(exp(e), value(e), type(e), UNKNOWN-DIR)
+ (e) : e
+ defn to-command (c:Command) :
+ match(map(to-exp, c)) :
+ (c:DefAccessor) :
+ WDefAccessor(name(c), source(c), index(c), UNKNOWN-DIR)
+ (c) :
+ map(to-command, c)
+
+ Circuit(modules*, main(c)) where :
+ val modules* =
+ for m in modules(c) map :
+ Module(name(m), ports(m), to-command(body(m)))
+
+;=============== Resolve Kinds =============================
+defn resolve-kinds (c:Circuit) :
+ defn resolve-exp (e:Expression, kinds:HashTable<Symbol,Kind>) :
+ match(e) :
+ (e:WRef) : WRef(name(e), type(e), kinds[name(e)], dir(e))
+ (e) : map(resolve-exp{_, kinds}, e)
+
+ defn resolve-comm (c:Command, kinds:HashTable<Symbol,Kind>) -> Command :
+ map{resolve-comm{_, kinds}, _} $
+ map(resolve-exp{_, kinds}, c)
+
+ defn find-kinds (c:Command, kinds:HashTable<Symbol,Kind>) :
+ match(c) :
+ (c:LetRec) :
+ for entry in entries(c) do :
+ kinds[key(entry)] = element-kind(value(entry))
+ (c:DefWire) : kinds[name(c)] = NodeKind()
+ (c:DefRegister) : kinds[name(c)] = RegKind()
+ (c:DefInstance) : kinds[name(c)] = InstanceKind()
+ (c:DefMemory) : kinds[name(c)] = MemKind()
+ (c:WDefAccessor) : kinds[name(c)] = AccessorKind()
+ (c) : false
+ do(find-kinds{_, kinds}, children(c))
+
+ defn element-kind (e:Element) :
+ match(e) :
+ (e:Memory) : StructuralMemKind()
+ (e) : NodeKind()
+
+ defn resolve-mod (m:Module, modules:List<Symbol>) :
+ val kinds = HashTable<Symbol,Kind>(symbol-hash)
+ for module in modules do :
+ kinds[module] = ModuleKind()
+ for port in ports(m) do :
+ kinds[name(port)] = PortKind()
+ find-kinds(body(m), kinds)
+ Module(name(m), ports(m), body*) where :
+ val body* = resolve-comm(body(m), kinds)
+
+ Circuit(modules*, main(c)) where :
+ val mod-names = map(name, modules(c))
+ val modules* = map(resolve-mod{_, mod-names}, modules(c))
+
+;=============== MAKE RESET EXPLICIT =======================
+defn make-explicit-reset (c:Circuit) :
+ defn reset-instances (c:Command, reset?: List<Symbol>) -> Command :
+ match(c) :
+ (c:DefInstance) :
+ val module = module(c) as WRef
+ if contains?(reset?, name(module)) :
+ c
+ else :
+ Begin $ list(c, Connect(WField(inst, `reset, UnknownType(), UNKNOWN-DIR), reset)) where :
+ val inst = WRef(name(c), UnknownType(), InstanceKind(), UNKNOWN-DIR)
+ val reset = WRef(`reset, UnknownType(), PortKind(), UNKNOWN-DIR)
+ (c) :
+ map(reset-instances{_:Command, reset?}, c)
+
+ defn make-explicit-reset (m:Module, reset-list: List<Symbol>) :
+ val reset? = contains?(reset-list, name(m))
+
+ ;Add reset port if necessary
+ val ports* =
+ if reset? :
+ ports(m)
+ else :
+ val reset = Port(`reset, INPUT, UIntType(IntWidth(1)))
+ List(reset, ports(m))
+
+ ;Reset Instances
+ val body* = reset-instances(body(m), reset-list)
+ val m* = Module(name(m), ports*, body*)
+
+ ;Initialize registers if necessary
+ if reset? : m*
+ else : initialize-registers(m*)
+
+ Circuit(modules*, main(c)) where :
+ defn reset? (m:Module) :
+ for p in ports(m) any? :
+ name(p) == `reset
+ val reset-list = to-list(stream(name, filter(reset?, modules(c))))
+ val modules* = map(make-explicit-reset{_, reset-list}, modules(c))
+
+;======= MAKE EXPLICIT REGISTER INITIALIZATION =============
+defn initialize-registers (m:Module) :
+ ;=== Initializing Expressions ===
+ defn init-exps (inits: List<KeyValue<Symbol,Expression>>) :
+ if empty?(inits) :
+ EmptyCommand()
+ else :
+ Conditionally(reset, Begin(map(connect, inits)), EmptyCommand()) where :
+ val reset = WRef(`reset, UnknownType(), PortKind(), UNKNOWN-DIR)
+ defn connect (init: KeyValue<Symbol, Expression>) :
+ val reg-ref = WRef(key(init), UnknownType(), RegKind(), UNKNOWN-DIR)
+ Connect(reg-ref, value(init))
+
+ defn initialize-registers (c: Command
+ inits: List<KeyValue<Symbol,Expression>>) ->
+ [Command, List<KeyValue<Symbol,Expression>>] :
+ ;=== Rename Expressions ===
+ defn rename (e:Expression) :
+ match(e) :
+ (e:WField) :
+ switch {name(e) == _} :
+ `init :
+ if reg?(exp(e)) : init-wire(exp(e))
+ else : map(rename, e)
+ else : map(rename, e)
+ (e) : map(rename, e)
+ defn reg? (e:Expression) :
+ match(e) :
+ (e:WRef) : kind(e) typeof RegKind
+ (e) : false
+ defn init-wire (e:Expression) :
+ lookup!(inits, name(e as WRef))
+
+ ;=== Driver ===
+ match(c) :
+ (c:DefRegister) :
+ [new-command, list(init-entry)] where :
+ val wire-name = gensym()
+ val wire-ref = WRef(wire-name, UnknownType(), NodeKind(), UNKNOWN-DIR)
+ val reg-ref = WRef(name(c), UnknownType(), RegKind(), UNKNOWN-DIR)
+ val def-init-wire = DefWire(wire-name, type(c))
+ val init-wire = Connect(wire-ref, reg-ref)
+ val init-reg = Connect(reg-ref, wire-ref)
+ val new-command = Begin(to-list([c, def-init-wire, init-wire, init-reg]))
+ val init-entry = name(c) => wire-ref
+ (c:Conditionally) :
+ val pred* = rename(pred(c))
+ val [conseq* con-inits] = initialize-registers(conseq(c), inits)
+ val [alt* alt-inits] = initialize-registers(alt(c), inits)
+ val c* = Conditionally(pred*, conseq+inits, alt+inits) where :
+ val conseq+inits = Begin(list(conseq*, init-exps(con-inits)))
+ val alt+inits = Begin(list(alt*, init-exps(alt-inits)))
+ [c*, List()]
+ (c:LetRec) :
+ val c* = map(rename, c)
+ val [body*, body-inits] = initialize-registers(body(c), inits)
+ val new-command =
+ LetRec(entries(c*), body+inits) where :
+ val body+inits = Begin(list(body*, init-exps(body-inits)))
+ [new-command, List()]
+ (c:Begin) :
+ var inits-in:List<KeyValue<Symbol,Expression>> = inits
+ var inits-out:List<KeyValue<Symbol,Expression>> = List()
+ val body* =
+ for c in body(c) map :
+ val [c* inits*] = initialize-registers(c, inits-in)
+ inits-in = append(inits*, inits-in)
+ inits-out = append(inits*, inits-out)
+ c*
+ [Begin(body*), inits-out]
+ (c) :
+ val c* = map(rename, c)
+ [c*, List()]
+
+ Module(name(m), ports(m), body+inits) where :
+ val [body*, inits] = initialize-registers(body(m), List())
+ val body+inits = Begin(list(body*, init-exps(inits)))
+
+
+;============== INFER TYPES ================================
+defmethod type (v:UIntValue) :
+ UIntType(width(v))
+
+defmethod type (v:SIntValue) :
+ SIntType(width(v))
+
+defn put-type (e:Expression, t:Type) -> Expression :
+ match(e) :
+ (e:WRef) : WRef(name(e), t, kind(e), dir(e))
+ (e:WField) : WField(exp(e), name(e), t, dir(e))
+ (e:WIndex) : WIndex(exp(e), value(e), t, dir(e))
+ (e:DoPrim) : DoPrim(op(e), args(e), consts(e), t)
+ (e:ReadPort) : ReadPort(mem(e), index(e), t)
+ (e) : e
+
+defn lookup-port (ports: Streamable<Port>, port-name: Symbol) :
+ for port in ports find :
+ name(port) == port-name
+
+defn infer (op:PrimOp, arg-types: List<Type>) -> Type :
+ defn wipe-width (t:Type) :
+ match(t) :
+ (t:UIntType) : UIntType(UnknownWidth())
+ (t:SIntType) : SIntType(UnknownWidth())
+
+ defn arg0 () : wipe-width(arg-types[0])
+ defn arg1 () : wipe-width(arg-types[1])
+ switch {op == _} :
+ ADD-OP : arg0()
+ ADD-MOD-OP : arg0()
+ SUB-OP : arg0()
+ SUB-MOD-OP : arg0()
+ TIMES-OP : arg0()
+ DIVIDE-OP : arg0()
+ MOD-OP : arg0()
+ SHIFT-LEFT-OP : arg0()
+ SHIFT-RIGHT-OP : arg0()
+ PAD-OP : arg0()
+ BIT-AND-OP : arg0()
+ BIT-OR-OP : arg0()
+ BIT-XOR-OP : arg0()
+ CONCAT-OP : arg0()
+ BIT-SELECT-OP : UIntType(UnknownWidth())
+ BITS-SELECT-OP : arg0()
+ MULTIPLEX-OP : arg0()
+ LESS-OP : UIntType(UnknownWidth())
+ LESS-EQ-OP : UIntType(UnknownWidth())
+ GREATER-OP : UIntType(UnknownWidth())
+ GREATER-EQ-OP : UIntType(UnknownWidth())
+ EQUAL-OP : UIntType(UnknownWidth())
+
+defn bundle-field-type (t:Type, n:Symbol) -> Type :
+ match(t) :
+ (t:BundleType) :
+ match(lookup-port(ports(t), n)) :
+ (p:Port) : type(p)
+ (p) : UnknownType()
+ (t) : UnknownType()
+
+defn vector-elem-type (t:Type) -> Type :
+ match(t) :
+ (t:VectorType) : type(t)
+ (t) : UnknownType()
+
+;e is the environment that contains all definitions seen so far.
+defn infer (c:Command, e:List<KeyValue<Symbol, Type>>) -> [Command, List<KeyValue<Symbol,Type>>] :
+ defn infer-exp (e:Expression, env:List<KeyValue<Symbol,Type>>) :
+ match(map(infer-exp{_, env}, e)) :
+ (e:WRef) :
+ put-type(e, lookup!(env, name(e)))
+ (e:WField) :
+ put-type(e, bundle-field-type(type(exp(e)), name(e)))
+ (e:WIndex) :
+ put-type(e, vector-elem-type(type(exp(e))))
+ (e:UIntValue) : e
+ (e:SIntValue) : e
+ (e:DoPrim) :
+ put-type(e, infer(op(e), map(type, args(e))))
+ (e:ReadPort) :
+ put-type(e, vector-elem-type(type(mem(e))))
+
+ defn element-type (e:Element, env:List<KeyValue<Symbol,Type>>) :
+ match(e) :
+ (e:Instance) :
+ val t = type(infer-exp(module(e), env))
+ match(t) :
+ (t:BundleType) :
+ BundleType $ to-list $
+ for p in ports(t) filter :
+ direction(p) == OUTPUT
+ (t) : UnknownType()
+ (e) : type(e)
+
+ match(c) :
+ (c:LetRec) :
+ val e* = append(elem-types, e) where :
+ val elem-types =
+ for entry in entries(c) map :
+ key(entry) => element-type(value(entry), e)
+ val c* = map(infer-exp{_, e*}, c)
+ val [body*, be] = infer(body(c*), e*)
+ [LetRec(entries(c*), body*), e]
+ (c) :
+ match(map(infer-exp{_, e}, c)) :
+ (c:DefWire) :
+ [c, List(entry, e)] where :
+ val entry = name(c) => type(c)
+ (c:DefRegister) :
+ [c, List(entry, e)] where :
+ val entry = name(c) => type(c)
+ (c:DefInstance) :
+ [c, List(entry, e)] where :
+ val entry = name(c) => type(module(c))
+ (c:DefMemory) :
+ [c, List(entry, e)] where :
+ val entry = name(c) => type(c)
+ (c:WDefAccessor) :
+ [c, List(entry, e)] where :
+ val src-type = type(source(c))
+ val entry = name(c) => vector-elem-type(src-type)
+ (c:Begin) :
+ var current-e: List<KeyValue<Symbol,Type>> = e
+ val body* = for c in body(c) map :
+ val [c*, e*] = infer(c, current-e)
+ current-e = e*
+ c*
+ [Begin(body*), current-e]
+ (c) :
+ defn infer-comm (c:Command) :
+ val [c* e*] = infer(c, e)
+ c*
+ val c* = map(infer-comm, c)
+ [c*, e]
+
+defn infer (m:Module, e:List<KeyValue<Symbol, Type>>) -> Module :
+ val env = append{_, e} $
+ for p in ports(m) map :
+ name(p) => type(p)
+ val [body*, e*] = infer(body(m), env)
+ Module(name(m), ports(m), body*)
+
+defn infer-types (c:Circuit) -> Circuit :
+ val env =
+ for m in modules(c) map :
+ name(m) => BundleType(ports(m))
+ Circuit(map(infer{_, env}, modules(c)),
+ main(c))
+
+;============= INFER DIRECTIONS ============================
+defn flip (d:Direction) :
+ switch {d == _} :
+ INPUT : OUTPUT
+ OUTPUT : INPUT
+ else : d
+
+defn times (d1:Direction, d2:Direction) :
+ if d1 == INPUT : flip(d2)
+ else : d2
+
+defn bundle-field-dir (t:Type, n:Symbol) -> Direction :
+ match(t) :
+ (t:BundleType) :
+ match(lookup-port(ports(t), n)) :
+ (p:Port) : direction(p)
+ (p) : UNKNOWN-DIR
+ (t) : UNKNOWN-DIR
+
+defn infer-dirs (m:Module) :
+ ;=== Direction of all Binders ===
+ val BI-DIR = new Direction
+ val directions = HashTable<Symbol,Direction>(symbol-hash)
+ defn find-dirs (c:Command) :
+ match(c) :
+ (c:LetRec) :
+ for entry in entries(c) do :
+ directions[key(entry)] = OUTPUT
+ find-dirs(body(c))
+ (c:DefWire) :
+ directions[name(c)] = BI-DIR
+ (c:DefRegister) :
+ directions[name(c)] = BI-DIR
+ (c:DefInstance) :
+ directions[name(c)] = OUTPUT
+ (c:DefMemory) :
+ directions[name(c)] = BI-DIR
+ (c:WDefAccessor) :
+ directions[name(c)] = dir(c)
+ (c) :
+ do(find-dirs, children(c))
+ for p in ports(m) do :
+ directions[name(p)] = flip(direction(p))
+ find-dirs(body(m))
+
+ ;=== Fix Point Status ===
+ var changed? = false
+
+ ;=== Infer directions of Expression ===
+ defn infer-exp (e:Expression, desired:Direction) :
+ match(e) :
+ (e:WRef) :
+ val dir* = let :
+ if kind(e) typeof ModuleKind :
+ OUTPUT
+ else :
+ val old-dir = directions[name(e)]
+ switch {old-dir == _} :
+ BI-DIR :
+ desired
+ UNKNOWN-DIR :
+ if directions[name(e)] != desired :
+ directions[name(e)] = desired
+ changed? = true
+ desired
+ else :
+ old-dir
+ WRef(name(e), type(e), kind(e), dir*)
+ (e:WField) :
+ val port-dir = bundle-field-dir(type(exp(e)), name(e))
+ val exp* = infer-exp(exp(e), port-dir * desired)
+ WField(exp*, name(e), type(e), port-dir * dir(exp*))
+ (e:WIndex) :
+ val exp* = infer-exp(exp(e), desired)
+ WIndex(exp*, value(e), type(e), dir(exp*))
+ (e) :
+ map(infer-exp{_, OUTPUT}, e)
+
+ ;=== Infer directions of Commands ===
+ defn infer-comm (c:Command) :
+ match(c) :
+ (c:LetRec) :
+ val c* = map(infer-exp{_, OUTPUT}, c)
+ LetRec(entries(c*), infer-comm(body(c)))
+ (c:DefInstance) :
+ DefInstance(name(c),
+ infer-exp(module(c), OUTPUT))
+ (c:WDefAccessor) :
+ val d = directions[name(c)]
+ WDefAccessor(name(c),
+ infer-exp(source(c), d),
+ infer-exp(index(c), OUTPUT),
+ d)
+ (c:Conditionally) :
+ Conditionally(infer-exp(pred(c), OUTPUT),
+ infer-comm(conseq(c)),
+ infer-comm(alt(c)))
+ (c:Connect) :
+ Connect(infer-exp(loc(c), INPUT), infer-exp(exp(c), OUTPUT))
+ (c) :
+ map(infer-comm, c)
+
+ ;=== Iterate until fix point ===
+ defn* fixpoint (c:Command) :
+ changed? = false
+ val c* = infer-comm(c)
+ if changed? : fixpoint(c*)
+ else : c*
+
+ Module(name(m), ports(m), body*) where :
+ val body* = fixpoint(body(m))
+
+defn infer-directions (c:Circuit) :
+ Circuit(modules*, main(c)) where :
+ val modules* = map(infer-dirs, modules(c))
+
+
+;============== EXPAND VECS ================================
+defstruct ManyConnect <: Command :
+ index: Expression
+ locs: List<Expression>
+ exp: Expression
+
+defstruct ConnectMany <: Command :
+ index: Expression
+ loc: Expression
+ exps: List<Expression>
+
+defmethod print (o:OutputStream, c:ManyConnect) :
+ print-all(o, [locs(c) "[" index(c) "] := " exp(c)])
+defmethod print (o:OutputStream, c:ConnectMany) :
+ print-all(o, [loc(c) " := " exps(c) "[" index(c) "]"])
+
+defmethod map (f: Expression -> Expression, c:ManyConnect) :
+ ManyConnect(f(index(c)), map(f, locs(c)), f(exp(c)))
+defmethod map (f: Expression -> Expression, c:ConnectMany) :
+ ConnectMany(f(index(c)), f(loc(c)), map(f, exps(c)))
+
+defn expand-accessors (m: Module) :
+ defn expand (c:Command) :
+ match(c) :
+ (c:WDefAccessor) :
+ ;Is the source a memory?
+ val mem? =
+ match(source(c)) :
+ (r:WRef) : kind(r) typeof MemKind
+ (r) : false
+
+ if mem? :
+ c
+ else :
+ switch {dir(c) == _} :
+ INPUT :
+ Begin(list(
+ DefWire(name(c), type(src-type))
+ ManyConnect(index(c), elems, wire-ref)))
+ where :
+ val src-type = type(source(c)) as VectorType
+ val wire-ref = WRef(name(c), type(src-type), NodeKind(), OUTPUT)
+ val elems = to-list $
+ for i in 0 to size(src-type) stream :
+ WIndex(source(c), i, type(src-type), INPUT)
+ OUTPUT :
+ Begin(list(
+ DefWire(name(c), type(src-type))
+ ConnectMany(index(c), wire-ref, elems)))
+ where :
+ val src-type = type(source(c)) as VectorType
+ val wire-ref = WRef(name(c), type(src-type), NodeKind(), INPUT)
+ val elems = to-list $
+ for i in 0 to size(src-type) stream :
+ WIndex(source(c), i, type(src-type), OUTPUT)
+ (c) :
+ map(expand, c)
+ Module(name(m), ports(m), expand(body(m)))
+
+defn expand-accessors (c:Circuit) :
+ Circuit(modules*, main(c)) where :
+ val modules* = map(expand-accessors, modules(c))
+
+
+
+;=============== BUNDLE FLATTENING =========================
+defn prefix (prefix, suffix) :
+ symbol-join([prefix "/" suffix])
+
+defn prefix-ports (pre:Symbol, ports:List<Port>) :
+ for p in ports map :
+ Port(prefix(pre, name(p)), direction(p), type(p))
+
+defn flatten-ports (port:Port) -> List<Port> :
+ match(type(port)) :
+ (t:BundleType) :
+ val ports = map-append(flatten-ports, ports(t))
+ for p in ports map :
+ Port(prefix(name(port), name(p)),
+ direction(port) * direction(p),
+ type(p))
+ (t:VectorType) :
+ val type* = flatten-type(t)
+ flatten-ports(Port(name(port), direction(port), type*))
+ (t:Type) :
+ list(port)
+
+defn flatten-type (t:Type) -> Type :
+ match(t) :
+ (t:BundleType) :
+ BundleType $
+ map-append(flatten-ports, ports(t))
+ (t:VectorType) :
+ flatten-type $ BundleType $ to-list $
+ for i in 0 to size(t) stream :
+ Port(to-symbol(i), OUTPUT, type(t))
+ (t:Type) :
+ t
+
+defn flatten-bundles (c:Circuit) :
+ defn flatten-exp (e:Expression) :
+ match(map(flatten-exp, e)) :
+ (e:UIntValue|SIntValue) :
+ e
+ (e:WRef) :
+ match(kind(e)) :
+ (k:MemKind|StructuralMemKind) :
+ val type* = map(flatten-type, type(e))
+ put-type(e, type*)
+ (k) :
+ val type* = flatten-type(type(e))
+ put-type(e, type*)
+ (e) :
+ val type* = flatten-type(type(e))
+ put-type(e, type*)
+
+ defn flatten-element (e:Element) :
+ val t* = flatten-type(type(e))
+ match(map(flatten-exp, e)) :
+ (e:Register) : Register(t*, value(e), enable(e))
+ (e:Memory) : Memory(t*, writers(e))
+ (e:Node) : Node(t*, value(e))
+ (e:Instance) : Instance(t*, module(e), ports(e))
+
+ defn flatten-comm (c:Command) :
+ match(c) :
+ (c:LetRec) :
+ val entries* =
+ for entry in entries(c) map :
+ key(entry) => flatten-element(value(entry))
+ LetRec(entries*, flatten-comm(body(c)))
+ (c:DefWire) :
+ DefWire(name(c), flatten-type(type(c)))
+ (c:DefRegister) :
+ DefRegister(name(c), flatten-type(type(c)))
+ (c:DefMemory) :
+ val type* = map(flatten-type, type(c))
+ DefMemory(name(c), type*)
+ (c) :
+ map{flatten-comm, _} $
+ map(flatten-exp, c)
+
+ defn flatten-module (m:Module) :
+ val ports* = map-append(flatten-ports, ports(m))
+ val body* = flatten-comm(body(m))
+ Module(name(m), ports*, body*)
+
+ Circuit(modules*, main(c)) where :
+ val modules* = map(flatten-module, modules(c))
+
+
+;================== BUNDLE EXPANSION =======================
+defn expand-bundles (m:Module) :
+
+ ;Collapse all field/index expressions
+ defn collapse-exp (e:Expression) -> Expression :
+ match(e) :
+ (e:WField) :
+ match(collapse-exp(exp(e))) :
+ (ei:WRef) :
+ if kind(ei) typeof InstanceKind :
+ e
+ else :
+ WRef(name*, type(e), kind(ei), dir(e)) where :
+ val name* = prefix(name(ei), name(e))
+ (ei:WField) :
+ WField(exp(ei), name*, type(e), dir(e)) where :
+ val name* = prefix(name(ei), name(e))
+ (e:WIndex) :
+ collapse-exp(WField(exp(e), name, type(e), dir(e))) where :
+ val name = to-symbol(value(e))
+ (e) :
+ map(collapse-exp, e)
+
+ ;Expand expressions
+ defn expand-exp (e:Expression) -> List<Expression> :
+ match(type(e)) :
+ (t:BundleType) :
+ for p in ports(t) map :
+ val dir* = direction(p) * dir(e)
+ collapse-exp(WField(e, name(p), type(p), dir*))
+ (t) :
+ list(collapse-exp(e))
+
+ ;Expand commands
+ defn expand-comm (c:Command) :
+ match(c) :
+ (c:DefWire) :
+ match(type(c)) :
+ (t:BundleType) :
+ Begin $
+ for p in ports(t) map :
+ DefWire(prefix(name(c), name(p)), type(p))
+ (t) :
+ c
+ (c:DefRegister) :
+ match(type(c)) :
+ (t:BundleType) :
+ Begin $
+ for p in ports(t) map :
+ DefRegister(prefix(name(c), name(p)), type(p))
+ (t) :
+ c
+ (c:DefMemory) :
+ match(type(type(c))) :
+ (t:BundleType) :
+ Begin $
+ for p in ports(t) map :
+ DefMemory(prefix(name(c), name(p)), type*) where :
+ val s = size(type(c))
+ val type* = VectorType(type(p), s)
+ (t) :
+ c
+ (c:WDefAccessor) :
+ match(type(source(c))) :
+ (t:BundleType) :
+ val srcs = expand-exp(source(c))
+ Begin $
+ for (p in ports(t), src in srcs) map :
+ WDefAccessor(name*, src, index(c), dir*) where :
+ val name* = prefix(name(c), name(p))
+ val dir* = direction(p) * dir(c)
+ (t) :
+ c
+ (c:Connect) :
+ val locs = expand-exp(loc(c))
+ val exps = expand-exp(exp(c))
+ Begin $
+ for (l in locs, e in exps) map :
+ switch {dir(l) == _} :
+ INPUT : Connect(l, e)
+ OUTPUT : Connect(e, l)
+ (c:ManyConnect) :
+ val locs-list = transpose(map(expand-exp, locs(c)))
+ val exps = expand-exp(exp(c))
+ Begin $
+ for (locs in locs-list, e in exps) map :
+ switch {dir(e) == _} :
+ OUTPUT : ManyConnect(index(c), locs, e)
+ INPUT : ConnectMany(index(c), e, locs)
+ (c:ConnectMany) :
+ val locs = expand-exp(loc(c))
+ val exps-list = transpose(map(expand-exp, exps(c)))
+ Begin $
+ for (l in locs, exps in exps-list) map :
+ switch {dir(l) == _} :
+ INPUT : ConnectMany(index(c), l, exps)
+ OUTPUT : ManyConnect(index(c), exps, l)
+ (c) :
+ map{expand-comm, _} $
+ map(collapse-exp, c)
+
+ Module(name(m), ports(m), expand-comm(body(m)))
+
+defn expand-bundles (c:Circuit) :
+ Circuit(modules*, main(c)) where :
+ val modules* = map(expand-bundles, modules(c))
+
+
+;=========== CONVERT MULTI CONNECTS to WHEN ================
+defn expand-multi-connects (c:Circuit) :
+ defn equal-exp (e1:Expression, e2:Expression) :
+ DoPrim(EQUAL-OP, list(e1, e2), List(), UIntType(UnknownWidth()))
+ defn uint (i:Int) :
+ UIntValue(i, UnknownWidth())
+
+ defn expand-comm (c:Command) :
+ match(c) :
+ (c:ConnectMany) :
+ Begin $ to-list $
+ for (i in 0 to false, e in exps(c)) stream :
+ Conditionally(equal-exp(index(c), uint(i)),
+ Connect(loc(c), e)
+ EmptyCommand())
+ (c:ManyConnect) :
+ Begin $ to-list $
+ for (i in 0 to false, l in locs(c)) stream :
+ Conditionally(equal-exp(index(c), uint(i)),
+ Connect(l, exp(c))
+ EmptyCommand())
+ (c) :
+ map(expand-comm, c)
+
+ defn expand (m:Module) :
+ Module(name(m), ports(m), expand-comm(body(m)))
+
+ Circuit(modules*, main(c)) where :
+ val modules* = map(expand, modules(c))
+
+
+;================ EXPAND WHENS =============================
+definterface SymbolicValue
+defstruct ExpValue <: SymbolicValue :
+ exp: Expression
+defstruct WhenValue <: SymbolicValue :
+ pred: Expression
+ conseq: SymbolicValue
+ alt: SymbolicValue
+defstruct VoidValue <: SymbolicValue
+
+defmethod print (o:OutputStream, sv:SymbolicValue) :
+ match(sv) :
+ (sv:VoidValue) : print(o, "VOID")
+ (sv:WhenValue) : print-all(o, ["(" pred(sv) "? " conseq(sv) " : " alt(sv) ")"])
+ (sv:ExpValue) : print(o, exp(sv))
+
+defn key-eqv? (e1:Expression, e2:Expression) :
+ match(e1, e2) :
+ (e1:WRef, e2:WRef) :
+ name(e1) == name(e2)
+ (e1:WField, e2:WField) :
+ name(e1) == name(e2) and
+ key-eqv?(exp(e1), exp(e2))
+ (e1, e2) :
+ false
+
+defn merge-env (pred: Expression,
+ con-env: List<KeyValue<Expression,SymbolicValue>>,
+ alt-env: List<KeyValue<Expression,SymbolicValue>>) :
+ val merged = Vector<KeyValue<Expression, SymbolicValue>>()
+ defn new-key? (k:Expression) :
+ for entry in merged none? :
+ key-eqv?(key(entry), k)
+
+ defn sv (env:List<KeyValue<Expression,SymbolicValue>>, k:Expression) :
+ for entry in env search :
+ if key-eqv?(key(entry), k) :
+ value(entry)
+
+ for k in stream(key, concat(con-env, alt-env)) do :
+ if new-key?(k) :
+ match(sv(con-env, k), sv(alt-env, k)) :
+ (a:SymbolicValue, b:SymbolicValue) :
+ if a == b :
+ add(merged, k => a)
+ else :
+ add(merged, k => WhenValue(pred, a, b))
+ (a:SymbolicValue, b:False) :
+ add(merged, k => a)
+ (a:False, b:SymbolicValue) :
+ add(merged, k => b)
+ (a:False, b:False) :
+ false
+
+ to-list(merged)
+
+defn simplify-env (env: List<KeyValue<Expression,SymbolicValue>>) :
+ val merged = Vector<KeyValue<Expression, SymbolicValue>>()
+ defn new-key? (k:Expression) :
+ for entry in merged none? :
+ key-eqv?(key(entry), k)
+ for entry in env do :
+ if new-key?(key(entry)) :
+ add(merged, entry)
+ to-list(merged)
+
+defn expand-whens (m:Module) :
+ val commands = Vector<Command>()
+ val elements = Vector<KeyValue<Symbol,Element>>()
+ defn eval (c:Command, env:List<KeyValue<Expression,SymbolicValue>>) ->
+ List<KeyValue<Expression,SymbolicValue>> :
+ match(c) :
+ (c:LetRec) :
+ do(add{elements, _}, entries(c))
+ eval(body(c), env)
+ (c:DefWire) :
+ add(commands, c)
+ val wire-ref = WRef(name(c), type(c), NodeKind(), INPUT)
+ List(wire-ref => VoidValue(), env)
+ (c:DefRegister) :
+ add(commands, c)
+ val reg-ref = WRef(name(c), type(c), RegKind(), INPUT)
+ List(reg-ref => VoidValue(), env)
+ (c:DefInstance) :
+ add(commands, c)
+ val entries = let :
+ val module-type = type(module(c)) as BundleType
+ val input-ports = to-list $
+ for p in ports(module-type) filter :
+ direction(p) == INPUT
+ val inst-ref = WRef(name(c), module-type, InstanceKind(), OUTPUT)
+ for p in input-ports map :
+ WField(inst-ref, name(p), type(p), INPUT) => VoidValue()
+ append(entries, env)
+ (c:DefMemory) :
+ add(commands, c)
+ env
+ (c:WDefAccessor) :
+ add(commands, c)
+ if dir(c) == INPUT :
+ val access-ref = WRef(name(c), type(source(c)), AccessorKind(), dir(c))
+ List(access-ref => VoidValue(), env)
+ else :
+ env
+ (c:Conditionally) :
+ val con-env = eval(conseq(c), env)
+ val alt-env = eval(alt(c), env)
+ merge-env(pred(c), con-env, alt-env)
+ (c:Begin) :
+ var env:List<KeyValue<Expression,SymbolicValue>> = env
+ for c in body(c) do :
+ env = eval(c, env)
+ env
+ (c:Connect) :
+ List(loc(c) => ExpValue(exp(c)), env)
+ (c:EmptyCommand) :
+ env
+
+ defn convert-symbolic (key:Expression, sv:SymbolicValue) :
+ match(sv) :
+ (sv:VoidValue) :
+ throw $ PassException $ string-join $ [
+ "No default value for " key "."]
+ (sv:ExpValue) :
+ exp(sv)
+ (sv:WhenValue) :
+ defn multiplex-exp (pred:Expression, conseq:Expression, alt:Expression) :
+ DoPrim(MULTIPLEX-OP, list(pred, conseq, alt), List(), type(conseq))
+ multiplex-exp(pred(sv),
+ convert-symbolic(key, conseq(sv))
+ convert-symbolic(key, alt(sv)))
+
+ ;Compute final environment
+ val env0 = let :
+ val output-ports = to-list $
+ for p in ports(m) filter :
+ direction(p) == OUTPUT
+ for p in output-ports map :
+ val port-ref = WRef(name(p), type(p), PortKind(), INPUT)
+ port-ref => VoidValue()
+ val env* = simplify-env(eval(body(m), env0))
+
+ ;Make new body
+ val body* = Begin(list(defs, LetRec(elems, connections))) where :
+ val defs = Begin(to-list(commands))
+ val elems = to-list(elements)
+ val connections = Begin $
+ for entry in env* map :
+ val sv = convert-symbolic(key(entry), value(entry))
+ Connect(key(entry), sv)
+
+ ;Final module
+ Module(name(m), ports(m), body*)
+
+defn expand-whens (c:Circuit) :
+ val modules* = map(expand-whens, modules(c))
+ Circuit(modules*, main(c))
+
+
+;================ STRUCTURAL FORM ==========================
+
+defn structural-form (m:Module) :
+ val elements = Vector<(() -> KeyValue<Symbol,Element>)>()
+ val connected = HashTable<Symbol, Expression>(symbol-hash)
+ val write-accessors = HashTable<Symbol, List<WDefAccessor>>(symbol-hash)
+ val read-accessors = HashTable<Symbol, WDefAccessor>(symbol-hash)
+ val inst-ports = HashTable<Symbol, List<KeyValue<Symbol, Expression>>>(symbol-hash)
+ val port-connects = Vector<Connect>()
+
+ defn scan (c:Command) :
+ match(c) :
+ (c:Connect) :
+ match(loc(c)) :
+ (loc:WRef) :
+ match(kind(loc)) :
+ (k:PortKind) : add(port-connects, c)
+ (k) : connected[name(loc)] = exp(c)
+ (loc:WField) :
+ val inst = exp(loc) as WRef
+ val entry = name(loc) => exp(c)
+ inst-ports[name(inst)] = List(entry, get?(inst-ports, name(inst), List()))
+ (c:LetRec) :
+ for e in entries(c) do :
+ add(elements, {e})
+ scan(body(c))
+ (c:DefWire) :
+ add{elements, _} $ fn () :
+ name(c) => Node(type(c), connected[name(c)])
+ (c:DefRegister) :
+ add{elements, _} $ fn () :
+ val one = UIntValue(1, UnknownWidth())
+ name(c) => Register(type(c), connected[name(c)], one)
+ (c:DefInstance) :
+ add{elements, _} $ fn () :
+ name(c) => Instance(UnknownType(), module(c), inst-ports[name(c)])
+ (c:DefMemory) :
+ add{elements, _} $ fn () :
+ val ports = for a in get?(write-accessors, name(c), List()) map :
+ val one = UIntValue(1, UnknownWidth())
+ WritePort(index(a), connected[name(a)], one)
+ name(c) => Memory(type(c), ports)
+ (c:WDefAccessor) :
+ val mem = source(c) as WRef
+ switch {dir(c) == _} :
+ INPUT :
+ write-accessors[name(mem)] = List(c,
+ get?(write-accessors, name(mem), List()))
+ OUTPUT :
+ read-accessors[name(c)] = c
+ (c) :
+ do(scan, children(c))
+
+ defn make-read-ports (e:Expression) :
+ match(e) :
+ (e:WRef) :
+ match(kind(e)) :
+ (k:AccessorKind) :
+ val accessor = read-accessors[name(e)]
+ ReadPort(source(accessor), index(accessor), type(e))
+ (k) : e
+ (e) : map(make-read-ports, e)
+
+ Module(name(m), ports(m), body*) where :
+ scan(body(m))
+ val elems = to-list $
+ for e in elements stream :
+ val entry = e()
+ key(entry) => map(make-read-ports, value(entry))
+ val connect-ports = Begin $ to-list $
+ for c in port-connects stream :
+ Connect(loc(c), make-read-ports(exp(c)))
+ val body* =
+ if empty?(elems) : connect-ports
+ else : LetRec(elems, connect-ports)
+
+defn structural-form (c:Circuit) :
+ val modules* = map(structural-form, modules(c))
+ Circuit(modules*, main(c))
+
+
+;==================== WIDTH INFERENCE ======================
+defstruct WidthVar <: Width :
+ name: Symbol
+
+defmethod print (o:OutputStream, w:WidthVar) :
+ print(o, name(w))
+
+defn width! (t:Type) :
+ match(t) :
+ (t:UIntType) : width(t)
+ (t:SIntType) : width(t)
+ (t) : error("No width field.")
+
+defn put-width (t:Type, w:Width) :
+ match(t) :
+ (t:UIntType) : UIntType(w)
+ (t:SIntType) : SIntType(w)
+ (t) : t
+
+defn put-width (e:Expression, w:Width) :
+ val type* = put-width(type(e), w)
+ put-type(e, type*)
+
+defn add-width-vars (t:Type) :
+ defn width? (w:Width) :
+ match(w) :
+ (w:UnknownWidth) : WidthVar(gensym())
+ (w) : w
+ match(t) :
+ (t:UIntType) : UIntType(width?(width(t)))
+ (t:SIntType) : SIntType(width?(width(t)))
+ (t) : map(add-width-vars, t)
+
+defn uint-width (i:Int) :
+ var v:Int = i
+ var n:Int = 0
+ while v != 0 :
+ v = v >> 1
+ n = n + 1
+ IntWidth(n)
+
+defn sint-width (i:Int) :
+ if i > 0 :
+ val w = uint-width(i)
+ IntWidth(width(w) + 1)
+ else :
+ val w = uint-width(neg(i) - 1)
+ IntWidth(width(w) + 1)
+
+defn to-exp (w:Width) :
+ match(w) :
+ (w:IntWidth) : ELit(width(w))
+ (w:WidthVar) : EVar(name(w))
+ (w) : error $ string-join $ [
+ "Cannot convert " w " to exp."]
+
+defn primop-width (op:PrimOp, ws:List<Width>, ints:List<Int>) -> Exp :
+ defn wmax (w1:Width, w2:Width) :
+ EMax(to-exp(w1), to-exp(w2))
+ defn wplus (w1:Width, w2:Width) :
+ EPlus(to-exp(w1), to-exp(w2))
+ defn wplus (w1:Width, w2:Int) :
+ EPlus(to-exp(w1), ELit(w2))
+ defn wminus (w1:Width, w2:Width) :
+ EMinus(to-exp(w1), to-exp(w2))
+ defn wminus (w1:Width, w2:Int) :
+ EMinus(to-exp(w1), ELit(w2))
+ defn wmax-inc (w1:Width, w2:Width) :
+ EPlus(wmax(w1, w2), ELit(1))
+
+ switch {op == _} :
+ ADD-OP : wmax-inc(ws[0], ws[1])
+ ADD-MOD-OP : wmax(ws[0], ws[1])
+ SUB-OP : wmax-inc(ws[0], ws[1])
+ SUB-MOD-OP : wmax(ws[0], ws[1])
+ TIMES-OP : wplus(ws[0], ws[1])
+ DIVIDE-OP : wminus(ws[0], ws[1])
+ MOD-OP : to-exp(ws[1])
+ SHIFT-LEFT-OP : wplus(ws[0], ints[0])
+ SHIFT-RIGHT-OP : wminus(ws[0], ints[0])
+ PAD-OP : ELit(ints[0])
+ BIT-AND-OP : wmax(ws[0], ws[1])
+ BIT-OR-OP : wmax(ws[0], ws[1])
+ BIT-XOR-OP : wmax(ws[0], ws[1])
+ CONCAT-OP : wplus(ws[0], ints[0])
+ BIT-SELECT-OP : ELit(1)
+ BITS-SELECT-OP : ELit(ints[0])
+ MULTIPLEX-OP : wmax(ws[1], ws[2])
+ LESS-OP : ELit(1)
+ LESS-EQ-OP : ELit(1)
+ GREATER-OP : ELit(1)
+ GREATER-EQ-OP : ELit(1)
+ EQUAL-OP : ELit(1)
+
+defn put-type (el:Element, t:Type) :
+ match(el) :
+ (el:Register) : Register(t, value(el), enable(el))
+ (el:Memory) : Memory(t, writers(el))
+ (el:Node) : Node(t, value(el))
+ (el:Instance) : Instance(t, module(el), ports(el))
+
+defn generate-constraints (c:Circuit) -> [Circuit, Vector<WConstraint>] :
+ ;Constraints
+ val cs = Vector<WConstraint>()
+ defn new-constraint (Constraint: (Symbol, Exp) -> WConstraint, wvar:Width, width:Width) :
+ match(wvar) :
+ (wvar:WidthVar) :
+ add(cs, Constraint(name(wvar), to-exp(width)))
+ (wvar) :
+ false
+
+ defn to-width (e:Exp) :
+ match(e) :
+ (e:ELit) :
+ IntWidth(width(e))
+ (e:EVar) :
+ WidthVar(name(e))
+ (e) :
+ val x = gensym()
+ add(cs, WidthEqual(x, e))
+ WidthVar(x)
+
+ ;Module types
+ val mod-types = HashTable<Symbol,Type>(symbol-hash)
+
+ defn add-port-vars (m:Module) -> Module :
+ val ports* =
+ for p in ports(m) map :
+ val type* = add-width-vars(type(p))
+ Port(name(p), direction(p), type*)
+ mod-types[name(m)] = BundleType(ports*)
+ Module(name(m), ports*, body(m))
+
+ ;Add Width Variables
+ defn add-module-vars (m:Module) -> Module :
+ val types = HashTable<Symbol,Type>(symbol-hash)
+ for p in ports(m) do :
+ types[name(p)] = type(p)
+
+ defn infer-exp-width (e:Expression) :
+ match(map(infer-exp-width, e)) :
+ (e:WRef) :
+ match(kind(e)) :
+ (k:ModuleKind) : put-type(e, mod-types[name(e)])
+ (k) : put-type(e, types[name(e)])
+ (e:WField) :
+ val t = bundle-field-type(type(exp(e)), name(e))
+ put-width(e, width!(t))
+ (e:UIntValue) :
+ match(width(e)) :
+ (w:UnknownWidth) : UIntValue(value(e), uint-width(value(e)))
+ (w) : e
+ (e:SIntValue) :
+ match(width(e)) :
+ (w:UnknownWidth) : SIntValue(value(e), sint-width(value(e)))
+ (w) : e
+ (e:DoPrim) :
+ val widths = map(width!{type(_)}, args(e))
+ val w = to-width(primop-width(op(e), widths, consts(e)))
+ put-width(e, w)
+ (e:ReadPort) :
+ val elem-type = type(type(mem(e)) as VectorType)
+ put-width(e, width!(elem-type))
+
+ defn infer-comm-width (c:Command) :
+ match(c) :
+ (c:LetRec) :
+ ;Add width vars to elements
+ var entries*: List<KeyValue<Symbol,Element>> =
+ for entry in entries(c) map :
+ val el-name = key(entry)
+ key(entry) =>
+ match(value(entry)) :
+ (el:Register|Node) :
+ put-type(el, add-width-vars(type(el)))
+ (el:Memory) :
+ el
+ (el:Instance) :
+ val mod-type = type(infer-exp-width(module(el))) as BundleType
+ val type = BundleType $ to-list $
+ for p in ports(mod-type) filter :
+ direction(p) == OUTPUT
+ put-type(el, type)
+
+ ;Add vars to environment
+ for entry in entries* do :
+ types[key(entry)] = type(value(entry))
+
+ ;Infer types for elements
+ entries* =
+ for entry in entries* map :
+ key(entry) => map(infer-exp-width, value(entry))
+
+ ;Generate constraints
+ for entry in entries* do :
+ val el-name = key(entry)
+ match(value(entry)) :
+ (el:Register) :
+ new-constraint(WidthEqual, reg-width, val-width) where :
+ val reg-width = width!(types[el-name])
+ val val-width = width!(type(value(el)))
+ (el:Node) :
+ new-constraint(WidthEqual, node-width, val-width) where :
+ val node-width = width!(types[el-name])
+ val val-width = width!(type(value(el)))
+ (el:Instance) :
+ val mod-type = type(module(el)) as BundleType
+ for entry in ports(el) do :
+ new-constraint(WidthGreater, port-width, val-width) where :
+ val port-name = key(entry)
+ val port-width = width!(bundle-field-type(mod-type, port-name))
+ val val-width = width!(type(value(entry)))
+ (el) : false
+
+ ;Analyze body
+ LetRec(entries*, infer-comm-width(body(c)))
+
+ (c:Connect) :
+ val loc* = infer-exp-width(loc(c))
+ val exp* = infer-exp-width(exp(c))
+ new-constraint(WidthGreater, loc-width, exp-width) where :
+ val loc-width = width!(type(loc*))
+ val exp-width = width!(type(exp*))
+ Connect(loc*, exp*)
+
+ (c:Begin) :
+ map(infer-comm-width, c)
+
+ Module(name(m), ports(m), body*) where :
+ val body* = infer-comm-width(body(m))
+
+ val c* =
+ Circuit(modules*, main(c)) where :
+ val ms = map(add-port-vars, modules(c))
+ val modules* = map(add-module-vars, ms)
+ [c*, cs]
+
+
+;================== FILL WIDTHS ============================
+defn fill-widths (c:Circuit, solved:Streamable<WidthEqual>) :
+ ;Populate table
+ val table = HashTable<Symbol, Width>(symbol-hash)
+ for eq in solved do :
+ table[name(eq)] = IntWidth(width(value(eq) as ELit))
+
+ defn width? (w:Width) :
+ match(w) :
+ (w:WidthVar) : get?(table, name(w), UnknownWidth())
+ (w) : w
+
+ defn fill-type (t:Type) :
+ match(t) :
+ (t:UIntType) : UIntType(width?(width(t)))
+ (t:SIntType) : SIntType(width?(width(t)))
+ (t) : map(fill-type, t)
+
+ defn fill-exp (e:Expression) :
+ val e* = map(fill-exp, e)
+ val type* = fill-type(type(e))
+ put-type(e*, type*)
+
+ defn fill-element (e:Element) :
+ val e* = map(fill-exp, e)
+ val type* = fill-type(type(e))
+ put-type(e*, type*)
+
+ defn fill-comm (c:Command) :
+ match(c) :
+ (c:LetRec) :
+ val entries* =
+ for e in entries(c) map :
+ key(e) => fill-element(value(e))
+ LetRec(entries*, fill-comm(body(c)))
+ (c) :
+ map{fill-comm, _} $
+ map(fill-exp, c)
+
+ defn fill-port (p:Port) :
+ Port(name(p), direction(p), fill-type(type(p)))
+
+ defn fill-mod (m:Module) :
+ Module(name(m), ports*, body*) where :
+ val ports* = map(fill-port, ports(m))
+ val body* = fill-comm(body(m))
+
+ Circuit(modules*, main(c)) where :
+ val modules* = map(fill-mod, modules(c))
+
+
+;=============== TYPE INFERENCE DRIVER =====================
+defn infer-widths (c:Circuit) :
+ val [c*, cs] = generate-constraints(c)
+ val solved = solve-widths(cs)
+ fill-widths(c*, solved)
+
+
+;================ PAD WIDTHS ===============================
+defn pad-widths (c:Circuit) :
+ ;Pad an expression to the given width
+ defn pad-exp (e:Expression, w:Int) :
+ match(type(e)) :
+ (t:UIntType|SIntType) :
+ val prev-w = width!(t) as IntWidth
+ if width(prev-w) < w :
+ val type* = put-width(t, IntWidth(w))
+ DoPrim(PAD-OP, list(e), list(w), type*)
+ else :
+ e
+ (t) :
+ e
+
+ defn pad-exp (e:Expression, w:Width) :
+ val w-value = width(w as IntWidth)
+ pad-exp(e, w-value)
+
+ ;Convenience
+ defn max-width (es:Streamable<Expression>) :
+ defn int-width (e:Expression) :
+ width(width!(type(e)) as IntWidth)
+ maximum(stream(int-width, es))
+
+ defn match-widths (es:List<Expression>) :
+ val w = max-width(es)
+ map(pad-exp{_, w}, es)
+
+ ;Match widths for an expression
+ defn match-exp-width (e:Expression) :
+ match(map(match-exp-width, e)) :
+ (e:DoPrim) :
+ if contains?([BIT-AND-OP, BIT-OR-OP, BIT-XOR-OP, EQUAL-OP], op(e)) :
+ val args* = match-widths(args(e))
+ DoPrim(op(e), args*, consts(e), type(e))
+ else if op(e) == MULTIPLEX-OP :
+ val args* = List(head(args(e)), match-widths(tail(args(e))))
+ DoPrim(op(e), args*, consts(e), type(e))
+ else :
+ e
+ (e) : e
+
+ defn match-element-width (e:Element) :
+ match(map(match-exp-width, e)) :
+ (e:Register) :
+ val w = width!(type(e))
+ val value* = pad-exp(value(e), w)
+ Register(type(e), value*, enable(e))
+ (e:Memory) :
+ val width = width!(type(type(e) as VectorType))
+ val writers* =
+ for w in writers(e) map :
+ WritePort(index(w), pad-exp(value(w), width), enable(w))
+ Memory(type(e), writers*)
+ (e:Node) :
+ val w = width!(type(e))
+ val value* = pad-exp(value(e), w)
+ Node(type(e), value*)
+ (e:Instance) :
+ val mod-type = type(module(e)) as BundleType
+ val ports* =
+ for p in ports(e) map :
+ val port-type = bundle-field-type(mod-type, key(p))
+ val port-val = pad-exp(value(p), width!(port-type))
+ key(p) => port-val
+ Instance(type(e), module(e), ports*)
+
+ ;Match widths for a command
+ defn match-comm-width (c:Command) :
+ match(map(match-exp-width, c)) :
+ (c:LetRec) :
+ val entries* =
+ for e in entries(c) map :
+ key(e) => match-element-width(value(e))
+ LetRec(entries*, match-comm-width(body(c)))
+ (c:Connect) :
+ val w = width!(type(loc(c)))
+ val exp* = pad-exp(exp(c), w)
+ Connect(loc(c), exp*)
+ (c) :
+ map(match-comm-width, c)
+
+ defn match-mod-width (m:Module) :
+ Module(name(m), ports(m), body*) where :
+ val body* = match-comm-width(body(m))
+
+ Circuit(modules*, main(c)) where :
+ val modules* = map(match-mod-width, modules(c))
+
+
+;================== INLINING ===============================
+defn inline-instances (c:Circuit) :
+ val module-table = HashTable<Symbol,Module>(symbol-hash)
+ val inlined? = HashTable<Symbol,True|False>(symbol-hash)
+ for m in modules(c) do :
+ module-table[name(m)] = m
+ inlined?[name(m)] = false
+
+ ;Convert a module into a sequence of elements
+ defn to-elements (m:Module,
+ inst:Symbol,
+ port-exps:List<KeyValue<Symbol,Expression>>) ->
+ List<KeyValue<Symbol, Element>> :
+ defn rename-exp (e:Expression) :
+ match(e) :
+ (e:WRef) : WRef(prefix(inst, name(e)), type(e), kind(e), dir(e))
+ (e) : map(rename-exp, e)
+
+ defn to-elements (c:Command) -> List<KeyValue<Symbol,Element>> :
+ match(c) :
+ (c:LetRec) :
+ val entries* =
+ for entry in entries(c) map :
+ val name* = prefix(inst, key(entry))
+ val element* = map(rename-exp, value(entry))
+ name* => element*
+ val body* = to-elements(body(c))
+ append(entries*, body*)
+ (c:Connect) :
+ val ref = loc(c) as WRef
+ val name* = prefix(inst, name(ref))
+ list(name* => Node(type(exp(c)), rename-exp(exp(c))))
+ (c:Begin) :
+ map-append(to-elements, body(c))
+
+ val inputs =
+ for p in ports(m) map-append :
+ if direction(p) == INPUT :
+ val port-exp = lookup!(port-exps, name(p))
+ val name* = prefix(inst, name(p))
+ list(name* => Node(type(port-exp), port-exp))
+ else :
+ List()
+ append(inputs, to-elements(body(m)))
+
+ ;Inline all instances in the module
+ defn inline-instances (m:Module) :
+ defn rename-exp (e:Expression) :
+ match(e) :
+ (e:WField) :
+ val inst-exp = exp(e) as WRef
+ val name* = prefix(name(inst-exp), name(e))
+ WRef(name*, type(e), NodeKind(), dir(e))
+ (e) :
+ map(rename-exp, e)
+
+ defn inline-elems (es:List<KeyValue<Symbol,Element>>) :
+ for entry in es map-append :
+ match(value(entry)) :
+ (el:Instance) :
+ val mod-name = name(module(el) as WRef)
+ val module = inlined-module(mod-name)
+ to-elements(module, key(entry), ports(el))
+ (el) :
+ list(entry)
+
+ defn inline-comm (c:Command) :
+ match(map(rename-exp, c)) :
+ (c:LetRec) :
+ val entries* = inline-elems(entries(c))
+ LetRec(entries*, inline-comm(body(c)))
+ (c) :
+ map(inline-comm, c)
+
+ Module(name(m), ports(m), inline-comm(body(m)))
+
+ ;Retrieve the inlined instance of a module
+ defn inlined-module (name:Symbol) :
+ if inlined?[name] :
+ module-table[name]
+ else :
+ val module* = inline-instances(module-table[name])
+ module-table[name] = module*
+ inlined?[name] = true
+ module*
+
+ ;Return the fully inlined circuit
+ val main-module = inlined-module(main(c))
+ Circuit(list(main-module), main(c))
+
+
+;;;================ UTILITIES ================================
+;
+;
+;
+;defn* root-ref (i:Immediate) :
+; match(i) :
+; (f:Field) : root-ref(imm(f))
+; (ind:Index) : root-ref(imm(ind))
+; (r) : r
+;
+;;defn lookup<?T> (e: Streamable<KeyValue<Immediate,?T>>, i:Immediate) :
+;; for entry in e search :
+;; if eqv?(key(entry), i) :
+;; value(entry)
+;;
+;;defn lookup!<?T> (e: Streamable<KeyValue<Immediate,?T>>, i:Immediate) :
+;; lookup(e, i) as T
+;;
+;;============ CHECK IF NAMES ARE UNIQUE ====================
+;defn check-duplicate-symbols (names: Streamable<Symbol>, msg: String) :
+; val dict = HashTable<Symbol, True>(symbol-hash)
+; for name in names do:
+; if key?(dict, name):
+; throw $ PassException $ string-join $
+; [msg ": " name]
+; else:
+; dict[name] = true
+;
+;defn check-duplicates (t: Type) :
+; match(t) :
+; (t:BundleType) :
+; val names = map(name, ports(t))
+; check-duplicate-symbols{names, string-join(_)} $
+; ["Duplicate port name in bundle "]
+; do(check-duplicates{type(_)}, ports(t))
+; (t:VectorType) :
+; check-duplicates(type(t))
+; (t) : false
+;
+;defn check-duplicates (c: Command) :
+; match(c) :
+; (c:DefWire) : check-duplicates(type(c))
+; (c:DefRegister) : check-duplicates(type(c))
+; (c:DefMemory) : check-duplicates(type(c))
+; (c) : do(check-duplicates, children(c))
+;
+;defn defined-names (c: Command) :
+; generate<Symbol> :
+; loop(c) where :
+; defn loop (c:Command) :
+; match(c) :
+; (c:Command&HasName) : yield(name(c))
+; (c) : do(loop, children(c))
+;
+;defn check-duplicates (m: Module):
+; ;Check all duplicate names in all types in all ports and body
+; do(check-duplicates{type(_)}, ports(m))
+; check-duplicates(body(m))
+;
+; ;Check all names defined in module
+; val names = concat(stream(name, ports(m)),
+; defined-names(body(m)))
+; check-duplicate-symbols{names, string-join(_)} $
+; ["Duplicate definition name in module " name(m)]
+;
+;defn check-duplicates (c: Circuit) :
+; ;Check all duplicate names in all modules
+; do(check-duplicates, modules(c))
+;
+; ;Check all defined modules
+; val names = stream(name, modules(c))
+; check-duplicate-symbols(names, "Duplicate module name")
+;
+;
+
+
+
+
+
+
+
+;;================ CLEANUP COMMANDS =========================
+;defn cleanup (c:Command) :
+; match(c) :
+; (c:Begin) :
+; to-command $ generate<Command> :
+; loop(c) where :
+; defn loop (c:Command) :
+; match(c) :
+; (c:Begin) : do(loop, body(c))
+; (c:EmptyCommand) : false
+; (c) : yield(cleanup(c))
+; (c) : map(cleanup{_ as Command}, c)
+;
+;defn cleanup (c:Circuit) :
+; val modules* =
+; for m in modules(c) map :
+; map(cleanup, m)
+; Circuit(modules*, main(c))
+;
+
+
+
+;;;============= SHIM ========================================
+;;defn shim (i:Immediate) -> Immediate :
+;; match(i) :
+;; (i:RegData) :
+;; Ref(name(i), direction(i), type(i))
+;; (i:InstPort) :
+;; val inst = Ref(name(i), UNKNOWN-DIR, UnknownType())
+;; Field(inst, port(i), direction(i), type(i))
+;; (i:Field) :
+;; val imm* = shim(imm(i))
+;; put-imm(i, imm*)
+;; (i) : i
+;;
+;;defn shim (c:Command) -> Command :
+;; val c* = map(shim{_ as Immediate}, c)
+;; map(shim{_ as Command}, c*)
+;;
+;;defn shim (c:Circuit) -> Circuit :
+;; val modules* =
+;; for m in modules(c) map :
+;; Module(name(m), ports(m), shim(body(m)))
+;; Circuit(modules*, main(c))
+;;
+;;;================== INLINE MODULES =========================
+;;defn cat-name (p: String|Symbol, s: String|Symbol) -> Symbol :
+;; if p == "" or p == `this : ;; TODO: REMOVE THIS WHEN `THIS GETS REMOVED
+;; to-symbol(s)
+;; else if s == `this : ;; TODO: DITTO
+;; to-symbol(p)
+;; else :
+;; symbol-join([p, "/", s])
+;;
+;;defn inline-command (c: Command, mods: HashTable<Symbol, Module>, prefix: String, cmds: Vector<Command>) :
+;; defn rename (n: Symbol) -> Symbol :
+;; cat-name(prefix, n)
+;; defn inline-name (i:Immediate) -> Symbol :
+;; match(i) :
+;; (r:Ref) : rename(name(r))
+;; (f:Field) : cat-name(inline-name(imm(f)), name(f))
+;; (f:Index) : cat-name(inline-name(imm(f)), to-string(value(f)))
+;; defn inline-imm (i:Immediate) -> Ref :
+;; Ref(inline-name(i), direction(i), type(i))
+;; match(c) :
+;; (c:DefUInt) : add(cmds, DefUInt(rename(name(c)), value(c), width(c)))
+;; (c:DefSInt) : add(cmds, DefSInt(rename(name(c)), value(c), width(c)))
+;; (c:DefWire) : add(cmds, DefWire(rename(name(c)), type(c)))
+;; (c:DefRegister) : add(cmds, DefRegister(rename(name(c)), type(c)))
+;; (c:DefMemory) : add(cmds, DefMemory(rename(name(c)), type(c), size(c)))
+;; (c:DefInstance) : inline-module(mods, mods[name(module(c))], to-string(rename(name(c))), cmds)
+;; (c:DoPrim) : add(cmds, DoPrim(rename(name(c)), op(c), map(inline-imm, args(c)), consts(c)))
+;; (c:DefAccessor) : add(cmds, DefAccessor(rename(name(c)), inline-imm(source(c)), direction(c), inline-imm(index(c))))
+;; (c:Connect) : add(cmds, Connect(inline-imm(loc(c)), inline-imm(exp(c))))
+;; (c:Begin) : do(inline-command{_, mods, prefix, cmds}, body(c))
+;; (c:EmptyCommand) : c
+;; (c) : error("Unsupported command")
+;;
+;;defn inline-port (p: Port, prefix: String) -> Command :
+;; DefWire(cat-name(prefix, name(p)), type(p))
+;;
+;;defn inline-module (mods: HashTable<Symbol, Module>, mod: Module, prefix: String, cmds: Vector<Command>) :
+;; do(add{cmds, _}, map(inline-port{_, prefix}, ports(mod)))
+;; inline-command(body(mod), mods, prefix, cmds)
+;;
+;;defn inline-modules (c: Circuit) -> Circuit :
+;; val cmds = Vector<Command>()
+;; val mods = HashTable<Symbol, Module>(symbol-hash)
+;; for mod in modules(c) do :
+;; mods[name(mod)] = mod
+;; val top = mods[main(c)]
+;; inline-command(body(top), mods, "", cmds)
+;; val main* = Module(name(top), ports(top), Begin(to-list(cmds)))
+;; Circuit(list(main*), name(top))
+;;
+;;
+;;;============= FLO PRINTER ======================================
+;;;;; TODO:
+;;;;; not supported gt, lte
+;;
+;;defn flo-op-name (op:PrimOp) -> String :
+;; switch {op == _} :
+;; ADD-OP : "add"
+;; ADD-MOD-OP : "add"
+;; MINUS-OP : "sub"
+;; SUB-MOD-OP : "sub"
+;; TIMES-OP : "mul" ;; todo: signed version
+;; DIVIDE-OP : "div" ;; todo: signed version
+;; MOD-OP : "mod" ;; todo: signed version
+;; SHIFT-LEFT-OP : "lsh" ;; todo: signed version
+;; SHIFT-RIGHT-OP : "rsh"
+;; PAD-OP : "pad" ;; todo: signed version
+;; BIT-AND-OP : "and"
+;; BIT-OR-OP : "or"
+;; BIT-XOR-OP : "xor"
+;; CONCAT-OP : "cat"
+;; BIT-SELECT-OP : "rsh"
+;; BITS-SELECT-OP : "rsh"
+;; LESS-OP : "lt" ;; todo: signed version
+;; LESS-EQ-OP : "lte" ;; todo: swap args
+;; GREATER-OP : "gt" ;; todo: swap args
+;; GREATER-EQ-OP : "gte" ;; todo: signed version
+;; EQUAL-OP : "eq"
+;; MULTIPLEX-OP : "mux"
+;; else : error $ string-join $
+;; ["Unable to print Primop: " op]
+;;
+;;defn emit (o:OutputStream, top:Symbol, ports:HashTable<Symbol, Port>, lits:HashTable<Symbol, DefUInt>, elt) :
+;; match(elt) :
+;; (e:String|Symbol|Int) :
+;; print(o, e)
+;; (e:Ref) :
+;; if key?(lits, name(e)) :
+;; val lit = lits[name(e)]
+;; print-all(o, [value(lit) "'" width(lit)])
+;; else :
+;; if key?(ports, name(e)) :
+;; print-all(o, [top "::"])
+;; print(o, name(e))
+;; (e:IntWidth) :
+;; print(o, value(e))
+;; (e:PrimOp) :
+;; print(o, flo-op-name(e))
+;; (e) :
+;; println-all(["EMIT " e])
+;; error("Unable to emit")
+;;
+;;defn emit-all (o:OutputStream, top:Symbol, ports:HashTable<Symbol, Port>, lits:HashTable<Symbol, DefUInt>, elts: Streamable) :
+;; for e in elts do : emit(o, top, ports, lits, e)
+;;
+;;defn prim-width (type:Type) -> Width :
+;; match(type) :
+;; (t:UIntType) : width(t)
+;; (t:SIntType) : width(t)
+;; (t) : error("Bad prim width type")
+;;
+;;defn emit-command (o:OutputStream, cmd:Command, top:Symbol, lits:HashTable<Symbol, DefUInt>, regs:HashTable<Symbol, DefRegister>, accs:HashTable<Symbol, DefAccessor>, ports:HashTable<Symbol, Port>, outs:HashTable<Symbol, Port>) :
+;; match(cmd) :
+;; (c:DefUInt) :
+;; lits[name(c)] = c
+;; (c:DefSInt) :
+;; emit-all(o, top, ports, lits, [name(c) " = " value(c) "'" width(c) "\n"])
+;; (c:DoPrim) : ;; NEED TO FIGURE OUT WHEN WIDTHS ARE NECESSARY AND EXTRACT
+;; emit-all(o, top, ports, lits, [name(c) " = " op(c)])
+;; for arg in args(c) do :
+;; print(o, " ")
+;; emit(o, top, ports, lits, arg)
+;; for const in consts(c) do :
+;; print(o, " ")
+;; emit(o, top, ports, lits, const)
+;; print("\n")
+;; (c:DefRegister) :
+;; regs[name(c)] = c
+;; (c:DefMemory) :
+;; emit-all(o, top, ports, lits, [name(c) " : mem'" prim-width(type(c)) " " size(c) "\n"])
+;; (c:DefAccessor) :
+;; accs[name(c)] = c
+;; (c:Connect) :
+;; val dst = name(loc(c) as Ref)
+;; val src = name(exp(c) as Ref)
+;; if key?(regs, dst) :
+;; val reg = regs[dst]
+;; emit-all(o, top, ports, lits, [dst " = reg'" prim-width(type(reg)) " 0'" prim-width(type(reg)) " " exp(c) "\n"])
+;; else if key?(accs, dst) :
+;; val acc = accs[dst]
+;; ;; assert(direction(acc) == OUTPUT)
+;; emit-all(o, top, ports, lits, [dst " = wr " source(acc) " " index(acc) " " exp(c) "\n"])
+;; else if key?(outs, dst) :
+;; val out = outs[dst]
+;; emit-all(o, top, ports, lits, [top "::" dst " = out'" prim-width(type(out)) " " exp(c) "\n"])
+;; else if key?(accs, src) :
+;; val acc = accs[src]
+;; ;; assert(direction(acc) == INPUT)
+;; emit-all(o, top, ports, lits, [dst " = rd " source(acc) " " index(acc) "\n"])
+;; else :
+;; emit-all(o, top, ports, lits, [dst " = mov " exp(c) "\n"])
+;; (c:Begin) :
+;; do(emit-command{o, _, top, lits, regs, accs, ports, outs}, body(c))
+;; (c:DefWire|EmptyCommand) :
+;; print("")
+;; (c) :
+;; error("Unable to print command")
+;;
+;;defn emit-module (o:OutputStream, m:Module) :
+;; val regs = HashTable<Symbol, DefRegister>(symbol-hash)
+;; val accs = HashTable<Symbol, DefAccessor>(symbol-hash)
+;; val lits = HashTable<Symbol, DefUInt>(symbol-hash)
+;; val outs = HashTable<Symbol, Port>(symbol-hash)
+;; val portz = HashTable<Symbol, Port>(symbol-hash)
+;; for port in ports(m) do :
+;; portz[name(port)] = port
+;; if direction(port) == OUTPUT :
+;; outs[name(port)] = port
+;; else if name(port) == `reset :
+;; print-all(o, [name(m) "::reset = rst\n"])
+;; else :
+;; print-all(o, [name(m) "::" name(port) " = " "in'" prim-width(type(port)) "\n"])
+;; emit-command(o, body(m), name(m), lits, regs, accs, portz, outs)
+;;
+;;public defn emit-circuit (o:OutputStream, c:Circuit) :
+;; emit-module(o, modules(c)[0])
+
+
+;============= DRIVER ======================================
+public defn run-passes (c: Circuit) :
+ var c*:Circuit = c
+ defn do-stage (name:String, f: Circuit -> Circuit) :
+ println(name)
+ c* = f(c*)
+ println(c*)
+ println("\n\n\n\n")
+
+ do-stage("Working IR", to-working-ir)
+ do-stage("Resolve Kinds", resolve-kinds)
+ do-stage("Make Explicit Reset", make-explicit-reset)
+ do-stage("Infer Types", infer-types)
+ do-stage("Infer Directions", infer-directions)
+ do-stage("Expand Accessors", expand-accessors)
+ do-stage("Flatten Bundles", flatten-bundles)
+ do-stage("Expand Bundles", expand-bundles)
+ do-stage("Expand Multi Connects", expand-multi-connects)
+ do-stage("Expand Whens", expand-whens)
+ do-stage("Structural Form", structural-form)
+ do-stage("Infer Widths", infer-widths)
+ do-stage("Pad Widths", pad-widths)
+ do-stage("Inline Instances", inline-instances)
+
+
+ ;; println("Shim for Jonathan's Passes")
+ ;; c* = shim(c*)
+ ;; println("Inline Modules")
+ ;; c* = inline-modules(c*)
+ ; c*