aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDonggyu Kim2016-09-01 14:49:40 -0700
committerDonggyu Kim2016-09-13 16:59:45 -0700
commit856909047609020957023ddf12f9dadc927d1a05 (patch)
tree95b09e27526a55b2846fa4c70c0c476a00901625 /src
parent6b9a0e6253d375369cda308d26a61455076b3f7c (diff)
clean up LowerTypes
no vars for mname, info
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/LowerTypes.scala290
1 files changed, 134 insertions, 156 deletions
diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala
index c72c5fe3..b3969bea 100644
--- a/src/main/scala/firrtl/passes/LowerTypes.scala
+++ b/src/main/scala/firrtl/passes/LowerTypes.scala
@@ -32,9 +32,6 @@ import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
-// Datastructures
-import scala.collection.mutable.HashMap
-
/** Removes all aggregate types from a [[firrtl.ir.Circuit]]
*
* @note Assumes [[firrtl.ir.SubAccess]]es have been removed
@@ -65,8 +62,8 @@ object LowerTypes extends Pass {
def loweredName(s: Seq[String]): String = s mkString delim
private case class LowerTypesException(msg: String) extends FIRRTLException(msg)
- private def error(msg: String)(implicit sinfo: Info, mname: String) =
- throw new LowerTypesException(s"$sinfo: [module $mname] $msg")
+ private def error(msg: String)(info: Info, mname: String) =
+ throw LowerTypesException(s"$info: [module $mname] $msg")
// TODO Improve? Probably not the best way to do this
private def splitMemRef(e1: Expression): (WRef, WRef, WRef, Option[Expression]) = {
@@ -81,165 +78,146 @@ object LowerTypes extends Pass {
}
}
- // Everything wrapped in run so that it's thread safe
- def run(c: Circuit): Circuit = {
- // Debug state
- implicit var mname: String = ""
- implicit var sinfo: Info = NoInfo
-
- def lowerTypes(m: DefModule): DefModule = {
- val memDataTypeMap = HashMap[String, Type]()
-
- // Lowers an expression of MemKind
- // Since mems with Bundle type must be split into multiple ground type
- // mem, references to fields addr, en, clk, and rmode must be replicated
- // for each resulting memory
- // References to data, mask, rdata, wdata, and wmask have already been split in expand connects
- // and just need to be converted to refer to the correct new memory
- def lowerTypesMemExp(e: Expression): Seq[Expression] = {
- val (mem, port, field, tail) = splitMemRef(e)
- field.name match {
- // Fields that need to be replicated for each resulting mem
- case "addr" | "en" | "clk" | "wmode" =>
- require(tail.isEmpty) // there can't be a tail for these
- memDataTypeMap(mem.name) match {
- case _: GroundType => Seq(e)
- case memType => create_exps(mem.name, memType) map { e =>
- val loMemName = loweredName(e)
- val loMem = WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER)
- mergeRef(loMem, mergeRef(port, field))
- }
- }
- // Fields that need not be replicated for each
- // eg. mem.reader.data[0].a
- // (Connect/IsInvalid must already have been split to ground types)
- case "data" | "mask" | "rdata" | "wdata" | "wmask" =>
- val loMem = tail match {
- case Some(e) =>
- val loMemExp = mergeRef(mem, e)
- val loMemName = loweredName(loMemExp)
- WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER)
- case None => mem
- }
- Seq(mergeRef(loMem, mergeRef(port, field)))
- case name => error(s"Error! Unhandled memory field ${name}")
+ // Lowers an expression of MemKind
+ // Since mems with Bundle type must be split into multiple ground type
+ // mem, references to fields addr, en, clk, and rmode must be replicated
+ // for each resulting memory
+ // References to data, mask, rdata, wdata, and wmask have already been split in expand connects
+ // and just need to be converted to refer to the correct new memory
+ type MemDataTypeMap = collection.mutable.HashMap[String, Type]
+ def lowerTypesMemExp(memDataTypeMap: MemDataTypeMap,
+ info: Info, mname: String)(e: Expression): Seq[Expression] = {
+ val (mem, port, field, tail) = splitMemRef(e)
+ field.name match {
+ // Fields that need to be replicated for each resulting mem
+ case "addr" | "en" | "clk" | "wmode" =>
+ require(tail.isEmpty) // there can't be a tail for these
+ memDataTypeMap(mem.name) match {
+ case _: GroundType => Seq(e)
+ case memType => create_exps(mem.name, memType) map { e =>
+ val loMemName = loweredName(e)
+ val loMem = WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER)
+ mergeRef(loMem, mergeRef(port, field))
+ }
}
- }
+ // Fields that need not be replicated for each
+ // eg. mem.reader.data[0].a
+ // (Connect/IsInvalid must already have been split to ground types)
+ case "data" | "mask" | "rdata" | "wdata" | "wmask" =>
+ val loMem = tail match {
+ case Some(e) =>
+ val loMemExp = mergeRef(mem, e)
+ val loMemName = loweredName(loMemExp)
+ WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER)
+ case None => mem
+ }
+ Seq(mergeRef(loMem, mergeRef(port, field)))
+ case name => error(s"Error! Unhandled memory field ${name}")(info, mname)
+ }
+ }
- def lowerTypesExp(e: Expression): Expression = e match {
- case e: WRef => e
- case (_: WSubField | _: WSubIndex) => kind(e) match {
- case k: InstanceKind =>
- val (root, tail) = splitRef(e)
- val name = loweredName(tail)
- WSubField(root, name, e.tpe, gender(e))
- case k: MemKind =>
- val exps = lowerTypesMemExp(e)
- exps.size match {
- case 1 => exps.head
- case _ => error("Error! lowerTypesExp called on MemKind " +
- "SubField that needs to be expanded!")
- }
- case _ => WRef(loweredName(e), e.tpe, kind(e), gender(e))
+ def lowerTypesExp(memDataTypeMap: MemDataTypeMap,
+ info: Info, mname: String)(e: Expression): Expression = e match {
+ case e: WRef => e
+ case (_: WSubField | _: WSubIndex) => kind(e) match {
+ case k: InstanceKind =>
+ val (root, tail) = splitRef(e)
+ val name = loweredName(tail)
+ WSubField(root, name, e.tpe, gender(e))
+ case k: MemKind =>
+ val exps = lowerTypesMemExp(memDataTypeMap, info, mname)(e)
+ exps.size match {
+ case 1 => exps.head
+ case _ => error("Error! lowerTypesExp called on MemKind " +
+ "SubField that needs to be expanded!")(info, mname)
}
- case e: Mux => e map (lowerTypesExp)
- case e: ValidIf => e map (lowerTypesExp)
- case e: DoPrim => e map (lowerTypesExp)
- case e @ (_: UIntLiteral | _: SIntLiteral) => e
- }
+ case _ => WRef(loweredName(e), e.tpe, kind(e), gender(e))
+ }
+ case e: Mux => e map lowerTypesExp(memDataTypeMap, info, mname)
+ case e: ValidIf => e map lowerTypesExp(memDataTypeMap, info, mname)
+ case e: DoPrim => e map lowerTypesExp(memDataTypeMap, info, mname)
+ case e @ (_: UIntLiteral | _: SIntLiteral) => e
+ }
- def lowerTypesStmt(s: Statement): Statement = s map lowerTypesStmt match {
- case s: DefWire =>
- sinfo = s.info
- s.tpe match {
- case _: GroundType => s
- case _ => Block(create_exps(s.name, s.tpe) map (
- e => DefWire(s.info, loweredName(e), e.tpe)))
- }
- case s: DefRegister =>
- sinfo = s.info
- s.tpe match {
- case _: GroundType => s map lowerTypesExp
- case _ =>
- val es = create_exps(s.name, s.tpe)
- val inits = create_exps(s.init) map (lowerTypesExp)
- val clock = lowerTypesExp(s.clock)
- val reset = lowerTypesExp(s.reset)
- Block(es zip inits map { case (e, i) =>
- DefRegister(s.info, loweredName(e), e.tpe, clock, reset, i)
- })
- }
- // Could instead just save the type of each Module as it gets processed
- case s: WDefInstance =>
- sinfo = s.info
- s.tpe match {
- case t: BundleType =>
- val fieldsx = t.fields flatMap (f =>
- create_exps(WRef(f.name, f.tpe, ExpKind(), times(f.flip, MALE))) map (
- // Flip because inst genders are reversed from Module type
- e => Field(loweredName(e), swap(to_flip(gender(e))), e.tpe)
- )
- )
- WDefInstance(s.info, s.name, s.module, BundleType(fieldsx))
- case _ => error("WDefInstance type should be Bundle!")
- }
- case s: DefMemory =>
- sinfo = s.info
- memDataTypeMap(s.name) = s.dataType
- s.dataType match {
- case _: GroundType => s
- case _ => Block(create_exps(s.name, s.dataType) map (e =>
- DefMemory(s.info, loweredName(e), e.tpe, s.depth,
- s.writeLatency, s.readLatency, s.readers, s.writers,
- s.readwriters)))
- }
- // wire foo : { a , b }
- // node x = foo
- // node y = x.a
- // ->
- // node x_a = foo_a
- // node x_b = foo_b
- // node y = x_a
- case s: DefNode =>
- sinfo = s.info
- val names = create_exps(s.name, s.value.tpe) map (lowerTypesExp)
- val exps = create_exps(s.value) map (lowerTypesExp)
- Block(names zip exps map {case (n, e) => DefNode(s.info, loweredName(n), e)})
- case s: IsInvalid =>
- sinfo = s.info
- kind(s.expr) match {
- case k: MemKind =>
- Block(lowerTypesMemExp(s.expr) map (IsInvalid(s.info, _)))
- case _ => s map (lowerTypesExp)
- }
- case s: Connect =>
- sinfo = s.info
- kind(s.loc) match {
- case k: MemKind =>
- val exp = lowerTypesExp(s.expr)
- val locs = lowerTypesMemExp(s.loc)
- Block(locs map (Connect(s.info, _, exp)))
- case _ => s map (lowerTypesExp)
- }
- case s => s map (lowerTypesExp)
+ def lowerTypesStmt(memDataTypeMap: MemDataTypeMap,
+ minfo: Info, mname: String)(s: Statement): Statement = {
+ val info = get_info(s) match {case NoInfo => minfo case x => x}
+ s map lowerTypesStmt(memDataTypeMap, info, mname) match {
+ case s: DefWire => s.tpe match {
+ case _: GroundType => s
+ case _ => Block(create_exps(s.name, s.tpe) map (
+ e => DefWire(s.info, loweredName(e), e.tpe)))
}
-
- sinfo = m.info
- mname = m.name
- // Lower Ports
- val portsx = m.ports flatMap ( p =>
- create_exps(WRef(p.name, p.tpe, PortKind(), to_gender(p.direction))) map (
- e => Port(p.info, loweredName(e), to_dir(gender(e)), e.tpe)
- )
- )
- m match {
- case m: ExtModule => m.copy(ports = portsx)
- case m: Module => Module(m.info, m.name, portsx, lowerTypesStmt(m.body))
+ case s: DefRegister => s.tpe match {
+ case _: GroundType => s map lowerTypesExp(memDataTypeMap, info, mname)
+ case _ =>
+ val es = create_exps(s.name, s.tpe)
+ val inits = create_exps(s.init) map lowerTypesExp(memDataTypeMap, info, mname)
+ val clock = lowerTypesExp(memDataTypeMap, info, mname)(s.clock)
+ val reset = lowerTypesExp(memDataTypeMap, info, mname)(s.reset)
+ Block(es zip inits map { case (e, i) =>
+ DefRegister(s.info, loweredName(e), e.tpe, clock, reset, i)
+ })
+ }
+ // Could instead just save the type of each Module as it gets processed
+ case s: WDefInstance => s.tpe match {
+ case t: BundleType =>
+ val fieldsx = t.fields flatMap (f =>
+ create_exps(WRef(f.name, f.tpe, ExpKind(), times(f.flip, MALE))) map (
+ // Flip because inst genders are reversed from Module type
+ e => Field(loweredName(e), swap(to_flip(gender(e))), e.tpe)))
+ WDefInstance(s.info, s.name, s.module, BundleType(fieldsx))
+ case _ => error("WDefInstance type should be Bundle!")(info, mname)
}
+ case s: DefMemory =>
+ memDataTypeMap(s.name) = s.dataType
+ s.dataType match {
+ case _: GroundType => s
+ case _ => Block(create_exps(s.name, s.dataType) map (e =>
+ s copy (name = loweredName(e), dataType = e.tpe)))
+ }
+ // wire foo : { a , b }
+ // node x = foo
+ // node y = x.a
+ // ->
+ // node x_a = foo_a
+ // node x_b = foo_b
+ // node y = x_a
+ case s: DefNode =>
+ val names = create_exps(s.name, s.value.tpe) map lowerTypesExp(memDataTypeMap, info, mname)
+ val exps = create_exps(s.value) map lowerTypesExp(memDataTypeMap, info, mname)
+ Block(names zip exps map { case (n, e) => DefNode(info, loweredName(n), e) })
+ case s: IsInvalid => kind(s.expr) match {
+ case _: MemKind =>
+ Block(lowerTypesMemExp(memDataTypeMap, info, mname)(s.expr) map (IsInvalid(info, _)))
+ case _ => s map lowerTypesExp(memDataTypeMap, info, mname)
+ }
+ case s: Connect => kind(s.loc) match {
+ case k: MemKind =>
+ val exp = lowerTypesExp(memDataTypeMap, info, mname)(s.expr)
+ val locs = lowerTypesMemExp(memDataTypeMap, info, mname)(s.loc)
+ Block(locs map (Connect(info, _, exp)))
+ case _ => s map lowerTypesExp(memDataTypeMap, info, mname)
+ }
+ case s => s map lowerTypesExp(memDataTypeMap, info, mname)
}
+ }
- sinfo = c.info
- Circuit(c.info, c.modules map lowerTypes, c.main)
+ def lowerTypes(m: DefModule): DefModule = {
+ val memDataTypeMap = new MemDataTypeMap
+ // Lower Ports
+ val portsx = m.ports flatMap { p =>
+ val exps = create_exps(WRef(p.name, p.tpe, PortKind(), to_gender(p.direction)))
+ exps map (e => Port(p.info, loweredName(e), to_dir(gender(e)), e.tpe))
+ }
+ m match {
+ case m: ExtModule =>
+ m copy (ports = portsx)
+ case m: Module =>
+ m copy (ports = portsx) map lowerTypesStmt(memDataTypeMap, m.info, m.name)
+ }
}
+
+ def run(c: Circuit): Circuit = c copy (modules = (c.modules map lowerTypes))
}