defpackage firrtl/ir-utils : import core import verse import firrtl/ir2 ;============== DEBUG STUFF ============================= public defmulti print-debug (o:OutputStream, e:Expression|Stmt|Type|Port|Field|Module|Circuit) -> False ;============== GENSYM STUFF ====================== public val v-keywords = to-list $ [ `always, `and, `assign, `attribute, `begin, `buf, `bufif0, `bufif1, `case, `casex, `casez, `cmos, `deassign, `default, `defparam, `disable, `edge, `else, `end, `endattribute, `endcase, `endfunction, `endmodule, `endprimitive, `endspecify, `endtable, `endtask, `event, `for, `force, `forever, `fork, `function, `highz0, `highz1, `if, `ifnone, `initial, `inout, `input, `integer, `initvar, `join, `medium, `module, `large, `macromodule, `nand, `negedge, `nmos, `nor, `not, `notif0, `notif1, `or, `output, `parameter, `pmos, `posedge, `primitive, `pull0, `pull1, `pulldown, `pullup, `rcmos, `real, `realtime, `reg, `release, `repeat, `rnmos, `rpmos, `rtran, `rtranif0, `rtranif1, `scalared, `signed, `small, `specify, `specparam, `strength, `strong0, `strong1, `supply0, `supply1, `table, `task, `time, `tran, `tranif0, `tranif1, `tri, `tri0, `tri1, `triand, `trior, `trireg, `unsigned, `vectored, `wait, `wand, `weak0, `weak1, `while, `wire, `wor, `xnor, `xor, `SYNTHESIS, `PRINTF_COND, `VCS ] public defn firrtl-gensym (s:Symbol) -> Symbol : firrtl-gensym(s,HashTable(symbol-hash)) public defn firrtl-gensym (sym-hash:HashTable) -> Symbol : firrtl-gensym(`gen,sym-hash) defn digits? (s:String) -> True|False : val digits = "0123456789" var yes = true for c in s do : if not contains?(digits,c) : yes = false yes public defn firrtl-gensym (s:Symbol,sym-hash:HashTable) -> Symbol : defn get-name (s:Symbol) -> Symbol : if key?(sym-hash,s) : val num = sym-hash[s] + 1 sym-hash[s] = num symbol-join([s delin num]) else : sym-hash[s] = 0 s val s* = to-string(s) val i* = for i in 0 to length(s*) - 1 find : s*[i] == '_' and digits?(substring(s*,i + 1)) match(i*) : (i:False) : get-name(s) (i:Int) : get-name(to-symbol(substring(s*,0,i))) public defn get-sym-hash (m:InModule) -> HashTable : get-sym-hash(m,list()) public defn get-sym-hash (m:InModule,keywords:List) -> HashTable : val sym-hash = HashTable(symbol-hash) for k in keywords do : sym-hash[k] = 0 defn add-name (s:Symbol) -> False : val s* = to-string(s) val i* = for i in 0 to length(s*) - 1 find : s*[i] == '_' and digits?(substring(s*,i + 1)) match(i*) : (i:False) : sym-hash[s] = 0 (i:Int) : val name = to-symbol(substring(s*,0,i)) val digit = to-int(substring(s*,i + 1)) if key?(sym-hash,name) : val num = sym-hash[name] sym-hash[name] = max(num,digit) else : sym-hash[name] = digit defn to-port (p:Port) -> False : add-name(name(p)) defn to-stmt (s:Stmt) -> Stmt : match(s) : (s:DefWire) : add-name(name(s)) (s:DefRegister) : add-name(name(s)) (s:DefInstance) : add-name(name(s)) (s:DefMemory) : add-name(name(s)) (s:DefNode) : add-name(name(s)) (s:DefAccessor) : add-name(name(s)) (s) : false map(to-stmt,s) to-stmt(body(m)) map(to-port,ports(m)) sym-hash ;============== Exceptions ===================== public definterface PassException <: Exception public defn PassException (s:String) : new PassException : defmethod print (o:OutputStream, this) : print(o, s) public defn PassExceptions (xs:Streamable) : PassException(string-join(xs, "\n")) ;============== Pass/Compiler Structs ============ public definterface Compiler public defmulti passes (c:Compiler) -> List public defmulti with-output (c:Compiler) -> ((() -> False) -> False) public definterface Pass public defmulti pass (p:Pass) -> (Circuit -> Circuit) public defmulti name (p:Pass) -> String public defmulti short-name (p:Pass) -> String public defmethod print (o:OutputStream, p:Pass) : print(o,name(p)) ;============== Various Useful Functions ============== public defn ceil-log2 (i:Long) -> Long : defn* loop (n:Long, l:Long) : if n < i : if l == 30 : to-long(31) else : loop(n * to-long(2), l + to-long(1)) else : l error("Log of negative number!") when i < to-long(0) loop(to-long $ 1, to-long $ 0) public defn abs (x:Long) -> Long : if x < to-long(0) : to-long(0) - x else : x public defn max (x:Long,y:Long) -> Long : if x < y : y else : x public defn to-int (x:Long) -> Int : if x > to-long(2147483647) or x < to-long(-2147483648) : error("Long too big to convert to Int") else : to-int(to-string(x)) ;============== PRINTERS =================================== defmethod print (o:OutputStream, d:Flip) : print{o, _} $ switch {d == _} : DEFAULT : "" REVERSE: "flip" defmethod print (o:OutputStream, d:AccDirection) : print{o, _} $ switch {d == _} : READ : "read" WRITE: "write" INFER: "infer" RDWR: "rdwr" defmethod print (o:OutputStream, d:PortDirection) : print{o, _} $ switch {d == _} : INPUT : "input" OUTPUT: "output" defmethod print (o:OutputStream, w:Width) : print{o, _} $ match(w) : (w:UnknownWidth) : "?" (w:IntWidth) : width(w) (w:LongWidth) : width(w) defmethod print (o:OutputStream, op:PrimOp) : print{o, _} $ switch {op == _} : ADD-OP : "add" SUB-OP : "sub" MUL-OP : "mul" DIV-OP : "div" MOD-OP : "mod" QUO-OP : "quo" REM-OP : "rem" ADD-WRAP-OP : "addw" SUB-WRAP-OP : "subw" LESS-OP : "lt" LESS-EQ-OP : "leq" GREATER-OP : "gt" GREATER-EQ-OP : "geq" EQUIV-OP : "eqv" NEQUIV-OP : "neqv" EQUAL-OP : "eq" NEQUAL-OP : "neq" MUX-OP : "mux" PAD-OP : "pad" AS-UINT-OP : "asUInt" AS-SINT-OP : "asSInt" DYN-SHIFT-LEFT-OP : "dshl" DYN-SHIFT-RIGHT-OP : "dshr" SHIFT-LEFT-OP : "shl" SHIFT-RIGHT-OP : "shr" CONVERT-OP : "cvt" NEG-OP : "neg" BIT-NOT-OP : "not" BIT-AND-OP : "and" BIT-OR-OP : "or" BIT-XOR-OP : "xor" BIT-AND-REDUCE-OP : "andr" BIT-OR-REDUCE-OP : "orr" BIT-XOR-REDUCE-OP : "xorr" CONCAT-OP : "cat" BIT-SELECT-OP : "bit" BITS-SELECT-OP : "bits" defmethod print (o:OutputStream, e:Expression) : match(e) : (e:Ref) : print(o, name(e)) (e:Subfield) : print-all(o, [exp(e) "." name(e)]) (e:Index) : print-all(o, [exp(e) "[" value(e) "]"]) (e:UIntValue) : print-all(o, ["UInt(" value(e) ")"]) (e:SIntValue) : print-all(o, ["SInt(" value(e) ")"]) (e:DoPrim) : print-all(o, [op(e) "("]) print-all(o, join(concat(args(e), consts(e)), ", ")) print(o, ")") print-debug(o,e) defmethod print (o:OutputStream, c:Stmt) : val io = IndentedStream(o, 3) match(c) : (c:DefWire) : print-all(o,["wire " name(c) " : " type(c)]) (c:DefRegister) : print-all(o,["reg " name(c) " : " type(c) ", " clock(c) ", " reset(c)]) (c:DefMemory) : if seq?(c) : print-all(o,["smem " name(c) " : " type(c) ", " clock(c)]) else : print-all(o,["cmem " name(c) " : " type(c) ", " clock(c)]) (c:DefInstance) : print-all(o,["inst " name(c) " of " module(c)]) (c:DefNode) : print-all(o,["node " name(c) " = " value(c)]) (c:DefAccessor) : print-all(o,[acc-dir(c) " accessor " name(c) " = " source(c) "[" index(c) "]"]) (c:Conditionally) : if conseq(c) typeof Begin : print-all(o, ["when " pred(c) " :"]) print-debug(o,c) print(o,"\n") print(io,conseq(c)) else : print-all(o, ["when " pred(c) " : " conseq(c)]) print-debug(o,c) if alt(c) not-typeof EmptyStmt : print(o, "\nelse :") print(io, "\n") print(io,alt(c)) (c:Begin) : do(print{o,_}, join(body(c), "\n")) (c:Connect) : print-all(o, [loc(c) " := " exp(c)]) (c:BulkConnect) : print-all(o, [loc(c) " <> " exp(c)]) (c:OnReset) : print-all(o, ["onreset " loc(c) " := " exp(c)]) (c:EmptyStmt) : print(o, "skip") if not c typeof Conditionally|Begin|EmptyStmt : print-debug(o,c) defmethod print (o:OutputStream, t:Type) : match(t) : (t:UnknownType) : print(o, "?") (t:ClockType) : print(o, "Clock") (t:UIntType) : match(width(t)) : (w:UnknownWidth) : print-all(o, ["UInt"]) (w) : print-all(o, ["UInt<" width(t) ">"]) (t:SIntType) : match(width(t)) : (w:UnknownWidth) : print-all(o, ["SInt"]) (w) : print-all(o, ["SInt<" width(t) ">"]) (t:BundleType) : print(o, "{") print-all(o, join(fields(t), ", ")) print(o, "}") (t:VectorType) : print-all(o, [type(t) "[" size(t) "]"]) print-debug(o,t) defmethod print (o:OutputStream, f:Field) : print-all(o, [flip(f) " " name(f) " : " type(f)]) print-debug(o,f) defmethod print (o:OutputStream, p:Port) : print-all(o, [direction(p) " " name(p) " : " type(p)]) print-debug(o,p) defmethod print (o:OutputStream, m:InModule) : print-all(o, ["module " name(m) " :"]) print-debug(o,m) print(o,"\n") val io = IndentedStream(o, 3) for p in ports(m) do : println(io,p) print(io,body(m)) defmethod print (o:OutputStream, m:ExModule) : print-all(o, ["extmodule " name(m) " :"]) print-debug(o,m) print(o,"\n") val io = IndentedStream(o, 3) for p in ports(m) do : println(io,p) defmethod print (o:OutputStream, c:Circuit) : print-all(o, ["circuit " main(c) " :"]) print-debug(o,c) print(o,"\n") val io = IndentedStream(o, 3) for m in modules(c) do : println(io, m) ;=================== MAPPERS =============================== public defn map (f: Type -> Type, t:?T&Type) -> T : val type = match(t) : (t:T&BundleType) : BundleType $ for p in fields(t) map : Field(name(p), flip(p), f(type(p))) (t:T&VectorType) : VectorType(f(type(t)), size(t)) (t) : t type as T&Type public defmulti map (f: Expression -> Expression, e:?T&Expression) -> T defmethod map (f: Expression -> Expression, e:Expression) -> Expression : match(e) : (e:Subfield) : Subfield(f(exp(e)), name(e), type(e)) (e:Index) : Index(f(exp(e)), value(e), type(e)) (e:DoPrim) : DoPrim(op(e), map(f, args(e)), consts(e), type(e)) (e) : e public defmulti map (f: Expression -> Expression, c:?T&Stmt) -> T defmethod map (f: Expression -> Expression, c:Stmt) -> Stmt : match(c) : (c:DefAccessor) : DefAccessor(info(c),name(c), f(source(c)), f(index(c)),acc-dir(c)) (c:DefRegister) : DefRegister(info(c),name(c), type(c), f(clock(c)), f(reset(c))) (c:DefMemory) : DefMemory(info(c),name(c), type(c), seq?(c), f(clock(c))) (c:DefNode) : DefNode(info(c),name(c), f(value(c))) (c:DefInstance) : DefInstance(info(c),name(c), f(module(c))) (c:Conditionally) : Conditionally(info(c),f(pred(c)), conseq(c), alt(c)) (c:Connect) : Connect(info(c),f(loc(c)), f(exp(c))) (c:BulkConnect) : BulkConnect(info(c),f(loc(c)), f(exp(c))) (c:OnReset) : OnReset(info(c),f(loc(c)),f(exp(c))) (c) : c public defmulti map (f: Stmt -> Stmt, c:?T&Stmt) -> T defmethod map (f: Stmt -> Stmt, c:Stmt) -> Stmt : match(c) : (c:Conditionally) : Conditionally(info(c),pred(c), f(conseq(c)), f(alt(c))) (c:Begin) : Begin(map(f, body(c))) (c) : c public defmulti map (f: Width -> Width, c:?T&Expression) -> T defmethod map (f: Width -> Width, c:Expression) -> Expression : match(c) : (c:UIntValue) : UIntValue(value(c),f(width(c))) (c:SIntValue) : SIntValue(value(c),f(width(c))) (c) : c public defmulti map (f: Width -> Width, c:?T&Type) -> T defmethod map (f: Width -> Width, c:Type) -> Type : match(c) : (c:UIntType) : UIntType(f(width(c))) (c:SIntType) : SIntType(f(width(c))) (c) : c public defmulti map (f: Type -> Type, c:?T&Expression) -> T defmethod map (f: Type -> Type, c:Expression) -> Expression : match(c) : (c:Ref) : Ref(name(c),f(type(c))) (c:Subfield) : Subfield(exp(c),name(c),f(type(c))) (c:Index) : Index(exp(c),value(c),f(type(c))) (c:DoPrim) : DoPrim(op(c),args(c),consts(c),f(type(c))) (c) : c public defmulti map (f: Type -> Type, c:?T&Stmt) -> T defmethod map (f: Type -> Type, c:Stmt) -> Stmt : match(c) : (c:DefWire) : DefWire(info(c),name(c),f(type(c))) (c:DefRegister) : DefRegister(info(c),name(c),f(type(c)),clock(c),reset(c)) (c:DefMemory) : DefMemory(info(c),name(c),f(type(c)) as VectorType,seq?(c),clock(c)) (c) : c public defmulti mapr (f: Width -> Width, t:?T&Type) -> T defmethod mapr (f: Width -> Width, t:Type) -> Type : defn apply-t (t:Type) -> Type : map{f,_} $ map(apply-t,t) apply-t(t) public defmulti mapr (f: Width -> Width, s:?T&Stmt) -> T defmethod mapr (f: Width -> Width, s:Stmt) -> Stmt : defn apply-t (t:Type) -> Type : mapr(f,t) defn apply-e (e:Expression) -> Expression : map{f,_} $ map{apply-t,_} $ map(apply-e,e) defn apply-s (s:Stmt) -> Stmt : map{apply-t,_} $ map{apply-e,_} $ map(apply-s,s) apply-s(s) ;================= HELPER FUNCTIONS USING MAP =================== public defmulti do (f:Expression -> ?, e:Expression) -> False defmethod do (f:Expression -> ?, e:Expression) -> False : defn f* (x:Expression) : f(x) x map(f*,e) false public defmulti do (f:Expression -> ?, s:Stmt) -> False defmethod do (f:Expression -> ?, s:Stmt) -> False : defn f* (x:Expression) : f(x) x map(f*,s) false public defmulti do (f:Stmt -> ?, s:Stmt) -> False defmethod do (f:Stmt -> ?, s:Stmt) -> False : defn f* (x:Stmt) : f(x) x map(f*,s) false ; Not well defined - usually use dor on fields of a recursive type ;public defmulti dor (f:Expression -> ?, e:Expression) -> False ;defmethod dor (f:Expression -> ?, e:Expression) -> False : ; f(e) ; for x in e map : ; dor(f,x) ; x ; false ; ;public defmulti dor (f:Expression -> ?, s:Stmt) -> False ;defmethod dor (f:Expression -> ?, s:Stmt) -> False : ; defn f* (x:Expression) : ; dor(f,x) ; x ; map(f*,s) ; false ; ;public defmulti dor (f:Stmt -> ?, s:Stmt) -> False ;defmethod dor (f:Stmt -> ?, s:Stmt) -> False : ; f(s) ; defn f* (x:Stmt) : ; dor(f,x) ; x ; map(f*,s) ; false ; ;public defmulti sub-exps (s:Expression|Stmt) -> List ;defmethod sub-exps (e:Expression) -> List : ; val l = Vector() ; defn f (x:Expression) : add(l,x) ; do(f,e) ; to-list(l) ;defmethod sub-exps (e:Stmt) -> List : ; val l = Vector() ; defn f (x:Expression) : add(l,x) ; do(f,e) ; to-list(l) ; ;public defmulti sub-stmts (s:Stmt) -> List ;defmethod sub-stmts (s:Stmt) : ; val l = Vector() ; defn f (x:Stmt) : add(l,x) ; do(f,s) ; to-list(l) ;=================== ADAM OPS =============================== public defn split (s:String,c:Char) -> List : if not contains(to-list(s),c) : list(s) else : val index = label ret : var i = 0 for c* in to-list(s) do : if c* == c : ret(i) else : i = i + 1 ret(0) val h = substring(s,0,index) val t = substring(s,index + 1,length(s)) List(h,split(t,c)) public defn contains (l:List, c:Char) : label myret : for x in l do : if x == c : myret(true) false public defn merge! (a:HashTable, b:HashTable) : for e in b do : a[key(e)] = value(e) public defn pow (x:Long,y:Long) -> Long : var x* = to-long(1) var y* = y while y* != to-long(0) : x* = times(x*,x) y* = minus(y*,to-long(1)) x*