aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAdam Izraelevitz2016-11-21 13:30:11 -0800
committerGitHub2016-11-21 13:30:11 -0800
commit9a967a27aa8bb51f4b62969d2889f9a9caa48e31 (patch)
treedad7370c71df7ce8d4628f70b4079296bfee4d99 /src
parentaad8e09f355f4804d29361d75f54ce4a5c2d5c52 (diff)
Bugfix: exponential runtime of pull muxes (#379)
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala53
-rw-r--r--src/test/scala/firrtlTests/UnitTests.scala19
2 files changed, 44 insertions, 28 deletions
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
index 46e68d5c..87458a2b 100644
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ b/src/main/scala/firrtl/passes/Passes.scala
@@ -54,35 +54,32 @@ object ToWorkingIR extends Pass {
object PullMuxes extends Pass {
def name = "Pull Muxes"
def run(c: Circuit): Circuit = {
- def pull_muxes_e(e: Expression): Expression = {
- val exxx = e map pull_muxes_e match {
- case ex: WSubField => ex.exp match {
- case exx: Mux => Mux(exx.cond,
- WSubField(exx.tval, ex.name, ex.tpe, ex.gender),
- WSubField(exx.fval, ex.name, ex.tpe, ex.gender), ex.tpe)
- case exx: ValidIf => ValidIf(exx.cond,
- WSubField(exx.value, ex.name, ex.tpe, ex.gender), ex.tpe)
- case _ => ex // case exx => exx causes failed tests
- }
- case ex: WSubIndex => ex.exp match {
- case exx: Mux => Mux(exx.cond,
- WSubIndex(exx.tval, ex.value, ex.tpe, ex.gender),
- WSubIndex(exx.fval, ex.value, ex.tpe, ex.gender), ex.tpe)
- case exx: ValidIf => ValidIf(exx.cond,
- WSubIndex(exx.value, ex.value, ex.tpe, ex.gender), ex.tpe)
- case _ => ex // case exx => exx causes failed tests
- }
- case ex: WSubAccess => ex.exp match {
- case exx: Mux => Mux(exx.cond,
- WSubAccess(exx.tval, ex.index, ex.tpe, ex.gender),
- WSubAccess(exx.fval, ex.index, ex.tpe, ex.gender), ex.tpe)
- case exx: ValidIf => ValidIf(exx.cond,
- WSubAccess(exx.value, ex.index, ex.tpe, ex.gender), ex.tpe)
- case _ => ex // case exx => exx causes failed tests
- }
- case ex => ex
+ def pull_muxes_e(e: Expression): Expression = e map pull_muxes_e match {
+ case ex: WSubField => ex.exp match {
+ case exx: Mux => Mux(exx.cond,
+ WSubField(exx.tval, ex.name, ex.tpe, ex.gender),
+ WSubField(exx.fval, ex.name, ex.tpe, ex.gender), ex.tpe)
+ case exx: ValidIf => ValidIf(exx.cond,
+ WSubField(exx.value, ex.name, ex.tpe, ex.gender), ex.tpe)
+ case _ => ex // case exx => exx causes failed tests
}
- exxx map pull_muxes_e
+ case ex: WSubIndex => ex.exp match {
+ case exx: Mux => Mux(exx.cond,
+ WSubIndex(exx.tval, ex.value, ex.tpe, ex.gender),
+ WSubIndex(exx.fval, ex.value, ex.tpe, ex.gender), ex.tpe)
+ case exx: ValidIf => ValidIf(exx.cond,
+ WSubIndex(exx.value, ex.value, ex.tpe, ex.gender), ex.tpe)
+ case _ => ex // case exx => exx causes failed tests
+ }
+ case ex: WSubAccess => ex.exp match {
+ case exx: Mux => Mux(exx.cond,
+ WSubAccess(exx.tval, ex.index, ex.tpe, ex.gender),
+ WSubAccess(exx.fval, ex.index, ex.tpe, ex.gender), ex.tpe)
+ case exx: ValidIf => ValidIf(exx.cond,
+ WSubAccess(exx.value, ex.index, ex.tpe, ex.gender), ex.tpe)
+ case _ => ex // case exx => exx causes failed tests
+ }
+ case ex => ex
}
def pull_muxes(s: Statement): Statement = s map pull_muxes map pull_muxes_e
val modulesx = c.modules.map {
diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala
index 1e181141..e2f8f729 100644
--- a/src/test/scala/firrtlTests/UnitTests.scala
+++ b/src/test/scala/firrtlTests/UnitTests.scala
@@ -22,6 +22,25 @@ class UnitTests extends FirrtlFlatSpec {
}
}
+ "Pull muxes" should "not be exponential in runtime" in {
+ val passes = Seq(
+ ToWorkingIR,
+ CheckHighForm,
+ ResolveKinds,
+ InferTypes,
+ CheckTypes,
+ PullMuxes)
+ val input =
+ """circuit Unit :
+ | module Unit :
+ | input _2: UInt<1>
+ | output x: UInt<32>
+ | x <= cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat( _2, cat(_2, cat(_2, cat(_2, _2)))))))))))))))))))))))))))))))""".stripMargin
+ passes.foldLeft(parse(input)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }
+ }
+
"Connecting bundles of different types" should "throw an exception" in {
val passes = Seq(
ToWorkingIR,