aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorazidar2016-11-09 20:14:13 -0500
committerazidar2016-11-09 20:14:13 -0500
commit76f91f313312895476ae06b45ed72494ab653f1c (patch)
tree689b8cc6d3396b68b7733f8c4bf24f4afb8b38e6 /src
parentba417a89c1a654d24c628c7e276433c9f5d64e55 (diff)
Added optimizations to for better width inference
Also added exceptions for uninferred widths when checking DoPrim width legality to not trigger compiler error
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/Checks.scala14
-rw-r--r--src/main/scala/firrtl/passes/InferWidths.scala16
2 files changed, 24 insertions, 6 deletions
diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala
index 4eb0b26c..0d16ef00 100644
--- a/src/main/scala/firrtl/passes/Checks.scala
+++ b/src/main/scala/firrtl/passes/Checks.scala
@@ -600,6 +600,12 @@ object CheckWidths extends Pass {
w
}
+ def hasWidth(tpe: Type): Boolean = tpe match {
+ case GroundType(IntWidth(w)) => true
+ case GroundType(_) => false
+ case _ => println(tpe); throwInternalError
+ }
+
def check_width_t(info: Info, mname: String)(t: Type): Type =
t map check_width_t(info, mname) map check_width_w(info, mname)
@@ -615,13 +621,13 @@ object CheckWidths extends Pass {
errors append new WidthTooSmall(info, mname, e.value)
case _ =>
}
- case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if bitWidth(a.tpe) <= hi =>
+ case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) <= hi) =>
errors append new BitsWidthException(info, mname, hi, bitWidth(a.tpe))
- case DoPrim(Head, Seq(a), Seq(n), _) if bitWidth(a.tpe) < n =>
+ case DoPrim(Head, Seq(a), Seq(n), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) < n) =>
errors append new HeadWidthException(info, mname, n, bitWidth(a.tpe))
- case DoPrim(Tail, Seq(a), Seq(n), _) if bitWidth(a.tpe) <= n =>
+ case DoPrim(Tail, Seq(a), Seq(n), _) if (hasWidth(a.tpe) && bitWidth(a.tpe) <= n) =>
errors append new TailWidthException(info, mname, n, bitWidth(a.tpe))
- case DoPrim(Dshl, Seq(a, b), _, _) if bitWidth(b.tpe) >= BigInt(32) =>
+ case DoPrim(Dshl, Seq(a, b), _, _) if (hasWidth(a.tpe) && bitWidth(b.tpe) >= BigInt(32)) =>
errors append new WidthTooBig(info, mname)
case _ =>
}
diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala
index 619bb25a..a7a150bd 100644
--- a/src/main/scala/firrtl/passes/InferWidths.scala
+++ b/src/main/scala/firrtl/passes/InferWidths.scala
@@ -37,11 +37,23 @@ object InferWidths extends Pass {
case (res, wxx) => res :+ wxx
}))
case wx: PlusWidth => (wx.arg1, wx.arg2) match {
- case (w1: IntWidth, w2 :IntWidth) => IntWidth(w1.width + w2.width)
+ case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width + w2.width)
+ case (w1, IntWidth(x)) if x == 0 => w1
+ case (IntWidth(x), w1) if x == 0 => w1
+ case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1)
+ case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1)
+ case (IntWidth(y), PlusWidth(w1, IntWidth(x))) => PlusWidth(IntWidth(x + y), w1)
+ case (IntWidth(y), PlusWidth(IntWidth(x), w1)) => PlusWidth(IntWidth(x + y), w1)
+ case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => simplify(PlusWidth(IntWidth(y - x), w1)) // call simplify in case y = x
+ case (IntWidth(y), MinusWidth(w1, IntWidth(x))) => simplify(PlusWidth(IntWidth(y - x), w1)) // call simplify in case y = x
case _ => wx
}
case wx: MinusWidth => (wx.arg1, wx.arg2) match {
case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width)
+ case (w1, IntWidth(x)) if x == 0 => w1
+ case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => simplify(PlusWidth(IntWidth(x - y), w1)) // call simplify in case y = x
+ case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => simplify(PlusWidth(IntWidth(x - y), w1)) // call simplify in case y = x
+ case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => simplify(PlusWidth(IntWidth(x - y), w1)) // call simplify in case y = x
case _ => wx
}
case wx: ExpWidth => wx.arg1 match {
@@ -126,7 +138,7 @@ object InferWidths extends Pass {
//for (x <- f) println(x)
//println("=========================")
- val e_sub = substitute(f)(e)
+ val e_sub = simplify(substitute(f)(e))
//println("Solving " + n + " => " + e)
//println("After Substitute: " + n + " => " + e_sub)