diff options
Diffstat (limited to 'src/main/stanza')
| -rw-r--r-- | src/main/stanza/firrtl-ir.stanza | 146 | ||||
| -rw-r--r-- | src/main/stanza/firrtl-main.stanza | 26 | ||||
| -rw-r--r-- | src/main/stanza/ir-parser.stanza | 204 | ||||
| -rw-r--r-- | src/main/stanza/ir-utils.stanza | 227 | ||||
| -rw-r--r-- | src/main/stanza/passes.stanza | 1878 | ||||
| -rw-r--r-- | src/main/stanza/widthsolver.stanza | 298 |
6 files changed, 2779 insertions, 0 deletions
diff --git a/src/main/stanza/firrtl-ir.stanza b/src/main/stanza/firrtl-ir.stanza new file mode 100644 index 00000000..1a6df3d5 --- /dev/null +++ b/src/main/stanza/firrtl-ir.stanza @@ -0,0 +1,146 @@ +defpackage chipper.ir2 : + import core + import verse + +public definterface Direction +public val INPUT = new Direction +public val OUTPUT = new Direction +public val UNKNOWN-DIR = new Direction + +public definterface Width +public defstruct UnknownWidth <: Width +public defstruct IntWidth <: Width : + width: Int + +public defstruct PrimOp +public val ADD-OP = PrimOp() +public val ADD-MOD-OP = PrimOp() +public val SUB-OP = PrimOp() +public val SUB-MOD-OP = PrimOp() +public val TIMES-OP = PrimOp() +public val DIVIDE-OP = PrimOp() +public val MOD-OP = PrimOp() +public val SHIFT-LEFT-OP = PrimOp() +public val SHIFT-RIGHT-OP = PrimOp() +public val PAD-OP = PrimOp() +public val BIT-AND-OP = PrimOp() +public val BIT-OR-OP = PrimOp() +public val BIT-XOR-OP = PrimOp() +public val CONCAT-OP = PrimOp() +public val BIT-SELECT-OP = PrimOp() +public val BITS-SELECT-OP = PrimOp() +public val MULTIPLEX-OP = PrimOp() +public val LESS-OP = PrimOp() +public val LESS-EQ-OP = PrimOp() +public val GREATER-OP = PrimOp() +public val GREATER-EQ-OP = PrimOp() +public val EQUAL-OP = PrimOp() + +public definterface Expression +public defmulti type (e:Expression) -> Type + +public defstruct Ref <: Expression : + name: Symbol + type: Type [multi => false] +public defstruct Field <: Expression : + exp: Expression + name: Symbol + type: Type [multi => false] +public defstruct Index <: Expression : + exp: Expression + value: Int + type: Type [multi => false] +public defstruct UIntValue <: Expression : + value: Int + width: Width +public defstruct SIntValue <: Expression : + value: Int + width: Width +public defstruct DoPrim <: Expression : + op: PrimOp + args: List<Expression> + consts: List<Int> + type: Type [multi => false] +public defstruct ReadPort <: Expression : + mem: Expression + index: Expression + type: Type [multi => false] + +public definterface Command +public defstruct LetRec <: Command : + entries: List<KeyValue<Symbol, Element>> + body: Command +public defstruct DefWire <: Command : + name: Symbol + type: Type +public defstruct DefRegister <: Command : + name: Symbol + type: Type +public defstruct DefInstance <: Command : + name: Symbol + module: Expression +public defstruct DefMemory <: Command : + name: Symbol + type: VectorType +public defstruct DefAccessor <: Command : + name: Symbol + source: Expression + index: Expression +public defstruct Conditionally <: Command : + pred: Expression + conseq: Command + alt: Command +public defstruct Begin <: Command : + body: List<Command> +public defstruct Connect <: Command : + loc: Expression + exp: Expression +public defstruct EmptyCommand <: Command + +public definterface Element +public defmulti type (e:Element) -> Type + +public defstruct Register <: Element : + type: Type [multi => false] + value: Expression + enable: Expression +public defstruct Memory <: Element : + type: Type [multi => false] + writers: List<WritePort> +public defstruct WritePort : + index: Expression + value: Expression + enable: Expression +public defstruct Node <: Element : + type: Type [multi => false] + value: Expression +public defstruct Instance <: Element : + type: Type [multi => false] + module: Expression + ports: List<KeyValue<Symbol,Expression>> + +public definterface Type +public defstruct UIntType <: Type : + width: Width +public defstruct SIntType <: Type : + width: Width +public defstruct BundleType <: Type : + ports: List<Port> +public defstruct VectorType <: Type : + type: Type + size: Int +public defstruct UnknownType <: Type + +public defstruct Port : + name: Symbol + direction: Direction + type: Type + +public defstruct Module : + name: Symbol + ports: List<Port> + body: Command + +public defstruct Circuit : + modules: List<Module> + main: Symbol
\ No newline at end of file diff --git a/src/main/stanza/firrtl-main.stanza b/src/main/stanza/firrtl-main.stanza new file mode 100644 index 00000000..bafb9dd7 --- /dev/null +++ b/src/main/stanza/firrtl-main.stanza @@ -0,0 +1,26 @@ +include<"core/stringeater.stanza"> +include<"compiler/lexer.stanza"> +include<"compiler/parser.stanza"> +include<"compiler/rdparser2.stanza"> +include<"compiler/macro-utils.stanza"> +include("firrtl-ir.stanza") +include("ir-utils.stanza") +include("ir-parser.stanza") +include("passes.stanza") +include("widthsolver.stanza") + +defpackage chmain : + import core + import verse + import chipper.parser + import chipper.passes + import stanza.lexer + import stanza.parser + +defn main () : + val lexed = lex-file("../../test/firrtl/firrtl-test.txt") + val c = parse-firrtl(lexed) + println(c) + run-passes(c) + +main() diff --git a/src/main/stanza/ir-parser.stanza b/src/main/stanza/ir-parser.stanza new file mode 100644 index 00000000..0da99033 --- /dev/null +++ b/src/main/stanza/ir-parser.stanza @@ -0,0 +1,204 @@ +defpackage chipper.parser : + import core + import verse + import chipper.ir2 + import stanza.rdparser + import stanza.lexer + +;======= Convenience Functions ==== +defn throw-error (x) : + throw $ new Exception : + defmethod print (o:OutputStream, this) : + print(o, x) + +defn ut (x) : + unwrap-token(x) + +;======== String Splitting ======== +defn substring? (s:String, look:String) : + index-of-string(s, look) != false + +defn index-of-string (s:String, look:String) : + for i in 0 through length(s) - length(look) index-when : + for j in 0 to length(look) all? : + s[i + j] == look[j] + +defn split-string (s:String, split:String) -> List<String> : + defn loop (s:String) -> List<String> : + if length(s) == 0 : + List() + else : + match(index-of-string(s, split)) : + (i:Int) : + val rest = loop(substring(s, i + length(split))) + if i == 0 : List(split, rest) + else : List(substring(s, 0, i), split, rest) + (f:False) : list(s) + loop(s) + +;======= Unwrap Prefix Forms ============ +defn unwrap-prefix-form (form) : + match(form) : + (form:Token) : + val fs = unwrap-prefix-form(item(form)) + List(Token(head(fs), info(form)), tail(fs)) + (form:List) : + if tagged-list?(form, `(@get @do @do-afn @of)) : + val rest = map-append(unwrap-prefix-form, tailn(form, 2)) + val form* = List(form[0], rest) + append(unwrap-prefix-form(form[1]), list(form*)) + else : + list(map-append(unwrap-prefix-form, form)) + (form) : + list(form) + +;======= Split Dots ============ +defn split-dots (forms:List) : + defn split (form) : + match(ut(form)) : + (f:Symbol) : + val fstr = to-string(f) + if contains?(fstr, '.') : map(to-symbol, split-string(fstr, ".")) + else : list(form) + (f:List) : + list(map-append(split, f)) + (f) : + list(f) + head(split(forms)) + +;====== Normalize Dots ======== +defn normalize-dots (forms:List) : + val forms* = head(unwrap-prefix-form(forms)) + split-dots(forms*) + +;======== SYNTAX ======================= +rd.defsyntax firrtl : + defrule circuit : + (circuit ?name:#symbol : (?module-form ...)) : + rd.match-syntax(normalize-dots(module-form)) : + (?modules:#module ...) : + Circuit(modules, ut(name)) + + defrule module : + (module ?name:#symbol : (?ports:#port ... ?body:#comm ...)) : + Module(ut(name), ports, Begin(body)) + + defrule port : + (input ?name:#symbol : ?type:#type) : + Port(ut(name), INPUT, type) + (output ?name:#symbol : ?type:#type) : + Port(ut(name), OUTPUT, type) + + defrule type : + (?type:#type (@get ?size:#int)) : + VectorType(type, ut(size)) + (UInt (@do ?width:#int)) : + UIntType(IntWidth(ut(width))) + (UInt) : + UIntType(UnknownWidth()) + (SInt (@do ?width:#int)) : + SIntType(IntWidth(ut(width))) + (SInt) : + SIntType(UnknownWidth()) + ({?ports:#port ...}) : + BundleType(ports) + + defrule comm : + (wire ?name:#symbol : ?type:#type) : + DefWire(ut(name), type) + (reg ?name:#symbol : ?type:#type) : + DefRegister(ut(name), type) + (mem ?name:#symbol : ?type:#type) : + DefMemory(ut(name), type) + (inst ?name:#symbol of ?module:#exp) : + DefInstance(ut(name), module) + (accessor ?name:#symbol = ?source:#exp (@get ?index:#exp)) : + DefAccessor(ut(name), source, index) + ((?body:#comm ...)) : + Begin(body) + (letrec : (?elems:#element ...) in : ?body:#comm) : + LetRec(elems, body) + (?x:#exp := ?y:#exp) : + Connect(x, y) + (?c:#comm/when) : + c + + defrule comm/when : + (when ?pred:#exp : ?conseq:#comm else : ?alt:#comm) : + Conditionally(pred, conseq, alt) + (when ?pred:#exp : ?conseq:#comm else ?alt:#comm/when) : + Conditionally(pred, conseq, alt) + (when ?pred:#exp : ?conseq:#comm) : + Conditionally(pred, conseq, EmptyCommand()) + + defrule element : + (reg ?name:#symbol : ?type:#type = Register (@do ?value:#exp ?en:#exp)) : + ut(name) => Register(type, value, en) + (mem ?name:#symbol : ?type:#type = Memory + (@do (?i:#exp => WritePort (@do ?value:#exp ?en:#exp) @...))) : + val ports = map(WritePort, i, value, en) + ut(name) => Memory(type, ports) + (node ?name:#symbol : ?type:#type = ?exp:#exp) : + ut(name) => Node(type, exp) + (inst ?name:#symbol = Instance (@do ?module:#exp + (?names:#symbol => ?values:#exp @...))) : + val ports = map({ut(_) => _}, names, values) + ut(name) => Instance(UnknownType(), module, ports) + + defrule exp : + (?x:#exp . ?f:#symbol) : + Field(x, ut(f), UnknownType()) + (?x:#exp . ?f:#int) : + Index(x, ut(f), UnknownType()) + (?x:#exp-form) : + x + + val operators = HashTable<Symbol, PrimOp>(symbol-hash) + operators[`add] = ADD-OP + operators[`add-mod] = ADD-MOD-OP + operators[`sub] = SUB-OP + operators[`sub-mod] = SUB-MOD-OP + operators[`times] = TIMES-OP + operators[`mod] = MOD-OP + operators[`bit-and] = BIT-AND-OP + operators[`bit-or] = BIT-OR-OP + operators[`bit-xor] = BIT-XOR-OP + operators[`concat] = CONCAT-OP + operators[`less] = LESS-OP + operators[`less-eq] = LESS-EQ-OP + operators[`greater] = GREATER-OP + operators[`greater-eq] = GREATER-EQ-OP + operators[`equal] = EQUAL-OP + operators[`multiplex] = MULTIPLEX-OP + operators[`pad] = PAD-OP + operators[`shift-left] = SHIFT-LEFT-OP + operators[`shift-right] = SHIFT-RIGHT-OP + operators[`bit] = BIT-SELECT-OP + operators[`bits] = BITS-SELECT-OP + + defrule exp-form : + (UInt (@do ?value:#int ?width:#int)) : + UIntValue(ut(value), IntWidth(ut(width))) + (UInt (@do ?value:#int)) : + UIntValue(ut(value), UnknownWidth()) + (SInt (@do ?value:#int ?width:#int)) : + SIntValue(ut(value), IntWidth(ut(width))) + (SInt (@do ?value:#int)) : + SIntValue(ut(value), UnknownWidth()) + (ReadPort (@do ?mem:#exp ?index:#exp)) : + ReadPort(mem, index, UnknownType()) + (?op:#symbol (@do ?es:#exp ... ?ints:#int ...)) : + match(get?(operators, ut(op), false)) : + (op:PrimOp) : + DoPrim(op, es, map(ut, ints), UnknownType()) + (f:False) : + throw-error $ string-join $ [ + "Invalid operator: " op] + (?x:#symbol) : + Ref(ut(x), UnknownType()) + +public defn parse-firrtl (forms:List) : + with-parser{`firrtl, _} $ fn () : + rd.match-syntax(forms) : + (?c:#circuit) : + c
\ No newline at end of file diff --git a/src/main/stanza/ir-utils.stanza b/src/main/stanza/ir-utils.stanza new file mode 100644 index 00000000..be08cac8 --- /dev/null +++ b/src/main/stanza/ir-utils.stanza @@ -0,0 +1,227 @@ +defpackage chipper.ir-utils : + import core + import verse + import chipper.ir2 + +;============== PRINTERS =================================== +defmethod print (o:OutputStream, d:Direction) : + print{o, _} $ + switch {d == _} : + INPUT : "input" + OUTPUT: "output" + UNKNOWN-DIR : "unknown" + +defmethod print (o:OutputStream, w:Width) : + print{o, _} $ + match(w) : + (w:UnknownWidth) : "?" + (w:IntWidth) : width(w) + +defmethod print (o:OutputStream, op:PrimOp) : + print{o, _} $ + switch {op == _} : + ADD-OP : "ADD" + ADD-MOD-OP : "ADD-MOD" + SUB-OP : "MINUS" + SUB-MOD-OP : "SUB-MOD" + TIMES-OP : "TIMES" + DIVIDE-OP : "DIVIDE" + MOD-OP : "MOD" + SHIFT-LEFT-OP : "SHIFT-LEFT" + SHIFT-RIGHT-OP : "SHIFT-RIGHT" + PAD-OP : "PAD" + BIT-AND-OP : "BIT-AND" + BIT-OR-OP : "BIT-OR" + BIT-XOR-OP : "BIT-XOR" + CONCAT-OP : "CONCAT" + BIT-SELECT-OP : "BIT-SELECT" + BITS-SELECT-OP : "BITS-SELECT" + MULTIPLEX-OP : "MULTIPLEX" + LESS-OP : "LESS" + LESS-EQ-OP : "LESS-EQ" + GREATER-OP : "GREATER" + GREATER-EQ-OP : "GREATER-EQ" + EQUAL-OP : "EQUAL" + +defmethod print (o:OutputStream, e:Expression) : + match(e) : + (e:Ref) : print(o, name(e)) + (e:Field) : print-all(o, [exp(e) "." name(e)]) + (e:Index) : print-all(o, [exp(e) "." value(e)]) + (e:UIntValue) : print-all(o, ["UInt(" value(e) ")"]) + (e:SIntValue) : print-all(o, ["SInt(" value(e) ")"]) + (e:DoPrim) : + print-all(o, [op(e) "("]) + print-all(o, join(concat(args(e), consts(e)), ", ")) + print(o, ")") + (e:ReadPort) : print-all(o, ["ReadPort(" mem(e) ", " index(e) ")"]) + +defmethod print (o:OutputStream, c:Command) : + match(c) : + (c:LetRec) : + println(o, "let : ") + indented{o, _} $ fn () : + for entry in entries(c) do : + println-all([key(entry) " = " value(entry)]) + println(o, "in :") + indented(o, print{o, body(c)}) + (c:DefWire) : + print-all(["wire " name(c) " : " type(c)]) + (c:DefRegister) : + print-all(["reg " name(c) " : " type(c)]) + (c:DefMemory) : + print-all(["mem " name(c) " : " type(c)]) + (c:DefInstance) : + print-all(["inst " name(c) " of " module(c)]) + (c:DefAccessor) : + print-all(["accessor " name(c) " = " source(c) "[" index(c) "]"]) + (c:Conditionally) : + println-all(o, ["when " pred(c) " :"]) + indented(o, print{conseq(c)}) + if alt(c) not-typeof EmptyCommand : + println(o, "\nelse :") + indented(o, print{alt(c)}) + (c:Begin) : + do(print, join(body(c), "\n")) + (c:Connect) : + print-all(o, [loc(c) " := " exp(c)]) + (c:EmptyCommand) : + print(o, "skip") + +defmethod print (o:OutputStream, e:Element) : + match(e) : + (e:Register) : + print-all(o, ["Register(" type(e) ", " value(e) ", " enable(e) ")"]) + (e:Memory) : + print-all(o, ["Memory(" type(e) ", "]) + print-all(o, join(writers(e), ", ")) + print(o, ")") + (e:Node) : + print-all(o, ["Node(" type(e) ", " value(e) ")"]) + (e:Instance) : + print-all(o, ["Instance(" module(e) ", "]) + print-all(o, join(ports(e), ", ")) + print(o, ")") + +defmethod print (o:OutputStream, p:WritePort) : + print-all(o, [index(p) " => WritePort(" value(p) ", " enable(p) ")"]) + +defmethod print (o:OutputStream, t:Type) : + match(t) : + (t:UnknownType) : + print(o, "?") + (t:UIntType) : + print-all(o, ["UInt(" width(t) ")"]) + (t:SIntType) : + print-all(o, ["SInt(" width(t) ")"]) + (t:BundleType) : + print(o, "{") + print-all(o, join(ports(t), ", ")) + print(o, "}") + (t:VectorType) : + print-all(o, [type(t) "[" size(t) "]"]) + +defmethod print (o:OutputStream, p:Port) : + print-all(o, [direction(p) " " name(p) " : " type(p)]) + +defmethod print (o:OutputStream, m:Module) : + println-all(o, ["module " name(m) " :"]) + indented{o, _} $ fn () : + do(println, ports(m)) + print(body(m)) + +defmethod print (o:OutputStream, c:Circuit) : + println-all(o, ["circuit " main(c) " :"]) + indented(o, do{println, modules(c)}) + +;================== INDENTATION ============================ +defn IndentedStream (o:OutputStream, n:Int) : + var indent? = true + defn put (c:Char) : + if indent? : + do(print{o, " "}, 0 to n) + indent? = false + print(o, c) + if c == '\n' : + indent? = true + + new OutputStream : + defmethod print (this, s:String) : do(put, s) + defmethod print (this, c:Char) : put(c) + +defn indented (o:OutputStream, f: () -> ?) : + val prev-stream = CURRENT-OUTPUT-STREAM + dynamic-wind( + fn () : CURRENT-OUTPUT-STREAM = IndentedStream(o, 3) + f + fn (f) : CURRENT-OUTPUT-STREAM = prev-stream) + + +;=================== MAPPERS =============================== +public defn map<?T> (f: Type -> Type, t:?T&Type) -> T : + val type = + match(t) : + (t:T&BundleType) : + BundleType $ + for p in ports(t) map : + Port(name(p), direction(p), f(type(p))) + (t:T&VectorType) : + VectorType(f(type(t)), size(t)) + (t) : + t + type as T&Type + +public defmulti map<?T> (f: Expression -> Expression, e:?T&Expression) -> T +defmethod map (f: Expression -> Expression, e:Expression) -> Expression : + match(e) : + (e:Field) : Field(f(exp(e)), name(e), type(e)) + (e:Index) : Index(f(exp(e)), value(e), type(e)) + (e:DoPrim) : DoPrim(op(e), map(f, args(e)), consts(e), type(e)) + (e:ReadPort) : ReadPort(f(mem(e)), f(index(e)), type(e)) + (e) : e + +public defmulti map<?T> (f: Expression -> Expression, e:?T&Element) -> T +defmethod map (f: Expression -> Expression, e:Element) -> Element : + match(e) : + (e:Register) : + Register(type(e), f(value(e)), f(enable(e))) + (e:Memory) : + val writers* = for w in writers(e) map : + WritePort(f(index(w)), f(value(w)), f(enable(w))) + Memory(type(e), writers*) + (e:Node) : + Node(type(e), f(value(e))) + (e:Instance) : + val ports* = for p in ports(e) map : + key(p) => f(value(p)) + Instance(type(e), f(module(e)), ports*) + +public defmulti map<?T> (f: Expression -> Expression, c:?T&Command) -> T +defmethod map (f: Expression -> Expression, c:Command) -> Command : + match(c) : + (c:LetRec) : + val entries* = for entry in entries(c) map : + key(entry) => map(f, value(entry)) + LetRec(entries*, body(c)) + (c:DefAccessor) : DefAccessor(name(c), f(source(c)), f(index(c))) + (c:DefInstance) : DefInstance(name(c), f(module(c))) + (c:Conditionally) : Conditionally(f(pred(c)), conseq(c), alt(c)) + (c:Connect) : Connect(f(loc(c)), f(exp(c))) + (c) : c + +public defmulti map<?T> (f: Command -> Command, c:?T&Command) -> T +defmethod map (f: Command -> Command, c:Command) -> Command : + match(c) : + (c:LetRec) : LetRec(entries(c), f(body(c))) + (c:Conditionally) : Conditionally(pred(c), f(conseq(c)), f(alt(c))) + (c:Begin) : Begin(map(f, body(c))) + (c) : c + +public defmulti children (c:Command) -> List<Command> +defmethod children (c:Command) : + match(c) : + (c:LetRec) : list(body(c)) + (c:Conditionally) : list(conseq(c), alt(c)) + (c:Begin) : body(c) + (c) : List() + diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza new file mode 100644 index 00000000..61eac73c --- /dev/null +++ b/src/main/stanza/passes.stanza @@ -0,0 +1,1878 @@ +defpackage chipper.passes : + import core + import verse + import chipper.ir2 + import chipper.ir-utils + import widthsolver + +;============== EXCEPTIONS ================================= +defclass PassException <: Exception +defn PassException (msg:String) : + new PassException : + defmethod print (o:OutputStream, this) : + print(o, msg) + +;=============== WORKING IR ================================ +definterface Kind +defstruct RegKind <: Kind +defstruct AccessorKind <: Kind +defstruct PortKind <: Kind +defstruct MemKind <: Kind +defstruct NodeKind <: Kind +defstruct ModuleKind <: Kind +defstruct InstanceKind <: Kind +defstruct StructuralMemKind <: Kind + +defstruct WRef <: Expression : + name: Symbol + type: Type [multi => false] + kind: Kind + dir: Direction [multi => false] + +defstruct WField <: Expression : + exp: Expression + name: Symbol + type: Type [multi => false] + dir: Direction [multi => false] + +defstruct WIndex <: Expression : + exp: Expression + value: Int + type: Type [multi => false] + dir: Direction [multi => false] + +defstruct WDefAccessor <: Command : + name: Symbol + source: Expression + index: Expression + dir: Direction + +;================ WORKING IR UTILS ========================= +;=== Printers === +defmethod print (o:OutputStream, k:Kind) : + print{o, _} $ + match(k) : + (k:RegKind) : "reg:" + (k:AccessorKind) : "accessor:" + (k:PortKind) : "port:" + (k:MemKind) : "mem:" + (k:NodeKind) : "n:" + (k:ModuleKind) : "module:" + (k:InstanceKind) : "inst:" + (k:StructuralMemKind) : "smem:" + +defmethod print (o:OutputStream, e:WRef) : + print-all(o, [name(e)]) +defmethod print (o:OutputStream, e:WField) : + print-all(o, [exp(e) "." name(e)]) +defmethod print (o:OutputStream, e:WIndex) : + print-all(o, [exp(e) "." value(e)]) + +defmethod print (o:OutputStream, c:WDefAccessor) : + print-all(o, [dir(c) " accessor " name(c) " = " source(c) "[" index(c) "]"]) + +defmethod map (f: Expression -> Expression, e: WField) : + WField(f(exp(e)), name(e), type(e), dir(e)) + +defmethod map (f: Expression -> Expression, e: WIndex) : + WIndex(f(exp(e)), value(e), type(e), dir(e)) + +defmethod map (f: Expression -> Expression, c:WDefAccessor) : + WDefAccessor(name(c), f(source(c)), f(index(c)), dir(c)) + +;================= DIRECTION =============================== +defmulti dir (e:Expression) -> Direction +defmethod dir (e:Expression) : + OUTPUT + +;============== Bring to Working IR ======================== +defn to-working-ir (c:Circuit) : + defn to-exp (e:Expression) : + match(map(to-exp, e)) : + (e:Ref) : WRef(name(e), type(e), NodeKind(), UNKNOWN-DIR) + (e:Field) : WField(exp(e), name(e), type(e), UNKNOWN-DIR) + (e:Index) : WIndex(exp(e), value(e), type(e), UNKNOWN-DIR) + (e) : e + defn to-command (c:Command) : + match(map(to-exp, c)) : + (c:DefAccessor) : + WDefAccessor(name(c), source(c), index(c), UNKNOWN-DIR) + (c) : + map(to-command, c) + + Circuit(modules*, main(c)) where : + val modules* = + for m in modules(c) map : + Module(name(m), ports(m), to-command(body(m))) + +;=============== Resolve Kinds ============================= +defn resolve-kinds (c:Circuit) : + defn resolve-exp (e:Expression, kinds:HashTable<Symbol,Kind>) : + match(e) : + (e:WRef) : WRef(name(e), type(e), kinds[name(e)], dir(e)) + (e) : map(resolve-exp{_, kinds}, e) + + defn resolve-comm (c:Command, kinds:HashTable<Symbol,Kind>) -> Command : + map{resolve-comm{_, kinds}, _} $ + map(resolve-exp{_, kinds}, c) + + defn find-kinds (c:Command, kinds:HashTable<Symbol,Kind>) : + match(c) : + (c:LetRec) : + for entry in entries(c) do : + kinds[key(entry)] = element-kind(value(entry)) + (c:DefWire) : kinds[name(c)] = NodeKind() + (c:DefRegister) : kinds[name(c)] = RegKind() + (c:DefInstance) : kinds[name(c)] = InstanceKind() + (c:DefMemory) : kinds[name(c)] = MemKind() + (c:WDefAccessor) : kinds[name(c)] = AccessorKind() + (c) : false + do(find-kinds{_, kinds}, children(c)) + + defn element-kind (e:Element) : + match(e) : + (e:Memory) : StructuralMemKind() + (e) : NodeKind() + + defn resolve-mod (m:Module, modules:List<Symbol>) : + val kinds = HashTable<Symbol,Kind>(symbol-hash) + for module in modules do : + kinds[module] = ModuleKind() + for port in ports(m) do : + kinds[name(port)] = PortKind() + find-kinds(body(m), kinds) + Module(name(m), ports(m), body*) where : + val body* = resolve-comm(body(m), kinds) + + Circuit(modules*, main(c)) where : + val mod-names = map(name, modules(c)) + val modules* = map(resolve-mod{_, mod-names}, modules(c)) + +;=============== MAKE RESET EXPLICIT ======================= +defn make-explicit-reset (c:Circuit) : + defn reset-instances (c:Command, reset?: List<Symbol>) -> Command : + match(c) : + (c:DefInstance) : + val module = module(c) as WRef + if contains?(reset?, name(module)) : + c + else : + Begin $ list(c, Connect(WField(inst, `reset, UnknownType(), UNKNOWN-DIR), reset)) where : + val inst = WRef(name(c), UnknownType(), InstanceKind(), UNKNOWN-DIR) + val reset = WRef(`reset, UnknownType(), PortKind(), UNKNOWN-DIR) + (c) : + map(reset-instances{_:Command, reset?}, c) + + defn make-explicit-reset (m:Module, reset-list: List<Symbol>) : + val reset? = contains?(reset-list, name(m)) + + ;Add reset port if necessary + val ports* = + if reset? : + ports(m) + else : + val reset = Port(`reset, INPUT, UIntType(IntWidth(1))) + List(reset, ports(m)) + + ;Reset Instances + val body* = reset-instances(body(m), reset-list) + val m* = Module(name(m), ports*, body*) + + ;Initialize registers if necessary + if reset? : m* + else : initialize-registers(m*) + + Circuit(modules*, main(c)) where : + defn reset? (m:Module) : + for p in ports(m) any? : + name(p) == `reset + val reset-list = to-list(stream(name, filter(reset?, modules(c)))) + val modules* = map(make-explicit-reset{_, reset-list}, modules(c)) + +;======= MAKE EXPLICIT REGISTER INITIALIZATION ============= +defn initialize-registers (m:Module) : + ;=== Initializing Expressions === + defn init-exps (inits: List<KeyValue<Symbol,Expression>>) : + if empty?(inits) : + EmptyCommand() + else : + Conditionally(reset, Begin(map(connect, inits)), EmptyCommand()) where : + val reset = WRef(`reset, UnknownType(), PortKind(), UNKNOWN-DIR) + defn connect (init: KeyValue<Symbol, Expression>) : + val reg-ref = WRef(key(init), UnknownType(), RegKind(), UNKNOWN-DIR) + Connect(reg-ref, value(init)) + + defn initialize-registers (c: Command + inits: List<KeyValue<Symbol,Expression>>) -> + [Command, List<KeyValue<Symbol,Expression>>] : + ;=== Rename Expressions === + defn rename (e:Expression) : + match(e) : + (e:WField) : + switch {name(e) == _} : + `init : + if reg?(exp(e)) : init-wire(exp(e)) + else : map(rename, e) + else : map(rename, e) + (e) : map(rename, e) + defn reg? (e:Expression) : + match(e) : + (e:WRef) : kind(e) typeof RegKind + (e) : false + defn init-wire (e:Expression) : + lookup!(inits, name(e as WRef)) + + ;=== Driver === + match(c) : + (c:DefRegister) : + [new-command, list(init-entry)] where : + val wire-name = gensym() + val wire-ref = WRef(wire-name, UnknownType(), NodeKind(), UNKNOWN-DIR) + val reg-ref = WRef(name(c), UnknownType(), RegKind(), UNKNOWN-DIR) + val def-init-wire = DefWire(wire-name, type(c)) + val init-wire = Connect(wire-ref, reg-ref) + val init-reg = Connect(reg-ref, wire-ref) + val new-command = Begin(to-list([c, def-init-wire, init-wire, init-reg])) + val init-entry = name(c) => wire-ref + (c:Conditionally) : + val pred* = rename(pred(c)) + val [conseq* con-inits] = initialize-registers(conseq(c), inits) + val [alt* alt-inits] = initialize-registers(alt(c), inits) + val c* = Conditionally(pred*, conseq+inits, alt+inits) where : + val conseq+inits = Begin(list(conseq*, init-exps(con-inits))) + val alt+inits = Begin(list(alt*, init-exps(alt-inits))) + [c*, List()] + (c:LetRec) : + val c* = map(rename, c) + val [body*, body-inits] = initialize-registers(body(c), inits) + val new-command = + LetRec(entries(c*), body+inits) where : + val body+inits = Begin(list(body*, init-exps(body-inits))) + [new-command, List()] + (c:Begin) : + var inits-in:List<KeyValue<Symbol,Expression>> = inits + var inits-out:List<KeyValue<Symbol,Expression>> = List() + val body* = + for c in body(c) map : + val [c* inits*] = initialize-registers(c, inits-in) + inits-in = append(inits*, inits-in) + inits-out = append(inits*, inits-out) + c* + [Begin(body*), inits-out] + (c) : + val c* = map(rename, c) + [c*, List()] + + Module(name(m), ports(m), body+inits) where : + val [body*, inits] = initialize-registers(body(m), List()) + val body+inits = Begin(list(body*, init-exps(inits))) + + +;============== INFER TYPES ================================ +defmethod type (v:UIntValue) : + UIntType(width(v)) + +defmethod type (v:SIntValue) : + SIntType(width(v)) + +defn put-type (e:Expression, t:Type) -> Expression : + match(e) : + (e:WRef) : WRef(name(e), t, kind(e), dir(e)) + (e:WField) : WField(exp(e), name(e), t, dir(e)) + (e:WIndex) : WIndex(exp(e), value(e), t, dir(e)) + (e:DoPrim) : DoPrim(op(e), args(e), consts(e), t) + (e:ReadPort) : ReadPort(mem(e), index(e), t) + (e) : e + +defn lookup-port (ports: Streamable<Port>, port-name: Symbol) : + for port in ports find : + name(port) == port-name + +defn infer (op:PrimOp, arg-types: List<Type>) -> Type : + defn wipe-width (t:Type) : + match(t) : + (t:UIntType) : UIntType(UnknownWidth()) + (t:SIntType) : SIntType(UnknownWidth()) + + defn arg0 () : wipe-width(arg-types[0]) + defn arg1 () : wipe-width(arg-types[1]) + switch {op == _} : + ADD-OP : arg0() + ADD-MOD-OP : arg0() + SUB-OP : arg0() + SUB-MOD-OP : arg0() + TIMES-OP : arg0() + DIVIDE-OP : arg0() + MOD-OP : arg0() + SHIFT-LEFT-OP : arg0() + SHIFT-RIGHT-OP : arg0() + PAD-OP : arg0() + BIT-AND-OP : arg0() + BIT-OR-OP : arg0() + BIT-XOR-OP : arg0() + CONCAT-OP : arg0() + BIT-SELECT-OP : UIntType(UnknownWidth()) + BITS-SELECT-OP : arg0() + MULTIPLEX-OP : arg0() + LESS-OP : UIntType(UnknownWidth()) + LESS-EQ-OP : UIntType(UnknownWidth()) + GREATER-OP : UIntType(UnknownWidth()) + GREATER-EQ-OP : UIntType(UnknownWidth()) + EQUAL-OP : UIntType(UnknownWidth()) + +defn bundle-field-type (t:Type, n:Symbol) -> Type : + match(t) : + (t:BundleType) : + match(lookup-port(ports(t), n)) : + (p:Port) : type(p) + (p) : UnknownType() + (t) : UnknownType() + +defn vector-elem-type (t:Type) -> Type : + match(t) : + (t:VectorType) : type(t) + (t) : UnknownType() + +;e is the environment that contains all definitions seen so far. +defn infer (c:Command, e:List<KeyValue<Symbol, Type>>) -> [Command, List<KeyValue<Symbol,Type>>] : + defn infer-exp (e:Expression, env:List<KeyValue<Symbol,Type>>) : + match(map(infer-exp{_, env}, e)) : + (e:WRef) : + put-type(e, lookup!(env, name(e))) + (e:WField) : + put-type(e, bundle-field-type(type(exp(e)), name(e))) + (e:WIndex) : + put-type(e, vector-elem-type(type(exp(e)))) + (e:UIntValue) : e + (e:SIntValue) : e + (e:DoPrim) : + put-type(e, infer(op(e), map(type, args(e)))) + (e:ReadPort) : + put-type(e, vector-elem-type(type(mem(e)))) + + defn element-type (e:Element, env:List<KeyValue<Symbol,Type>>) : + match(e) : + (e:Instance) : + val t = type(infer-exp(module(e), env)) + match(t) : + (t:BundleType) : + BundleType $ to-list $ + for p in ports(t) filter : + direction(p) == OUTPUT + (t) : UnknownType() + (e) : type(e) + + match(c) : + (c:LetRec) : + val e* = append(elem-types, e) where : + val elem-types = + for entry in entries(c) map : + key(entry) => element-type(value(entry), e) + val c* = map(infer-exp{_, e*}, c) + val [body*, be] = infer(body(c*), e*) + [LetRec(entries(c*), body*), e] + (c) : + match(map(infer-exp{_, e}, c)) : + (c:DefWire) : + [c, List(entry, e)] where : + val entry = name(c) => type(c) + (c:DefRegister) : + [c, List(entry, e)] where : + val entry = name(c) => type(c) + (c:DefInstance) : + [c, List(entry, e)] where : + val entry = name(c) => type(module(c)) + (c:DefMemory) : + [c, List(entry, e)] where : + val entry = name(c) => type(c) + (c:WDefAccessor) : + [c, List(entry, e)] where : + val src-type = type(source(c)) + val entry = name(c) => vector-elem-type(src-type) + (c:Begin) : + var current-e: List<KeyValue<Symbol,Type>> = e + val body* = for c in body(c) map : + val [c*, e*] = infer(c, current-e) + current-e = e* + c* + [Begin(body*), current-e] + (c) : + defn infer-comm (c:Command) : + val [c* e*] = infer(c, e) + c* + val c* = map(infer-comm, c) + [c*, e] + +defn infer (m:Module, e:List<KeyValue<Symbol, Type>>) -> Module : + val env = append{_, e} $ + for p in ports(m) map : + name(p) => type(p) + val [body*, e*] = infer(body(m), env) + Module(name(m), ports(m), body*) + +defn infer-types (c:Circuit) -> Circuit : + val env = + for m in modules(c) map : + name(m) => BundleType(ports(m)) + Circuit(map(infer{_, env}, modules(c)), + main(c)) + +;============= INFER DIRECTIONS ============================ +defn flip (d:Direction) : + switch {d == _} : + INPUT : OUTPUT + OUTPUT : INPUT + else : d + +defn times (d1:Direction, d2:Direction) : + if d1 == INPUT : flip(d2) + else : d2 + +defn bundle-field-dir (t:Type, n:Symbol) -> Direction : + match(t) : + (t:BundleType) : + match(lookup-port(ports(t), n)) : + (p:Port) : direction(p) + (p) : UNKNOWN-DIR + (t) : UNKNOWN-DIR + +defn infer-dirs (m:Module) : + ;=== Direction of all Binders === + val BI-DIR = new Direction + val directions = HashTable<Symbol,Direction>(symbol-hash) + defn find-dirs (c:Command) : + match(c) : + (c:LetRec) : + for entry in entries(c) do : + directions[key(entry)] = OUTPUT + find-dirs(body(c)) + (c:DefWire) : + directions[name(c)] = BI-DIR + (c:DefRegister) : + directions[name(c)] = BI-DIR + (c:DefInstance) : + directions[name(c)] = OUTPUT + (c:DefMemory) : + directions[name(c)] = BI-DIR + (c:WDefAccessor) : + directions[name(c)] = dir(c) + (c) : + do(find-dirs, children(c)) + for p in ports(m) do : + directions[name(p)] = flip(direction(p)) + find-dirs(body(m)) + + ;=== Fix Point Status === + var changed? = false + + ;=== Infer directions of Expression === + defn infer-exp (e:Expression, desired:Direction) : + match(e) : + (e:WRef) : + val dir* = let : + if kind(e) typeof ModuleKind : + OUTPUT + else : + val old-dir = directions[name(e)] + switch {old-dir == _} : + BI-DIR : + desired + UNKNOWN-DIR : + if directions[name(e)] != desired : + directions[name(e)] = desired + changed? = true + desired + else : + old-dir + WRef(name(e), type(e), kind(e), dir*) + (e:WField) : + val port-dir = bundle-field-dir(type(exp(e)), name(e)) + val exp* = infer-exp(exp(e), port-dir * desired) + WField(exp*, name(e), type(e), port-dir * dir(exp*)) + (e:WIndex) : + val exp* = infer-exp(exp(e), desired) + WIndex(exp*, value(e), type(e), dir(exp*)) + (e) : + map(infer-exp{_, OUTPUT}, e) + + ;=== Infer directions of Commands === + defn infer-comm (c:Command) : + match(c) : + (c:LetRec) : + val c* = map(infer-exp{_, OUTPUT}, c) + LetRec(entries(c*), infer-comm(body(c))) + (c:DefInstance) : + DefInstance(name(c), + infer-exp(module(c), OUTPUT)) + (c:WDefAccessor) : + val d = directions[name(c)] + WDefAccessor(name(c), + infer-exp(source(c), d), + infer-exp(index(c), OUTPUT), + d) + (c:Conditionally) : + Conditionally(infer-exp(pred(c), OUTPUT), + infer-comm(conseq(c)), + infer-comm(alt(c))) + (c:Connect) : + Connect(infer-exp(loc(c), INPUT), infer-exp(exp(c), OUTPUT)) + (c) : + map(infer-comm, c) + + ;=== Iterate until fix point === + defn* fixpoint (c:Command) : + changed? = false + val c* = infer-comm(c) + if changed? : fixpoint(c*) + else : c* + + Module(name(m), ports(m), body*) where : + val body* = fixpoint(body(m)) + +defn infer-directions (c:Circuit) : + Circuit(modules*, main(c)) where : + val modules* = map(infer-dirs, modules(c)) + + +;============== EXPAND VECS ================================ +defstruct ManyConnect <: Command : + index: Expression + locs: List<Expression> + exp: Expression + +defstruct ConnectMany <: Command : + index: Expression + loc: Expression + exps: List<Expression> + +defmethod print (o:OutputStream, c:ManyConnect) : + print-all(o, [locs(c) "[" index(c) "] := " exp(c)]) +defmethod print (o:OutputStream, c:ConnectMany) : + print-all(o, [loc(c) " := " exps(c) "[" index(c) "]"]) + +defmethod map (f: Expression -> Expression, c:ManyConnect) : + ManyConnect(f(index(c)), map(f, locs(c)), f(exp(c))) +defmethod map (f: Expression -> Expression, c:ConnectMany) : + ConnectMany(f(index(c)), f(loc(c)), map(f, exps(c))) + +defn expand-accessors (m: Module) : + defn expand (c:Command) : + match(c) : + (c:WDefAccessor) : + ;Is the source a memory? + val mem? = + match(source(c)) : + (r:WRef) : kind(r) typeof MemKind + (r) : false + + if mem? : + c + else : + switch {dir(c) == _} : + INPUT : + Begin(list( + DefWire(name(c), type(src-type)) + ManyConnect(index(c), elems, wire-ref))) + where : + val src-type = type(source(c)) as VectorType + val wire-ref = WRef(name(c), type(src-type), NodeKind(), OUTPUT) + val elems = to-list $ + for i in 0 to size(src-type) stream : + WIndex(source(c), i, type(src-type), INPUT) + OUTPUT : + Begin(list( + DefWire(name(c), type(src-type)) + ConnectMany(index(c), wire-ref, elems))) + where : + val src-type = type(source(c)) as VectorType + val wire-ref = WRef(name(c), type(src-type), NodeKind(), INPUT) + val elems = to-list $ + for i in 0 to size(src-type) stream : + WIndex(source(c), i, type(src-type), OUTPUT) + (c) : + map(expand, c) + Module(name(m), ports(m), expand(body(m))) + +defn expand-accessors (c:Circuit) : + Circuit(modules*, main(c)) where : + val modules* = map(expand-accessors, modules(c)) + + + +;=============== BUNDLE FLATTENING ========================= +defn prefix (prefix, suffix) : + symbol-join([prefix "/" suffix]) + +defn prefix-ports (pre:Symbol, ports:List<Port>) : + for p in ports map : + Port(prefix(pre, name(p)), direction(p), type(p)) + +defn flatten-ports (port:Port) -> List<Port> : + match(type(port)) : + (t:BundleType) : + val ports = map-append(flatten-ports, ports(t)) + for p in ports map : + Port(prefix(name(port), name(p)), + direction(port) * direction(p), + type(p)) + (t:VectorType) : + val type* = flatten-type(t) + flatten-ports(Port(name(port), direction(port), type*)) + (t:Type) : + list(port) + +defn flatten-type (t:Type) -> Type : + match(t) : + (t:BundleType) : + BundleType $ + map-append(flatten-ports, ports(t)) + (t:VectorType) : + flatten-type $ BundleType $ to-list $ + for i in 0 to size(t) stream : + Port(to-symbol(i), OUTPUT, type(t)) + (t:Type) : + t + +defn flatten-bundles (c:Circuit) : + defn flatten-exp (e:Expression) : + match(map(flatten-exp, e)) : + (e:UIntValue|SIntValue) : + e + (e:WRef) : + match(kind(e)) : + (k:MemKind|StructuralMemKind) : + val type* = map(flatten-type, type(e)) + put-type(e, type*) + (k) : + val type* = flatten-type(type(e)) + put-type(e, type*) + (e) : + val type* = flatten-type(type(e)) + put-type(e, type*) + + defn flatten-element (e:Element) : + val t* = flatten-type(type(e)) + match(map(flatten-exp, e)) : + (e:Register) : Register(t*, value(e), enable(e)) + (e:Memory) : Memory(t*, writers(e)) + (e:Node) : Node(t*, value(e)) + (e:Instance) : Instance(t*, module(e), ports(e)) + + defn flatten-comm (c:Command) : + match(c) : + (c:LetRec) : + val entries* = + for entry in entries(c) map : + key(entry) => flatten-element(value(entry)) + LetRec(entries*, flatten-comm(body(c))) + (c:DefWire) : + DefWire(name(c), flatten-type(type(c))) + (c:DefRegister) : + DefRegister(name(c), flatten-type(type(c))) + (c:DefMemory) : + val type* = map(flatten-type, type(c)) + DefMemory(name(c), type*) + (c) : + map{flatten-comm, _} $ + map(flatten-exp, c) + + defn flatten-module (m:Module) : + val ports* = map-append(flatten-ports, ports(m)) + val body* = flatten-comm(body(m)) + Module(name(m), ports*, body*) + + Circuit(modules*, main(c)) where : + val modules* = map(flatten-module, modules(c)) + + +;================== BUNDLE EXPANSION ======================= +defn expand-bundles (m:Module) : + + ;Collapse all field/index expressions + defn collapse-exp (e:Expression) -> Expression : + match(e) : + (e:WField) : + match(collapse-exp(exp(e))) : + (ei:WRef) : + if kind(ei) typeof InstanceKind : + e + else : + WRef(name*, type(e), kind(ei), dir(e)) where : + val name* = prefix(name(ei), name(e)) + (ei:WField) : + WField(exp(ei), name*, type(e), dir(e)) where : + val name* = prefix(name(ei), name(e)) + (e:WIndex) : + collapse-exp(WField(exp(e), name, type(e), dir(e))) where : + val name = to-symbol(value(e)) + (e) : + map(collapse-exp, e) + + ;Expand expressions + defn expand-exp (e:Expression) -> List<Expression> : + match(type(e)) : + (t:BundleType) : + for p in ports(t) map : + val dir* = direction(p) * dir(e) + collapse-exp(WField(e, name(p), type(p), dir*)) + (t) : + list(collapse-exp(e)) + + ;Expand commands + defn expand-comm (c:Command) : + match(c) : + (c:DefWire) : + match(type(c)) : + (t:BundleType) : + Begin $ + for p in ports(t) map : + DefWire(prefix(name(c), name(p)), type(p)) + (t) : + c + (c:DefRegister) : + match(type(c)) : + (t:BundleType) : + Begin $ + for p in ports(t) map : + DefRegister(prefix(name(c), name(p)), type(p)) + (t) : + c + (c:DefMemory) : + match(type(type(c))) : + (t:BundleType) : + Begin $ + for p in ports(t) map : + DefMemory(prefix(name(c), name(p)), type*) where : + val s = size(type(c)) + val type* = VectorType(type(p), s) + (t) : + c + (c:WDefAccessor) : + match(type(source(c))) : + (t:BundleType) : + val srcs = expand-exp(source(c)) + Begin $ + for (p in ports(t), src in srcs) map : + WDefAccessor(name*, src, index(c), dir*) where : + val name* = prefix(name(c), name(p)) + val dir* = direction(p) * dir(c) + (t) : + c + (c:Connect) : + val locs = expand-exp(loc(c)) + val exps = expand-exp(exp(c)) + Begin $ + for (l in locs, e in exps) map : + switch {dir(l) == _} : + INPUT : Connect(l, e) + OUTPUT : Connect(e, l) + (c:ManyConnect) : + val locs-list = transpose(map(expand-exp, locs(c))) + val exps = expand-exp(exp(c)) + Begin $ + for (locs in locs-list, e in exps) map : + switch {dir(e) == _} : + OUTPUT : ManyConnect(index(c), locs, e) + INPUT : ConnectMany(index(c), e, locs) + (c:ConnectMany) : + val locs = expand-exp(loc(c)) + val exps-list = transpose(map(expand-exp, exps(c))) + Begin $ + for (l in locs, exps in exps-list) map : + switch {dir(l) == _} : + INPUT : ConnectMany(index(c), l, exps) + OUTPUT : ManyConnect(index(c), exps, l) + (c) : + map{expand-comm, _} $ + map(collapse-exp, c) + + Module(name(m), ports(m), expand-comm(body(m))) + +defn expand-bundles (c:Circuit) : + Circuit(modules*, main(c)) where : + val modules* = map(expand-bundles, modules(c)) + + +;=========== CONVERT MULTI CONNECTS to WHEN ================ +defn expand-multi-connects (c:Circuit) : + defn equal-exp (e1:Expression, e2:Expression) : + DoPrim(EQUAL-OP, list(e1, e2), List(), UIntType(UnknownWidth())) + defn uint (i:Int) : + UIntValue(i, UnknownWidth()) + + defn expand-comm (c:Command) : + match(c) : + (c:ConnectMany) : + Begin $ to-list $ + for (i in 0 to false, e in exps(c)) stream : + Conditionally(equal-exp(index(c), uint(i)), + Connect(loc(c), e) + EmptyCommand()) + (c:ManyConnect) : + Begin $ to-list $ + for (i in 0 to false, l in locs(c)) stream : + Conditionally(equal-exp(index(c), uint(i)), + Connect(l, exp(c)) + EmptyCommand()) + (c) : + map(expand-comm, c) + + defn expand (m:Module) : + Module(name(m), ports(m), expand-comm(body(m))) + + Circuit(modules*, main(c)) where : + val modules* = map(expand, modules(c)) + + +;================ EXPAND WHENS ============================= +definterface SymbolicValue +defstruct ExpValue <: SymbolicValue : + exp: Expression +defstruct WhenValue <: SymbolicValue : + pred: Expression + conseq: SymbolicValue + alt: SymbolicValue +defstruct VoidValue <: SymbolicValue + +defmethod print (o:OutputStream, sv:SymbolicValue) : + match(sv) : + (sv:VoidValue) : print(o, "VOID") + (sv:WhenValue) : print-all(o, ["(" pred(sv) "? " conseq(sv) " : " alt(sv) ")"]) + (sv:ExpValue) : print(o, exp(sv)) + +defn key-eqv? (e1:Expression, e2:Expression) : + match(e1, e2) : + (e1:WRef, e2:WRef) : + name(e1) == name(e2) + (e1:WField, e2:WField) : + name(e1) == name(e2) and + key-eqv?(exp(e1), exp(e2)) + (e1, e2) : + false + +defn merge-env (pred: Expression, + con-env: List<KeyValue<Expression,SymbolicValue>>, + alt-env: List<KeyValue<Expression,SymbolicValue>>) : + val merged = Vector<KeyValue<Expression, SymbolicValue>>() + defn new-key? (k:Expression) : + for entry in merged none? : + key-eqv?(key(entry), k) + + defn sv (env:List<KeyValue<Expression,SymbolicValue>>, k:Expression) : + for entry in env search : + if key-eqv?(key(entry), k) : + value(entry) + + for k in stream(key, concat(con-env, alt-env)) do : + if new-key?(k) : + match(sv(con-env, k), sv(alt-env, k)) : + (a:SymbolicValue, b:SymbolicValue) : + if a == b : + add(merged, k => a) + else : + add(merged, k => WhenValue(pred, a, b)) + (a:SymbolicValue, b:False) : + add(merged, k => a) + (a:False, b:SymbolicValue) : + add(merged, k => b) + (a:False, b:False) : + false + + to-list(merged) + +defn simplify-env (env: List<KeyValue<Expression,SymbolicValue>>) : + val merged = Vector<KeyValue<Expression, SymbolicValue>>() + defn new-key? (k:Expression) : + for entry in merged none? : + key-eqv?(key(entry), k) + for entry in env do : + if new-key?(key(entry)) : + add(merged, entry) + to-list(merged) + +defn expand-whens (m:Module) : + val commands = Vector<Command>() + val elements = Vector<KeyValue<Symbol,Element>>() + defn eval (c:Command, env:List<KeyValue<Expression,SymbolicValue>>) -> + List<KeyValue<Expression,SymbolicValue>> : + match(c) : + (c:LetRec) : + do(add{elements, _}, entries(c)) + eval(body(c), env) + (c:DefWire) : + add(commands, c) + val wire-ref = WRef(name(c), type(c), NodeKind(), INPUT) + List(wire-ref => VoidValue(), env) + (c:DefRegister) : + add(commands, c) + val reg-ref = WRef(name(c), type(c), RegKind(), INPUT) + List(reg-ref => VoidValue(), env) + (c:DefInstance) : + add(commands, c) + val entries = let : + val module-type = type(module(c)) as BundleType + val input-ports = to-list $ + for p in ports(module-type) filter : + direction(p) == INPUT + val inst-ref = WRef(name(c), module-type, InstanceKind(), OUTPUT) + for p in input-ports map : + WField(inst-ref, name(p), type(p), INPUT) => VoidValue() + append(entries, env) + (c:DefMemory) : + add(commands, c) + env + (c:WDefAccessor) : + add(commands, c) + if dir(c) == INPUT : + val access-ref = WRef(name(c), type(source(c)), AccessorKind(), dir(c)) + List(access-ref => VoidValue(), env) + else : + env + (c:Conditionally) : + val con-env = eval(conseq(c), env) + val alt-env = eval(alt(c), env) + merge-env(pred(c), con-env, alt-env) + (c:Begin) : + var env:List<KeyValue<Expression,SymbolicValue>> = env + for c in body(c) do : + env = eval(c, env) + env + (c:Connect) : + List(loc(c) => ExpValue(exp(c)), env) + (c:EmptyCommand) : + env + + defn convert-symbolic (key:Expression, sv:SymbolicValue) : + match(sv) : + (sv:VoidValue) : + throw $ PassException $ string-join $ [ + "No default value for " key "."] + (sv:ExpValue) : + exp(sv) + (sv:WhenValue) : + defn multiplex-exp (pred:Expression, conseq:Expression, alt:Expression) : + DoPrim(MULTIPLEX-OP, list(pred, conseq, alt), List(), type(conseq)) + multiplex-exp(pred(sv), + convert-symbolic(key, conseq(sv)) + convert-symbolic(key, alt(sv))) + + ;Compute final environment + val env0 = let : + val output-ports = to-list $ + for p in ports(m) filter : + direction(p) == OUTPUT + for p in output-ports map : + val port-ref = WRef(name(p), type(p), PortKind(), INPUT) + port-ref => VoidValue() + val env* = simplify-env(eval(body(m), env0)) + + ;Make new body + val body* = Begin(list(defs, LetRec(elems, connections))) where : + val defs = Begin(to-list(commands)) + val elems = to-list(elements) + val connections = Begin $ + for entry in env* map : + val sv = convert-symbolic(key(entry), value(entry)) + Connect(key(entry), sv) + + ;Final module + Module(name(m), ports(m), body*) + +defn expand-whens (c:Circuit) : + val modules* = map(expand-whens, modules(c)) + Circuit(modules*, main(c)) + + +;================ STRUCTURAL FORM ========================== + +defn structural-form (m:Module) : + val elements = Vector<(() -> KeyValue<Symbol,Element>)>() + val connected = HashTable<Symbol, Expression>(symbol-hash) + val write-accessors = HashTable<Symbol, List<WDefAccessor>>(symbol-hash) + val read-accessors = HashTable<Symbol, WDefAccessor>(symbol-hash) + val inst-ports = HashTable<Symbol, List<KeyValue<Symbol, Expression>>>(symbol-hash) + val port-connects = Vector<Connect>() + + defn scan (c:Command) : + match(c) : + (c:Connect) : + match(loc(c)) : + (loc:WRef) : + match(kind(loc)) : + (k:PortKind) : add(port-connects, c) + (k) : connected[name(loc)] = exp(c) + (loc:WField) : + val inst = exp(loc) as WRef + val entry = name(loc) => exp(c) + inst-ports[name(inst)] = List(entry, get?(inst-ports, name(inst), List())) + (c:LetRec) : + for e in entries(c) do : + add(elements, {e}) + scan(body(c)) + (c:DefWire) : + add{elements, _} $ fn () : + name(c) => Node(type(c), connected[name(c)]) + (c:DefRegister) : + add{elements, _} $ fn () : + val one = UIntValue(1, UnknownWidth()) + name(c) => Register(type(c), connected[name(c)], one) + (c:DefInstance) : + add{elements, _} $ fn () : + name(c) => Instance(UnknownType(), module(c), inst-ports[name(c)]) + (c:DefMemory) : + add{elements, _} $ fn () : + val ports = for a in get?(write-accessors, name(c), List()) map : + val one = UIntValue(1, UnknownWidth()) + WritePort(index(a), connected[name(a)], one) + name(c) => Memory(type(c), ports) + (c:WDefAccessor) : + val mem = source(c) as WRef + switch {dir(c) == _} : + INPUT : + write-accessors[name(mem)] = List(c, + get?(write-accessors, name(mem), List())) + OUTPUT : + read-accessors[name(c)] = c + (c) : + do(scan, children(c)) + + defn make-read-ports (e:Expression) : + match(e) : + (e:WRef) : + match(kind(e)) : + (k:AccessorKind) : + val accessor = read-accessors[name(e)] + ReadPort(source(accessor), index(accessor), type(e)) + (k) : e + (e) : map(make-read-ports, e) + + Module(name(m), ports(m), body*) where : + scan(body(m)) + val elems = to-list $ + for e in elements stream : + val entry = e() + key(entry) => map(make-read-ports, value(entry)) + val connect-ports = Begin $ to-list $ + for c in port-connects stream : + Connect(loc(c), make-read-ports(exp(c))) + val body* = + if empty?(elems) : connect-ports + else : LetRec(elems, connect-ports) + +defn structural-form (c:Circuit) : + val modules* = map(structural-form, modules(c)) + Circuit(modules*, main(c)) + + +;==================== WIDTH INFERENCE ====================== +defstruct WidthVar <: Width : + name: Symbol + +defmethod print (o:OutputStream, w:WidthVar) : + print(o, name(w)) + +defn width! (t:Type) : + match(t) : + (t:UIntType) : width(t) + (t:SIntType) : width(t) + (t) : error("No width field.") + +defn put-width (t:Type, w:Width) : + match(t) : + (t:UIntType) : UIntType(w) + (t:SIntType) : SIntType(w) + (t) : t + +defn put-width (e:Expression, w:Width) : + val type* = put-width(type(e), w) + put-type(e, type*) + +defn add-width-vars (t:Type) : + defn width? (w:Width) : + match(w) : + (w:UnknownWidth) : WidthVar(gensym()) + (w) : w + match(t) : + (t:UIntType) : UIntType(width?(width(t))) + (t:SIntType) : SIntType(width?(width(t))) + (t) : map(add-width-vars, t) + +defn uint-width (i:Int) : + var v:Int = i + var n:Int = 0 + while v != 0 : + v = v >> 1 + n = n + 1 + IntWidth(n) + +defn sint-width (i:Int) : + if i > 0 : + val w = uint-width(i) + IntWidth(width(w) + 1) + else : + val w = uint-width(neg(i) - 1) + IntWidth(width(w) + 1) + +defn to-exp (w:Width) : + match(w) : + (w:IntWidth) : ELit(width(w)) + (w:WidthVar) : EVar(name(w)) + (w) : error $ string-join $ [ + "Cannot convert " w " to exp."] + +defn primop-width (op:PrimOp, ws:List<Width>, ints:List<Int>) -> Exp : + defn wmax (w1:Width, w2:Width) : + EMax(to-exp(w1), to-exp(w2)) + defn wplus (w1:Width, w2:Width) : + EPlus(to-exp(w1), to-exp(w2)) + defn wplus (w1:Width, w2:Int) : + EPlus(to-exp(w1), ELit(w2)) + defn wminus (w1:Width, w2:Width) : + EMinus(to-exp(w1), to-exp(w2)) + defn wminus (w1:Width, w2:Int) : + EMinus(to-exp(w1), ELit(w2)) + defn wmax-inc (w1:Width, w2:Width) : + EPlus(wmax(w1, w2), ELit(1)) + + switch {op == _} : + ADD-OP : wmax-inc(ws[0], ws[1]) + ADD-MOD-OP : wmax(ws[0], ws[1]) + SUB-OP : wmax-inc(ws[0], ws[1]) + SUB-MOD-OP : wmax(ws[0], ws[1]) + TIMES-OP : wplus(ws[0], ws[1]) + DIVIDE-OP : wminus(ws[0], ws[1]) + MOD-OP : to-exp(ws[1]) + SHIFT-LEFT-OP : wplus(ws[0], ints[0]) + SHIFT-RIGHT-OP : wminus(ws[0], ints[0]) + PAD-OP : ELit(ints[0]) + BIT-AND-OP : wmax(ws[0], ws[1]) + BIT-OR-OP : wmax(ws[0], ws[1]) + BIT-XOR-OP : wmax(ws[0], ws[1]) + CONCAT-OP : wplus(ws[0], ints[0]) + BIT-SELECT-OP : ELit(1) + BITS-SELECT-OP : ELit(ints[0]) + MULTIPLEX-OP : wmax(ws[1], ws[2]) + LESS-OP : ELit(1) + LESS-EQ-OP : ELit(1) + GREATER-OP : ELit(1) + GREATER-EQ-OP : ELit(1) + EQUAL-OP : ELit(1) + +defn put-type (el:Element, t:Type) : + match(el) : + (el:Register) : Register(t, value(el), enable(el)) + (el:Memory) : Memory(t, writers(el)) + (el:Node) : Node(t, value(el)) + (el:Instance) : Instance(t, module(el), ports(el)) + +defn generate-constraints (c:Circuit) -> [Circuit, Vector<WConstraint>] : + ;Constraints + val cs = Vector<WConstraint>() + defn new-constraint (Constraint: (Symbol, Exp) -> WConstraint, wvar:Width, width:Width) : + match(wvar) : + (wvar:WidthVar) : + add(cs, Constraint(name(wvar), to-exp(width))) + (wvar) : + false + + defn to-width (e:Exp) : + match(e) : + (e:ELit) : + IntWidth(width(e)) + (e:EVar) : + WidthVar(name(e)) + (e) : + val x = gensym() + add(cs, WidthEqual(x, e)) + WidthVar(x) + + ;Module types + val mod-types = HashTable<Symbol,Type>(symbol-hash) + + defn add-port-vars (m:Module) -> Module : + val ports* = + for p in ports(m) map : + val type* = add-width-vars(type(p)) + Port(name(p), direction(p), type*) + mod-types[name(m)] = BundleType(ports*) + Module(name(m), ports*, body(m)) + + ;Add Width Variables + defn add-module-vars (m:Module) -> Module : + val types = HashTable<Symbol,Type>(symbol-hash) + for p in ports(m) do : + types[name(p)] = type(p) + + defn infer-exp-width (e:Expression) : + match(map(infer-exp-width, e)) : + (e:WRef) : + match(kind(e)) : + (k:ModuleKind) : put-type(e, mod-types[name(e)]) + (k) : put-type(e, types[name(e)]) + (e:WField) : + val t = bundle-field-type(type(exp(e)), name(e)) + put-width(e, width!(t)) + (e:UIntValue) : + match(width(e)) : + (w:UnknownWidth) : UIntValue(value(e), uint-width(value(e))) + (w) : e + (e:SIntValue) : + match(width(e)) : + (w:UnknownWidth) : SIntValue(value(e), sint-width(value(e))) + (w) : e + (e:DoPrim) : + val widths = map(width!{type(_)}, args(e)) + val w = to-width(primop-width(op(e), widths, consts(e))) + put-width(e, w) + (e:ReadPort) : + val elem-type = type(type(mem(e)) as VectorType) + put-width(e, width!(elem-type)) + + defn infer-comm-width (c:Command) : + match(c) : + (c:LetRec) : + ;Add width vars to elements + var entries*: List<KeyValue<Symbol,Element>> = + for entry in entries(c) map : + val el-name = key(entry) + key(entry) => + match(value(entry)) : + (el:Register|Node) : + put-type(el, add-width-vars(type(el))) + (el:Memory) : + el + (el:Instance) : + val mod-type = type(infer-exp-width(module(el))) as BundleType + val type = BundleType $ to-list $ + for p in ports(mod-type) filter : + direction(p) == OUTPUT + put-type(el, type) + + ;Add vars to environment + for entry in entries* do : + types[key(entry)] = type(value(entry)) + + ;Infer types for elements + entries* = + for entry in entries* map : + key(entry) => map(infer-exp-width, value(entry)) + + ;Generate constraints + for entry in entries* do : + val el-name = key(entry) + match(value(entry)) : + (el:Register) : + new-constraint(WidthEqual, reg-width, val-width) where : + val reg-width = width!(types[el-name]) + val val-width = width!(type(value(el))) + (el:Node) : + new-constraint(WidthEqual, node-width, val-width) where : + val node-width = width!(types[el-name]) + val val-width = width!(type(value(el))) + (el:Instance) : + val mod-type = type(module(el)) as BundleType + for entry in ports(el) do : + new-constraint(WidthGreater, port-width, val-width) where : + val port-name = key(entry) + val port-width = width!(bundle-field-type(mod-type, port-name)) + val val-width = width!(type(value(entry))) + (el) : false + + ;Analyze body + LetRec(entries*, infer-comm-width(body(c))) + + (c:Connect) : + val loc* = infer-exp-width(loc(c)) + val exp* = infer-exp-width(exp(c)) + new-constraint(WidthGreater, loc-width, exp-width) where : + val loc-width = width!(type(loc*)) + val exp-width = width!(type(exp*)) + Connect(loc*, exp*) + + (c:Begin) : + map(infer-comm-width, c) + + Module(name(m), ports(m), body*) where : + val body* = infer-comm-width(body(m)) + + val c* = + Circuit(modules*, main(c)) where : + val ms = map(add-port-vars, modules(c)) + val modules* = map(add-module-vars, ms) + [c*, cs] + + +;================== FILL WIDTHS ============================ +defn fill-widths (c:Circuit, solved:Streamable<WidthEqual>) : + ;Populate table + val table = HashTable<Symbol, Width>(symbol-hash) + for eq in solved do : + table[name(eq)] = IntWidth(width(value(eq) as ELit)) + + defn width? (w:Width) : + match(w) : + (w:WidthVar) : get?(table, name(w), UnknownWidth()) + (w) : w + + defn fill-type (t:Type) : + match(t) : + (t:UIntType) : UIntType(width?(width(t))) + (t:SIntType) : SIntType(width?(width(t))) + (t) : map(fill-type, t) + + defn fill-exp (e:Expression) : + val e* = map(fill-exp, e) + val type* = fill-type(type(e)) + put-type(e*, type*) + + defn fill-element (e:Element) : + val e* = map(fill-exp, e) + val type* = fill-type(type(e)) + put-type(e*, type*) + + defn fill-comm (c:Command) : + match(c) : + (c:LetRec) : + val entries* = + for e in entries(c) map : + key(e) => fill-element(value(e)) + LetRec(entries*, fill-comm(body(c))) + (c) : + map{fill-comm, _} $ + map(fill-exp, c) + + defn fill-port (p:Port) : + Port(name(p), direction(p), fill-type(type(p))) + + defn fill-mod (m:Module) : + Module(name(m), ports*, body*) where : + val ports* = map(fill-port, ports(m)) + val body* = fill-comm(body(m)) + + Circuit(modules*, main(c)) where : + val modules* = map(fill-mod, modules(c)) + + +;=============== TYPE INFERENCE DRIVER ===================== +defn infer-widths (c:Circuit) : + val [c*, cs] = generate-constraints(c) + val solved = solve-widths(cs) + fill-widths(c*, solved) + + +;================ PAD WIDTHS =============================== +defn pad-widths (c:Circuit) : + ;Pad an expression to the given width + defn pad-exp (e:Expression, w:Int) : + match(type(e)) : + (t:UIntType|SIntType) : + val prev-w = width!(t) as IntWidth + if width(prev-w) < w : + val type* = put-width(t, IntWidth(w)) + DoPrim(PAD-OP, list(e), list(w), type*) + else : + e + (t) : + e + + defn pad-exp (e:Expression, w:Width) : + val w-value = width(w as IntWidth) + pad-exp(e, w-value) + + ;Convenience + defn max-width (es:Streamable<Expression>) : + defn int-width (e:Expression) : + width(width!(type(e)) as IntWidth) + maximum(stream(int-width, es)) + + defn match-widths (es:List<Expression>) : + val w = max-width(es) + map(pad-exp{_, w}, es) + + ;Match widths for an expression + defn match-exp-width (e:Expression) : + match(map(match-exp-width, e)) : + (e:DoPrim) : + if contains?([BIT-AND-OP, BIT-OR-OP, BIT-XOR-OP, EQUAL-OP], op(e)) : + val args* = match-widths(args(e)) + DoPrim(op(e), args*, consts(e), type(e)) + else if op(e) == MULTIPLEX-OP : + val args* = List(head(args(e)), match-widths(tail(args(e)))) + DoPrim(op(e), args*, consts(e), type(e)) + else : + e + (e) : e + + defn match-element-width (e:Element) : + match(map(match-exp-width, e)) : + (e:Register) : + val w = width!(type(e)) + val value* = pad-exp(value(e), w) + Register(type(e), value*, enable(e)) + (e:Memory) : + val width = width!(type(type(e) as VectorType)) + val writers* = + for w in writers(e) map : + WritePort(index(w), pad-exp(value(w), width), enable(w)) + Memory(type(e), writers*) + (e:Node) : + val w = width!(type(e)) + val value* = pad-exp(value(e), w) + Node(type(e), value*) + (e:Instance) : + val mod-type = type(module(e)) as BundleType + val ports* = + for p in ports(e) map : + val port-type = bundle-field-type(mod-type, key(p)) + val port-val = pad-exp(value(p), width!(port-type)) + key(p) => port-val + Instance(type(e), module(e), ports*) + + ;Match widths for a command + defn match-comm-width (c:Command) : + match(map(match-exp-width, c)) : + (c:LetRec) : + val entries* = + for e in entries(c) map : + key(e) => match-element-width(value(e)) + LetRec(entries*, match-comm-width(body(c))) + (c:Connect) : + val w = width!(type(loc(c))) + val exp* = pad-exp(exp(c), w) + Connect(loc(c), exp*) + (c) : + map(match-comm-width, c) + + defn match-mod-width (m:Module) : + Module(name(m), ports(m), body*) where : + val body* = match-comm-width(body(m)) + + Circuit(modules*, main(c)) where : + val modules* = map(match-mod-width, modules(c)) + + +;================== INLINING =============================== +defn inline-instances (c:Circuit) : + val module-table = HashTable<Symbol,Module>(symbol-hash) + val inlined? = HashTable<Symbol,True|False>(symbol-hash) + for m in modules(c) do : + module-table[name(m)] = m + inlined?[name(m)] = false + + ;Convert a module into a sequence of elements + defn to-elements (m:Module, + inst:Symbol, + port-exps:List<KeyValue<Symbol,Expression>>) -> + List<KeyValue<Symbol, Element>> : + defn rename-exp (e:Expression) : + match(e) : + (e:WRef) : WRef(prefix(inst, name(e)), type(e), kind(e), dir(e)) + (e) : map(rename-exp, e) + + defn to-elements (c:Command) -> List<KeyValue<Symbol,Element>> : + match(c) : + (c:LetRec) : + val entries* = + for entry in entries(c) map : + val name* = prefix(inst, key(entry)) + val element* = map(rename-exp, value(entry)) + name* => element* + val body* = to-elements(body(c)) + append(entries*, body*) + (c:Connect) : + val ref = loc(c) as WRef + val name* = prefix(inst, name(ref)) + list(name* => Node(type(exp(c)), rename-exp(exp(c)))) + (c:Begin) : + map-append(to-elements, body(c)) + + val inputs = + for p in ports(m) map-append : + if direction(p) == INPUT : + val port-exp = lookup!(port-exps, name(p)) + val name* = prefix(inst, name(p)) + list(name* => Node(type(port-exp), port-exp)) + else : + List() + append(inputs, to-elements(body(m))) + + ;Inline all instances in the module + defn inline-instances (m:Module) : + defn rename-exp (e:Expression) : + match(e) : + (e:WField) : + val inst-exp = exp(e) as WRef + val name* = prefix(name(inst-exp), name(e)) + WRef(name*, type(e), NodeKind(), dir(e)) + (e) : + map(rename-exp, e) + + defn inline-elems (es:List<KeyValue<Symbol,Element>>) : + for entry in es map-append : + match(value(entry)) : + (el:Instance) : + val mod-name = name(module(el) as WRef) + val module = inlined-module(mod-name) + to-elements(module, key(entry), ports(el)) + (el) : + list(entry) + + defn inline-comm (c:Command) : + match(map(rename-exp, c)) : + (c:LetRec) : + val entries* = inline-elems(entries(c)) + LetRec(entries*, inline-comm(body(c))) + (c) : + map(inline-comm, c) + + Module(name(m), ports(m), inline-comm(body(m))) + + ;Retrieve the inlined instance of a module + defn inlined-module (name:Symbol) : + if inlined?[name] : + module-table[name] + else : + val module* = inline-instances(module-table[name]) + module-table[name] = module* + inlined?[name] = true + module* + + ;Return the fully inlined circuit + val main-module = inlined-module(main(c)) + Circuit(list(main-module), main(c)) + + +;;;================ UTILITIES ================================ +; +; +; +;defn* root-ref (i:Immediate) : +; match(i) : +; (f:Field) : root-ref(imm(f)) +; (ind:Index) : root-ref(imm(ind)) +; (r) : r +; +;;defn lookup<?T> (e: Streamable<KeyValue<Immediate,?T>>, i:Immediate) : +;; for entry in e search : +;; if eqv?(key(entry), i) : +;; value(entry) +;; +;;defn lookup!<?T> (e: Streamable<KeyValue<Immediate,?T>>, i:Immediate) : +;; lookup(e, i) as T +;; +;;============ CHECK IF NAMES ARE UNIQUE ==================== +;defn check-duplicate-symbols (names: Streamable<Symbol>, msg: String) : +; val dict = HashTable<Symbol, True>(symbol-hash) +; for name in names do: +; if key?(dict, name): +; throw $ PassException $ string-join $ +; [msg ": " name] +; else: +; dict[name] = true +; +;defn check-duplicates (t: Type) : +; match(t) : +; (t:BundleType) : +; val names = map(name, ports(t)) +; check-duplicate-symbols{names, string-join(_)} $ +; ["Duplicate port name in bundle "] +; do(check-duplicates{type(_)}, ports(t)) +; (t:VectorType) : +; check-duplicates(type(t)) +; (t) : false +; +;defn check-duplicates (c: Command) : +; match(c) : +; (c:DefWire) : check-duplicates(type(c)) +; (c:DefRegister) : check-duplicates(type(c)) +; (c:DefMemory) : check-duplicates(type(c)) +; (c) : do(check-duplicates, children(c)) +; +;defn defined-names (c: Command) : +; generate<Symbol> : +; loop(c) where : +; defn loop (c:Command) : +; match(c) : +; (c:Command&HasName) : yield(name(c)) +; (c) : do(loop, children(c)) +; +;defn check-duplicates (m: Module): +; ;Check all duplicate names in all types in all ports and body +; do(check-duplicates{type(_)}, ports(m)) +; check-duplicates(body(m)) +; +; ;Check all names defined in module +; val names = concat(stream(name, ports(m)), +; defined-names(body(m))) +; check-duplicate-symbols{names, string-join(_)} $ +; ["Duplicate definition name in module " name(m)] +; +;defn check-duplicates (c: Circuit) : +; ;Check all duplicate names in all modules +; do(check-duplicates, modules(c)) +; +; ;Check all defined modules +; val names = stream(name, modules(c)) +; check-duplicate-symbols(names, "Duplicate module name") +; +; + + + + + + + +;;================ CLEANUP COMMANDS ========================= +;defn cleanup (c:Command) : +; match(c) : +; (c:Begin) : +; to-command $ generate<Command> : +; loop(c) where : +; defn loop (c:Command) : +; match(c) : +; (c:Begin) : do(loop, body(c)) +; (c:EmptyCommand) : false +; (c) : yield(cleanup(c)) +; (c) : map(cleanup{_ as Command}, c) +; +;defn cleanup (c:Circuit) : +; val modules* = +; for m in modules(c) map : +; map(cleanup, m) +; Circuit(modules*, main(c)) +; + + + +;;;============= SHIM ======================================== +;;defn shim (i:Immediate) -> Immediate : +;; match(i) : +;; (i:RegData) : +;; Ref(name(i), direction(i), type(i)) +;; (i:InstPort) : +;; val inst = Ref(name(i), UNKNOWN-DIR, UnknownType()) +;; Field(inst, port(i), direction(i), type(i)) +;; (i:Field) : +;; val imm* = shim(imm(i)) +;; put-imm(i, imm*) +;; (i) : i +;; +;;defn shim (c:Command) -> Command : +;; val c* = map(shim{_ as Immediate}, c) +;; map(shim{_ as Command}, c*) +;; +;;defn shim (c:Circuit) -> Circuit : +;; val modules* = +;; for m in modules(c) map : +;; Module(name(m), ports(m), shim(body(m))) +;; Circuit(modules*, main(c)) +;; +;;;================== INLINE MODULES ========================= +;;defn cat-name (p: String|Symbol, s: String|Symbol) -> Symbol : +;; if p == "" or p == `this : ;; TODO: REMOVE THIS WHEN `THIS GETS REMOVED +;; to-symbol(s) +;; else if s == `this : ;; TODO: DITTO +;; to-symbol(p) +;; else : +;; symbol-join([p, "/", s]) +;; +;;defn inline-command (c: Command, mods: HashTable<Symbol, Module>, prefix: String, cmds: Vector<Command>) : +;; defn rename (n: Symbol) -> Symbol : +;; cat-name(prefix, n) +;; defn inline-name (i:Immediate) -> Symbol : +;; match(i) : +;; (r:Ref) : rename(name(r)) +;; (f:Field) : cat-name(inline-name(imm(f)), name(f)) +;; (f:Index) : cat-name(inline-name(imm(f)), to-string(value(f))) +;; defn inline-imm (i:Immediate) -> Ref : +;; Ref(inline-name(i), direction(i), type(i)) +;; match(c) : +;; (c:DefUInt) : add(cmds, DefUInt(rename(name(c)), value(c), width(c))) +;; (c:DefSInt) : add(cmds, DefSInt(rename(name(c)), value(c), width(c))) +;; (c:DefWire) : add(cmds, DefWire(rename(name(c)), type(c))) +;; (c:DefRegister) : add(cmds, DefRegister(rename(name(c)), type(c))) +;; (c:DefMemory) : add(cmds, DefMemory(rename(name(c)), type(c), size(c))) +;; (c:DefInstance) : inline-module(mods, mods[name(module(c))], to-string(rename(name(c))), cmds) +;; (c:DoPrim) : add(cmds, DoPrim(rename(name(c)), op(c), map(inline-imm, args(c)), consts(c))) +;; (c:DefAccessor) : add(cmds, DefAccessor(rename(name(c)), inline-imm(source(c)), direction(c), inline-imm(index(c)))) +;; (c:Connect) : add(cmds, Connect(inline-imm(loc(c)), inline-imm(exp(c)))) +;; (c:Begin) : do(inline-command{_, mods, prefix, cmds}, body(c)) +;; (c:EmptyCommand) : c +;; (c) : error("Unsupported command") +;; +;;defn inline-port (p: Port, prefix: String) -> Command : +;; DefWire(cat-name(prefix, name(p)), type(p)) +;; +;;defn inline-module (mods: HashTable<Symbol, Module>, mod: Module, prefix: String, cmds: Vector<Command>) : +;; do(add{cmds, _}, map(inline-port{_, prefix}, ports(mod))) +;; inline-command(body(mod), mods, prefix, cmds) +;; +;;defn inline-modules (c: Circuit) -> Circuit : +;; val cmds = Vector<Command>() +;; val mods = HashTable<Symbol, Module>(symbol-hash) +;; for mod in modules(c) do : +;; mods[name(mod)] = mod +;; val top = mods[main(c)] +;; inline-command(body(top), mods, "", cmds) +;; val main* = Module(name(top), ports(top), Begin(to-list(cmds))) +;; Circuit(list(main*), name(top)) +;; +;; +;;;============= FLO PRINTER ====================================== +;;;;; TODO: +;;;;; not supported gt, lte +;; +;;defn flo-op-name (op:PrimOp) -> String : +;; switch {op == _} : +;; ADD-OP : "add" +;; ADD-MOD-OP : "add" +;; MINUS-OP : "sub" +;; SUB-MOD-OP : "sub" +;; TIMES-OP : "mul" ;; todo: signed version +;; DIVIDE-OP : "div" ;; todo: signed version +;; MOD-OP : "mod" ;; todo: signed version +;; SHIFT-LEFT-OP : "lsh" ;; todo: signed version +;; SHIFT-RIGHT-OP : "rsh" +;; PAD-OP : "pad" ;; todo: signed version +;; BIT-AND-OP : "and" +;; BIT-OR-OP : "or" +;; BIT-XOR-OP : "xor" +;; CONCAT-OP : "cat" +;; BIT-SELECT-OP : "rsh" +;; BITS-SELECT-OP : "rsh" +;; LESS-OP : "lt" ;; todo: signed version +;; LESS-EQ-OP : "lte" ;; todo: swap args +;; GREATER-OP : "gt" ;; todo: swap args +;; GREATER-EQ-OP : "gte" ;; todo: signed version +;; EQUAL-OP : "eq" +;; MULTIPLEX-OP : "mux" +;; else : error $ string-join $ +;; ["Unable to print Primop: " op] +;; +;;defn emit (o:OutputStream, top:Symbol, ports:HashTable<Symbol, Port>, lits:HashTable<Symbol, DefUInt>, elt) : +;; match(elt) : +;; (e:String|Symbol|Int) : +;; print(o, e) +;; (e:Ref) : +;; if key?(lits, name(e)) : +;; val lit = lits[name(e)] +;; print-all(o, [value(lit) "'" width(lit)]) +;; else : +;; if key?(ports, name(e)) : +;; print-all(o, [top "::"]) +;; print(o, name(e)) +;; (e:IntWidth) : +;; print(o, value(e)) +;; (e:PrimOp) : +;; print(o, flo-op-name(e)) +;; (e) : +;; println-all(["EMIT " e]) +;; error("Unable to emit") +;; +;;defn emit-all (o:OutputStream, top:Symbol, ports:HashTable<Symbol, Port>, lits:HashTable<Symbol, DefUInt>, elts: Streamable) : +;; for e in elts do : emit(o, top, ports, lits, e) +;; +;;defn prim-width (type:Type) -> Width : +;; match(type) : +;; (t:UIntType) : width(t) +;; (t:SIntType) : width(t) +;; (t) : error("Bad prim width type") +;; +;;defn emit-command (o:OutputStream, cmd:Command, top:Symbol, lits:HashTable<Symbol, DefUInt>, regs:HashTable<Symbol, DefRegister>, accs:HashTable<Symbol, DefAccessor>, ports:HashTable<Symbol, Port>, outs:HashTable<Symbol, Port>) : +;; match(cmd) : +;; (c:DefUInt) : +;; lits[name(c)] = c +;; (c:DefSInt) : +;; emit-all(o, top, ports, lits, [name(c) " = " value(c) "'" width(c) "\n"]) +;; (c:DoPrim) : ;; NEED TO FIGURE OUT WHEN WIDTHS ARE NECESSARY AND EXTRACT +;; emit-all(o, top, ports, lits, [name(c) " = " op(c)]) +;; for arg in args(c) do : +;; print(o, " ") +;; emit(o, top, ports, lits, arg) +;; for const in consts(c) do : +;; print(o, " ") +;; emit(o, top, ports, lits, const) +;; print("\n") +;; (c:DefRegister) : +;; regs[name(c)] = c +;; (c:DefMemory) : +;; emit-all(o, top, ports, lits, [name(c) " : mem'" prim-width(type(c)) " " size(c) "\n"]) +;; (c:DefAccessor) : +;; accs[name(c)] = c +;; (c:Connect) : +;; val dst = name(loc(c) as Ref) +;; val src = name(exp(c) as Ref) +;; if key?(regs, dst) : +;; val reg = regs[dst] +;; emit-all(o, top, ports, lits, [dst " = reg'" prim-width(type(reg)) " 0'" prim-width(type(reg)) " " exp(c) "\n"]) +;; else if key?(accs, dst) : +;; val acc = accs[dst] +;; ;; assert(direction(acc) == OUTPUT) +;; emit-all(o, top, ports, lits, [dst " = wr " source(acc) " " index(acc) " " exp(c) "\n"]) +;; else if key?(outs, dst) : +;; val out = outs[dst] +;; emit-all(o, top, ports, lits, [top "::" dst " = out'" prim-width(type(out)) " " exp(c) "\n"]) +;; else if key?(accs, src) : +;; val acc = accs[src] +;; ;; assert(direction(acc) == INPUT) +;; emit-all(o, top, ports, lits, [dst " = rd " source(acc) " " index(acc) "\n"]) +;; else : +;; emit-all(o, top, ports, lits, [dst " = mov " exp(c) "\n"]) +;; (c:Begin) : +;; do(emit-command{o, _, top, lits, regs, accs, ports, outs}, body(c)) +;; (c:DefWire|EmptyCommand) : +;; print("") +;; (c) : +;; error("Unable to print command") +;; +;;defn emit-module (o:OutputStream, m:Module) : +;; val regs = HashTable<Symbol, DefRegister>(symbol-hash) +;; val accs = HashTable<Symbol, DefAccessor>(symbol-hash) +;; val lits = HashTable<Symbol, DefUInt>(symbol-hash) +;; val outs = HashTable<Symbol, Port>(symbol-hash) +;; val portz = HashTable<Symbol, Port>(symbol-hash) +;; for port in ports(m) do : +;; portz[name(port)] = port +;; if direction(port) == OUTPUT : +;; outs[name(port)] = port +;; else if name(port) == `reset : +;; print-all(o, [name(m) "::reset = rst\n"]) +;; else : +;; print-all(o, [name(m) "::" name(port) " = " "in'" prim-width(type(port)) "\n"]) +;; emit-command(o, body(m), name(m), lits, regs, accs, portz, outs) +;; +;;public defn emit-circuit (o:OutputStream, c:Circuit) : +;; emit-module(o, modules(c)[0]) + + +;============= DRIVER ====================================== +public defn run-passes (c: Circuit) : + var c*:Circuit = c + defn do-stage (name:String, f: Circuit -> Circuit) : + println(name) + c* = f(c*) + println(c*) + println("\n\n\n\n") + + do-stage("Working IR", to-working-ir) + do-stage("Resolve Kinds", resolve-kinds) + do-stage("Make Explicit Reset", make-explicit-reset) + do-stage("Infer Types", infer-types) + do-stage("Infer Directions", infer-directions) + do-stage("Expand Accessors", expand-accessors) + do-stage("Flatten Bundles", flatten-bundles) + do-stage("Expand Bundles", expand-bundles) + do-stage("Expand Multi Connects", expand-multi-connects) + do-stage("Expand Whens", expand-whens) + do-stage("Structural Form", structural-form) + do-stage("Infer Widths", infer-widths) + do-stage("Pad Widths", pad-widths) + do-stage("Inline Instances", inline-instances) + + + ;; println("Shim for Jonathan's Passes") + ;; c* = shim(c*) + ;; println("Inline Modules") + ;; c* = inline-modules(c*) + ; c* diff --git a/src/main/stanza/widthsolver.stanza b/src/main/stanza/widthsolver.stanza new file mode 100644 index 00000000..4c9da2c6 --- /dev/null +++ b/src/main/stanza/widthsolver.stanza @@ -0,0 +1,298 @@ +;Define the STANDALONE flag to run STANDALONE +if-defined(STANDALONE) : + include<"core/stringeater.stanza"> + include<"compiler/lexer.stanza"> + +defpackage widthsolver : + import core + import verse + import stanza.lexer + +;============= Language of Constraints ====================== +public definterface WConstraint +public defstruct WidthEqual <: WConstraint : + name: Symbol + value: Exp +public defstruct WidthGreater <: WConstraint : + name: Symbol + value: Exp + +defmethod print (o:OutputStream, c:WConstraint) : + print-all{o, _} $ + match(c) : + (c:WidthEqual) : [name(c) " = " value(c)] + (c:WidthGreater) : [name(c) " >= " value(c)] + +defn construct-eqns (cs: Streamable<WConstraint>) : + val eqns = HashTable<Symbol, False|Exp>(symbol-hash) + val lower-bounds = HashTable<Symbol, List<Exp>>(symbol-hash) + for c in cs do : + match(c) : + (c:WidthEqual) : + eqns[name(c)] = value(c) + (c:WidthGreater) : + lower-bounds[name(c)] = + List(value(c), + get?(lower-bounds, name(c), List())) + + ;Create minimum expressions for lower-bounds + for entry in lower-bounds do : + val v = key(entry) + val exps = value(entry) + if not key?(eqns, v) : + eqns[v] = reduce(EMax, ELit(0), exps) + + ;Return equations + eqns +;============================================================ + +;============= Language of Expressions ====================== +public definterface Exp +public defstruct EVar <: Exp : + name: Symbol +public defstruct EMax <: Exp : + a: Exp + b: Exp +public defstruct EPlus <: Exp : + a: Exp + b: Exp +public defstruct EMinus <: Exp : + a: Exp + b: Exp +public defstruct ELit <: Exp : + width: Int + +defmethod print (o:OutputStream, e:Exp) : + match(e) : + (e:EVar) : print(o, name(e)) + (e:EMax) : print-all(o, ["max(" a(e) ", " b(e) ")"]) + (e:EPlus) : print-all(o, [a(e) " + " b(e)]) + (e:EMinus) : print-all(o, [a(e) " - " b(e)]) + (e:ELit) : print(o, width(e)) + +defn map (f: (Exp) -> Exp, e: Exp) -> Exp : + match(e) : + (e:EMax) : EMax(f(a(e)), f(b(e))) + (e:EPlus) : EPlus(f(a(e)), f(b(e))) + (e:EMinus) : EMinus(f(a(e)), f(b(e))) + (e:Exp) : e + +defn children (e: Exp) -> List<Exp> : + match(e) : + (e:EMax) : list(a(e), b(e)) + (e:EPlus) : list(a(e), b(e)) + (e:EMinus) : list(a(e), b(e)) + (e:Exp) : list() +;============================================================ + +;================== Reading from File ======================= +defn read-exp (x) : + match(unwrap-token(x)) : + (x:Symbol) : + EVar(x) + (x:Int) : + ELit(x) + (x:List) : + val tag = unwrap-token(x[1]) + switch {tag == _} : + `plus : EPlus(read-exp(x[2]), read-exp(x[3])) + `minus : EMinus(read-exp(x[2]), read-exp(x[3])) + `max : EMax(read-exp(x[2]), read-exp(x[3])) + else : error $ string-join $ + ["Improper expression: " x] + +defn read (filename: String) : + var form:List = lex-file(filename) + val cs = Vector<WConstraint>() + while not empty?(form) : + val x = unwrap-token(form[0]) + val op = form[1] + val e = read-exp(form[2]) + form = tailn(form, 3) + add{cs, _} $ + switch {unwrap-token(op) == _} : + `= : WidthEqual(x, e) + `>= : WidthGreater(x, e) + else : error $ string-join $ ["Unsupported Operator: " op] + cs +;============================================================ + +;============ Operations on Expressions ===================== +defn occurs? (v: Symbol, exp: Exp) : + match(exp) : + (exp: EVar) : name(exp) == v + (exp: Exp) : any?(occurs?{v, _}, children(exp)) + +defn freevars (exp: Exp) : + to-list $ generate<Symbol> : + defn loop (exp: Exp) : + match(exp) : + (exp: EVar) : yield(name(exp)) + (exp: Exp) : do(loop, children(exp)) + loop(exp) + +defn contains-only-max? (exp: Exp) : + match(exp) : + (exp:EVar|EMax|ELit) : all?(contains-only-max?, children(exp)) + (exp) : false + +defn simplify (exp: Exp) : + match(map(simplify,exp)) : + (exp: EPlus) : + match(a(exp), b(exp)) : + (a: ELit, b: ELit) : + ELit(width(a) + width(b)) + (a: ELit, b) : + if width(a) == 0 : b + else : exp + (a, b: ELit) : + if width(b) == 0 : a + else : exp + (a, b) : + exp + (exp: EMinus) : + match(a(exp), b(exp)) : + (a: ELit, b: ELit) : + ELit(width(a) - width(b)) + (a, b: ELit) : + if width(b) == 0 : a + else : exp + (a, b) : + exp + (exp: EMax) : + match(a(exp), b(exp)) : + (a: ELit, b: ELit) : + ELit(max(width(a), width(b))) + (a: ELit, b) : + if width(a) == 0 : b + else : exp + (a, b: ELit) : + if width(b) == 0 : a + else : exp + (a, b) : + exp + (exp: Exp) : + exp + +defn eval (exp: Exp, state: HashTable<Symbol,Int>) -> Int : + defn loop (e: Exp) -> Int : + match(e) : + (e: EVar) : state[name(e)] + (e: EMax) : max(loop(a(e)), loop(b(e))) + (e: EPlus) : loop(a(e)) + loop(b(e)) + (e: EMinus) : loop(a(e)) - loop(b(e)) + (e: ELit) : width(e) + loop(exp) +;============================================================ + + +;================ Constraint Solver ========================= +defn substitute (solns: HashTable<Symbol, Exp>, exp: Exp) : + match(exp) : + (exp: EVar) : + match(get?(solns, name(exp), false)) : + (s:Exp) : substitute(solns, s) + (f:False) : exp + (exp) : + map(substitute{solns, _}, exp) + +defn dataflow (eqns: HashTable<Symbol, False|Exp>, solns: HashTable<Symbol,Exp>) : + var progress?:True|False = false + for entry in eqns do : + if value(entry) != false : + val v = key(entry) + val exp = simplify(substitute(solns, value(entry) as Exp)) + if occurs?(v, exp) : + eqns[v] = exp + else : + eqns[v] = false + solns[v] = exp + progress? = true + progress? + +defn fixpoint (eqns: HashTable<Symbol, False|Exp>, solns: HashTable<Symbol,Exp>) : + label<False|True> break : + for v in keys(eqns) do : + if eqns[v] != false : + val fix-eqns = fixpoint-eqns(v, eqns) + val has-fixpoint? = all?(contains-only-max?{value(_)}, fix-eqns) + if has-fixpoint? : + val soln = solve-fixpoint(fix-eqns) + for s in soln do : + solns[key(s)] = ELit(value(s)) + eqns[key(s)] = false + break(true) + false + +defn fixpoint-eqns (v: Symbol, eqns: HashTable<Symbol,False|Exp>) : + val vs = HashTable<Symbol,Exp>(symbol-hash) + defn loop (v: Symbol) : + if not key?(vs, v) : + val eqn = eqns[v] as Exp + vs[v] = eqn + do(loop, freevars(eqn)) + loop(v) + to-list(vs) + +defn solve-fixpoint (eqns: List<KeyValue<Symbol,Exp>>) : + ;Solve for fixpoint + val sol = HashTable<Symbol,Int>(symbol-hash) + do({sol[key(_)] = 0}, eqns) + defn loop () : + var progress?:True|False = false + for eqn in eqns do : + val v = key(eqn) + val x = eval(value(eqn), sol) + if x != sol[v] : + sol[v] = x + progress? = true + progress? + while loop() : false + + ;Return solutions + to-list(sol) + +defn backsubstitute (vs:Streamable<Symbol>, solns: HashTable<Symbol,Exp>) : + val widths = HashTable<Symbol,False|Int>(symbol-hash) + defn get-width (v:Symbol) : + if key?(solns, v) : + val vs = freevars(solns[v]) + ;Calculate dependencies + for v in vs do : + if not key?(widths, v) : + widths[v] = get-width(v) + ;Compute value + if none?({widths[_] == false}, vs) : + eval(solns[v], widths as HashTable<Symbol,Int>) + + ;Compute all widths + for v in vs do : + widths[v] = get-width(v) + + ;Return widths + to-list $ generate<WidthEqual> : + for entry in widths do : + if value(entry) != false : + yield $ WidthEqual(key(entry), ELit(value(entry) as Int)) + +public defn solve-widths (cs: Streamable<WConstraint>) : + ;Copy to new hashtable + val eqns = construct-eqns(cs) + val solns = HashTable<Symbol,Exp>(symbol-hash) + defn loop () : + dataflow(eqns, solns) or + fixpoint(eqns, solns) + while loop() : false + backsubstitute(keys(eqns), solns) + +;================= Main ===================================== +if-defined(STANDALONE) : + defn main () : + val input = lex(commandline-arguments()) + error("No input file!") when length(input) < 2 + val cs = read(to-string(input[1])) + do(println, solve-widths(cs)) + + main() +;============================================================ + |
