aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDonggyu2016-09-07 16:33:24 -0700
committerGitHub2016-09-07 16:33:24 -0700
commita404cf5b2c4ca6457c964eb32aae8330c48422e1 (patch)
treed6c137d446a254bae87cd015c94bf86806225042 /src
parent13345ce816a51cc19f93a03c4148eecf0dd2c739 (diff)
parent8beaa3be259d2a793e3a99628b2f0d38d98f5b9a (diff)
Merge pull request #280 from ucb-bar/cleanup_passes
Clean up passes
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/LowerTypes.scala237
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala234
-rw-r--r--src/main/scala/firrtl/passes/SplitExpressions.scala86
3 files changed, 253 insertions, 304 deletions
diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala
index a4c584ed..57f8fd76 100644
--- a/src/main/scala/firrtl/passes/LowerTypes.scala
+++ b/src/main/scala/firrtl/passes/LowerTypes.scala
@@ -61,10 +61,10 @@ object LowerTypes extends Pass {
*/
def loweredName(e: Expression): String = e match {
case e: WRef => e.name
- case e: WSubField => loweredName(e.exp) + delim + e.name
- case e: WSubIndex => loweredName(e.exp) + delim + e.value
+ case e: WSubField => s"${loweredName(e.exp)}$delim${e.name}"
+ case e: WSubIndex => s"${loweredName(e.exp)}$delim${e.value}"
}
- def loweredName(s: Seq[String]): String = s.mkString(delim)
+ def loweredName(s: Seq[String]): String = s mkString delim
private case class LowerTypesException(msg: String) extends FIRRTLException(msg)
private def error(msg: String)(implicit sinfo: Info, mname: String) =
@@ -100,35 +100,31 @@ object LowerTypes extends Pass {
// and just need to be converted to refer to the correct new memory
def lowerTypesMemExp(e: Expression): Seq[Expression] = {
val (mem, port, field, tail) = splitMemRef(e)
- // Fields that need to be replicated for each resulting mem
- if (Seq("addr", "en", "clk", "wmode").contains(field.name)) {
- require(tail.isEmpty) // there can't be a tail for these
- val memType = memDataTypeMap(mem.name)
-
- memType match {
- case _: GroundType => Seq(e)
- case _ =>
- val exps = create_exps(mem.name, memType)
- exps map { e =>
+ field.name match {
+ // Fields that need to be replicated for each resulting mem
+ case "addr" | "en" | "clk" | "wmode" =>
+ require(tail.isEmpty) // there can't be a tail for these
+ memDataTypeMap(mem.name) match {
+ case _: GroundType => Seq(e)
+ case memType => create_exps(mem.name, memType) map { e =>
val loMemName = loweredName(e)
val loMem = WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER)
mergeRef(loMem, mergeRef(port, field))
}
- }
- // Fields that need not be replicated for each
- // eg. mem.reader.data[0].a
- // (Connect/IsInvalid must already have been split to ground types)
- } else if (Seq("data", "mask", "rdata", "wdata", "wmask").contains(field.name)) {
- val loMem = tail match {
- case Some(e) =>
- val loMemExp = mergeRef(mem, e)
- val loMemName = loweredName(loMemExp)
- WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER)
- case None => mem
- }
- Seq(mergeRef(loMem, mergeRef(port, field)))
- } else {
- error(s"Error! Unhandled memory field ${field.name}")
+ }
+ // Fields that need not be replicated for each
+ // eg. mem.reader.data[0].a
+ // (Connect/IsInvalid must already have been split to ground types)
+ case "data" | "mask" | "rdata" | "wdata" | "wmask" =>
+ val loMem = tail match {
+ case Some(e) =>
+ val loMemExp = mergeRef(mem, e)
+ val loMemName = loweredName(loMemExp)
+ WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER)
+ case None => mem
+ }
+ Seq(mergeRef(loMem, mergeRef(port, field)))
+ case name => error(s"Error! Unhandled memory field ${name}")
}
}
@@ -141,116 +137,103 @@ object LowerTypes extends Pass {
WSubField(root, name, e.tpe, gender(e))
case k: MemKind =>
val exps = lowerTypesMemExp(e)
- if (exps.length > 1)
- error("Error! lowerTypesExp called on MemKind SubField that needs" +
- " to be expanded!")
- exps(0)
- case k =>
- WRef(loweredName(e), e.tpe, kind(e), gender(e))
+ exps.size match {
+ case 1 => exps.head
+ case _ => error("Error! lowerTypesExp called on MemKind " +
+ "SubField that needs to be expanded!")
+ }
+ case _ => WRef(loweredName(e), e.tpe, kind(e), gender(e))
}
case e: Mux => e map (lowerTypesExp)
case e: ValidIf => e map (lowerTypesExp)
- case (_: UIntLiteral | _: SIntLiteral) => e
case e: DoPrim => e map (lowerTypesExp)
+ case e @ (_: UIntLiteral | _: SIntLiteral) => e
}
- def lowerTypesStmt(s: Statement): Statement = {
- s map lowerTypesStmt match {
- case s: DefWire =>
- sinfo = s.info
- s.tpe match {
- case _: GroundType => s
- case _ =>
- val exps = create_exps(s.name, s.tpe)
- val stmts = exps map (e => DefWire(s.info, loweredName(e), e.tpe))
- Block(stmts)
- }
- case s: DefRegister =>
- sinfo = s.info
- s.tpe match {
- case _: GroundType => s map lowerTypesExp
- case _ =>
- val es = create_exps(s.name, s.tpe)
- val inits = create_exps(s.init) map (lowerTypesExp)
- val clock = lowerTypesExp(s.clock)
- val reset = lowerTypesExp(s.reset)
- val stmts = es zip inits map { case (e, i) =>
- DefRegister(s.info, loweredName(e), e.tpe, clock, reset, i)
- }
- Block(stmts)
- }
- // Could instead just save the type of each Module as it gets processed
- case s: WDefInstance =>
- sinfo = s.info
- s.tpe match {
- case t: BundleType =>
- val fieldsx = t.fields flatMap { f =>
- val exps = create_exps(WRef(f.name, f.tpe, ExpKind(), times(f.flip, MALE)))
- exps map ( e =>
- // Flip because inst genders are reversed from Module type
- Field(loweredName(e), swap(to_flip(gender(e))), e.tpe)
- )
- }
- WDefInstance(s.info, s.name, s.module, BundleType(fieldsx))
- case _ => error("WDefInstance type should be Bundle!")
- }
- case s: DefMemory =>
- sinfo = s.info
- memDataTypeMap += (s.name -> s.dataType)
- s.dataType match {
- case _: GroundType => s
- case _ =>
- val exps = create_exps(s.name, s.dataType)
- val stmts = exps map { e =>
- DefMemory(s.info, loweredName(e), e.tpe, s.depth,
- s.writeLatency, s.readLatency, s.readers, s.writers,
- s.readwriters)
- }
- Block(stmts)
- }
- // wire foo : { a , b }
- // node x = foo
- // node y = x.a
- // ->
- // node x_a = foo_a
- // node x_b = foo_b
- // node y = x_a
- case s: DefNode =>
- sinfo = s.info
- val names = create_exps(s.name, s.value.tpe) map (lowerTypesExp)
- val exps = create_exps(s.value) map (lowerTypesExp)
- val stmts = names zip exps map { case (n, e) =>
- DefNode(s.info, loweredName(n), e)
- }
- Block(stmts)
- case s: IsInvalid =>
- sinfo = s.info
- kind(s.expr) match {
- case k: MemKind =>
- val exps = lowerTypesMemExp(s.expr)
- Block(exps map (exp => IsInvalid(s.info, exp)))
- case _ => s map (lowerTypesExp)
- }
- case s: Connect =>
- sinfo = s.info
- kind(s.loc) match {
- case k: MemKind =>
- val exp = lowerTypesExp(s.expr)
- val locs = lowerTypesMemExp(s.loc)
- Block(locs map (loc => Connect(s.info, loc, exp)))
- case _ => s map (lowerTypesExp)
- }
- case s => s map (lowerTypesExp)
- }
+ def lowerTypesStmt(s: Statement): Statement = s map lowerTypesStmt match {
+ case s: DefWire =>
+ sinfo = s.info
+ s.tpe match {
+ case _: GroundType => s
+ case _ => Block(create_exps(s.name, s.tpe) map (
+ e => DefWire(s.info, loweredName(e), e.tpe)))
+ }
+ case s: DefRegister =>
+ sinfo = s.info
+ s.tpe match {
+ case _: GroundType => s map lowerTypesExp
+ case _ =>
+ val es = create_exps(s.name, s.tpe)
+ val inits = create_exps(s.init) map (lowerTypesExp)
+ val clock = lowerTypesExp(s.clock)
+ val reset = lowerTypesExp(s.reset)
+ Block(es zip inits map { case (e, i) =>
+ DefRegister(s.info, loweredName(e), e.tpe, clock, reset, i)
+ })
+ }
+ // Could instead just save the type of each Module as it gets processed
+ case s: WDefInstance =>
+ sinfo = s.info
+ s.tpe match {
+ case t: BundleType =>
+ val fieldsx = t.fields flatMap (f =>
+ create_exps(WRef(f.name, f.tpe, ExpKind(), times(f.flip, MALE))) map (
+ // Flip because inst genders are reversed from Module type
+ e => Field(loweredName(e), swap(to_flip(gender(e))), e.tpe)
+ )
+ )
+ WDefInstance(s.info, s.name, s.module, BundleType(fieldsx))
+ case _ => error("WDefInstance type should be Bundle!")
+ }
+ case s: DefMemory =>
+ sinfo = s.info
+ memDataTypeMap(s.name) = s.dataType
+ s.dataType match {
+ case _: GroundType => s
+ case _ => Block(create_exps(s.name, s.dataType) map (e =>
+ DefMemory(s.info, loweredName(e), e.tpe, s.depth,
+ s.writeLatency, s.readLatency, s.readers, s.writers,
+ s.readwriters)))
+ }
+ // wire foo : { a , b }
+ // node x = foo
+ // node y = x.a
+ // ->
+ // node x_a = foo_a
+ // node x_b = foo_b
+ // node y = x_a
+ case s: DefNode =>
+ sinfo = s.info
+ val names = create_exps(s.name, s.value.tpe) map (lowerTypesExp)
+ val exps = create_exps(s.value) map (lowerTypesExp)
+ Block(names zip exps map {case (n, e) => DefNode(s.info, loweredName(n), e)})
+ case s: IsInvalid =>
+ sinfo = s.info
+ kind(s.expr) match {
+ case k: MemKind =>
+ Block(lowerTypesMemExp(s.expr) map (IsInvalid(s.info, _)))
+ case _ => s map (lowerTypesExp)
+ }
+ case s: Connect =>
+ sinfo = s.info
+ kind(s.loc) match {
+ case k: MemKind =>
+ val exp = lowerTypesExp(s.expr)
+ val locs = lowerTypesMemExp(s.loc)
+ Block(locs map (Connect(s.info, _, exp)))
+ case _ => s map (lowerTypesExp)
+ }
+ case s => s map (lowerTypesExp)
}
sinfo = m.info
mname = m.name
// Lower Ports
- val portsx = m.ports flatMap { p =>
- val exps = create_exps(WRef(p.name, p.tpe, PortKind(), to_gender(p.direction)))
- exps map ( e => Port(p.info, loweredName(e), to_dir(gender(e)), e.tpe) )
- }
+ val portsx = m.ports flatMap ( p =>
+ create_exps(WRef(p.name, p.tpe, PortKind(), to_gender(p.direction))) map (
+ e => Port(p.info, loweredName(e), to_dir(gender(e)), e.tpe)
+ )
+ )
m match {
case m: ExtModule => m.copy(ports = portsx)
case m: Module => Module(m.info, m.name, portsx, lowerTypesStmt(m.body))
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
index 6b6dc811..c143212e 100644
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ b/src/main/scala/firrtl/passes/Passes.scala
@@ -602,149 +602,115 @@ object InferWidths extends Pass {
}
object PullMuxes extends Pass {
- private var mname = ""
def name = "Pull Muxes"
- def run (c:Circuit): Circuit = {
- def pull_muxes_e (e:Expression) : Expression = {
- val ex = e map (pull_muxes_e) match {
- case (e:WRef) => e
- case (e:WSubField) => {
- e.exp match {
- case (ex:Mux) => Mux(ex.cond,WSubField(ex.tval,e.name,e.tpe,e.gender),WSubField(ex.fval,e.name,e.tpe,e.gender),e.tpe)
- case (ex:ValidIf) => ValidIf(ex.cond,WSubField(ex.value,e.name,e.tpe,e.gender),e.tpe)
- case (ex) => e
- }
- }
- case (e:WSubIndex) => {
- e.exp match {
- case (ex:Mux) => Mux(ex.cond,WSubIndex(ex.tval,e.value,e.tpe,e.gender),WSubIndex(ex.fval,e.value,e.tpe,e.gender),e.tpe)
- case (ex:ValidIf) => ValidIf(ex.cond,WSubIndex(ex.value,e.value,e.tpe,e.gender),e.tpe)
- case (ex) => e
- }
- }
- case (e:WSubAccess) => {
- e.exp match {
- case (ex:Mux) => Mux(ex.cond,WSubAccess(ex.tval,e.index,e.tpe,e.gender),WSubAccess(ex.fval,e.index,e.tpe,e.gender),e.tpe)
- case (ex:ValidIf) => ValidIf(ex.cond,WSubAccess(ex.value,e.index,e.tpe,e.gender),e.tpe)
- case (ex) => e
- }
- }
- case (e:Mux) => e
- case (e:ValidIf) => e
- case (e) => e
+ def run(c: Circuit): Circuit = {
+ def pull_muxes_e(e: Expression): Expression = {
+ val ex = e map (pull_muxes_e) match {
+ case (e: WSubField) => e.exp match {
+ case (ex: Mux) => Mux(ex.cond,
+ WSubField(ex.tval, e.name, e.tpe, e.gender),
+ WSubField(ex.fval, e.name, e.tpe, e.gender), e.tpe)
+ case (ex: ValidIf) => ValidIf(ex.cond,
+ WSubField(ex.value, e.name, e.tpe, e.gender), e.tpe)
+ case (ex) => e
}
- ex map (pull_muxes_e)
- }
- def pull_muxes (s:Statement) : Statement = s map (pull_muxes) map (pull_muxes_e)
- val modulesx = c.modules.map {
- m => {
- mname = m.name
- m match {
- case (m:Module) => Module(m.info,m.name,m.ports,pull_muxes(m.body))
- case (m:ExtModule) => m
- }
+ case (e: WSubIndex) => e.exp match {
+ case (ex: Mux) => Mux(ex.cond,
+ WSubIndex(ex.tval, e.value, e.tpe, e.gender),
+ WSubIndex(ex.fval, e.value, e.tpe, e.gender), e.tpe)
+ case (ex: ValidIf) => ValidIf(ex.cond,
+ WSubIndex(ex.value, e.value, e.tpe, e.gender), e.tpe)
+ case (ex) => e
}
- }
- Circuit(c.info,modulesx,c.main)
+ case (e: WSubAccess) => e.exp match {
+ case (ex: Mux) => Mux(ex.cond,
+ WSubAccess(ex.tval, e.index, e.tpe, e.gender),
+ WSubAccess(ex.fval, e.index, e.tpe, e.gender), e.tpe)
+ case (ex: ValidIf) => ValidIf(ex.cond,
+ WSubAccess(ex.value, e.index, e.tpe, e.gender), e.tpe)
+ case (ex) => e
+ }
+ case (e) => e
+ }
+ ex map (pull_muxes_e)
+ }
+ def pull_muxes(s: Statement): Statement = s map (pull_muxes) map (pull_muxes_e)
+ val modulesx = c.modules.map {
+ case (m:Module) => Module(m.info, m.name, m.ports, pull_muxes(m.body))
+ case (m:ExtModule) => m
+ }
+ Circuit(c.info, modulesx, c.main)
}
}
object ExpandConnects extends Pass {
- private var mname = ""
- def name = "Expand Connects"
- def run (c:Circuit): Circuit = {
- def expand_connects (m:Module) : Module = {
- mname = m.name
- val genders = LinkedHashMap[String,Gender]()
- def expand_s (s:Statement) : Statement = {
- def set_gender (e:Expression) : Expression = {
- e map (set_gender) match {
- case (e:WRef) => WRef(e.name,e.tpe,e.kind,genders(e.name))
- case (e:WSubField) => {
- val f = get_field(e.exp.tpe,e.name)
- val genderx = times(gender(e.exp),f.flip)
- WSubField(e.exp,e.name,e.tpe,genderx)
- }
- case (e:WSubIndex) => WSubIndex(e.exp,e.value,e.tpe,gender(e.exp))
- case (e:WSubAccess) => WSubAccess(e.exp,e.index,e.tpe,gender(e.exp))
- case (e) => e
+ def name = "Expand Connects"
+ def run(c: Circuit): Circuit = {
+ def expand_connects(m: Module): Module = {
+ val genders = LinkedHashMap[String,Gender]()
+ def expand_s(s: Statement): Statement = {
+ def set_gender(e: Expression): Expression = e map (set_gender) match {
+ case (e: WRef) => WRef(e.name, e.tpe, e.kind, genders(e.name))
+ case (e: WSubField) =>
+ val f = get_field(e.exp.tpe, e.name)
+ val genderx = times(gender(e.exp), f.flip)
+ WSubField(e.exp, e.name, e.tpe, genderx)
+ case (e: WSubIndex) => WSubIndex(e.exp, e.value, e.tpe, gender(e.exp))
+ case (e: WSubAccess) => WSubAccess(e.exp, e.index, e.tpe, gender(e.exp))
+ case (e) => e
+ }
+ s match {
+ case (s: DefWire) => genders(s.name) = BIGENDER; s
+ case (s: DefRegister) => genders(s.name) = BIGENDER; s
+ case (s: WDefInstance) => genders(s.name) = MALE; s
+ case (s: DefMemory) => genders(s.name) = MALE; s
+ case (s: DefNode) => genders(s.name) = MALE; s
+ case (s: IsInvalid) =>
+ val invalids = (create_exps(s.expr) foldLeft Seq[Statement]())(
+ (invalids, expx) => gender(set_gender(expx)) match {
+ case BIGENDER => invalids :+ IsInvalid(s.info, expx)
+ case FEMALE => invalids :+ IsInvalid(s.info, expx)
+ case _ => invalids
}
+ )
+ invalids.size match {
+ case 0 => EmptyStmt
+ case 1 => invalids.head
+ case _ => Block(invalids)
}
- s match {
- case (s:DefWire) => { genders(s.name) = BIGENDER; s }
- case (s:DefRegister) => { genders(s.name) = BIGENDER; s }
- case (s:WDefInstance) => { genders(s.name) = MALE; s }
- case (s:DefMemory) => { genders(s.name) = MALE; s }
- case (s:DefNode) => { genders(s.name) = MALE; s }
- case (s:IsInvalid) => {
- val n = get_size(s.expr.tpe)
- val invalids = ArrayBuffer[Statement]()
- val exps = create_exps(s.expr)
- for (i <- 0 until n) {
- val expx = exps(i)
- val gexpx = set_gender(expx)
- gender(gexpx) match {
- case BIGENDER => invalids += IsInvalid(s.info,expx)
- case FEMALE => invalids += IsInvalid(s.info,expx)
- case _ => {}
- }
- }
- if (invalids.length == 0) {
- EmptyStmt
- } else if (invalids.length == 1) {
- invalids(0)
- } else Block(invalids)
- }
- case (s:Connect) => {
- val n = get_size(s.loc.tpe)
- val connects = ArrayBuffer[Statement]()
- val locs = create_exps(s.loc)
- val exps = create_exps(s.expr)
- for (i <- 0 until n) {
- val locx = locs(i)
- val expx = exps(i)
- val sx = get_flip(s.loc.tpe,i,Default) match {
- case Default => Connect(s.info,locx,expx)
- case Flip => Connect(s.info,expx,locx)
- }
- connects += sx
- }
- Block(connects)
- }
- case (s:PartialConnect) => {
- val ls = get_valid_points(s.loc.tpe,s.expr.tpe,Default,Default)
- val connects = ArrayBuffer[Statement]()
- val locs = create_exps(s.loc)
- val exps = create_exps(s.expr)
- ls.foreach { x => {
- val locx = locs(x._1)
- val expx = exps(x._2)
- val sx = get_flip(s.loc.tpe,x._1,Default) match {
- case Default => Connect(s.info,locx,expx)
- case Flip => Connect(s.info,expx,locx)
- }
- connects += sx
- }}
- Block(connects)
+ case (s: Connect) =>
+ val locs = create_exps(s.loc)
+ val exps = create_exps(s.expr)
+ Block((locs zip exps).zipWithIndex map {case ((locx, expx), i) =>
+ get_flip(s.loc.tpe, i, Default) match {
+ case Default => Connect(s.info, locx, expx)
+ case Flip => Connect(s.info, expx, locx)
}
- case (s) => s map (expand_s)
- }
- }
-
- m.ports.foreach { p => genders(p.name) = to_gender(p.direction) }
- Module(m.info,m.name,m.ports,expand_s(m.body))
- }
-
- val modulesx = c.modules.map {
- m => {
- m match {
- case (m:ExtModule) => m
- case (m:Module) => expand_connects(m)
- }
- }
+ })
+ case (s: PartialConnect) =>
+ val ls = get_valid_points(s.loc.tpe, s.expr.tpe, Default, Default)
+ val locs = create_exps(s.loc)
+ val exps = create_exps(s.expr)
+ Block(ls map {case (x, y) =>
+ get_flip(s.loc.tpe, x, Default) match {
+ case Default => Connect(s.info, locs(x), exps(y))
+ case Flip => Connect(s.info, exps(y), locs(x))
+ }
+ })
+ case (s) => s map (expand_s)
+ }
}
- Circuit(c.info,modulesx,c.main)
- }
+
+ m.ports.foreach { p => genders(p.name) = to_gender(p.direction) }
+ Module(m.info, m.name, m.ports, expand_s(m.body))
+ }
+
+ val modulesx = c.modules.map {
+ case (m: ExtModule) => m
+ case (m: Module) => expand_connects(m)
+ }
+ Circuit(c.info, modulesx, c.main)
+ }
}
@@ -754,8 +720,8 @@ object Legalize extends Pass {
def name = "Legalize"
def legalizeShiftRight (e: DoPrim): Expression = e.op match {
case Shr => {
- val amount = e.consts(0).toInt
- val width = long_BANG(e.args(0).tpe)
+ val amount = e.consts.head.toInt
+ val width = long_BANG(e.args.head.tpe)
lazy val msb = width - 1
if (amount >= width) {
e.tpe match {
diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala
index 3b6021ed..90b92a35 100644
--- a/src/main/scala/firrtl/passes/SplitExpressions.scala
+++ b/src/main/scala/firrtl/passes/SplitExpressions.scala
@@ -14,53 +14,53 @@ object SplitExpressions extends Pass {
private def onModule(m: Module): Module = {
val namespace = Namespace(m)
def onStmt(s: Statement): Statement = {
- val v = mutable.ArrayBuffer[Statement]()
- // Splits current expression if needed
- // Adds named temporaries to v
- def split(e: Expression): Expression = e match {
- case e: DoPrim => {
- val name = namespace.newTemp
- v += DefNode(get_info(s), name, e)
- WRef(name, e.tpe, kind(e), gender(e))
- }
- case e: Mux => {
- val name = namespace.newTemp
- v += DefNode(get_info(s), name, e)
- WRef(name, e.tpe, kind(e), gender(e))
- }
- case e: ValidIf => {
- val name = namespace.newTemp
- v += DefNode(get_info(s), name, e)
- WRef(name, e.tpe, kind(e), gender(e))
- }
- case e => e
- }
- // Recursive. Splits compound nodes
- def onExp(e: Expression): Expression = {
- val ex = e map onExp
- ex match {
- case (_: DoPrim) => ex map split
- case v => v
- }
- }
- val x = s map onExp
- x match {
- case x: Block => x map onStmt
- case EmptyStmt => x
- case x => {
- v += x
- if (v.size > 1) Block(v.toVector)
- else v(0)
- }
+ val v = mutable.ArrayBuffer[Statement]()
+ // Splits current expression if needed
+ // Adds named temporaries to v
+ def split(e: Expression): Expression = e match {
+ case e: DoPrim => {
+ val name = namespace.newTemp
+ v += DefNode(get_info(s), name, e)
+ WRef(name, e.tpe, kind(e), gender(e))
+ }
+ case e: Mux => {
+ val name = namespace.newTemp
+ v += DefNode(get_info(s), name, e)
+ WRef(name, e.tpe, kind(e), gender(e))
+ }
+ case e: ValidIf => {
+ val name = namespace.newTemp
+ v += DefNode(get_info(s), name, e)
+ WRef(name, e.tpe, kind(e), gender(e))
+ }
+ case e => e
+ }
+
+ // Recursive. Splits compound nodes
+ def onExp(e: Expression): Expression =
+ e map onExp match {
+ case ex: DoPrim => ex map split
+ case v => v
}
+
+ s map onExp match {
+ case x: Block => x map onStmt
+ case EmptyStmt => EmptyStmt
+ case x =>
+ v += x
+ v.size match {
+ case 1 => v.head
+ case _ => Block(v.toSeq)
+ }
+ }
}
Module(m.info, m.name, m.ports, onStmt(m.body))
}
def run(c: Circuit): Circuit = {
- val modulesx = c.modules.map( _ match {
- case m: Module => onModule(m)
- case m: ExtModule => m
- })
- Circuit(c.info, modulesx, c.main)
+ val modulesx = c.modules map {
+ case m: Module => onModule(m)
+ case m: ExtModule => m
+ }
+ Circuit(c.info, modulesx, c.main)
}
}