From ba417a89c1a654d24c628c7e276433c9f5d64e55 Mon Sep 17 00:00:00 2001 From: azidar Date: Wed, 9 Nov 2016 19:05:52 -0500 Subject: Bugfix: removed recursive removal in infer widths This will certainly lead to more uninferred width errors, but now widths that were previously incorrectly inferred are now correctly uninferred. An example is: reg r : UInt, clock with: (reset => (reset, UInt<2>(3))) node x = add(r, r) r <= x Here, r's width follows the following formula, which cannot be solved: rWidth >= max(max(rWidth, rWidth) + 1, 2) --- src/main/scala/firrtl/passes/InferWidths.scala | 2 +- src/test/scala/firrtlTests/WidthSpec.scala | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index 67f2b90e..619bb25a 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -81,7 +81,7 @@ object InferWidths extends Pass { def remove_cycle(n: String)(w: Width): Width = { //;println-all-debug(["Removing cycle for " n " inside " w]) - w map remove_cycle(n) match { + w match { case wx: MaxWidth => MaxWidth(wx.args filter { case wxx: VarWidth => !(n equals wxx.name) case _ => true diff --git a/src/test/scala/firrtlTests/WidthSpec.scala b/src/test/scala/firrtlTests/WidthSpec.scala index f2938016..9b0ee139 100644 --- a/src/test/scala/firrtlTests/WidthSpec.scala +++ b/src/test/scala/firrtlTests/WidthSpec.scala @@ -78,4 +78,27 @@ class WidthSpec extends FirrtlFlatSpec { executeTest(input, Nil, passes) } } + "Circular reg depending on reg + 1" should "error" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + InferWidths, + CheckWidths) + val input = + """circuit Unit : + | module Unit : + | input clock: Clock + | input reset: UInt<1> + | reg r : UInt, clock with : + | reset => (reset, UInt(3)) + | node T_7 = add(r, r) + | r <= T_7 + |""".stripMargin + intercept[CheckWidths.UninferredWidth] { + executeTest(input, Nil, passes) + } + } } -- cgit v1.2.3 From 76f91f313312895476ae06b45ed72494ab653f1c Mon Sep 17 00:00:00 2001 From: azidar Date: Wed, 9 Nov 2016 20:14:13 -0500 Subject: Added optimizations to for better width inference Also added exceptions for uninferred widths when checking DoPrim width legality to not trigger compiler error --- src/main/scala/firrtl/passes/Checks.scala | 14 ++++++++++---- src/main/scala/firrtl/passes/InferWidths.scala | 16 ++++++++++++++-- 2 files changed, 24 insertions(+), 6 deletions(-) (limited to 'src') diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index 4eb0b26c..0d16ef00 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -600,6 +600,12 @@ object CheckWidths extends Pass { w } + def hasWidth(tpe: Type): Boolean = tpe match { + case GroundType(IntWidth(w)) => true + case GroundType(_) => false + case _ => println(tpe); throwInternalError + } + def check_width_t(info: Info, mname: String)(t: Type): Type = t map check_width_t(info, mname) map check_width_w(info, mname) @@ -615,13 +621,13 @@ object CheckWidths extends Pass { errors append new WidthTooSmall(info, mname, e.value) case _ => } - case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if bitWidth(a.tpe) <= hi => + case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) <= hi) => errors append new BitsWidthException(info, mname, hi, bitWidth(a.tpe)) - case DoPrim(Head, Seq(a), Seq(n), _) if bitWidth(a.tpe) < n => + case DoPrim(Head, Seq(a), Seq(n), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) < n) => errors append new HeadWidthException(info, mname, n, bitWidth(a.tpe)) - case DoPrim(Tail, Seq(a), Seq(n), _) if bitWidth(a.tpe) <= n => + case DoPrim(Tail, Seq(a), Seq(n), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) <= n) => errors append new TailWidthException(info, mname, n, bitWidth(a.tpe)) - case DoPrim(Dshl, Seq(a, b), _, _) if bitWidth(b.tpe) >= BigInt(32) => + case DoPrim(Dshl, Seq(a, b), _, _) if (hasWidth(a.tpe) && bitWidth(b.tpe) >= BigInt(32)) => errors append new WidthTooBig(info, mname) case _ => } diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index 619bb25a..a7a150bd 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -37,11 +37,23 @@ object InferWidths extends Pass { case (res, wxx) => res :+ wxx })) case wx: PlusWidth => (wx.arg1, wx.arg2) match { - case (w1: IntWidth, w2 :IntWidth) => IntWidth(w1.width + w2.width) + case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width + w2.width) + case (w1, IntWidth(x)) if x == 0 => w1 + case (IntWidth(x), w1) if x == 0 => w1 + case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1) + case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1) + case (IntWidth(y), PlusWidth(w1, IntWidth(x))) => PlusWidth(IntWidth(x + y), w1) + case (IntWidth(y), PlusWidth(IntWidth(x), w1)) => PlusWidth(IntWidth(x + y), w1) + case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => simplify(PlusWidth(IntWidth(y - x), w1)) // call simplify in case y = x + case (IntWidth(y), MinusWidth(w1, IntWidth(x))) => simplify(PlusWidth(IntWidth(y - x), w1)) // call simplify in case y = x case _ => wx } case wx: MinusWidth => (wx.arg1, wx.arg2) match { case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width) + case (w1, IntWidth(x)) if x == 0 => w1 + case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => simplify(PlusWidth(IntWidth(x - y), w1)) // call simplify in case y = x + case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => simplify(PlusWidth(IntWidth(x - y), w1)) // call simplify in case y = x + case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => simplify(PlusWidth(IntWidth(x - y), w1)) // call simplify in case y = x case _ => wx } case wx: ExpWidth => wx.arg1 match { @@ -126,7 +138,7 @@ object InferWidths extends Pass { //for (x <- f) println(x) //println("=========================") - val e_sub = substitute(f)(e) + val e_sub = simplify(substitute(f)(e)) //println("Solving " + n + " => " + e) //println("After Substitute: " + n + " => " + e_sub) -- cgit v1.2.3 From fdde6f839d7f6811e12127dbe9f3f1ae429ee12c Mon Sep 17 00:00:00 2001 From: azidar Date: Thu, 10 Nov 2016 13:42:38 -0500 Subject: Added additional optimizations Required for passing all chisel3 tests --- src/main/scala/firrtl/passes/InferWidths.scala | 58 ++++++++++++++++++++------ 1 file changed, 46 insertions(+), 12 deletions(-) (limited to 'src') diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index a7a150bd..36466c09 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -27,40 +27,73 @@ object InferWidths extends Pass { case _ => h }) } - def simplify(w: Width): Width = w map simplify match { - case wx: MinWidth => MinWidth(unique((wx.args foldLeft Seq[Width]()){ + def pullMinMax(w: Width): Width = w map pullMinMax match { + case PlusWidth(MaxWidth(maxs), IntWidth(i)) => MaxWidth(maxs.map(m => PlusWidth(m, IntWidth(i)))) + case PlusWidth(IntWidth(i), MaxWidth(maxs)) => MaxWidth(maxs.map(m => PlusWidth(m, IntWidth(i)))) + case MinusWidth(MaxWidth(maxs), IntWidth(i)) => MaxWidth(maxs.map(m => MinusWidth(m, IntWidth(i)))) + case MinusWidth(IntWidth(i), MaxWidth(maxs)) => MaxWidth(maxs.map(m => MinusWidth(IntWidth(i), m))) + case PlusWidth(MinWidth(mins), IntWidth(i)) => MinWidth(mins.map(m => PlusWidth(m, IntWidth(i)))) + case PlusWidth(IntWidth(i), MinWidth(mins)) => MinWidth(mins.map(m => PlusWidth(m, IntWidth(i)))) + case MinusWidth(MinWidth(mins), IntWidth(i)) => MinWidth(mins.map(m => MinusWidth(m, IntWidth(i)))) + case MinusWidth(IntWidth(i), MinWidth(mins)) => MinWidth(mins.map(m => MinusWidth(IntWidth(i), m))) + case wx => wx + } + def collectMinMax(w: Width): Width = w map collectMinMax match { + case MinWidth(args) => MinWidth(unique((args.foldLeft(Seq[Width]())) { case (res, wxx: MinWidth) => res ++ wxx.args case (res, wxx) => res :+ wxx })) - case wx: MaxWidth => MaxWidth(unique((wx.args foldLeft Seq[Width]()){ + case MaxWidth(args) => MaxWidth(unique((args.foldLeft(Seq[Width]())) { case (res, wxx: MaxWidth) => res ++ wxx.args case (res, wxx) => res :+ wxx })) + case wx => wx + } + def mergePlusMinus(w: Width): Width = w map mergePlusMinus match { case wx: PlusWidth => (wx.arg1, wx.arg2) match { case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width + w2.width) - case (w1, IntWidth(x)) if x == 0 => w1 - case (IntWidth(x), w1) if x == 0 => w1 case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1) case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1) case (IntWidth(y), PlusWidth(w1, IntWidth(x))) => PlusWidth(IntWidth(x + y), w1) case (IntWidth(y), PlusWidth(IntWidth(x), w1)) => PlusWidth(IntWidth(x + y), w1) - case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => simplify(PlusWidth(IntWidth(y - x), w1)) // call simplify in case y = x - case (IntWidth(y), MinusWidth(w1, IntWidth(x))) => simplify(PlusWidth(IntWidth(y - x), w1)) // call simplify in case y = x + case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(y - x), w1) + case (IntWidth(y), MinusWidth(w1, IntWidth(x))) => PlusWidth(IntWidth(y - x), w1) case _ => wx } case wx: MinusWidth => (wx.arg1, wx.arg2) match { case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width) - case (w1, IntWidth(x)) if x == 0 => w1 - case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => simplify(PlusWidth(IntWidth(x - y), w1)) // call simplify in case y = x - case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => simplify(PlusWidth(IntWidth(x - y), w1)) // call simplify in case y = x - case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => simplify(PlusWidth(IntWidth(x - y), w1)) // call simplify in case y = x + case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1) + case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1) + case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1) case _ => wx } case wx: ExpWidth => wx.arg1 match { case w1: IntWidth => IntWidth(BigInt((math.pow(2, w1.width.toDouble) - 1).toLong)) case _ => wx } - case _ => w + case wx => wx + } + def removeZeros(w: Width): Width = w map removeZeros match { + case wx: PlusWidth => (wx.arg1, wx.arg2) match { + case (w1, IntWidth(x)) if x == 0 => w1 + case (IntWidth(x), w1) if x == 0 => w1 + case _ => wx + } + case wx: MinusWidth => (wx.arg1, wx.arg2) match { + case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width) + case (w1, IntWidth(x)) if x == 0 => w1 + case _ => wx + } + case wx => wx + } + def simplify(w: Width): Width = { + val opts = Seq( + pullMinMax _, + collectMinMax _, + mergePlusMinus _, + removeZeros _ + ) + opts.foldLeft(w) { (width, opt) => opt(width) } } def substitute(h: ConstraintMap)(w: Width): Width = { @@ -96,6 +129,7 @@ object InferWidths extends Pass { w match { case wx: MaxWidth => MaxWidth(wx.args filter { case wxx: VarWidth => !(n equals wxx.name) + case MinusWidth(VarWidth(name), IntWidth(i)) if ((i >= 0) && (n == name)) => false case _ => true }) case wx: MinusWidth => wx.arg1 match { -- cgit v1.2.3