aboutsummaryrefslogtreecommitdiff
path: root/src/main/stanza/passes.stanza
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/stanza/passes.stanza')
-rw-r--r--src/main/stanza/passes.stanza129
1 files changed, 60 insertions, 69 deletions
diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza
index 0e20cd4c..236addce 100644
--- a/src/main/stanza/passes.stanza
+++ b/src/main/stanza/passes.stanza
@@ -196,12 +196,12 @@ defn hasGender (e:?) :
e typeof WRef|WSubfield|WIndex|WDefAccessor
defn hasWidth (e:?) :
- e typeof UIntType|SIntType|UIntValue|SIntValue|Pad
+ e typeof UIntType|SIntType|UIntValue|SIntValue
defn hasType (e:?) :
e typeof Ref|Subfield|Index|DoPrim|WritePort|ReadPort|WRef|WSubfield
|WIndex|DefWire|DefRegister|DefMemory|Register
- |VectorType|Port|Field|Pad
+ |VectorType|Port|Field
defn hasKind (e:?) :
e typeof WRef
@@ -244,15 +244,15 @@ defmethod print (o:OutputStream, e:WIndex) :
print-debug(o,e as ?)
defmethod print (o:OutputStream, s:WDefAccessor) :
- print-all(o,["accessor " name(s) " = " source(s) "[" index(s) "] @[" info(s) "]"])
+ print-all(o,["accessor " name(s) " = " source(s) "[" index(s) "]"])
print-debug(o,s)
defmethod print (o:OutputStream, c:ConnectToIndexed) :
- print-all(o, [locs(c) "[" index(c) "] := " exp(c) " @[" info(c) "]"])
+ print-all(o, [locs(c) "[" index(c) "] := " exp(c)])
print-debug(o,c as ?)
defmethod print (o:OutputStream, c:ConnectFromIndexed) :
- print-all(o, [loc(c) " := " exps(c) "[" index(c) "] @[" info(c) "]"])
+ print-all(o, [loc(c) " := " exps(c) "[" index(c) "]" ])
print-debug(o,c as ?)
defmethod map (f: Expression -> Expression, e: WSubfield) :
@@ -482,7 +482,6 @@ defn infer-exp-types (e:Expression, l:List<KeyValue<Symbol,Type>>) -> Expression
(e:ReadPort) : ReadPort(mem(e),index(e),get-vector-subtype(type(mem(e))),enable(e))
(e:WritePort) : WritePort(mem(e),index(e),get-vector-subtype(type(mem(e))),enable(e))
(e:Register) : Register(type(value(e)),value(e),enable(e))
- (e:Pad) : Pad(value(e),width(e),type(value(e)))
(e:UIntValue|SIntValue) : e
defn infer-types (s:Stmt, l:List<KeyValue<Symbol,Type>>) -> [Stmt List<KeyValue<Symbol,Type>>] :
@@ -790,9 +789,9 @@ defn expand-expr (e:Expression) -> List<EF> :
val exps = expand-expr(exp(e))
val len = num-elems(type(e))
headn(tailn(exps,len * value(e)),len)
- (e:Pad) :
- val v = exp(head(expand-expr(value(e))))
- list(EF(Pad(v,width(e),type(e)),DEFAULT))
+ ;(e:Pad) :
+ ;val v = exp(head(expand-expr(value(e))))
+ ;list(EF(Pad(v,width(e),type(e)),DEFAULT))
(e:DoPrim) :
val args = for x in args(e) map : exp(head(expand-expr(x)))
list(EF(DoPrim(op(e),args,consts(e),type(e)),DEFAULT))
@@ -931,7 +930,7 @@ public defmethod short-name (b:ExpandIndexedConnects) -> String : "expand-indexe
defn expand-connect-indexed-stmt (s: Stmt) -> Stmt :
defn equality (e1:Expression,e2:Expression) -> Expression :
- DoPrim(EQUAL-UU-OP,list(e1,e2),List(),UIntType(UnknownWidth()))
+ DoPrim(EQUAL-OP,list(e1,e2),List(),UIntType(UnknownWidth()))
defn get-name (e:Expression) -> Symbol :
match(e) :
(e:WRef) : symbol-join([name(e) `#])
@@ -1018,7 +1017,7 @@ defmethod equal? (e1:Expression,e2:Expression) -> True|False :
(e1:WRef,e2:WRef) : name(e1) == name(e2)
;(e1:DoPrim,e2:DoPrim) : TODO
(e1:WSubfield,e2:WSubfield) : name(e1) == name(e2)
- (e1:Pad,e2:Pad) : width(e1) == width(e2) and value(e1) == value(e2)
+ ;(e1:Pad,e2:Pad) : width(e1) == width(e2) and value(e1) == value(e2)
(e1:DoPrim,e2:DoPrim) :
var are-equal? = op(e1) == op(e2)
for (x in args(e1),y in args(e2)) do :
@@ -1047,7 +1046,7 @@ defn OR (e1:Expression,e2:Expression) -> Expression :
defn NOT (e1:Expression) -> Expression :
if e1 == one : zero
else if e1 == zero : one
- else : DoPrim(EQUAL-UU-OP,list(e1,zero),list(),UIntType(IntWidth(1)))
+ else : DoPrim(EQUAL-OP,list(e1,zero),list(),UIntType(IntWidth(1)))
defn children (e:Expression) -> List<Expression> :
val es = Vector<Expression>()
@@ -1144,7 +1143,7 @@ defn remove-nul (sv:SymbolicValue) -> SymbolicValue :
defn to-exp (sv:SymbolicValue) -> Expression|False :
match(remove-nul(sv)) :
(sv:SVMux) :
- DoPrim(MUX-UU-OP,
+ DoPrim(MUX-OP,
list(pred(sv),to-exp(conseq(sv)) as Expression,to-exp(alt(sv)) as Expression),
list(),
UIntType(IntWidth(1)))
@@ -1453,10 +1452,8 @@ defn apply (a:Int|False,b:Int|False, f: (Int,Int) -> Int) -> Int|False :
if a typeof Int and b typeof Int : f(a as Int, b as Int)
else : false
-; TODO: I should make MaxWidth take a variable list of arguments, which would make it easier to write the simplify function. It looks like there isn't a bug in the algorithm, but simplification reallllly speeds it up.
-
-defn solve-constraints (l:List<WGeq>) -> HashTable<Symbol,Int> :
+defn solve-constraints (l:List<WGeq>) -> HashTable<Symbol,Width> :
defn contains? (n:Symbol,h:HashTable<Symbol,?>) -> True|False : key?(h,n)
defn make-unique (ls:List<WGeq>) -> HashTable<Symbol,Width> :
val h = HashTable<Symbol,Width>(symbol-hash)
@@ -1533,34 +1530,6 @@ defn solve-constraints (l:List<WGeq>) -> HashTable<Symbol,Int> :
w
look(w)
has?
- defn evaluate (h:HashTable<Symbol,Width>) -> HashTable<Symbol,Int> :
- defn apply (a:Int|False,f:(Int) -> Int) -> Int|False :
- if a typeof Int : f(a as Int)
- else : false
- defn apply (a:Int|False,b:Int|False, f: (Int,Int) -> Int) -> Int|False :
- if a typeof Int and b typeof Int : f(a as Int, b as Int)
- else : false
- defn apply-l (l:List<Int|False>,f:(Int,Int) -> Int) -> Int|False :
- if length(l) == 0 : 0
- else : apply(head(l),apply-l(tail(l),f),f)
- defn max (a:Int,b:Int) -> Int :
- if a >= b : a
- else : b
- defn solve (w:Width) -> Int|False :
- match(w) :
- (w:VarWidth) : false
- (w:MaxWidth) : apply-l(map(solve,args(w)),max)
- (w:PlusWidth) : apply(solve(arg1(w)),solve(arg2(w)),{_ + _})
- (w:MinusWidth) : apply(solve(arg1(w)),solve(arg2(w)),{_ - _})
- (w:ExpWidth) : apply(2,solve(arg1(w)),{pow(_,_) - 1})
- (w:IntWidth) : width(w)
- (w) : error("Shouldn't be here")
-
- val i = HashTable<Symbol,Int>(symbol-hash)
- for x in h do :
- val s = solve(value(x))
- if s typeof Int : i[key(x)] = s as Int
- i
; Forward solve
; Returns a solved list where each constraint undergoes:
@@ -1613,11 +1582,7 @@ defn solve-constraints (l:List<WGeq>) -> HashTable<Symbol,Int> :
for x in b do : println-debug(x)
println-debug("=========================")
- ; Evaluate
- val e = evaluate(b)
- println-debug("Evaluated Constraints")
- for x in e do : println-debug(x)
- e
+ b
public defn width! (t:Type) -> Width :
match(t) :
@@ -1652,15 +1617,15 @@ defn gen-constraints (m:Module, h:HashTable<Symbol,Type>, v:Vector<WGeq>) -> Mod
(e:WSubfield) : WSubfield(exp(e),name(e),bundle-field-type(type(exp(e)),name(e)),gender(e))
(e:WIndex) : error("Shouldn't be here")
(e:DoPrim) : DoPrim(op(e),args(e),consts(e),primop-gen-constraints(e,v))
- (e:Pad) :
- val value-w = width!(value(e))
- val pad-w = remove-unknowns-w(width(e))
- add(v,WGeq(pad-w, value-w))
- val pad-t = match(type(e)) :
- (t:UIntType) : UIntType(pad-w)
- (t:SIntType) : SIntType(pad-w)
- (t) : error("Shouldn't be here")
- Pad(value(e),pad-w,pad-t)
+ ;(e:Pad) :
+ ; val value-w = width!(value(e))
+ ; val pad-w = remove-unknowns-w(width(e))
+ ; add(v,WGeq(pad-w, value-w))
+ ; val pad-t = match(type(e)) :
+ ; (t:UIntType) : UIntType(pad-w)
+ ; (t:SIntType) : SIntType(pad-w)
+ ; (t) : error("Shouldn't be here")
+ ; Pad(value(e),pad-w,pad-t)
(e:ReadPort) : ReadPort(mem(e),index(e),type(type(mem(e)) as VectorType),enable(e))
(e:WritePort) : WritePort(mem(e),index(e),type(type(mem(e)) as VectorType),enable(e))
(e:Register) : Register(type(value(e)),value(e),enable(e))
@@ -1699,21 +1664,47 @@ defn build-environment (c:Circuit,m:Module,h:HashTable<Symbol,Type>) -> HashTabl
build-environment(body(m))
h
-defn replace-var-widths (c:Circuit,h:HashTable<Symbol,Int>) -> Circuit :
- defn replace-var-widths-w (w:Width) -> Width :
+defn reduce-var-widths (c:Circuit,h:HashTable<Symbol,Width>) -> Circuit :
+ defn evaluate (w:Width) -> Width :
+ defn apply (a:Int|False,f:(Int) -> Int) -> Int|False :
+ if a typeof Int : f(a as Int)
+ else : false
+ defn apply (a:Int|False,b:Int|False, f: (Int,Int) -> Int) -> Int|False :
+ if a typeof Int and b typeof Int : f(a as Int, b as Int)
+ else : false
+ defn apply-l (l:List<Int|False>,f:(Int,Int) -> Int) -> Int|False :
+ if length(l) == 0 : 0
+ else : apply(head(l),apply-l(tail(l),f),f)
+ defn max (a:Int,b:Int) -> Int :
+ if a >= b : a
+ else : b
+ defn solve (w:Width) -> Int|False :
+ match(w) :
+ (w:VarWidth) :
+ val w* = h[name(w)]
+ if w* typeof VarWidth : false
+ else : solve(w*)
+ (w:MaxWidth) : apply-l(map(solve,args(w)),max)
+ (w:PlusWidth) : apply(solve(arg1(w)),solve(arg2(w)),{_ + _})
+ (w:MinusWidth) : apply(solve(arg1(w)),solve(arg2(w)),{_ - _})
+ (w:ExpWidth) : apply(2,solve(arg1(w)),{pow(_,_) - 1})
+ (w:IntWidth) : width(w)
+ (w) : error("Shouldn't be here")
+
+ val s = solve(w)
+ if s typeof Int : IntWidth(s as Int)
+ else : w
+
+ defn reduce-var-widths-w (w:Width) -> Width :
println-all-debug(["REPLACE: " w])
- val w* = match(w) :
- (w:VarWidth) :
- if key?(h,name(w)) : IntWidth(h[name(w)])
- else: w
- (w) : w
+ val w* = evaluate(w)
println-all-debug(["WITH: " w*])
w*
val modules* = for m in modules(c) map :
- Module{info(m),name(m),_,mapr(replace-var-widths-w,body(m))} $
+ Module{info(m),name(m),_,mapr(reduce-var-widths-w,body(m))} $
for p in ports(m) map :
- Port(info(p),name(p),direction(p),mapr(replace-var-widths-w,type(p)))
+ Port(info(p),name(p),direction(p),mapr(reduce-var-widths-w,type(p)))
Circuit(info(c),modules*,main(c))
@@ -1751,7 +1742,7 @@ defn infer-widths (c:Circuit) -> Circuit :
println-debug("======== SOLVED CONSTRAINTS ========")
for x in h do : println-debug(x)
println-debug("====================================")
- replace-var-widths(Circuit(info(c),modules*,main(c)),h)
+ reduce-var-widths(Circuit(info(c),modules*,main(c)),h)
;================= Inline Instances ========================
@@ -1835,7 +1826,7 @@ defn split-exp (c:Circuit) :
false
defn split-exp-e (e:Expression,v:Vector<Stmt>,n:Symbol|False,info:FileInfo) -> Expression :
match(map(split-exp-e{_,v,n,info},e)):
- (e:Subfield|DoPrim|Pad|ReadPort|Register|WritePort) :
+ (e:Subfield|DoPrim|ReadPort|Register|WritePort) :
val n* =
if n typeof False : firrtl-gensym(`T)
else : firrtl-gensym(symbol-join([n as Symbol `#]))