From 427095ad97ac31e994fee3d083eb18f78e701004 Mon Sep 17 00:00:00 2001 From: Jack Koenig Date: Thu, 6 Jul 2017 18:16:33 -0700 Subject: Fix ConstProp bug where multiple names would swap with one Fixes issue in https://github.com/freechipsproject/rocket-chip/pull/848 --- .../firrtl/transforms/ConstantPropagation.scala | 4 +++- .../firrtlTests/ConstantPropagationTests.scala | 27 ++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index bf8b1a55..46c12b2d 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -304,7 +304,9 @@ class ConstantPropagation extends Transform { // When propagating a reference, check if we want to keep the name that would be deleted def propagateRef(lname: String, value: Expression): Unit = { value match { - case WRef(rname,_,_,_) if betterName(lname, rname) => + case WRef(rname,_,_,_) if betterName(lname, rname) && !swapMap.contains(rname) => + assert(!swapMap.contains(lname)) // <- Shouldn't be possible because lname is either a + // node declaration or the single connection to a wire or register swapMap += (lname -> rname, rname -> lname) case _ => } diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index 75c43cf2..380d53e5 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -448,6 +448,33 @@ class ConstantPropagationSpec extends FirrtlFlatSpec { node _T_1 = n z <= n n <= x +""" + (parse(exec(input))) should be (parse(check)) + } + + // ============================= + "ConstProp" should "only swap a given name with one other name" in { + val input = +"""circuit Top : + module Top : + input x : UInt<1> + input y : UInt<1> + output z : UInt<3> + node _T_1 = add(x, y) + node n = _T_1 + node m = _T_1 + z <= add(n, m) +""" + val check = +"""circuit Top : + module Top : + input x : UInt<1> + input y : UInt<1> + output z : UInt<3> + node n = add(x, y) + node _T_1 = n + node m = n + z <= add(n, n) """ (parse(exec(input))) should be (parse(check)) } -- cgit v1.2.3