diff options
Diffstat (limited to 'src/main/stanza/ir-utils.stanza')
| -rw-r--r-- | src/main/stanza/ir-utils.stanza | 154 |
1 files changed, 131 insertions, 23 deletions
diff --git a/src/main/stanza/ir-utils.stanza b/src/main/stanza/ir-utils.stanza index 7a9cff31..7d4c30b2 100644 --- a/src/main/stanza/ir-utils.stanza +++ b/src/main/stanza/ir-utils.stanza @@ -39,7 +39,7 @@ public defn firrtl-gensym (s:Symbol,sym-hash:HashTable<Symbol,Int>) -> Symbol : symbol-join([s delin num]) else : sym-hash[s] = 0 - s + symbol-join([s delin 0]) val s* = to-string(s) val i* = generated?(s*) val nex = match(i*) : @@ -113,7 +113,20 @@ public defn EQV (e1:Expression,e2:Expression) -> Expression : DoPrim(EQUIV-OP,list(e1,e2),list(),type(e1)) public defn MUX (p:Expression,e1:Expression,e2:Expression) -> Expression : - DoPrim(MUX-OP,list(p,e1,e2),list(),type(e1)) + Mux(p,e1,e2,mux-type(type(e1),type(e2))) + +public defn mux-type (e1:Expression,e2:Expression) -> Type : + mux-type(type(e1),type(e2)) +public defn mux-type (t1:Type,t2:Type) -> Type : + if t1 == t2 : + match(t1,t2) : + (t1:UIntType,t2:UIntType) : UIntType(UnknownWidth()) + (t1:SIntType,t2:SIntType) : SIntType(UnknownWidth()) + (t1:VectorType,t2:VectorType) : VectorType(mux-type(type(t1),type(t2)),size(t1)) + (t1:BundleType,t2:BundleType) : + BundleType $ for (f1 in fields(t1),f2 in fields(t2)) map : + Field(name(f1),flip(f1),mux-type(type(f1),type(f2))) + else : UnknownType() public defn CAT (e1:Expression,e2:Expression) -> Expression : DoPrim(CONCAT-OP,list(e1,e2),list(),type(e1)) @@ -148,13 +161,113 @@ public defn list-hash (l:List) -> Int : turn-on-debug(false) i +;===== Type Expansion Algorithms ========= +public defn times (f1:Flip,f2:Flip) -> Flip : + switch {_ == f2} : + DEFAULT : f1 + REVERSE : swap(f1) +public defn swap (f:Flip) -> Flip : + switch {_ == f} : + DEFAULT : REVERSE + REVERSE : DEFAULT + +public defmulti get-type (s:Stmt) -> Type +public defmethod get-type (s:Stmt) -> Type : + match(s) : + (s:DefWire|DefPoison|DefRegister) : type(s) + (s:DefNode) : type(value(s)) + (s:DefMemory) : + val depth = depth(s) + ; Fields + val addr = Field(`addr,DEFAULT,UIntType(IntWidth(ceil-log2(depth)))) + val en = Field(`en,DEFAULT,BoolType()) + val clk = Field(`clk,DEFAULT,ClockType()) + val def-data = Field(`data,DEFAULT,data-type(s)) + val rev-data = Field(`data,REVERSE,data-type(s)) + val rdata = Field(`rdata,REVERSE,data-type(s)) + val wdata = Field(`wdata,DEFAULT,data-type(s)) + val mask = Field(`mask,DEFAULT,create-mask(data-type(s))) + val wmask = Field(`wmask,DEFAULT,create-mask(data-type(s))) + val ren = Field(`ren,DEFAULT,UIntType(IntWidth(1))) + val wen = Field(`wen,DEFAULT,UIntType(IntWidth(1))) + val raddr = Field(`raddr,DEFAULT,UIntType(IntWidth(ceil-log2(depth)))) + val waddr = Field(`waddr,DEFAULT,UIntType(IntWidth(ceil-log2(depth)))) + + val read-type = BundleType(to-list([rev-data,addr,en,clk])) + val write-type = BundleType(to-list([def-data,mask,addr,en,clk])) + val readwrite-type = BundleType(to-list([wdata,wmask,waddr,wen,rdata,raddr,ren,clk])) + + val mem-fields = Vector<Field>() + for x in readers(s) do : + add(mem-fields,Field(x,REVERSE,read-type)) + for x in writers(s) do : + add(mem-fields,Field(x,REVERSE,write-type)) + for x in readwriters(s) do : + add(mem-fields,Field(x,REVERSE,readwrite-type)) + BundleType(to-list(mem-fields)) + (s:DefInstance) : UnknownType() + (s:Begin|Connect|BulkConnect|Stop|Print|Empty|IsInvalid) : UnknownType() + +public defn get-size (t:Type) -> Int : + val x = match(t) : + (t:BundleType) : + var sum = 0 + for f in fields(t) do : + sum = sum + get-size(type(f)) + sum + (t:VectorType) : size(t) * get-size(type(t)) + (t) : 1 + x +public defn get-valid-points (t1:Type,t2:Type,flip1:Flip,flip2:Flip) -> List<[Int,Int]> : + ;println-all(["Inside with t1:" t1 ",t2:" t2 ",f1:" flip1 ",f2:" flip2]) + match(t1,t2) : + (t1:UIntType,t2:UIntType) : + if flip1 == flip2 : list([0, 0]) + else: list() + (t1:SIntType,t2:SIntType) : + if flip1 == flip2 : list([0, 0]) + else: list() + (t1:BundleType,t2:BundleType) : + val points = Vector<[Int,Int]>() + var ilen = 0 + var jlen = 0 + for i in 0 to length(fields(t1)) do : + for j in 0 to length(fields(t2)) do : + ;println(i) + ;println(j) + ;println(ilen) + ;println(jlen) + val f1 = fields(t1)[i] + val f2 = fields(t2)[j] + if name(f1) == name(f2) : + val ls = get-valid-points(type(f1),type(f2),flip1 * flip(f1), + flip2 * flip(f2)) + for x in ls do : + add(points,[x[0] + ilen, x[1] + jlen]) + ;println(points) + jlen = jlen + get-size(type(fields(t2)[j])) + ilen = ilen + get-size(type(fields(t1)[i])) + jlen = 0 + to-list(points) + (t1:VectorType,t2:VectorType) : + val points = Vector<[Int,Int]>() + var ilen = 0 + var jlen = 0 + for i in 0 to min(size(t1),size(t2)) do : + val ls = get-valid-points(type(t1),type(t2),flip1,flip2) + for x in ls do : + add(points,[x[0] + ilen, x[1] + jlen]) + ilen = ilen + get-size(type(t1)) + jlen = jlen + get-size(type(t2)) + to-list(points) + ;============= Useful functions ============== -public defn create-mask (n:Symbol,dt:Type) -> Field : - Field{n,DEFAULT,_} $ match(dt) : - (t:VectorType) : VectorType(BoolType(),size(t)) +public defn create-mask (dt:Type) -> Type : + match(dt) : + (t:VectorType) : VectorType(create-mask(type(t)),size(t)) (t:BundleType) : val fields* = for f in fields(t) map : - Field(name(f),flip(f),BoolType()) + Field(name(f),flip(f),create-mask(type(f))) BundleType(fields*) (t:UIntType|SIntType) : BoolType() @@ -265,7 +378,7 @@ defmethod print (o:OutputStream, op:PrimOp) : NEQUIV-OP : "neqv" EQUAL-OP : "eq" NEQUAL-OP : "neq" - MUX-OP : "mux" + ;MUX-OP : "mux" PAD-OP : "pad" AS-UINT-OP : "asUInt" AS-SINT-OP : "asSInt" @@ -298,6 +411,10 @@ defmethod print (o:OutputStream, e:Expression) : print-all(o, [op(e) "("]) print-all(o, join(concat(args(e), consts(e)), ", ")) print(o, ")") + (e:Mux) : + print-all(o, ["mux(" cond(e) ", " tval(e) ", " fval(e) ")"]) + (e:ValidIf) : + print-all(o, ["validif(" cond(e) ", " value(e) ")"]) print-debug(o,e) defmethod print (o:OutputStream, c:Stmt) : @@ -340,6 +457,8 @@ defmethod print (o:OutputStream, c:Stmt) : do(print{o,_}, join(body(c), "\n")) (c:Connect) : print-all(o, [loc(c) " <= " exp(c)]) + (c:IsInvalid) : + print-all(o, [exp(c) " is invalid"]) (c:BulkConnect) : print-all(o, [loc(c) " <- " exp(c)]) (c:Empty) : @@ -350,25 +469,9 @@ defmethod print (o:OutputStream, c:Stmt) : print-all(o, ["printf(" clk(c) ", " en(c) ", "]) ;" print-all(o, join(List(escape(string(c)),args(c)), ", ")) print(o, ")") - (c:CDefMemory) : - if seq?(c) : - print-all(o, ["smem " name(c) " : " type(c) "[" size(c) "]"]) - else : - print-all(o, ["cmem " name(c) " : " type(c) "[" size(c) "]"]) - (c:CDefMPort) : - if direction(c) == MRead : - print-all(o, [direction(c) " mport " name(c) " = " mem(c) "[" exps(c)[0] "], " exps(c)[1]]) - else : - print-all(o, [direction(c) " mport " name(c) " = " mem(c) "[" exps(c)[0] "], " exps(c)[1] ", " exps(c)[2]]) if not c typeof Conditionally|Begin|Empty: print-debug(o,c) -defmethod print (o:OutputStream, m:MPortDir) : - switch { m == _ } : - MRead : print(o,"read") - MWrite : print(o,"write") - MReadWrite : print(o,"rdwr") - defmethod print (o:OutputStream, t:Type) : match(t) : (t:UnknownType) : @@ -445,6 +548,8 @@ defmethod map (f: Expression -> Expression, e:Expression) -> Expression : (e:SubIndex) : SubIndex(f(exp(e)), value(e), type(e)) (e:SubAccess) : SubAccess(f(exp(e)), f(index(e)), type(e)) (e:DoPrim) : DoPrim(op(e), map(f, args(e)), consts(e), type(e)) + (e:Mux) : Mux(f(cond(e)),f(tval(e)),f(fval(e)),type(e)) + (e:ValidIf) : ValidIf(f(cond(e)),f(value(e)),type(e)) (e) : e public defmulti map<?T> (f: Symbol -> Symbol, c:?T&Stmt) -> T @@ -467,6 +572,7 @@ defmethod map (f: Expression -> Expression, c:Stmt) -> Stmt : (c:Conditionally) : Conditionally(info(c),f(pred(c)), conseq(c), alt(c)) (c:Connect) : Connect(info(c),f(loc(c)), f(exp(c))) (c:BulkConnect) : BulkConnect(info(c),f(loc(c)), f(exp(c))) + (c:IsInvalid) : IsInvalid(info(c),f(exp(c))) (c:Stop) : Stop(info(c),ret(c),f(clk(c)),f(en(c))) (c:Print) : Print(info(c),string(c),map(f,args(c)),f(clk(c)),f(en(c))) (c) : c @@ -500,6 +606,8 @@ defmethod map (f: Type -> Type, c:Expression) -> Expression : (c:SubIndex) : SubIndex(exp(c),value(c),f(type(c))) (c:SubAccess) : SubAccess(exp(c),index(c),f(type(c))) (c:DoPrim) : DoPrim(op(c),args(c),consts(c),f(type(c))) + (c:Mux) : Mux(cond(c),tval(c),fval(c),f(type(c))) + (c:ValidIf) : ValidIf(cond(c),value(c),f(type(c))) (c) : c public defmulti map<?T> (f: Type -> Type, c:?T&Stmt) -> T |
