diff options
| author | Donggyu Kim | 2016-09-01 14:49:40 -0700 |
|---|---|---|
| committer | Donggyu Kim | 2016-09-13 16:59:45 -0700 |
| commit | 856909047609020957023ddf12f9dadc927d1a05 (patch) | |
| tree | 95b09e27526a55b2846fa4c70c0c476a00901625 /src | |
| parent | 6b9a0e6253d375369cda308d26a61455076b3f7c (diff) | |
clean up LowerTypes
no vars for mname, info
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/LowerTypes.scala | 290 |
1 files changed, 134 insertions, 156 deletions
diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index c72c5fe3..b3969bea 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -32,9 +32,6 @@ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ -// Datastructures -import scala.collection.mutable.HashMap - /** Removes all aggregate types from a [[firrtl.ir.Circuit]] * * @note Assumes [[firrtl.ir.SubAccess]]es have been removed @@ -65,8 +62,8 @@ object LowerTypes extends Pass { def loweredName(s: Seq[String]): String = s mkString delim private case class LowerTypesException(msg: String) extends FIRRTLException(msg) - private def error(msg: String)(implicit sinfo: Info, mname: String) = - throw new LowerTypesException(s"$sinfo: [module $mname] $msg") + private def error(msg: String)(info: Info, mname: String) = + throw LowerTypesException(s"$info: [module $mname] $msg") // TODO Improve? Probably not the best way to do this private def splitMemRef(e1: Expression): (WRef, WRef, WRef, Option[Expression]) = { @@ -81,165 +78,146 @@ object LowerTypes extends Pass { } } - // Everything wrapped in run so that it's thread safe - def run(c: Circuit): Circuit = { - // Debug state - implicit var mname: String = "" - implicit var sinfo: Info = NoInfo - - def lowerTypes(m: DefModule): DefModule = { - val memDataTypeMap = HashMap[String, Type]() - - // Lowers an expression of MemKind - // Since mems with Bundle type must be split into multiple ground type - // mem, references to fields addr, en, clk, and rmode must be replicated - // for each resulting memory - // References to data, mask, rdata, wdata, and wmask have already been split in expand connects - // and just need to be converted to refer to the correct new memory - def lowerTypesMemExp(e: Expression): Seq[Expression] = { - val (mem, port, field, tail) = splitMemRef(e) - field.name match { - // Fields that need to be replicated for each resulting mem - case "addr" | "en" | "clk" | "wmode" => - require(tail.isEmpty) // there can't be a tail for these - memDataTypeMap(mem.name) match { - case _: GroundType => Seq(e) - case memType => create_exps(mem.name, memType) map { e => - val loMemName = loweredName(e) - val loMem = WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER) - mergeRef(loMem, mergeRef(port, field)) - } - } - // Fields that need not be replicated for each - // eg. mem.reader.data[0].a - // (Connect/IsInvalid must already have been split to ground types) - case "data" | "mask" | "rdata" | "wdata" | "wmask" => - val loMem = tail match { - case Some(e) => - val loMemExp = mergeRef(mem, e) - val loMemName = loweredName(loMemExp) - WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER) - case None => mem - } - Seq(mergeRef(loMem, mergeRef(port, field))) - case name => error(s"Error! Unhandled memory field ${name}") + // Lowers an expression of MemKind + // Since mems with Bundle type must be split into multiple ground type + // mem, references to fields addr, en, clk, and rmode must be replicated + // for each resulting memory + // References to data, mask, rdata, wdata, and wmask have already been split in expand connects + // and just need to be converted to refer to the correct new memory + type MemDataTypeMap = collection.mutable.HashMap[String, Type] + def lowerTypesMemExp(memDataTypeMap: MemDataTypeMap, + info: Info, mname: String)(e: Expression): Seq[Expression] = { + val (mem, port, field, tail) = splitMemRef(e) + field.name match { + // Fields that need to be replicated for each resulting mem + case "addr" | "en" | "clk" | "wmode" => + require(tail.isEmpty) // there can't be a tail for these + memDataTypeMap(mem.name) match { + case _: GroundType => Seq(e) + case memType => create_exps(mem.name, memType) map { e => + val loMemName = loweredName(e) + val loMem = WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER) + mergeRef(loMem, mergeRef(port, field)) + } } - } + // Fields that need not be replicated for each + // eg. mem.reader.data[0].a + // (Connect/IsInvalid must already have been split to ground types) + case "data" | "mask" | "rdata" | "wdata" | "wmask" => + val loMem = tail match { + case Some(e) => + val loMemExp = mergeRef(mem, e) + val loMemName = loweredName(loMemExp) + WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER) + case None => mem + } + Seq(mergeRef(loMem, mergeRef(port, field))) + case name => error(s"Error! Unhandled memory field ${name}")(info, mname) + } + } - def lowerTypesExp(e: Expression): Expression = e match { - case e: WRef => e - case (_: WSubField | _: WSubIndex) => kind(e) match { - case k: InstanceKind => - val (root, tail) = splitRef(e) - val name = loweredName(tail) - WSubField(root, name, e.tpe, gender(e)) - case k: MemKind => - val exps = lowerTypesMemExp(e) - exps.size match { - case 1 => exps.head - case _ => error("Error! lowerTypesExp called on MemKind " + - "SubField that needs to be expanded!") - } - case _ => WRef(loweredName(e), e.tpe, kind(e), gender(e)) + def lowerTypesExp(memDataTypeMap: MemDataTypeMap, + info: Info, mname: String)(e: Expression): Expression = e match { + case e: WRef => e + case (_: WSubField | _: WSubIndex) => kind(e) match { + case k: InstanceKind => + val (root, tail) = splitRef(e) + val name = loweredName(tail) + WSubField(root, name, e.tpe, gender(e)) + case k: MemKind => + val exps = lowerTypesMemExp(memDataTypeMap, info, mname)(e) + exps.size match { + case 1 => exps.head + case _ => error("Error! lowerTypesExp called on MemKind " + + "SubField that needs to be expanded!")(info, mname) } - case e: Mux => e map (lowerTypesExp) - case e: ValidIf => e map (lowerTypesExp) - case e: DoPrim => e map (lowerTypesExp) - case e @ (_: UIntLiteral | _: SIntLiteral) => e - } + case _ => WRef(loweredName(e), e.tpe, kind(e), gender(e)) + } + case e: Mux => e map lowerTypesExp(memDataTypeMap, info, mname) + case e: ValidIf => e map lowerTypesExp(memDataTypeMap, info, mname) + case e: DoPrim => e map lowerTypesExp(memDataTypeMap, info, mname) + case e @ (_: UIntLiteral | _: SIntLiteral) => e + } - def lowerTypesStmt(s: Statement): Statement = s map lowerTypesStmt match { - case s: DefWire => - sinfo = s.info - s.tpe match { - case _: GroundType => s - case _ => Block(create_exps(s.name, s.tpe) map ( - e => DefWire(s.info, loweredName(e), e.tpe))) - } - case s: DefRegister => - sinfo = s.info - s.tpe match { - case _: GroundType => s map lowerTypesExp - case _ => - val es = create_exps(s.name, s.tpe) - val inits = create_exps(s.init) map (lowerTypesExp) - val clock = lowerTypesExp(s.clock) - val reset = lowerTypesExp(s.reset) - Block(es zip inits map { case (e, i) => - DefRegister(s.info, loweredName(e), e.tpe, clock, reset, i) - }) - } - // Could instead just save the type of each Module as it gets processed - case s: WDefInstance => - sinfo = s.info - s.tpe match { - case t: BundleType => - val fieldsx = t.fields flatMap (f => - create_exps(WRef(f.name, f.tpe, ExpKind(), times(f.flip, MALE))) map ( - // Flip because inst genders are reversed from Module type - e => Field(loweredName(e), swap(to_flip(gender(e))), e.tpe) - ) - ) - WDefInstance(s.info, s.name, s.module, BundleType(fieldsx)) - case _ => error("WDefInstance type should be Bundle!") - } - case s: DefMemory => - sinfo = s.info - memDataTypeMap(s.name) = s.dataType - s.dataType match { - case _: GroundType => s - case _ => Block(create_exps(s.name, s.dataType) map (e => - DefMemory(s.info, loweredName(e), e.tpe, s.depth, - s.writeLatency, s.readLatency, s.readers, s.writers, - s.readwriters))) - } - // wire foo : { a , b } - // node x = foo - // node y = x.a - // -> - // node x_a = foo_a - // node x_b = foo_b - // node y = x_a - case s: DefNode => - sinfo = s.info - val names = create_exps(s.name, s.value.tpe) map (lowerTypesExp) - val exps = create_exps(s.value) map (lowerTypesExp) - Block(names zip exps map {case (n, e) => DefNode(s.info, loweredName(n), e)}) - case s: IsInvalid => - sinfo = s.info - kind(s.expr) match { - case k: MemKind => - Block(lowerTypesMemExp(s.expr) map (IsInvalid(s.info, _))) - case _ => s map (lowerTypesExp) - } - case s: Connect => - sinfo = s.info - kind(s.loc) match { - case k: MemKind => - val exp = lowerTypesExp(s.expr) - val locs = lowerTypesMemExp(s.loc) - Block(locs map (Connect(s.info, _, exp))) - case _ => s map (lowerTypesExp) - } - case s => s map (lowerTypesExp) + def lowerTypesStmt(memDataTypeMap: MemDataTypeMap, + minfo: Info, mname: String)(s: Statement): Statement = { + val info = get_info(s) match {case NoInfo => minfo case x => x} + s map lowerTypesStmt(memDataTypeMap, info, mname) match { + case s: DefWire => s.tpe match { + case _: GroundType => s + case _ => Block(create_exps(s.name, s.tpe) map ( + e => DefWire(s.info, loweredName(e), e.tpe))) } - - sinfo = m.info - mname = m.name - // Lower Ports - val portsx = m.ports flatMap ( p => - create_exps(WRef(p.name, p.tpe, PortKind(), to_gender(p.direction))) map ( - e => Port(p.info, loweredName(e), to_dir(gender(e)), e.tpe) - ) - ) - m match { - case m: ExtModule => m.copy(ports = portsx) - case m: Module => Module(m.info, m.name, portsx, lowerTypesStmt(m.body)) + case s: DefRegister => s.tpe match { + case _: GroundType => s map lowerTypesExp(memDataTypeMap, info, mname) + case _ => + val es = create_exps(s.name, s.tpe) + val inits = create_exps(s.init) map lowerTypesExp(memDataTypeMap, info, mname) + val clock = lowerTypesExp(memDataTypeMap, info, mname)(s.clock) + val reset = lowerTypesExp(memDataTypeMap, info, mname)(s.reset) + Block(es zip inits map { case (e, i) => + DefRegister(s.info, loweredName(e), e.tpe, clock, reset, i) + }) + } + // Could instead just save the type of each Module as it gets processed + case s: WDefInstance => s.tpe match { + case t: BundleType => + val fieldsx = t.fields flatMap (f => + create_exps(WRef(f.name, f.tpe, ExpKind(), times(f.flip, MALE))) map ( + // Flip because inst genders are reversed from Module type + e => Field(loweredName(e), swap(to_flip(gender(e))), e.tpe))) + WDefInstance(s.info, s.name, s.module, BundleType(fieldsx)) + case _ => error("WDefInstance type should be Bundle!")(info, mname) } + case s: DefMemory => + memDataTypeMap(s.name) = s.dataType + s.dataType match { + case _: GroundType => s + case _ => Block(create_exps(s.name, s.dataType) map (e => + s copy (name = loweredName(e), dataType = e.tpe))) + } + // wire foo : { a , b } + // node x = foo + // node y = x.a + // -> + // node x_a = foo_a + // node x_b = foo_b + // node y = x_a + case s: DefNode => + val names = create_exps(s.name, s.value.tpe) map lowerTypesExp(memDataTypeMap, info, mname) + val exps = create_exps(s.value) map lowerTypesExp(memDataTypeMap, info, mname) + Block(names zip exps map { case (n, e) => DefNode(info, loweredName(n), e) }) + case s: IsInvalid => kind(s.expr) match { + case _: MemKind => + Block(lowerTypesMemExp(memDataTypeMap, info, mname)(s.expr) map (IsInvalid(info, _))) + case _ => s map lowerTypesExp(memDataTypeMap, info, mname) + } + case s: Connect => kind(s.loc) match { + case k: MemKind => + val exp = lowerTypesExp(memDataTypeMap, info, mname)(s.expr) + val locs = lowerTypesMemExp(memDataTypeMap, info, mname)(s.loc) + Block(locs map (Connect(info, _, exp))) + case _ => s map lowerTypesExp(memDataTypeMap, info, mname) + } + case s => s map lowerTypesExp(memDataTypeMap, info, mname) } + } - sinfo = c.info - Circuit(c.info, c.modules map lowerTypes, c.main) + def lowerTypes(m: DefModule): DefModule = { + val memDataTypeMap = new MemDataTypeMap + // Lower Ports + val portsx = m.ports flatMap { p => + val exps = create_exps(WRef(p.name, p.tpe, PortKind(), to_gender(p.direction))) + exps map (e => Port(p.info, loweredName(e), to_dir(gender(e)), e.tpe)) + } + m match { + case m: ExtModule => + m copy (ports = portsx) + case m: Module => + m copy (ports = portsx) map lowerTypesStmt(memDataTypeMap, m.info, m.name) + } } + + def run(c: Circuit): Circuit = c copy (modules = (c.modules map lowerTypes)) } |
