aboutsummaryrefslogtreecommitdiff
path: root/src/main/stanza/ir-utils.stanza
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/stanza/ir-utils.stanza')
-rw-r--r--src/main/stanza/ir-utils.stanza154
1 files changed, 131 insertions, 23 deletions
diff --git a/src/main/stanza/ir-utils.stanza b/src/main/stanza/ir-utils.stanza
index 7a9cff31..7d4c30b2 100644
--- a/src/main/stanza/ir-utils.stanza
+++ b/src/main/stanza/ir-utils.stanza
@@ -39,7 +39,7 @@ public defn firrtl-gensym (s:Symbol,sym-hash:HashTable<Symbol,Int>) -> Symbol :
symbol-join([s delin num])
else :
sym-hash[s] = 0
- s
+ symbol-join([s delin 0])
val s* = to-string(s)
val i* = generated?(s*)
val nex = match(i*) :
@@ -113,7 +113,20 @@ public defn EQV (e1:Expression,e2:Expression) -> Expression :
DoPrim(EQUIV-OP,list(e1,e2),list(),type(e1))
public defn MUX (p:Expression,e1:Expression,e2:Expression) -> Expression :
- DoPrim(MUX-OP,list(p,e1,e2),list(),type(e1))
+ Mux(p,e1,e2,mux-type(type(e1),type(e2)))
+
+public defn mux-type (e1:Expression,e2:Expression) -> Type :
+ mux-type(type(e1),type(e2))
+public defn mux-type (t1:Type,t2:Type) -> Type :
+ if t1 == t2 :
+ match(t1,t2) :
+ (t1:UIntType,t2:UIntType) : UIntType(UnknownWidth())
+ (t1:SIntType,t2:SIntType) : SIntType(UnknownWidth())
+ (t1:VectorType,t2:VectorType) : VectorType(mux-type(type(t1),type(t2)),size(t1))
+ (t1:BundleType,t2:BundleType) :
+ BundleType $ for (f1 in fields(t1),f2 in fields(t2)) map :
+ Field(name(f1),flip(f1),mux-type(type(f1),type(f2)))
+ else : UnknownType()
public defn CAT (e1:Expression,e2:Expression) -> Expression :
DoPrim(CONCAT-OP,list(e1,e2),list(),type(e1))
@@ -148,13 +161,113 @@ public defn list-hash (l:List) -> Int :
turn-on-debug(false)
i
+;===== Type Expansion Algorithms =========
+public defn times (f1:Flip,f2:Flip) -> Flip :
+ switch {_ == f2} :
+ DEFAULT : f1
+ REVERSE : swap(f1)
+public defn swap (f:Flip) -> Flip :
+ switch {_ == f} :
+ DEFAULT : REVERSE
+ REVERSE : DEFAULT
+
+public defmulti get-type (s:Stmt) -> Type
+public defmethod get-type (s:Stmt) -> Type :
+ match(s) :
+ (s:DefWire|DefPoison|DefRegister) : type(s)
+ (s:DefNode) : type(value(s))
+ (s:DefMemory) :
+ val depth = depth(s)
+ ; Fields
+ val addr = Field(`addr,DEFAULT,UIntType(IntWidth(ceil-log2(depth))))
+ val en = Field(`en,DEFAULT,BoolType())
+ val clk = Field(`clk,DEFAULT,ClockType())
+ val def-data = Field(`data,DEFAULT,data-type(s))
+ val rev-data = Field(`data,REVERSE,data-type(s))
+ val rdata = Field(`rdata,REVERSE,data-type(s))
+ val wdata = Field(`wdata,DEFAULT,data-type(s))
+ val mask = Field(`mask,DEFAULT,create-mask(data-type(s)))
+ val wmask = Field(`wmask,DEFAULT,create-mask(data-type(s)))
+ val ren = Field(`ren,DEFAULT,UIntType(IntWidth(1)))
+ val wen = Field(`wen,DEFAULT,UIntType(IntWidth(1)))
+ val raddr = Field(`raddr,DEFAULT,UIntType(IntWidth(ceil-log2(depth))))
+ val waddr = Field(`waddr,DEFAULT,UIntType(IntWidth(ceil-log2(depth))))
+
+ val read-type = BundleType(to-list([rev-data,addr,en,clk]))
+ val write-type = BundleType(to-list([def-data,mask,addr,en,clk]))
+ val readwrite-type = BundleType(to-list([wdata,wmask,waddr,wen,rdata,raddr,ren,clk]))
+
+ val mem-fields = Vector<Field>()
+ for x in readers(s) do :
+ add(mem-fields,Field(x,REVERSE,read-type))
+ for x in writers(s) do :
+ add(mem-fields,Field(x,REVERSE,write-type))
+ for x in readwriters(s) do :
+ add(mem-fields,Field(x,REVERSE,readwrite-type))
+ BundleType(to-list(mem-fields))
+ (s:DefInstance) : UnknownType()
+ (s:Begin|Connect|BulkConnect|Stop|Print|Empty|IsInvalid) : UnknownType()
+
+public defn get-size (t:Type) -> Int :
+ val x = match(t) :
+ (t:BundleType) :
+ var sum = 0
+ for f in fields(t) do :
+ sum = sum + get-size(type(f))
+ sum
+ (t:VectorType) : size(t) * get-size(type(t))
+ (t) : 1
+ x
+public defn get-valid-points (t1:Type,t2:Type,flip1:Flip,flip2:Flip) -> List<[Int,Int]> :
+ ;println-all(["Inside with t1:" t1 ",t2:" t2 ",f1:" flip1 ",f2:" flip2])
+ match(t1,t2) :
+ (t1:UIntType,t2:UIntType) :
+ if flip1 == flip2 : list([0, 0])
+ else: list()
+ (t1:SIntType,t2:SIntType) :
+ if flip1 == flip2 : list([0, 0])
+ else: list()
+ (t1:BundleType,t2:BundleType) :
+ val points = Vector<[Int,Int]>()
+ var ilen = 0
+ var jlen = 0
+ for i in 0 to length(fields(t1)) do :
+ for j in 0 to length(fields(t2)) do :
+ ;println(i)
+ ;println(j)
+ ;println(ilen)
+ ;println(jlen)
+ val f1 = fields(t1)[i]
+ val f2 = fields(t2)[j]
+ if name(f1) == name(f2) :
+ val ls = get-valid-points(type(f1),type(f2),flip1 * flip(f1),
+ flip2 * flip(f2))
+ for x in ls do :
+ add(points,[x[0] + ilen, x[1] + jlen])
+ ;println(points)
+ jlen = jlen + get-size(type(fields(t2)[j]))
+ ilen = ilen + get-size(type(fields(t1)[i]))
+ jlen = 0
+ to-list(points)
+ (t1:VectorType,t2:VectorType) :
+ val points = Vector<[Int,Int]>()
+ var ilen = 0
+ var jlen = 0
+ for i in 0 to min(size(t1),size(t2)) do :
+ val ls = get-valid-points(type(t1),type(t2),flip1,flip2)
+ for x in ls do :
+ add(points,[x[0] + ilen, x[1] + jlen])
+ ilen = ilen + get-size(type(t1))
+ jlen = jlen + get-size(type(t2))
+ to-list(points)
+
;============= Useful functions ==============
-public defn create-mask (n:Symbol,dt:Type) -> Field :
- Field{n,DEFAULT,_} $ match(dt) :
- (t:VectorType) : VectorType(BoolType(),size(t))
+public defn create-mask (dt:Type) -> Type :
+ match(dt) :
+ (t:VectorType) : VectorType(create-mask(type(t)),size(t))
(t:BundleType) :
val fields* = for f in fields(t) map :
- Field(name(f),flip(f),BoolType())
+ Field(name(f),flip(f),create-mask(type(f)))
BundleType(fields*)
(t:UIntType|SIntType) : BoolType()
@@ -265,7 +378,7 @@ defmethod print (o:OutputStream, op:PrimOp) :
NEQUIV-OP : "neqv"
EQUAL-OP : "eq"
NEQUAL-OP : "neq"
- MUX-OP : "mux"
+ ;MUX-OP : "mux"
PAD-OP : "pad"
AS-UINT-OP : "asUInt"
AS-SINT-OP : "asSInt"
@@ -298,6 +411,10 @@ defmethod print (o:OutputStream, e:Expression) :
print-all(o, [op(e) "("])
print-all(o, join(concat(args(e), consts(e)), ", "))
print(o, ")")
+ (e:Mux) :
+ print-all(o, ["mux(" cond(e) ", " tval(e) ", " fval(e) ")"])
+ (e:ValidIf) :
+ print-all(o, ["validif(" cond(e) ", " value(e) ")"])
print-debug(o,e)
defmethod print (o:OutputStream, c:Stmt) :
@@ -340,6 +457,8 @@ defmethod print (o:OutputStream, c:Stmt) :
do(print{o,_}, join(body(c), "\n"))
(c:Connect) :
print-all(o, [loc(c) " <= " exp(c)])
+ (c:IsInvalid) :
+ print-all(o, [exp(c) " is invalid"])
(c:BulkConnect) :
print-all(o, [loc(c) " <- " exp(c)])
(c:Empty) :
@@ -350,25 +469,9 @@ defmethod print (o:OutputStream, c:Stmt) :
print-all(o, ["printf(" clk(c) ", " en(c) ", "]) ;"
print-all(o, join(List(escape(string(c)),args(c)), ", "))
print(o, ")")
- (c:CDefMemory) :
- if seq?(c) :
- print-all(o, ["smem " name(c) " : " type(c) "[" size(c) "]"])
- else :
- print-all(o, ["cmem " name(c) " : " type(c) "[" size(c) "]"])
- (c:CDefMPort) :
- if direction(c) == MRead :
- print-all(o, [direction(c) " mport " name(c) " = " mem(c) "[" exps(c)[0] "], " exps(c)[1]])
- else :
- print-all(o, [direction(c) " mport " name(c) " = " mem(c) "[" exps(c)[0] "], " exps(c)[1] ", " exps(c)[2]])
if not c typeof Conditionally|Begin|Empty: print-debug(o,c)
-defmethod print (o:OutputStream, m:MPortDir) :
- switch { m == _ } :
- MRead : print(o,"read")
- MWrite : print(o,"write")
- MReadWrite : print(o,"rdwr")
-
defmethod print (o:OutputStream, t:Type) :
match(t) :
(t:UnknownType) :
@@ -445,6 +548,8 @@ defmethod map (f: Expression -> Expression, e:Expression) -> Expression :
(e:SubIndex) : SubIndex(f(exp(e)), value(e), type(e))
(e:SubAccess) : SubAccess(f(exp(e)), f(index(e)), type(e))
(e:DoPrim) : DoPrim(op(e), map(f, args(e)), consts(e), type(e))
+ (e:Mux) : Mux(f(cond(e)),f(tval(e)),f(fval(e)),type(e))
+ (e:ValidIf) : ValidIf(f(cond(e)),f(value(e)),type(e))
(e) : e
public defmulti map<?T> (f: Symbol -> Symbol, c:?T&Stmt) -> T
@@ -467,6 +572,7 @@ defmethod map (f: Expression -> Expression, c:Stmt) -> Stmt :
(c:Conditionally) : Conditionally(info(c),f(pred(c)), conseq(c), alt(c))
(c:Connect) : Connect(info(c),f(loc(c)), f(exp(c)))
(c:BulkConnect) : BulkConnect(info(c),f(loc(c)), f(exp(c)))
+ (c:IsInvalid) : IsInvalid(info(c),f(exp(c)))
(c:Stop) : Stop(info(c),ret(c),f(clk(c)),f(en(c)))
(c:Print) : Print(info(c),string(c),map(f,args(c)),f(clk(c)),f(en(c)))
(c) : c
@@ -500,6 +606,8 @@ defmethod map (f: Type -> Type, c:Expression) -> Expression :
(c:SubIndex) : SubIndex(exp(c),value(c),f(type(c)))
(c:SubAccess) : SubAccess(exp(c),index(c),f(type(c)))
(c:DoPrim) : DoPrim(op(c),args(c),consts(c),f(type(c)))
+ (c:Mux) : Mux(cond(c),tval(c),fval(c),f(type(c)))
+ (c:ValidIf) : ValidIf(cond(c),value(c),f(type(c)))
(c) : c
public defmulti map<?T> (f: Type -> Type, c:?T&Stmt) -> T