diff options
| author | Adam Izraelevitz | 2016-11-21 13:30:11 -0800 |
|---|---|---|
| committer | GitHub | 2016-11-21 13:30:11 -0800 |
| commit | 9a967a27aa8bb51f4b62969d2889f9a9caa48e31 (patch) | |
| tree | dad7370c71df7ce8d4628f70b4079296bfee4d99 /src | |
| parent | aad8e09f355f4804d29361d75f54ce4a5c2d5c52 (diff) | |
Bugfix: exponential runtime of pull muxes (#379)
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 53 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/UnitTests.scala | 19 |
2 files changed, 44 insertions, 28 deletions
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 46e68d5c..87458a2b 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -54,35 +54,32 @@ object ToWorkingIR extends Pass { object PullMuxes extends Pass { def name = "Pull Muxes" def run(c: Circuit): Circuit = { - def pull_muxes_e(e: Expression): Expression = { - val exxx = e map pull_muxes_e match { - case ex: WSubField => ex.exp match { - case exx: Mux => Mux(exx.cond, - WSubField(exx.tval, ex.name, ex.tpe, ex.gender), - WSubField(exx.fval, ex.name, ex.tpe, ex.gender), ex.tpe) - case exx: ValidIf => ValidIf(exx.cond, - WSubField(exx.value, ex.name, ex.tpe, ex.gender), ex.tpe) - case _ => ex // case exx => exx causes failed tests - } - case ex: WSubIndex => ex.exp match { - case exx: Mux => Mux(exx.cond, - WSubIndex(exx.tval, ex.value, ex.tpe, ex.gender), - WSubIndex(exx.fval, ex.value, ex.tpe, ex.gender), ex.tpe) - case exx: ValidIf => ValidIf(exx.cond, - WSubIndex(exx.value, ex.value, ex.tpe, ex.gender), ex.tpe) - case _ => ex // case exx => exx causes failed tests - } - case ex: WSubAccess => ex.exp match { - case exx: Mux => Mux(exx.cond, - WSubAccess(exx.tval, ex.index, ex.tpe, ex.gender), - WSubAccess(exx.fval, ex.index, ex.tpe, ex.gender), ex.tpe) - case exx: ValidIf => ValidIf(exx.cond, - WSubAccess(exx.value, ex.index, ex.tpe, ex.gender), ex.tpe) - case _ => ex // case exx => exx causes failed tests - } - case ex => ex + def pull_muxes_e(e: Expression): Expression = e map pull_muxes_e match { + case ex: WSubField => ex.exp match { + case exx: Mux => Mux(exx.cond, + WSubField(exx.tval, ex.name, ex.tpe, ex.gender), + WSubField(exx.fval, ex.name, ex.tpe, ex.gender), ex.tpe) + case exx: ValidIf => ValidIf(exx.cond, + WSubField(exx.value, ex.name, ex.tpe, ex.gender), ex.tpe) + case _ => ex // case exx => exx causes failed tests } - exxx map pull_muxes_e + case ex: WSubIndex => ex.exp match { + case exx: Mux => Mux(exx.cond, + WSubIndex(exx.tval, ex.value, ex.tpe, ex.gender), + WSubIndex(exx.fval, ex.value, ex.tpe, ex.gender), ex.tpe) + case exx: ValidIf => ValidIf(exx.cond, + WSubIndex(exx.value, ex.value, ex.tpe, ex.gender), ex.tpe) + case _ => ex // case exx => exx causes failed tests + } + case ex: WSubAccess => ex.exp match { + case exx: Mux => Mux(exx.cond, + WSubAccess(exx.tval, ex.index, ex.tpe, ex.gender), + WSubAccess(exx.fval, ex.index, ex.tpe, ex.gender), ex.tpe) + case exx: ValidIf => ValidIf(exx.cond, + WSubAccess(exx.value, ex.index, ex.tpe, ex.gender), ex.tpe) + case _ => ex // case exx => exx causes failed tests + } + case ex => ex } def pull_muxes(s: Statement): Statement = s map pull_muxes map pull_muxes_e val modulesx = c.modules.map { diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index 1e181141..e2f8f729 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -22,6 +22,25 @@ class UnitTests extends FirrtlFlatSpec { } } + "Pull muxes" should "not be exponential in runtime" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + PullMuxes) + val input = + """circuit Unit : + | module Unit : + | input _2: UInt<1> + | output x: UInt<32> + | x <= cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat( _2, cat(_2, cat(_2, cat(_2, _2)))))))))))))))))))))))))))))))""".stripMargin + passes.foldLeft(parse(input)) { + (c: Circuit, p: Pass) => p.run(c) + } + } + "Connecting bundles of different types" should "throw an exception" in { val passes = Seq( ToWorkingIR, |
