aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorazidar2015-02-13 15:42:47 -0800
committerazidar2015-02-13 15:42:47 -0800
commit4f68f75415eb89427062eb86ff21b0e53bf4cadd (patch)
tree1f6a552e18eed4874a563359e95e5aad87a8ef50 /src
parent4deb61cefa9c0ef7806e3986231865ce59673bc2 (diff)
First commit.
Added stanza as a .zip, changed names from ch to firrtl, and spec.tex is included. need to add installation instructions. TODO's included in README
Diffstat (limited to 'src')
-rw-r--r--src/lib/stanzam.zipbin0 -> 2167258 bytes
-rw-r--r--src/main/stanza/firrtl-ir.stanza146
-rw-r--r--src/main/stanza/firrtl-main.stanza26
-rw-r--r--src/main/stanza/ir-parser.stanza204
-rw-r--r--src/main/stanza/ir-utils.stanza227
-rw-r--r--src/main/stanza/passes.stanza1878
-rw-r--r--src/main/stanza/widthsolver.stanza298
-rw-r--r--src/test/firrtl/firrtl-test.txt56
8 files changed, 2835 insertions, 0 deletions
diff --git a/src/lib/stanzam.zip b/src/lib/stanzam.zip
new file mode 100644
index 00000000..cc396d61
--- /dev/null
+++ b/src/lib/stanzam.zip
Binary files differ
diff --git a/src/main/stanza/firrtl-ir.stanza b/src/main/stanza/firrtl-ir.stanza
new file mode 100644
index 00000000..1a6df3d5
--- /dev/null
+++ b/src/main/stanza/firrtl-ir.stanza
@@ -0,0 +1,146 @@
+defpackage chipper.ir2 :
+ import core
+ import verse
+
+public definterface Direction
+public val INPUT = new Direction
+public val OUTPUT = new Direction
+public val UNKNOWN-DIR = new Direction
+
+public definterface Width
+public defstruct UnknownWidth <: Width
+public defstruct IntWidth <: Width :
+ width: Int
+
+public defstruct PrimOp
+public val ADD-OP = PrimOp()
+public val ADD-MOD-OP = PrimOp()
+public val SUB-OP = PrimOp()
+public val SUB-MOD-OP = PrimOp()
+public val TIMES-OP = PrimOp()
+public val DIVIDE-OP = PrimOp()
+public val MOD-OP = PrimOp()
+public val SHIFT-LEFT-OP = PrimOp()
+public val SHIFT-RIGHT-OP = PrimOp()
+public val PAD-OP = PrimOp()
+public val BIT-AND-OP = PrimOp()
+public val BIT-OR-OP = PrimOp()
+public val BIT-XOR-OP = PrimOp()
+public val CONCAT-OP = PrimOp()
+public val BIT-SELECT-OP = PrimOp()
+public val BITS-SELECT-OP = PrimOp()
+public val MULTIPLEX-OP = PrimOp()
+public val LESS-OP = PrimOp()
+public val LESS-EQ-OP = PrimOp()
+public val GREATER-OP = PrimOp()
+public val GREATER-EQ-OP = PrimOp()
+public val EQUAL-OP = PrimOp()
+
+public definterface Expression
+public defmulti type (e:Expression) -> Type
+
+public defstruct Ref <: Expression :
+ name: Symbol
+ type: Type [multi => false]
+public defstruct Field <: Expression :
+ exp: Expression
+ name: Symbol
+ type: Type [multi => false]
+public defstruct Index <: Expression :
+ exp: Expression
+ value: Int
+ type: Type [multi => false]
+public defstruct UIntValue <: Expression :
+ value: Int
+ width: Width
+public defstruct SIntValue <: Expression :
+ value: Int
+ width: Width
+public defstruct DoPrim <: Expression :
+ op: PrimOp
+ args: List<Expression>
+ consts: List<Int>
+ type: Type [multi => false]
+public defstruct ReadPort <: Expression :
+ mem: Expression
+ index: Expression
+ type: Type [multi => false]
+
+public definterface Command
+public defstruct LetRec <: Command :
+ entries: List<KeyValue<Symbol, Element>>
+ body: Command
+public defstruct DefWire <: Command :
+ name: Symbol
+ type: Type
+public defstruct DefRegister <: Command :
+ name: Symbol
+ type: Type
+public defstruct DefInstance <: Command :
+ name: Symbol
+ module: Expression
+public defstruct DefMemory <: Command :
+ name: Symbol
+ type: VectorType
+public defstruct DefAccessor <: Command :
+ name: Symbol
+ source: Expression
+ index: Expression
+public defstruct Conditionally <: Command :
+ pred: Expression
+ conseq: Command
+ alt: Command
+public defstruct Begin <: Command :
+ body: List<Command>
+public defstruct Connect <: Command :
+ loc: Expression
+ exp: Expression
+public defstruct EmptyCommand <: Command
+
+public definterface Element
+public defmulti type (e:Element) -> Type
+
+public defstruct Register <: Element :
+ type: Type [multi => false]
+ value: Expression
+ enable: Expression
+public defstruct Memory <: Element :
+ type: Type [multi => false]
+ writers: List<WritePort>
+public defstruct WritePort :
+ index: Expression
+ value: Expression
+ enable: Expression
+public defstruct Node <: Element :
+ type: Type [multi => false]
+ value: Expression
+public defstruct Instance <: Element :
+ type: Type [multi => false]
+ module: Expression
+ ports: List<KeyValue<Symbol,Expression>>
+
+public definterface Type
+public defstruct UIntType <: Type :
+ width: Width
+public defstruct SIntType <: Type :
+ width: Width
+public defstruct BundleType <: Type :
+ ports: List<Port>
+public defstruct VectorType <: Type :
+ type: Type
+ size: Int
+public defstruct UnknownType <: Type
+
+public defstruct Port :
+ name: Symbol
+ direction: Direction
+ type: Type
+
+public defstruct Module :
+ name: Symbol
+ ports: List<Port>
+ body: Command
+
+public defstruct Circuit :
+ modules: List<Module>
+ main: Symbol \ No newline at end of file
diff --git a/src/main/stanza/firrtl-main.stanza b/src/main/stanza/firrtl-main.stanza
new file mode 100644
index 00000000..bafb9dd7
--- /dev/null
+++ b/src/main/stanza/firrtl-main.stanza
@@ -0,0 +1,26 @@
+include<"core/stringeater.stanza">
+include<"compiler/lexer.stanza">
+include<"compiler/parser.stanza">
+include<"compiler/rdparser2.stanza">
+include<"compiler/macro-utils.stanza">
+include("firrtl-ir.stanza")
+include("ir-utils.stanza")
+include("ir-parser.stanza")
+include("passes.stanza")
+include("widthsolver.stanza")
+
+defpackage chmain :
+ import core
+ import verse
+ import chipper.parser
+ import chipper.passes
+ import stanza.lexer
+ import stanza.parser
+
+defn main () :
+ val lexed = lex-file("../../test/firrtl/firrtl-test.txt")
+ val c = parse-firrtl(lexed)
+ println(c)
+ run-passes(c)
+
+main()
diff --git a/src/main/stanza/ir-parser.stanza b/src/main/stanza/ir-parser.stanza
new file mode 100644
index 00000000..0da99033
--- /dev/null
+++ b/src/main/stanza/ir-parser.stanza
@@ -0,0 +1,204 @@
+defpackage chipper.parser :
+ import core
+ import verse
+ import chipper.ir2
+ import stanza.rdparser
+ import stanza.lexer
+
+;======= Convenience Functions ====
+defn throw-error (x) :
+ throw $ new Exception :
+ defmethod print (o:OutputStream, this) :
+ print(o, x)
+
+defn ut (x) :
+ unwrap-token(x)
+
+;======== String Splitting ========
+defn substring? (s:String, look:String) :
+ index-of-string(s, look) != false
+
+defn index-of-string (s:String, look:String) :
+ for i in 0 through length(s) - length(look) index-when :
+ for j in 0 to length(look) all? :
+ s[i + j] == look[j]
+
+defn split-string (s:String, split:String) -> List<String> :
+ defn loop (s:String) -> List<String> :
+ if length(s) == 0 :
+ List()
+ else :
+ match(index-of-string(s, split)) :
+ (i:Int) :
+ val rest = loop(substring(s, i + length(split)))
+ if i == 0 : List(split, rest)
+ else : List(substring(s, 0, i), split, rest)
+ (f:False) : list(s)
+ loop(s)
+
+;======= Unwrap Prefix Forms ============
+defn unwrap-prefix-form (form) :
+ match(form) :
+ (form:Token) :
+ val fs = unwrap-prefix-form(item(form))
+ List(Token(head(fs), info(form)), tail(fs))
+ (form:List) :
+ if tagged-list?(form, `(@get @do @do-afn @of)) :
+ val rest = map-append(unwrap-prefix-form, tailn(form, 2))
+ val form* = List(form[0], rest)
+ append(unwrap-prefix-form(form[1]), list(form*))
+ else :
+ list(map-append(unwrap-prefix-form, form))
+ (form) :
+ list(form)
+
+;======= Split Dots ============
+defn split-dots (forms:List) :
+ defn split (form) :
+ match(ut(form)) :
+ (f:Symbol) :
+ val fstr = to-string(f)
+ if contains?(fstr, '.') : map(to-symbol, split-string(fstr, "."))
+ else : list(form)
+ (f:List) :
+ list(map-append(split, f))
+ (f) :
+ list(f)
+ head(split(forms))
+
+;====== Normalize Dots ========
+defn normalize-dots (forms:List) :
+ val forms* = head(unwrap-prefix-form(forms))
+ split-dots(forms*)
+
+;======== SYNTAX =======================
+rd.defsyntax firrtl :
+ defrule circuit :
+ (circuit ?name:#symbol : (?module-form ...)) :
+ rd.match-syntax(normalize-dots(module-form)) :
+ (?modules:#module ...) :
+ Circuit(modules, ut(name))
+
+ defrule module :
+ (module ?name:#symbol : (?ports:#port ... ?body:#comm ...)) :
+ Module(ut(name), ports, Begin(body))
+
+ defrule port :
+ (input ?name:#symbol : ?type:#type) :
+ Port(ut(name), INPUT, type)
+ (output ?name:#symbol : ?type:#type) :
+ Port(ut(name), OUTPUT, type)
+
+ defrule type :
+ (?type:#type (@get ?size:#int)) :
+ VectorType(type, ut(size))
+ (UInt (@do ?width:#int)) :
+ UIntType(IntWidth(ut(width)))
+ (UInt) :
+ UIntType(UnknownWidth())
+ (SInt (@do ?width:#int)) :
+ SIntType(IntWidth(ut(width)))
+ (SInt) :
+ SIntType(UnknownWidth())
+ ({?ports:#port ...}) :
+ BundleType(ports)
+
+ defrule comm :
+ (wire ?name:#symbol : ?type:#type) :
+ DefWire(ut(name), type)
+ (reg ?name:#symbol : ?type:#type) :
+ DefRegister(ut(name), type)
+ (mem ?name:#symbol : ?type:#type) :
+ DefMemory(ut(name), type)
+ (inst ?name:#symbol of ?module:#exp) :
+ DefInstance(ut(name), module)
+ (accessor ?name:#symbol = ?source:#exp (@get ?index:#exp)) :
+ DefAccessor(ut(name), source, index)
+ ((?body:#comm ...)) :
+ Begin(body)
+ (letrec : (?elems:#element ...) in : ?body:#comm) :
+ LetRec(elems, body)
+ (?x:#exp := ?y:#exp) :
+ Connect(x, y)
+ (?c:#comm/when) :
+ c
+
+ defrule comm/when :
+ (when ?pred:#exp : ?conseq:#comm else : ?alt:#comm) :
+ Conditionally(pred, conseq, alt)
+ (when ?pred:#exp : ?conseq:#comm else ?alt:#comm/when) :
+ Conditionally(pred, conseq, alt)
+ (when ?pred:#exp : ?conseq:#comm) :
+ Conditionally(pred, conseq, EmptyCommand())
+
+ defrule element :
+ (reg ?name:#symbol : ?type:#type = Register (@do ?value:#exp ?en:#exp)) :
+ ut(name) => Register(type, value, en)
+ (mem ?name:#symbol : ?type:#type = Memory
+ (@do (?i:#exp => WritePort (@do ?value:#exp ?en:#exp) @...))) :
+ val ports = map(WritePort, i, value, en)
+ ut(name) => Memory(type, ports)
+ (node ?name:#symbol : ?type:#type = ?exp:#exp) :
+ ut(name) => Node(type, exp)
+ (inst ?name:#symbol = Instance (@do ?module:#exp
+ (?names:#symbol => ?values:#exp @...))) :
+ val ports = map({ut(_) => _}, names, values)
+ ut(name) => Instance(UnknownType(), module, ports)
+
+ defrule exp :
+ (?x:#exp . ?f:#symbol) :
+ Field(x, ut(f), UnknownType())
+ (?x:#exp . ?f:#int) :
+ Index(x, ut(f), UnknownType())
+ (?x:#exp-form) :
+ x
+
+ val operators = HashTable<Symbol, PrimOp>(symbol-hash)
+ operators[`add] = ADD-OP
+ operators[`add-mod] = ADD-MOD-OP
+ operators[`sub] = SUB-OP
+ operators[`sub-mod] = SUB-MOD-OP
+ operators[`times] = TIMES-OP
+ operators[`mod] = MOD-OP
+ operators[`bit-and] = BIT-AND-OP
+ operators[`bit-or] = BIT-OR-OP
+ operators[`bit-xor] = BIT-XOR-OP
+ operators[`concat] = CONCAT-OP
+ operators[`less] = LESS-OP
+ operators[`less-eq] = LESS-EQ-OP
+ operators[`greater] = GREATER-OP
+ operators[`greater-eq] = GREATER-EQ-OP
+ operators[`equal] = EQUAL-OP
+ operators[`multiplex] = MULTIPLEX-OP
+ operators[`pad] = PAD-OP
+ operators[`shift-left] = SHIFT-LEFT-OP
+ operators[`shift-right] = SHIFT-RIGHT-OP
+ operators[`bit] = BIT-SELECT-OP
+ operators[`bits] = BITS-SELECT-OP
+
+ defrule exp-form :
+ (UInt (@do ?value:#int ?width:#int)) :
+ UIntValue(ut(value), IntWidth(ut(width)))
+ (UInt (@do ?value:#int)) :
+ UIntValue(ut(value), UnknownWidth())
+ (SInt (@do ?value:#int ?width:#int)) :
+ SIntValue(ut(value), IntWidth(ut(width)))
+ (SInt (@do ?value:#int)) :
+ SIntValue(ut(value), UnknownWidth())
+ (ReadPort (@do ?mem:#exp ?index:#exp)) :
+ ReadPort(mem, index, UnknownType())
+ (?op:#symbol (@do ?es:#exp ... ?ints:#int ...)) :
+ match(get?(operators, ut(op), false)) :
+ (op:PrimOp) :
+ DoPrim(op, es, map(ut, ints), UnknownType())
+ (f:False) :
+ throw-error $ string-join $ [
+ "Invalid operator: " op]
+ (?x:#symbol) :
+ Ref(ut(x), UnknownType())
+
+public defn parse-firrtl (forms:List) :
+ with-parser{`firrtl, _} $ fn () :
+ rd.match-syntax(forms) :
+ (?c:#circuit) :
+ c \ No newline at end of file
diff --git a/src/main/stanza/ir-utils.stanza b/src/main/stanza/ir-utils.stanza
new file mode 100644
index 00000000..be08cac8
--- /dev/null
+++ b/src/main/stanza/ir-utils.stanza
@@ -0,0 +1,227 @@
+defpackage chipper.ir-utils :
+ import core
+ import verse
+ import chipper.ir2
+
+;============== PRINTERS ===================================
+defmethod print (o:OutputStream, d:Direction) :
+ print{o, _} $
+ switch {d == _} :
+ INPUT : "input"
+ OUTPUT: "output"
+ UNKNOWN-DIR : "unknown"
+
+defmethod print (o:OutputStream, w:Width) :
+ print{o, _} $
+ match(w) :
+ (w:UnknownWidth) : "?"
+ (w:IntWidth) : width(w)
+
+defmethod print (o:OutputStream, op:PrimOp) :
+ print{o, _} $
+ switch {op == _} :
+ ADD-OP : "ADD"
+ ADD-MOD-OP : "ADD-MOD"
+ SUB-OP : "MINUS"
+ SUB-MOD-OP : "SUB-MOD"
+ TIMES-OP : "TIMES"
+ DIVIDE-OP : "DIVIDE"
+ MOD-OP : "MOD"
+ SHIFT-LEFT-OP : "SHIFT-LEFT"
+ SHIFT-RIGHT-OP : "SHIFT-RIGHT"
+ PAD-OP : "PAD"
+ BIT-AND-OP : "BIT-AND"
+ BIT-OR-OP : "BIT-OR"
+ BIT-XOR-OP : "BIT-XOR"
+ CONCAT-OP : "CONCAT"
+ BIT-SELECT-OP : "BIT-SELECT"
+ BITS-SELECT-OP : "BITS-SELECT"
+ MULTIPLEX-OP : "MULTIPLEX"
+ LESS-OP : "LESS"
+ LESS-EQ-OP : "LESS-EQ"
+ GREATER-OP : "GREATER"
+ GREATER-EQ-OP : "GREATER-EQ"
+ EQUAL-OP : "EQUAL"
+
+defmethod print (o:OutputStream, e:Expression) :
+ match(e) :
+ (e:Ref) : print(o, name(e))
+ (e:Field) : print-all(o, [exp(e) "." name(e)])
+ (e:Index) : print-all(o, [exp(e) "." value(e)])
+ (e:UIntValue) : print-all(o, ["UInt(" value(e) ")"])
+ (e:SIntValue) : print-all(o, ["SInt(" value(e) ")"])
+ (e:DoPrim) :
+ print-all(o, [op(e) "("])
+ print-all(o, join(concat(args(e), consts(e)), ", "))
+ print(o, ")")
+ (e:ReadPort) : print-all(o, ["ReadPort(" mem(e) ", " index(e) ")"])
+
+defmethod print (o:OutputStream, c:Command) :
+ match(c) :
+ (c:LetRec) :
+ println(o, "let : ")
+ indented{o, _} $ fn () :
+ for entry in entries(c) do :
+ println-all([key(entry) " = " value(entry)])
+ println(o, "in :")
+ indented(o, print{o, body(c)})
+ (c:DefWire) :
+ print-all(["wire " name(c) " : " type(c)])
+ (c:DefRegister) :
+ print-all(["reg " name(c) " : " type(c)])
+ (c:DefMemory) :
+ print-all(["mem " name(c) " : " type(c)])
+ (c:DefInstance) :
+ print-all(["inst " name(c) " of " module(c)])
+ (c:DefAccessor) :
+ print-all(["accessor " name(c) " = " source(c) "[" index(c) "]"])
+ (c:Conditionally) :
+ println-all(o, ["when " pred(c) " :"])
+ indented(o, print{conseq(c)})
+ if alt(c) not-typeof EmptyCommand :
+ println(o, "\nelse :")
+ indented(o, print{alt(c)})
+ (c:Begin) :
+ do(print, join(body(c), "\n"))
+ (c:Connect) :
+ print-all(o, [loc(c) " := " exp(c)])
+ (c:EmptyCommand) :
+ print(o, "skip")
+
+defmethod print (o:OutputStream, e:Element) :
+ match(e) :
+ (e:Register) :
+ print-all(o, ["Register(" type(e) ", " value(e) ", " enable(e) ")"])
+ (e:Memory) :
+ print-all(o, ["Memory(" type(e) ", "])
+ print-all(o, join(writers(e), ", "))
+ print(o, ")")
+ (e:Node) :
+ print-all(o, ["Node(" type(e) ", " value(e) ")"])
+ (e:Instance) :
+ print-all(o, ["Instance(" module(e) ", "])
+ print-all(o, join(ports(e), ", "))
+ print(o, ")")
+
+defmethod print (o:OutputStream, p:WritePort) :
+ print-all(o, [index(p) " => WritePort(" value(p) ", " enable(p) ")"])
+
+defmethod print (o:OutputStream, t:Type) :
+ match(t) :
+ (t:UnknownType) :
+ print(o, "?")
+ (t:UIntType) :
+ print-all(o, ["UInt(" width(t) ")"])
+ (t:SIntType) :
+ print-all(o, ["SInt(" width(t) ")"])
+ (t:BundleType) :
+ print(o, "{")
+ print-all(o, join(ports(t), ", "))
+ print(o, "}")
+ (t:VectorType) :
+ print-all(o, [type(t) "[" size(t) "]"])
+
+defmethod print (o:OutputStream, p:Port) :
+ print-all(o, [direction(p) " " name(p) " : " type(p)])
+
+defmethod print (o:OutputStream, m:Module) :
+ println-all(o, ["module " name(m) " :"])
+ indented{o, _} $ fn () :
+ do(println, ports(m))
+ print(body(m))
+
+defmethod print (o:OutputStream, c:Circuit) :
+ println-all(o, ["circuit " main(c) " :"])
+ indented(o, do{println, modules(c)})
+
+;================== INDENTATION ============================
+defn IndentedStream (o:OutputStream, n:Int) :
+ var indent? = true
+ defn put (c:Char) :
+ if indent? :
+ do(print{o, " "}, 0 to n)
+ indent? = false
+ print(o, c)
+ if c == '\n' :
+ indent? = true
+
+ new OutputStream :
+ defmethod print (this, s:String) : do(put, s)
+ defmethod print (this, c:Char) : put(c)
+
+defn indented (o:OutputStream, f: () -> ?) :
+ val prev-stream = CURRENT-OUTPUT-STREAM
+ dynamic-wind(
+ fn () : CURRENT-OUTPUT-STREAM = IndentedStream(o, 3)
+ f
+ fn (f) : CURRENT-OUTPUT-STREAM = prev-stream)
+
+
+;=================== MAPPERS ===============================
+public defn map<?T> (f: Type -> Type, t:?T&Type) -> T :
+ val type =
+ match(t) :
+ (t:T&BundleType) :
+ BundleType $
+ for p in ports(t) map :
+ Port(name(p), direction(p), f(type(p)))
+ (t:T&VectorType) :
+ VectorType(f(type(t)), size(t))
+ (t) :
+ t
+ type as T&Type
+
+public defmulti map<?T> (f: Expression -> Expression, e:?T&Expression) -> T
+defmethod map (f: Expression -> Expression, e:Expression) -> Expression :
+ match(e) :
+ (e:Field) : Field(f(exp(e)), name(e), type(e))
+ (e:Index) : Index(f(exp(e)), value(e), type(e))
+ (e:DoPrim) : DoPrim(op(e), map(f, args(e)), consts(e), type(e))
+ (e:ReadPort) : ReadPort(f(mem(e)), f(index(e)), type(e))
+ (e) : e
+
+public defmulti map<?T> (f: Expression -> Expression, e:?T&Element) -> T
+defmethod map (f: Expression -> Expression, e:Element) -> Element :
+ match(e) :
+ (e:Register) :
+ Register(type(e), f(value(e)), f(enable(e)))
+ (e:Memory) :
+ val writers* = for w in writers(e) map :
+ WritePort(f(index(w)), f(value(w)), f(enable(w)))
+ Memory(type(e), writers*)
+ (e:Node) :
+ Node(type(e), f(value(e)))
+ (e:Instance) :
+ val ports* = for p in ports(e) map :
+ key(p) => f(value(p))
+ Instance(type(e), f(module(e)), ports*)
+
+public defmulti map<?T> (f: Expression -> Expression, c:?T&Command) -> T
+defmethod map (f: Expression -> Expression, c:Command) -> Command :
+ match(c) :
+ (c:LetRec) :
+ val entries* = for entry in entries(c) map :
+ key(entry) => map(f, value(entry))
+ LetRec(entries*, body(c))
+ (c:DefAccessor) : DefAccessor(name(c), f(source(c)), f(index(c)))
+ (c:DefInstance) : DefInstance(name(c), f(module(c)))
+ (c:Conditionally) : Conditionally(f(pred(c)), conseq(c), alt(c))
+ (c:Connect) : Connect(f(loc(c)), f(exp(c)))
+ (c) : c
+
+public defmulti map<?T> (f: Command -> Command, c:?T&Command) -> T
+defmethod map (f: Command -> Command, c:Command) -> Command :
+ match(c) :
+ (c:LetRec) : LetRec(entries(c), f(body(c)))
+ (c:Conditionally) : Conditionally(pred(c), f(conseq(c)), f(alt(c)))
+ (c:Begin) : Begin(map(f, body(c)))
+ (c) : c
+
+public defmulti children (c:Command) -> List<Command>
+defmethod children (c:Command) :
+ match(c) :
+ (c:LetRec) : list(body(c))
+ (c:Conditionally) : list(conseq(c), alt(c))
+ (c:Begin) : body(c)
+ (c) : List()
+
diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza
new file mode 100644
index 00000000..61eac73c
--- /dev/null
+++ b/src/main/stanza/passes.stanza
@@ -0,0 +1,1878 @@
+defpackage chipper.passes :
+ import core
+ import verse
+ import chipper.ir2
+ import chipper.ir-utils
+ import widthsolver
+
+;============== EXCEPTIONS =================================
+defclass PassException <: Exception
+defn PassException (msg:String) :
+ new PassException :
+ defmethod print (o:OutputStream, this) :
+ print(o, msg)
+
+;=============== WORKING IR ================================
+definterface Kind
+defstruct RegKind <: Kind
+defstruct AccessorKind <: Kind
+defstruct PortKind <: Kind
+defstruct MemKind <: Kind
+defstruct NodeKind <: Kind
+defstruct ModuleKind <: Kind
+defstruct InstanceKind <: Kind
+defstruct StructuralMemKind <: Kind
+
+defstruct WRef <: Expression :
+ name: Symbol
+ type: Type [multi => false]
+ kind: Kind
+ dir: Direction [multi => false]
+
+defstruct WField <: Expression :
+ exp: Expression
+ name: Symbol
+ type: Type [multi => false]
+ dir: Direction [multi => false]
+
+defstruct WIndex <: Expression :
+ exp: Expression
+ value: Int
+ type: Type [multi => false]
+ dir: Direction [multi => false]
+
+defstruct WDefAccessor <: Command :
+ name: Symbol
+ source: Expression
+ index: Expression
+ dir: Direction
+
+;================ WORKING IR UTILS =========================
+;=== Printers ===
+defmethod print (o:OutputStream, k:Kind) :
+ print{o, _} $
+ match(k) :
+ (k:RegKind) : "reg:"
+ (k:AccessorKind) : "accessor:"
+ (k:PortKind) : "port:"
+ (k:MemKind) : "mem:"
+ (k:NodeKind) : "n:"
+ (k:ModuleKind) : "module:"
+ (k:InstanceKind) : "inst:"
+ (k:StructuralMemKind) : "smem:"
+
+defmethod print (o:OutputStream, e:WRef) :
+ print-all(o, [name(e)])
+defmethod print (o:OutputStream, e:WField) :
+ print-all(o, [exp(e) "." name(e)])
+defmethod print (o:OutputStream, e:WIndex) :
+ print-all(o, [exp(e) "." value(e)])
+
+defmethod print (o:OutputStream, c:WDefAccessor) :
+ print-all(o, [dir(c) " accessor " name(c) " = " source(c) "[" index(c) "]"])
+
+defmethod map (f: Expression -> Expression, e: WField) :
+ WField(f(exp(e)), name(e), type(e), dir(e))
+
+defmethod map (f: Expression -> Expression, e: WIndex) :
+ WIndex(f(exp(e)), value(e), type(e), dir(e))
+
+defmethod map (f: Expression -> Expression, c:WDefAccessor) :
+ WDefAccessor(name(c), f(source(c)), f(index(c)), dir(c))
+
+;================= DIRECTION ===============================
+defmulti dir (e:Expression) -> Direction
+defmethod dir (e:Expression) :
+ OUTPUT
+
+;============== Bring to Working IR ========================
+defn to-working-ir (c:Circuit) :
+ defn to-exp (e:Expression) :
+ match(map(to-exp, e)) :
+ (e:Ref) : WRef(name(e), type(e), NodeKind(), UNKNOWN-DIR)
+ (e:Field) : WField(exp(e), name(e), type(e), UNKNOWN-DIR)
+ (e:Index) : WIndex(exp(e), value(e), type(e), UNKNOWN-DIR)
+ (e) : e
+ defn to-command (c:Command) :
+ match(map(to-exp, c)) :
+ (c:DefAccessor) :
+ WDefAccessor(name(c), source(c), index(c), UNKNOWN-DIR)
+ (c) :
+ map(to-command, c)
+
+ Circuit(modules*, main(c)) where :
+ val modules* =
+ for m in modules(c) map :
+ Module(name(m), ports(m), to-command(body(m)))
+
+;=============== Resolve Kinds =============================
+defn resolve-kinds (c:Circuit) :
+ defn resolve-exp (e:Expression, kinds:HashTable<Symbol,Kind>) :
+ match(e) :
+ (e:WRef) : WRef(name(e), type(e), kinds[name(e)], dir(e))
+ (e) : map(resolve-exp{_, kinds}, e)
+
+ defn resolve-comm (c:Command, kinds:HashTable<Symbol,Kind>) -> Command :
+ map{resolve-comm{_, kinds}, _} $
+ map(resolve-exp{_, kinds}, c)
+
+ defn find-kinds (c:Command, kinds:HashTable<Symbol,Kind>) :
+ match(c) :
+ (c:LetRec) :
+ for entry in entries(c) do :
+ kinds[key(entry)] = element-kind(value(entry))
+ (c:DefWire) : kinds[name(c)] = NodeKind()
+ (c:DefRegister) : kinds[name(c)] = RegKind()
+ (c:DefInstance) : kinds[name(c)] = InstanceKind()
+ (c:DefMemory) : kinds[name(c)] = MemKind()
+ (c:WDefAccessor) : kinds[name(c)] = AccessorKind()
+ (c) : false
+ do(find-kinds{_, kinds}, children(c))
+
+ defn element-kind (e:Element) :
+ match(e) :
+ (e:Memory) : StructuralMemKind()
+ (e) : NodeKind()
+
+ defn resolve-mod (m:Module, modules:List<Symbol>) :
+ val kinds = HashTable<Symbol,Kind>(symbol-hash)
+ for module in modules do :
+ kinds[module] = ModuleKind()
+ for port in ports(m) do :
+ kinds[name(port)] = PortKind()
+ find-kinds(body(m), kinds)
+ Module(name(m), ports(m), body*) where :
+ val body* = resolve-comm(body(m), kinds)
+
+ Circuit(modules*, main(c)) where :
+ val mod-names = map(name, modules(c))
+ val modules* = map(resolve-mod{_, mod-names}, modules(c))
+
+;=============== MAKE RESET EXPLICIT =======================
+defn make-explicit-reset (c:Circuit) :
+ defn reset-instances (c:Command, reset?: List<Symbol>) -> Command :
+ match(c) :
+ (c:DefInstance) :
+ val module = module(c) as WRef
+ if contains?(reset?, name(module)) :
+ c
+ else :
+ Begin $ list(c, Connect(WField(inst, `reset, UnknownType(), UNKNOWN-DIR), reset)) where :
+ val inst = WRef(name(c), UnknownType(), InstanceKind(), UNKNOWN-DIR)
+ val reset = WRef(`reset, UnknownType(), PortKind(), UNKNOWN-DIR)
+ (c) :
+ map(reset-instances{_:Command, reset?}, c)
+
+ defn make-explicit-reset (m:Module, reset-list: List<Symbol>) :
+ val reset? = contains?(reset-list, name(m))
+
+ ;Add reset port if necessary
+ val ports* =
+ if reset? :
+ ports(m)
+ else :
+ val reset = Port(`reset, INPUT, UIntType(IntWidth(1)))
+ List(reset, ports(m))
+
+ ;Reset Instances
+ val body* = reset-instances(body(m), reset-list)
+ val m* = Module(name(m), ports*, body*)
+
+ ;Initialize registers if necessary
+ if reset? : m*
+ else : initialize-registers(m*)
+
+ Circuit(modules*, main(c)) where :
+ defn reset? (m:Module) :
+ for p in ports(m) any? :
+ name(p) == `reset
+ val reset-list = to-list(stream(name, filter(reset?, modules(c))))
+ val modules* = map(make-explicit-reset{_, reset-list}, modules(c))
+
+;======= MAKE EXPLICIT REGISTER INITIALIZATION =============
+defn initialize-registers (m:Module) :
+ ;=== Initializing Expressions ===
+ defn init-exps (inits: List<KeyValue<Symbol,Expression>>) :
+ if empty?(inits) :
+ EmptyCommand()
+ else :
+ Conditionally(reset, Begin(map(connect, inits)), EmptyCommand()) where :
+ val reset = WRef(`reset, UnknownType(), PortKind(), UNKNOWN-DIR)
+ defn connect (init: KeyValue<Symbol, Expression>) :
+ val reg-ref = WRef(key(init), UnknownType(), RegKind(), UNKNOWN-DIR)
+ Connect(reg-ref, value(init))
+
+ defn initialize-registers (c: Command
+ inits: List<KeyValue<Symbol,Expression>>) ->
+ [Command, List<KeyValue<Symbol,Expression>>] :
+ ;=== Rename Expressions ===
+ defn rename (e:Expression) :
+ match(e) :
+ (e:WField) :
+ switch {name(e) == _} :
+ `init :
+ if reg?(exp(e)) : init-wire(exp(e))
+ else : map(rename, e)
+ else : map(rename, e)
+ (e) : map(rename, e)
+ defn reg? (e:Expression) :
+ match(e) :
+ (e:WRef) : kind(e) typeof RegKind
+ (e) : false
+ defn init-wire (e:Expression) :
+ lookup!(inits, name(e as WRef))
+
+ ;=== Driver ===
+ match(c) :
+ (c:DefRegister) :
+ [new-command, list(init-entry)] where :
+ val wire-name = gensym()
+ val wire-ref = WRef(wire-name, UnknownType(), NodeKind(), UNKNOWN-DIR)
+ val reg-ref = WRef(name(c), UnknownType(), RegKind(), UNKNOWN-DIR)
+ val def-init-wire = DefWire(wire-name, type(c))
+ val init-wire = Connect(wire-ref, reg-ref)
+ val init-reg = Connect(reg-ref, wire-ref)
+ val new-command = Begin(to-list([c, def-init-wire, init-wire, init-reg]))
+ val init-entry = name(c) => wire-ref
+ (c:Conditionally) :
+ val pred* = rename(pred(c))
+ val [conseq* con-inits] = initialize-registers(conseq(c), inits)
+ val [alt* alt-inits] = initialize-registers(alt(c), inits)
+ val c* = Conditionally(pred*, conseq+inits, alt+inits) where :
+ val conseq+inits = Begin(list(conseq*, init-exps(con-inits)))
+ val alt+inits = Begin(list(alt*, init-exps(alt-inits)))
+ [c*, List()]
+ (c:LetRec) :
+ val c* = map(rename, c)
+ val [body*, body-inits] = initialize-registers(body(c), inits)
+ val new-command =
+ LetRec(entries(c*), body+inits) where :
+ val body+inits = Begin(list(body*, init-exps(body-inits)))
+ [new-command, List()]
+ (c:Begin) :
+ var inits-in:List<KeyValue<Symbol,Expression>> = inits
+ var inits-out:List<KeyValue<Symbol,Expression>> = List()
+ val body* =
+ for c in body(c) map :
+ val [c* inits*] = initialize-registers(c, inits-in)
+ inits-in = append(inits*, inits-in)
+ inits-out = append(inits*, inits-out)
+ c*
+ [Begin(body*), inits-out]
+ (c) :
+ val c* = map(rename, c)
+ [c*, List()]
+
+ Module(name(m), ports(m), body+inits) where :
+ val [body*, inits] = initialize-registers(body(m), List())
+ val body+inits = Begin(list(body*, init-exps(inits)))
+
+
+;============== INFER TYPES ================================
+defmethod type (v:UIntValue) :
+ UIntType(width(v))
+
+defmethod type (v:SIntValue) :
+ SIntType(width(v))
+
+defn put-type (e:Expression, t:Type) -> Expression :
+ match(e) :
+ (e:WRef) : WRef(name(e), t, kind(e), dir(e))
+ (e:WField) : WField(exp(e), name(e), t, dir(e))
+ (e:WIndex) : WIndex(exp(e), value(e), t, dir(e))
+ (e:DoPrim) : DoPrim(op(e), args(e), consts(e), t)
+ (e:ReadPort) : ReadPort(mem(e), index(e), t)
+ (e) : e
+
+defn lookup-port (ports: Streamable<Port>, port-name: Symbol) :
+ for port in ports find :
+ name(port) == port-name
+
+defn infer (op:PrimOp, arg-types: List<Type>) -> Type :
+ defn wipe-width (t:Type) :
+ match(t) :
+ (t:UIntType) : UIntType(UnknownWidth())
+ (t:SIntType) : SIntType(UnknownWidth())
+
+ defn arg0 () : wipe-width(arg-types[0])
+ defn arg1 () : wipe-width(arg-types[1])
+ switch {op == _} :
+ ADD-OP : arg0()
+ ADD-MOD-OP : arg0()
+ SUB-OP : arg0()
+ SUB-MOD-OP : arg0()
+ TIMES-OP : arg0()
+ DIVIDE-OP : arg0()
+ MOD-OP : arg0()
+ SHIFT-LEFT-OP : arg0()
+ SHIFT-RIGHT-OP : arg0()
+ PAD-OP : arg0()
+ BIT-AND-OP : arg0()
+ BIT-OR-OP : arg0()
+ BIT-XOR-OP : arg0()
+ CONCAT-OP : arg0()
+ BIT-SELECT-OP : UIntType(UnknownWidth())
+ BITS-SELECT-OP : arg0()
+ MULTIPLEX-OP : arg0()
+ LESS-OP : UIntType(UnknownWidth())
+ LESS-EQ-OP : UIntType(UnknownWidth())
+ GREATER-OP : UIntType(UnknownWidth())
+ GREATER-EQ-OP : UIntType(UnknownWidth())
+ EQUAL-OP : UIntType(UnknownWidth())
+
+defn bundle-field-type (t:Type, n:Symbol) -> Type :
+ match(t) :
+ (t:BundleType) :
+ match(lookup-port(ports(t), n)) :
+ (p:Port) : type(p)
+ (p) : UnknownType()
+ (t) : UnknownType()
+
+defn vector-elem-type (t:Type) -> Type :
+ match(t) :
+ (t:VectorType) : type(t)
+ (t) : UnknownType()
+
+;e is the environment that contains all definitions seen so far.
+defn infer (c:Command, e:List<KeyValue<Symbol, Type>>) -> [Command, List<KeyValue<Symbol,Type>>] :
+ defn infer-exp (e:Expression, env:List<KeyValue<Symbol,Type>>) :
+ match(map(infer-exp{_, env}, e)) :
+ (e:WRef) :
+ put-type(e, lookup!(env, name(e)))
+ (e:WField) :
+ put-type(e, bundle-field-type(type(exp(e)), name(e)))
+ (e:WIndex) :
+ put-type(e, vector-elem-type(type(exp(e))))
+ (e:UIntValue) : e
+ (e:SIntValue) : e
+ (e:DoPrim) :
+ put-type(e, infer(op(e), map(type, args(e))))
+ (e:ReadPort) :
+ put-type(e, vector-elem-type(type(mem(e))))
+
+ defn element-type (e:Element, env:List<KeyValue<Symbol,Type>>) :
+ match(e) :
+ (e:Instance) :
+ val t = type(infer-exp(module(e), env))
+ match(t) :
+ (t:BundleType) :
+ BundleType $ to-list $
+ for p in ports(t) filter :
+ direction(p) == OUTPUT
+ (t) : UnknownType()
+ (e) : type(e)
+
+ match(c) :
+ (c:LetRec) :
+ val e* = append(elem-types, e) where :
+ val elem-types =
+ for entry in entries(c) map :
+ key(entry) => element-type(value(entry), e)
+ val c* = map(infer-exp{_, e*}, c)
+ val [body*, be] = infer(body(c*), e*)
+ [LetRec(entries(c*), body*), e]
+ (c) :
+ match(map(infer-exp{_, e}, c)) :
+ (c:DefWire) :
+ [c, List(entry, e)] where :
+ val entry = name(c) => type(c)
+ (c:DefRegister) :
+ [c, List(entry, e)] where :
+ val entry = name(c) => type(c)
+ (c:DefInstance) :
+ [c, List(entry, e)] where :
+ val entry = name(c) => type(module(c))
+ (c:DefMemory) :
+ [c, List(entry, e)] where :
+ val entry = name(c) => type(c)
+ (c:WDefAccessor) :
+ [c, List(entry, e)] where :
+ val src-type = type(source(c))
+ val entry = name(c) => vector-elem-type(src-type)
+ (c:Begin) :
+ var current-e: List<KeyValue<Symbol,Type>> = e
+ val body* = for c in body(c) map :
+ val [c*, e*] = infer(c, current-e)
+ current-e = e*
+ c*
+ [Begin(body*), current-e]
+ (c) :
+ defn infer-comm (c:Command) :
+ val [c* e*] = infer(c, e)
+ c*
+ val c* = map(infer-comm, c)
+ [c*, e]
+
+defn infer (m:Module, e:List<KeyValue<Symbol, Type>>) -> Module :
+ val env = append{_, e} $
+ for p in ports(m) map :
+ name(p) => type(p)
+ val [body*, e*] = infer(body(m), env)
+ Module(name(m), ports(m), body*)
+
+defn infer-types (c:Circuit) -> Circuit :
+ val env =
+ for m in modules(c) map :
+ name(m) => BundleType(ports(m))
+ Circuit(map(infer{_, env}, modules(c)),
+ main(c))
+
+;============= INFER DIRECTIONS ============================
+defn flip (d:Direction) :
+ switch {d == _} :
+ INPUT : OUTPUT
+ OUTPUT : INPUT
+ else : d
+
+defn times (d1:Direction, d2:Direction) :
+ if d1 == INPUT : flip(d2)
+ else : d2
+
+defn bundle-field-dir (t:Type, n:Symbol) -> Direction :
+ match(t) :
+ (t:BundleType) :
+ match(lookup-port(ports(t), n)) :
+ (p:Port) : direction(p)
+ (p) : UNKNOWN-DIR
+ (t) : UNKNOWN-DIR
+
+defn infer-dirs (m:Module) :
+ ;=== Direction of all Binders ===
+ val BI-DIR = new Direction
+ val directions = HashTable<Symbol,Direction>(symbol-hash)
+ defn find-dirs (c:Command) :
+ match(c) :
+ (c:LetRec) :
+ for entry in entries(c) do :
+ directions[key(entry)] = OUTPUT
+ find-dirs(body(c))
+ (c:DefWire) :
+ directions[name(c)] = BI-DIR
+ (c:DefRegister) :
+ directions[name(c)] = BI-DIR
+ (c:DefInstance) :
+ directions[name(c)] = OUTPUT
+ (c:DefMemory) :
+ directions[name(c)] = BI-DIR
+ (c:WDefAccessor) :
+ directions[name(c)] = dir(c)
+ (c) :
+ do(find-dirs, children(c))
+ for p in ports(m) do :
+ directions[name(p)] = flip(direction(p))
+ find-dirs(body(m))
+
+ ;=== Fix Point Status ===
+ var changed? = false
+
+ ;=== Infer directions of Expression ===
+ defn infer-exp (e:Expression, desired:Direction) :
+ match(e) :
+ (e:WRef) :
+ val dir* = let :
+ if kind(e) typeof ModuleKind :
+ OUTPUT
+ else :
+ val old-dir = directions[name(e)]
+ switch {old-dir == _} :
+ BI-DIR :
+ desired
+ UNKNOWN-DIR :
+ if directions[name(e)] != desired :
+ directions[name(e)] = desired
+ changed? = true
+ desired
+ else :
+ old-dir
+ WRef(name(e), type(e), kind(e), dir*)
+ (e:WField) :
+ val port-dir = bundle-field-dir(type(exp(e)), name(e))
+ val exp* = infer-exp(exp(e), port-dir * desired)
+ WField(exp*, name(e), type(e), port-dir * dir(exp*))
+ (e:WIndex) :
+ val exp* = infer-exp(exp(e), desired)
+ WIndex(exp*, value(e), type(e), dir(exp*))
+ (e) :
+ map(infer-exp{_, OUTPUT}, e)
+
+ ;=== Infer directions of Commands ===
+ defn infer-comm (c:Command) :
+ match(c) :
+ (c:LetRec) :
+ val c* = map(infer-exp{_, OUTPUT}, c)
+ LetRec(entries(c*), infer-comm(body(c)))
+ (c:DefInstance) :
+ DefInstance(name(c),
+ infer-exp(module(c), OUTPUT))
+ (c:WDefAccessor) :
+ val d = directions[name(c)]
+ WDefAccessor(name(c),
+ infer-exp(source(c), d),
+ infer-exp(index(c), OUTPUT),
+ d)
+ (c:Conditionally) :
+ Conditionally(infer-exp(pred(c), OUTPUT),
+ infer-comm(conseq(c)),
+ infer-comm(alt(c)))
+ (c:Connect) :
+ Connect(infer-exp(loc(c), INPUT), infer-exp(exp(c), OUTPUT))
+ (c) :
+ map(infer-comm, c)
+
+ ;=== Iterate until fix point ===
+ defn* fixpoint (c:Command) :
+ changed? = false
+ val c* = infer-comm(c)
+ if changed? : fixpoint(c*)
+ else : c*
+
+ Module(name(m), ports(m), body*) where :
+ val body* = fixpoint(body(m))
+
+defn infer-directions (c:Circuit) :
+ Circuit(modules*, main(c)) where :
+ val modules* = map(infer-dirs, modules(c))
+
+
+;============== EXPAND VECS ================================
+defstruct ManyConnect <: Command :
+ index: Expression
+ locs: List<Expression>
+ exp: Expression
+
+defstruct ConnectMany <: Command :
+ index: Expression
+ loc: Expression
+ exps: List<Expression>
+
+defmethod print (o:OutputStream, c:ManyConnect) :
+ print-all(o, [locs(c) "[" index(c) "] := " exp(c)])
+defmethod print (o:OutputStream, c:ConnectMany) :
+ print-all(o, [loc(c) " := " exps(c) "[" index(c) "]"])
+
+defmethod map (f: Expression -> Expression, c:ManyConnect) :
+ ManyConnect(f(index(c)), map(f, locs(c)), f(exp(c)))
+defmethod map (f: Expression -> Expression, c:ConnectMany) :
+ ConnectMany(f(index(c)), f(loc(c)), map(f, exps(c)))
+
+defn expand-accessors (m: Module) :
+ defn expand (c:Command) :
+ match(c) :
+ (c:WDefAccessor) :
+ ;Is the source a memory?
+ val mem? =
+ match(source(c)) :
+ (r:WRef) : kind(r) typeof MemKind
+ (r) : false
+
+ if mem? :
+ c
+ else :
+ switch {dir(c) == _} :
+ INPUT :
+ Begin(list(
+ DefWire(name(c), type(src-type))
+ ManyConnect(index(c), elems, wire-ref)))
+ where :
+ val src-type = type(source(c)) as VectorType
+ val wire-ref = WRef(name(c), type(src-type), NodeKind(), OUTPUT)
+ val elems = to-list $
+ for i in 0 to size(src-type) stream :
+ WIndex(source(c), i, type(src-type), INPUT)
+ OUTPUT :
+ Begin(list(
+ DefWire(name(c), type(src-type))
+ ConnectMany(index(c), wire-ref, elems)))
+ where :
+ val src-type = type(source(c)) as VectorType
+ val wire-ref = WRef(name(c), type(src-type), NodeKind(), INPUT)
+ val elems = to-list $
+ for i in 0 to size(src-type) stream :
+ WIndex(source(c), i, type(src-type), OUTPUT)
+ (c) :
+ map(expand, c)
+ Module(name(m), ports(m), expand(body(m)))
+
+defn expand-accessors (c:Circuit) :
+ Circuit(modules*, main(c)) where :
+ val modules* = map(expand-accessors, modules(c))
+
+
+
+;=============== BUNDLE FLATTENING =========================
+defn prefix (prefix, suffix) :
+ symbol-join([prefix "/" suffix])
+
+defn prefix-ports (pre:Symbol, ports:List<Port>) :
+ for p in ports map :
+ Port(prefix(pre, name(p)), direction(p), type(p))
+
+defn flatten-ports (port:Port) -> List<Port> :
+ match(type(port)) :
+ (t:BundleType) :
+ val ports = map-append(flatten-ports, ports(t))
+ for p in ports map :
+ Port(prefix(name(port), name(p)),
+ direction(port) * direction(p),
+ type(p))
+ (t:VectorType) :
+ val type* = flatten-type(t)
+ flatten-ports(Port(name(port), direction(port), type*))
+ (t:Type) :
+ list(port)
+
+defn flatten-type (t:Type) -> Type :
+ match(t) :
+ (t:BundleType) :
+ BundleType $
+ map-append(flatten-ports, ports(t))
+ (t:VectorType) :
+ flatten-type $ BundleType $ to-list $
+ for i in 0 to size(t) stream :
+ Port(to-symbol(i), OUTPUT, type(t))
+ (t:Type) :
+ t
+
+defn flatten-bundles (c:Circuit) :
+ defn flatten-exp (e:Expression) :
+ match(map(flatten-exp, e)) :
+ (e:UIntValue|SIntValue) :
+ e
+ (e:WRef) :
+ match(kind(e)) :
+ (k:MemKind|StructuralMemKind) :
+ val type* = map(flatten-type, type(e))
+ put-type(e, type*)
+ (k) :
+ val type* = flatten-type(type(e))
+ put-type(e, type*)
+ (e) :
+ val type* = flatten-type(type(e))
+ put-type(e, type*)
+
+ defn flatten-element (e:Element) :
+ val t* = flatten-type(type(e))
+ match(map(flatten-exp, e)) :
+ (e:Register) : Register(t*, value(e), enable(e))
+ (e:Memory) : Memory(t*, writers(e))
+ (e:Node) : Node(t*, value(e))
+ (e:Instance) : Instance(t*, module(e), ports(e))
+
+ defn flatten-comm (c:Command) :
+ match(c) :
+ (c:LetRec) :
+ val entries* =
+ for entry in entries(c) map :
+ key(entry) => flatten-element(value(entry))
+ LetRec(entries*, flatten-comm(body(c)))
+ (c:DefWire) :
+ DefWire(name(c), flatten-type(type(c)))
+ (c:DefRegister) :
+ DefRegister(name(c), flatten-type(type(c)))
+ (c:DefMemory) :
+ val type* = map(flatten-type, type(c))
+ DefMemory(name(c), type*)
+ (c) :
+ map{flatten-comm, _} $
+ map(flatten-exp, c)
+
+ defn flatten-module (m:Module) :
+ val ports* = map-append(flatten-ports, ports(m))
+ val body* = flatten-comm(body(m))
+ Module(name(m), ports*, body*)
+
+ Circuit(modules*, main(c)) where :
+ val modules* = map(flatten-module, modules(c))
+
+
+;================== BUNDLE EXPANSION =======================
+defn expand-bundles (m:Module) :
+
+ ;Collapse all field/index expressions
+ defn collapse-exp (e:Expression) -> Expression :
+ match(e) :
+ (e:WField) :
+ match(collapse-exp(exp(e))) :
+ (ei:WRef) :
+ if kind(ei) typeof InstanceKind :
+ e
+ else :
+ WRef(name*, type(e), kind(ei), dir(e)) where :
+ val name* = prefix(name(ei), name(e))
+ (ei:WField) :
+ WField(exp(ei), name*, type(e), dir(e)) where :
+ val name* = prefix(name(ei), name(e))
+ (e:WIndex) :
+ collapse-exp(WField(exp(e), name, type(e), dir(e))) where :
+ val name = to-symbol(value(e))
+ (e) :
+ map(collapse-exp, e)
+
+ ;Expand expressions
+ defn expand-exp (e:Expression) -> List<Expression> :
+ match(type(e)) :
+ (t:BundleType) :
+ for p in ports(t) map :
+ val dir* = direction(p) * dir(e)
+ collapse-exp(WField(e, name(p), type(p), dir*))
+ (t) :
+ list(collapse-exp(e))
+
+ ;Expand commands
+ defn expand-comm (c:Command) :
+ match(c) :
+ (c:DefWire) :
+ match(type(c)) :
+ (t:BundleType) :
+ Begin $
+ for p in ports(t) map :
+ DefWire(prefix(name(c), name(p)), type(p))
+ (t) :
+ c
+ (c:DefRegister) :
+ match(type(c)) :
+ (t:BundleType) :
+ Begin $
+ for p in ports(t) map :
+ DefRegister(prefix(name(c), name(p)), type(p))
+ (t) :
+ c
+ (c:DefMemory) :
+ match(type(type(c))) :
+ (t:BundleType) :
+ Begin $
+ for p in ports(t) map :
+ DefMemory(prefix(name(c), name(p)), type*) where :
+ val s = size(type(c))
+ val type* = VectorType(type(p), s)
+ (t) :
+ c
+ (c:WDefAccessor) :
+ match(type(source(c))) :
+ (t:BundleType) :
+ val srcs = expand-exp(source(c))
+ Begin $
+ for (p in ports(t), src in srcs) map :
+ WDefAccessor(name*, src, index(c), dir*) where :
+ val name* = prefix(name(c), name(p))
+ val dir* = direction(p) * dir(c)
+ (t) :
+ c
+ (c:Connect) :
+ val locs = expand-exp(loc(c))
+ val exps = expand-exp(exp(c))
+ Begin $
+ for (l in locs, e in exps) map :
+ switch {dir(l) == _} :
+ INPUT : Connect(l, e)
+ OUTPUT : Connect(e, l)
+ (c:ManyConnect) :
+ val locs-list = transpose(map(expand-exp, locs(c)))
+ val exps = expand-exp(exp(c))
+ Begin $
+ for (locs in locs-list, e in exps) map :
+ switch {dir(e) == _} :
+ OUTPUT : ManyConnect(index(c), locs, e)
+ INPUT : ConnectMany(index(c), e, locs)
+ (c:ConnectMany) :
+ val locs = expand-exp(loc(c))
+ val exps-list = transpose(map(expand-exp, exps(c)))
+ Begin $
+ for (l in locs, exps in exps-list) map :
+ switch {dir(l) == _} :
+ INPUT : ConnectMany(index(c), l, exps)
+ OUTPUT : ManyConnect(index(c), exps, l)
+ (c) :
+ map{expand-comm, _} $
+ map(collapse-exp, c)
+
+ Module(name(m), ports(m), expand-comm(body(m)))
+
+defn expand-bundles (c:Circuit) :
+ Circuit(modules*, main(c)) where :
+ val modules* = map(expand-bundles, modules(c))
+
+
+;=========== CONVERT MULTI CONNECTS to WHEN ================
+defn expand-multi-connects (c:Circuit) :
+ defn equal-exp (e1:Expression, e2:Expression) :
+ DoPrim(EQUAL-OP, list(e1, e2), List(), UIntType(UnknownWidth()))
+ defn uint (i:Int) :
+ UIntValue(i, UnknownWidth())
+
+ defn expand-comm (c:Command) :
+ match(c) :
+ (c:ConnectMany) :
+ Begin $ to-list $
+ for (i in 0 to false, e in exps(c)) stream :
+ Conditionally(equal-exp(index(c), uint(i)),
+ Connect(loc(c), e)
+ EmptyCommand())
+ (c:ManyConnect) :
+ Begin $ to-list $
+ for (i in 0 to false, l in locs(c)) stream :
+ Conditionally(equal-exp(index(c), uint(i)),
+ Connect(l, exp(c))
+ EmptyCommand())
+ (c) :
+ map(expand-comm, c)
+
+ defn expand (m:Module) :
+ Module(name(m), ports(m), expand-comm(body(m)))
+
+ Circuit(modules*, main(c)) where :
+ val modules* = map(expand, modules(c))
+
+
+;================ EXPAND WHENS =============================
+definterface SymbolicValue
+defstruct ExpValue <: SymbolicValue :
+ exp: Expression
+defstruct WhenValue <: SymbolicValue :
+ pred: Expression
+ conseq: SymbolicValue
+ alt: SymbolicValue
+defstruct VoidValue <: SymbolicValue
+
+defmethod print (o:OutputStream, sv:SymbolicValue) :
+ match(sv) :
+ (sv:VoidValue) : print(o, "VOID")
+ (sv:WhenValue) : print-all(o, ["(" pred(sv) "? " conseq(sv) " : " alt(sv) ")"])
+ (sv:ExpValue) : print(o, exp(sv))
+
+defn key-eqv? (e1:Expression, e2:Expression) :
+ match(e1, e2) :
+ (e1:WRef, e2:WRef) :
+ name(e1) == name(e2)
+ (e1:WField, e2:WField) :
+ name(e1) == name(e2) and
+ key-eqv?(exp(e1), exp(e2))
+ (e1, e2) :
+ false
+
+defn merge-env (pred: Expression,
+ con-env: List<KeyValue<Expression,SymbolicValue>>,
+ alt-env: List<KeyValue<Expression,SymbolicValue>>) :
+ val merged = Vector<KeyValue<Expression, SymbolicValue>>()
+ defn new-key? (k:Expression) :
+ for entry in merged none? :
+ key-eqv?(key(entry), k)
+
+ defn sv (env:List<KeyValue<Expression,SymbolicValue>>, k:Expression) :
+ for entry in env search :
+ if key-eqv?(key(entry), k) :
+ value(entry)
+
+ for k in stream(key, concat(con-env, alt-env)) do :
+ if new-key?(k) :
+ match(sv(con-env, k), sv(alt-env, k)) :
+ (a:SymbolicValue, b:SymbolicValue) :
+ if a == b :
+ add(merged, k => a)
+ else :
+ add(merged, k => WhenValue(pred, a, b))
+ (a:SymbolicValue, b:False) :
+ add(merged, k => a)
+ (a:False, b:SymbolicValue) :
+ add(merged, k => b)
+ (a:False, b:False) :
+ false
+
+ to-list(merged)
+
+defn simplify-env (env: List<KeyValue<Expression,SymbolicValue>>) :
+ val merged = Vector<KeyValue<Expression, SymbolicValue>>()
+ defn new-key? (k:Expression) :
+ for entry in merged none? :
+ key-eqv?(key(entry), k)
+ for entry in env do :
+ if new-key?(key(entry)) :
+ add(merged, entry)
+ to-list(merged)
+
+defn expand-whens (m:Module) :
+ val commands = Vector<Command>()
+ val elements = Vector<KeyValue<Symbol,Element>>()
+ defn eval (c:Command, env:List<KeyValue<Expression,SymbolicValue>>) ->
+ List<KeyValue<Expression,SymbolicValue>> :
+ match(c) :
+ (c:LetRec) :
+ do(add{elements, _}, entries(c))
+ eval(body(c), env)
+ (c:DefWire) :
+ add(commands, c)
+ val wire-ref = WRef(name(c), type(c), NodeKind(), INPUT)
+ List(wire-ref => VoidValue(), env)
+ (c:DefRegister) :
+ add(commands, c)
+ val reg-ref = WRef(name(c), type(c), RegKind(), INPUT)
+ List(reg-ref => VoidValue(), env)
+ (c:DefInstance) :
+ add(commands, c)
+ val entries = let :
+ val module-type = type(module(c)) as BundleType
+ val input-ports = to-list $
+ for p in ports(module-type) filter :
+ direction(p) == INPUT
+ val inst-ref = WRef(name(c), module-type, InstanceKind(), OUTPUT)
+ for p in input-ports map :
+ WField(inst-ref, name(p), type(p), INPUT) => VoidValue()
+ append(entries, env)
+ (c:DefMemory) :
+ add(commands, c)
+ env
+ (c:WDefAccessor) :
+ add(commands, c)
+ if dir(c) == INPUT :
+ val access-ref = WRef(name(c), type(source(c)), AccessorKind(), dir(c))
+ List(access-ref => VoidValue(), env)
+ else :
+ env
+ (c:Conditionally) :
+ val con-env = eval(conseq(c), env)
+ val alt-env = eval(alt(c), env)
+ merge-env(pred(c), con-env, alt-env)
+ (c:Begin) :
+ var env:List<KeyValue<Expression,SymbolicValue>> = env
+ for c in body(c) do :
+ env = eval(c, env)
+ env
+ (c:Connect) :
+ List(loc(c) => ExpValue(exp(c)), env)
+ (c:EmptyCommand) :
+ env
+
+ defn convert-symbolic (key:Expression, sv:SymbolicValue) :
+ match(sv) :
+ (sv:VoidValue) :
+ throw $ PassException $ string-join $ [
+ "No default value for " key "."]
+ (sv:ExpValue) :
+ exp(sv)
+ (sv:WhenValue) :
+ defn multiplex-exp (pred:Expression, conseq:Expression, alt:Expression) :
+ DoPrim(MULTIPLEX-OP, list(pred, conseq, alt), List(), type(conseq))
+ multiplex-exp(pred(sv),
+ convert-symbolic(key, conseq(sv))
+ convert-symbolic(key, alt(sv)))
+
+ ;Compute final environment
+ val env0 = let :
+ val output-ports = to-list $
+ for p in ports(m) filter :
+ direction(p) == OUTPUT
+ for p in output-ports map :
+ val port-ref = WRef(name(p), type(p), PortKind(), INPUT)
+ port-ref => VoidValue()
+ val env* = simplify-env(eval(body(m), env0))
+
+ ;Make new body
+ val body* = Begin(list(defs, LetRec(elems, connections))) where :
+ val defs = Begin(to-list(commands))
+ val elems = to-list(elements)
+ val connections = Begin $
+ for entry in env* map :
+ val sv = convert-symbolic(key(entry), value(entry))
+ Connect(key(entry), sv)
+
+ ;Final module
+ Module(name(m), ports(m), body*)
+
+defn expand-whens (c:Circuit) :
+ val modules* = map(expand-whens, modules(c))
+ Circuit(modules*, main(c))
+
+
+;================ STRUCTURAL FORM ==========================
+
+defn structural-form (m:Module) :
+ val elements = Vector<(() -> KeyValue<Symbol,Element>)>()
+ val connected = HashTable<Symbol, Expression>(symbol-hash)
+ val write-accessors = HashTable<Symbol, List<WDefAccessor>>(symbol-hash)
+ val read-accessors = HashTable<Symbol, WDefAccessor>(symbol-hash)
+ val inst-ports = HashTable<Symbol, List<KeyValue<Symbol, Expression>>>(symbol-hash)
+ val port-connects = Vector<Connect>()
+
+ defn scan (c:Command) :
+ match(c) :
+ (c:Connect) :
+ match(loc(c)) :
+ (loc:WRef) :
+ match(kind(loc)) :
+ (k:PortKind) : add(port-connects, c)
+ (k) : connected[name(loc)] = exp(c)
+ (loc:WField) :
+ val inst = exp(loc) as WRef
+ val entry = name(loc) => exp(c)
+ inst-ports[name(inst)] = List(entry, get?(inst-ports, name(inst), List()))
+ (c:LetRec) :
+ for e in entries(c) do :
+ add(elements, {e})
+ scan(body(c))
+ (c:DefWire) :
+ add{elements, _} $ fn () :
+ name(c) => Node(type(c), connected[name(c)])
+ (c:DefRegister) :
+ add{elements, _} $ fn () :
+ val one = UIntValue(1, UnknownWidth())
+ name(c) => Register(type(c), connected[name(c)], one)
+ (c:DefInstance) :
+ add{elements, _} $ fn () :
+ name(c) => Instance(UnknownType(), module(c), inst-ports[name(c)])
+ (c:DefMemory) :
+ add{elements, _} $ fn () :
+ val ports = for a in get?(write-accessors, name(c), List()) map :
+ val one = UIntValue(1, UnknownWidth())
+ WritePort(index(a), connected[name(a)], one)
+ name(c) => Memory(type(c), ports)
+ (c:WDefAccessor) :
+ val mem = source(c) as WRef
+ switch {dir(c) == _} :
+ INPUT :
+ write-accessors[name(mem)] = List(c,
+ get?(write-accessors, name(mem), List()))
+ OUTPUT :
+ read-accessors[name(c)] = c
+ (c) :
+ do(scan, children(c))
+
+ defn make-read-ports (e:Expression) :
+ match(e) :
+ (e:WRef) :
+ match(kind(e)) :
+ (k:AccessorKind) :
+ val accessor = read-accessors[name(e)]
+ ReadPort(source(accessor), index(accessor), type(e))
+ (k) : e
+ (e) : map(make-read-ports, e)
+
+ Module(name(m), ports(m), body*) where :
+ scan(body(m))
+ val elems = to-list $
+ for e in elements stream :
+ val entry = e()
+ key(entry) => map(make-read-ports, value(entry))
+ val connect-ports = Begin $ to-list $
+ for c in port-connects stream :
+ Connect(loc(c), make-read-ports(exp(c)))
+ val body* =
+ if empty?(elems) : connect-ports
+ else : LetRec(elems, connect-ports)
+
+defn structural-form (c:Circuit) :
+ val modules* = map(structural-form, modules(c))
+ Circuit(modules*, main(c))
+
+
+;==================== WIDTH INFERENCE ======================
+defstruct WidthVar <: Width :
+ name: Symbol
+
+defmethod print (o:OutputStream, w:WidthVar) :
+ print(o, name(w))
+
+defn width! (t:Type) :
+ match(t) :
+ (t:UIntType) : width(t)
+ (t:SIntType) : width(t)
+ (t) : error("No width field.")
+
+defn put-width (t:Type, w:Width) :
+ match(t) :
+ (t:UIntType) : UIntType(w)
+ (t:SIntType) : SIntType(w)
+ (t) : t
+
+defn put-width (e:Expression, w:Width) :
+ val type* = put-width(type(e), w)
+ put-type(e, type*)
+
+defn add-width-vars (t:Type) :
+ defn width? (w:Width) :
+ match(w) :
+ (w:UnknownWidth) : WidthVar(gensym())
+ (w) : w
+ match(t) :
+ (t:UIntType) : UIntType(width?(width(t)))
+ (t:SIntType) : SIntType(width?(width(t)))
+ (t) : map(add-width-vars, t)
+
+defn uint-width (i:Int) :
+ var v:Int = i
+ var n:Int = 0
+ while v != 0 :
+ v = v >> 1
+ n = n + 1
+ IntWidth(n)
+
+defn sint-width (i:Int) :
+ if i > 0 :
+ val w = uint-width(i)
+ IntWidth(width(w) + 1)
+ else :
+ val w = uint-width(neg(i) - 1)
+ IntWidth(width(w) + 1)
+
+defn to-exp (w:Width) :
+ match(w) :
+ (w:IntWidth) : ELit(width(w))
+ (w:WidthVar) : EVar(name(w))
+ (w) : error $ string-join $ [
+ "Cannot convert " w " to exp."]
+
+defn primop-width (op:PrimOp, ws:List<Width>, ints:List<Int>) -> Exp :
+ defn wmax (w1:Width, w2:Width) :
+ EMax(to-exp(w1), to-exp(w2))
+ defn wplus (w1:Width, w2:Width) :
+ EPlus(to-exp(w1), to-exp(w2))
+ defn wplus (w1:Width, w2:Int) :
+ EPlus(to-exp(w1), ELit(w2))
+ defn wminus (w1:Width, w2:Width) :
+ EMinus(to-exp(w1), to-exp(w2))
+ defn wminus (w1:Width, w2:Int) :
+ EMinus(to-exp(w1), ELit(w2))
+ defn wmax-inc (w1:Width, w2:Width) :
+ EPlus(wmax(w1, w2), ELit(1))
+
+ switch {op == _} :
+ ADD-OP : wmax-inc(ws[0], ws[1])
+ ADD-MOD-OP : wmax(ws[0], ws[1])
+ SUB-OP : wmax-inc(ws[0], ws[1])
+ SUB-MOD-OP : wmax(ws[0], ws[1])
+ TIMES-OP : wplus(ws[0], ws[1])
+ DIVIDE-OP : wminus(ws[0], ws[1])
+ MOD-OP : to-exp(ws[1])
+ SHIFT-LEFT-OP : wplus(ws[0], ints[0])
+ SHIFT-RIGHT-OP : wminus(ws[0], ints[0])
+ PAD-OP : ELit(ints[0])
+ BIT-AND-OP : wmax(ws[0], ws[1])
+ BIT-OR-OP : wmax(ws[0], ws[1])
+ BIT-XOR-OP : wmax(ws[0], ws[1])
+ CONCAT-OP : wplus(ws[0], ints[0])
+ BIT-SELECT-OP : ELit(1)
+ BITS-SELECT-OP : ELit(ints[0])
+ MULTIPLEX-OP : wmax(ws[1], ws[2])
+ LESS-OP : ELit(1)
+ LESS-EQ-OP : ELit(1)
+ GREATER-OP : ELit(1)
+ GREATER-EQ-OP : ELit(1)
+ EQUAL-OP : ELit(1)
+
+defn put-type (el:Element, t:Type) :
+ match(el) :
+ (el:Register) : Register(t, value(el), enable(el))
+ (el:Memory) : Memory(t, writers(el))
+ (el:Node) : Node(t, value(el))
+ (el:Instance) : Instance(t, module(el), ports(el))
+
+defn generate-constraints (c:Circuit) -> [Circuit, Vector<WConstraint>] :
+ ;Constraints
+ val cs = Vector<WConstraint>()
+ defn new-constraint (Constraint: (Symbol, Exp) -> WConstraint, wvar:Width, width:Width) :
+ match(wvar) :
+ (wvar:WidthVar) :
+ add(cs, Constraint(name(wvar), to-exp(width)))
+ (wvar) :
+ false
+
+ defn to-width (e:Exp) :
+ match(e) :
+ (e:ELit) :
+ IntWidth(width(e))
+ (e:EVar) :
+ WidthVar(name(e))
+ (e) :
+ val x = gensym()
+ add(cs, WidthEqual(x, e))
+ WidthVar(x)
+
+ ;Module types
+ val mod-types = HashTable<Symbol,Type>(symbol-hash)
+
+ defn add-port-vars (m:Module) -> Module :
+ val ports* =
+ for p in ports(m) map :
+ val type* = add-width-vars(type(p))
+ Port(name(p), direction(p), type*)
+ mod-types[name(m)] = BundleType(ports*)
+ Module(name(m), ports*, body(m))
+
+ ;Add Width Variables
+ defn add-module-vars (m:Module) -> Module :
+ val types = HashTable<Symbol,Type>(symbol-hash)
+ for p in ports(m) do :
+ types[name(p)] = type(p)
+
+ defn infer-exp-width (e:Expression) :
+ match(map(infer-exp-width, e)) :
+ (e:WRef) :
+ match(kind(e)) :
+ (k:ModuleKind) : put-type(e, mod-types[name(e)])
+ (k) : put-type(e, types[name(e)])
+ (e:WField) :
+ val t = bundle-field-type(type(exp(e)), name(e))
+ put-width(e, width!(t))
+ (e:UIntValue) :
+ match(width(e)) :
+ (w:UnknownWidth) : UIntValue(value(e), uint-width(value(e)))
+ (w) : e
+ (e:SIntValue) :
+ match(width(e)) :
+ (w:UnknownWidth) : SIntValue(value(e), sint-width(value(e)))
+ (w) : e
+ (e:DoPrim) :
+ val widths = map(width!{type(_)}, args(e))
+ val w = to-width(primop-width(op(e), widths, consts(e)))
+ put-width(e, w)
+ (e:ReadPort) :
+ val elem-type = type(type(mem(e)) as VectorType)
+ put-width(e, width!(elem-type))
+
+ defn infer-comm-width (c:Command) :
+ match(c) :
+ (c:LetRec) :
+ ;Add width vars to elements
+ var entries*: List<KeyValue<Symbol,Element>> =
+ for entry in entries(c) map :
+ val el-name = key(entry)
+ key(entry) =>
+ match(value(entry)) :
+ (el:Register|Node) :
+ put-type(el, add-width-vars(type(el)))
+ (el:Memory) :
+ el
+ (el:Instance) :
+ val mod-type = type(infer-exp-width(module(el))) as BundleType
+ val type = BundleType $ to-list $
+ for p in ports(mod-type) filter :
+ direction(p) == OUTPUT
+ put-type(el, type)
+
+ ;Add vars to environment
+ for entry in entries* do :
+ types[key(entry)] = type(value(entry))
+
+ ;Infer types for elements
+ entries* =
+ for entry in entries* map :
+ key(entry) => map(infer-exp-width, value(entry))
+
+ ;Generate constraints
+ for entry in entries* do :
+ val el-name = key(entry)
+ match(value(entry)) :
+ (el:Register) :
+ new-constraint(WidthEqual, reg-width, val-width) where :
+ val reg-width = width!(types[el-name])
+ val val-width = width!(type(value(el)))
+ (el:Node) :
+ new-constraint(WidthEqual, node-width, val-width) where :
+ val node-width = width!(types[el-name])
+ val val-width = width!(type(value(el)))
+ (el:Instance) :
+ val mod-type = type(module(el)) as BundleType
+ for entry in ports(el) do :
+ new-constraint(WidthGreater, port-width, val-width) where :
+ val port-name = key(entry)
+ val port-width = width!(bundle-field-type(mod-type, port-name))
+ val val-width = width!(type(value(entry)))
+ (el) : false
+
+ ;Analyze body
+ LetRec(entries*, infer-comm-width(body(c)))
+
+ (c:Connect) :
+ val loc* = infer-exp-width(loc(c))
+ val exp* = infer-exp-width(exp(c))
+ new-constraint(WidthGreater, loc-width, exp-width) where :
+ val loc-width = width!(type(loc*))
+ val exp-width = width!(type(exp*))
+ Connect(loc*, exp*)
+
+ (c:Begin) :
+ map(infer-comm-width, c)
+
+ Module(name(m), ports(m), body*) where :
+ val body* = infer-comm-width(body(m))
+
+ val c* =
+ Circuit(modules*, main(c)) where :
+ val ms = map(add-port-vars, modules(c))
+ val modules* = map(add-module-vars, ms)
+ [c*, cs]
+
+
+;================== FILL WIDTHS ============================
+defn fill-widths (c:Circuit, solved:Streamable<WidthEqual>) :
+ ;Populate table
+ val table = HashTable<Symbol, Width>(symbol-hash)
+ for eq in solved do :
+ table[name(eq)] = IntWidth(width(value(eq) as ELit))
+
+ defn width? (w:Width) :
+ match(w) :
+ (w:WidthVar) : get?(table, name(w), UnknownWidth())
+ (w) : w
+
+ defn fill-type (t:Type) :
+ match(t) :
+ (t:UIntType) : UIntType(width?(width(t)))
+ (t:SIntType) : SIntType(width?(width(t)))
+ (t) : map(fill-type, t)
+
+ defn fill-exp (e:Expression) :
+ val e* = map(fill-exp, e)
+ val type* = fill-type(type(e))
+ put-type(e*, type*)
+
+ defn fill-element (e:Element) :
+ val e* = map(fill-exp, e)
+ val type* = fill-type(type(e))
+ put-type(e*, type*)
+
+ defn fill-comm (c:Command) :
+ match(c) :
+ (c:LetRec) :
+ val entries* =
+ for e in entries(c) map :
+ key(e) => fill-element(value(e))
+ LetRec(entries*, fill-comm(body(c)))
+ (c) :
+ map{fill-comm, _} $
+ map(fill-exp, c)
+
+ defn fill-port (p:Port) :
+ Port(name(p), direction(p), fill-type(type(p)))
+
+ defn fill-mod (m:Module) :
+ Module(name(m), ports*, body*) where :
+ val ports* = map(fill-port, ports(m))
+ val body* = fill-comm(body(m))
+
+ Circuit(modules*, main(c)) where :
+ val modules* = map(fill-mod, modules(c))
+
+
+;=============== TYPE INFERENCE DRIVER =====================
+defn infer-widths (c:Circuit) :
+ val [c*, cs] = generate-constraints(c)
+ val solved = solve-widths(cs)
+ fill-widths(c*, solved)
+
+
+;================ PAD WIDTHS ===============================
+defn pad-widths (c:Circuit) :
+ ;Pad an expression to the given width
+ defn pad-exp (e:Expression, w:Int) :
+ match(type(e)) :
+ (t:UIntType|SIntType) :
+ val prev-w = width!(t) as IntWidth
+ if width(prev-w) < w :
+ val type* = put-width(t, IntWidth(w))
+ DoPrim(PAD-OP, list(e), list(w), type*)
+ else :
+ e
+ (t) :
+ e
+
+ defn pad-exp (e:Expression, w:Width) :
+ val w-value = width(w as IntWidth)
+ pad-exp(e, w-value)
+
+ ;Convenience
+ defn max-width (es:Streamable<Expression>) :
+ defn int-width (e:Expression) :
+ width(width!(type(e)) as IntWidth)
+ maximum(stream(int-width, es))
+
+ defn match-widths (es:List<Expression>) :
+ val w = max-width(es)
+ map(pad-exp{_, w}, es)
+
+ ;Match widths for an expression
+ defn match-exp-width (e:Expression) :
+ match(map(match-exp-width, e)) :
+ (e:DoPrim) :
+ if contains?([BIT-AND-OP, BIT-OR-OP, BIT-XOR-OP, EQUAL-OP], op(e)) :
+ val args* = match-widths(args(e))
+ DoPrim(op(e), args*, consts(e), type(e))
+ else if op(e) == MULTIPLEX-OP :
+ val args* = List(head(args(e)), match-widths(tail(args(e))))
+ DoPrim(op(e), args*, consts(e), type(e))
+ else :
+ e
+ (e) : e
+
+ defn match-element-width (e:Element) :
+ match(map(match-exp-width, e)) :
+ (e:Register) :
+ val w = width!(type(e))
+ val value* = pad-exp(value(e), w)
+ Register(type(e), value*, enable(e))
+ (e:Memory) :
+ val width = width!(type(type(e) as VectorType))
+ val writers* =
+ for w in writers(e) map :
+ WritePort(index(w), pad-exp(value(w), width), enable(w))
+ Memory(type(e), writers*)
+ (e:Node) :
+ val w = width!(type(e))
+ val value* = pad-exp(value(e), w)
+ Node(type(e), value*)
+ (e:Instance) :
+ val mod-type = type(module(e)) as BundleType
+ val ports* =
+ for p in ports(e) map :
+ val port-type = bundle-field-type(mod-type, key(p))
+ val port-val = pad-exp(value(p), width!(port-type))
+ key(p) => port-val
+ Instance(type(e), module(e), ports*)
+
+ ;Match widths for a command
+ defn match-comm-width (c:Command) :
+ match(map(match-exp-width, c)) :
+ (c:LetRec) :
+ val entries* =
+ for e in entries(c) map :
+ key(e) => match-element-width(value(e))
+ LetRec(entries*, match-comm-width(body(c)))
+ (c:Connect) :
+ val w = width!(type(loc(c)))
+ val exp* = pad-exp(exp(c), w)
+ Connect(loc(c), exp*)
+ (c) :
+ map(match-comm-width, c)
+
+ defn match-mod-width (m:Module) :
+ Module(name(m), ports(m), body*) where :
+ val body* = match-comm-width(body(m))
+
+ Circuit(modules*, main(c)) where :
+ val modules* = map(match-mod-width, modules(c))
+
+
+;================== INLINING ===============================
+defn inline-instances (c:Circuit) :
+ val module-table = HashTable<Symbol,Module>(symbol-hash)
+ val inlined? = HashTable<Symbol,True|False>(symbol-hash)
+ for m in modules(c) do :
+ module-table[name(m)] = m
+ inlined?[name(m)] = false
+
+ ;Convert a module into a sequence of elements
+ defn to-elements (m:Module,
+ inst:Symbol,
+ port-exps:List<KeyValue<Symbol,Expression>>) ->
+ List<KeyValue<Symbol, Element>> :
+ defn rename-exp (e:Expression) :
+ match(e) :
+ (e:WRef) : WRef(prefix(inst, name(e)), type(e), kind(e), dir(e))
+ (e) : map(rename-exp, e)
+
+ defn to-elements (c:Command) -> List<KeyValue<Symbol,Element>> :
+ match(c) :
+ (c:LetRec) :
+ val entries* =
+ for entry in entries(c) map :
+ val name* = prefix(inst, key(entry))
+ val element* = map(rename-exp, value(entry))
+ name* => element*
+ val body* = to-elements(body(c))
+ append(entries*, body*)
+ (c:Connect) :
+ val ref = loc(c) as WRef
+ val name* = prefix(inst, name(ref))
+ list(name* => Node(type(exp(c)), rename-exp(exp(c))))
+ (c:Begin) :
+ map-append(to-elements, body(c))
+
+ val inputs =
+ for p in ports(m) map-append :
+ if direction(p) == INPUT :
+ val port-exp = lookup!(port-exps, name(p))
+ val name* = prefix(inst, name(p))
+ list(name* => Node(type(port-exp), port-exp))
+ else :
+ List()
+ append(inputs, to-elements(body(m)))
+
+ ;Inline all instances in the module
+ defn inline-instances (m:Module) :
+ defn rename-exp (e:Expression) :
+ match(e) :
+ (e:WField) :
+ val inst-exp = exp(e) as WRef
+ val name* = prefix(name(inst-exp), name(e))
+ WRef(name*, type(e), NodeKind(), dir(e))
+ (e) :
+ map(rename-exp, e)
+
+ defn inline-elems (es:List<KeyValue<Symbol,Element>>) :
+ for entry in es map-append :
+ match(value(entry)) :
+ (el:Instance) :
+ val mod-name = name(module(el) as WRef)
+ val module = inlined-module(mod-name)
+ to-elements(module, key(entry), ports(el))
+ (el) :
+ list(entry)
+
+ defn inline-comm (c:Command) :
+ match(map(rename-exp, c)) :
+ (c:LetRec) :
+ val entries* = inline-elems(entries(c))
+ LetRec(entries*, inline-comm(body(c)))
+ (c) :
+ map(inline-comm, c)
+
+ Module(name(m), ports(m), inline-comm(body(m)))
+
+ ;Retrieve the inlined instance of a module
+ defn inlined-module (name:Symbol) :
+ if inlined?[name] :
+ module-table[name]
+ else :
+ val module* = inline-instances(module-table[name])
+ module-table[name] = module*
+ inlined?[name] = true
+ module*
+
+ ;Return the fully inlined circuit
+ val main-module = inlined-module(main(c))
+ Circuit(list(main-module), main(c))
+
+
+;;;================ UTILITIES ================================
+;
+;
+;
+;defn* root-ref (i:Immediate) :
+; match(i) :
+; (f:Field) : root-ref(imm(f))
+; (ind:Index) : root-ref(imm(ind))
+; (r) : r
+;
+;;defn lookup<?T> (e: Streamable<KeyValue<Immediate,?T>>, i:Immediate) :
+;; for entry in e search :
+;; if eqv?(key(entry), i) :
+;; value(entry)
+;;
+;;defn lookup!<?T> (e: Streamable<KeyValue<Immediate,?T>>, i:Immediate) :
+;; lookup(e, i) as T
+;;
+;;============ CHECK IF NAMES ARE UNIQUE ====================
+;defn check-duplicate-symbols (names: Streamable<Symbol>, msg: String) :
+; val dict = HashTable<Symbol, True>(symbol-hash)
+; for name in names do:
+; if key?(dict, name):
+; throw $ PassException $ string-join $
+; [msg ": " name]
+; else:
+; dict[name] = true
+;
+;defn check-duplicates (t: Type) :
+; match(t) :
+; (t:BundleType) :
+; val names = map(name, ports(t))
+; check-duplicate-symbols{names, string-join(_)} $
+; ["Duplicate port name in bundle "]
+; do(check-duplicates{type(_)}, ports(t))
+; (t:VectorType) :
+; check-duplicates(type(t))
+; (t) : false
+;
+;defn check-duplicates (c: Command) :
+; match(c) :
+; (c:DefWire) : check-duplicates(type(c))
+; (c:DefRegister) : check-duplicates(type(c))
+; (c:DefMemory) : check-duplicates(type(c))
+; (c) : do(check-duplicates, children(c))
+;
+;defn defined-names (c: Command) :
+; generate<Symbol> :
+; loop(c) where :
+; defn loop (c:Command) :
+; match(c) :
+; (c:Command&HasName) : yield(name(c))
+; (c) : do(loop, children(c))
+;
+;defn check-duplicates (m: Module):
+; ;Check all duplicate names in all types in all ports and body
+; do(check-duplicates{type(_)}, ports(m))
+; check-duplicates(body(m))
+;
+; ;Check all names defined in module
+; val names = concat(stream(name, ports(m)),
+; defined-names(body(m)))
+; check-duplicate-symbols{names, string-join(_)} $
+; ["Duplicate definition name in module " name(m)]
+;
+;defn check-duplicates (c: Circuit) :
+; ;Check all duplicate names in all modules
+; do(check-duplicates, modules(c))
+;
+; ;Check all defined modules
+; val names = stream(name, modules(c))
+; check-duplicate-symbols(names, "Duplicate module name")
+;
+;
+
+
+
+
+
+
+
+;;================ CLEANUP COMMANDS =========================
+;defn cleanup (c:Command) :
+; match(c) :
+; (c:Begin) :
+; to-command $ generate<Command> :
+; loop(c) where :
+; defn loop (c:Command) :
+; match(c) :
+; (c:Begin) : do(loop, body(c))
+; (c:EmptyCommand) : false
+; (c) : yield(cleanup(c))
+; (c) : map(cleanup{_ as Command}, c)
+;
+;defn cleanup (c:Circuit) :
+; val modules* =
+; for m in modules(c) map :
+; map(cleanup, m)
+; Circuit(modules*, main(c))
+;
+
+
+
+;;;============= SHIM ========================================
+;;defn shim (i:Immediate) -> Immediate :
+;; match(i) :
+;; (i:RegData) :
+;; Ref(name(i), direction(i), type(i))
+;; (i:InstPort) :
+;; val inst = Ref(name(i), UNKNOWN-DIR, UnknownType())
+;; Field(inst, port(i), direction(i), type(i))
+;; (i:Field) :
+;; val imm* = shim(imm(i))
+;; put-imm(i, imm*)
+;; (i) : i
+;;
+;;defn shim (c:Command) -> Command :
+;; val c* = map(shim{_ as Immediate}, c)
+;; map(shim{_ as Command}, c*)
+;;
+;;defn shim (c:Circuit) -> Circuit :
+;; val modules* =
+;; for m in modules(c) map :
+;; Module(name(m), ports(m), shim(body(m)))
+;; Circuit(modules*, main(c))
+;;
+;;;================== INLINE MODULES =========================
+;;defn cat-name (p: String|Symbol, s: String|Symbol) -> Symbol :
+;; if p == "" or p == `this : ;; TODO: REMOVE THIS WHEN `THIS GETS REMOVED
+;; to-symbol(s)
+;; else if s == `this : ;; TODO: DITTO
+;; to-symbol(p)
+;; else :
+;; symbol-join([p, "/", s])
+;;
+;;defn inline-command (c: Command, mods: HashTable<Symbol, Module>, prefix: String, cmds: Vector<Command>) :
+;; defn rename (n: Symbol) -> Symbol :
+;; cat-name(prefix, n)
+;; defn inline-name (i:Immediate) -> Symbol :
+;; match(i) :
+;; (r:Ref) : rename(name(r))
+;; (f:Field) : cat-name(inline-name(imm(f)), name(f))
+;; (f:Index) : cat-name(inline-name(imm(f)), to-string(value(f)))
+;; defn inline-imm (i:Immediate) -> Ref :
+;; Ref(inline-name(i), direction(i), type(i))
+;; match(c) :
+;; (c:DefUInt) : add(cmds, DefUInt(rename(name(c)), value(c), width(c)))
+;; (c:DefSInt) : add(cmds, DefSInt(rename(name(c)), value(c), width(c)))
+;; (c:DefWire) : add(cmds, DefWire(rename(name(c)), type(c)))
+;; (c:DefRegister) : add(cmds, DefRegister(rename(name(c)), type(c)))
+;; (c:DefMemory) : add(cmds, DefMemory(rename(name(c)), type(c), size(c)))
+;; (c:DefInstance) : inline-module(mods, mods[name(module(c))], to-string(rename(name(c))), cmds)
+;; (c:DoPrim) : add(cmds, DoPrim(rename(name(c)), op(c), map(inline-imm, args(c)), consts(c)))
+;; (c:DefAccessor) : add(cmds, DefAccessor(rename(name(c)), inline-imm(source(c)), direction(c), inline-imm(index(c))))
+;; (c:Connect) : add(cmds, Connect(inline-imm(loc(c)), inline-imm(exp(c))))
+;; (c:Begin) : do(inline-command{_, mods, prefix, cmds}, body(c))
+;; (c:EmptyCommand) : c
+;; (c) : error("Unsupported command")
+;;
+;;defn inline-port (p: Port, prefix: String) -> Command :
+;; DefWire(cat-name(prefix, name(p)), type(p))
+;;
+;;defn inline-module (mods: HashTable<Symbol, Module>, mod: Module, prefix: String, cmds: Vector<Command>) :
+;; do(add{cmds, _}, map(inline-port{_, prefix}, ports(mod)))
+;; inline-command(body(mod), mods, prefix, cmds)
+;;
+;;defn inline-modules (c: Circuit) -> Circuit :
+;; val cmds = Vector<Command>()
+;; val mods = HashTable<Symbol, Module>(symbol-hash)
+;; for mod in modules(c) do :
+;; mods[name(mod)] = mod
+;; val top = mods[main(c)]
+;; inline-command(body(top), mods, "", cmds)
+;; val main* = Module(name(top), ports(top), Begin(to-list(cmds)))
+;; Circuit(list(main*), name(top))
+;;
+;;
+;;;============= FLO PRINTER ======================================
+;;;;; TODO:
+;;;;; not supported gt, lte
+;;
+;;defn flo-op-name (op:PrimOp) -> String :
+;; switch {op == _} :
+;; ADD-OP : "add"
+;; ADD-MOD-OP : "add"
+;; MINUS-OP : "sub"
+;; SUB-MOD-OP : "sub"
+;; TIMES-OP : "mul" ;; todo: signed version
+;; DIVIDE-OP : "div" ;; todo: signed version
+;; MOD-OP : "mod" ;; todo: signed version
+;; SHIFT-LEFT-OP : "lsh" ;; todo: signed version
+;; SHIFT-RIGHT-OP : "rsh"
+;; PAD-OP : "pad" ;; todo: signed version
+;; BIT-AND-OP : "and"
+;; BIT-OR-OP : "or"
+;; BIT-XOR-OP : "xor"
+;; CONCAT-OP : "cat"
+;; BIT-SELECT-OP : "rsh"
+;; BITS-SELECT-OP : "rsh"
+;; LESS-OP : "lt" ;; todo: signed version
+;; LESS-EQ-OP : "lte" ;; todo: swap args
+;; GREATER-OP : "gt" ;; todo: swap args
+;; GREATER-EQ-OP : "gte" ;; todo: signed version
+;; EQUAL-OP : "eq"
+;; MULTIPLEX-OP : "mux"
+;; else : error $ string-join $
+;; ["Unable to print Primop: " op]
+;;
+;;defn emit (o:OutputStream, top:Symbol, ports:HashTable<Symbol, Port>, lits:HashTable<Symbol, DefUInt>, elt) :
+;; match(elt) :
+;; (e:String|Symbol|Int) :
+;; print(o, e)
+;; (e:Ref) :
+;; if key?(lits, name(e)) :
+;; val lit = lits[name(e)]
+;; print-all(o, [value(lit) "'" width(lit)])
+;; else :
+;; if key?(ports, name(e)) :
+;; print-all(o, [top "::"])
+;; print(o, name(e))
+;; (e:IntWidth) :
+;; print(o, value(e))
+;; (e:PrimOp) :
+;; print(o, flo-op-name(e))
+;; (e) :
+;; println-all(["EMIT " e])
+;; error("Unable to emit")
+;;
+;;defn emit-all (o:OutputStream, top:Symbol, ports:HashTable<Symbol, Port>, lits:HashTable<Symbol, DefUInt>, elts: Streamable) :
+;; for e in elts do : emit(o, top, ports, lits, e)
+;;
+;;defn prim-width (type:Type) -> Width :
+;; match(type) :
+;; (t:UIntType) : width(t)
+;; (t:SIntType) : width(t)
+;; (t) : error("Bad prim width type")
+;;
+;;defn emit-command (o:OutputStream, cmd:Command, top:Symbol, lits:HashTable<Symbol, DefUInt>, regs:HashTable<Symbol, DefRegister>, accs:HashTable<Symbol, DefAccessor>, ports:HashTable<Symbol, Port>, outs:HashTable<Symbol, Port>) :
+;; match(cmd) :
+;; (c:DefUInt) :
+;; lits[name(c)] = c
+;; (c:DefSInt) :
+;; emit-all(o, top, ports, lits, [name(c) " = " value(c) "'" width(c) "\n"])
+;; (c:DoPrim) : ;; NEED TO FIGURE OUT WHEN WIDTHS ARE NECESSARY AND EXTRACT
+;; emit-all(o, top, ports, lits, [name(c) " = " op(c)])
+;; for arg in args(c) do :
+;; print(o, " ")
+;; emit(o, top, ports, lits, arg)
+;; for const in consts(c) do :
+;; print(o, " ")
+;; emit(o, top, ports, lits, const)
+;; print("\n")
+;; (c:DefRegister) :
+;; regs[name(c)] = c
+;; (c:DefMemory) :
+;; emit-all(o, top, ports, lits, [name(c) " : mem'" prim-width(type(c)) " " size(c) "\n"])
+;; (c:DefAccessor) :
+;; accs[name(c)] = c
+;; (c:Connect) :
+;; val dst = name(loc(c) as Ref)
+;; val src = name(exp(c) as Ref)
+;; if key?(regs, dst) :
+;; val reg = regs[dst]
+;; emit-all(o, top, ports, lits, [dst " = reg'" prim-width(type(reg)) " 0'" prim-width(type(reg)) " " exp(c) "\n"])
+;; else if key?(accs, dst) :
+;; val acc = accs[dst]
+;; ;; assert(direction(acc) == OUTPUT)
+;; emit-all(o, top, ports, lits, [dst " = wr " source(acc) " " index(acc) " " exp(c) "\n"])
+;; else if key?(outs, dst) :
+;; val out = outs[dst]
+;; emit-all(o, top, ports, lits, [top "::" dst " = out'" prim-width(type(out)) " " exp(c) "\n"])
+;; else if key?(accs, src) :
+;; val acc = accs[src]
+;; ;; assert(direction(acc) == INPUT)
+;; emit-all(o, top, ports, lits, [dst " = rd " source(acc) " " index(acc) "\n"])
+;; else :
+;; emit-all(o, top, ports, lits, [dst " = mov " exp(c) "\n"])
+;; (c:Begin) :
+;; do(emit-command{o, _, top, lits, regs, accs, ports, outs}, body(c))
+;; (c:DefWire|EmptyCommand) :
+;; print("")
+;; (c) :
+;; error("Unable to print command")
+;;
+;;defn emit-module (o:OutputStream, m:Module) :
+;; val regs = HashTable<Symbol, DefRegister>(symbol-hash)
+;; val accs = HashTable<Symbol, DefAccessor>(symbol-hash)
+;; val lits = HashTable<Symbol, DefUInt>(symbol-hash)
+;; val outs = HashTable<Symbol, Port>(symbol-hash)
+;; val portz = HashTable<Symbol, Port>(symbol-hash)
+;; for port in ports(m) do :
+;; portz[name(port)] = port
+;; if direction(port) == OUTPUT :
+;; outs[name(port)] = port
+;; else if name(port) == `reset :
+;; print-all(o, [name(m) "::reset = rst\n"])
+;; else :
+;; print-all(o, [name(m) "::" name(port) " = " "in'" prim-width(type(port)) "\n"])
+;; emit-command(o, body(m), name(m), lits, regs, accs, portz, outs)
+;;
+;;public defn emit-circuit (o:OutputStream, c:Circuit) :
+;; emit-module(o, modules(c)[0])
+
+
+;============= DRIVER ======================================
+public defn run-passes (c: Circuit) :
+ var c*:Circuit = c
+ defn do-stage (name:String, f: Circuit -> Circuit) :
+ println(name)
+ c* = f(c*)
+ println(c*)
+ println("\n\n\n\n")
+
+ do-stage("Working IR", to-working-ir)
+ do-stage("Resolve Kinds", resolve-kinds)
+ do-stage("Make Explicit Reset", make-explicit-reset)
+ do-stage("Infer Types", infer-types)
+ do-stage("Infer Directions", infer-directions)
+ do-stage("Expand Accessors", expand-accessors)
+ do-stage("Flatten Bundles", flatten-bundles)
+ do-stage("Expand Bundles", expand-bundles)
+ do-stage("Expand Multi Connects", expand-multi-connects)
+ do-stage("Expand Whens", expand-whens)
+ do-stage("Structural Form", structural-form)
+ do-stage("Infer Widths", infer-widths)
+ do-stage("Pad Widths", pad-widths)
+ do-stage("Inline Instances", inline-instances)
+
+
+ ;; println("Shim for Jonathan's Passes")
+ ;; c* = shim(c*)
+ ;; println("Inline Modules")
+ ;; c* = inline-modules(c*)
+ ; c*
diff --git a/src/main/stanza/widthsolver.stanza b/src/main/stanza/widthsolver.stanza
new file mode 100644
index 00000000..4c9da2c6
--- /dev/null
+++ b/src/main/stanza/widthsolver.stanza
@@ -0,0 +1,298 @@
+;Define the STANDALONE flag to run STANDALONE
+if-defined(STANDALONE) :
+ include<"core/stringeater.stanza">
+ include<"compiler/lexer.stanza">
+
+defpackage widthsolver :
+ import core
+ import verse
+ import stanza.lexer
+
+;============= Language of Constraints ======================
+public definterface WConstraint
+public defstruct WidthEqual <: WConstraint :
+ name: Symbol
+ value: Exp
+public defstruct WidthGreater <: WConstraint :
+ name: Symbol
+ value: Exp
+
+defmethod print (o:OutputStream, c:WConstraint) :
+ print-all{o, _} $
+ match(c) :
+ (c:WidthEqual) : [name(c) " = " value(c)]
+ (c:WidthGreater) : [name(c) " >= " value(c)]
+
+defn construct-eqns (cs: Streamable<WConstraint>) :
+ val eqns = HashTable<Symbol, False|Exp>(symbol-hash)
+ val lower-bounds = HashTable<Symbol, List<Exp>>(symbol-hash)
+ for c in cs do :
+ match(c) :
+ (c:WidthEqual) :
+ eqns[name(c)] = value(c)
+ (c:WidthGreater) :
+ lower-bounds[name(c)] =
+ List(value(c),
+ get?(lower-bounds, name(c), List()))
+
+ ;Create minimum expressions for lower-bounds
+ for entry in lower-bounds do :
+ val v = key(entry)
+ val exps = value(entry)
+ if not key?(eqns, v) :
+ eqns[v] = reduce(EMax, ELit(0), exps)
+
+ ;Return equations
+ eqns
+;============================================================
+
+;============= Language of Expressions ======================
+public definterface Exp
+public defstruct EVar <: Exp :
+ name: Symbol
+public defstruct EMax <: Exp :
+ a: Exp
+ b: Exp
+public defstruct EPlus <: Exp :
+ a: Exp
+ b: Exp
+public defstruct EMinus <: Exp :
+ a: Exp
+ b: Exp
+public defstruct ELit <: Exp :
+ width: Int
+
+defmethod print (o:OutputStream, e:Exp) :
+ match(e) :
+ (e:EVar) : print(o, name(e))
+ (e:EMax) : print-all(o, ["max(" a(e) ", " b(e) ")"])
+ (e:EPlus) : print-all(o, [a(e) " + " b(e)])
+ (e:EMinus) : print-all(o, [a(e) " - " b(e)])
+ (e:ELit) : print(o, width(e))
+
+defn map (f: (Exp) -> Exp, e: Exp) -> Exp :
+ match(e) :
+ (e:EMax) : EMax(f(a(e)), f(b(e)))
+ (e:EPlus) : EPlus(f(a(e)), f(b(e)))
+ (e:EMinus) : EMinus(f(a(e)), f(b(e)))
+ (e:Exp) : e
+
+defn children (e: Exp) -> List<Exp> :
+ match(e) :
+ (e:EMax) : list(a(e), b(e))
+ (e:EPlus) : list(a(e), b(e))
+ (e:EMinus) : list(a(e), b(e))
+ (e:Exp) : list()
+;============================================================
+
+;================== Reading from File =======================
+defn read-exp (x) :
+ match(unwrap-token(x)) :
+ (x:Symbol) :
+ EVar(x)
+ (x:Int) :
+ ELit(x)
+ (x:List) :
+ val tag = unwrap-token(x[1])
+ switch {tag == _} :
+ `plus : EPlus(read-exp(x[2]), read-exp(x[3]))
+ `minus : EMinus(read-exp(x[2]), read-exp(x[3]))
+ `max : EMax(read-exp(x[2]), read-exp(x[3]))
+ else : error $ string-join $
+ ["Improper expression: " x]
+
+defn read (filename: String) :
+ var form:List = lex-file(filename)
+ val cs = Vector<WConstraint>()
+ while not empty?(form) :
+ val x = unwrap-token(form[0])
+ val op = form[1]
+ val e = read-exp(form[2])
+ form = tailn(form, 3)
+ add{cs, _} $
+ switch {unwrap-token(op) == _} :
+ `= : WidthEqual(x, e)
+ `>= : WidthGreater(x, e)
+ else : error $ string-join $ ["Unsupported Operator: " op]
+ cs
+;============================================================
+
+;============ Operations on Expressions =====================
+defn occurs? (v: Symbol, exp: Exp) :
+ match(exp) :
+ (exp: EVar) : name(exp) == v
+ (exp: Exp) : any?(occurs?{v, _}, children(exp))
+
+defn freevars (exp: Exp) :
+ to-list $ generate<Symbol> :
+ defn loop (exp: Exp) :
+ match(exp) :
+ (exp: EVar) : yield(name(exp))
+ (exp: Exp) : do(loop, children(exp))
+ loop(exp)
+
+defn contains-only-max? (exp: Exp) :
+ match(exp) :
+ (exp:EVar|EMax|ELit) : all?(contains-only-max?, children(exp))
+ (exp) : false
+
+defn simplify (exp: Exp) :
+ match(map(simplify,exp)) :
+ (exp: EPlus) :
+ match(a(exp), b(exp)) :
+ (a: ELit, b: ELit) :
+ ELit(width(a) + width(b))
+ (a: ELit, b) :
+ if width(a) == 0 : b
+ else : exp
+ (a, b: ELit) :
+ if width(b) == 0 : a
+ else : exp
+ (a, b) :
+ exp
+ (exp: EMinus) :
+ match(a(exp), b(exp)) :
+ (a: ELit, b: ELit) :
+ ELit(width(a) - width(b))
+ (a, b: ELit) :
+ if width(b) == 0 : a
+ else : exp
+ (a, b) :
+ exp
+ (exp: EMax) :
+ match(a(exp), b(exp)) :
+ (a: ELit, b: ELit) :
+ ELit(max(width(a), width(b)))
+ (a: ELit, b) :
+ if width(a) == 0 : b
+ else : exp
+ (a, b: ELit) :
+ if width(b) == 0 : a
+ else : exp
+ (a, b) :
+ exp
+ (exp: Exp) :
+ exp
+
+defn eval (exp: Exp, state: HashTable<Symbol,Int>) -> Int :
+ defn loop (e: Exp) -> Int :
+ match(e) :
+ (e: EVar) : state[name(e)]
+ (e: EMax) : max(loop(a(e)), loop(b(e)))
+ (e: EPlus) : loop(a(e)) + loop(b(e))
+ (e: EMinus) : loop(a(e)) - loop(b(e))
+ (e: ELit) : width(e)
+ loop(exp)
+;============================================================
+
+
+;================ Constraint Solver =========================
+defn substitute (solns: HashTable<Symbol, Exp>, exp: Exp) :
+ match(exp) :
+ (exp: EVar) :
+ match(get?(solns, name(exp), false)) :
+ (s:Exp) : substitute(solns, s)
+ (f:False) : exp
+ (exp) :
+ map(substitute{solns, _}, exp)
+
+defn dataflow (eqns: HashTable<Symbol, False|Exp>, solns: HashTable<Symbol,Exp>) :
+ var progress?:True|False = false
+ for entry in eqns do :
+ if value(entry) != false :
+ val v = key(entry)
+ val exp = simplify(substitute(solns, value(entry) as Exp))
+ if occurs?(v, exp) :
+ eqns[v] = exp
+ else :
+ eqns[v] = false
+ solns[v] = exp
+ progress? = true
+ progress?
+
+defn fixpoint (eqns: HashTable<Symbol, False|Exp>, solns: HashTable<Symbol,Exp>) :
+ label<False|True> break :
+ for v in keys(eqns) do :
+ if eqns[v] != false :
+ val fix-eqns = fixpoint-eqns(v, eqns)
+ val has-fixpoint? = all?(contains-only-max?{value(_)}, fix-eqns)
+ if has-fixpoint? :
+ val soln = solve-fixpoint(fix-eqns)
+ for s in soln do :
+ solns[key(s)] = ELit(value(s))
+ eqns[key(s)] = false
+ break(true)
+ false
+
+defn fixpoint-eqns (v: Symbol, eqns: HashTable<Symbol,False|Exp>) :
+ val vs = HashTable<Symbol,Exp>(symbol-hash)
+ defn loop (v: Symbol) :
+ if not key?(vs, v) :
+ val eqn = eqns[v] as Exp
+ vs[v] = eqn
+ do(loop, freevars(eqn))
+ loop(v)
+ to-list(vs)
+
+defn solve-fixpoint (eqns: List<KeyValue<Symbol,Exp>>) :
+ ;Solve for fixpoint
+ val sol = HashTable<Symbol,Int>(symbol-hash)
+ do({sol[key(_)] = 0}, eqns)
+ defn loop () :
+ var progress?:True|False = false
+ for eqn in eqns do :
+ val v = key(eqn)
+ val x = eval(value(eqn), sol)
+ if x != sol[v] :
+ sol[v] = x
+ progress? = true
+ progress?
+ while loop() : false
+
+ ;Return solutions
+ to-list(sol)
+
+defn backsubstitute (vs:Streamable<Symbol>, solns: HashTable<Symbol,Exp>) :
+ val widths = HashTable<Symbol,False|Int>(symbol-hash)
+ defn get-width (v:Symbol) :
+ if key?(solns, v) :
+ val vs = freevars(solns[v])
+ ;Calculate dependencies
+ for v in vs do :
+ if not key?(widths, v) :
+ widths[v] = get-width(v)
+ ;Compute value
+ if none?({widths[_] == false}, vs) :
+ eval(solns[v], widths as HashTable<Symbol,Int>)
+
+ ;Compute all widths
+ for v in vs do :
+ widths[v] = get-width(v)
+
+ ;Return widths
+ to-list $ generate<WidthEqual> :
+ for entry in widths do :
+ if value(entry) != false :
+ yield $ WidthEqual(key(entry), ELit(value(entry) as Int))
+
+public defn solve-widths (cs: Streamable<WConstraint>) :
+ ;Copy to new hashtable
+ val eqns = construct-eqns(cs)
+ val solns = HashTable<Symbol,Exp>(symbol-hash)
+ defn loop () :
+ dataflow(eqns, solns) or
+ fixpoint(eqns, solns)
+ while loop() : false
+ backsubstitute(keys(eqns), solns)
+
+;================= Main =====================================
+if-defined(STANDALONE) :
+ defn main () :
+ val input = lex(commandline-arguments())
+ error("No input file!") when length(input) < 2
+ val cs = read(to-string(input[1]))
+ do(println, solve-widths(cs))
+
+ main()
+;============================================================
+
diff --git a/src/test/firrtl/firrtl-test.txt b/src/test/firrtl/firrtl-test.txt
new file mode 100644
index 00000000..7d8e66d2
--- /dev/null
+++ b/src/test/firrtl/firrtl-test.txt
@@ -0,0 +1,56 @@
+circuit top :
+ module subtracter :
+ input x:UInt
+ input y:UInt
+ output z:UInt
+ z := sub-mod(x, y)
+
+ module gcd :
+ input a: UInt(16)
+ input b: UInt(16)
+ input e: UInt(1)
+ output z: UInt(16)
+ output v: UInt(1)
+
+ reg x: UInt
+ reg y: UInt
+ x.init := UInt(0)
+ y.init := UInt(42)
+
+ when greater(x, y) :
+ inst s of subtracter
+ s.x := x
+ s.y := y
+ x := s.z
+ else :
+ inst s2 of subtracter
+ s2.x := x
+ s2.y := y
+ y := s2.z
+
+ when e :
+ x := a
+ y := b
+
+ v := equal(v, UInt(0))
+ z := x
+
+ module top :
+ input a: UInt(16)
+ input b: UInt(16)
+ output z: UInt
+
+ inst i of gcd
+ i.a := a
+ i.b := b
+ i.e := UInt(1)
+ z := i.z
+
+
+
+
+
+
+
+
+