diff options
| author | azidar | 2016-11-10 13:42:38 -0500 |
|---|---|---|
| committer | azidar | 2016-11-10 13:42:38 -0500 |
| commit | fdde6f839d7f6811e12127dbe9f3f1ae429ee12c (patch) | |
| tree | 6232b14043cea1d930f8ac9291dba64e0fc039a7 /src | |
| parent | 76f91f313312895476ae06b45ed72494ab653f1c (diff) | |
Added additional optimizations
Required for passing all chisel3 tests
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/InferWidths.scala | 58 |
1 files changed, 46 insertions, 12 deletions
diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index a7a150bd..36466c09 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -27,40 +27,73 @@ object InferWidths extends Pass { case _ => h }) } - def simplify(w: Width): Width = w map simplify match { - case wx: MinWidth => MinWidth(unique((wx.args foldLeft Seq[Width]()){ + def pullMinMax(w: Width): Width = w map pullMinMax match { + case PlusWidth(MaxWidth(maxs), IntWidth(i)) => MaxWidth(maxs.map(m => PlusWidth(m, IntWidth(i)))) + case PlusWidth(IntWidth(i), MaxWidth(maxs)) => MaxWidth(maxs.map(m => PlusWidth(m, IntWidth(i)))) + case MinusWidth(MaxWidth(maxs), IntWidth(i)) => MaxWidth(maxs.map(m => MinusWidth(m, IntWidth(i)))) + case MinusWidth(IntWidth(i), MaxWidth(maxs)) => MaxWidth(maxs.map(m => MinusWidth(IntWidth(i), m))) + case PlusWidth(MinWidth(mins), IntWidth(i)) => MinWidth(mins.map(m => PlusWidth(m, IntWidth(i)))) + case PlusWidth(IntWidth(i), MinWidth(mins)) => MinWidth(mins.map(m => PlusWidth(m, IntWidth(i)))) + case MinusWidth(MinWidth(mins), IntWidth(i)) => MinWidth(mins.map(m => MinusWidth(m, IntWidth(i)))) + case MinusWidth(IntWidth(i), MinWidth(mins)) => MinWidth(mins.map(m => MinusWidth(IntWidth(i), m))) + case wx => wx + } + def collectMinMax(w: Width): Width = w map collectMinMax match { + case MinWidth(args) => MinWidth(unique((args.foldLeft(Seq[Width]())) { case (res, wxx: MinWidth) => res ++ wxx.args case (res, wxx) => res :+ wxx })) - case wx: MaxWidth => MaxWidth(unique((wx.args foldLeft Seq[Width]()){ + case MaxWidth(args) => MaxWidth(unique((args.foldLeft(Seq[Width]())) { case (res, wxx: MaxWidth) => res ++ wxx.args case (res, wxx) => res :+ wxx })) + case wx => wx + } + def mergePlusMinus(w: Width): Width = w map mergePlusMinus match { case wx: PlusWidth => (wx.arg1, wx.arg2) match { case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width + w2.width) - case (w1, IntWidth(x)) if x == 0 => w1 - case (IntWidth(x), w1) if x == 0 => w1 case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1) case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x + y), w1) case (IntWidth(y), PlusWidth(w1, IntWidth(x))) => PlusWidth(IntWidth(x + y), w1) case (IntWidth(y), PlusWidth(IntWidth(x), w1)) => PlusWidth(IntWidth(x + y), w1) - case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => simplify(PlusWidth(IntWidth(y - x), w1)) // call simplify in case y = x - case (IntWidth(y), MinusWidth(w1, IntWidth(x))) => simplify(PlusWidth(IntWidth(y - x), w1)) // call simplify in case y = x + case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(y - x), w1) + case (IntWidth(y), MinusWidth(w1, IntWidth(x))) => PlusWidth(IntWidth(y - x), w1) case _ => wx } case wx: MinusWidth => (wx.arg1, wx.arg2) match { case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width) - case (w1, IntWidth(x)) if x == 0 => w1 - case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => simplify(PlusWidth(IntWidth(x - y), w1)) // call simplify in case y = x - case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => simplify(PlusWidth(IntWidth(x - y), w1)) // call simplify in case y = x - case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => simplify(PlusWidth(IntWidth(x - y), w1)) // call simplify in case y = x + case (PlusWidth(IntWidth(x), w1), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1) + case (PlusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1) + case (MinusWidth(w1, IntWidth(x)), IntWidth(y)) => PlusWidth(IntWidth(x - y), w1) case _ => wx } case wx: ExpWidth => wx.arg1 match { case w1: IntWidth => IntWidth(BigInt((math.pow(2, w1.width.toDouble) - 1).toLong)) case _ => wx } - case _ => w + case wx => wx + } + def removeZeros(w: Width): Width = w map removeZeros match { + case wx: PlusWidth => (wx.arg1, wx.arg2) match { + case (w1, IntWidth(x)) if x == 0 => w1 + case (IntWidth(x), w1) if x == 0 => w1 + case _ => wx + } + case wx: MinusWidth => (wx.arg1, wx.arg2) match { + case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width) + case (w1, IntWidth(x)) if x == 0 => w1 + case _ => wx + } + case wx => wx + } + def simplify(w: Width): Width = { + val opts = Seq( + pullMinMax _, + collectMinMax _, + mergePlusMinus _, + removeZeros _ + ) + opts.foldLeft(w) { (width, opt) => opt(width) } } def substitute(h: ConstraintMap)(w: Width): Width = { @@ -96,6 +129,7 @@ object InferWidths extends Pass { w match { case wx: MaxWidth => MaxWidth(wx.args filter { case wxx: VarWidth => !(n equals wxx.name) + case MinusWidth(VarWidth(name), IntWidth(i)) if ((i >= 0) && (n == name)) => false case _ => true }) case wx: MinusWidth => wx.arg1 match { |
