diff options
| author | Donggyu | 2016-09-12 12:43:23 -0700 |
|---|---|---|
| committer | GitHub | 2016-09-12 12:43:23 -0700 |
| commit | 00bef01b6df158939406f3e744cbdda544823ae5 (patch) | |
| tree | 30a09340c7dd7e21ed031de125dc1e83a2c13b37 /src | |
| parent | 20ff9c96a7c07df8e0cb91444f223384261d35fe (diff) | |
| parent | 9e7ce6366454347e0ad912c7da0252070e5bb4a1 (diff) | |
Merge pull request #247 from ucb-bar/fix-invalid
Bugfix: ExpandWhen was emitting WInvalid()
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/ExpandWhens.scala | 21 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ExpandWhensSpec.scala | 42 |
2 files changed, 55 insertions, 8 deletions
diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index dcefb20f..5a7a7bac 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -137,15 +137,20 @@ object ExpandWhens extends Pass { conseqNetlist getOrElse (lvalue, altNetlist(lvalue)) } - nodes get res match { - case Some(name) => - netlist(lvalue) = WRef(name, res.tpe, NodeKind(), MALE) + res match { + case _: ValidIf | _: Mux | _: DoPrim => nodes get res match { + case Some(name) => + netlist(lvalue) = WRef(name, res.tpe, NodeKind(), MALE) + EmptyStmt + case None => + val name = namespace.newTemp + nodes(res) = name + netlist(lvalue) = WRef(name, res.tpe, NodeKind(), MALE) + DefNode(s.info, name, res) + } + case _ => + netlist(lvalue) = res EmptyStmt - case None => - val name = namespace.newTemp - nodes(res) = name - netlist(lvalue) = WRef(name, res.tpe, NodeKind(), MALE) - DefNode(s.info, name, res) } } Block(Seq(conseqStmt, altStmt) ++ memos) diff --git a/src/test/scala/firrtlTests/ExpandWhensSpec.scala b/src/test/scala/firrtlTests/ExpandWhensSpec.scala index 6219fd8c..82809dc8 100644 --- a/src/test/scala/firrtlTests/ExpandWhensSpec.scala +++ b/src/test/scala/firrtlTests/ExpandWhensSpec.scala @@ -32,9 +32,51 @@ import org.scalatest._ import org.scalatest.prop._ import firrtl._ import firrtl.passes._ +import firrtl.ir._ +import firrtl.Parser.IgnoreInfo class ExpandWhensSpec extends FirrtlFlatSpec { + private def parse(input: String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) + private def executeTest(input: String, notExpected: String, passes: Seq[Pass]) = { + val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { + (c: Circuit, p: Pass) => p.run(c) + } + val lines = c.serialize.split("\n") map normalized + + lines foreach { l => + l.contains(notExpected) should be (false) + } + } "Expand Whens" should "compile and run" in { runFirrtlTest("ExpandWhens", "/passes/ExpandWhens") } + "Expand Whens" should "not emit INVALID" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + Uniquify, + ResolveKinds, + InferTypes, + ResolveGenders, + CheckGenders, + InferWidths, + CheckWidths, + PullMuxes, + ExpandConnects, + RemoveAccesses, + ExpandWhens) + val input = + """|circuit Tester : + | module Tester : + | input p : UInt<1> + | when p : + | wire a : {b : UInt<64>, c : UInt<64>} + | a is invalid + | a.b <= UInt<64>("h04000000000000000")""".stripMargin + val check = "INVALID" + executeTest(input, check, passes) + } } |
