aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/RemoveIntervals.scala
diff options
context:
space:
mode:
authorchick2020-08-14 19:47:53 -0700
committerJack Koenig2020-08-14 19:47:53 -0700
commit6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch)
tree2ed103ee80b0fba613c88a66af854ae9952610ce /src/main/scala/firrtl/passes/RemoveIntervals.scala
parentb516293f703c4de86397862fee1897aded2ae140 (diff)
All of src/ formatted with scalafmt
Diffstat (limited to 'src/main/scala/firrtl/passes/RemoveIntervals.scala')
-rw-r--r--src/main/scala/firrtl/passes/RemoveIntervals.scala149
1 files changed, 76 insertions, 73 deletions
diff --git a/src/main/scala/firrtl/passes/RemoveIntervals.scala b/src/main/scala/firrtl/passes/RemoveIntervals.scala
index 7059526c..657b4356 100644
--- a/src/main/scala/firrtl/passes/RemoveIntervals.scala
+++ b/src/main/scala/firrtl/passes/RemoveIntervals.scala
@@ -13,14 +13,13 @@ import firrtl.options.Dependency
import scala.math.BigDecimal.RoundingMode._
class WrapWithRemainder(info: Info, mname: String, wrap: DoPrim)
- extends PassException({
- val toWrap = wrap.args.head.serialize
- val toWrapTpe = wrap.args.head.tpe.serialize
- val wrapTo = wrap.args(1).serialize
- val wrapToTpe = wrap.args(1).tpe.serialize
- s"$info: [module $mname] Wraps with remainder currently unsupported: $toWrap:$toWrapTpe cannot be wrapped to $wrapTo's type $wrapToTpe"
- })
-
+ extends PassException({
+ val toWrap = wrap.args.head.serialize
+ val toWrapTpe = wrap.args.head.tpe.serialize
+ val wrapTo = wrap.args(1).serialize
+ val wrapToTpe = wrap.args(1).tpe.serialize
+ s"$info: [module $mname] Wraps with remainder currently unsupported: $toWrap:$toWrapTpe cannot be wrapped to $wrapTo's type $wrapToTpe"
+ })
/** Replaces IntervalType with SIntType, three AST walks:
* 1) Align binary points
@@ -39,48 +38,50 @@ class WrapWithRemainder(info: Info, mname: String, wrap: DoPrim)
class RemoveIntervals extends Pass {
override def prerequisites: Seq[Dependency[Transform]] =
- Seq( Dependency(PullMuxes),
- Dependency(ReplaceAccesses),
- Dependency(ExpandConnects),
- Dependency(RemoveAccesses),
- Dependency[ExpandWhensAndCheck] ) ++ firrtl.stage.Forms.Deduped
+ Seq(
+ Dependency(PullMuxes),
+ Dependency(ReplaceAccesses),
+ Dependency(ExpandConnects),
+ Dependency(RemoveAccesses),
+ Dependency[ExpandWhensAndCheck]
+ ) ++ firrtl.stage.Forms.Deduped
override def invalidates(transform: Transform): Boolean = {
transform match {
case InferTypes | ResolveKinds => true
- case _ => false
+ case _ => false
}
}
def run(c: Circuit): Circuit = {
val alignedCircuit = c
val errors = new Errors()
- val wiredCircuit = alignedCircuit map makeWireModule
- val replacedCircuit = wiredCircuit map replaceModuleInterval(errors)
+ val wiredCircuit = alignedCircuit.map(makeWireModule)
+ val replacedCircuit = wiredCircuit.map(replaceModuleInterval(errors))
errors.trigger()
replacedCircuit
}
/* Replace interval types */
private def replaceModuleInterval(errors: Errors)(m: DefModule): DefModule =
- m map replaceStmtInterval(errors, m.name) map replacePortInterval
+ m.map(replaceStmtInterval(errors, m.name)).map(replacePortInterval)
private def replaceStmtInterval(errors: Errors, mname: String)(s: Statement): Statement = {
val info = s match {
case h: HasInfo => h.info
case _ => NoInfo
}
- s map replaceTypeInterval map replaceStmtInterval(errors, mname) map replaceExprInterval(errors, info, mname)
+ s.map(replaceTypeInterval).map(replaceStmtInterval(errors, mname)).map(replaceExprInterval(errors, info, mname))
}
private def replaceExprInterval(errors: Errors, info: Info, mname: String)(e: Expression): Expression = e match {
case _: WRef | _: WSubIndex | _: WSubField => e
case o =>
- o map replaceExprInterval(errors, info, mname) match {
+ o.map(replaceExprInterval(errors, info, mname)) match {
case DoPrim(AsInterval, Seq(a1), _, tpe) => DoPrim(AsSInt, Seq(a1), Seq.empty, tpe)
- case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe)
- case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe)
+ case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe)
+ case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe)
case DoPrim(Clip, Seq(a1, _), Nil, tpe: IntervalType) =>
// Output interval (pre-calculated)
val clipLo = tpe.minAdjusted.get
@@ -94,13 +95,13 @@ class RemoveIntervals extends Pass {
val ltOpt = clipLo <= inLow
(gtOpt, ltOpt) match {
// input range within output range -> no optimization
- case (true, true) => a1
+ case (true, true) => a1
case (true, false) => Mux(Lt(a1, clipLo.S), clipLo.S, a1)
case (false, true) => Mux(Gt(a1, clipHi.S), clipHi.S, a1)
- case _ => Mux(Gt(a1, clipHi.S), clipHi.S, Mux(Lt(a1, clipLo.S), clipLo.S, a1))
+ case _ => Mux(Gt(a1, clipHi.S), clipHi.S, Mux(Lt(a1, clipLo.S), clipLo.S, a1))
}
- case sqz@DoPrim(Squeeze, Seq(a1, a2), Nil, tpe: IntervalType) =>
+ case sqz @ DoPrim(Squeeze, Seq(a1, a2), Nil, tpe: IntervalType) =>
// Using (conditional) reassign interval w/o adding mux
val a1tpe = a1.tpe.asInstanceOf[IntervalType]
val a2tpe = a2.tpe.asInstanceOf[IntervalType]
@@ -117,54 +118,55 @@ class RemoveIntervals extends Pass {
val bits = DoPrim(Bits, Seq(a1), Seq(w2 - 1, 0), UIntType(IntWidth(w2)))
DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(w2)))
}
- case w@DoPrim(Wrap, Seq(a1, a2), Nil, tpe: IntervalType) => a2.tpe match {
- // If a2 type is Interval wrap around range. If UInt, wrap around width
- case t: IntervalType =>
- // Need to match binary points before getting *adjusted!
- val (wrapLo, wrapHi) = t.copy(point = tpe.point) match {
- case t: IntervalType => (t.minAdjusted.get, t.maxAdjusted.get)
- case _ => Utils.throwInternalError(s"Illegal AST state: cannot have $e not have an IntervalType")
- }
- val (inLo, inHi) = a1.tpe match {
- case t2: IntervalType => (t2.minAdjusted.get, t2.maxAdjusted.get)
- case _ => sys.error("Shouldn't be here")
- }
- // If (max input) - (max wrap) + (min wrap) is less then (maxwrap), we can optimize when (max input > max wrap)
- val range = wrapHi - wrapLo
- val ltOpt = Add(a1, (range + 1).S)
- val gtOpt = Sub(a1, (range + 1).S)
- // [Angie]: This is dangerous. Would rather throw compilation error right now than allow "Rem" without the user explicitly including it.
- // If x < wl
- // output: wh - (wl - x) + 1 AKA x + r + 1
- // worst case: wh - (wl - xl) + 1 = wl
- // -> xl + wr + 1 = wl
- // If x > wh
- // output: wl + (x - wh) - 1 AKA x - r - 1
- // worst case: wl + (xh - wh) - 1 = wh
- // -> xh - wr - 1 = wh
- val default = Add(Rem(Sub(a1, wrapLo.S), Sub(wrapHi.S, wrapLo.S)), wrapLo.S)
- (wrapHi >= inHi, wrapLo <= inLo, (inHi - range - 1) <= wrapHi, (inLo + range + 1) >= wrapLo) match {
- case (true, true, _, _) => a1
- case (true, _, _, true) => Mux(Lt(a1, wrapLo.S), ltOpt, a1)
- case (_, true, true, _) => Mux(Gt(a1, wrapHi.S), gtOpt, a1)
- // Note: inHi - range - 1 = wrapHi can't be true when inLo + range + 1 = wrapLo (i.e. simultaneous extreme cases don't work)
- case (_, _, true, true) => Mux(Gt(a1, wrapHi.S), gtOpt, Mux(Lt(a1, wrapLo.S), ltOpt, a1))
- case _ =>
- errors.append(new WrapWithRemainder(info, mname, w))
- default
- }
- case _ => sys.error("Shouldn't be here")
- }
+ case w @ DoPrim(Wrap, Seq(a1, a2), Nil, tpe: IntervalType) =>
+ a2.tpe match {
+ // If a2 type is Interval wrap around range. If UInt, wrap around width
+ case t: IntervalType =>
+ // Need to match binary points before getting *adjusted!
+ val (wrapLo, wrapHi) = t.copy(point = tpe.point) match {
+ case t: IntervalType => (t.minAdjusted.get, t.maxAdjusted.get)
+ case _ => Utils.throwInternalError(s"Illegal AST state: cannot have $e not have an IntervalType")
+ }
+ val (inLo, inHi) = a1.tpe match {
+ case t2: IntervalType => (t2.minAdjusted.get, t2.maxAdjusted.get)
+ case _ => sys.error("Shouldn't be here")
+ }
+ // If (max input) - (max wrap) + (min wrap) is less then (maxwrap), we can optimize when (max input > max wrap)
+ val range = wrapHi - wrapLo
+ val ltOpt = Add(a1, (range + 1).S)
+ val gtOpt = Sub(a1, (range + 1).S)
+ // [Angie]: This is dangerous. Would rather throw compilation error right now than allow "Rem" without the user explicitly including it.
+ // If x < wl
+ // output: wh - (wl - x) + 1 AKA x + r + 1
+ // worst case: wh - (wl - xl) + 1 = wl
+ // -> xl + wr + 1 = wl
+ // If x > wh
+ // output: wl + (x - wh) - 1 AKA x - r - 1
+ // worst case: wl + (xh - wh) - 1 = wh
+ // -> xh - wr - 1 = wh
+ val default = Add(Rem(Sub(a1, wrapLo.S), Sub(wrapHi.S, wrapLo.S)), wrapLo.S)
+ (wrapHi >= inHi, wrapLo <= inLo, (inHi - range - 1) <= wrapHi, (inLo + range + 1) >= wrapLo) match {
+ case (true, true, _, _) => a1
+ case (true, _, _, true) => Mux(Lt(a1, wrapLo.S), ltOpt, a1)
+ case (_, true, true, _) => Mux(Gt(a1, wrapHi.S), gtOpt, a1)
+ // Note: inHi - range - 1 = wrapHi can't be true when inLo + range + 1 = wrapLo (i.e. simultaneous extreme cases don't work)
+ case (_, _, true, true) => Mux(Gt(a1, wrapHi.S), gtOpt, Mux(Lt(a1, wrapLo.S), ltOpt, a1))
+ case _ =>
+ errors.append(new WrapWithRemainder(info, mname, w))
+ default
+ }
+ case _ => sys.error("Shouldn't be here")
+ }
case other => other
}
}
- private def replacePortInterval(p: Port): Port = p map replaceTypeInterval
+ private def replacePortInterval(p: Port): Port = p.map(replaceTypeInterval)
private def replaceTypeInterval(t: Type): Type = t match {
- case i@IntervalType(l: IsKnown, u: IsKnown, p: IntWidth) => SIntType(i.width)
+ case i @ IntervalType(l: IsKnown, u: IsKnown, p: IntWidth) => SIntType(i.width)
case i: IntervalType => sys.error(s"Shouldn't be here: $i")
- case v => v map replaceTypeInterval
+ case v => v.map(replaceTypeInterval)
}
/** Replace Interval Nodes with Interval Wires
@@ -174,15 +176,16 @@ class RemoveIntervals extends Pass {
* @param m module to replace nodes with wire + connection
* @return
*/
- private def makeWireModule(m: DefModule): DefModule = m map makeWireStmt
+ private def makeWireModule(m: DefModule): DefModule = m.map(makeWireStmt)
private def makeWireStmt(s: Statement): Statement = s match {
- case DefNode(info, name, value) => value.tpe match {
- case IntervalType(l, u, p) =>
- val newType = IntervalType(l, u, p)
- Block(Seq(DefWire(info, name, newType), Connect(info, WRef(name, newType, WireKind, SinkFlow), value)))
- case other => s
- }
- case other => other map makeWireStmt
+ case DefNode(info, name, value) =>
+ value.tpe match {
+ case IntervalType(l, u, p) =>
+ val newType = IntervalType(l, u, p)
+ Block(Seq(DefWire(info, name, newType), Connect(info, WRef(name, newType, WireKind, SinkFlow), value)))
+ case other => s
+ }
+ case other => other.map(makeWireStmt)
}
}