aboutsummaryrefslogtreecommitdiff
path: root/src/main/stanza/chirrtl.stanza
diff options
context:
space:
mode:
authorazidar2016-01-23 06:58:59 -0800
committerazidar2016-01-23 06:58:59 -0800
commit99062792e5006dbf4c6b1f97da9121bbd6217c7a (patch)
treeb8cf5153edecc108a65894da5b00c03c05e9fffb /src/main/stanza/chirrtl.stanza
parentcce5603ac7f5765434ec8239053b1fde74a2c67f (diff)
Changed chirrtl to not require known mask values
Diffstat (limited to 'src/main/stanza/chirrtl.stanza')
-rw-r--r--src/main/stanza/chirrtl.stanza177
1 files changed, 155 insertions, 22 deletions
diff --git a/src/main/stanza/chirrtl.stanza b/src/main/stanza/chirrtl.stanza
index 049d0b0d..fd39c001 100644
--- a/src/main/stanza/chirrtl.stanza
+++ b/src/main/stanza/chirrtl.stanza
@@ -3,11 +3,146 @@ defpackage firrtl/chirrtl :
import verse
import firrtl/ir2
import firrtl/ir-utils
+ import firrtl/primops
-public defstruct ToIR <: Pass
-public defmethod pass (b:ToIR) -> (Circuit -> Circuit) : to-ir
-public defmethod name (b:ToIR) -> String : "To FIRRTL"
-public defmethod short-name (b:ToIR) -> String : "to-firrtl"
+; CHIRRTL Additional IR Nodes
+public definterface MPortDir
+public val MRead = new MPortDir
+public val MWrite = new MPortDir
+public val MReadWrite = new MPortDir
+
+defmethod print (o:OutputStream, m:MPortDir) :
+ switch { m == _ } :
+ MRead : print(o,"read")
+ MWrite : print(o,"write")
+ MReadWrite : print(o,"rdwr")
+
+public defstruct CDefMemory <: Stmt : ;LOW
+ info: FileInfo with: (as-method => true)
+ name: Symbol
+ type: Type
+ size: Int
+ seq?: True|False
+public defstruct CDefMPort <: Stmt :
+ info: FileInfo with: (as-method => true)
+ name: Symbol
+ type: Type
+ mem: Symbol
+ exps: List<Expression>
+ direction: MPortDir
+
+defmethod print (o:OutputStream,c:CDefMemory) :
+ if seq?(c) : print-all(o, ["smem " name(c) " : " type(c) "[" size(c) "]"])
+ else : print-all(o, ["cmem " name(c) " : " type(c) "[" size(c) "]"])
+defmethod map (f: Type -> Type, c:CDefMemory) -> CDefMemory :
+ CDefMemory(info(c),name(c),f(type(c)),size(c),seq?(c))
+defmethod map (f: Symbol -> Symbol, c:CDefMemory) -> CDefMemory :
+ CDefMemory(info(c),f(name(c)),type(c),size(c),seq?(c))
+
+defmethod print (o:OutputStream,c:CDefMPort) :
+ print-all(o, [direction(c) " mport " name(c) " = " mem(c) "[" exps(c)[0] "], " exps(c)[1]])
+defmethod map (f: Expression -> Expression, c:CDefMPort) -> CDefMPort :
+ CDefMPort(info(c),name(c),type(c),mem(c),map(f,exps(c)),direction(c))
+defmethod map (f: Type -> Type, c:CDefMPort) -> CDefMPort :
+ CDefMPort(info(c),name(c),f(type(c)),mem(c),exps(c),direction(c))
+defmethod map (f: Symbol -> Symbol, c:CDefMPort) -> CDefMPort :
+ CDefMPort(info(c),f(name(c)),type(c),mem(c),exps(c),direction(c))
+
+
+;======================= Infer Chirrtl Types ======================
+public defstruct CInferTypes <: Pass
+public defmethod pass (b:CInferTypes) -> (Circuit -> Circuit) : infer-types
+public defmethod name (b:CInferTypes) -> String : "Infer CTypes"
+public defmethod short-name (b:CInferTypes) -> String : "infer-ctypes"
+
+;--------------- Utils -----------------
+
+defn set-type (s:Stmt,t:Type) -> Stmt :
+ match(s) :
+ (s:DefWire) : DefWire(info(s),name(s),t)
+ (s:DefRegister) : DefRegister(info(s),name(s),t,clock(s),reset(s),init(s))
+ (s:CDefMemory) : CDefMemory(info(s),name(s),t,size(s),seq?(s))
+ (s:CDefMPort) : CDefMPort(info(s),name(s),t,mem(s),exps(s),direction(s))
+ (s:DefNode) : s
+ (s:DefPoison) : DefPoison(info(s),name(s),t)
+
+defn to-field (p:Port) -> Field :
+ if direction(p) == OUTPUT : Field(name(p),DEFAULT,type(p))
+ else if direction(p) == INPUT : Field(name(p),REVERSE,type(p))
+ else : error("Shouldn't be here")
+defn module-type (m:Module) -> Type :
+ BundleType(for p in ports(m) map : to-field(p))
+defn field-type (v:Type,s:Symbol) -> Type :
+ match(v) :
+ (v:BundleType) :
+ val ft = for p in fields(v) find : name(p) == s
+ if ft != false : type(ft as Field)
+ else : UnknownType()
+ (v) : UnknownType()
+defn sub-type (v:Type) -> Type :
+ match(v) :
+ (v:VectorType) : type(v)
+ (v) : UnknownType()
+
+;--------------- Pass -----------------
+
+defn infer-types (c:Circuit) -> Circuit :
+ val module-types = HashTable<Symbol,Type>(symbol-hash)
+ defn infer-types (m:Module) -> Module :
+ val types = HashTable<Symbol,Type>(symbol-hash)
+ defn infer-types-e (e:Expression) -> Expression :
+ match(map(infer-types-e,e)) :
+ (e:Ref) : Ref(name(e), types[name(e)])
+ (e:SubField) : SubField(exp(e),name(e),field-type(type(exp(e)),name(e)))
+ (e:SubIndex) : SubIndex(exp(e),value(e),sub-type(type(exp(e))))
+ (e:SubAccess) : SubAccess(exp(e),index(e),sub-type(type(exp(e))))
+ (e:DoPrim) : set-primop-type(e)
+ (e:UIntValue|SIntValue) : e
+ defn infer-types-s (s:Stmt) -> Stmt :
+ match(s) :
+ (s:DefRegister) :
+ types[name(s)] = type(s)
+ map(infer-types-e,s)
+ s
+ (s:DefWire|DefPoison) :
+ types[name(s)] = type(s)
+ s
+ (s:DefNode) :
+ val s* = map(infer-types-e,s)
+ val t = type(value(s*))
+ types[name(s*)] = t
+ s*
+ (s:CDefMPort) :
+ val t = types[mem(s)]
+ types[name(s)] = t
+ CDefMPort(info(s),name(s),t,mem(s),exps(s),direction(s))
+ (s:CDefMemory) :
+ types[name(s)] = type(s)
+ s
+ (s:DefInstance) :
+ types[name(s)] = module-types[module(s)]
+ s
+ (s) : map{infer-types-e,_} $ map(infer-types-s,s)
+ for p in ports(m) do :
+ types[name(p)] = type(p)
+ match(m) :
+ (m:InModule) :
+ InModule(info(m),name(m),ports(m),infer-types-s(body(m)))
+ (m:ExModule) : m
+
+ ; MAIN
+ for m in modules(c) do :
+ module-types[name(m)] = module-type(m)
+ Circuit{info(c), _, main(c) } $
+ for m in modules(c) map :
+ infer-types(m)
+
+;==========================================
+
+public defstruct RemoveCHIRRTL <: Pass
+public defmethod pass (b:RemoveCHIRRTL) -> (Circuit -> Circuit) : remove-chirrtl
+public defmethod name (b:RemoveCHIRRTL) -> String : "Remove CHIRRTL"
+public defmethod short-name (b:RemoveCHIRRTL) -> String : "remove-chirrtl"
defstruct MPort :
name : Symbol
@@ -36,8 +171,8 @@ defn create-exps (e:Expression) -> List<Expression> :
for i in 0 to size(t) map-append :
create-exps(SubIndex(e,i,type(t)))
-defn to-ir (c:Circuit) :
- defn to-ir-m (m:InModule) -> InModule :
+defn remove-chirrtl (c:Circuit) :
+ defn remove-chirrtl-m (m:InModule) -> InModule :
val hash = HashTable<Symbol,MPorts>(symbol-hash)
val sh = get-sym-hash(m,keys(v-keywords))
val repl = HashTable<Symbol,DataRef>(symbol-hash)
@@ -126,14 +261,12 @@ defn to-ir (c:Circuit) :
add(stmts,Connect(info(s),SubField(SubField(Ref(mem(s),ut),name(s),ut),x,ut),exps(s)[0]))
for x in ens do :
add(stmts,Connect(info(s),SubField(SubField(Ref(mem(s),ut),name(s),ut),x,ut),one))
- for x in masks do :
- add(stmts,Connect(info(s),SubField(SubField(Ref(mem(s),ut),name(s),ut),x,ut),exps(s)[2]))
Begin $ to-list $ stmts
(s) : map(collect-refs,s)
- defn to-ir-s (s:Stmt) -> Stmt :
+ defn remove-chirrtl-s (s:Stmt) -> Stmt :
var has-write-mport? = false
- defn to-ir-e (e:Expression,g:Gender) -> Expression :
- match(map(to-ir-e{_,g},e)) :
+ defn remove-chirrtl-e (e:Expression,g:Gender) -> Expression :
+ match(map(remove-chirrtl-e{_,g},e)) :
(e:Ref) :
if key?(repl,name(e)) :
val vt = repl[name(e)]
@@ -156,37 +289,37 @@ defn to-ir (c:Circuit) :
match(s) :
(s:Connect) :
val stmts = Vector<Stmt>()
- val roc* = to-ir-e(exp(s),MALE)
- val loc* = to-ir-e(loc(s),FEMALE)
+ val roc* = remove-chirrtl-e(exp(s),MALE)
+ val loc* = remove-chirrtl-e(loc(s),FEMALE)
add(stmts,Connect(info(s),loc*,roc*))
if has-write-mport? :
val e = get-mask(loc(s))
for x in create-exps(e) do :
- add(stmts,Connect(info(s),x,UInt(1)))
- if length(stmts > 1) : Begin(to-list(stmts))
+ add(stmts,Connect(info(s),x,one))
+ if length(stmts) > 1 : Begin(to-list(stmts))
else : stmts[0]
(s:BulkConnect) :
val stmts = Vector<Stmt>()
- val loc* = to-ir-e(loc(s),FEMALE)
- val roc* = to-ir-e(exp(s),MALE)
+ val loc* = remove-chirrtl-e(loc(s),FEMALE)
+ val roc* = remove-chirrtl-e(exp(s),MALE)
add(stmts,BulkConnect(info(s),loc*,roc*))
if has-write-mport? :
val ls = get-valid-points(type(loc(s)),type(exp(s)),DEFAULT,DEFAULT)
val locs = create-exps(get-mask(loc(s)))
for x in ls do :
val loc* = locs[x[0]]
- add(stmts,Connect(info(s),loc*,UInt(1)))
- if length(stmts > 1) : Begin(to-list(stmts))
+ add(stmts,Connect(info(s),loc*,one))
+ if length(stmts) > 1 : Begin(to-list(stmts))
else : stmts[0]
- (s) : map(to-ir-e{_,MALE}, map(to-ir-s,s))
+ (s) : map(remove-chirrtl-e{_,MALE}, map(remove-chirrtl-s,s))
collect-mports(body(m))
val s* = collect-refs(body(m))
- InModule(info(m),name(m), ports(m), to-ir-s(s*))
+ InModule(info(m),name(m), ports(m), remove-chirrtl-s(s*))
Circuit(info(c),modules*, main(c)) where :
val modules* =
for m in modules(c) map :
match(m) :
- (m:InModule) : to-ir-m(m)
+ (m:InModule) : remove-chirrtl-m(m)
(m:ExModule) : m