diff options
Diffstat (limited to 'src/main/stanza/ir-utils.stanza')
| -rw-r--r-- | src/main/stanza/ir-utils.stanza | 227 |
1 files changed, 227 insertions, 0 deletions
diff --git a/src/main/stanza/ir-utils.stanza b/src/main/stanza/ir-utils.stanza new file mode 100644 index 00000000..be08cac8 --- /dev/null +++ b/src/main/stanza/ir-utils.stanza @@ -0,0 +1,227 @@ +defpackage chipper.ir-utils : + import core + import verse + import chipper.ir2 + +;============== PRINTERS =================================== +defmethod print (o:OutputStream, d:Direction) : + print{o, _} $ + switch {d == _} : + INPUT : "input" + OUTPUT: "output" + UNKNOWN-DIR : "unknown" + +defmethod print (o:OutputStream, w:Width) : + print{o, _} $ + match(w) : + (w:UnknownWidth) : "?" + (w:IntWidth) : width(w) + +defmethod print (o:OutputStream, op:PrimOp) : + print{o, _} $ + switch {op == _} : + ADD-OP : "ADD" + ADD-MOD-OP : "ADD-MOD" + SUB-OP : "MINUS" + SUB-MOD-OP : "SUB-MOD" + TIMES-OP : "TIMES" + DIVIDE-OP : "DIVIDE" + MOD-OP : "MOD" + SHIFT-LEFT-OP : "SHIFT-LEFT" + SHIFT-RIGHT-OP : "SHIFT-RIGHT" + PAD-OP : "PAD" + BIT-AND-OP : "BIT-AND" + BIT-OR-OP : "BIT-OR" + BIT-XOR-OP : "BIT-XOR" + CONCAT-OP : "CONCAT" + BIT-SELECT-OP : "BIT-SELECT" + BITS-SELECT-OP : "BITS-SELECT" + MULTIPLEX-OP : "MULTIPLEX" + LESS-OP : "LESS" + LESS-EQ-OP : "LESS-EQ" + GREATER-OP : "GREATER" + GREATER-EQ-OP : "GREATER-EQ" + EQUAL-OP : "EQUAL" + +defmethod print (o:OutputStream, e:Expression) : + match(e) : + (e:Ref) : print(o, name(e)) + (e:Field) : print-all(o, [exp(e) "." name(e)]) + (e:Index) : print-all(o, [exp(e) "." value(e)]) + (e:UIntValue) : print-all(o, ["UInt(" value(e) ")"]) + (e:SIntValue) : print-all(o, ["SInt(" value(e) ")"]) + (e:DoPrim) : + print-all(o, [op(e) "("]) + print-all(o, join(concat(args(e), consts(e)), ", ")) + print(o, ")") + (e:ReadPort) : print-all(o, ["ReadPort(" mem(e) ", " index(e) ")"]) + +defmethod print (o:OutputStream, c:Command) : + match(c) : + (c:LetRec) : + println(o, "let : ") + indented{o, _} $ fn () : + for entry in entries(c) do : + println-all([key(entry) " = " value(entry)]) + println(o, "in :") + indented(o, print{o, body(c)}) + (c:DefWire) : + print-all(["wire " name(c) " : " type(c)]) + (c:DefRegister) : + print-all(["reg " name(c) " : " type(c)]) + (c:DefMemory) : + print-all(["mem " name(c) " : " type(c)]) + (c:DefInstance) : + print-all(["inst " name(c) " of " module(c)]) + (c:DefAccessor) : + print-all(["accessor " name(c) " = " source(c) "[" index(c) "]"]) + (c:Conditionally) : + println-all(o, ["when " pred(c) " :"]) + indented(o, print{conseq(c)}) + if alt(c) not-typeof EmptyCommand : + println(o, "\nelse :") + indented(o, print{alt(c)}) + (c:Begin) : + do(print, join(body(c), "\n")) + (c:Connect) : + print-all(o, [loc(c) " := " exp(c)]) + (c:EmptyCommand) : + print(o, "skip") + +defmethod print (o:OutputStream, e:Element) : + match(e) : + (e:Register) : + print-all(o, ["Register(" type(e) ", " value(e) ", " enable(e) ")"]) + (e:Memory) : + print-all(o, ["Memory(" type(e) ", "]) + print-all(o, join(writers(e), ", ")) + print(o, ")") + (e:Node) : + print-all(o, ["Node(" type(e) ", " value(e) ")"]) + (e:Instance) : + print-all(o, ["Instance(" module(e) ", "]) + print-all(o, join(ports(e), ", ")) + print(o, ")") + +defmethod print (o:OutputStream, p:WritePort) : + print-all(o, [index(p) " => WritePort(" value(p) ", " enable(p) ")"]) + +defmethod print (o:OutputStream, t:Type) : + match(t) : + (t:UnknownType) : + print(o, "?") + (t:UIntType) : + print-all(o, ["UInt(" width(t) ")"]) + (t:SIntType) : + print-all(o, ["SInt(" width(t) ")"]) + (t:BundleType) : + print(o, "{") + print-all(o, join(ports(t), ", ")) + print(o, "}") + (t:VectorType) : + print-all(o, [type(t) "[" size(t) "]"]) + +defmethod print (o:OutputStream, p:Port) : + print-all(o, [direction(p) " " name(p) " : " type(p)]) + +defmethod print (o:OutputStream, m:Module) : + println-all(o, ["module " name(m) " :"]) + indented{o, _} $ fn () : + do(println, ports(m)) + print(body(m)) + +defmethod print (o:OutputStream, c:Circuit) : + println-all(o, ["circuit " main(c) " :"]) + indented(o, do{println, modules(c)}) + +;================== INDENTATION ============================ +defn IndentedStream (o:OutputStream, n:Int) : + var indent? = true + defn put (c:Char) : + if indent? : + do(print{o, " "}, 0 to n) + indent? = false + print(o, c) + if c == '\n' : + indent? = true + + new OutputStream : + defmethod print (this, s:String) : do(put, s) + defmethod print (this, c:Char) : put(c) + +defn indented (o:OutputStream, f: () -> ?) : + val prev-stream = CURRENT-OUTPUT-STREAM + dynamic-wind( + fn () : CURRENT-OUTPUT-STREAM = IndentedStream(o, 3) + f + fn (f) : CURRENT-OUTPUT-STREAM = prev-stream) + + +;=================== MAPPERS =============================== +public defn map<?T> (f: Type -> Type, t:?T&Type) -> T : + val type = + match(t) : + (t:T&BundleType) : + BundleType $ + for p in ports(t) map : + Port(name(p), direction(p), f(type(p))) + (t:T&VectorType) : + VectorType(f(type(t)), size(t)) + (t) : + t + type as T&Type + +public defmulti map<?T> (f: Expression -> Expression, e:?T&Expression) -> T +defmethod map (f: Expression -> Expression, e:Expression) -> Expression : + match(e) : + (e:Field) : Field(f(exp(e)), name(e), type(e)) + (e:Index) : Index(f(exp(e)), value(e), type(e)) + (e:DoPrim) : DoPrim(op(e), map(f, args(e)), consts(e), type(e)) + (e:ReadPort) : ReadPort(f(mem(e)), f(index(e)), type(e)) + (e) : e + +public defmulti map<?T> (f: Expression -> Expression, e:?T&Element) -> T +defmethod map (f: Expression -> Expression, e:Element) -> Element : + match(e) : + (e:Register) : + Register(type(e), f(value(e)), f(enable(e))) + (e:Memory) : + val writers* = for w in writers(e) map : + WritePort(f(index(w)), f(value(w)), f(enable(w))) + Memory(type(e), writers*) + (e:Node) : + Node(type(e), f(value(e))) + (e:Instance) : + val ports* = for p in ports(e) map : + key(p) => f(value(p)) + Instance(type(e), f(module(e)), ports*) + +public defmulti map<?T> (f: Expression -> Expression, c:?T&Command) -> T +defmethod map (f: Expression -> Expression, c:Command) -> Command : + match(c) : + (c:LetRec) : + val entries* = for entry in entries(c) map : + key(entry) => map(f, value(entry)) + LetRec(entries*, body(c)) + (c:DefAccessor) : DefAccessor(name(c), f(source(c)), f(index(c))) + (c:DefInstance) : DefInstance(name(c), f(module(c))) + (c:Conditionally) : Conditionally(f(pred(c)), conseq(c), alt(c)) + (c:Connect) : Connect(f(loc(c)), f(exp(c))) + (c) : c + +public defmulti map<?T> (f: Command -> Command, c:?T&Command) -> T +defmethod map (f: Command -> Command, c:Command) -> Command : + match(c) : + (c:LetRec) : LetRec(entries(c), f(body(c))) + (c:Conditionally) : Conditionally(pred(c), f(conseq(c)), f(alt(c))) + (c:Begin) : Begin(map(f, body(c))) + (c) : c + +public defmulti children (c:Command) -> List<Command> +defmethod children (c:Command) : + match(c) : + (c:LetRec) : list(body(c)) + (c:Conditionally) : list(conseq(c), alt(c)) + (c:Begin) : body(c) + (c) : List() + |
