From 427095ad97ac31e994fee3d083eb18f78e701004 Mon Sep 17 00:00:00 2001 From: Jack Koenig Date: Thu, 6 Jul 2017 18:16:33 -0700 Subject: Fix ConstProp bug where multiple names would swap with one Fixes issue in https://github.com/freechipsproject/rocket-chip/pull/848 --- .../firrtl/transforms/ConstantPropagation.scala | 4 +++- .../firrtlTests/ConstantPropagationTests.scala | 27 ++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index bf8b1a55..46c12b2d 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -304,7 +304,9 @@ class ConstantPropagation extends Transform { // When propagating a reference, check if we want to keep the name that would be deleted def propagateRef(lname: String, value: Expression): Unit = { value match { - case WRef(rname,_,_,_) if betterName(lname, rname) => + case WRef(rname,_,_,_) if betterName(lname, rname) && !swapMap.contains(rname) => + assert(!swapMap.contains(lname)) // <- Shouldn't be possible because lname is either a + // node declaration or the single connection to a wire or register swapMap += (lname -> rname, rname -> lname) case _ => } diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index 75c43cf2..380d53e5 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -448,6 +448,33 @@ class ConstantPropagationSpec extends FirrtlFlatSpec { node _T_1 = n z <= n n <= x +""" + (parse(exec(input))) should be (parse(check)) + } + + // ============================= + "ConstProp" should "only swap a given name with one other name" in { + val input = +"""circuit Top : + module Top : + input x : UInt<1> + input y : UInt<1> + output z : UInt<3> + node _T_1 = add(x, y) + node n = _T_1 + node m = _T_1 + z <= add(n, m) +""" + val check = +"""circuit Top : + module Top : + input x : UInt<1> + input y : UInt<1> + output z : UInt<3> + node n = add(x, y) + node _T_1 = n + node m = n + z <= add(n, n) """ (parse(exec(input))) should be (parse(check)) } -- cgit v1.2.3 From 97642d6ddeca4e2109010ac5d6a0a199df01f28c Mon Sep 17 00:00:00 2001 From: Donggyu Kim Date: Mon, 17 Jul 2017 11:48:11 -0700 Subject: do not swap wire names with node names --- .../firrtl/transforms/ConstantPropagation.scala | 2 +- .../firrtlTests/ConstantPropagationTests.scala | 30 ++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 46c12b2d..d5a4b7e1 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -317,7 +317,7 @@ class ConstantPropagation extends Transform { val stmtx = s map constPropStmt map constPropExpression stmtx match { case x: DefNode if !dontTouches.contains(x.name) => propagateRef(x.name, x.value) - case Connect(_, WRef(wname, wtpe, WireKind, _), expr) if !dontTouches.contains(wname) => + case Connect(_, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) => val exprx = constPropExpression(pad(expr, wtpe)) propagateRef(wname, exprx) // Const prop registers that are fed only a constant or a mux between and constant and the diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index 380d53e5..e42ecfac 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -475,6 +475,36 @@ class ConstantPropagationSpec extends FirrtlFlatSpec { node _T_1 = n node m = n z <= add(n, n) +""" + (parse(exec(input))) should be (parse(check)) + } + + "ConstProp" should "NOT swap wire names with node names" in { + val input = +"""circuit Top : + module Top : + input clock : Clock + input x : UInt<1> + input y : UInt<1> + output z : UInt<1> + wire hit : UInt<1> + node _T_1 = or(x, y) + node _T_2 = eq(_T_1, UInt<1>(1)) + hit <= _T_2 + z <= hit +""" + val check = +"""circuit Top : + module Top : + input clock : Clock + input x : UInt<1> + input y : UInt<1> + output z : UInt<1> + wire hit : UInt<1> + node _T_1 = or(x, y) + node _T_2 = _T_1 + hit <= _T_1 + z <= hit """ (parse(exec(input))) should be (parse(check)) } -- cgit v1.2.3