diff options
| author | Donggyu | 2016-09-07 16:33:24 -0700 |
|---|---|---|
| committer | GitHub | 2016-09-07 16:33:24 -0700 |
| commit | a404cf5b2c4ca6457c964eb32aae8330c48422e1 (patch) | |
| tree | d6c137d446a254bae87cd015c94bf86806225042 /src | |
| parent | 13345ce816a51cc19f93a03c4148eecf0dd2c739 (diff) | |
| parent | 8beaa3be259d2a793e3a99628b2f0d38d98f5b9a (diff) | |
Merge pull request #280 from ucb-bar/cleanup_passes
Clean up passes
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/LowerTypes.scala | 237 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 234 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/SplitExpressions.scala | 86 |
3 files changed, 253 insertions, 304 deletions
diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index a4c584ed..57f8fd76 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -61,10 +61,10 @@ object LowerTypes extends Pass { */ def loweredName(e: Expression): String = e match { case e: WRef => e.name - case e: WSubField => loweredName(e.exp) + delim + e.name - case e: WSubIndex => loweredName(e.exp) + delim + e.value + case e: WSubField => s"${loweredName(e.exp)}$delim${e.name}" + case e: WSubIndex => s"${loweredName(e.exp)}$delim${e.value}" } - def loweredName(s: Seq[String]): String = s.mkString(delim) + def loweredName(s: Seq[String]): String = s mkString delim private case class LowerTypesException(msg: String) extends FIRRTLException(msg) private def error(msg: String)(implicit sinfo: Info, mname: String) = @@ -100,35 +100,31 @@ object LowerTypes extends Pass { // and just need to be converted to refer to the correct new memory def lowerTypesMemExp(e: Expression): Seq[Expression] = { val (mem, port, field, tail) = splitMemRef(e) - // Fields that need to be replicated for each resulting mem - if (Seq("addr", "en", "clk", "wmode").contains(field.name)) { - require(tail.isEmpty) // there can't be a tail for these - val memType = memDataTypeMap(mem.name) - - memType match { - case _: GroundType => Seq(e) - case _ => - val exps = create_exps(mem.name, memType) - exps map { e => + field.name match { + // Fields that need to be replicated for each resulting mem + case "addr" | "en" | "clk" | "wmode" => + require(tail.isEmpty) // there can't be a tail for these + memDataTypeMap(mem.name) match { + case _: GroundType => Seq(e) + case memType => create_exps(mem.name, memType) map { e => val loMemName = loweredName(e) val loMem = WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER) mergeRef(loMem, mergeRef(port, field)) } - } - // Fields that need not be replicated for each - // eg. mem.reader.data[0].a - // (Connect/IsInvalid must already have been split to ground types) - } else if (Seq("data", "mask", "rdata", "wdata", "wmask").contains(field.name)) { - val loMem = tail match { - case Some(e) => - val loMemExp = mergeRef(mem, e) - val loMemName = loweredName(loMemExp) - WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER) - case None => mem - } - Seq(mergeRef(loMem, mergeRef(port, field))) - } else { - error(s"Error! Unhandled memory field ${field.name}") + } + // Fields that need not be replicated for each + // eg. mem.reader.data[0].a + // (Connect/IsInvalid must already have been split to ground types) + case "data" | "mask" | "rdata" | "wdata" | "wmask" => + val loMem = tail match { + case Some(e) => + val loMemExp = mergeRef(mem, e) + val loMemName = loweredName(loMemExp) + WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER) + case None => mem + } + Seq(mergeRef(loMem, mergeRef(port, field))) + case name => error(s"Error! Unhandled memory field ${name}") } } @@ -141,116 +137,103 @@ object LowerTypes extends Pass { WSubField(root, name, e.tpe, gender(e)) case k: MemKind => val exps = lowerTypesMemExp(e) - if (exps.length > 1) - error("Error! lowerTypesExp called on MemKind SubField that needs" + - " to be expanded!") - exps(0) - case k => - WRef(loweredName(e), e.tpe, kind(e), gender(e)) + exps.size match { + case 1 => exps.head + case _ => error("Error! lowerTypesExp called on MemKind " + + "SubField that needs to be expanded!") + } + case _ => WRef(loweredName(e), e.tpe, kind(e), gender(e)) } case e: Mux => e map (lowerTypesExp) case e: ValidIf => e map (lowerTypesExp) - case (_: UIntLiteral | _: SIntLiteral) => e case e: DoPrim => e map (lowerTypesExp) + case e @ (_: UIntLiteral | _: SIntLiteral) => e } - def lowerTypesStmt(s: Statement): Statement = { - s map lowerTypesStmt match { - case s: DefWire => - sinfo = s.info - s.tpe match { - case _: GroundType => s - case _ => - val exps = create_exps(s.name, s.tpe) - val stmts = exps map (e => DefWire(s.info, loweredName(e), e.tpe)) - Block(stmts) - } - case s: DefRegister => - sinfo = s.info - s.tpe match { - case _: GroundType => s map lowerTypesExp - case _ => - val es = create_exps(s.name, s.tpe) - val inits = create_exps(s.init) map (lowerTypesExp) - val clock = lowerTypesExp(s.clock) - val reset = lowerTypesExp(s.reset) - val stmts = es zip inits map { case (e, i) => - DefRegister(s.info, loweredName(e), e.tpe, clock, reset, i) - } - Block(stmts) - } - // Could instead just save the type of each Module as it gets processed - case s: WDefInstance => - sinfo = s.info - s.tpe match { - case t: BundleType => - val fieldsx = t.fields flatMap { f => - val exps = create_exps(WRef(f.name, f.tpe, ExpKind(), times(f.flip, MALE))) - exps map ( e => - // Flip because inst genders are reversed from Module type - Field(loweredName(e), swap(to_flip(gender(e))), e.tpe) - ) - } - WDefInstance(s.info, s.name, s.module, BundleType(fieldsx)) - case _ => error("WDefInstance type should be Bundle!") - } - case s: DefMemory => - sinfo = s.info - memDataTypeMap += (s.name -> s.dataType) - s.dataType match { - case _: GroundType => s - case _ => - val exps = create_exps(s.name, s.dataType) - val stmts = exps map { e => - DefMemory(s.info, loweredName(e), e.tpe, s.depth, - s.writeLatency, s.readLatency, s.readers, s.writers, - s.readwriters) - } - Block(stmts) - } - // wire foo : { a , b } - // node x = foo - // node y = x.a - // -> - // node x_a = foo_a - // node x_b = foo_b - // node y = x_a - case s: DefNode => - sinfo = s.info - val names = create_exps(s.name, s.value.tpe) map (lowerTypesExp) - val exps = create_exps(s.value) map (lowerTypesExp) - val stmts = names zip exps map { case (n, e) => - DefNode(s.info, loweredName(n), e) - } - Block(stmts) - case s: IsInvalid => - sinfo = s.info - kind(s.expr) match { - case k: MemKind => - val exps = lowerTypesMemExp(s.expr) - Block(exps map (exp => IsInvalid(s.info, exp))) - case _ => s map (lowerTypesExp) - } - case s: Connect => - sinfo = s.info - kind(s.loc) match { - case k: MemKind => - val exp = lowerTypesExp(s.expr) - val locs = lowerTypesMemExp(s.loc) - Block(locs map (loc => Connect(s.info, loc, exp))) - case _ => s map (lowerTypesExp) - } - case s => s map (lowerTypesExp) - } + def lowerTypesStmt(s: Statement): Statement = s map lowerTypesStmt match { + case s: DefWire => + sinfo = s.info + s.tpe match { + case _: GroundType => s + case _ => Block(create_exps(s.name, s.tpe) map ( + e => DefWire(s.info, loweredName(e), e.tpe))) + } + case s: DefRegister => + sinfo = s.info + s.tpe match { + case _: GroundType => s map lowerTypesExp + case _ => + val es = create_exps(s.name, s.tpe) + val inits = create_exps(s.init) map (lowerTypesExp) + val clock = lowerTypesExp(s.clock) + val reset = lowerTypesExp(s.reset) + Block(es zip inits map { case (e, i) => + DefRegister(s.info, loweredName(e), e.tpe, clock, reset, i) + }) + } + // Could instead just save the type of each Module as it gets processed + case s: WDefInstance => + sinfo = s.info + s.tpe match { + case t: BundleType => + val fieldsx = t.fields flatMap (f => + create_exps(WRef(f.name, f.tpe, ExpKind(), times(f.flip, MALE))) map ( + // Flip because inst genders are reversed from Module type + e => Field(loweredName(e), swap(to_flip(gender(e))), e.tpe) + ) + ) + WDefInstance(s.info, s.name, s.module, BundleType(fieldsx)) + case _ => error("WDefInstance type should be Bundle!") + } + case s: DefMemory => + sinfo = s.info + memDataTypeMap(s.name) = s.dataType + s.dataType match { + case _: GroundType => s + case _ => Block(create_exps(s.name, s.dataType) map (e => + DefMemory(s.info, loweredName(e), e.tpe, s.depth, + s.writeLatency, s.readLatency, s.readers, s.writers, + s.readwriters))) + } + // wire foo : { a , b } + // node x = foo + // node y = x.a + // -> + // node x_a = foo_a + // node x_b = foo_b + // node y = x_a + case s: DefNode => + sinfo = s.info + val names = create_exps(s.name, s.value.tpe) map (lowerTypesExp) + val exps = create_exps(s.value) map (lowerTypesExp) + Block(names zip exps map {case (n, e) => DefNode(s.info, loweredName(n), e)}) + case s: IsInvalid => + sinfo = s.info + kind(s.expr) match { + case k: MemKind => + Block(lowerTypesMemExp(s.expr) map (IsInvalid(s.info, _))) + case _ => s map (lowerTypesExp) + } + case s: Connect => + sinfo = s.info + kind(s.loc) match { + case k: MemKind => + val exp = lowerTypesExp(s.expr) + val locs = lowerTypesMemExp(s.loc) + Block(locs map (Connect(s.info, _, exp))) + case _ => s map (lowerTypesExp) + } + case s => s map (lowerTypesExp) } sinfo = m.info mname = m.name // Lower Ports - val portsx = m.ports flatMap { p => - val exps = create_exps(WRef(p.name, p.tpe, PortKind(), to_gender(p.direction))) - exps map ( e => Port(p.info, loweredName(e), to_dir(gender(e)), e.tpe) ) - } + val portsx = m.ports flatMap ( p => + create_exps(WRef(p.name, p.tpe, PortKind(), to_gender(p.direction))) map ( + e => Port(p.info, loweredName(e), to_dir(gender(e)), e.tpe) + ) + ) m match { case m: ExtModule => m.copy(ports = portsx) case m: Module => Module(m.info, m.name, portsx, lowerTypesStmt(m.body)) diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 6b6dc811..c143212e 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -602,149 +602,115 @@ object InferWidths extends Pass { } object PullMuxes extends Pass { - private var mname = "" def name = "Pull Muxes" - def run (c:Circuit): Circuit = { - def pull_muxes_e (e:Expression) : Expression = { - val ex = e map (pull_muxes_e) match { - case (e:WRef) => e - case (e:WSubField) => { - e.exp match { - case (ex:Mux) => Mux(ex.cond,WSubField(ex.tval,e.name,e.tpe,e.gender),WSubField(ex.fval,e.name,e.tpe,e.gender),e.tpe) - case (ex:ValidIf) => ValidIf(ex.cond,WSubField(ex.value,e.name,e.tpe,e.gender),e.tpe) - case (ex) => e - } - } - case (e:WSubIndex) => { - e.exp match { - case (ex:Mux) => Mux(ex.cond,WSubIndex(ex.tval,e.value,e.tpe,e.gender),WSubIndex(ex.fval,e.value,e.tpe,e.gender),e.tpe) - case (ex:ValidIf) => ValidIf(ex.cond,WSubIndex(ex.value,e.value,e.tpe,e.gender),e.tpe) - case (ex) => e - } - } - case (e:WSubAccess) => { - e.exp match { - case (ex:Mux) => Mux(ex.cond,WSubAccess(ex.tval,e.index,e.tpe,e.gender),WSubAccess(ex.fval,e.index,e.tpe,e.gender),e.tpe) - case (ex:ValidIf) => ValidIf(ex.cond,WSubAccess(ex.value,e.index,e.tpe,e.gender),e.tpe) - case (ex) => e - } - } - case (e:Mux) => e - case (e:ValidIf) => e - case (e) => e + def run(c: Circuit): Circuit = { + def pull_muxes_e(e: Expression): Expression = { + val ex = e map (pull_muxes_e) match { + case (e: WSubField) => e.exp match { + case (ex: Mux) => Mux(ex.cond, + WSubField(ex.tval, e.name, e.tpe, e.gender), + WSubField(ex.fval, e.name, e.tpe, e.gender), e.tpe) + case (ex: ValidIf) => ValidIf(ex.cond, + WSubField(ex.value, e.name, e.tpe, e.gender), e.tpe) + case (ex) => e } - ex map (pull_muxes_e) - } - def pull_muxes (s:Statement) : Statement = s map (pull_muxes) map (pull_muxes_e) - val modulesx = c.modules.map { - m => { - mname = m.name - m match { - case (m:Module) => Module(m.info,m.name,m.ports,pull_muxes(m.body)) - case (m:ExtModule) => m - } + case (e: WSubIndex) => e.exp match { + case (ex: Mux) => Mux(ex.cond, + WSubIndex(ex.tval, e.value, e.tpe, e.gender), + WSubIndex(ex.fval, e.value, e.tpe, e.gender), e.tpe) + case (ex: ValidIf) => ValidIf(ex.cond, + WSubIndex(ex.value, e.value, e.tpe, e.gender), e.tpe) + case (ex) => e } - } - Circuit(c.info,modulesx,c.main) + case (e: WSubAccess) => e.exp match { + case (ex: Mux) => Mux(ex.cond, + WSubAccess(ex.tval, e.index, e.tpe, e.gender), + WSubAccess(ex.fval, e.index, e.tpe, e.gender), e.tpe) + case (ex: ValidIf) => ValidIf(ex.cond, + WSubAccess(ex.value, e.index, e.tpe, e.gender), e.tpe) + case (ex) => e + } + case (e) => e + } + ex map (pull_muxes_e) + } + def pull_muxes(s: Statement): Statement = s map (pull_muxes) map (pull_muxes_e) + val modulesx = c.modules.map { + case (m:Module) => Module(m.info, m.name, m.ports, pull_muxes(m.body)) + case (m:ExtModule) => m + } + Circuit(c.info, modulesx, c.main) } } object ExpandConnects extends Pass { - private var mname = "" - def name = "Expand Connects" - def run (c:Circuit): Circuit = { - def expand_connects (m:Module) : Module = { - mname = m.name - val genders = LinkedHashMap[String,Gender]() - def expand_s (s:Statement) : Statement = { - def set_gender (e:Expression) : Expression = { - e map (set_gender) match { - case (e:WRef) => WRef(e.name,e.tpe,e.kind,genders(e.name)) - case (e:WSubField) => { - val f = get_field(e.exp.tpe,e.name) - val genderx = times(gender(e.exp),f.flip) - WSubField(e.exp,e.name,e.tpe,genderx) - } - case (e:WSubIndex) => WSubIndex(e.exp,e.value,e.tpe,gender(e.exp)) - case (e:WSubAccess) => WSubAccess(e.exp,e.index,e.tpe,gender(e.exp)) - case (e) => e + def name = "Expand Connects" + def run(c: Circuit): Circuit = { + def expand_connects(m: Module): Module = { + val genders = LinkedHashMap[String,Gender]() + def expand_s(s: Statement): Statement = { + def set_gender(e: Expression): Expression = e map (set_gender) match { + case (e: WRef) => WRef(e.name, e.tpe, e.kind, genders(e.name)) + case (e: WSubField) => + val f = get_field(e.exp.tpe, e.name) + val genderx = times(gender(e.exp), f.flip) + WSubField(e.exp, e.name, e.tpe, genderx) + case (e: WSubIndex) => WSubIndex(e.exp, e.value, e.tpe, gender(e.exp)) + case (e: WSubAccess) => WSubAccess(e.exp, e.index, e.tpe, gender(e.exp)) + case (e) => e + } + s match { + case (s: DefWire) => genders(s.name) = BIGENDER; s + case (s: DefRegister) => genders(s.name) = BIGENDER; s + case (s: WDefInstance) => genders(s.name) = MALE; s + case (s: DefMemory) => genders(s.name) = MALE; s + case (s: DefNode) => genders(s.name) = MALE; s + case (s: IsInvalid) => + val invalids = (create_exps(s.expr) foldLeft Seq[Statement]())( + (invalids, expx) => gender(set_gender(expx)) match { + case BIGENDER => invalids :+ IsInvalid(s.info, expx) + case FEMALE => invalids :+ IsInvalid(s.info, expx) + case _ => invalids } + ) + invalids.size match { + case 0 => EmptyStmt + case 1 => invalids.head + case _ => Block(invalids) } - s match { - case (s:DefWire) => { genders(s.name) = BIGENDER; s } - case (s:DefRegister) => { genders(s.name) = BIGENDER; s } - case (s:WDefInstance) => { genders(s.name) = MALE; s } - case (s:DefMemory) => { genders(s.name) = MALE; s } - case (s:DefNode) => { genders(s.name) = MALE; s } - case (s:IsInvalid) => { - val n = get_size(s.expr.tpe) - val invalids = ArrayBuffer[Statement]() - val exps = create_exps(s.expr) - for (i <- 0 until n) { - val expx = exps(i) - val gexpx = set_gender(expx) - gender(gexpx) match { - case BIGENDER => invalids += IsInvalid(s.info,expx) - case FEMALE => invalids += IsInvalid(s.info,expx) - case _ => {} - } - } - if (invalids.length == 0) { - EmptyStmt - } else if (invalids.length == 1) { - invalids(0) - } else Block(invalids) - } - case (s:Connect) => { - val n = get_size(s.loc.tpe) - val connects = ArrayBuffer[Statement]() - val locs = create_exps(s.loc) - val exps = create_exps(s.expr) - for (i <- 0 until n) { - val locx = locs(i) - val expx = exps(i) - val sx = get_flip(s.loc.tpe,i,Default) match { - case Default => Connect(s.info,locx,expx) - case Flip => Connect(s.info,expx,locx) - } - connects += sx - } - Block(connects) - } - case (s:PartialConnect) => { - val ls = get_valid_points(s.loc.tpe,s.expr.tpe,Default,Default) - val connects = ArrayBuffer[Statement]() - val locs = create_exps(s.loc) - val exps = create_exps(s.expr) - ls.foreach { x => { - val locx = locs(x._1) - val expx = exps(x._2) - val sx = get_flip(s.loc.tpe,x._1,Default) match { - case Default => Connect(s.info,locx,expx) - case Flip => Connect(s.info,expx,locx) - } - connects += sx - }} - Block(connects) + case (s: Connect) => + val locs = create_exps(s.loc) + val exps = create_exps(s.expr) + Block((locs zip exps).zipWithIndex map {case ((locx, expx), i) => + get_flip(s.loc.tpe, i, Default) match { + case Default => Connect(s.info, locx, expx) + case Flip => Connect(s.info, expx, locx) } - case (s) => s map (expand_s) - } - } - - m.ports.foreach { p => genders(p.name) = to_gender(p.direction) } - Module(m.info,m.name,m.ports,expand_s(m.body)) - } - - val modulesx = c.modules.map { - m => { - m match { - case (m:ExtModule) => m - case (m:Module) => expand_connects(m) - } - } + }) + case (s: PartialConnect) => + val ls = get_valid_points(s.loc.tpe, s.expr.tpe, Default, Default) + val locs = create_exps(s.loc) + val exps = create_exps(s.expr) + Block(ls map {case (x, y) => + get_flip(s.loc.tpe, x, Default) match { + case Default => Connect(s.info, locs(x), exps(y)) + case Flip => Connect(s.info, exps(y), locs(x)) + } + }) + case (s) => s map (expand_s) + } } - Circuit(c.info,modulesx,c.main) - } + + m.ports.foreach { p => genders(p.name) = to_gender(p.direction) } + Module(m.info, m.name, m.ports, expand_s(m.body)) + } + + val modulesx = c.modules.map { + case (m: ExtModule) => m + case (m: Module) => expand_connects(m) + } + Circuit(c.info, modulesx, c.main) + } } @@ -754,8 +720,8 @@ object Legalize extends Pass { def name = "Legalize" def legalizeShiftRight (e: DoPrim): Expression = e.op match { case Shr => { - val amount = e.consts(0).toInt - val width = long_BANG(e.args(0).tpe) + val amount = e.consts.head.toInt + val width = long_BANG(e.args.head.tpe) lazy val msb = width - 1 if (amount >= width) { e.tpe match { diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala index 3b6021ed..90b92a35 100644 --- a/src/main/scala/firrtl/passes/SplitExpressions.scala +++ b/src/main/scala/firrtl/passes/SplitExpressions.scala @@ -14,53 +14,53 @@ object SplitExpressions extends Pass { private def onModule(m: Module): Module = { val namespace = Namespace(m) def onStmt(s: Statement): Statement = { - val v = mutable.ArrayBuffer[Statement]() - // Splits current expression if needed - // Adds named temporaries to v - def split(e: Expression): Expression = e match { - case e: DoPrim => { - val name = namespace.newTemp - v += DefNode(get_info(s), name, e) - WRef(name, e.tpe, kind(e), gender(e)) - } - case e: Mux => { - val name = namespace.newTemp - v += DefNode(get_info(s), name, e) - WRef(name, e.tpe, kind(e), gender(e)) - } - case e: ValidIf => { - val name = namespace.newTemp - v += DefNode(get_info(s), name, e) - WRef(name, e.tpe, kind(e), gender(e)) - } - case e => e - } - // Recursive. Splits compound nodes - def onExp(e: Expression): Expression = { - val ex = e map onExp - ex match { - case (_: DoPrim) => ex map split - case v => v - } - } - val x = s map onExp - x match { - case x: Block => x map onStmt - case EmptyStmt => x - case x => { - v += x - if (v.size > 1) Block(v.toVector) - else v(0) - } + val v = mutable.ArrayBuffer[Statement]() + // Splits current expression if needed + // Adds named temporaries to v + def split(e: Expression): Expression = e match { + case e: DoPrim => { + val name = namespace.newTemp + v += DefNode(get_info(s), name, e) + WRef(name, e.tpe, kind(e), gender(e)) + } + case e: Mux => { + val name = namespace.newTemp + v += DefNode(get_info(s), name, e) + WRef(name, e.tpe, kind(e), gender(e)) + } + case e: ValidIf => { + val name = namespace.newTemp + v += DefNode(get_info(s), name, e) + WRef(name, e.tpe, kind(e), gender(e)) + } + case e => e + } + + // Recursive. Splits compound nodes + def onExp(e: Expression): Expression = + e map onExp match { + case ex: DoPrim => ex map split + case v => v } + + s map onExp match { + case x: Block => x map onStmt + case EmptyStmt => EmptyStmt + case x => + v += x + v.size match { + case 1 => v.head + case _ => Block(v.toSeq) + } + } } Module(m.info, m.name, m.ports, onStmt(m.body)) } def run(c: Circuit): Circuit = { - val modulesx = c.modules.map( _ match { - case m: Module => onModule(m) - case m: ExtModule => m - }) - Circuit(c.info, modulesx, c.main) + val modulesx = c.modules map { + case m: Module => onModule(m) + case m: ExtModule => m + } + Circuit(c.info, modulesx, c.main) } } |
