aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorazidar2016-11-10 13:42:38 -0500
committerazidar2016-11-10 13:42:38 -0500
commitfdde6f839d7f6811e12127dbe9f3f1ae429ee12c (patch)
tree6232b14043cea1d930f8ac9291dba64e0fc039a7 /src
parent76f91f313312895476ae06b45ed72494ab653f1c (diff)
Added additional optimizations
Required for passing all chisel3 tests
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/InferWidths.scala58
1 files changed, 46 insertions, 12 deletions
diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala
index a7a150bd..36466c09 100644
--- a/src/main/scala/firrtl/passes/InferWidths.scala
+++ b/src/main/scala/firrtl/passes/InferWidths.scala
@@ -27,40 +27,73 @@ object InferWidths extends Pass {
case _ => h
})
}
- def simplify(w: Width): Width = w map simplify match {
- case wx: MinWidth => MinWidth(unique((wx.args foldLeft Seq[Width]()){
+ def pullMinMax(w: Width): Width = w map pullMinMax match {
+ case PlusWidth(MaxWidth(maxs), IntWidth(i)) => MaxWidth(maxs.map(m => PlusWidth(m, IntWidth(i))))
+ case PlusWidth(IntWidth(i), MaxWidth(maxs)) => MaxWidth(maxs.map(m => PlusWidth(m, IntWidth(i))))
+ case MinusWidth(MaxWidth(maxs), IntWidth(i)) => MaxWidth(maxs.map(m => MinusWidth(m, IntWidth(i))))
+ case MinusWidth(IntWidth(i), MaxWidth(maxs)) => MaxWidth(maxs.map(m => MinusWidth(IntWidth(i), m)))
+ case PlusWidth(MinWidth(mins), IntWidth(i)) => MinWidth(mins.map(m => PlusWidth(m, IntWidth(i))))
+ case PlusWidth(IntWidth(i), MinWidth(mins)) => MinWidth(mins.map(m => PlusWidth(m, IntWidth(i))))
+ case MinusWidth(MinWidth(mins), IntWidth(i)) => MinWidth(mins.map(m => MinusWidth(m, IntWidth(i))))
+ case MinusWidth(IntWidth(i), MinWidth(mins)) => MinWidth(mins.map(m => MinusWidth(IntWidth(i), m)))
+ case wx => wx
+ }
+ def collectMinMax(w: Width): Width = w map collectMinMax match {
+ case MinWidth(args) => MinWidth(unique((args.foldLeft(Seq[Width]())) {
case (res, wxx: MinWidth) => res ++ wxx.args
case (res, wxx) => res :+ wxx
}))
- case wx: MaxWidth => MaxWidth(unique((wx.args foldLeft Seq[Width]()){
+ case MaxWidth(args) => MaxWidth(unique((args.foldLeft(Seq[Width]())) {
case (res, wxx: MaxWidth) => res ++ wxx.args
case (res, wxx) => res :+ wxx
}))
+ case wx => wx
+ }
+ def mergePlusMinus(w: Width): Width = w map mergePlusMinus match {
case wx: PlusWidth => (wx.arg1, wx.arg2) match {
case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width + w2.width)
- case (w1, IntWidth(x)) if x == 0 => w1
- case (IntWidth(x), w1) if x == 0 => w1
case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1)
case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1)
case (IntWidth(y), PlusWidth(w1, IntWidth(x))) => PlusWidth(IntWidth(x + y), w1)
case (IntWidth(y), PlusWidth(IntWidth(x), w1)) => PlusWidth(IntWidth(x + y), w1)
- case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => simplify(PlusWidth(IntWidth(y - x), w1)) // call simplify in case y = x
- case (IntWidth(y), MinusWidth(w1, IntWidth(x))) => simplify(PlusWidth(IntWidth(y - x), w1)) // call simplify in case y = x
+ case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(y - x), w1)
+ case (IntWidth(y), MinusWidth(w1, IntWidth(x))) => PlusWidth(IntWidth(y - x), w1)
case _ => wx
}
case wx: MinusWidth => (wx.arg1, wx.arg2) match {
case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width)
- case (w1, IntWidth(x)) if x == 0 => w1
- case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => simplify(PlusWidth(IntWidth(x - y), w1)) // call simplify in case y = x
- case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => simplify(PlusWidth(IntWidth(x - y), w1)) // call simplify in case y = x
- case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => simplify(PlusWidth(IntWidth(x - y), w1)) // call simplify in case y = x
+ case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1)
+ case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1)
+ case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1)
case _ => wx
}
case wx: ExpWidth => wx.arg1 match {
case w1: IntWidth => IntWidth(BigInt((math.pow(2, w1.width.toDouble) - 1).toLong))
case _ => wx
}
- case _ => w
+ case wx => wx
+ }
+ def removeZeros(w: Width): Width = w map removeZeros match {
+ case wx: PlusWidth => (wx.arg1, wx.arg2) match {
+ case (w1, IntWidth(x)) if x == 0 => w1
+ case (IntWidth(x), w1) if x == 0 => w1
+ case _ => wx
+ }
+ case wx: MinusWidth => (wx.arg1, wx.arg2) match {
+ case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width)
+ case (w1, IntWidth(x)) if x == 0 => w1
+ case _ => wx
+ }
+ case wx => wx
+ }
+ def simplify(w: Width): Width = {
+ val opts = Seq(
+ pullMinMax _,
+ collectMinMax _,
+ mergePlusMinus _,
+ removeZeros _
+ )
+ opts.foldLeft(w) { (width, opt) => opt(width) }
}
def substitute(h: ConstraintMap)(w: Width): Width = {
@@ -96,6 +129,7 @@ object InferWidths extends Pass {
w match {
case wx: MaxWidth => MaxWidth(wx.args filter {
case wxx: VarWidth => !(n equals wxx.name)
+ case MinusWidth(VarWidth(name), IntWidth(i)) if ((i >= 0) && (n == name)) => false
case _ => true
})
case wx: MinusWidth => wx.arg1 match {