aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAdam Izraelevitz2016-11-14 10:52:21 -0800
committerGitHub2016-11-14 10:52:21 -0800
commitcf1372d5b721c2384f88632e76841e6dc6772c6c (patch)
tree6232b14043cea1d930f8ac9291dba64e0fc039a7 /src
parentc19a53a562883ebb7d9c6131c4ef308bcfbd720a (diff)
parentfdde6f839d7f6811e12127dbe9f3f1ae429ee12c (diff)
Bugfix inferwidth (#372)
* Bugfix: removed recursive removal in infer widths This will certainly lead to more uninferred width errors, but now widths that were previously incorrectly inferred are now correctly uninferred. An example is: reg r : UInt, clock with: (reset => (reset, UInt<2>(3))) node x = add(r, r) r <= x Here, r's width follows the following formula, which cannot be solved: rWidth >= max(max(rWidth, rWidth) + 1, 2) * Added optimizations to for better width inference Also added exceptions for uninferred widths when checking DoPrim width legality to not trigger compiler error * Added additional optimizations Required for passing all chisel3 tests
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/Checks.scala14
-rw-r--r--src/main/scala/firrtl/passes/InferWidths.scala60
-rw-r--r--src/test/scala/firrtlTests/WidthSpec.scala23
3 files changed, 86 insertions, 11 deletions
diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala
index 4eb0b26c..0d16ef00 100644
--- a/src/main/scala/firrtl/passes/Checks.scala
+++ b/src/main/scala/firrtl/passes/Checks.scala
@@ -600,6 +600,12 @@ object CheckWidths extends Pass {
w
}
+ def hasWidth(tpe: Type): Boolean = tpe match {
+ case GroundType(IntWidth(w)) => true
+ case GroundType(_) => false
+ case _ => println(tpe); throwInternalError
+ }
+
def check_width_t(info: Info, mname: String)(t: Type): Type =
t map check_width_t(info, mname) map check_width_w(info, mname)
@@ -615,13 +621,13 @@ object CheckWidths extends Pass {
errors append new WidthTooSmall(info, mname, e.value)
case _ =>
}
- case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if bitWidth(a.tpe) <= hi =>
+ case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) <= hi) =>
errors append new BitsWidthException(info, mname, hi, bitWidth(a.tpe))
- case DoPrim(Head, Seq(a), Seq(n), _) if bitWidth(a.tpe) < n =>
+ case DoPrim(Head, Seq(a), Seq(n), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) < n) =>
errors append new HeadWidthException(info, mname, n, bitWidth(a.tpe))
- case DoPrim(Tail, Seq(a), Seq(n), _) if bitWidth(a.tpe) <= n =>
+ case DoPrim(Tail, Seq(a), Seq(n), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) <= n) =>
errors append new TailWidthException(info, mname, n, bitWidth(a.tpe))
- case DoPrim(Dshl, Seq(a, b), _, _) if bitWidth(b.tpe) >= BigInt(32) =>
+ case DoPrim(Dshl, Seq(a, b), _, _) if (hasWidth(a.tpe) && bitWidth(b.tpe) >= BigInt(32)) =>
errors append new WidthTooBig(info, mname)
case _ =>
}
diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala
index 67f2b90e..36466c09 100644
--- a/src/main/scala/firrtl/passes/InferWidths.scala
+++ b/src/main/scala/firrtl/passes/InferWidths.scala
@@ -27,28 +27,73 @@ object InferWidths extends Pass {
case _ => h
})
}
- def simplify(w: Width): Width = w map simplify match {
- case wx: MinWidth => MinWidth(unique((wx.args foldLeft Seq[Width]()){
+ def pullMinMax(w: Width): Width = w map pullMinMax match {
+ case PlusWidth(MaxWidth(maxs), IntWidth(i)) => MaxWidth(maxs.map(m => PlusWidth(m, IntWidth(i))))
+ case PlusWidth(IntWidth(i), MaxWidth(maxs)) => MaxWidth(maxs.map(m => PlusWidth(m, IntWidth(i))))
+ case MinusWidth(MaxWidth(maxs), IntWidth(i)) => MaxWidth(maxs.map(m => MinusWidth(m, IntWidth(i))))
+ case MinusWidth(IntWidth(i), MaxWidth(maxs)) => MaxWidth(maxs.map(m => MinusWidth(IntWidth(i), m)))
+ case PlusWidth(MinWidth(mins), IntWidth(i)) => MinWidth(mins.map(m => PlusWidth(m, IntWidth(i))))
+ case PlusWidth(IntWidth(i), MinWidth(mins)) => MinWidth(mins.map(m => PlusWidth(m, IntWidth(i))))
+ case MinusWidth(MinWidth(mins), IntWidth(i)) => MinWidth(mins.map(m => MinusWidth(m, IntWidth(i))))
+ case MinusWidth(IntWidth(i), MinWidth(mins)) => MinWidth(mins.map(m => MinusWidth(IntWidth(i), m)))
+ case wx => wx
+ }
+ def collectMinMax(w: Width): Width = w map collectMinMax match {
+ case MinWidth(args) => MinWidth(unique((args.foldLeft(Seq[Width]())) {
case (res, wxx: MinWidth) => res ++ wxx.args
case (res, wxx) => res :+ wxx
}))
- case wx: MaxWidth => MaxWidth(unique((wx.args foldLeft Seq[Width]()){
+ case MaxWidth(args) => MaxWidth(unique((args.foldLeft(Seq[Width]())) {
case (res, wxx: MaxWidth) => res ++ wxx.args
case (res, wxx) => res :+ wxx
}))
+ case wx => wx
+ }
+ def mergePlusMinus(w: Width): Width = w map mergePlusMinus match {
case wx: PlusWidth => (wx.arg1, wx.arg2) match {
- case (w1: IntWidth, w2 :IntWidth) => IntWidth(w1.width + w2.width)
+ case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width + w2.width)
+ case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1)
+ case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1)
+ case (IntWidth(y), PlusWidth(w1, IntWidth(x))) => PlusWidth(IntWidth(x + y), w1)
+ case (IntWidth(y), PlusWidth(IntWidth(x), w1)) => PlusWidth(IntWidth(x + y), w1)
+ case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(y - x), w1)
+ case (IntWidth(y), MinusWidth(w1, IntWidth(x))) => PlusWidth(IntWidth(y - x), w1)
case _ => wx
}
case wx: MinusWidth => (wx.arg1, wx.arg2) match {
case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width)
+ case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1)
+ case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1)
+ case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1)
case _ => wx
}
case wx: ExpWidth => wx.arg1 match {
case w1: IntWidth => IntWidth(BigInt((math.pow(2, w1.width.toDouble) - 1).toLong))
case _ => wx
}
- case _ => w
+ case wx => wx
+ }
+ def removeZeros(w: Width): Width = w map removeZeros match {
+ case wx: PlusWidth => (wx.arg1, wx.arg2) match {
+ case (w1, IntWidth(x)) if x == 0 => w1
+ case (IntWidth(x), w1) if x == 0 => w1
+ case _ => wx
+ }
+ case wx: MinusWidth => (wx.arg1, wx.arg2) match {
+ case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width)
+ case (w1, IntWidth(x)) if x == 0 => w1
+ case _ => wx
+ }
+ case wx => wx
+ }
+ def simplify(w: Width): Width = {
+ val opts = Seq(
+ pullMinMax _,
+ collectMinMax _,
+ mergePlusMinus _,
+ removeZeros _
+ )
+ opts.foldLeft(w) { (width, opt) => opt(width) }
}
def substitute(h: ConstraintMap)(w: Width): Width = {
@@ -81,9 +126,10 @@ object InferWidths extends Pass {
def remove_cycle(n: String)(w: Width): Width = {
//;println-all-debug(["Removing cycle for " n " inside " w])
- w map remove_cycle(n) match {
+ w match {
case wx: MaxWidth => MaxWidth(wx.args filter {
case wxx: VarWidth => !(n equals wxx.name)
+ case MinusWidth(VarWidth(name), IntWidth(i)) if ((i >= 0) && (n == name)) => false
case _ => true
})
case wx: MinusWidth => wx.arg1 match {
@@ -126,7 +172,7 @@ object InferWidths extends Pass {
//for (x <- f) println(x)
//println("=========================")
- val e_sub = substitute(f)(e)
+ val e_sub = simplify(substitute(f)(e))
//println("Solving " + n + " => " + e)
//println("After Substitute: " + n + " => " + e_sub)
diff --git a/src/test/scala/firrtlTests/WidthSpec.scala b/src/test/scala/firrtlTests/WidthSpec.scala
index f2938016..9b0ee139 100644
--- a/src/test/scala/firrtlTests/WidthSpec.scala
+++ b/src/test/scala/firrtlTests/WidthSpec.scala
@@ -78,4 +78,27 @@ class WidthSpec extends FirrtlFlatSpec {
executeTest(input, Nil, passes)
}
}
+ "Circular reg depending on reg + 1" should "error" in {
+ val passes = Seq(
+ ToWorkingIR,
+ CheckHighForm,
+ ResolveKinds,
+ InferTypes,
+ CheckTypes,
+ InferWidths,
+ CheckWidths)
+ val input =
+ """circuit Unit :
+ | module Unit :
+ | input clock: Clock
+ | input reset: UInt<1>
+ | reg r : UInt, clock with :
+ | reset => (reset, UInt(3))
+ | node T_7 = add(r, r)
+ | r <= T_7
+ |""".stripMargin
+ intercept[CheckWidths.UninferredWidth] {
+ executeTest(input, Nil, passes)
+ }
+ }
}