aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/Passes.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/passes/Passes.scala')
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala512
1 files changed, 454 insertions, 58 deletions
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
index 6c77d35d..a6b53e86 100644
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ b/src/main/scala/firrtl/passes/Passes.scala
@@ -9,12 +9,13 @@ import scala.sys.process._
import scala.io.Source
// Datastructures
-import scala.collection.mutable.HashMap
+import scala.collection.mutable.LinkedHashMap
import scala.collection.mutable.ArrayBuffer
import firrtl._
import firrtl.Utils._
import firrtl.PrimOps._
+import firrtl.WrappedExpression._
trait Pass extends LazyLogging {
def name: String
@@ -107,7 +108,7 @@ object ResolveKinds extends Pass {
def name = "Resolve Kinds"
def run (c:Circuit): Circuit = {
def resolve_kinds (m:Module, c:Circuit):Module = {
- val kinds = HashMap[String,Kind]()
+ val kinds = LinkedHashMap[String,Kind]()
def resolve (body:Stmt) = {
def resolve_expr (e:Expression):Expression = {
e match {
@@ -157,7 +158,7 @@ object ResolveKinds extends Pass {
object InferTypes extends Pass {
private var mname = ""
def name = "Infer Types"
- val width_name_hash = HashMap[String,Int]()
+ val width_name_hash = LinkedHashMap[String,Int]()
def set_type (s:Stmt,t:Type) : Stmt = {
s match {
case s:DefWire => DefWire(s.info,s.name,t)
@@ -175,9 +176,9 @@ object InferTypes extends Pass {
}
def remove_unknowns (t:Type): Type = mapr(remove_unknowns_w _,t)
def run (c:Circuit): Circuit = {
- val module_types = HashMap[String,Type]()
+ val module_types = LinkedHashMap[String,Type]()
def infer_types (m:Module) : Module = {
- val types = HashMap[String,Type]()
+ val types = LinkedHashMap[String,Type]()
def infer_types_e (e:Expression) : Expression = {
eMap(infer_types_e _,e) match {
case e:ValidIf => ValidIf(e.cond,e.value,tpe(e.value))
@@ -328,10 +329,10 @@ object CheckGenders extends Pass with StanzaPass {
object InferWidths extends Pass {
def name = "Infer Widths"
var mname = ""
- def solve_constraints (l:Seq[WGeq]) : HashMap[String,Width] = {
+ def solve_constraints (l:Seq[WGeq]) : LinkedHashMap[String,Width] = {
def unique (ls:Seq[Width]) : Seq[Width] = ls.map(w => new WrappedWidth(w)).distinct.map(_.w)
- def make_unique (ls:Seq[WGeq]) : HashMap[String,Width] = {
- val h = HashMap[String,Width]()
+ def make_unique (ls:Seq[WGeq]) : LinkedHashMap[String,Width] = {
+ val h = LinkedHashMap[String,Width]()
for (g <- ls) {
(g.loc) match {
case (w:VarWidth) => {
@@ -369,10 +370,10 @@ object InferWidths extends Pass {
case (w1,w2) => w }}
case (w:ExpWidth) => {
(w.arg1) match {
- case (w1:IntWidth) => IntWidth((2 ^ w1.width) - 1)
+ case (w1:IntWidth) => IntWidth(BigInt((scala.math.pow(2,w1.width.toDouble) - 1).toLong))
case (w1) => w }}
case (w) => w } }
- def substitute (h:HashMap[String,Width])(w:Width) : Width = {
+ def substitute (h:LinkedHashMap[String,Width])(w:Width) : Width = {
//;println-all-debug(["Substituting for [" w "]"])
val wx = simplify(w)
//;println-all-debug(["After Simplify: [" wx "]"])
@@ -394,7 +395,7 @@ object InferWidths extends Pass {
//;println-all-debug(["not varwidth!" w])
}
}
- def b_sub (h:HashMap[String,Width])(w:Width) : Width = {
+ def b_sub (h:LinkedHashMap[String,Width])(w:Width) : Width = {
(wMap(b_sub(h) _,w)) match {
case (w:VarWidth) => if (h.contains(w.name)) h(w.name) else w
case (w) => w
@@ -433,54 +434,44 @@ object InferWidths extends Pass {
//; 2) Remove Cycles
//; 3) Move to solved if not self-recursive
val u = make_unique(l)
- /*
- println("======== UNIQUE CONSTRAINTS ========")
- for (x <- u) { println(x) }
- println("====================================")
- */
+
+ //println("======== UNIQUE CONSTRAINTS ========")
+ //for (x <- u) { println(x) }
+ //println("====================================")
+
- val f = HashMap[String,Width]()
+ val f = LinkedHashMap[String,Width]()
val o = ArrayBuffer[String]()
for (x <- u) {
- /*
- println("==== SOLUTIONS TABLE ====")
- for (x <- f) println(x)
- println("=========================")
- */
+ //println("==== SOLUTIONS TABLE ====")
+ //for (x <- f) println(x)
+ //println("=========================")
val (n, e) = (x._1, x._2)
val e_sub = substitute(f)(e)
- /*
- println("Solving " + n + " => " + e)
- println("After Substitute: " + n + " => " + e_sub)
- println("==== SOLUTIONS TABLE (Post Substitute) ====")
- for (x <- f) println(x)
- println("=========================")
- */
+ //println("Solving " + n + " => " + e)
+ //println("After Substitute: " + n + " => " + e_sub)
+ //println("==== SOLUTIONS TABLE (Post Substitute) ====")
+ //for (x <- f) println(x)
+ //println("=========================")
val ex = remove_cycle(n)(e_sub)
- /*
- println("After Remove Cycle: " + n + " => " + ex)
- */
+ //println("After Remove Cycle: " + n + " => " + ex)
if (!self_rec(n,ex)) {
- /*
- println("Not rec!: " + n + " => " + ex)
- println("Adding [" + n + "=>" + ex + "] to Solutions Table")
- */
+ //println("Not rec!: " + n + " => " + ex)
+ //println("Adding [" + n + "=>" + ex + "] to Solutions Table")
o += n
f(n) = ex
}
}
- /*
- println("Forward Solved Constraints")
- for (x <- f) println(x)
- */
+ //println("Forward Solved Constraints")
+ //for (x <- f) println(x)
//; Backwards Solve
- val b = HashMap[String,Width]()
+ val b = LinkedHashMap[String,Width]()
for (i <- 0 until o.size) {
val n = o(o.size - 1 - i)
/*
@@ -510,7 +501,7 @@ object InferWidths extends Pass {
case (t:ClockType) => IntWidth(1)
case (t) => error("No width!"); IntWidth(-1) } }
def width_BANG (e:Expression) : Width = width_BANG(tpe(e))
- def reduce_var_widths (c:Circuit,h:HashMap[String,Width]) : Circuit = {
+ def reduce_var_widths (c:Circuit,h:LinkedHashMap[String,Width]) : Circuit = {
def evaluate (w:Width) : Width = {
def apply_2 (a:Option[BigInt],b:Option[BigInt], f: (BigInt,BigInt) => BigInt) : Option[BigInt] = {
(a,b) match {
@@ -525,6 +516,7 @@ object InferWidths extends Pass {
}
def max (a:BigInt,b:BigInt) : BigInt = if (a >= b) a else b
def min (a:BigInt,b:BigInt) : BigInt = if (a >= b) b else a
+ def pow (a:BigInt,b:BigInt) : BigInt = BigInt((scala.math.pow(a.toDouble,b.toDouble) - 1).toLong)
def solve (w:Width) : Option[BigInt] = {
(w) match {
case (w:VarWidth) => {
@@ -539,7 +531,7 @@ object InferWidths extends Pass {
case (w:MinWidth) => apply_l(w.args.map(solve _),min)
case (w:PlusWidth) => apply_2(solve(w.arg1),solve(w.arg2),{_ + _})
case (w:MinusWidth) => apply_2(solve(w.arg1),solve(w.arg2),{_ - _})
- case (w:ExpWidth) => apply_2(Some(BigInt(2)),solve(w.arg1),{(x,y) => (x ^ y) - BigInt(1)})
+ case (w:ExpWidth) => apply_2(Some(BigInt(2)),solve(w.arg1),pow)
case (w:IntWidth) => Some(w.width)
case (w) => println(w); error("Shouldn't be here"); None;
}
@@ -691,7 +683,7 @@ object ExpandConnects extends Pass {
def run (c:Circuit): Circuit = {
def expand_connects (m:InModule) : InModule = {
mname = m.name
- val genders = HashMap[String,Gender]()
+ val genders = LinkedHashMap[String,Gender]()
def expand_s (s:Stmt) : Stmt = {
def set_gender (e:Expression) : Expression = {
eMap(set_gender _,e) match {
@@ -854,7 +846,7 @@ object RemoveAccesses extends Pass {
def remove_s (s:Stmt) : Stmt = {
val stmts = ArrayBuffer[Stmt]()
def create_temp (e:Expression) : Expression = {
- val n = firrtl_gensym("GEN",sh)
+ val n = firrtl_gensym_module(mname)
stmts += DefWire(info(s),n,tpe(e))
WRef(n,tpe(e),kind(e),gender(e))
}
@@ -897,7 +889,7 @@ object RemoveAccesses extends Pass {
if (has_access(s.loc)) {
val ls = get_locations(s.loc)
val locx =
- if (ls.size == 1 & ls(0).guard == one) s.loc
+ if (ls.size == 1 & weq(ls(0).guard,one)) s.loc
else {
val temp = create_temp(s.loc)
for (x <- ls) { stmts += Conditionally(s.info,x.guard,Connect(s.info,x.base,temp),Empty()) }
@@ -930,12 +922,12 @@ object ExpandWhens extends Pass {
def name = "Expand Whens"
var mname = ""
// ; ========== Expand When Utilz ==========
- def add (hash:HashMap[WrappedExpression,Expression],key:WrappedExpression,value:Expression) = {
+ def add (hash:LinkedHashMap[WrappedExpression,Expression],key:WrappedExpression,value:Expression) = {
hash += (key -> value)
}
- def get_entries (hash:HashMap[WrappedExpression,Expression],exps:Seq[Expression]) : HashMap[WrappedExpression,Expression] = {
- val hashx = HashMap[WrappedExpression,Expression]()
+ def get_entries (hash:LinkedHashMap[WrappedExpression,Expression],exps:Seq[Expression]) : LinkedHashMap[WrappedExpression,Expression] = {
+ val hashx = LinkedHashMap[WrappedExpression,Expression]()
exps.foreach { e => {
val value = hash.get(e)
value match {
@@ -987,10 +979,10 @@ object ExpandWhens extends Pass {
val bodyx = void_all_s(m.body)
InModule(m.info,m.name,m.ports,Begin(Seq(Begin(voids),bodyx)))
}
- def expand_whens (m:InModule) : Tuple2[HashMap[WrappedExpression,Expression],ArrayBuffer[Stmt]] = {
+ def expand_whens (m:InModule) : Tuple2[LinkedHashMap[WrappedExpression,Expression],ArrayBuffer[Stmt]] = {
val simlist = ArrayBuffer[Stmt]()
mname = m.name
- def expand_whens (netlist:HashMap[WrappedExpression,Expression],p:Expression)(s:Stmt) : Stmt = {
+ def expand_whens (netlist:LinkedHashMap[WrappedExpression,Expression],p:Expression)(s:Stmt) : Stmt = {
(s) match {
case (s:Connect) => netlist(s.loc) = s.exp
case (s:IsInvalid) => netlist(s.exp) = WInvalid()
@@ -1025,14 +1017,14 @@ object ExpandWhens extends Pass {
}
}
case (s:Print) => {
- if (p == one) {
+ if (weq(p,one)) {
simlist += s
} else {
simlist += Print(s.info,s.string,s.args,s.clk,AND(p,s.en))
}
}
case (s:Stop) => {
- if (p == one) {
+ if (weq(p,one)) {
simlist += s
} else {
simlist += Stop(s.info,s.ret,s.clk,AND(p,s.en))
@@ -1042,7 +1034,7 @@ object ExpandWhens extends Pass {
}
s
}
- val netlist = HashMap[WrappedExpression,Expression]()
+ val netlist = LinkedHashMap[WrappedExpression,Expression]()
expand_whens(netlist,one)(m.body)
//println("Netlist:")
@@ -1052,7 +1044,7 @@ object ExpandWhens extends Pass {
( netlist, simlist )
}
- def create_module (netlist:HashMap[WrappedExpression,Expression],simlist:ArrayBuffer[Stmt],m:InModule) : InModule = {
+ def create_module (netlist:LinkedHashMap[WrappedExpression,Expression],simlist:ArrayBuffer[Stmt],m:InModule) : InModule = {
mname = m.name
val stmts = ArrayBuffer[Stmt]()
val connections = ArrayBuffer[Stmt]()
@@ -1242,10 +1234,9 @@ object SplitExp extends Pass {
def split_exp (m:InModule) : InModule = {
mname = m.name
val v = ArrayBuffer[Stmt]()
- val sh = sym_hash
def split_exp_s (s:Stmt) : Stmt = {
def split (e:Expression) : Expression = {
- val n = firrtl_gensym("GEN",sh)
+ val n = firrtl_gensym_module(mname)
v += DefNode(info(s),n,e)
WRef(n,tpe(e),kind(e),gender(e))
}
@@ -1385,7 +1376,7 @@ object LowerTypes extends Pass {
//;------------- Pass ------------------
def lower_types (m:Module) : Module = {
- val mdt = HashMap[String,Type]()
+ val mdt = LinkedHashMap[String,Type]()
mname = m.name
def lower_types (s:Stmt) : Stmt = {
def lower_mem (e:Expression) : Seq[Expression] = {
@@ -1522,3 +1513,408 @@ object LowerTypes extends Pass {
}
}
+object CInferTypes extends Pass {
+ def name = "CInfer Types"
+ var mname = ""
+ def set_type (s:Stmt,t:Type) : Stmt = {
+ (s) match {
+ case (s:DefWire) => DefWire(s.info,s.name,t)
+ case (s:DefRegister) => DefRegister(s.info,s.name,t,s.clock,s.reset,s.init)
+ case (s:CDefMemory) => CDefMemory(s.info,s.name,t,s.size,s.seq)
+ case (s:CDefMPort) => CDefMPort(s.info,s.name,t,s.mem,s.exps,s.direction)
+ case (s:DefNode) => s
+ case (s:DefPoison) => DefPoison(s.info,s.name,t)
+ }
+ }
+
+ def to_field (p:Port) : Field = {
+ if (p.direction == OUTPUT) Field(p.name,DEFAULT,p.tpe)
+ else if (p.direction == INPUT) Field(p.name,REVERSE,p.tpe)
+ else error("Shouldn't be here"); Field(p.name,REVERSE,p.tpe)
+ }
+ def module_type (m:Module) : Type = BundleType(m.ports.map(p => to_field(p)))
+ def field_type (v:Type,s:String) : Type = {
+ (v) match {
+ case (v:BundleType) => {
+ val ft = v.fields.find(p => p.name == s)
+ if (ft != None) ft.get.tpe
+ else UnknownType()
+ }
+ case (v) => UnknownType()
+ }
+ }
+ def sub_type (v:Type) : Type =
+ (v) match {
+ case (v:VectorType) => v.tpe
+ case (v) => UnknownType()
+ }
+ def run (c:Circuit) : Circuit = {
+ val module_types = LinkedHashMap[String,Type]()
+ def infer_types (m:Module) : Module = {
+ val types = LinkedHashMap[String,Type]()
+ def infer_types_e (e:Expression) : Expression = {
+ (eMap(infer_types_e _,e)) match {
+ case (e:Ref) => Ref(e.name, types.getOrElse(e.name,UnknownType()))
+ case (e:SubField) => SubField(e.exp,e.name,field_type(tpe(e.exp),e.name))
+ case (e:SubIndex) => SubIndex(e.exp,e.value,sub_type(tpe(e.exp)))
+ case (e:SubAccess) => SubAccess(e.exp,e.index,sub_type(tpe(e.exp)))
+ case (e:DoPrim) => set_primop_type(e)
+ case (e:Mux) => Mux(e.cond,e.tval,e.fval,mux_type(e.tval,e.tval))
+ case (e:ValidIf) => ValidIf(e.cond,e.value,tpe(e.value))
+ case (_:UIntValue|_:SIntValue) => e
+ }
+ }
+ def infer_types_s (s:Stmt) : Stmt = {
+ (s) match {
+ case (s:DefRegister) => {
+ types(s.name) = s.tpe
+ eMap(infer_types_e _,s)
+ s
+ }
+ case (s:DefWire) => {
+ types(s.name) = s.tpe
+ s
+ }
+ case (s:DefPoison) => {
+ types(s.name) = s.tpe
+ s
+ }
+ case (s:DefNode) => {
+ val sx = eMap(infer_types_e _,s)
+ val t = get_type(sx)
+ types(s.name) = t
+ sx
+ }
+ case (s:DefMemory) => {
+ types(s.name) = get_type(s)
+ s
+ }
+ case (s:CDefMPort) => {
+ val t = types.getOrElse(s.mem,UnknownType())
+ types(s.name) = t
+ CDefMPort(s.info,s.name,t,s.mem,s.exps,s.direction)
+ }
+ case (s:CDefMemory) => {
+ types(s.name) = s.tpe
+ s
+ }
+ case (s:DefInstance) => {
+ types(s.name) = module_types.getOrElse(s.module,UnknownType())
+ s
+ }
+ case (s) => eMap(infer_types_e _,sMap(infer_types_s _,s))
+ }
+ }
+ for (p <- m.ports) {
+ types(p.name) = p.tpe
+ }
+ (m) match {
+ case (m:InModule) => InModule(m.info,m.name,m.ports,infer_types_s(m.body))
+ case (m:ExModule) => m
+ }
+ }
+
+ //; MAIN
+ for (m <- c.modules) {
+ module_types(m.name) = module_type(m)
+ }
+ val modulesx = c.modules.map(m => infer_types(m))
+ Circuit(c.info, modulesx, c.main)
+ }
+}
+
+object CInferMDir extends Pass {
+ def name = "CInfer MDir"
+ var mname = ""
+ def run (c:Circuit) : Circuit = {
+ def infer_mdir (m:Module) : Module = {
+ val mports = LinkedHashMap[String,MPortDir]()
+ def infer_mdir_e (dir:MPortDir)(e:Expression) : Expression = {
+ (eMap(infer_mdir_e(dir) _,e)) match {
+ case (e:Ref) => {
+ if (mports.contains(e.name)) {
+ val new_mport_dir = {
+ (mports(e.name),dir) match {
+ case (MInfer,MInfer) => error("Shouldn't be here")
+ case (MInfer,MWrite) => MWrite
+ case (MInfer,MRead) => MRead
+ case (MInfer,MReadWrite) => MReadWrite
+ case (MWrite,MInfer) => error("Shouldn't be here")
+ case (MWrite,MWrite) => MWrite
+ case (MWrite,MRead) => MReadWrite
+ case (MWrite,MReadWrite) => MReadWrite
+ case (MRead,MInfer) => error("Shouldn't be here")
+ case (MRead,MWrite) => MReadWrite
+ case (MRead,MRead) => MRead
+ case (MRead,MReadWrite) => MReadWrite
+ case (MReadWrite,MInfer) => error("Shouldn't be here")
+ case (MReadWrite,MWrite) => MReadWrite
+ case (MReadWrite,MRead) => MReadWrite
+ case (MReadWrite,MReadWrite) => MReadWrite
+ }
+ }
+ mports(e.name) = new_mport_dir
+ }
+ e
+ }
+ case (e) => e
+ }
+ }
+ def infer_mdir_s (s:Stmt) : Stmt = {
+ (s) match {
+ case (s:CDefMPort) => {
+ mports(s.name) = s.direction
+ eMap(infer_mdir_e(MRead) _,s)
+ }
+ case (s:Connect) => {
+ infer_mdir_e(MRead)(s.exp)
+ infer_mdir_e(MWrite)(s.loc)
+ s
+ }
+ case (s:BulkConnect) => {
+ infer_mdir_e(MRead)(s.exp)
+ infer_mdir_e(MWrite)(s.loc)
+ s
+ }
+ case (s) => eMap(infer_mdir_e(MRead) _, sMap(infer_mdir_s,s))
+ }
+ }
+ def set_mdir_s (s:Stmt) : Stmt = {
+ (s) match {
+ case (s:CDefMPort) =>
+ CDefMPort(s.info,s.name,s.tpe,s.mem,s.exps,mports(s.name))
+ case (s) => sMap(set_mdir_s _,s)
+ }
+ }
+ (m) match {
+ case (m:InModule) => {
+ infer_mdir_s(m.body)
+ InModule(m.info,m.name,m.ports,set_mdir_s(m.body))
+ }
+ case (m:ExModule) => m
+ }
+ }
+
+ //; MAIN
+ Circuit(c.info, c.modules.map(m => infer_mdir(m)), c.main)
+ }
+}
+
+case class MPort( val name : String, val clk : Expression)
+case class MPorts( val readers : ArrayBuffer[MPort], val writers : ArrayBuffer[MPort], val readwriters : ArrayBuffer[MPort])
+case class DataRef( val exp : Expression, val male : String, val female : String, val mask : String, val rdwrite : Boolean)
+
+object RemoveCHIRRTL extends Pass {
+ def name = "Remove CHIRRTL"
+ var mname = ""
+ def create_exps (e:Expression) : Seq[Expression] = {
+ (e) match {
+ case (e:Mux)=>
+ (create_exps(e.tval),create_exps(e.fval)).zipped.map((e1,e2) => {
+ Mux(e.cond,e1,e2,mux_type(e1,e2))
+ })
+ case (e:ValidIf) =>
+ create_exps(e.value).map(e1 => {
+ ValidIf(e.cond,e1,tpe(e1))
+ })
+ case (e) => (tpe(e)) match {
+ case (_:UIntType|_:SIntType|_:ClockType) => Seq(e)
+ case (t:BundleType) =>
+ t.fields.flatMap(f => create_exps(SubField(e,f.name,f.tpe)))
+ case (t:VectorType)=>
+ (0 until t.size).flatMap(i => create_exps(SubIndex(e,i,t.tpe)))
+ case (t:UnknownType) => Seq(e)
+ }
+ }
+ }
+ def run (c:Circuit) : Circuit = {
+ def remove_chirrtl_m (m:InModule) : InModule = {
+ val hash = LinkedHashMap[String,MPorts]()
+ val repl = LinkedHashMap[String,DataRef]()
+ val ut = UnknownType()
+ val mport_types = LinkedHashMap[String,Type]()
+ def EMPs () : MPorts = MPorts(ArrayBuffer[MPort](),ArrayBuffer[MPort](),ArrayBuffer[MPort]())
+ def collect_mports (s:Stmt) : Stmt = {
+ (s) match {
+ case (s:CDefMPort) => {
+ val mports = hash.getOrElse(s.mem,EMPs())
+ s.direction match {
+ case MRead => mports.readers += MPort(s.name,s.exps(1))
+ case MWrite => mports.writers += MPort(s.name,s.exps(1))
+ case MReadWrite => mports.readwriters += MPort(s.name,s.exps(1))
+ }
+ hash(s.mem) = mports
+ s
+ }
+ case (s) => sMap(collect_mports _,s)
+ }
+ }
+ def collect_refs (s:Stmt) : Stmt = {
+ (s) match {
+ case (s:CDefMemory) => {
+ mport_types(s.name) = s.tpe
+ val stmts = ArrayBuffer[Stmt]()
+ val taddr = UIntType(IntWidth(scala.math.max(1,ceil_log2(s.size))))
+ val tdata = s.tpe
+ def set_poison (vec:Seq[MPort],addr:String) : Unit = {
+ for (r <- vec ) {
+ stmts += IsInvalid(s.info,SubField(SubField(Ref(s.name,ut),r.name,ut),addr,taddr))
+ stmts += Connect(s.info,SubField(SubField(Ref(s.name,ut),r.name,ut),"clk",taddr),r.clk)
+ }
+ }
+ def set_enable (vec:Seq[MPort],en:String) : Unit = {
+ for (r <- vec ) {
+ stmts += Connect(s.info,SubField(SubField(Ref(s.name,ut),r.name,ut),en,taddr),zero)
+ }}
+ def set_wmode (vec:Seq[MPort],wmode:String) : Unit = {
+ for (r <- vec) {
+ stmts += Connect(s.info,SubField(SubField(Ref(s.name,ut),r.name,ut),wmode,taddr),zero)
+ }}
+ def set_write (vec:Seq[MPort],data:String,mask:String) : Unit = {
+ val tmask = create_mask(s.tpe)
+ for (r <- vec ) {
+ stmts += IsInvalid(s.info,SubField(SubField(Ref(s.name,ut),r.name,ut),data,tdata))
+ for (x <- create_exps(SubField(SubField(Ref(s.name,ut),r.name,ut),mask,tmask)) ) {
+ stmts += Connect(s.info,x,zero)
+ }}}
+ val rds = (hash.getOrElse(s.name,EMPs())).readers
+ set_poison(rds,"addr")
+ set_enable(rds,"en")
+ val wrs = (hash.getOrElse(s.name,EMPs())).writers
+ set_poison(wrs,"addr")
+ set_enable(wrs,"en")
+ set_write(wrs,"data","mask")
+ val rws = (hash.getOrElse(s.name,EMPs())).readwriters
+ set_poison(rws,"addr")
+ set_wmode(rws,"wmode")
+ set_enable(rws,"en")
+ set_write(rws,"data","mask")
+ val read_l = if (s.seq) 1 else 0
+ val mem = DefMemory(s.info,s.name,s.tpe,s.size,1,read_l,rds.map(_.name),wrs.map(_.name),rws.map(_.name))
+ Begin(Seq(mem,Begin(stmts)))
+ }
+ case (s:CDefMPort) => {
+ mport_types(s.name) = mport_types(s.mem)
+ val addrs = ArrayBuffer[String]()
+ val ens = ArrayBuffer[String]()
+ val masks = ArrayBuffer[String]()
+ s.direction match {
+ case MReadWrite => {
+ repl(s.name) = DataRef(SubField(Ref(s.mem,ut),s.name,ut),"rdata","data","mask",true)
+ addrs += "addr"
+ ens += "en"
+ masks += "mask"
+ }
+ case MWrite => {
+ repl(s.name) = DataRef(SubField(Ref(s.mem,ut),s.name,ut),"data","data","mask",false)
+ addrs += "addr"
+ ens += "en"
+ masks += "mask"
+ }
+ case _ => {
+ repl(s.name) = DataRef(SubField(Ref(s.mem,ut),s.name,ut),"data","data","blah",false)
+ addrs += "addr"
+ ens += "en"
+ }
+ }
+ val stmts = ArrayBuffer[Stmt]()
+ for (x <- addrs ) {
+ stmts += Connect(s.info,SubField(SubField(Ref(s.mem,ut),s.name,ut),x,ut),s.exps(0))
+ }
+ for (x <- ens ) {
+ stmts += Connect(s.info,SubField(SubField(Ref(s.mem,ut),s.name,ut),x,ut),one)
+ }
+ Begin(stmts)
+ }
+ case (s) => sMap(collect_refs _,s)
+ }
+ }
+ def remove_chirrtl_s (s:Stmt) : Stmt = {
+ var has_write_mport = false
+ var has_readwrite_mport:Option[Expression] = None
+ def remove_chirrtl_e (g:Gender)(e:Expression) : Expression = {
+ (e) match {
+ case (e:Ref) => {
+ if (repl.contains(e.name)) {
+ val vt = repl(e.name)
+ g match {
+ case MALE => SubField(vt.exp,vt.male,e.tpe)
+ case FEMALE => {
+ has_write_mport = true
+ if (vt.rdwrite == true)
+ has_readwrite_mport = Some(SubField(vt.exp,"wmode",UIntType(IntWidth(1))))
+ SubField(vt.exp,vt.female,e.tpe)
+ }
+ }
+ } else e
+ }
+ case (e:SubAccess) => SubAccess(remove_chirrtl_e(g)(e.exp),remove_chirrtl_e(MALE)(e.index),e.tpe)
+ case (e) => eMap(remove_chirrtl_e(g) _,e)
+ }
+ }
+ def get_mask (e:Expression) : Expression = {
+ (eMap(get_mask _,e)) match {
+ case (e:Ref) => {
+ if (repl.contains(e.name)) {
+ val vt = repl(e.name)
+ val t = create_mask(e.tpe)
+ SubField(vt.exp,vt.mask,t)
+ } else e
+ }
+ case (e) => e
+ }
+ }
+ (s) match {
+ case (s:Connect) => {
+ val stmts = ArrayBuffer[Stmt]()
+ val rocx = remove_chirrtl_e(MALE)(s.exp)
+ val locx = remove_chirrtl_e(FEMALE)(s.loc)
+ stmts += Connect(s.info,locx,rocx)
+ if (has_write_mport) {
+ val e = get_mask(s.loc)
+ for (x <- create_exps(e) ) {
+ stmts += Connect(s.info,x,one)
+ }
+ if (has_readwrite_mport != None) {
+ val wmode = has_readwrite_mport.get
+ stmts += Connect(s.info,wmode,one)
+ }
+ }
+ if (stmts.size > 1) Begin(stmts)
+ else stmts(0)
+ }
+ case (s:BulkConnect) => {
+ val stmts = ArrayBuffer[Stmt]()
+ val locx = remove_chirrtl_e(FEMALE)(s.loc)
+ val rocx = remove_chirrtl_e(MALE)(s.exp)
+ stmts += BulkConnect(s.info,locx,rocx)
+ if (has_write_mport != false) {
+ val ls = get_valid_points(tpe(s.loc),tpe(s.exp),DEFAULT,DEFAULT)
+ val locs = create_exps(get_mask(s.loc))
+ for (x <- ls ) {
+ val locx = locs(x._1)
+ stmts += Connect(s.info,locx,one)
+ }
+ if (has_readwrite_mport != None) {
+ val wmode = has_readwrite_mport.get
+ stmts += Connect(s.info,wmode,one)
+ }
+ }
+ if (stmts.size > 1) Begin(stmts)
+ else stmts(0)
+ }
+ case (s) => eMap(remove_chirrtl_e(MALE) _, sMap(remove_chirrtl_s,s))
+ }
+ }
+ collect_mports(m.body)
+ val sx = collect_refs(m.body)
+ InModule(m.info,m.name, m.ports, remove_chirrtl_s(sx))
+ }
+ val modulesx = c.modules.map{ m => {
+ (m) match {
+ case (m:InModule) => remove_chirrtl_m(m)
+ case (m:ExModule) => m
+ }}}
+ Circuit(c.info,modulesx, c.main)
+ }
+}