aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorazidar2015-03-04 16:25:25 -0800
committerazidar2015-03-04 16:25:25 -0800
commit6ad6267d26b52258f6e0d4d004aeb5f36856cf95 (patch)
tree16aad9875b1f58dc0cc2a5cd59091e89d57a0861
parent355749c83d2066f1a149333ed762a7945d405076 (diff)
Finished infer-types pass
-rw-r--r--.gitignore9
-rw-r--r--TODO2
-rw-r--r--spec/spec.tex4
-rw-r--r--src/main/stanza/ir-parser.stanza68
-rw-r--r--src/main/stanza/ir-utils.stanza56
-rw-r--r--src/main/stanza/passes.stanza2341
-rw-r--r--test/passes/infer-types/bundle.fir13
-rw-r--r--test/passes/infer-types/gcd.fir20
-rw-r--r--test/passes/infer-types/primops.fir222
-rw-r--r--test/passes/initialize-register/when.fir2
-rw-r--r--test/passes/resolve-kinds/gcd.fir6
11 files changed, 1393 insertions, 1350 deletions
diff --git a/.gitignore b/.gitignore
index f33036d1..273dc02c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,7 +1,8 @@
-src/*.DS_STORE
-src/*/*.DS_STORE
-src/*/*/*.DS_STORE
-src/*/*/*/*.DS_STORE
+.DS_STORE
+*/*.DS_STORE
+*/*/*.DS_STORE
+*/*/*/*.DS_STORE
+*/*/*/*/*.DS_STORE
*.swp
*/*.swp
*/*/*.swp
diff --git a/TODO b/TODO
index 332b127d..e122c35c 100644
--- a/TODO
+++ b/TODO
@@ -3,6 +3,8 @@ TODO
Figure out how types and widths propogate for all updated primops
Write infer-types pass
Remove letrec. Add to expressions: Register(input,en), ReadPort(mem,index,enable), WritePort(mem,index,enable)
+ Add bit-reduce-and etc to primops
+ Write pass to rename identifiers (alpha-transform)
Update spec
change concrete syntactical names of structural elements
change direction names for bundle fields
diff --git a/spec/spec.tex b/spec/spec.tex
index d92785c5..fc8a8129 100644
--- a/spec/spec.tex
+++ b/spec/spec.tex
@@ -576,8 +576,8 @@ The resultant value of a divide operation has width equal to the width of the di
\kws{mod}( \pds{op1}, \pds{op2}) & UInt & width(op1)|width(op2) - 1 \\
\kws{mod-uu}(\pds{op1}, \pds{op2}) & UInt & width(op2) \\
\kws{mod-us}(\pds{op1}, \pds{op2}) & UInt & width(op2) - 1? \\
-\kws{mod-su}(\pds{op1}, \pds{op2}) & UInt & width(op2) \\
-\kws{mod-ss}(\pds{op1}, \pds{op2}) & UInt & width(op2) - 1? \\
+\kws{mod-su}(\pds{op1}, \pds{op2}) & SInt & width(op2) \\
+\kws{mod-ss}(\pds{op1}, \pds{op2}) & SInt & width(op2) - 1? \\
\end{array}
\]
diff --git a/src/main/stanza/ir-parser.stanza b/src/main/stanza/ir-parser.stanza
index 43383f9a..cbd57f9b 100644
--- a/src/main/stanza/ir-parser.stanza
+++ b/src/main/stanza/ir-parser.stanza
@@ -54,11 +54,15 @@ defn unwrap-prefix-form (form) :
;======= Split Dots ============
defn split-dots (forms:List) :
+ defn to-form (x:String) :
+ val num? = for c in x all? :
+ c >= '0' and c <= '9'
+ to-int(x) when num? else to-symbol(x)
defn split (form) :
match(ut(form)) :
(f:Symbol) :
val fstr = to-string(f)
- if contains?(fstr, '.') : map(to-symbol, split-string(fstr, "."))
+ if contains?(fstr, '.') : map(to-form, split-string(fstr, "."))
else : list(form)
(f:List) :
list(map-append(split, f))
@@ -148,10 +152,10 @@ rd.defsyntax firrtl :
ut(name) => Instance(UnknownType(), module, ports)
defrule exp :
- (?x:#exp . ?f:#symbol) :
- Field(x, ut(f), UnknownType())
(?x:#exp . ?f:#int) :
Index(x, ut(f), UnknownType())
+ (?x:#exp . ?f:#symbol) :
+ Field(x, ut(f), UnknownType())
(?x:#exp-form) :
x
@@ -201,26 +205,26 @@ rd.defsyntax firrtl :
operators[`sub-wrap-us] = SUB-WRAP-US-OP
operators[`sub-wrap-su] = SUB-WRAP-SU-OP
operators[`sub-wrap-ss] = SUB-WRAP-SS-OP
- operators[`less] = LESS-OP
- operators[`less-uu] = LESS-UU-OP
- operators[`less-us] = LESS-US-OP
- operators[`less-su] = LESS-SU-OP
- operators[`less-ss] = LESS-SS-OP
- operators[`less-eq] = LESS-EQ-OP
- operators[`less-eq-uu] = LESS-EQ-UU-OP
- operators[`less-eq-us] = LESS-EQ-US-OP
- operators[`less-eq-su] = LESS-EQ-SU-OP
- operators[`less-eq-ss] = LESS-EQ-SS-OP
- operators[`greater] = GREATER-OP
- operators[`greater-uu] = GREATER-UU-OP
- operators[`greater-us] = GREATER-US-OP
- operators[`greater-su] = GREATER-SU-OP
- operators[`greater-ss] = GREATER-SS-OP
- operators[`greater-eq] = GREATER-EQ-OP
- operators[`greater-eq-uu] = GREATER-EQ-UU-OP
- operators[`greater-eq-us] = GREATER-EQ-US-OP
- operators[`greater-eq-su] = GREATER-EQ-SU-OP
- operators[`greater-eq-ss] = GREATER-EQ-SS-OP
+ operators[`lt] = LESS-OP
+ operators[`lt-uu] = LESS-UU-OP
+ operators[`lt-us] = LESS-US-OP
+ operators[`lt-su] = LESS-SU-OP
+ operators[`lt-ss] = LESS-SS-OP
+ operators[`leq] = LESS-EQ-OP
+ operators[`leq-uu] = LESS-EQ-UU-OP
+ operators[`leq-us] = LESS-EQ-US-OP
+ operators[`leq-su] = LESS-EQ-SU-OP
+ operators[`leq-ss] = LESS-EQ-SS-OP
+ operators[`gt] = GREATER-OP
+ operators[`gt-uu] = GREATER-UU-OP
+ operators[`gt-us] = GREATER-US-OP
+ operators[`gt-su] = GREATER-SU-OP
+ operators[`gt-ss] = GREATER-SS-OP
+ operators[`geq] = GREATER-EQ-OP
+ operators[`geq-uu] = GREATER-EQ-UU-OP
+ operators[`geq-us] = GREATER-EQ-US-OP
+ operators[`geq-su] = GREATER-EQ-SU-OP
+ operators[`geq-ss] = GREATER-EQ-SS-OP
operators[`equal] = EQUAL-OP
operators[`equal-uu] = EQUAL-UU-OP
operators[`equal-ss] = EQUAL-SS-OP
@@ -236,15 +240,15 @@ rd.defsyntax firrtl :
operators[`as-SInt] = AS-SINT-OP
operators[`as-SInt-u] = AS-SINT-U-OP
operators[`as-SInt-s] = AS-SINT-S-OP
- operators[`shift-left] = SHIFT-LEFT-OP
- operators[`shift-left-u] = SHIFT-LEFT-U-OP
- operators[`shift-left-s] = SHIFT-LEFT-S-OP
- operators[`shift-right] = SHIFT-RIGHT-OP
- operators[`shift-right-u] = SHIFT-RIGHT-U-OP
- operators[`shift-right-s] = SHIFT-RIGHT-S-OP
- operators[`convert] = SHIFT-RIGHT-OP
- operators[`convert-u] = SHIFT-RIGHT-U-OP
- operators[`convert-s] = SHIFT-RIGHT-S-OP
+ operators[`shl] = SHIFT-LEFT-OP
+ operators[`shl-u] = SHIFT-LEFT-U-OP
+ operators[`shl-s] = SHIFT-LEFT-S-OP
+ operators[`shr] = SHIFT-RIGHT-OP
+ operators[`shr-u] = SHIFT-RIGHT-U-OP
+ operators[`shr-s] = SHIFT-RIGHT-S-OP
+ operators[`convert] = CONVERT-OP
+ operators[`convert-u] = CONVERT-U-OP
+ operators[`convert-s] = CONVERT-S-OP
operators[`bit-and] = BIT-AND-OP
operators[`bit-or] = BIT-OR-OP
operators[`bit-xor] = BIT-XOR-OP
diff --git a/src/main/stanza/ir-utils.stanza b/src/main/stanza/ir-utils.stanza
index 4da64981..9e8c63c5 100644
--- a/src/main/stanza/ir-utils.stanza
+++ b/src/main/stanza/ir-utils.stanza
@@ -69,26 +69,26 @@ defmethod print (o:OutputStream, op:PrimOp) :
SUB-WRAP-US-OP : "sub-wrap-us"
SUB-WRAP-SU-OP : "sub-wrap-su"
SUB-WRAP-SS-OP : "sub-wrap-ss"
- LESS-OP : "less"
- LESS-UU-OP : "less-uu"
- LESS-US-OP : "less-us"
- LESS-SU-OP : "less-su"
- LESS-SS-OP : "less-ss"
- LESS-EQ-OP : "less-eq"
- LESS-EQ-UU-OP : "less-eq-uu"
- LESS-EQ-US-OP : "less-eq-us"
- LESS-EQ-SU-OP : "less-eq-su"
- LESS-EQ-SS-OP : "less-eq-ss"
- GREATER-OP : "greater"
- GREATER-UU-OP : "greater-uu"
- GREATER-US-OP : "greater-us"
- GREATER-SU-OP : "greater-su"
- GREATER-SS-OP : "greater-ss"
- GREATER-EQ-OP : "greater-eq"
- GREATER-EQ-UU-OP : "greater-eq-uu"
- GREATER-EQ-US-OP : "greater-eq-us"
- GREATER-EQ-SU-OP : "greater-eq-su"
- GREATER-EQ-SS-OP : "greater-eq-ss"
+ LESS-OP : "lt"
+ LESS-UU-OP : "lt-uu"
+ LESS-US-OP : "lt-us"
+ LESS-SU-OP : "lt-su"
+ LESS-SS-OP : "lt-ss"
+ LESS-EQ-OP : "leq"
+ LESS-EQ-UU-OP : "leq-uu"
+ LESS-EQ-US-OP : "leq-us"
+ LESS-EQ-SU-OP : "leq-su"
+ LESS-EQ-SS-OP : "leq-ss"
+ GREATER-OP : "gt"
+ GREATER-UU-OP : "gt-uu"
+ GREATER-US-OP : "gt-us"
+ GREATER-SU-OP : "gt-su"
+ GREATER-SS-OP : "gt-ss"
+ GREATER-EQ-OP : "geq"
+ GREATER-EQ-UU-OP : "geq-uu"
+ GREATER-EQ-US-OP : "geq-us"
+ GREATER-EQ-SU-OP : "geq-su"
+ GREATER-EQ-SS-OP : "geq-ss"
EQUAL-OP : "equal"
EQUAL-UU-OP : "equal-uu"
EQUAL-SS-OP : "equal-ss"
@@ -104,12 +104,12 @@ defmethod print (o:OutputStream, op:PrimOp) :
AS-SINT-OP : "as-SInt"
AS-SINT-U-OP : "as-SInt-u"
AS-SINT-S-OP : "as-SInt-s"
- SHIFT-LEFT-OP : "shift-left"
- SHIFT-LEFT-U-OP : "shift-left-u"
- SHIFT-LEFT-S-OP : "shift-left-s"
- SHIFT-RIGHT-OP : "shift-right"
- SHIFT-RIGHT-U-OP : "shift-right-u"
- SHIFT-RIGHT-S-OP : "shift-right-s"
+ SHIFT-LEFT-OP : "shl"
+ SHIFT-LEFT-U-OP : "shl-u"
+ SHIFT-LEFT-S-OP : "shl-s"
+ SHIFT-RIGHT-OP : "shr"
+ SHIFT-RIGHT-U-OP : "shr-u"
+ SHIFT-RIGHT-S-OP : "shr-s"
CONVERT-OP : "convert"
CONVERT-U-OP : "convert-u"
CONVERT-S-OP : "convert-s"
@@ -198,7 +198,9 @@ defmethod print (o:OutputStream, t:Type) :
(w:UnknownWidth) : print-all(o, ["UInt"])
(w) : print-all(o, ["UInt(" width(t) ")"])
(t:SIntType) :
- print-all(o, ["SInt(" width(t) ")"])
+ match(width(t)) :
+ (w:UnknownWidth) : print-all(o, ["SInt"])
+ (w) : print-all(o, ["SInt(" width(t) ")"])
(t:BundleType) :
print(o, "{")
print-all(o, join(ports(t), ", "))
diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza
index 18593499..3cdb553b 100644
--- a/src/main/stanza/passes.stanza
+++ b/src/main/stanza/passes.stanza
@@ -84,9 +84,15 @@ defn any-debug? (e:Expression|Stmt|Type|Element|Port) :
(hasKind(e) and PRINT-KINDS)
defmethod print-debug (o:OutputStream, e:Expression|Stmt|Type|Element|Port) :
+ defn wipe-width (t:Type) -> Type :
+ match(t) :
+ (t:UIntType) : UIntType(UnknownWidth())
+ (t:SIntType) : SIntType(UnknownWidth())
+ (t) : t
+
if any-debug?(e) : print(o,"@")
if PRINT-KINDS and hasKind(e) : print-all(o,["<k:" kind(e as ?) ">"])
- if PRINT-TYPES and hasType(e) : print-all(o,["<t:" type(e as ?) ">"])
+ if PRINT-TYPES and hasType(e) : print-all(o,["<t:" wipe-width(type(e as ?)) ">"])
if PRINT-WIDTHS and hasWidth(e): print-all(o,["<w:" width(e as ?) ">"])
defmethod print (o:OutputStream, e:WRef) :
@@ -346,18 +352,20 @@ defn get-primop-rettype (e:DoPrim) -> Type :
defn u () : UIntType(UnknownWidth())
defn s () : SIntType(UnknownWidth())
defn u-and (op1:Expression,op2:Expression) :
- if type(op1) typeof UIntType and type(op2) typeof UIntType :
- UIntType(UnknownWidth())
- else :
- SIntType(UnknownWidth())
+ match(type(op1), type(op2)) :
+ (t1:UIntType, t2:UIntType) : u()
+ (t1:SIntType, t2) : s()
+ (t1, t2:SIntType) : s()
+ (t1, t2) : UnknownType()
+
defn of-type (op:Expression) :
- if type(op) typeof UIntType :
- UIntType(UnknownWidth())
- if type(op) typeof SIntType :
- SIntType(UnknownWidth())
- else : UnknownType()
+ match(type(op)) :
+ (t:UIntType) : u()
+ (t:SIntType) : s()
+ (t) : UnknownType()
- switch {e == _} :
+ ;println-all(["Inferencing primop type: " e])
+ switch {op(e) == _} :
ADD-OP : u-and(args(e)[0],args(e)[1])
ADD-UU-OP : u()
ADD-US-OP : s()
@@ -378,11 +386,11 @@ defn get-primop-rettype (e:DoPrim) -> Type :
DIV-US-OP : s()
DIV-SU-OP : s()
DIV-SS-OP : s()
- MOD-OP : u()
+ MOD-OP : of-type(args(e)[0])
MOD-UU-OP : u()
MOD-US-OP : u()
- MOD-SU-OP : u()
- MOD-SS-OP : u()
+ MOD-SU-OP : s()
+ MOD-SS-OP : s()
QUO-OP : u-and(args(e)[0],args(e)[1])
QUO-UU-OP : u()
QUO-US-OP : s()
@@ -458,22 +466,20 @@ defn type (m:Module) -> Type :
BundleType(ports(m))
defn get-type (b:Symbol,l:List<KeyValue<Symbol,Type>>) -> Type :
- val contains? = for kv in l any? : b == key(kv)
- if contains? :
- label<Type> myret :
- for kv in l do :
- if b == key(kv) : myret(value(kv))
- myret(UnknownType())
- else : UnknownType()
+ val ma = for kv in l find : b == key(kv)
+ if ma != false :
+ val ret = value(ma as KeyValue<Symbol,Type>)
+ ;println-all(["Found! Returning " ret " for " b])
+ ret
+ else :
+ ;println-all(["Not found! Returning " UnknownType() " for " b])
+ UnknownType()
defn bundle-field-type (v:Type,s:Symbol) -> Type :
match(v) :
(v:BundleType) :
- val contains? = for p in ports(v) any? : name(p) == s
- if contains? :
- label<Type> myret :
- for p in ports(v) do :
- if b == name(p) : myret(type(p))
+ val ft = for p in ports(v) find : name(p) == s
+ if ft != false : type(ft as Port)
else : UnknownType()
(v) : UnknownType()
@@ -482,8 +488,9 @@ defn get-vector-subtype (v:Type) -> Type :
(v:VectorType) : type(v)
(v) : UnknownType()
-defn infer-types (e:Expression, l:List<KeyValue<Symbol,Type>>) -> Expression :
- match(map(infer-types{_,l},e)) :
+defn infer-exp-types (e:Expression, l:List<KeyValue<Symbol,Type>>) -> Expression :
+ val r = map(infer-exp-types{_,l},e)
+ match(r) :
(e:WRef) : WRef(name(e), get-type(name(e),l),kind(e),dir(e))
(e:WField) : WField(exp(e),name(e), bundle-field-type(type(exp(e)),name(e)),dir(e))
(e:WIndex) : WIndex(exp(e),value(e), get-vector-subtype(type(exp(e))),dir(e))
@@ -492,7 +499,7 @@ defn infer-types (e:Expression, l:List<KeyValue<Symbol,Type>>) -> Expression :
(e:UIntValue|SIntValue|Null) : e
defn infer-types (s:Stmt, l:List<KeyValue<Symbol,Type>>) -> [Stmt, List<KeyValue<Symbol,Type>>] :
- match(s) :
+ match(map(infer-exp-types{_,l},s)) :
(s:LetRec) : [s,l] ;TODO, this is wrong but we might be getting rid of letrecs?
(s:Begin) :
var env = l
@@ -507,7 +514,7 @@ defn infer-types (s:Stmt, l:List<KeyValue<Symbol,Type>>) -> [Stmt, List<KeyValue
(s:DefMemory) : [s,List(name(s) => type(s),l)]
(s:DefInstance) : [s, List(name(s) => type(module(s)),l)]
(s:DefNode) : [s, List(name(s) => type(value(s)),l)]
- (s:WDefAccessor) : [s, List(name(s) => type(source(s)),l)]
+ (s:WDefAccessor) : [s, List(name(s) => get-vector-subtype(type(source(s))),l)]
(s:Conditionally) :
val [s*,l*] = infer-types(conseq(s),l)
val [s**,l**] = infer-types(alt(s),l)
@@ -518,6 +525,7 @@ defn infer-types (m:Module, l:List<KeyValue<Symbol,Type>>) -> Module :
val ptypes =
for p in ports(m) map :
name(p) => type(p)
+ ;println-all(append(ptypes,l))
val [s,l*] = infer-types(body(m),append(ptypes, l))
Module(name(m),ports(m),s)
@@ -525,1138 +533,1139 @@ defn infer-types (c:Circuit) -> Circuit :
val l =
for m in modules(c) map :
name(m) => BundleType(ports(m))
+ ;println-all(l)
Circuit{ _, main(c) } $
for m in modules(c) map :
infer-types(m,l)
;============= INFER DIRECTIONS ============================
-defn flip (d:Direction) :
- switch {d == _} :
- INPUT : OUTPUT
- OUTPUT : INPUT
- else : d
-
-defn times (d1:Direction, d2:Direction) :
- if d1 == INPUT : flip(d2)
- else : d2
-
-defn lookup-port (ports: Streamable<Port>, port-name: Symbol) :
- for port in ports find :
- name(port) == port-name
-
-defn bundle-field-dir (t:Type, n:Symbol) -> Direction :
- match(t) :
- (t:BundleType) :
- match(lookup-port(ports(t), n)) :
- (p:Port) : direction(p)
- (p) : UNKNOWN-DIR
- (t) : UNKNOWN-DIR
-
-defn infer-dirs (m:Module) :
- ;=== Direction of all Binders ===
- val BI-DIR = new Direction
- val directions = HashTable<Symbol,Direction>(symbol-hash)
- defn find-dirs (c:Stmt) :
- match(c) :
- (c:LetRec) :
- for entry in entries(c) do :
- directions[key(entry)] = OUTPUT
- find-dirs(body(c))
- (c:DefWire) :
- directions[name(c)] = BI-DIR
- (c:DefRegister) :
- directions[name(c)] = BI-DIR
- (c:DefInstance) :
- directions[name(c)] = OUTPUT
- (c:DefMemory) :
- directions[name(c)] = BI-DIR
- (c:WDefAccessor) :
- directions[name(c)] = dir(c)
- (c) :
- do(find-dirs, children(c))
- for p in ports(m) do :
- directions[name(p)] = flip(direction(p))
- find-dirs(body(m))
-
- ;=== Fix Point Status ===
- var changed? = false
-
- ;=== Infer directions of Expression ===
- defn infer-exp (e:Expression, desired:Direction) :
- match(e) :
- (e:WRef) :
- val dir* = let :
- if kind(e) typeof ModuleKind :
- OUTPUT
- else :
- val old-dir = directions[name(e)]
- switch {old-dir == _} :
- BI-DIR :
- desired
- UNKNOWN-DIR :
- if directions[name(e)] != desired :
- directions[name(e)] = desired
- changed? = true
- desired
- else :
- old-dir
- WRef(name(e), type(e), kind(e), dir*)
- (e:WField) :
- val port-dir = bundle-field-dir(type(exp(e)), name(e))
- val exp* = infer-exp(exp(e), port-dir * desired)
- WField(exp*, name(e), type(e), port-dir * dir(exp*))
- (e:WIndex) :
- val exp* = infer-exp(exp(e), desired)
- WIndex(exp*, value(e), type(e), dir(exp*))
- (e) :
- map(infer-exp{_, OUTPUT}, e)
-
- ;=== Infer directions of Stmts ===
- defn infer-comm (c:Stmt) :
- match(c) :
- (c:LetRec) :
- val c* = map(infer-exp{_, OUTPUT}, c)
- LetRec(entries(c*), infer-comm(body(c)))
- (c:DefInstance) :
- DefInstance(name(c),
- infer-exp(module(c), OUTPUT))
- (c:WDefAccessor) :
- val d = directions[name(c)]
- WDefAccessor(name(c),
- infer-exp(source(c), d),
- infer-exp(index(c), OUTPUT),
- d)
- (c:Conditionally) :
- Conditionally(infer-exp(pred(c), OUTPUT),
- infer-comm(conseq(c)),
- infer-comm(alt(c)))
- (c:Connect) :
- Connect(infer-exp(loc(c), INPUT), infer-exp(exp(c), OUTPUT))
- (c) :
- map(infer-comm, c)
-
- ;=== Iterate until fix point ===
- defn* fixpoint (c:Stmt) :
- changed? = false
- val c* = infer-comm(c)
- if changed? : fixpoint(c*)
- else : c*
-
- Module(name(m), ports(m), body*) where :
- val body* = fixpoint(body(m))
-
-defn infer-directions (c:Circuit) :
- Circuit(modules*, main(c)) where :
- val modules* = map(infer-dirs, modules(c))
-
-
-;============== EXPAND VECS ================================
-defstruct ManyConnect <: Stmt :
- index: Expression
- locs: List<Expression>
- exp: Expression
-
-defstruct ConnectMany <: Stmt :
- index: Expression
- loc: Expression
- exps: List<Expression>
-
-defmethod print (o:OutputStream, c:ManyConnect) :
- print-all(o, [locs(c) "[" index(c) "] := " exp(c)])
-defmethod print (o:OutputStream, c:ConnectMany) :
- print-all(o, [loc(c) " := " exps(c) "[" index(c) "]"])
-
-defmethod map (f: Expression -> Expression, c:ManyConnect) :
- ManyConnect(f(index(c)), map(f, locs(c)), f(exp(c)))
-defmethod map (f: Expression -> Expression, c:ConnectMany) :
- ConnectMany(f(index(c)), f(loc(c)), map(f, exps(c)))
-
-defn expand-accessors (m: Module) :
- defn expand (c:Stmt) :
- match(c) :
- (c:WDefAccessor) :
- ;Is the source a memory?
- val mem? =
- match(source(c)) :
- (r:WRef) : kind(r) typeof MemKind
- (r) : false
-
- if mem? :
- c
- else :
- switch {dir(c) == _} :
- INPUT :
- Begin(list(
- DefWire(name(c), type(src-type))
- ManyConnect(index(c), elems, wire-ref)))
- where :
- val src-type = type(source(c)) as VectorType
- val wire-ref = WRef(name(c), type(src-type), NodeKind(), OUTPUT)
- val elems = to-list $
- for i in 0 to size(src-type) stream :
- WIndex(source(c), i, type(src-type), INPUT)
- OUTPUT :
- Begin(list(
- DefWire(name(c), type(src-type))
- ConnectMany(index(c), wire-ref, elems)))
- where :
- val src-type = type(source(c)) as VectorType
- val wire-ref = WRef(name(c), type(src-type), NodeKind(), INPUT)
- val elems = to-list $
- for i in 0 to size(src-type) stream :
- WIndex(source(c), i, type(src-type), OUTPUT)
- (c) :
- map(expand, c)
- Module(name(m), ports(m), expand(body(m)))
-
-defn expand-accessors (c:Circuit) :
- Circuit(modules*, main(c)) where :
- val modules* = map(expand-accessors, modules(c))
-
-
-
-;=============== BUNDLE FLATTENING =========================
-defn prefix (prefix, suffix) :
- symbol-join([prefix "/" suffix])
-
-defn prefix-ports (pre:Symbol, ports:List<Port>) :
- for p in ports map :
- Port(prefix(pre, name(p)), direction(p), type(p))
-
-defn flatten-ports (port:Port) -> List<Port> :
- match(type(port)) :
- (t:BundleType) :
- val ports = map-append(flatten-ports, ports(t))
- for p in ports map :
- Port(prefix(name(port), name(p)),
- direction(port) * direction(p),
- type(p))
- (t:VectorType) :
- val type* = flatten-type(t)
- flatten-ports(Port(name(port), direction(port), type*))
- (t:Type) :
- list(port)
-
-defn flatten-type (t:Type) -> Type :
- match(t) :
- (t:BundleType) :
- BundleType $
- map-append(flatten-ports, ports(t))
- (t:VectorType) :
- flatten-type $ BundleType $ to-list $
- for i in 0 to size(t) stream :
- Port(to-symbol(i), OUTPUT, type(t))
- (t:Type) :
- t
-
-defn flatten-bundles (c:Circuit) :
- defn flatten-exp (e:Expression) :
- match(map(flatten-exp, e)) :
- (e:UIntValue|SIntValue) :
- e
- (e:WRef) :
- match(kind(e)) :
- (k:MemKind|StructuralMemKind) :
- val type* = map(flatten-type, type(e))
- put-type(e, type*)
- (k) :
- val type* = flatten-type(type(e))
- put-type(e, type*)
- (e) :
- val type* = flatten-type(type(e))
- put-type(e, type*)
-
- defn flatten-element (e:Element) :
- val t* = flatten-type(type(e))
- match(map(flatten-exp, e)) :
- (e:Register) : Register(t*, value(e), enable(e))
- (e:Memory) : Memory(t*, writers(e))
- (e:Node) : Node(t*, value(e))
- (e:Instance) : Instance(t*, module(e), ports(e))
-
- defn flatten-comm (c:Stmt) :
- match(c) :
- (c:LetRec) :
- val entries* =
- for entry in entries(c) map :
- key(entry) => flatten-element(value(entry))
- LetRec(entries*, flatten-comm(body(c)))
- (c:DefWire) :
- DefWire(name(c), flatten-type(type(c)))
- (c:DefRegister) :
- DefRegister(name(c), flatten-type(type(c)))
- (c:DefMemory) :
- val type* = map(flatten-type, type(c))
- DefMemory(name(c), type*)
- (c) :
- map{flatten-comm, _} $
- map(flatten-exp, c)
-
- defn flatten-module (m:Module) :
- val ports* = map-append(flatten-ports, ports(m))
- val body* = flatten-comm(body(m))
- Module(name(m), ports*, body*)
-
- Circuit(modules*, main(c)) where :
- val modules* = map(flatten-module, modules(c))
-
-
-;================== BUNDLE EXPANSION =======================
-defn expand-bundles (m:Module) :
-
- ;Collapse all field/index expressions
- defn collapse-exp (e:Expression) -> Expression :
- match(e) :
- (e:WField) :
- match(collapse-exp(exp(e))) :
- (ei:WRef) :
- if kind(ei) typeof InstanceKind :
- e
- else :
- WRef(name*, type(e), kind(ei), dir(e)) where :
- val name* = prefix(name(ei), name(e))
- (ei:WField) :
- WField(exp(ei), name*, type(e), dir(e)) where :
- val name* = prefix(name(ei), name(e))
- (e:WIndex) :
- collapse-exp(WField(exp(e), name, type(e), dir(e))) where :
- val name = to-symbol(value(e))
- (e) :
- map(collapse-exp, e)
-
- ;Expand expressions
- defn expand-exp (e:Expression) -> List<Expression> :
- match(type(e)) :
- (t:BundleType) :
- for p in ports(t) map :
- val dir* = direction(p) * dir(e)
- collapse-exp(WField(e, name(p), type(p), dir*))
- (t) :
- list(collapse-exp(e))
-
- ;Expand commands
- defn expand-comm (c:Stmt) :
- match(c) :
- (c:DefWire) :
- match(type(c)) :
- (t:BundleType) :
- Begin $
- for p in ports(t) map :
- DefWire(prefix(name(c), name(p)), type(p))
- (t) :
- c
- (c:DefRegister) :
- match(type(c)) :
- (t:BundleType) :
- Begin $
- for p in ports(t) map :
- DefRegister(prefix(name(c), name(p)), type(p))
- (t) :
- c
- (c:DefMemory) :
- match(type(type(c))) :
- (t:BundleType) :
- Begin $
- for p in ports(t) map :
- DefMemory(prefix(name(c), name(p)), type*) where :
- val s = size(type(c))
- val type* = VectorType(type(p), s)
- (t) :
- c
- (c:WDefAccessor) :
- match(type(source(c))) :
- (t:BundleType) :
- val srcs = expand-exp(source(c))
- Begin $
- for (p in ports(t), src in srcs) map :
- WDefAccessor(name*, src, index(c), dir*) where :
- val name* = prefix(name(c), name(p))
- val dir* = direction(p) * dir(c)
- (t) :
- c
- (c:Connect) :
- val locs = expand-exp(loc(c))
- val exps = expand-exp(exp(c))
- Begin $
- for (l in locs, e in exps) map :
- switch {dir(l) == _} :
- INPUT : Connect(l, e)
- OUTPUT : Connect(e, l)
- (c:ManyConnect) :
- val locs-list = transpose(map(expand-exp, locs(c)))
- val exps = expand-exp(exp(c))
- Begin $
- for (locs in locs-list, e in exps) map :
- switch {dir(e) == _} :
- OUTPUT : ManyConnect(index(c), locs, e)
- INPUT : ConnectMany(index(c), e, locs)
- (c:ConnectMany) :
- val locs = expand-exp(loc(c))
- val exps-list = transpose(map(expand-exp, exps(c)))
- Begin $
- for (l in locs, exps in exps-list) map :
- switch {dir(l) == _} :
- INPUT : ConnectMany(index(c), l, exps)
- OUTPUT : ManyConnect(index(c), exps, l)
- (c) :
- map{expand-comm, _} $
- map(collapse-exp, c)
-
- Module(name(m), ports(m), expand-comm(body(m)))
-
-defn expand-bundles (c:Circuit) :
- Circuit(modules*, main(c)) where :
- val modules* = map(expand-bundles, modules(c))
-
-
-;=========== CONVERT MULTI CONNECTS to WHEN ================
-defn expand-multi-connects (c:Circuit) :
- defn equal-exp (e1:Expression, e2:Expression) :
- DoPrim(EQUAL-OP, list(e1, e2), List(), UIntType(UnknownWidth()))
- defn uint (i:Int) :
- UIntValue(i, UnknownWidth())
-
- defn expand-comm (c:Stmt) :
- match(c) :
- (c:ConnectMany) :
- Begin $ to-list $
- for (i in 0 to false, e in exps(c)) stream :
- Conditionally(equal-exp(index(c), uint(i)),
- Connect(loc(c), e)
- EmptyStmt())
- (c:ManyConnect) :
- Begin $ to-list $
- for (i in 0 to false, l in locs(c)) stream :
- Conditionally(equal-exp(index(c), uint(i)),
- Connect(l, exp(c))
- EmptyStmt())
- (c) :
- map(expand-comm, c)
-
- defn expand (m:Module) :
- Module(name(m), ports(m), expand-comm(body(m)))
-
- Circuit(modules*, main(c)) where :
- val modules* = map(expand, modules(c))
-
-
-;================ EXPAND WHENS =============================
-definterface SymbolicValue
-defstruct ExpValue <: SymbolicValue :
- exp: Expression
-defstruct WhenValue <: SymbolicValue :
- pred: Expression
- conseq: SymbolicValue
- alt: SymbolicValue
-defstruct VoidValue <: SymbolicValue
-
-defmethod print (o:OutputStream, sv:SymbolicValue) :
- match(sv) :
- (sv:VoidValue) : print(o, "VOID")
- (sv:WhenValue) : print-all(o, ["(" pred(sv) "? " conseq(sv) " : " alt(sv) ")"])
- (sv:ExpValue) : print(o, exp(sv))
-
-defn key-eqv? (e1:Expression, e2:Expression) :
- match(e1, e2) :
- (e1:WRef, e2:WRef) :
- name(e1) == name(e2)
- (e1:WField, e2:WField) :
- name(e1) == name(e2) and
- key-eqv?(exp(e1), exp(e2))
- (e1, e2) :
- false
-
-defn merge-env (pred: Expression,
- con-env: List<KeyValue<Expression,SymbolicValue>>,
- alt-env: List<KeyValue<Expression,SymbolicValue>>) :
- val merged = Vector<KeyValue<Expression, SymbolicValue>>()
- defn new-key? (k:Expression) :
- for entry in merged none? :
- key-eqv?(key(entry), k)
-
- defn sv (env:List<KeyValue<Expression,SymbolicValue>>, k:Expression) :
- for entry in env search :
- if key-eqv?(key(entry), k) :
- value(entry)
-
- for k in stream(key, concat(con-env, alt-env)) do :
- if new-key?(k) :
- match(sv(con-env, k), sv(alt-env, k)) :
- (a:SymbolicValue, b:SymbolicValue) :
- if a == b :
- add(merged, k => a)
- else :
- add(merged, k => WhenValue(pred, a, b))
- (a:SymbolicValue, b:False) :
- add(merged, k => a)
- (a:False, b:SymbolicValue) :
- add(merged, k => b)
- (a:False, b:False) :
- false
-
- to-list(merged)
-
-defn simplify-env (env: List<KeyValue<Expression,SymbolicValue>>) :
- val merged = Vector<KeyValue<Expression, SymbolicValue>>()
- defn new-key? (k:Expression) :
- for entry in merged none? :
- key-eqv?(key(entry), k)
- for entry in env do :
- if new-key?(key(entry)) :
- add(merged, entry)
- to-list(merged)
-
-defn expand-whens (m:Module) :
- val commands = Vector<Stmt>()
- val elements = Vector<KeyValue<Symbol,Element>>()
- defn eval (c:Stmt, env:List<KeyValue<Expression,SymbolicValue>>) ->
- List<KeyValue<Expression,SymbolicValue>> :
- match(c) :
- (c:LetRec) :
- do(add{elements, _}, entries(c))
- eval(body(c), env)
- (c:DefWire) :
- add(commands, c)
- val wire-ref = WRef(name(c), type(c), NodeKind(), INPUT)
- List(wire-ref => VoidValue(), env)
- (c:DefRegister) :
- add(commands, c)
- val reg-ref = WRef(name(c), type(c), RegKind(), INPUT)
- List(reg-ref => VoidValue(), env)
- (c:DefInstance) :
- add(commands, c)
- val entries = let :
- val module-type = type(module(c)) as BundleType
- val input-ports = to-list $
- for p in ports(module-type) filter :
- direction(p) == INPUT
- val inst-ref = WRef(name(c), module-type, InstanceKind(), OUTPUT)
- for p in input-ports map :
- WField(inst-ref, name(p), type(p), INPUT) => VoidValue()
- append(entries, env)
- (c:DefMemory) :
- add(commands, c)
- env
- (c:WDefAccessor) :
- add(commands, c)
- if dir(c) == INPUT :
- val access-ref = WRef(name(c), type(source(c)), AccessorKind(), dir(c))
- List(access-ref => VoidValue(), env)
- else :
- env
- (c:Conditionally) :
- val con-env = eval(conseq(c), env)
- val alt-env = eval(alt(c), env)
- merge-env(pred(c), con-env, alt-env)
- (c:Begin) :
- var env:List<KeyValue<Expression,SymbolicValue>> = env
- for c in body(c) do :
- env = eval(c, env)
- env
- (c:Connect) :
- List(loc(c) => ExpValue(exp(c)), env)
- (c:EmptyStmt) :
- env
-
- defn convert-symbolic (key:Expression, sv:SymbolicValue) :
- match(sv) :
- (sv:VoidValue) :
- throw $ PassException $ string-join $ [
- "No default value for " key "."]
- (sv:ExpValue) :
- exp(sv)
- (sv:WhenValue) :
- defn multiplex-exp (pred:Expression, conseq:Expression, alt:Expression) :
- DoPrim(MUX-OP, list(pred, conseq, alt), List(), type(conseq))
- multiplex-exp(pred(sv),
- convert-symbolic(key, conseq(sv))
- convert-symbolic(key, alt(sv)))
-
- ;Compute final environment
- val env0 = let :
- val output-ports = to-list $
- for p in ports(m) filter :
- direction(p) == OUTPUT
- for p in output-ports map :
- val port-ref = WRef(name(p), type(p), PortKind(), INPUT)
- port-ref => VoidValue()
- val env* = simplify-env(eval(body(m), env0))
-
- ;Make new body
- val body* = Begin(list(defs, LetRec(elems, connections))) where :
- val defs = Begin(to-list(commands))
- val elems = to-list(elements)
- val connections = Begin $
- for entry in env* map :
- val sv = convert-symbolic(key(entry), value(entry))
- Connect(key(entry), sv)
-
- ;Final module
- Module(name(m), ports(m), body*)
-
-defn expand-whens (c:Circuit) :
- val modules* = map(expand-whens, modules(c))
- Circuit(modules*, main(c))
-
-
-;================ STRUCTURAL FORM ==========================
-
-defn structural-form (m:Module) :
- val elements = Vector<(() -> KeyValue<Symbol,Element>)>()
- val connected = HashTable<Symbol, Expression>(symbol-hash)
- val write-accessors = HashTable<Symbol, List<WDefAccessor>>(symbol-hash)
- val read-accessors = HashTable<Symbol, WDefAccessor>(symbol-hash)
- val inst-ports = HashTable<Symbol, List<KeyValue<Symbol, Expression>>>(symbol-hash)
- val port-connects = Vector<Connect>()
-
- defn scan (c:Stmt) :
- match(c) :
- (c:Connect) :
- match(loc(c)) :
- (loc:WRef) :
- match(kind(loc)) :
- (k:PortKind) : add(port-connects, c)
- (k) : connected[name(loc)] = exp(c)
- (loc:WField) :
- val inst = exp(loc) as WRef
- val entry = name(loc) => exp(c)
- inst-ports[name(inst)] = List(entry, get?(inst-ports, name(inst), List()))
- (c:LetRec) :
- for e in entries(c) do :
- add(elements, {e})
- scan(body(c))
- (c:DefWire) :
- add{elements, _} $ fn () :
- name(c) => Node(type(c), connected[name(c)])
- (c:DefRegister) :
- add{elements, _} $ fn () :
- val one = UIntValue(1, UnknownWidth())
- name(c) => Register(type(c), connected[name(c)], one)
- (c:DefInstance) :
- add{elements, _} $ fn () :
- name(c) => Instance(UnknownType(), module(c), inst-ports[name(c)])
- (c:DefMemory) :
- add{elements, _} $ fn () :
- val ports = for a in get?(write-accessors, name(c), List()) map :
- val one = UIntValue(1, UnknownWidth())
- WritePort(index(a), connected[name(a)], one)
- name(c) => Memory(type(c), ports)
- (c:WDefAccessor) :
- val mem = source(c) as WRef
- switch {dir(c) == _} :
- INPUT :
- write-accessors[name(mem)] = List(c,
- get?(write-accessors, name(mem), List()))
- OUTPUT :
- read-accessors[name(c)] = c
- (c) :
- do(scan, children(c))
-
- defn make-read-ports (e:Expression) :
- match(e) :
- (e:WRef) :
- match(kind(e)) :
- (k:AccessorKind) :
- val accessor = read-accessors[name(e)]
- ReadPort(source(accessor), index(accessor), type(e))
- (k) : e
- (e) : map(make-read-ports, e)
-
- Module(name(m), ports(m), body*) where :
- scan(body(m))
- val elems = to-list $
- for e in elements stream :
- val entry = e()
- key(entry) => map(make-read-ports, value(entry))
- val connect-ports = Begin $ to-list $
- for c in port-connects stream :
- Connect(loc(c), make-read-ports(exp(c)))
- val body* =
- if empty?(elems) : connect-ports
- else : LetRec(elems, connect-ports)
-
-defn structural-form (c:Circuit) :
- val modules* = map(structural-form, modules(c))
- Circuit(modules*, main(c))
-
-
-;==================== WIDTH INFERENCE ======================
-defstruct WidthVar <: Width :
- name: Symbol
-
-defmethod print (o:OutputStream, w:WidthVar) :
- print(o, name(w))
-
-defn width! (t:Type) :
- match(t) :
- (t:UIntType) : width(t)
- (t:SIntType) : width(t)
- (t) : error("No width field.")
-
-defn put-width (t:Type, w:Width) :
- match(t) :
- (t:UIntType) : UIntType(w)
- (t:SIntType) : SIntType(w)
- (t) : t
-
-defn put-width (e:Expression, w:Width) :
- val type* = put-width(type(e), w)
- put-type(e, type*)
-
-defn add-width-vars (t:Type) :
- defn width? (w:Width) :
- match(w) :
- (w:UnknownWidth) : WidthVar(gensym())
- (w) : w
- match(t) :
- (t:UIntType) : UIntType(width?(width(t)))
- (t:SIntType) : SIntType(width?(width(t)))
- (t) : map(add-width-vars, t)
-
-defn uint-width (i:Int) :
- var v:Int = i
- var n:Int = 0
- while v != 0 :
- v = v >> 1
- n = n + 1
- IntWidth(n)
-
-defn sint-width (i:Int) :
- if i > 0 :
- val w = uint-width(i)
- IntWidth(width(w) + 1)
- else :
- val w = uint-width(neg(i) - 1)
- IntWidth(width(w) + 1)
-
-defn to-exp (w:Width) :
- match(w) :
- (w:IntWidth) : ELit(width(w))
- (w:WidthVar) : EVar(name(w))
- (w) : error $ string-join $ [
- "Cannot convert " w " to exp."]
-
-defn primop-width (op:PrimOp, ws:List<Width>, ints:List<Int>) -> Exp :
- defn wmax (w1:Width, w2:Width) :
- EMax(to-exp(w1), to-exp(w2))
- defn wplus (w1:Width, w2:Width) :
- EPlus(to-exp(w1), to-exp(w2))
- defn wplus (w1:Width, w2:Int) :
- EPlus(to-exp(w1), ELit(w2))
- defn wminus (w1:Width, w2:Width) :
- EMinus(to-exp(w1), to-exp(w2))
- defn wminus (w1:Width, w2:Int) :
- EMinus(to-exp(w1), ELit(w2))
- defn wmax-inc (w1:Width, w2:Width) :
- EPlus(wmax(w1, w2), ELit(1))
-
- switch {op == _} :
- ADD-OP : wmax-inc(ws[0], ws[1])
- ADD-WRAP-OP : wmax(ws[0], ws[1])
- SUB-OP : wmax-inc(ws[0], ws[1])
- SUB-WRAP-OP : wmax(ws[0], ws[1])
- MUL-OP : wplus(ws[0], ws[1])
- DIV-OP : wminus(ws[0], ws[1])
- MOD-OP : to-exp(ws[1])
- SHIFT-LEFT-OP : wplus(ws[0], ints[0])
- SHIFT-RIGHT-OP : wminus(ws[0], ints[0])
- PAD-OP : ELit(ints[0])
- BIT-AND-OP : wmax(ws[0], ws[1])
- BIT-OR-OP : wmax(ws[0], ws[1])
- BIT-XOR-OP : wmax(ws[0], ws[1])
- CONCAT-OP : wplus(ws[0], ints[0])
- BIT-SELECT-OP : ELit(1)
- BITS-SELECT-OP : ELit(ints[0])
- MUX-OP : wmax(ws[1], ws[2])
- LESS-OP : ELit(1)
- LESS-EQ-OP : ELit(1)
- GREATER-OP : ELit(1)
- GREATER-EQ-OP : ELit(1)
- EQUAL-OP : ELit(1)
-
-defn put-type (el:Element, t:Type) :
- match(el) :
- (el:Register) : Register(t, value(el), enable(el))
- (el:Memory) : Memory(t, writers(el))
- (el:Node) : Node(t, value(el))
- (el:Instance) : Instance(t, module(el), ports(el))
-
-defn generate-constraints (c:Circuit) -> [Circuit, Vector<WConstraint>] :
- ;Constraints
- val cs = Vector<WConstraint>()
- defn new-constraint (Constraint: (Symbol, Exp) -> WConstraint, wvar:Width, width:Width) :
- match(wvar) :
- (wvar:WidthVar) :
- add(cs, Constraint(name(wvar), to-exp(width)))
- (wvar) :
- false
-
- defn to-width (e:Exp) :
- match(e) :
- (e:ELit) :
- IntWidth(width(e))
- (e:EVar) :
- WidthVar(name(e))
- (e) :
- val x = gensym()
- add(cs, WidthEqual(x, e))
- WidthVar(x)
-
- ;Module types
- val mod-types = HashTable<Symbol,Type>(symbol-hash)
-
- defn add-port-vars (m:Module) -> Module :
- val ports* =
- for p in ports(m) map :
- val type* = add-width-vars(type(p))
- Port(name(p), direction(p), type*)
- mod-types[name(m)] = BundleType(ports*)
- Module(name(m), ports*, body(m))
-
- ;Add Width Variables
- defn add-module-vars (m:Module) -> Module :
- val types = HashTable<Symbol,Type>(symbol-hash)
- for p in ports(m) do :
- types[name(p)] = type(p)
-
- defn infer-exp-width (e:Expression) :
- match(map(infer-exp-width, e)) :
- (e:WRef) :
- match(kind(e)) :
- (k:ModuleKind) : put-type(e, mod-types[name(e)])
- (k) : put-type(e, types[name(e)])
- (e:WField) :
- val t = bundle-field-type(type(exp(e)), name(e))
- put-width(e, width!(t))
- (e:UIntValue) :
- match(width(e)) :
- (w:UnknownWidth) : UIntValue(value(e), uint-width(value(e)))
- (w) : e
- (e:SIntValue) :
- match(width(e)) :
- (w:UnknownWidth) : SIntValue(value(e), sint-width(value(e)))
- (w) : e
- (e:DoPrim) :
- val widths = map(width!{type(_)}, args(e))
- val w = to-width(primop-width(op(e), widths, consts(e)))
- put-width(e, w)
- (e:ReadPort) :
- val elem-type = type(type(mem(e)) as VectorType)
- put-width(e, width!(elem-type))
-
- defn infer-comm-width (c:Stmt) :
- match(c) :
- (c:LetRec) :
- ;Add width vars to elements
- var entries*: List<KeyValue<Symbol,Element>> =
- for entry in entries(c) map :
- val el-name = key(entry)
- key(entry) =>
- match(value(entry)) :
- (el:Register|Node) :
- put-type(el, add-width-vars(type(el)))
- (el:Memory) :
- el
- (el:Instance) :
- val mod-type = type(infer-exp-width(module(el))) as BundleType
- val type = BundleType $ to-list $
- for p in ports(mod-type) filter :
- direction(p) == OUTPUT
- put-type(el, type)
-
- ;Add vars to environment
- for entry in entries* do :
- types[key(entry)] = type(value(entry))
-
- ;Infer types for elements
- entries* =
- for entry in entries* map :
- key(entry) => map(infer-exp-width, value(entry))
-
- ;Generate constraints
- for entry in entries* do :
- val el-name = key(entry)
- match(value(entry)) :
- (el:Register) :
- new-constraint(WidthEqual, reg-width, val-width) where :
- val reg-width = width!(types[el-name])
- val val-width = width!(type(value(el)))
- (el:Node) :
- new-constraint(WidthEqual, node-width, val-width) where :
- val node-width = width!(types[el-name])
- val val-width = width!(type(value(el)))
- (el:Instance) :
- val mod-type = type(module(el)) as BundleType
- for entry in ports(el) do :
- new-constraint(WidthGreater, port-width, val-width) where :
- val port-name = key(entry)
- val port-width = width!(bundle-field-type(mod-type, port-name))
- val val-width = width!(type(value(entry)))
- (el) : false
-
- ;Analyze body
- LetRec(entries*, infer-comm-width(body(c)))
-
- (c:Connect) :
- val loc* = infer-exp-width(loc(c))
- val exp* = infer-exp-width(exp(c))
- new-constraint(WidthGreater, loc-width, exp-width) where :
- val loc-width = width!(type(loc*))
- val exp-width = width!(type(exp*))
- Connect(loc*, exp*)
-
- (c:Begin) :
- map(infer-comm-width, c)
-
- Module(name(m), ports(m), body*) where :
- val body* = infer-comm-width(body(m))
-
- val c* =
- Circuit(modules*, main(c)) where :
- val ms = map(add-port-vars, modules(c))
- val modules* = map(add-module-vars, ms)
- [c*, cs]
-
-
-;================== FILL WIDTHS ============================
-defn fill-widths (c:Circuit, solved:Streamable<WidthEqual>) :
- ;Populate table
- val table = HashTable<Symbol, Width>(symbol-hash)
- for eq in solved do :
- table[name(eq)] = IntWidth(width(value(eq) as ELit))
-
- defn width? (w:Width) :
- match(w) :
- (w:WidthVar) : get?(table, name(w), UnknownWidth())
- (w) : w
-
- defn fill-type (t:Type) :
- match(t) :
- (t:UIntType) : UIntType(width?(width(t)))
- (t:SIntType) : SIntType(width?(width(t)))
- (t) : map(fill-type, t)
-
- defn fill-exp (e:Expression) :
- val e* = map(fill-exp, e)
- val type* = fill-type(type(e))
- put-type(e*, type*)
-
- defn fill-element (e:Element) :
- val e* = map(fill-exp, e)
- val type* = fill-type(type(e))
- put-type(e*, type*)
-
- defn fill-comm (c:Stmt) :
- match(c) :
- (c:LetRec) :
- val entries* =
- for e in entries(c) map :
- key(e) => fill-element(value(e))
- LetRec(entries*, fill-comm(body(c)))
- (c) :
- map{fill-comm, _} $
- map(fill-exp, c)
-
- defn fill-port (p:Port) :
- Port(name(p), direction(p), fill-type(type(p)))
-
- defn fill-mod (m:Module) :
- Module(name(m), ports*, body*) where :
- val ports* = map(fill-port, ports(m))
- val body* = fill-comm(body(m))
-
- Circuit(modules*, main(c)) where :
- val modules* = map(fill-mod, modules(c))
-
-
-;=============== TYPE INFERENCE DRIVER =====================
-defn infer-widths (c:Circuit) :
- val [c*, cs] = generate-constraints(c)
- val solved = solve-widths(cs)
- fill-widths(c*, solved)
-
-
-;================ PAD WIDTHS ===============================
-defn pad-widths (c:Circuit) :
- ;Pad an expression to the given width
- defn pad-exp (e:Expression, w:Int) :
- match(type(e)) :
- (t:UIntType|SIntType) :
- val prev-w = width!(t) as IntWidth
- if width(prev-w) < w :
- val type* = put-width(t, IntWidth(w))
- DoPrim(PAD-OP, list(e), list(w), type*)
- else :
- e
- (t) :
- e
-
- defn pad-exp (e:Expression, w:Width) :
- val w-value = width(w as IntWidth)
- pad-exp(e, w-value)
-
- ;Convenience
- defn max-width (es:Streamable<Expression>) :
- defn int-width (e:Expression) :
- width(width!(type(e)) as IntWidth)
- maximum(stream(int-width, es))
-
- defn match-widths (es:List<Expression>) :
- val w = max-width(es)
- map(pad-exp{_, w}, es)
-
- ;Match widths for an expression
- defn match-exp-width (e:Expression) :
- match(map(match-exp-width, e)) :
- (e:DoPrim) :
- if contains?([BIT-AND-OP, BIT-OR-OP, BIT-XOR-OP, EQUAL-OP], op(e)) :
- val args* = match-widths(args(e))
- DoPrim(op(e), args*, consts(e), type(e))
- else if op(e) == MUX-OP :
- val args* = List(head(args(e)), match-widths(tail(args(e))))
- DoPrim(op(e), args*, consts(e), type(e))
- else :
- e
- (e) : e
-
- defn match-element-width (e:Element) :
- match(map(match-exp-width, e)) :
- (e:Register) :
- val w = width!(type(e))
- val value* = pad-exp(value(e), w)
- Register(type(e), value*, enable(e))
- (e:Memory) :
- val width = width!(type(type(e) as VectorType))
- val writers* =
- for w in writers(e) map :
- WritePort(index(w), pad-exp(value(w), width), enable(w))
- Memory(type(e), writers*)
- (e:Node) :
- val w = width!(type(e))
- val value* = pad-exp(value(e), w)
- Node(type(e), value*)
- (e:Instance) :
- val mod-type = type(module(e)) as BundleType
- val ports* =
- for p in ports(e) map :
- val port-type = bundle-field-type(mod-type, key(p))
- val port-val = pad-exp(value(p), width!(port-type))
- key(p) => port-val
- Instance(type(e), module(e), ports*)
-
- ;Match widths for a command
- defn match-comm-width (c:Stmt) :
- match(map(match-exp-width, c)) :
- (c:LetRec) :
- val entries* =
- for e in entries(c) map :
- key(e) => match-element-width(value(e))
- LetRec(entries*, match-comm-width(body(c)))
- (c:Connect) :
- val w = width!(type(loc(c)))
- val exp* = pad-exp(exp(c), w)
- Connect(loc(c), exp*)
- (c) :
- map(match-comm-width, c)
-
- defn match-mod-width (m:Module) :
- Module(name(m), ports(m), body*) where :
- val body* = match-comm-width(body(m))
-
- Circuit(modules*, main(c)) where :
- val modules* = map(match-mod-width, modules(c))
-
-
-;================== INLINING ===============================
-defn inline-instances (c:Circuit) :
- val module-table = HashTable<Symbol,Module>(symbol-hash)
- val inlined? = HashTable<Symbol,True|False>(symbol-hash)
- for m in modules(c) do :
- module-table[name(m)] = m
- inlined?[name(m)] = false
-
- ;Convert a module into a sequence of elements
- defn to-elements (m:Module,
- inst:Symbol,
- port-exps:List<KeyValue<Symbol,Expression>>) ->
- List<KeyValue<Symbol, Element>> :
- defn rename-exp (e:Expression) :
- match(e) :
- (e:WRef) : WRef(prefix(inst, name(e)), type(e), kind(e), dir(e))
- (e) : map(rename-exp, e)
-
- defn to-elements (c:Stmt) -> List<KeyValue<Symbol,Element>> :
- match(c) :
- (c:LetRec) :
- val entries* =
- for entry in entries(c) map :
- val name* = prefix(inst, key(entry))
- val element* = map(rename-exp, value(entry))
- name* => element*
- val body* = to-elements(body(c))
- append(entries*, body*)
- (c:Connect) :
- val ref = loc(c) as WRef
- val name* = prefix(inst, name(ref))
- list(name* => Node(type(exp(c)), rename-exp(exp(c))))
- (c:Begin) :
- map-append(to-elements, body(c))
-
- val inputs =
- for p in ports(m) map-append :
- if direction(p) == INPUT :
- val port-exp = lookup!(port-exps, name(p))
- val name* = prefix(inst, name(p))
- list(name* => Node(type(port-exp), port-exp))
- else :
- List()
- append(inputs, to-elements(body(m)))
-
- ;Inline all instances in the module
- defn inline-instances (m:Module) :
- defn rename-exp (e:Expression) :
- match(e) :
- (e:WField) :
- val inst-exp = exp(e) as WRef
- val name* = prefix(name(inst-exp), name(e))
- WRef(name*, type(e), NodeKind(), dir(e))
- (e) :
- map(rename-exp, e)
-
- defn inline-elems (es:List<KeyValue<Symbol,Element>>) :
- for entry in es map-append :
- match(value(entry)) :
- (el:Instance) :
- val mod-name = name(module(el) as WRef)
- val module = inlined-module(mod-name)
- to-elements(module, key(entry), ports(el))
- (el) :
- list(entry)
-
- defn inline-comm (c:Stmt) :
- match(map(rename-exp, c)) :
- (c:LetRec) :
- val entries* = inline-elems(entries(c))
- LetRec(entries*, inline-comm(body(c)))
- (c) :
- map(inline-comm, c)
-
- Module(name(m), ports(m), inline-comm(body(m)))
-
- ;Retrieve the inlined instance of a module
- defn inlined-module (name:Symbol) :
- if inlined?[name] :
- module-table[name]
- else :
- val module* = inline-instances(module-table[name])
- module-table[name] = module*
- inlined?[name] = true
- module*
-
- ;Return the fully inlined circuit
- val main-module = inlined-module(main(c))
- Circuit(list(main-module), main(c))
+;defn flip (d:Direction) :
+; switch {d == _} :
+; INPUT : OUTPUT
+; OUTPUT : INPUT
+; else : d
+;
+;defn times (d1:Direction, d2:Direction) :
+; if d1 == INPUT : flip(d2)
+; else : d2
+;
+;defn lookup-port (ports: Streamable<Port>, port-name: Symbol) :
+; for port in ports find :
+; name(port) == port-name
+;
+;defn bundle-field-dir (t:Type, n:Symbol) -> Direction :
+; match(t) :
+; (t:BundleType) :
+; match(lookup-port(ports(t), n)) :
+; (p:Port) : direction(p)
+; (p) : UNKNOWN-DIR
+; (t) : UNKNOWN-DIR
+;
+;defn infer-dirs (m:Module) :
+; ;=== Direction of all Binders ===
+; val BI-DIR = new Direction
+; val directions = HashTable<Symbol,Direction>(symbol-hash)
+; defn find-dirs (c:Stmt) :
+; match(c) :
+; (c:LetRec) :
+; for entry in entries(c) do :
+; directions[key(entry)] = OUTPUT
+; find-dirs(body(c))
+; (c:DefWire) :
+; directions[name(c)] = BI-DIR
+; (c:DefRegister) :
+; directions[name(c)] = BI-DIR
+; (c:DefInstance) :
+; directions[name(c)] = OUTPUT
+; (c:DefMemory) :
+; directions[name(c)] = BI-DIR
+; (c:WDefAccessor) :
+; directions[name(c)] = dir(c)
+; (c) :
+; do(find-dirs, children(c))
+; for p in ports(m) do :
+; directions[name(p)] = flip(direction(p))
+; find-dirs(body(m))
+;
+; ;=== Fix Point Status ===
+; var changed? = false
+;
+; ;=== Infer directions of Expression ===
+; defn infer-exp (e:Expression, desired:Direction) :
+; match(e) :
+; (e:WRef) :
+; val dir* = let :
+; if kind(e) typeof ModuleKind :
+; OUTPUT
+; else :
+; val old-dir = directions[name(e)]
+; switch {old-dir == _} :
+; BI-DIR :
+; desired
+; UNKNOWN-DIR :
+; if directions[name(e)] != desired :
+; directions[name(e)] = desired
+; changed? = true
+; desired
+; else :
+; old-dir
+; WRef(name(e), type(e), kind(e), dir*)
+; (e:WField) :
+; val port-dir = bundle-field-dir(type(exp(e)), name(e))
+; val exp* = infer-exp(exp(e), port-dir * desired)
+; WField(exp*, name(e), type(e), port-dir * dir(exp*))
+; (e:WIndex) :
+; val exp* = infer-exp(exp(e), desired)
+; WIndex(exp*, value(e), type(e), dir(exp*))
+; (e) :
+; map(infer-exp{_, OUTPUT}, e)
+;
+; ;=== Infer directions of Stmts ===
+; defn infer-comm (c:Stmt) :
+; match(c) :
+; (c:LetRec) :
+; val c* = map(infer-exp{_, OUTPUT}, c)
+; LetRec(entries(c*), infer-comm(body(c)))
+; (c:DefInstance) :
+; DefInstance(name(c),
+; infer-exp(module(c), OUTPUT))
+; (c:WDefAccessor) :
+; val d = directions[name(c)]
+; WDefAccessor(name(c),
+; infer-exp(source(c), d),
+; infer-exp(index(c), OUTPUT),
+; d)
+; (c:Conditionally) :
+; Conditionally(infer-exp(pred(c), OUTPUT),
+; infer-comm(conseq(c)),
+; infer-comm(alt(c)))
+; (c:Connect) :
+; Connect(infer-exp(loc(c), INPUT), infer-exp(exp(c), OUTPUT))
+; (c) :
+; map(infer-comm, c)
+;
+; ;=== Iterate until fix point ===
+; defn* fixpoint (c:Stmt) :
+; changed? = false
+; val c* = infer-comm(c)
+; if changed? : fixpoint(c*)
+; else : c*
+;
+; Module(name(m), ports(m), body*) where :
+; val body* = fixpoint(body(m))
+;
+;defn infer-directions (c:Circuit) :
+; Circuit(modules*, main(c)) where :
+; val modules* = map(infer-dirs, modules(c))
+;
+;
+;;============== EXPAND VECS ================================
+;defstruct ManyConnect <: Stmt :
+; index: Expression
+; locs: List<Expression>
+; exp: Expression
+;
+;defstruct ConnectMany <: Stmt :
+; index: Expression
+; loc: Expression
+; exps: List<Expression>
+;
+;defmethod print (o:OutputStream, c:ManyConnect) :
+; print-all(o, [locs(c) "[" index(c) "] := " exp(c)])
+;defmethod print (o:OutputStream, c:ConnectMany) :
+; print-all(o, [loc(c) " := " exps(c) "[" index(c) "]"])
+;
+;defmethod map (f: Expression -> Expression, c:ManyConnect) :
+; ManyConnect(f(index(c)), map(f, locs(c)), f(exp(c)))
+;defmethod map (f: Expression -> Expression, c:ConnectMany) :
+; ConnectMany(f(index(c)), f(loc(c)), map(f, exps(c)))
+;
+;defn expand-accessors (m: Module) :
+; defn expand (c:Stmt) :
+; match(c) :
+; (c:WDefAccessor) :
+; ;Is the source a memory?
+; val mem? =
+; match(source(c)) :
+; (r:WRef) : kind(r) typeof MemKind
+; (r) : false
+;
+; if mem? :
+; c
+; else :
+; switch {dir(c) == _} :
+; INPUT :
+; Begin(list(
+; DefWire(name(c), type(src-type))
+; ManyConnect(index(c), elems, wire-ref)))
+; where :
+; val src-type = type(source(c)) as VectorType
+; val wire-ref = WRef(name(c), type(src-type), NodeKind(), OUTPUT)
+; val elems = to-list $
+; for i in 0 to size(src-type) stream :
+; WIndex(source(c), i, type(src-type), INPUT)
+; OUTPUT :
+; Begin(list(
+; DefWire(name(c), type(src-type))
+; ConnectMany(index(c), wire-ref, elems)))
+; where :
+; val src-type = type(source(c)) as VectorType
+; val wire-ref = WRef(name(c), type(src-type), NodeKind(), INPUT)
+; val elems = to-list $
+; for i in 0 to size(src-type) stream :
+; WIndex(source(c), i, type(src-type), OUTPUT)
+; (c) :
+; map(expand, c)
+; Module(name(m), ports(m), expand(body(m)))
+;
+;defn expand-accessors (c:Circuit) :
+; Circuit(modules*, main(c)) where :
+; val modules* = map(expand-accessors, modules(c))
+;
+;
+;
+;;=============== BUNDLE FLATTENING =========================
+;defn prefix (prefix, suffix) :
+; symbol-join([prefix "/" suffix])
+;
+;defn prefix-ports (pre:Symbol, ports:List<Port>) :
+; for p in ports map :
+; Port(prefix(pre, name(p)), direction(p), type(p))
+;
+;defn flatten-ports (port:Port) -> List<Port> :
+; match(type(port)) :
+; (t:BundleType) :
+; val ports = map-append(flatten-ports, ports(t))
+; for p in ports map :
+; Port(prefix(name(port), name(p)),
+; direction(port) * direction(p),
+; type(p))
+; (t:VectorType) :
+; val type* = flatten-type(t)
+; flatten-ports(Port(name(port), direction(port), type*))
+; (t:Type) :
+; list(port)
+;
+;defn flatten-type (t:Type) -> Type :
+; match(t) :
+; (t:BundleType) :
+; BundleType $
+; map-append(flatten-ports, ports(t))
+; (t:VectorType) :
+; flatten-type $ BundleType $ to-list $
+; for i in 0 to size(t) stream :
+; Port(to-symbol(i), OUTPUT, type(t))
+; (t:Type) :
+; t
+;
+;defn flatten-bundles (c:Circuit) :
+; defn flatten-exp (e:Expression) :
+; match(map(flatten-exp, e)) :
+; (e:UIntValue|SIntValue) :
+; e
+; (e:WRef) :
+; match(kind(e)) :
+; (k:MemKind|StructuralMemKind) :
+; val type* = map(flatten-type, type(e))
+; put-type(e, type*)
+; (k) :
+; val type* = flatten-type(type(e))
+; put-type(e, type*)
+; (e) :
+; val type* = flatten-type(type(e))
+; put-type(e, type*)
+;
+; defn flatten-element (e:Element) :
+; val t* = flatten-type(type(e))
+; match(map(flatten-exp, e)) :
+; (e:Register) : Register(t*, value(e), enable(e))
+; (e:Memory) : Memory(t*, writers(e))
+; (e:Node) : Node(t*, value(e))
+; (e:Instance) : Instance(t*, module(e), ports(e))
+;
+; defn flatten-comm (c:Stmt) :
+; match(c) :
+; (c:LetRec) :
+; val entries* =
+; for entry in entries(c) map :
+; key(entry) => flatten-element(value(entry))
+; LetRec(entries*, flatten-comm(body(c)))
+; (c:DefWire) :
+; DefWire(name(c), flatten-type(type(c)))
+; (c:DefRegister) :
+; DefRegister(name(c), flatten-type(type(c)))
+; (c:DefMemory) :
+; val type* = map(flatten-type, type(c))
+; DefMemory(name(c), type*)
+; (c) :
+; map{flatten-comm, _} $
+; map(flatten-exp, c)
+;
+; defn flatten-module (m:Module) :
+; val ports* = map-append(flatten-ports, ports(m))
+; val body* = flatten-comm(body(m))
+; Module(name(m), ports*, body*)
+;
+; Circuit(modules*, main(c)) where :
+; val modules* = map(flatten-module, modules(c))
+;
+;
+;;================== BUNDLE EXPANSION =======================
+;defn expand-bundles (m:Module) :
+;
+; ;Collapse all field/index expressions
+; defn collapse-exp (e:Expression) -> Expression :
+; match(e) :
+; (e:WField) :
+; match(collapse-exp(exp(e))) :
+; (ei:WRef) :
+; if kind(ei) typeof InstanceKind :
+; e
+; else :
+; WRef(name*, type(e), kind(ei), dir(e)) where :
+; val name* = prefix(name(ei), name(e))
+; (ei:WField) :
+; WField(exp(ei), name*, type(e), dir(e)) where :
+; val name* = prefix(name(ei), name(e))
+; (e:WIndex) :
+; collapse-exp(WField(exp(e), name, type(e), dir(e))) where :
+; val name = to-symbol(value(e))
+; (e) :
+; map(collapse-exp, e)
+;
+; ;Expand expressions
+; defn expand-exp (e:Expression) -> List<Expression> :
+; match(type(e)) :
+; (t:BundleType) :
+; for p in ports(t) map :
+; val dir* = direction(p) * dir(e)
+; collapse-exp(WField(e, name(p), type(p), dir*))
+; (t) :
+; list(collapse-exp(e))
+;
+; ;Expand commands
+; defn expand-comm (c:Stmt) :
+; match(c) :
+; (c:DefWire) :
+; match(type(c)) :
+; (t:BundleType) :
+; Begin $
+; for p in ports(t) map :
+; DefWire(prefix(name(c), name(p)), type(p))
+; (t) :
+; c
+; (c:DefRegister) :
+; match(type(c)) :
+; (t:BundleType) :
+; Begin $
+; for p in ports(t) map :
+; DefRegister(prefix(name(c), name(p)), type(p))
+; (t) :
+; c
+; (c:DefMemory) :
+; match(type(type(c))) :
+; (t:BundleType) :
+; Begin $
+; for p in ports(t) map :
+; DefMemory(prefix(name(c), name(p)), type*) where :
+; val s = size(type(c))
+; val type* = VectorType(type(p), s)
+; (t) :
+; c
+; (c:WDefAccessor) :
+; match(type(source(c))) :
+; (t:BundleType) :
+; val srcs = expand-exp(source(c))
+; Begin $
+; for (p in ports(t), src in srcs) map :
+; WDefAccessor(name*, src, index(c), dir*) where :
+; val name* = prefix(name(c), name(p))
+; val dir* = direction(p) * dir(c)
+; (t) :
+; c
+; (c:Connect) :
+; val locs = expand-exp(loc(c))
+; val exps = expand-exp(exp(c))
+; Begin $
+; for (l in locs, e in exps) map :
+; switch {dir(l) == _} :
+; INPUT : Connect(l, e)
+; OUTPUT : Connect(e, l)
+; (c:ManyConnect) :
+; val locs-list = transpose(map(expand-exp, locs(c)))
+; val exps = expand-exp(exp(c))
+; Begin $
+; for (locs in locs-list, e in exps) map :
+; switch {dir(e) == _} :
+; OUTPUT : ManyConnect(index(c), locs, e)
+; INPUT : ConnectMany(index(c), e, locs)
+; (c:ConnectMany) :
+; val locs = expand-exp(loc(c))
+; val exps-list = transpose(map(expand-exp, exps(c)))
+; Begin $
+; for (l in locs, exps in exps-list) map :
+; switch {dir(l) == _} :
+; INPUT : ConnectMany(index(c), l, exps)
+; OUTPUT : ManyConnect(index(c), exps, l)
+; (c) :
+; map{expand-comm, _} $
+; map(collapse-exp, c)
+;
+; Module(name(m), ports(m), expand-comm(body(m)))
+;
+;defn expand-bundles (c:Circuit) :
+; Circuit(modules*, main(c)) where :
+; val modules* = map(expand-bundles, modules(c))
+;
+;
+;;=========== CONVERT MULTI CONNECTS to WHEN ================
+;defn expand-multi-connects (c:Circuit) :
+; defn equal-exp (e1:Expression, e2:Expression) :
+; DoPrim(EQUAL-OP, list(e1, e2), List(), UIntType(UnknownWidth()))
+; defn uint (i:Int) :
+; UIntValue(i, UnknownWidth())
+;
+; defn expand-comm (c:Stmt) :
+; match(c) :
+; (c:ConnectMany) :
+; Begin $ to-list $
+; for (i in 0 to false, e in exps(c)) stream :
+; Conditionally(equal-exp(index(c), uint(i)),
+; Connect(loc(c), e)
+; EmptyStmt())
+; (c:ManyConnect) :
+; Begin $ to-list $
+; for (i in 0 to false, l in locs(c)) stream :
+; Conditionally(equal-exp(index(c), uint(i)),
+; Connect(l, exp(c))
+; EmptyStmt())
+; (c) :
+; map(expand-comm, c)
+;
+; defn expand (m:Module) :
+; Module(name(m), ports(m), expand-comm(body(m)))
+;
+; Circuit(modules*, main(c)) where :
+; val modules* = map(expand, modules(c))
+;
+;
+;;================ EXPAND WHENS =============================
+;definterface SymbolicValue
+;defstruct ExpValue <: SymbolicValue :
+; exp: Expression
+;defstruct WhenValue <: SymbolicValue :
+; pred: Expression
+; conseq: SymbolicValue
+; alt: SymbolicValue
+;defstruct VoidValue <: SymbolicValue
+;
+;defmethod print (o:OutputStream, sv:SymbolicValue) :
+; match(sv) :
+; (sv:VoidValue) : print(o, "VOID")
+; (sv:WhenValue) : print-all(o, ["(" pred(sv) "? " conseq(sv) " : " alt(sv) ")"])
+; (sv:ExpValue) : print(o, exp(sv))
+;
+;defn key-eqv? (e1:Expression, e2:Expression) :
+; match(e1, e2) :
+; (e1:WRef, e2:WRef) :
+; name(e1) == name(e2)
+; (e1:WField, e2:WField) :
+; name(e1) == name(e2) and
+; key-eqv?(exp(e1), exp(e2))
+; (e1, e2) :
+; false
+;
+;defn merge-env (pred: Expression,
+; con-env: List<KeyValue<Expression,SymbolicValue>>,
+; alt-env: List<KeyValue<Expression,SymbolicValue>>) :
+; val merged = Vector<KeyValue<Expression, SymbolicValue>>()
+; defn new-key? (k:Expression) :
+; for entry in merged none? :
+; key-eqv?(key(entry), k)
+;
+; defn sv (env:List<KeyValue<Expression,SymbolicValue>>, k:Expression) :
+; for entry in env search :
+; if key-eqv?(key(entry), k) :
+; value(entry)
+;
+; for k in stream(key, concat(con-env, alt-env)) do :
+; if new-key?(k) :
+; match(sv(con-env, k), sv(alt-env, k)) :
+; (a:SymbolicValue, b:SymbolicValue) :
+; if a == b :
+; add(merged, k => a)
+; else :
+; add(merged, k => WhenValue(pred, a, b))
+; (a:SymbolicValue, b:False) :
+; add(merged, k => a)
+; (a:False, b:SymbolicValue) :
+; add(merged, k => b)
+; (a:False, b:False) :
+; false
+;
+; to-list(merged)
+;
+;defn simplify-env (env: List<KeyValue<Expression,SymbolicValue>>) :
+; val merged = Vector<KeyValue<Expression, SymbolicValue>>()
+; defn new-key? (k:Expression) :
+; for entry in merged none? :
+; key-eqv?(key(entry), k)
+; for entry in env do :
+; if new-key?(key(entry)) :
+; add(merged, entry)
+; to-list(merged)
+;
+;defn expand-whens (m:Module) :
+; val commands = Vector<Stmt>()
+; val elements = Vector<KeyValue<Symbol,Element>>()
+; defn eval (c:Stmt, env:List<KeyValue<Expression,SymbolicValue>>) ->
+; List<KeyValue<Expression,SymbolicValue>> :
+; match(c) :
+; (c:LetRec) :
+; do(add{elements, _}, entries(c))
+; eval(body(c), env)
+; (c:DefWire) :
+; add(commands, c)
+; val wire-ref = WRef(name(c), type(c), NodeKind(), INPUT)
+; List(wire-ref => VoidValue(), env)
+; (c:DefRegister) :
+; add(commands, c)
+; val reg-ref = WRef(name(c), type(c), RegKind(), INPUT)
+; List(reg-ref => VoidValue(), env)
+; (c:DefInstance) :
+; add(commands, c)
+; val entries = let :
+; val module-type = type(module(c)) as BundleType
+; val input-ports = to-list $
+; for p in ports(module-type) filter :
+; direction(p) == INPUT
+; val inst-ref = WRef(name(c), module-type, InstanceKind(), OUTPUT)
+; for p in input-ports map :
+; WField(inst-ref, name(p), type(p), INPUT) => VoidValue()
+; append(entries, env)
+; (c:DefMemory) :
+; add(commands, c)
+; env
+; (c:WDefAccessor) :
+; add(commands, c)
+; if dir(c) == INPUT :
+; val access-ref = WRef(name(c), type(source(c)), AccessorKind(), dir(c))
+; List(access-ref => VoidValue(), env)
+; else :
+; env
+; (c:Conditionally) :
+; val con-env = eval(conseq(c), env)
+; val alt-env = eval(alt(c), env)
+; merge-env(pred(c), con-env, alt-env)
+; (c:Begin) :
+; var env:List<KeyValue<Expression,SymbolicValue>> = env
+; for c in body(c) do :
+; env = eval(c, env)
+; env
+; (c:Connect) :
+; List(loc(c) => ExpValue(exp(c)), env)
+; (c:EmptyStmt) :
+; env
+;
+; defn convert-symbolic (key:Expression, sv:SymbolicValue) :
+; match(sv) :
+; (sv:VoidValue) :
+; throw $ PassException $ string-join $ [
+; "No default value for " key "."]
+; (sv:ExpValue) :
+; exp(sv)
+; (sv:WhenValue) :
+; defn multiplex-exp (pred:Expression, conseq:Expression, alt:Expression) :
+; DoPrim(MUX-OP, list(pred, conseq, alt), List(), type(conseq))
+; multiplex-exp(pred(sv),
+; convert-symbolic(key, conseq(sv))
+; convert-symbolic(key, alt(sv)))
+;
+; ;Compute final environment
+; val env0 = let :
+; val output-ports = to-list $
+; for p in ports(m) filter :
+; direction(p) == OUTPUT
+; for p in output-ports map :
+; val port-ref = WRef(name(p), type(p), PortKind(), INPUT)
+; port-ref => VoidValue()
+; val env* = simplify-env(eval(body(m), env0))
+;
+; ;Make new body
+; val body* = Begin(list(defs, LetRec(elems, connections))) where :
+; val defs = Begin(to-list(commands))
+; val elems = to-list(elements)
+; val connections = Begin $
+; for entry in env* map :
+; val sv = convert-symbolic(key(entry), value(entry))
+; Connect(key(entry), sv)
+;
+; ;Final module
+; Module(name(m), ports(m), body*)
+;
+;defn expand-whens (c:Circuit) :
+; val modules* = map(expand-whens, modules(c))
+; Circuit(modules*, main(c))
+;
+;
+;;================ STRUCTURAL FORM ==========================
+;
+;defn structural-form (m:Module) :
+; val elements = Vector<(() -> KeyValue<Symbol,Element>)>()
+; val connected = HashTable<Symbol, Expression>(symbol-hash)
+; val write-accessors = HashTable<Symbol, List<WDefAccessor>>(symbol-hash)
+; val read-accessors = HashTable<Symbol, WDefAccessor>(symbol-hash)
+; val inst-ports = HashTable<Symbol, List<KeyValue<Symbol, Expression>>>(symbol-hash)
+; val port-connects = Vector<Connect>()
+;
+; defn scan (c:Stmt) :
+; match(c) :
+; (c:Connect) :
+; match(loc(c)) :
+; (loc:WRef) :
+; match(kind(loc)) :
+; (k:PortKind) : add(port-connects, c)
+; (k) : connected[name(loc)] = exp(c)
+; (loc:WField) :
+; val inst = exp(loc) as WRef
+; val entry = name(loc) => exp(c)
+; inst-ports[name(inst)] = List(entry, get?(inst-ports, name(inst), List()))
+; (c:LetRec) :
+; for e in entries(c) do :
+; add(elements, {e})
+; scan(body(c))
+; (c:DefWire) :
+; add{elements, _} $ fn () :
+; name(c) => Node(type(c), connected[name(c)])
+; (c:DefRegister) :
+; add{elements, _} $ fn () :
+; val one = UIntValue(1, UnknownWidth())
+; name(c) => Register(type(c), connected[name(c)], one)
+; (c:DefInstance) :
+; add{elements, _} $ fn () :
+; name(c) => Instance(UnknownType(), module(c), inst-ports[name(c)])
+; (c:DefMemory) :
+; add{elements, _} $ fn () :
+; val ports = for a in get?(write-accessors, name(c), List()) map :
+; val one = UIntValue(1, UnknownWidth())
+; WritePort(index(a), connected[name(a)], one)
+; name(c) => Memory(type(c), ports)
+; (c:WDefAccessor) :
+; val mem = source(c) as WRef
+; switch {dir(c) == _} :
+; INPUT :
+; write-accessors[name(mem)] = List(c,
+; get?(write-accessors, name(mem), List()))
+; OUTPUT :
+; read-accessors[name(c)] = c
+; (c) :
+; do(scan, children(c))
+;
+; defn make-read-ports (e:Expression) :
+; match(e) :
+; (e:WRef) :
+; match(kind(e)) :
+; (k:AccessorKind) :
+; val accessor = read-accessors[name(e)]
+; ReadPort(source(accessor), index(accessor), type(e))
+; (k) : e
+; (e) : map(make-read-ports, e)
+;
+; Module(name(m), ports(m), body*) where :
+; scan(body(m))
+; val elems = to-list $
+; for e in elements stream :
+; val entry = e()
+; key(entry) => map(make-read-ports, value(entry))
+; val connect-ports = Begin $ to-list $
+; for c in port-connects stream :
+; Connect(loc(c), make-read-ports(exp(c)))
+; val body* =
+; if empty?(elems) : connect-ports
+; else : LetRec(elems, connect-ports)
+;
+;defn structural-form (c:Circuit) :
+; val modules* = map(structural-form, modules(c))
+; Circuit(modules*, main(c))
+;
+;
+;;==================== WIDTH INFERENCE ======================
+;defstruct WidthVar <: Width :
+; name: Symbol
+;
+;defmethod print (o:OutputStream, w:WidthVar) :
+; print(o, name(w))
+;
+;defn width! (t:Type) :
+; match(t) :
+; (t:UIntType) : width(t)
+; (t:SIntType) : width(t)
+; (t) : error("No width field.")
+;
+;defn put-width (t:Type, w:Width) :
+; match(t) :
+; (t:UIntType) : UIntType(w)
+; (t:SIntType) : SIntType(w)
+; (t) : t
+;
+;defn put-width (e:Expression, w:Width) :
+; val type* = put-width(type(e), w)
+; put-type(e, type*)
+;
+;defn add-width-vars (t:Type) :
+; defn width? (w:Width) :
+; match(w) :
+; (w:UnknownWidth) : WidthVar(gensym())
+; (w) : w
+; match(t) :
+; (t:UIntType) : UIntType(width?(width(t)))
+; (t:SIntType) : SIntType(width?(width(t)))
+; (t) : map(add-width-vars, t)
+;
+;defn uint-width (i:Int) :
+; var v:Int = i
+; var n:Int = 0
+; while v != 0 :
+; v = v >> 1
+; n = n + 1
+; IntWidth(n)
+;
+;defn sint-width (i:Int) :
+; if i > 0 :
+; val w = uint-width(i)
+; IntWidth(width(w) + 1)
+; else :
+; val w = uint-width(neg(i) - 1)
+; IntWidth(width(w) + 1)
+;
+;defn to-exp (w:Width) :
+; match(w) :
+; (w:IntWidth) : ELit(width(w))
+; (w:WidthVar) : EVar(name(w))
+; (w) : error $ string-join $ [
+; "Cannot convert " w " to exp."]
+;
+;defn primop-width (op:PrimOp, ws:List<Width>, ints:List<Int>) -> Exp :
+; defn wmax (w1:Width, w2:Width) :
+; EMax(to-exp(w1), to-exp(w2))
+; defn wplus (w1:Width, w2:Width) :
+; EPlus(to-exp(w1), to-exp(w2))
+; defn wplus (w1:Width, w2:Int) :
+; EPlus(to-exp(w1), ELit(w2))
+; defn wminus (w1:Width, w2:Width) :
+; EMinus(to-exp(w1), to-exp(w2))
+; defn wminus (w1:Width, w2:Int) :
+; EMinus(to-exp(w1), ELit(w2))
+; defn wmax-inc (w1:Width, w2:Width) :
+; EPlus(wmax(w1, w2), ELit(1))
+;
+; switch {op == _} :
+; ADD-OP : wmax-inc(ws[0], ws[1])
+; ADD-WRAP-OP : wmax(ws[0], ws[1])
+; SUB-OP : wmax-inc(ws[0], ws[1])
+; SUB-WRAP-OP : wmax(ws[0], ws[1])
+; MUL-OP : wplus(ws[0], ws[1])
+; DIV-OP : wminus(ws[0], ws[1])
+; MOD-OP : to-exp(ws[1])
+; SHIFT-LEFT-OP : wplus(ws[0], ints[0])
+; SHIFT-RIGHT-OP : wminus(ws[0], ints[0])
+; PAD-OP : ELit(ints[0])
+; BIT-AND-OP : wmax(ws[0], ws[1])
+; BIT-OR-OP : wmax(ws[0], ws[1])
+; BIT-XOR-OP : wmax(ws[0], ws[1])
+; CONCAT-OP : wplus(ws[0], ints[0])
+; BIT-SELECT-OP : ELit(1)
+; BITS-SELECT-OP : ELit(ints[0])
+; MUX-OP : wmax(ws[1], ws[2])
+; LESS-OP : ELit(1)
+; LESS-EQ-OP : ELit(1)
+; GREATER-OP : ELit(1)
+; GREATER-EQ-OP : ELit(1)
+; EQUAL-OP : ELit(1)
+;
+;defn put-type (el:Element, t:Type) -> Element :
+; match(el) :
+; (el:Register) : Register(t, value(el), enable(el))
+; (el:Memory) : Memory(t, writers(el))
+; (el:Node) : Node(t, value(el))
+; (el:Instance) : Instance(t, module(el), ports(el))
+;
+;defn generate-constraints (c:Circuit) -> [Circuit, Vector<WConstraint>] :
+; ;Constraints
+; val cs = Vector<WConstraint>()
+; defn new-constraint (Constraint: (Symbol, Exp) -> WConstraint, wvar:Width, width:Width) :
+; match(wvar) :
+; (wvar:WidthVar) :
+; add(cs, Constraint(name(wvar), to-exp(width)))
+; (wvar) :
+; false
+;
+; defn to-width (e:Exp) :
+; match(e) :
+; (e:ELit) :
+; IntWidth(width(e))
+; (e:EVar) :
+; WidthVar(name(e))
+; (e) :
+; val x = gensym()
+; add(cs, WidthEqual(x, e))
+; WidthVar(x)
+;
+; ;Module types
+; val mod-types = HashTable<Symbol,Type>(symbol-hash)
+;
+; defn add-port-vars (m:Module) -> Module :
+; val ports* =
+; for p in ports(m) map :
+; val type* = add-width-vars(type(p))
+; Port(name(p), direction(p), type*)
+; mod-types[name(m)] = BundleType(ports*)
+; Module(name(m), ports*, body(m))
+;
+; ;Add Width Variables
+; defn add-module-vars (m:Module) -> Module :
+; val types = HashTable<Symbol,Type>(symbol-hash)
+; for p in ports(m) do :
+; types[name(p)] = type(p)
+;
+; defn infer-exp-width (e:Expression) -> Expression :
+; match(map(infer-exp-width, e)) :
+; (e:WRef) :
+; match(kind(e)) :
+; (k:ModuleKind) : put-type(e, mod-types[name(e)])
+; (k) : put-type(e, types[name(e)])
+; (e:WField) :
+; val t = bundle-field-type(type(exp(e)), name(e))
+; put-width(e, width!(t))
+; (e:UIntValue) :
+; match(width(e)) :
+; (w:UnknownWidth) : UIntValue(value(e), uint-width(value(e)))
+; (w) : e
+; (e:SIntValue) :
+; match(width(e)) :
+; (w:UnknownWidth) : SIntValue(value(e), sint-width(value(e)))
+; (w) : e
+; (e:DoPrim) :
+; val widths = map(width!{type(_)}, args(e))
+; val w = to-width(primop-width(op(e), widths, consts(e)))
+; put-width(e, w)
+; (e:ReadPort) :
+; val elem-type = type(type(mem(e)) as VectorType)
+; put-width(e, width!(elem-type))
+;
+; defn infer-comm-width (c:Stmt) :
+; match(c) :
+; (c:LetRec) :
+; ;Add width vars to elements
+; var entries*: List<KeyValue<Symbol,Element>> =
+; for entry in entries(c) map :
+; val el-name = key(entry)
+; key(entry) =>
+; match(value(entry)) :
+; (el:Register|Node) :
+; put-type(el, add-width-vars(type(el)))
+; (el:Memory) :
+; el
+; (el:Instance) :
+; val mod-type = type(infer-exp-width(module(el))) as BundleType
+; val type = BundleType $ to-list $
+; for p in ports(mod-type) filter :
+; direction(p) == OUTPUT
+; put-type(el, type)
+;
+; ;Add vars to environment
+; for entry in entries* do :
+; types[key(entry)] = type(value(entry))
+;
+; ;Infer types for elements
+; entries* =
+; for entry in entries* map :
+; key(entry) => map(infer-exp-width, value(entry))
+;
+; ;Generate constraints
+; for entry in entries* do :
+; val el-name = key(entry)
+; match(value(entry)) :
+; (el:Register) :
+; new-constraint(WidthEqual, reg-width, val-width) where :
+; val reg-width = width!(types[el-name])
+; val val-width = width!(type(value(el)))
+; (el:Node) :
+; new-constraint(WidthEqual, node-width, val-width) where :
+; val node-width = width!(types[el-name])
+; val val-width = width!(type(value(el)))
+; (el:Instance) :
+; val mod-type = type(module(el)) as BundleType
+; for entry in ports(el) do :
+; new-constraint(WidthGreater, port-width, val-width) where :
+; val port-name = key(entry)
+; val port-width = width!(bundle-field-type(mod-type, port-name))
+; val val-width = width!(type(value(entry)))
+; (el) : false
+;
+; ;Analyze body
+; LetRec(entries*, infer-comm-width(body(c)))
+;
+; (c:Connect) :
+; val loc* = infer-exp-width(loc(c))
+; val exp* = infer-exp-width(exp(c))
+; new-constraint(WidthGreater, loc-width, exp-width) where :
+; val loc-width = width!(type(loc*))
+; val exp-width = width!(type(exp*))
+; Connect(loc*, exp*)
+;
+; (c:Begin) :
+; map(infer-comm-width, c)
+;
+; Module(name(m), ports(m), body*) where :
+; val body* = infer-comm-width(body(m))
+;
+; val c* =
+; Circuit(modules*, main(c)) where :
+; val ms = map(add-port-vars, modules(c))
+; val modules* = map(add-module-vars, ms)
+; [c*, cs]
+;
+;
+;;================== FILL WIDTHS ============================
+;defn fill-widths (c:Circuit, solved:Streamable<WidthEqual>) :
+; ;Populate table
+; val table = HashTable<Symbol, Width>(symbol-hash)
+; for eq in solved do :
+; table[name(eq)] = IntWidth(width(value(eq) as ELit))
+;
+; defn width? (w:Width) :
+; match(w) :
+; (w:WidthVar) : get?(table, name(w), UnknownWidth())
+; (w) : w
+;
+; defn fill-type (t:Type) :
+; match(t) :
+; (t:UIntType) : UIntType(width?(width(t)))
+; (t:SIntType) : SIntType(width?(width(t)))
+; (t) : map(fill-type, t)
+;
+; defn fill-exp (e:Expression) -> Expression :
+; val e* = map(fill-exp, e)
+; val type* = fill-type(type(e))
+; put-type(e*, type*)
+;
+; defn fill-element (e:Element) :
+; val e* = map(fill-exp, e)
+; val type* = fill-type(type(e))
+; put-type(e*, type*)
+;
+; defn fill-comm (c:Stmt) :
+; match(c) :
+; (c:LetRec) :
+; val entries* =
+; for e in entries(c) map :
+; key(e) => fill-element(value(e))
+; LetRec(entries*, fill-comm(body(c)))
+; (c) :
+; map{fill-comm, _} $
+; map(fill-exp, c)
+;
+; defn fill-port (p:Port) :
+; Port(name(p), direction(p), fill-type(type(p)))
+;
+; defn fill-mod (m:Module) :
+; Module(name(m), ports*, body*) where :
+; val ports* = map(fill-port, ports(m))
+; val body* = fill-comm(body(m))
+;
+; Circuit(modules*, main(c)) where :
+; val modules* = map(fill-mod, modules(c))
+;
+;
+;;=============== TYPE INFERENCE DRIVER =====================
+;defn infer-widths (c:Circuit) :
+; val [c*, cs] = generate-constraints(c)
+; val solved = solve-widths(cs)
+; fill-widths(c*, solved)
+;
+;
+;;================ PAD WIDTHS ===============================
+;defn pad-widths (c:Circuit) :
+; ;Pad an expression to the given width
+; defn pad-exp (e:Expression, w:Int) :
+; match(type(e)) :
+; (t:UIntType|SIntType) :
+; val prev-w = width!(t) as IntWidth
+; if width(prev-w) < w :
+; val type* = put-width(t, IntWidth(w))
+; DoPrim(PAD-OP, list(e), list(w), type*)
+; else :
+; e
+; (t) :
+; e
+;
+; defn pad-exp (e:Expression, w:Width) :
+; val w-value = width(w as IntWidth)
+; pad-exp(e, w-value)
+;
+; ;Convenience
+; defn max-width (es:Streamable<Expression>) :
+; defn int-width (e:Expression) :
+; width(width!(type(e)) as IntWidth)
+; maximum(stream(int-width, es))
+;
+; defn match-widths (es:List<Expression>) :
+; val w = max-width(es)
+; map(pad-exp{_, w}, es)
+;
+; ;Match widths for an expression
+; defn match-exp-width (e:Expression) :
+; match(map(match-exp-width, e)) :
+; (e:DoPrim) :
+; if contains?([BIT-AND-OP, BIT-OR-OP, BIT-XOR-OP, EQUAL-OP], op(e)) :
+; val args* = match-widths(args(e))
+; DoPrim(op(e), args*, consts(e), type(e))
+; else if op(e) == MUX-OP :
+; val args* = List(head(args(e)), match-widths(tail(args(e))))
+; DoPrim(op(e), args*, consts(e), type(e))
+; else :
+; e
+; (e) : e
+;
+; defn match-element-width (e:Element) :
+; match(map(match-exp-width, e)) :
+; (e:Register) :
+; val w = width!(type(e))
+; val value* = pad-exp(value(e), w)
+; Register(type(e), value*, enable(e))
+; (e:Memory) :
+; val width = width!(type(type(e) as VectorType))
+; val writers* =
+; for w in writers(e) map :
+; WritePort(index(w), pad-exp(value(w), width), enable(w))
+; Memory(type(e), writers*)
+; (e:Node) :
+; val w = width!(type(e))
+; val value* = pad-exp(value(e), w)
+; Node(type(e), value*)
+; (e:Instance) :
+; val mod-type = type(module(e)) as BundleType
+; val ports* =
+; for p in ports(e) map :
+; val port-type = bundle-field-type(mod-type, key(p))
+; val port-val = pad-exp(value(p), width!(port-type))
+; key(p) => port-val
+; Instance(type(e), module(e), ports*)
+;
+; ;Match widths for a command
+; defn match-comm-width (c:Stmt) :
+; match(map(match-exp-width, c)) :
+; (c:LetRec) :
+; val entries* =
+; for e in entries(c) map :
+; key(e) => match-element-width(value(e))
+; LetRec(entries*, match-comm-width(body(c)))
+; (c:Connect) :
+; val w = width!(type(loc(c)))
+; val exp* = pad-exp(exp(c), w)
+; Connect(loc(c), exp*)
+; (c) :
+; map(match-comm-width, c)
+;
+; defn match-mod-width (m:Module) :
+; Module(name(m), ports(m), body*) where :
+; val body* = match-comm-width(body(m))
+;
+; Circuit(modules*, main(c)) where :
+; val modules* = map(match-mod-width, modules(c))
+;
+;
+;;================== INLINING ===============================
+;defn inline-instances (c:Circuit) :
+; val module-table = HashTable<Symbol,Module>(symbol-hash)
+; val inlined? = HashTable<Symbol,True|False>(symbol-hash)
+; for m in modules(c) do :
+; module-table[name(m)] = m
+; inlined?[name(m)] = false
+;
+; ;Convert a module into a sequence of elements
+; defn to-elements (m:Module,
+; inst:Symbol,
+; port-exps:List<KeyValue<Symbol,Expression>>) ->
+; List<KeyValue<Symbol, Element>> :
+; defn rename-exp (e:Expression) :
+; match(e) :
+; (e:WRef) : WRef(prefix(inst, name(e)), type(e), kind(e), dir(e))
+; (e) : map(rename-exp, e)
+;
+; defn to-elements (c:Stmt) -> List<KeyValue<Symbol,Element>> :
+; match(c) :
+; (c:LetRec) :
+; val entries* =
+; for entry in entries(c) map :
+; val name* = prefix(inst, key(entry))
+; val element* = map(rename-exp, value(entry))
+; name* => element*
+; val body* = to-elements(body(c))
+; append(entries*, body*)
+; (c:Connect) :
+; val ref = loc(c) as WRef
+; val name* = prefix(inst, name(ref))
+; list(name* => Node(type(exp(c)), rename-exp(exp(c))))
+; (c:Begin) :
+; map-append(to-elements, body(c))
+;
+; val inputs =
+; for p in ports(m) map-append :
+; if direction(p) == INPUT :
+; val port-exp = lookup!(port-exps, name(p))
+; val name* = prefix(inst, name(p))
+; list(name* => Node(type(port-exp), port-exp))
+; else :
+; List()
+; append(inputs, to-elements(body(m)))
+;
+; ;Inline all instances in the module
+; defn inline-instances (m:Module) :
+; defn rename-exp (e:Expression) :
+; match(e) :
+; (e:WField) :
+; val inst-exp = exp(e) as WRef
+; val name* = prefix(name(inst-exp), name(e))
+; WRef(name*, type(e), NodeKind(), dir(e))
+; (e) :
+; map(rename-exp, e)
+;
+; defn inline-elems (es:List<KeyValue<Symbol,Element>>) :
+; for entry in es map-append :
+; match(value(entry)) :
+; (el:Instance) :
+; val mod-name = name(module(el) as WRef)
+; val module = inlined-module(mod-name)
+; to-elements(module, key(entry), ports(el))
+; (el) :
+; list(entry)
+;
+; defn inline-comm (c:Stmt) :
+; match(map(rename-exp, c)) :
+; (c:LetRec) :
+; val entries* = inline-elems(entries(c))
+; LetRec(entries*, inline-comm(body(c)))
+; (c) :
+; map(inline-comm, c)
+;
+; Module(name(m), ports(m), inline-comm(body(m)))
+;
+; ;Retrieve the inlined instance of a module
+; defn inlined-module (name:Symbol) :
+; if inlined?[name] :
+; module-table[name]
+; else :
+; val module* = inline-instances(module-table[name])
+; module-table[name] = module*
+; inlined?[name] = true
+; module*
+;
+; ;Return the fully inlined circuit
+; val main-module = inlined-module(main(c))
+; Circuit(list(main-module), main(c))
;;;================ UTILITIES ================================
@@ -1982,16 +1991,16 @@ public defn run-passes (c: Circuit, p: List<Char>) :
if contains(p,'c') : do-stage("Make Explicit Reset", make-explicit-reset)
if contains(p,'d') : do-stage("Initialize Registers", initialize-registers)
if contains(p,'e') : do-stage("Infer Types", infer-types)
- if contains(p,'f') : do-stage("Infer Directions", infer-directions)
- if contains(p,'g') : do-stage("Expand Accessors", expand-accessors)
- if contains(p,'h') : do-stage("Flatten Bundles", flatten-bundles)
- if contains(p,'i') : do-stage("Expand Bundles", expand-bundles)
- if contains(p,'j') : do-stage("Expand Multi Connects", expand-multi-connects)
- if contains(p,'k') : do-stage("Expand Whens", expand-whens)
- if contains(p,'l') : do-stage("Structural Form", structural-form)
- if contains(p,'m') : do-stage("Infer Widths", infer-widths)
- if contains(p,'n') : do-stage("Pad Widths", pad-widths)
- if contains(p,'o') : do-stage("Inline Instances", inline-instances)
+ ;if contains(p,'f') : do-stage("Infer Directions", infer-directions)
+ ;if contains(p,'g') : do-stage("Expand Accessors", expand-accessors)
+ ;if contains(p,'h') : do-stage("Flatten Bundles", flatten-bundles)
+ ;if contains(p,'i') : do-stage("Expand Bundles", expand-bundles)
+ ;if contains(p,'j') : do-stage("Expand Multi Connects", expand-multi-connects)
+ ;if contains(p,'k') : do-stage("Expand Whens", expand-whens)
+ ;if contains(p,'l') : do-stage("Structural Form", structural-form)
+ ;if contains(p,'m') : do-stage("Infer Widths", infer-widths)
+ ;if contains(p,'n') : do-stage("Pad Widths", pad-widths)
+ ;if contains(p,'o') : do-stage("Inline Instances", inline-instances)
println("Done!")
diff --git a/test/passes/infer-types/bundle.fir b/test/passes/infer-types/bundle.fir
new file mode 100644
index 00000000..10a839a1
--- /dev/null
+++ b/test/passes/infer-types/bundle.fir
@@ -0,0 +1,13 @@
+; RUN: firrtl %s abcde ct | tee %s.out | FileCheck %s
+
+;CHECK: Infer Types
+circuit top :
+ module subtracter :
+ wire z : {input x : UInt, output y: SInt}
+ node x = z.x ;CHECK: node x = z@<t:{input x : UInt@<t:UInt>, output y : SInt@<t:SInt>}>.x@<t:UInt>
+ node y = z.y ;CHECK: node y = z@<t:{input x : UInt@<t:UInt>, output y : SInt@<t:SInt>}>.y@<t:SInt>
+
+ wire a : UInt(3)[10] ;CHECK: wire a : UInt(3)[10]@<t:UInt>@<t:UInt(3)[10]@<t:UInt>>
+ node b = a.2 ;CHECK: node b = a@<t:UInt(3)[10]@<t:UInt>>.2@<t:UInt>
+ accessor c = a[UInt(3)] ;CHECK: unknown accessor c = a@<t:UInt(3)[10]@<t:UInt>>[UInt(3)]
+; CHECK: Finished Infer Types
diff --git a/test/passes/infer-types/gcd.fir b/test/passes/infer-types/gcd.fir
index 4fdc9ab1..b60c4f38 100644
--- a/test/passes/infer-types/gcd.fir
+++ b/test/passes/infer-types/gcd.fir
@@ -6,8 +6,8 @@ circuit top :
input x : UInt
input y : UInt
output z : UInt
- z := sub-mod(x, y)
- ;CHECK: z@<t:UInt> := sub-mod(x@<t:UInt>, y@<t:UInt>)@<t:UInt>
+ z := sub-wrap(x, y)
+ ;CHECK: z@<t:UInt> := sub-wrap(x@<t:UInt>, y@<t:UInt>)@<t:UInt>
module gcd :
input a : UInt(16)
input b : UInt(16)
@@ -19,17 +19,17 @@ circuit top :
; CHECK: reg x : UInt
x.init := UInt(0)
y.init := UInt(42)
- when greater(x, y) :
- ;CHECK: when greater(x@<t:UInt>, y@<t:UInt>)@<t:UInt> :
+ when gt(x, y) :
+ ;CHECK: when gt(x@<t:UInt>, y@<t:UInt>)@<t:UInt> :
inst s of subtracter
- ;CHECK: inst s of subtracter@<t:{input x : UInt@<t:UInt>, input y : UInt@<t:UInt>, output z : UInt@<t:UInt>, input reset : UInt(1)@<t:UInt(1)>}>
+ ;CHECK: inst s of subtracter@<t:{input x : UInt@<t:UInt>, input y : UInt@<t:UInt>, output z : UInt@<t:UInt>, input reset : UInt(1)@<t:UInt>}>
s.x := x
s.y := y
x := s.z
- ;CHECK: s@<t:{input x : UInt@<t:UInt>, input y : UInt@<t:UInt>, output z : UInt@<t:UInt>, input reset : UInt(1)@<t:UInt(1)>}>.reset@<t:UInt(1)> := reset@<t:UInt(1)>
- ;CHECK: s@<t:{input x : UInt@<t:UInt>, input y : UInt@<t:UInt>, output z : UInt@<t:UInt>, input reset : UInt(1)@<t:UInt(1)>}>.x@<t:UInt> := x@<t:UInt>
- ;CHECK: s@<t:{input x : UInt@<t:UInt>, input y : UInt@<t:UInt>, output z : UInt@<t:UInt>, input reset : UInt(1)@<t:UInt(1)>}>.y@<t:UInt> := y@<t:UInt>
- ;CHECK: x@<t:UInt> := s@<t:{input x : UInt@<t:UInt>, input y : UInt@<t:UInt>, output z : UInt@<t:UInt>, input reset : UInt(1)@<t:UInt(1)>}>.z@<t:UInt>
+ ;CHECK: s@<t:{input x : UInt@<t:UInt>, input y : UInt@<t:UInt>, output z : UInt@<t:UInt>, input reset : UInt(1)@<t:UInt>}>.reset@<t:UInt> := reset@<t:UInt>
+ ;CHECK: s@<t:{input x : UInt@<t:UInt>, input y : UInt@<t:UInt>, output z : UInt@<t:UInt>, input reset : UInt(1)@<t:UInt>}>.x@<t:UInt> := x@<t:UInt>
+ ;CHECK: s@<t:{input x : UInt@<t:UInt>, input y : UInt@<t:UInt>, output z : UInt@<t:UInt>, input reset : UInt(1)@<t:UInt>}>.y@<t:UInt> := y@<t:UInt>
+ ;CHECK: x@<t:UInt> := s@<t:{input x : UInt@<t:UInt>, input y : UInt@<t:UInt>, output z : UInt@<t:UInt>, input reset : UInt(1)@<t:UInt>}>.z@<t:UInt>
else :
inst s2 of subtracter
s2.x := x
@@ -39,7 +39,7 @@ circuit top :
x := a
y := b
v := equal(v, UInt(0))
- ;CHECK: v@<t:UInt(1)> := equal(v@<t:UInt(1)>, UInt(0))@<t:UInt>
+ ;CHECK: v@<t:UInt> := equal(v@<t:UInt>, UInt(0))@<t:UInt>
z := x
module top :
input a : UInt(16)
diff --git a/test/passes/infer-types/primops.fir b/test/passes/infer-types/primops.fir
index 2a29efbf..7e7342ae 100644
--- a/test/passes/infer-types/primops.fir
+++ b/test/passes/infer-types/primops.fir
@@ -7,110 +7,122 @@ circuit top :
wire b : UInt(8)
wire c : SInt(16)
wire d : SInt(8)
+ wire e : UInt(1)
- ;add
- module add :
- node w = adduu(a,b) ;CHECK: node w = adduu(a@<t:UInt>,b@<t:UInt>)@<t:UInt>
- node x = addus(a,d) ;CHECK: node x = addus(a@<t:UInt>,d@<t:SInt>)@<t:SInt>
- node y = addsu(c,b) ;CHECK: node y = addsu(c@<t:SInt>,b@<t:UInt>)@<t:SInt>
- node z = addss(c,d) ;CHECK: node z = addss(c@<t:SInt>,d@<t:SInt>)@<t:SInt>
- ;sub
- module sub :
- node w = subuu(a,b) ;CHECK: node w = subuu(a@<t:UInt>,b@<t:UInt>)@<t:SInt>
- node x = subus(a,d) ;CHECK: node x = subus(a@<t:UInt>,d@<t:SInt>)@<t:SInt>
- node y = subsu(c,b) ;CHECK: node y = subsu(c@<t:SInt>,b@<t:UInt>)@<t:SInt>
- node z = subss(c,d) ;CHECK: node z = subss(c@<t:SInt>,d@<t:SInt>)@<t:SInt>
- ;mul
- module mul :
- node w = muluu(a,b) ;CHECK: node w = muluu(a@<t:UInt>,b@<t:UInt>)@<t:UInt>
- node x = mulus(a,d) ;CHECK: node x = mulus(a@<t:UInt>,d@<t:SInt>)@<t:SInt>
- node y = mulsu(c,b) ;CHECK: node y = mulsu(c@<t:SInt>,b@<t:UInt>)@<t:SInt>
- node z = mulss(c,d) ;CHECK: node z = mulss(c@<t:SInt>,d@<t:SInt>)@<t:SInt>
- ;div
- module div :
- node w = divuu(a,b) ;CHECK: node w = divuu(a@<t:UInt>,b@<t:UInt>)@<t:UInt>
- node x = divus(a,d) ;CHECK: node x = divus(a@<t:UInt>,d@<t:SInt>)@<t:SInt>
- node y = divsu(c,b) ;CHECK: node y = divsu(c@<t:SInt>,b@<t:UInt>)@<t:SInt>
- node z = divss(c,d) ;CHECK: node z = divss(c@<t:SInt>,d@<t:SInt>)@<t:SInt>
- ;mod
- module mod :
- node w = moduu(a,b) ;CHECK: node w = moduu(a@<t:UInt>,b@<t:UInt>)@<t:UInt>
- node x = modus(a,d) ;CHECK: node x = modus(a@<t:UInt>,d@<t:SInt>)@<t:SInt>
- node y = modsu(c,b) ;CHECK: node y = modsu(c@<t:SInt>,b@<t:UInt>)@<t:SInt>
- node z = modss(c,d) ;CHECK: node z = modss(c@<t:SInt>,d@<t:SInt>)@<t:SInt>
- ;rem
- module rem :
- node w = remuu(a,b) ;CHECK: node w = remuu(a@<t:UInt>,b@<t:UInt>)@<t:UInt>
- node x = remus(a,d) ;CHECK: node x = remus(a@<t:UInt>,d@<t:SInt>)@<t:SInt>
- node y = remsu(c,b) ;CHECK: node y = remsu(c@<t:SInt>,b@<t:UInt>)@<t:SInt>
- node z = remss(c,d) ;CHECK: node z = remss(c@<t:SInt>,d@<t:SInt>)@<t:SInt>
- ;add-mod
- module add-mod :
- node w = add-moduu(a,b) ;CHECK: node w = add-moduu(a@<t:UInt>,b@<t:UInt>)@<t:UInt>
- node x = add-modus(a,d) ;CHECK: node x = add-modus(a@<t:UInt>,d@<t:SInt>)@<t:UInt>
- node y = add-modsu(c,b) ;CHECK: node y = add-modsu(c@<t:SInt>,b@<t:UInt>)@<t:SInt>
- node z = add-modss(c,d) ;CHECK: node z = add-modss(c@<t:SInt>,d@<t:SInt>)@<t:SInt>
- ;sub-mod
- module sub-mod :
- node w = sub-moduu(a,b) ;CHECK: node w = sub-moduu(a@<t:UInt>,b@<t:UInt>)@<t:UInt>
- node x = sub-modus(a,d) ;CHECK: node x = sub-modus(a@<t:UInt>,d@<t:SInt>)@<t:UInt>
- node y = sub-modsu(c,b) ;CHECK: node y = sub-modsu(c@<t:SInt>,b@<t:UInt>)@<t:SInt>
- node z = sub-modss(c,d) ;CHECK: node z = sub-modss(c@<t:SInt>,d@<t:SInt>)@<t:SInt>
- ;lt
- module lt :
- node w = ltuu(a,b) ;CHECK: node w = ltuu(a@<t:UInt>,b@<t:UInt>)@<t:UInt>
- node x = ltus(a,d) ;CHECK: node x = ltus(a@<t:UInt>,d@<t:SInt>)@<t:UInt>
- node y = ltsu(c,b) ;CHECK: node y = ltsu(c@<t:SInt>,b@<t:UInt>)@<t:UInt>
- node z = ltss(c,d) ;CHECK: node z = ltss(c@<t:SInt>,d@<t:SInt>)@<t:UInt>
- ;leq
- module leq :
- node w = lequu(a,b) ;CHECK: node w = lequu(a@<t:UInt>,b@<t:UInt>)@<t:UInt>
- node x = lequs(a,d) ;CHECK: node x = lequs(a@<t:UInt>,d@<t:SInt>)@<t:UInt>
- node y = leqsu(c,b) ;CHECK: node y = leqsu(c@<t:SInt>,b@<t:UInt>)@<t:UInt>
- node z = leqss(c,d) ;CHECK: node z = leqss(c@<t:SInt>,d@<t:SInt>)@<t:UInt>
- ;gt
- module gt :
- node w = gtuu(a,b) ;CHECK: node w = gtuu(a@<t:UInt>,b@<t:UInt>)@<t:UInt>
- node x = gtus(a,d) ;CHECK: node x = gtus(a@<t:UInt>,d@<t:SInt>)@<t:UInt>
- node y = gtsu(c,b) ;CHECK: node y = gtsu(c@<t:SInt>,b@<t:UInt>)@<t:UInt>
- node z = gtss(c,d) ;CHECK: node z = gtss(c@<t:SInt>,d@<t:SInt>)@<t:UInt>
- ;geq
- module geq :
- node w = gequu(a,b) ;CHECK: node w = gequu(a@<t:UInt>,b@<t:UInt>)@<t:UInt>
- node x = gequs(a,d) ;CHECK: node x = gequs(a@<t:UInt>,d@<t:SInt>)@<t:UInt>
- node y = geqsu(c,b) ;CHECK: node y = geqsu(c@<t:SInt>,b@<t:UInt>)@<t:UInt>
- node z = geqss(c,d) ;CHECK: node z = geqss(c@<t:SInt>,d@<t:SInt>)@<t:UInt>
- ;pad
- module pad :
- node w = paduu(a,b) ;CHECK: node w = paduu(a@<t:UInt>,b@<t:UInt>)@<t:UInt>
- node x = padus(a,d) ;CHECK: node x = padus(a@<t:UInt>,d@<t:SInt>)@<t:UInt>
- node y = padsu(c,b) ;CHECK: node y = padsu(c@<t:SInt>,b@<t:UInt>)@<t:SInt>
- node z = padss(c,d) ;CHECK: node z = padss(c@<t:SInt>,d@<t:SInt>)@<t:SInt>
- ;and
- module and :
- node w = and(a,b) ;CHECK: node w = and(a@<t:UInt>,b@<t:UInt>)@<t:UInt>
- module or :
- node w = or(a,b) ;CHECK: node w = or(a@<t:UInt>,b@<t:UInt>)@<t:UInt>
- module xor :
- node w = xor(a,b) ;CHECK: node w = xor(a@<t:UInt>,b@<t:UInt>)@<t:UInt>
- ;concat
- node w = concat(a,b) ;CHECK: node w = concat(a@<t:UInt>,b@<t:UInt>)@<t:UInt>
- ;equal
- node w = equaluu(a,b) ;CHECK: node w = equaluu(a@<t:UInt>,b@<t:UInt>)@<t:UInt>
- node x = equalus(a,d) ;CHECK: node x = equalus(a@<t:UInt>,d@<t:SInt>)@<t:UInt>
- node y = equalsu(c,b) ;CHECK: node y = equalsu(c@<t:SInt>,b@<t:UInt>)@<t:UInt>
- node z = equalss(c,d) ;CHECK: node z = equalss(c@<t:SInt>,d@<t:SInt>)@<t:UInt>
- ;mux
- node w = muxuu(e,a,b) ;CHECK: node w = muxuu(e@<t:UInt>,a@<t:UInt>,b@<t:UInt>)@<t:UInt>
- node x = muxus(e,a,d) ;CHECK: node x = muxus(e@<t:UInt>,a@<t:UInt>,d@<t:SInt>)@<t:SInt>
- node y = muxsu(e,c,b) ;CHECK: node y = muxsu(e@<t:UInt>,c@<t:SInt>,b@<t:UInt>)@<t:SInt>
- node z = muxss(e,c,d) ;CHECK: node z = muxss(e@<t:UInt>,c@<t:SInt>,d@<t:SInt>)@<t:SInt>
- ;shl
- ;shr
- ;bit
- ;bits
-
-
-
-
+ node vadd = add(a, c) ;CHECK: node vadd = add(a@<t:UInt>, c@<t:SInt>)@<t:SInt>
+ node wadd-uu = add-uu(a, b) ;CHECK: node wadd-uu = add-uu(a@<t:UInt>, b@<t:UInt>)@<t:UInt>
+ node xadd-us = add-us(a, d) ;CHECK: node xadd-us = add-us(a@<t:UInt>, d@<t:SInt>)@<t:SInt>
+ node yadd-su = add-su(c, b) ;CHECK: node yadd-su = add-su(c@<t:SInt>, b@<t:UInt>)@<t:SInt>
+ node zadd-ss = add-ss(c, d) ;CHECK: node zadd-ss = add-ss(c@<t:SInt>, d@<t:SInt>)@<t:SInt>
+
+ node vsub = sub(a, c) ;CHECK: node vsub = sub(a@<t:UInt>, c@<t:SInt>)@<t:SInt>
+ node wsub-uu = sub-uu(a, b) ;CHECK: node wsub-uu = sub-uu(a@<t:UInt>, b@<t:UInt>)@<t:SInt>
+ node xsub-us = sub-us(a, d) ;CHECK: node xsub-us = sub-us(a@<t:UInt>, d@<t:SInt>)@<t:SInt>
+ node ysub-su = sub-su(c, b) ;CHECK: node ysub-su = sub-su(c@<t:SInt>, b@<t:UInt>)@<t:SInt>
+ node zsub-ss = sub-ss(c, d) ;CHECK: node zsub-ss = sub-ss(c@<t:SInt>, d@<t:SInt>)@<t:SInt>
+
+ node vmul = mul(a, c) ;CHECK: node vmul = mul(a@<t:UInt>, c@<t:SInt>)@<t:SInt>
+ node wmul-uu = mul-uu(a, b) ;CHECK: node wmul-uu = mul-uu(a@<t:UInt>, b@<t:UInt>)@<t:UInt>
+ node xmul-us = mul-us(a, d) ;CHECK: node xmul-us = mul-us(a@<t:UInt>, d@<t:SInt>)@<t:SInt>
+ node ymul-su = mul-su(c, b) ;CHECK: node ymul-su = mul-su(c@<t:SInt>, b@<t:UInt>)@<t:SInt>
+ node zmul-ss = mul-ss(c, d) ;CHECK: node zmul-ss = mul-ss(c@<t:SInt>, d@<t:SInt>)@<t:SInt>
+
+ node vdiv = div(a, c) ;CHECK: node vdiv = div(a@<t:UInt>, c@<t:SInt>)@<t:SInt>
+ node wdiv-uu = div-uu(a, b) ;CHECK: node wdiv-uu = div-uu(a@<t:UInt>, b@<t:UInt>)@<t:UInt>
+ node xdiv-us = div-us(a, d) ;CHECK: node xdiv-us = div-us(a@<t:UInt>, d@<t:SInt>)@<t:SInt>
+ node ydiv-su = div-su(c, b) ;CHECK: node ydiv-su = div-su(c@<t:SInt>, b@<t:UInt>)@<t:SInt>
+ node zdiv-ss = div-ss(c, d) ;CHECK: node zdiv-ss = div-ss(c@<t:SInt>, d@<t:SInt>)@<t:SInt>
+
+ node vmod = mod(a, c) ;CHECK: node vmod = mod(a@<t:UInt>, c@<t:SInt>)@<t:UInt>
+ node wmod-uu = mod-uu(a, b) ;CHECK: node wmod-uu = mod-uu(a@<t:UInt>, b@<t:UInt>)@<t:UInt>
+ node xmod-us = mod-us(a, d) ;CHECK: node xmod-us = mod-us(a@<t:UInt>, d@<t:SInt>)@<t:UInt>
+ node ymod-su = mod-su(c, b) ;CHECK: node ymod-su = mod-su(c@<t:SInt>, b@<t:UInt>)@<t:SInt>
+ node zmod-ss = mod-ss(c, d) ;CHECK: node zmod-ss = mod-ss(c@<t:SInt>, d@<t:SInt>)@<t:SInt>
+
+ node vquo = quo(a, c) ;CHECK: node vquo = quo(a@<t:UInt>, c@<t:SInt>)@<t:SInt>
+ node wquo-uu = quo-uu(a, b) ;CHECK: node wquo-uu = quo-uu(a@<t:UInt>, b@<t:UInt>)@<t:UInt>
+ node xquo-us = quo-us(a, d) ;CHECK: node xquo-us = quo-us(a@<t:UInt>, d@<t:SInt>)@<t:SInt>
+ node yquo-su = quo-su(c, b) ;CHECK: node yquo-su = quo-su(c@<t:SInt>, b@<t:UInt>)@<t:SInt>
+ node zquo-ss = quo-ss(c, d) ;CHECK: node zquo-ss = quo-ss(c@<t:SInt>, d@<t:SInt>)@<t:SInt>
+
+ node vrem = rem(a, c) ;CHECK: node vrem = rem(a@<t:UInt>, c@<t:SInt>)@<t:SInt>
+ node wrem-uu = rem-uu(a, b) ;CHECK: node wrem-uu = rem-uu(a@<t:UInt>, b@<t:UInt>)@<t:UInt>
+ node xrem-us = rem-us(a, d) ;CHECK: node xrem-us = rem-us(a@<t:UInt>, d@<t:SInt>)@<t:SInt>
+ node yrem-su = rem-su(c, b) ;CHECK: node yrem-su = rem-su(c@<t:SInt>, b@<t:UInt>)@<t:UInt>
+ node zrem-ss = rem-ss(c, d) ;CHECK: node zrem-ss = rem-ss(c@<t:SInt>, d@<t:SInt>)@<t:SInt>
+
+ node vadd-wrap = add-wrap(a, c) ;CHECK: node vadd-wrap = add-wrap(a@<t:UInt>, c@<t:SInt>)@<t:SInt>
+ node wadd-wrap-uu = add-wrap-uu(a, b) ;CHECK: node wadd-wrap-uu = add-wrap-uu(a@<t:UInt>, b@<t:UInt>)@<t:UInt>
+ node xadd-wrap-us = add-wrap-us(a, d) ;CHECK: node xadd-wrap-us = add-wrap-us(a@<t:UInt>, d@<t:SInt>)@<t:SInt>
+ node yadd-wrap-su = add-wrap-su(c, b) ;CHECK: node yadd-wrap-su = add-wrap-su(c@<t:SInt>, b@<t:UInt>)@<t:SInt>
+ node zadd-wrap-ss = add-wrap-ss(c, d) ;CHECK: node zadd-wrap-ss = add-wrap-ss(c@<t:SInt>, d@<t:SInt>)@<t:SInt>
+
+ node vsub-wrap = sub-wrap(a, c) ;CHECK: node vsub-wrap = sub-wrap(a@<t:UInt>, c@<t:SInt>)@<t:SInt>
+ node wsub-wrap-uu = sub-wrap-uu(a, b) ;CHECK: node wsub-wrap-uu = sub-wrap-uu(a@<t:UInt>, b@<t:UInt>)@<t:UInt>
+ node xsub-wrap-us = sub-wrap-us(a, d) ;CHECK: node xsub-wrap-us = sub-wrap-us(a@<t:UInt>, d@<t:SInt>)@<t:SInt>
+ node ysub-wrap-su = sub-wrap-su(c, b) ;CHECK: node ysub-wrap-su = sub-wrap-su(c@<t:SInt>, b@<t:UInt>)@<t:SInt>
+ node zsub-wrap-ss = sub-wrap-ss(c, d) ;CHECK: node zsub-wrap-ss = sub-wrap-ss(c@<t:SInt>, d@<t:SInt>)@<t:SInt>
+
+ node vlt = lt(a, c) ;CHECK: node vlt = lt(a@<t:UInt>, c@<t:SInt>)@<t:UInt>
+ node wlt-uu = lt-uu(a, b) ;CHECK: node wlt-uu = lt-uu(a@<t:UInt>, b@<t:UInt>)@<t:UInt>
+ node xlt-us = lt-us(a, d) ;CHECK: node xlt-us = lt-us(a@<t:UInt>, d@<t:SInt>)@<t:UInt>
+ node ylt-su = lt-su(c, b) ;CHECK: node ylt-su = lt-su(c@<t:SInt>, b@<t:UInt>)@<t:UInt>
+ node zlt-ss = lt-ss(c, d) ;CHECK: node zlt-ss = lt-ss(c@<t:SInt>, d@<t:SInt>)@<t:UInt>
+
+ node vleq = leq(a, c) ;CHECK: node vleq = leq(a@<t:UInt>, c@<t:SInt>)@<t:UInt>
+ node wleq-uu = leq-uu(a, b) ;CHECK: node wleq-uu = leq-uu(a@<t:UInt>, b@<t:UInt>)@<t:UInt>
+ node xleq-us = leq-us(a, d) ;CHECK: node xleq-us = leq-us(a@<t:UInt>, d@<t:SInt>)@<t:UInt>
+ node yleq-su = leq-su(c, b) ;CHECK: node yleq-su = leq-su(c@<t:SInt>, b@<t:UInt>)@<t:UInt>
+ node zleq-ss = leq-ss(c, d) ;CHECK: node zleq-ss = leq-ss(c@<t:SInt>, d@<t:SInt>)@<t:UInt>
+
+ node vgt = gt(a, c) ;CHECK: node vgt = gt(a@<t:UInt>, c@<t:SInt>)@<t:UInt>
+ node wgt-uu = gt-uu(a, b) ;CHECK: node wgt-uu = gt-uu(a@<t:UInt>, b@<t:UInt>)@<t:UInt>
+ node xgt-us = gt-us(a, d) ;CHECK: node xgt-us = gt-us(a@<t:UInt>, d@<t:SInt>)@<t:UInt>
+ node ygt-su = gt-su(c, b) ;CHECK: node ygt-su = gt-su(c@<t:SInt>, b@<t:UInt>)@<t:UInt>
+ node zgt-ss = gt-ss(c, d) ;CHECK: node zgt-ss = gt-ss(c@<t:SInt>, d@<t:SInt>)@<t:UInt>
+
+ node vgeq = geq(a, c) ;CHECK: node vgeq = geq(a@<t:UInt>, c@<t:SInt>)@<t:UInt>
+ node wgeq-uu = geq-uu(a, b) ;CHECK: node wgeq-uu = geq-uu(a@<t:UInt>, b@<t:UInt>)@<t:UInt>
+ node xgeq-us = geq-us(a, d) ;CHECK: node xgeq-us = geq-us(a@<t:UInt>, d@<t:SInt>)@<t:UInt>
+ node ygeq-su = geq-su(c, b) ;CHECK: node ygeq-su = geq-su(c@<t:SInt>, b@<t:UInt>)@<t:UInt>
+ node zgeq-ss = geq-ss(c, d) ;CHECK: node zgeq-ss = geq-ss(c@<t:SInt>, d@<t:SInt>)@<t:UInt>
+
+ node vequal = equal(a, b) ;CHECK: node vequal = equal(a@<t:UInt>, b@<t:UInt>)@<t:UInt>
+ node wequal-uu = equal-uu(a, b) ;CHECK: node wequal-uu = equal-uu(a@<t:UInt>, b@<t:UInt>)@<t:UInt>
+ node zequal-ss = equal-ss(c, d) ;CHECK: node zequal-ss = equal-ss(c@<t:SInt>, d@<t:SInt>)@<t:UInt>
+
+ node vmux = mux(e, a, b) ;CHECK: node vmux = mux(e@<t:UInt>, a@<t:UInt>, b@<t:UInt>)@<t:UInt>
+ node wmux-uu = mux-uu(e, a, b) ;CHECK: node wmux-uu = mux-uu(e@<t:UInt>, a@<t:UInt>, b@<t:UInt>)@<t:UInt>
+ node zmux-ss = mux-ss(e, c, d) ;CHECK: node zmux-ss = mux-ss(e@<t:UInt>, c@<t:SInt>, d@<t:SInt>)@<t:SInt>
+
+ node vpad = pad(a, 10) ;CHECK: node vpad = pad(a@<t:UInt>, 10)@<t:UInt>
+ node wpad-u = pad-u(a, 10) ;CHECK: node wpad-u = pad-u(a@<t:UInt>, 10)@<t:UInt>
+ node zpad-s = pad-s(c, 10) ;CHECK: node zpad-s = pad-s(c@<t:SInt>, 10)@<t:SInt>
+
+ node vas-UInt = as-UInt(d) ;CHECK: node vas-UInt = as-UInt(d@<t:SInt>)@<t:UInt>
+ node was-UInt-u = as-UInt-u(a) ;CHECK: node was-UInt-u = as-UInt-u(a@<t:UInt>)@<t:UInt>
+ node zas-UInt-s = as-UInt-s(c) ;CHECK: node zas-UInt-s = as-UInt-s(c@<t:SInt>)@<t:UInt>
+
+ node vas-SInt = as-SInt(a) ;CHECK: node vas-SInt = as-SInt(a@<t:UInt>)@<t:SInt>
+ node was-SInt-u = as-SInt-u(a) ;CHECK: node was-SInt-u = as-SInt-u(a@<t:UInt>)@<t:SInt>
+ node zas-SInt-s = as-SInt-s(c) ;CHECK: node zas-SInt-s = as-SInt-s(c@<t:SInt>)@<t:SInt>
+
+ node vshl = shl(a, 10) ;CHECK: node vshl = shl(a@<t:UInt>, 10)@<t:UInt>
+ node wshl-u = shl-u(a, 10) ;CHECK: node wshl-u = shl-u(a@<t:UInt>, 10)@<t:UInt>
+ node zshl-s = shl-s(c, 10) ;CHECK: node zshl-s = shl-s(c@<t:SInt>, 10)@<t:SInt>
+
+ node vshr = shr(a, 10) ;CHECK: node vshr = shr(a@<t:UInt>, 10)@<t:UInt>
+ node wshr-u = shr-u(a, 10) ;CHECK: node wshr-u = shr-u(a@<t:UInt>, 10)@<t:UInt>
+ node zshr-s = shr-s(c, 10) ;CHECK: node zshr-s = shr-s(c@<t:SInt>, 10)@<t:SInt>
+
+ node vconvert = convert(a) ;CHECK: node vconvert = convert(a@<t:UInt>)@<t:SInt>
+ node wconvert-u = convert-u(a) ;CHECK: node wconvert-u = convert-u(a@<t:UInt>)@<t:SInt>
+ node zconvert-s = convert-s(c) ;CHECK: node zconvert-s = convert-s(c@<t:SInt>)@<t:SInt>
+
+ node uand = bit-and(a, b) ;CHECK: node uand = bit-and(a@<t:UInt>, b@<t:UInt>)@<t:UInt>
+ node vor = bit-or(a, b) ;CHECK: node vor = bit-or(a@<t:UInt>, b@<t:UInt>)@<t:UInt>
+ node wxor = bit-xor(a, b) ;CHECK: node wxor = bit-xor(a@<t:UInt>, b@<t:UInt>)@<t:UInt>
+ node xconcat = concat(a, b) ;CHECK: node xconcat = concat(a@<t:UInt>, b@<t:UInt>)@<t:UInt>
+ node ybit = bit(a, 0) ;CHECK: node ybit = bit(a@<t:UInt>, 0)@<t:UInt>
+ node zbits = bits(a, 2, 0) ;CHECK: node zbits = bits(a@<t:UInt>, 2, 0)@<t:UInt>
;CHECK: Finished Infer Types
diff --git a/test/passes/initialize-register/when.fir b/test/passes/initialize-register/when.fir
index 4e2bef79..4e0690d8 100644
--- a/test/passes/initialize-register/when.fir
+++ b/test/passes/initialize-register/when.fir
@@ -6,7 +6,7 @@
input a : UInt(16)
input b : UInt(16)
output z : UInt
- when greater(1, 2) :
+ when gt(1, 2) :
reg r1: UInt
r1.init := UInt(12)
; CHECK: wire [[R1:gen[0-9]*]] : UInt
diff --git a/test/passes/resolve-kinds/gcd.fir b/test/passes/resolve-kinds/gcd.fir
index b06da6c5..f4ad0e05 100644
--- a/test/passes/resolve-kinds/gcd.fir
+++ b/test/passes/resolve-kinds/gcd.fir
@@ -6,8 +6,8 @@ circuit top :
input x : UInt
input y : UInt
output z : UInt
- z := sub-mod(x, y)
- ;CHECK: z@<k:port> := sub-mod(x@<k:port>, y@<k:port>)
+ z := sub-wrap(x, y)
+ ;CHECK: z@<k:port> := sub-wrap(x@<k:port>, y@<k:port>)
module gcd :
input a : UInt(16)
input b : UInt(16)
@@ -18,7 +18,7 @@ circuit top :
reg y : UInt
x.init := UInt(0)
y.init := UInt(42)
- when greater(x, y) :
+ when gt(x, y) :
inst s of subtracter
s.x := x
;CHECK: s@<k:inst>.x := x@<k:reg>