diff options
| author | Adam Izraelevitz | 2016-11-14 10:52:21 -0800 |
|---|---|---|
| committer | GitHub | 2016-11-14 10:52:21 -0800 |
| commit | cf1372d5b721c2384f88632e76841e6dc6772c6c (patch) | |
| tree | 6232b14043cea1d930f8ac9291dba64e0fc039a7 /src | |
| parent | c19a53a562883ebb7d9c6131c4ef308bcfbd720a (diff) | |
| parent | fdde6f839d7f6811e12127dbe9f3f1ae429ee12c (diff) | |
Bugfix inferwidth (#372)
* Bugfix: removed recursive removal in infer widths
This will certainly lead to more uninferred width errors, but now widths
that were previously incorrectly inferred are now correctly uninferred.
An example is:
reg r : UInt, clock with: (reset => (reset, UInt<2>(3)))
node x = add(r, r)
r <= x
Here, r's width follows the following formula, which cannot be solved:
rWidth >= max(max(rWidth, rWidth) + 1, 2)
* Added optimizations to for better width inference
Also added exceptions for uninferred widths when checking DoPrim width
legality to not trigger compiler error
* Added additional optimizations
Required for passing all chisel3 tests
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/Checks.scala | 14 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/InferWidths.scala | 60 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/WidthSpec.scala | 23 |
3 files changed, 86 insertions, 11 deletions
diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index 4eb0b26c..0d16ef00 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -600,6 +600,12 @@ object CheckWidths extends Pass { w } + def hasWidth(tpe: Type): Boolean = tpe match { + case GroundType(IntWidth(w)) => true + case GroundType(_) => false + case _ => println(tpe); throwInternalError + } + def check_width_t(info: Info, mname: String)(t: Type): Type = t map check_width_t(info, mname) map check_width_w(info, mname) @@ -615,13 +621,13 @@ object CheckWidths extends Pass { errors append new WidthTooSmall(info, mname, e.value) case _ => } - case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if bitWidth(a.tpe) <= hi => + case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) <= hi) => errors append new BitsWidthException(info, mname, hi, bitWidth(a.tpe)) - case DoPrim(Head, Seq(a), Seq(n), _) if bitWidth(a.tpe) < n => + case DoPrim(Head, Seq(a), Seq(n), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) < n) => errors append new HeadWidthException(info, mname, n, bitWidth(a.tpe)) - case DoPrim(Tail, Seq(a), Seq(n), _) if bitWidth(a.tpe) <= n => + case DoPrim(Tail, Seq(a), Seq(n), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) <= n) => errors append new TailWidthException(info, mname, n, bitWidth(a.tpe)) - case DoPrim(Dshl, Seq(a, b), _, _) if bitWidth(b.tpe) >= BigInt(32) => + case DoPrim(Dshl, Seq(a, b), _, _) if (hasWidth(a.tpe) && bitWidth(b.tpe) >= BigInt(32)) => errors append new WidthTooBig(info, mname) case _ => } diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index 67f2b90e..36466c09 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -27,28 +27,73 @@ object InferWidths extends Pass { case _ => h }) } - def simplify(w: Width): Width = w map simplify match { - case wx: MinWidth => MinWidth(unique((wx.args foldLeft Seq[Width]()){ + def pullMinMax(w: Width): Width = w map pullMinMax match { + case PlusWidth(MaxWidth(maxs), IntWidth(i)) => MaxWidth(maxs.map(m => PlusWidth(m, IntWidth(i)))) + case PlusWidth(IntWidth(i), MaxWidth(maxs)) => MaxWidth(maxs.map(m => PlusWidth(m, IntWidth(i)))) + case MinusWidth(MaxWidth(maxs), IntWidth(i)) => MaxWidth(maxs.map(m => MinusWidth(m, IntWidth(i)))) + case MinusWidth(IntWidth(i), MaxWidth(maxs)) => MaxWidth(maxs.map(m => MinusWidth(IntWidth(i), m))) + case PlusWidth(MinWidth(mins), IntWidth(i)) => MinWidth(mins.map(m => PlusWidth(m, IntWidth(i)))) + case PlusWidth(IntWidth(i), MinWidth(mins)) => MinWidth(mins.map(m => PlusWidth(m, IntWidth(i)))) + case MinusWidth(MinWidth(mins), IntWidth(i)) => MinWidth(mins.map(m => MinusWidth(m, IntWidth(i)))) + case MinusWidth(IntWidth(i), MinWidth(mins)) => MinWidth(mins.map(m => MinusWidth(IntWidth(i), m))) + case wx => wx + } + def collectMinMax(w: Width): Width = w map collectMinMax match { + case MinWidth(args) => MinWidth(unique((args.foldLeft(Seq[Width]())) { case (res, wxx: MinWidth) => res ++ wxx.args case (res, wxx) => res :+ wxx })) - case wx: MaxWidth => MaxWidth(unique((wx.args foldLeft Seq[Width]()){ + case MaxWidth(args) => MaxWidth(unique((args.foldLeft(Seq[Width]())) { case (res, wxx: MaxWidth) => res ++ wxx.args case (res, wxx) => res :+ wxx })) + case wx => wx + } + def mergePlusMinus(w: Width): Width = w map mergePlusMinus match { case wx: PlusWidth => (wx.arg1, wx.arg2) match { - case (w1: IntWidth, w2 :IntWidth) => IntWidth(w1.width + w2.width) + case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width + w2.width) + case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1) + case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1) + case (IntWidth(y), PlusWidth(w1, IntWidth(x))) => PlusWidth(IntWidth(x + y), w1) + case (IntWidth(y), PlusWidth(IntWidth(x), w1)) => PlusWidth(IntWidth(x + y), w1) + case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(y - x), w1) + case (IntWidth(y), MinusWidth(w1, IntWidth(x))) => PlusWidth(IntWidth(y - x), w1) case _ => wx } case wx: MinusWidth => (wx.arg1, wx.arg2) match { case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width) + case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1) + case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1) + case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1) case _ => wx } case wx: ExpWidth => wx.arg1 match { case w1: IntWidth => IntWidth(BigInt((math.pow(2, w1.width.toDouble) - 1).toLong)) case _ => wx } - case _ => w + case wx => wx + } + def removeZeros(w: Width): Width = w map removeZeros match { + case wx: PlusWidth => (wx.arg1, wx.arg2) match { + case (w1, IntWidth(x)) if x == 0 => w1 + case (IntWidth(x), w1) if x == 0 => w1 + case _ => wx + } + case wx: MinusWidth => (wx.arg1, wx.arg2) match { + case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width) + case (w1, IntWidth(x)) if x == 0 => w1 + case _ => wx + } + case wx => wx + } + def simplify(w: Width): Width = { + val opts = Seq( + pullMinMax _, + collectMinMax _, + mergePlusMinus _, + removeZeros _ + ) + opts.foldLeft(w) { (width, opt) => opt(width) } } def substitute(h: ConstraintMap)(w: Width): Width = { @@ -81,9 +126,10 @@ object InferWidths extends Pass { def remove_cycle(n: String)(w: Width): Width = { //;println-all-debug(["Removing cycle for " n " inside " w]) - w map remove_cycle(n) match { + w match { case wx: MaxWidth => MaxWidth(wx.args filter { case wxx: VarWidth => !(n equals wxx.name) + case MinusWidth(VarWidth(name), IntWidth(i)) if ((i >= 0) && (n == name)) => false case _ => true }) case wx: MinusWidth => wx.arg1 match { @@ -126,7 +172,7 @@ object InferWidths extends Pass { //for (x <- f) println(x) //println("=========================") - val e_sub = substitute(f)(e) + val e_sub = simplify(substitute(f)(e)) //println("Solving " + n + " => " + e) //println("After Substitute: " + n + " => " + e_sub) diff --git a/src/test/scala/firrtlTests/WidthSpec.scala b/src/test/scala/firrtlTests/WidthSpec.scala index f2938016..9b0ee139 100644 --- a/src/test/scala/firrtlTests/WidthSpec.scala +++ b/src/test/scala/firrtlTests/WidthSpec.scala @@ -78,4 +78,27 @@ class WidthSpec extends FirrtlFlatSpec { executeTest(input, Nil, passes) } } + "Circular reg depending on reg + 1" should "error" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + InferWidths, + CheckWidths) + val input = + """circuit Unit : + | module Unit : + | input clock: Clock + | input reset: UInt<1> + | reg r : UInt, clock with : + | reset => (reset, UInt(3)) + | node T_7 = add(r, r) + | r <= T_7 + |""".stripMargin + intercept[CheckWidths.UninferredWidth] { + executeTest(input, Nil, passes) + } + } } |
