diff options
| author | azidar | 2016-11-09 20:14:13 -0500 |
|---|---|---|
| committer | azidar | 2016-11-09 20:14:13 -0500 |
| commit | 76f91f313312895476ae06b45ed72494ab653f1c (patch) | |
| tree | 689b8cc6d3396b68b7733f8c4bf24f4afb8b38e6 /src | |
| parent | ba417a89c1a654d24c628c7e276433c9f5d64e55 (diff) | |
Added optimizations to for better width inference
Also added exceptions for uninferred widths when checking DoPrim width
legality to not trigger compiler error
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/Checks.scala | 14 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/InferWidths.scala | 16 |
2 files changed, 24 insertions, 6 deletions
diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index 4eb0b26c..0d16ef00 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -600,6 +600,12 @@ object CheckWidths extends Pass { w } + def hasWidth(tpe: Type): Boolean = tpe match { + case GroundType(IntWidth(w)) => true + case GroundType(_) => false + case _ => println(tpe); throwInternalError + } + def check_width_t(info: Info, mname: String)(t: Type): Type = t map check_width_t(info, mname) map check_width_w(info, mname) @@ -615,13 +621,13 @@ object CheckWidths extends Pass { errors append new WidthTooSmall(info, mname, e.value) case _ => } - case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if bitWidth(a.tpe) <= hi => + case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) <= hi) => errors append new BitsWidthException(info, mname, hi, bitWidth(a.tpe)) - case DoPrim(Head, Seq(a), Seq(n), _) if bitWidth(a.tpe) < n => + case DoPrim(Head, Seq(a), Seq(n), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) < n) => errors append new HeadWidthException(info, mname, n, bitWidth(a.tpe)) - case DoPrim(Tail, Seq(a), Seq(n), _) if bitWidth(a.tpe) <= n => + case DoPrim(Tail, Seq(a), Seq(n), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) <= n) => errors append new TailWidthException(info, mname, n, bitWidth(a.tpe)) - case DoPrim(Dshl, Seq(a, b), _, _) if bitWidth(b.tpe) >= BigInt(32) => + case DoPrim(Dshl, Seq(a, b), _, _) if (hasWidth(a.tpe) && bitWidth(b.tpe) >= BigInt(32)) => errors append new WidthTooBig(info, mname) case _ => } diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index 619bb25a..a7a150bd 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -37,11 +37,23 @@ object InferWidths extends Pass { case (res, wxx) => res :+ wxx })) case wx: PlusWidth => (wx.arg1, wx.arg2) match { - case (w1: IntWidth, w2 :IntWidth) => IntWidth(w1.width + w2.width) + case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width + w2.width) + case (w1, IntWidth(x)) if x == 0 => w1 + case (IntWidth(x), w1) if x == 0 => w1 + case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1) + case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1) + case (IntWidth(y), PlusWidth(w1, IntWidth(x))) => PlusWidth(IntWidth(x + y), w1) + case (IntWidth(y), PlusWidth(IntWidth(x), w1)) => PlusWidth(IntWidth(x + y), w1) + case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => simplify(PlusWidth(IntWidth(y - x), w1)) // call simplify in case y = x + case (IntWidth(y), MinusWidth(w1, IntWidth(x))) => simplify(PlusWidth(IntWidth(y - x), w1)) // call simplify in case y = x case _ => wx } case wx: MinusWidth => (wx.arg1, wx.arg2) match { case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width) + case (w1, IntWidth(x)) if x == 0 => w1 + case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => simplify(PlusWidth(IntWidth(x - y), w1)) // call simplify in case y = x + case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => simplify(PlusWidth(IntWidth(x - y), w1)) // call simplify in case y = x + case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => simplify(PlusWidth(IntWidth(x - y), w1)) // call simplify in case y = x case _ => wx } case wx: ExpWidth => wx.arg1 match { @@ -126,7 +138,7 @@ object InferWidths extends Pass { //for (x <- f) println(x) //println("=========================") - val e_sub = substitute(f)(e) + val e_sub = simplify(substitute(f)(e)) //println("Solving " + n + " => " + e) //println("After Substitute: " + n + " => " + e_sub) |
