aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/passes/RemoveCHIRRTL.scala')
-rw-r--r--src/main/scala/firrtl/passes/RemoveCHIRRTL.scala51
1 files changed, 44 insertions, 7 deletions
diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
index b072dfa0..c841dc32 100644
--- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
+++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
@@ -14,7 +14,9 @@ case class MPort(name: String, clk: Expression)
case class MPorts(readers: ArrayBuffer[MPort], writers: ArrayBuffer[MPort], readwriters: ArrayBuffer[MPort])
case class DataRef(exp: Expression, male: String, female: String, mask: String, rdwrite: Boolean)
-object RemoveCHIRRTL extends Pass {
+object RemoveCHIRRTL extends Transform {
+ def inputForm: CircuitForm = UnknownForm
+ def outputForm: CircuitForm = UnknownForm
val ut = UnknownType
type MPortMap = collection.mutable.LinkedHashMap[String, MPorts]
type SeqMemSet = collection.mutable.HashSet[String]
@@ -22,6 +24,15 @@ object RemoveCHIRRTL extends Pass {
type DataRefMap = collection.mutable.LinkedHashMap[String, DataRef]
type AddrMap = collection.mutable.HashMap[String, Expression]
+ def create_all_exps(ex: Expression): Seq[Expression] = ex.tpe match {
+ case _: GroundType => Seq(ex)
+ case t: BundleType => (t.fields foldLeft Seq[Expression]())((exps, f) =>
+ exps ++ create_all_exps(SubField(ex, f.name, f.tpe))) ++ Seq(ex)
+ case t: VectorType => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) =>
+ exps ++ create_all_exps(SubIndex(ex, i, t.tpe))) ++ Seq(ex)
+ case UnknownType => Seq(ex)
+ }
+
def create_exps(e: Expression): Seq[Expression] = e match {
case ex: Mux =>
val e1s = create_exps(ex.tval)
@@ -59,7 +70,7 @@ object RemoveCHIRRTL extends Pass {
}
def collect_refs(mports: MPortMap, smems: SeqMemSet, types: MPortTypeMap,
- refs: DataRefMap, raddrs: AddrMap)(s: Statement): Statement = s match {
+ refs: DataRefMap, raddrs: AddrMap, renames: RenameMap)(s: Statement): Statement = s match {
case sx: CDefMemory =>
types(sx.name) = sx.tpe
val taddr = UIntType(IntWidth(1 max ceilLog2(sx.size)))
@@ -104,11 +115,25 @@ object RemoveCHIRRTL extends Pass {
addrs += "addr"
clks += "clk"
ens += "en"
+ renames.rename(sx.name, s"${sx.mem}.${sx.name}.rdata")
+ renames.rename(sx.name, s"${sx.mem}.${sx.name}.wdata")
+ val es = create_all_exps(WRef(sx.name, sx.tpe))
+ val rs = create_all_exps(WRef(s"${sx.mem}.${sx.name}.rdata", sx.tpe))
+ val ws = create_all_exps(WRef(s"${sx.mem}.${sx.name}.wdata", sx.tpe))
+ ((es zip rs) zip ws) map {
+ case ((e, r), w) => renames.rename(e.serialize, Seq(r.serialize, w.serialize))
+ }
case MWrite =>
refs(sx.name) = DataRef(SubField(Reference(sx.mem, ut), sx.name, ut), "data", "data", "mask", rdwrite = false)
addrs += "addr"
clks += "clk"
ens += "en"
+ renames.rename(sx.name, s"${sx.mem}.${sx.name}.data")
+ val es = create_all_exps(WRef(sx.name, sx.tpe))
+ val ws = create_all_exps(WRef(s"${sx.mem}.${sx.name}.data", sx.tpe))
+ (es zip ws) map {
+ case (e, w) => renames.rename(e.serialize, w.serialize)
+ }
case MRead =>
refs(sx.name) = DataRef(SubField(Reference(sx.mem, ut), sx.name, ut), "data", "data", "blah", rdwrite = false)
addrs += "addr"
@@ -118,13 +143,19 @@ object RemoveCHIRRTL extends Pass {
raddrs(e.name) = SubField(SubField(Reference(sx.mem, ut), sx.name, ut), "en", ut)
case _ => ens += "en"
}
+ renames.rename(sx.name, s"${sx.mem}.${sx.name}.data")
+ val es = create_all_exps(WRef(sx.name, sx.tpe))
+ val rs = create_all_exps(WRef(s"${sx.mem}.${sx.name}.data", sx.tpe))
+ (es zip rs) map {
+ case (e, r) => renames.rename(e.serialize, r.serialize)
+ }
case MInfer => // do nothing if it's not being used
}
Block(
(addrs map (x => Connect(sx.info, SubField(SubField(Reference(sx.mem, ut), sx.name, ut), x, ut), sx.exps.head))) ++
(clks map (x => Connect(sx.info, SubField(SubField(Reference(sx.mem, ut), sx.name, ut), x, ut), sx.exps(1)))) ++
(ens map (x => Connect(sx.info,SubField(SubField(Reference(sx.mem,ut), sx.name, ut), x, ut), one))))
- case sx => sx map collect_refs(mports, smems, types, refs, raddrs)
+ case sx => sx map collect_refs(mports, smems, types, refs, raddrs, renames)
}
def get_mask(refs: DataRefMap)(e: Expression): Expression =
@@ -213,17 +244,23 @@ object RemoveCHIRRTL extends Pass {
}
}
- def remove_chirrtl_m(m: DefModule): DefModule = {
+ def remove_chirrtl_m(renames: RenameMap)(m: DefModule): DefModule = {
val mports = new MPortMap
val smems = new SeqMemSet
val types = new MPortTypeMap
val refs = new DataRefMap
val raddrs = new AddrMap
+ renames.setModule(m.name)
(m map collect_smems_and_mports(mports, smems)
- map collect_refs(mports, smems, types, refs, raddrs)
+ map collect_refs(mports, smems, types, refs, raddrs, renames)
map remove_chirrtl_s(refs, raddrs))
}
- def run(c: Circuit): Circuit =
- c copy (modules = c.modules map remove_chirrtl_m)
+ def execute(state: CircuitState): CircuitState = {
+ val c = state.circuit
+ val renames = RenameMap()
+ renames.setCircuit(c.main)
+ val result = c copy (modules = c.modules map remove_chirrtl_m(renames))
+ CircuitState(result, outputForm, state.annotations, Some(renames))
+ }
}