diff options
| author | jackkoenig | 2016-04-13 16:38:34 -0700 |
|---|---|---|
| committer | jackkoenig | 2016-04-22 13:46:15 -0700 |
| commit | 6fcf8bcf215106d3c34a5d33ad89fa5a1adaa4db (patch) | |
| tree | 3ef98b2ed2d211320a6ffaadeafe089ad07f48da /src | |
| parent | 07cae2b9a53c9e2e3492c493dc23c0afe2d0f7e0 (diff) | |
Refactor LowerTypes
Make loweredName a public utility function of the Pass
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Emitter.scala | 42 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 7 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/LowerTypes.scala | 455 |
3 files changed, 271 insertions, 233 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index ceccccbc..109bfad8 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -87,8 +87,8 @@ class VerilogEmitter extends Emitter { case (e:Mux) => emit2(Seq(e.cond," ? ",cast(e.tval)," : ",cast(e.fval)),top + 1) case (e:ValidIf) => emit2(Seq(cast(e.value)),top + 1) case (e:WRef) => w.get.write(e.serialize) - case (e:WSubField) => w.get.write(lowered_name(e)) - case (e:WSubAccess) => w.get.write(lowered_name(e.exp) + "[" + lowered_name(e.index) + "]") + case (e:WSubField) => w.get.write(LowerTypes.loweredName(e)) + case (e:WSubAccess) => w.get.write(LowerTypes.loweredName(e.exp) + "[" + LowerTypes.loweredName(e.index) + "]") case (e:WSubIndex) => w.get.write(e.serialize) case (_:UIntValue|_:SIntValue) => v_print(e) } @@ -335,14 +335,14 @@ class VerilogEmitter extends Emitter { def instantiate (n:String,m:String,es:Seq[Expression]) = { instdeclares += Seq(m," ",n," (") (es,0 until es.size).zipped.foreach{ (e,i) => { - val s = Seq(tab,".",remove_root(e),"(",lowered_name(e),")") + val s = Seq(tab,".",remove_root(e),"(",LowerTypes.loweredName(e),")") if (i != es.size - 1) instdeclares += Seq(s,",") else instdeclares += s }} instdeclares += Seq(");") for (e <- es) { - declare("wire",lowered_name(e),tpe(e)) - val ex = WRef(lowered_name(e),tpe(e),kind(e),gender(e)) + declare("wire",LowerTypes.loweredName(e),tpe(e)) + val ex = WRef(LowerTypes.loweredName(e),tpe(e),kind(e),gender(e)) if (gender(e) == FEMALE) { assign(ex,netlist(e)) } @@ -449,10 +449,10 @@ class VerilogEmitter extends Emitter { val en = mem_exp(r,"en") val clk = mem_exp(r,"clk") - declare("wire",lowered_name(data),tpe(data)) - declare("wire",lowered_name(addr),tpe(addr)) - declare("wire",lowered_name(en),tpe(en)) - declare("wire",lowered_name(clk),tpe(clk)) + declare("wire",LowerTypes.loweredName(data),tpe(data)) + declare("wire",LowerTypes.loweredName(addr),tpe(addr)) + declare("wire",LowerTypes.loweredName(en),tpe(en)) + declare("wire",LowerTypes.loweredName(clk),tpe(clk)) //; Read port assign(addr,netlist(addr)) //;Connects value to m.r.addr @@ -471,11 +471,11 @@ class VerilogEmitter extends Emitter { val en = mem_exp(w,"en") val clk = mem_exp(w,"clk") - declare("wire",lowered_name(data),tpe(data)) - declare("wire",lowered_name(addr),tpe(addr)) - declare("wire",lowered_name(mask),tpe(mask)) - declare("wire",lowered_name(en),tpe(en)) - declare("wire",lowered_name(clk),tpe(clk)) + declare("wire",LowerTypes.loweredName(data),tpe(data)) + declare("wire",LowerTypes.loweredName(addr),tpe(addr)) + declare("wire",LowerTypes.loweredName(mask),tpe(mask)) + declare("wire",LowerTypes.loweredName(en),tpe(en)) + declare("wire",LowerTypes.loweredName(clk),tpe(clk)) //; Write port assign(data,netlist(data)) @@ -501,13 +501,13 @@ class VerilogEmitter extends Emitter { val en = mem_exp(rw,"en") val clk = mem_exp(rw,"clk") - declare("wire",lowered_name(wmode),tpe(wmode)) - declare("wire",lowered_name(rdata),tpe(rdata)) - declare("wire",lowered_name(data),tpe(data)) - declare("wire",lowered_name(mask),tpe(mask)) - declare("wire",lowered_name(addr),tpe(addr)) - declare("wire",lowered_name(en),tpe(en)) - declare("wire",lowered_name(clk),tpe(clk)) + declare("wire",LowerTypes.loweredName(wmode),tpe(wmode)) + declare("wire",LowerTypes.loweredName(rdata),tpe(rdata)) + declare("wire",LowerTypes.loweredName(data),tpe(data)) + declare("wire",LowerTypes.loweredName(mask),tpe(mask)) + declare("wire",LowerTypes.loweredName(addr),tpe(addr)) + declare("wire",LowerTypes.loweredName(en),tpe(en)) + declare("wire",LowerTypes.loweredName(clk),tpe(clk)) //; Assigned to lowered wires of each assign(clk,netlist(clk)) diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 951b5c75..16a56893 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -136,13 +136,6 @@ object Utils { } } } - def lowered_name (e:Expression) : String = { - (e) match { - case (e:WRef) => e.name - case (e:WSubField) => lowered_name(e.exp) + "_" + e.name - case (e:WSubIndex) => lowered_name(e.exp) + "_" + e.value - } - } def get_flip (t:Type, i:Int, f:Flip) : Flip = { if (i >= get_size(t)) error("Shouldn't be here") val x = t match { diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index 7adf7ad3..47e78170 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -27,231 +27,276 @@ MODIFICATIONS. package firrtl.passes +import com.typesafe.scalalogging.LazyLogging + import firrtl._ import firrtl.Utils._ import firrtl.Mappers._ // Datastructures -import scala.collection.mutable.LinkedHashMap import scala.collection.mutable.HashMap -import scala.collection.mutable.ArrayBuffer +/** Removes all aggregate types from a [[Circuit]] + * + * @note Assumes [[firrtl.SubAccess]]es have been removed + * @note Assumes [[firrtl.Connect]]s and [[firrtl.IsInvalid]]s only operate on [[firrtl.Expression]]s of ground type + * @example + * {{{ + * wire foo : { a : UInt<32>, b : UInt<16> } + * }}} lowers to + * {{{ + * wire foo_a : UInt<32> + * wire foo_b : UInt<16> + * }}} + */ object LowerTypes extends Pass { - def name = "Lower Types" - var mname = "" - def is_ground (t:Type) : Boolean = { - (t) match { - case (_:UIntType|_:SIntType) => true - case (t) => false + def name = "Lower Types" + + /** Delimiter used in lowering names */ + val delim = "_" + /** Expands a chain of referential [[firrtl.Expression]]s into the equivalent lowered name + * @param e [[firrtl.Expression]] made up of _only_ [[firrtl.WRef]], [[firrtl.WSubField]], and [[firrtl.WSubIndex]] + * @return Lowered name of e + */ + def loweredName(e: Expression): String = e match { + case e: WRef => e.name + case e: WSubField => loweredName(e.exp) + delim + e.name + case e: WSubIndex => loweredName(e.exp) + delim + e.value + } + def loweredName(s: Seq[String]): String = s.mkString(delim) + + private case class LowerTypesException(msg: String) extends FIRRTLException(msg) + private def error(msg: String)(implicit sinfo: Info, mname: String) = + throw new LowerTypesException(s"$sinfo: [module $mname] $msg") + + // Useful for splitting then remerging references + private case object EmptyExpression extends Expression { def tpe = UnknownType() } + + /** Splits an Expression into root Ref and tail + * + * @example + * Given: SubField(SubIndex(SubField(Ref("a", UIntType(IntWidth(32))), "b"), 2), "c") + * Returns: (Ref("a"), SubField(SubIndex(Ref("b"), 2), "c")) + * a.b[2].c -> (a, b[2].c) + * @example + * Given: SubField(SubIndex(Ref("b"), 2), "c") + * Returns: (Ref("b"), SubField(SubIndex(EmptyExpression, 2), "c")) + * b[2].c -> (b, EMPTY[2].c) + * @note This function only supports WRef, WSubField, and WSubIndex + */ + private def splitRef(e: Expression): (WRef, Expression) = e match { + case e: WRef => (e, EmptyExpression) + case e: WSubIndex => + val (root, tail) = splitRef(e.exp) + (root, WSubIndex(tail, e.value, e.tpe, e.gender)) + case e: WSubField => + val (root, tail) = splitRef(e.exp) + tail match { + case EmptyExpression => (root, WRef(e.name, e.tpe, root.kind, e.gender)) + case exp => (root, WSubField(tail, e.name, e.tpe, e.gender)) } - } - def data (ex:Expression) : Boolean = { - (kind(ex)) match { - case (k:MemKind) => (ex) match { - case (_:WRef|_:WSubIndex) => false - case (ex:WSubField) => { - var yes = ex.name match { - case "rdata" => true - case "data" => true - case "mask" => true - case _ => false - } - yes && ((ex.exp) match { - case (e:WSubField) => kind(e).as[MemKind].get.ports.contains(e.name) && (e.exp.typeof[WRef]) - case (e) => false - }) + } + + /** Adds a root reference to some SubField/SubIndex chain */ + private def mergeRef(root: WRef, body: Expression): Expression = body match { + case e: WRef => + WSubField(root, e.name, e.tpe, e.gender) + case e: WSubIndex => + WSubIndex(mergeRef(root, e.exp), e.value, e.tpe, e.gender) + case e: WSubField => + WSubField(mergeRef(root, e.exp), e.name, e.tpe, e.gender) + case EmptyExpression => root + } + + // TODO Improve? Probably not the best way to do this + private def splitMemRef(e1: Expression): (WRef, WRef, WRef, Option[Expression]) = { + val (mem, tail1) = splitRef(e1) + val (port, tail2) = splitRef(tail1) + tail2 match { + case e2: WRef => + (mem, port, e2, None) + case _ => + val (field, tail3) = splitRef(tail2) + (mem, port, field, Some(tail3)) + } + } + + // Everything wrapped in run so that it's thread safe + def run(c: Circuit): Circuit = { + // Debug state + implicit var mname: String = "" + implicit var sinfo: Info = NoInfo + + def lowerTypes(m: Module): Module = { + val memDataTypeMap = HashMap[String, Type]() + + // Lowers an expression of MemKind + // Since mems with Bundle type must be split into multiple ground type + // mem, references to fields addr, en, clk, and rmode must be replicated + // for each resulting memory + // References to data, mask, and rdata have already been split in expand connects + // and just need to be converted to refer to the correct new memory + def lowerTypesMemExp(e: Expression): Seq[Expression] = { + val (mem, port, field, tail) = splitMemRef(e) + // Fields that need to be replicated for each resulting mem + if (Seq("addr", "en", "clk", "rmode").contains(field.name)) { + require(tail.isEmpty) // there can't be a tail for these + val memType = memDataTypeMap(mem.name) + + if (memType.isGround) { + Seq(e) + } else { + val exps = create_exps(mem.name, memType) + exps map { e => + val loMemName = loweredName(e) + val loMem = WRef(loMemName, UnknownType(), kind(mem), UNKNOWNGENDER) + mergeRef(loMem, mergeRef(port, field)) } - case (ex) => false - } - case (k) => false + } + // Fields that need not be replicated for each + // eg. mem.reader.data[0].a + // (Connect/IsInvalid must already have been split to gorund types) + } else if (Seq("data", "mask", "rdata").contains(field.name)) { + val loMem = tail match { + case Some(e) => + val loMemExp = mergeRef(mem, e) + val loMemName = loweredName(loMemExp) + WRef(loMemName, UnknownType(), kind(mem), UNKNOWNGENDER) + case None => mem + } + Seq(mergeRef(loMem, mergeRef(port, field))) + } else { + error(s"Error! Unhandled memory field ${field.name}") + } } - } - def expand_name (e:Expression) : Seq[String] = { - val names = ArrayBuffer[String]() - def expand_name_e (e:Expression) : Expression = { - (e map (expand_name_e)) match { - case (e:WRef) => names += e.name - case (e:WSubField) => names += e.name - case (e:WSubIndex) => names += e.value.toString - } - e - } - expand_name_e(e) - names - } - def lower_other_mem (e:Expression, dt:Type) : Seq[Expression] = { - val names = expand_name(e) - if (names.size < 3) error("Shouldn't be here") - create_exps(names(0),dt).map{ x => { - var base = lowered_name(x) - for (i <- 0 until names.size) { - if (i >= 3) base = base + "_" + names(i) - } - val m = WRef(base, UnknownType(), kind(e), UNKNOWNGENDER) - val p = WSubField(m,names(1),UnknownType(),UNKNOWNGENDER) - WSubField(p,names(2),UnknownType(),UNKNOWNGENDER) - }} - } - def lower_data_mem (e:Expression) : Expression = { - val names = expand_name(e) - if (names.size < 3) error("Shouldn't be here") - else { - var base = names(0) - for (i <- 0 until names.size) { - if (i >= 3) base = base + "_" + names(i) - } - val m = WRef(base, UnknownType(), kind(e), UNKNOWNGENDER) - val p = WSubField(m,names(1),UnknownType(),UNKNOWNGENDER) - WSubField(p,names(2),UnknownType(),UNKNOWNGENDER) - } - } - def merge (a:String,b:String,x:String) : String = a + x + b - def root_ref (e:Expression) : WRef = { - (e) match { - case (e:WRef) => e - case (e:WSubField) => root_ref(e.exp) - case (e:WSubIndex) => root_ref(e.exp) - case (e:WSubAccess) => root_ref(e.exp) + + def lowerTypesExp(e: Expression): Expression = e match { + case e: WRef => e + case (_: WSubField | _: WSubIndex) => kind(e) match { + case k: InstanceKind => + val (root, tail) = splitRef(e) + val name = loweredName(tail) + WSubField(root, name, tpe(e), gender(e)) + case k: MemKind => + val exps = lowerTypesMemExp(e) + if (exps.length > 1) + error("Error! lowerTypesExp called on MemKind SubField that needs" + + " to be expanded!") + exps(0) + case k => + WRef(loweredName(e), tpe(e), kind(e), gender(e)) + } + case e: Mux => e map (lowerTypesExp) + case e: ValidIf => e map (lowerTypesExp) + case (_: UIntValue | _: SIntValue) => e + case e: DoPrim => e map (lowerTypesExp) } - } - - //;------------- Pass ------------------ - - def lower_types (m:Module) : Module = { - val mdt = LinkedHashMap[String,Type]() - mname = m.name - def lower_types (s:Stmt) : Stmt = { - def lower_mem (e:Expression) : Seq[Expression] = { - val names = expand_name(e) - if (Seq("data","mask","rdata").contains(names(2))) Seq(lower_data_mem(e)) - else lower_other_mem(e,mdt(root_ref(e).name)) - } - def lower_types_e (e:Expression) : Expression = { - e match { - case (_:WRef|_:UIntValue|_:SIntValue) => e - case (_:WSubField|_:WSubIndex) => { - (kind(e)) match { - case (k:InstanceKind) => { - val names = expand_name(e) - var n = names(1) - for (i <- 0 until names.size) { - if (i > 1) n = n + "_" + names(i) - } - WSubField(root_ref(e),n,tpe(e),gender(e)) - } - case (k:MemKind) => { - if (gender(e) != FEMALE) lower_mem(e)(0) - else e - } - case (k) => WRef(lowered_name(e),tpe(e),kind(e),gender(e)) - } - } - case (e:DoPrim) => e map (lower_types_e) - case (e:Mux) => e map (lower_types_e) - case (e:ValidIf) => e map (lower_types_e) - } - } - (s) match { - case (s:DefWire) => { - if (is_ground(s.tpe)) s else { - val es = create_exps(s.name,s.tpe) - val stmts = (es, 0 until es.size).zipped.map{ (e,i) => { - DefWire(s.info,lowered_name(e),tpe(e)) - }} - Begin(stmts) - } - } - case (s:DefPoison) => { - if (is_ground(s.tpe)) s else { - val es = create_exps(s.name,s.tpe) - val stmts = (es, 0 until es.size).zipped.map{ (e,i) => { - DefPoison(s.info,lowered_name(e),tpe(e)) - }} - Begin(stmts) - } + + def lowerTypesStmt(s: Stmt): Stmt = { + s map lowerTypesStmt match { + case s: DefWire => + sinfo = s.info + if (s.tpe.isGround) { + s + } else { + val exps = create_exps(s.name, s.tpe) + val stmts = exps map (e => DefWire(s.info, loweredName(e), tpe(e))) + Begin(stmts) } - case (s:DefRegister) => { - if (is_ground(s.tpe)) s else { - val es = create_exps(s.name,s.tpe) - val inits = create_exps(s.init) - val stmts = (es, 0 until es.size).zipped.map{ (e,i) => { - val init = lower_types_e(inits(i)) - DefRegister(s.info,lowered_name(e),tpe(e),s.clock,s.reset,init) - }} - Begin(stmts) - } + case s: DefRegister => + sinfo = s.info + if (s.tpe.isGround) { + s map lowerTypesExp + } else { + val es = create_exps(s.name, s.tpe) + val inits = create_exps(s.init) map (lowerTypesExp) + val clock = lowerTypesExp(s.clock) + val reset = lowerTypesExp(s.reset) + val stmts = es zip inits map { case (e, i) => + DefRegister(s.info, loweredName(e), tpe(e), clock, reset, i) + } + Begin(stmts) } - case (s:WDefInstance) => { - val fieldsx = s.tpe.as[BundleType].get.fields.flatMap{ f => { - val es = create_exps(WRef(f.name,f.tpe,ExpKind(),times(f.flip,MALE))) - es.map{ e => { - gender(e) match { - case MALE => Field(lowered_name(e),DEFAULT,f.tpe) - case FEMALE => Field(lowered_name(e),REVERSE,f.tpe) - } - }} - }} - WDefInstance(s.info,s.name,s.module,BundleType(fieldsx)) + // Could instead just save the type of each Module as it gets processed + case s: WDefInstance => + sinfo = s.info + s.tpe match { + case t: BundleType => + val fieldsx = t.fields flatMap { f => + val exps = create_exps(WRef(f.name, f.tpe, ExpKind(), times(f.flip, MALE))) + exps map ( e => + // Flip because inst genders are reversed from Module type + Field(loweredName(e), toFlip(gender(e)).flip, tpe(e)) + ) + } + WDefInstance(s.info, s.name, s.module, BundleType(fieldsx)) + case _ => error("WDefInstance type should be Bundle!") } - case (s:DefMemory) => { - mdt(s.name) = s.data_type - if (is_ground(s.data_type)) s else { - val es = create_exps(s.name,s.data_type) - val stmts = es.map{ e => { - DefMemory(s.info,lowered_name(e),tpe(e),s.depth,s.write_latency,s.read_latency,s.readers,s.writers,s.readwriters) - }} - Begin(stmts) - } + case s: DefMemory => + sinfo = s.info + memDataTypeMap += (s.name -> s.data_type) + if (s.data_type.isGround) { + s + } else { + val exps = create_exps(s.name, s.data_type) + val stmts = exps map { e => + DefMemory(s.info, loweredName(e), tpe(e), s.depth, + s.write_latency, s.read_latency, s.readers, s.writers, + s.readwriters) + } + Begin(stmts) } - case (s:IsInvalid) => { - val sx = (s map (lower_types_e)).as[IsInvalid].get - kind(sx.exp) match { - case (k:MemKind) => { - val es = lower_mem(sx.exp) - Begin(es.map(e => {IsInvalid(sx.info,e)})) - } - case (_) => sx - } + // wire foo : { a , b } + // node x = foo + // node y = x.a + // -> + // node x_a = foo_a + // node x_b = foo_b + // node y = x_a + case s: DefNode => + sinfo = s.info + val names = create_exps(s.name, tpe(s.value)) map (lowerTypesExp) + val exps = create_exps(s.value) map (lowerTypesExp) + val stmts = names zip exps map { case (n, e) => + DefNode(s.info, loweredName(n), e) } - case (s:Connect) => { - val sx = (s map (lower_types_e)).as[Connect].get - kind(sx.loc) match { - case (k:MemKind) => { - val es = lower_mem(sx.loc) - Begin(es.map(e => {Connect(sx.info,e,sx.exp)})) - } - case (_) => sx - } + Begin(stmts) + case s: IsInvalid => + sinfo = s.info + kind(s.exp) match { + case k: MemKind => + val exps = lowerTypesMemExp(s.exp) + Begin(exps map (exp => IsInvalid(s.info, exp))) + case _ => s map (lowerTypesExp) } - case (s:DefNode) => { - val locs = create_exps(s.name,tpe(s.value)) - val n = locs.size - val nodes = ArrayBuffer[Stmt]() - val exps = create_exps(s.value) - for (i <- 0 until n) { - val locx = locs(i) - val expx = exps(i) - nodes += DefNode(s.info,lowered_name(locx),lower_types_e(expx)) - } - if (n == 1) nodes(0) else Begin(nodes) + case s: Connect => + sinfo = s.info + kind(s.loc) match { + case k: MemKind => + val exp = lowerTypesExp(s.exp) + val locs = lowerTypesMemExp(s.loc) + Begin(locs map (loc => Connect(s.info, loc, exp))) + case _ => s map (lowerTypesExp) } - case (s) => s map (lower_types) map (lower_types_e) - } + case s => s map (lowerTypesExp) + } } - - val portsx = m.ports.flatMap{ p => { - val es = create_exps(WRef(p.name,p.tpe,PortKind(),to_gender(p.direction))) - es.map(e => { Port(p.info,lowered_name(e),to_dir(gender(e)),tpe(e)) }) - }} - (m) match { - case (m:ExModule) => ExModule(m.info,m.name,portsx) - case (m:InModule) => InModule(m.info,m.name,portsx,lower_types(m.body)) + + sinfo = m.info + mname = m.name + // Lower Ports + val portsx = m.ports flatMap { p => + val exps = create_exps(WRef(p.name, p.tpe, PortKind(), to_gender(p.direction))) + exps map ( e => Port(p.info, loweredName(e), to_dir(gender(e)), tpe(e)) ) } - } - - def run (c:Circuit) : Circuit = { - val modulesx = c.modules.map(m => lower_types(m)) - Circuit(c.info,modulesx,c.main) - } + m match { + case m: ExModule => m.copy(ports = portsx) + case m: InModule => InModule(m.info, m.name, portsx, lowerTypesStmt(m.body)) + } + } + + sinfo = c.info + Circuit(c.info, c.modules map lowerTypes, c.main) + } } |
