diff options
| author | chick | 2020-08-14 19:47:53 -0700 |
|---|---|---|
| committer | Jack Koenig | 2020-08-14 19:47:53 -0700 |
| commit | 6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch) | |
| tree | 2ed103ee80b0fba613c88a66af854ae9952610ce /src/main/scala/firrtl/passes/RemoveIntervals.scala | |
| parent | b516293f703c4de86397862fee1897aded2ae140 (diff) | |
All of src/ formatted with scalafmt
Diffstat (limited to 'src/main/scala/firrtl/passes/RemoveIntervals.scala')
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveIntervals.scala | 149 |
1 files changed, 76 insertions, 73 deletions
diff --git a/src/main/scala/firrtl/passes/RemoveIntervals.scala b/src/main/scala/firrtl/passes/RemoveIntervals.scala index 7059526c..657b4356 100644 --- a/src/main/scala/firrtl/passes/RemoveIntervals.scala +++ b/src/main/scala/firrtl/passes/RemoveIntervals.scala @@ -13,14 +13,13 @@ import firrtl.options.Dependency import scala.math.BigDecimal.RoundingMode._ class WrapWithRemainder(info: Info, mname: String, wrap: DoPrim) - extends PassException({ - val toWrap = wrap.args.head.serialize - val toWrapTpe = wrap.args.head.tpe.serialize - val wrapTo = wrap.args(1).serialize - val wrapToTpe = wrap.args(1).tpe.serialize - s"$info: [module $mname] Wraps with remainder currently unsupported: $toWrap:$toWrapTpe cannot be wrapped to $wrapTo's type $wrapToTpe" - }) - + extends PassException({ + val toWrap = wrap.args.head.serialize + val toWrapTpe = wrap.args.head.tpe.serialize + val wrapTo = wrap.args(1).serialize + val wrapToTpe = wrap.args(1).tpe.serialize + s"$info: [module $mname] Wraps with remainder currently unsupported: $toWrap:$toWrapTpe cannot be wrapped to $wrapTo's type $wrapToTpe" + }) /** Replaces IntervalType with SIntType, three AST walks: * 1) Align binary points @@ -39,48 +38,50 @@ class WrapWithRemainder(info: Info, mname: String, wrap: DoPrim) class RemoveIntervals extends Pass { override def prerequisites: Seq[Dependency[Transform]] = - Seq( Dependency(PullMuxes), - Dependency(ReplaceAccesses), - Dependency(ExpandConnects), - Dependency(RemoveAccesses), - Dependency[ExpandWhensAndCheck] ) ++ firrtl.stage.Forms.Deduped + Seq( + Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects), + Dependency(RemoveAccesses), + Dependency[ExpandWhensAndCheck] + ) ++ firrtl.stage.Forms.Deduped override def invalidates(transform: Transform): Boolean = { transform match { case InferTypes | ResolveKinds => true - case _ => false + case _ => false } } def run(c: Circuit): Circuit = { val alignedCircuit = c val errors = new Errors() - val wiredCircuit = alignedCircuit map makeWireModule - val replacedCircuit = wiredCircuit map replaceModuleInterval(errors) + val wiredCircuit = alignedCircuit.map(makeWireModule) + val replacedCircuit = wiredCircuit.map(replaceModuleInterval(errors)) errors.trigger() replacedCircuit } /* Replace interval types */ private def replaceModuleInterval(errors: Errors)(m: DefModule): DefModule = - m map replaceStmtInterval(errors, m.name) map replacePortInterval + m.map(replaceStmtInterval(errors, m.name)).map(replacePortInterval) private def replaceStmtInterval(errors: Errors, mname: String)(s: Statement): Statement = { val info = s match { case h: HasInfo => h.info case _ => NoInfo } - s map replaceTypeInterval map replaceStmtInterval(errors, mname) map replaceExprInterval(errors, info, mname) + s.map(replaceTypeInterval).map(replaceStmtInterval(errors, mname)).map(replaceExprInterval(errors, info, mname)) } private def replaceExprInterval(errors: Errors, info: Info, mname: String)(e: Expression): Expression = e match { case _: WRef | _: WSubIndex | _: WSubField => e case o => - o map replaceExprInterval(errors, info, mname) match { + o.map(replaceExprInterval(errors, info, mname)) match { case DoPrim(AsInterval, Seq(a1), _, tpe) => DoPrim(AsSInt, Seq(a1), Seq.empty, tpe) - case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe) - case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe) + case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe) + case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe) case DoPrim(Clip, Seq(a1, _), Nil, tpe: IntervalType) => // Output interval (pre-calculated) val clipLo = tpe.minAdjusted.get @@ -94,13 +95,13 @@ class RemoveIntervals extends Pass { val ltOpt = clipLo <= inLow (gtOpt, ltOpt) match { // input range within output range -> no optimization - case (true, true) => a1 + case (true, true) => a1 case (true, false) => Mux(Lt(a1, clipLo.S), clipLo.S, a1) case (false, true) => Mux(Gt(a1, clipHi.S), clipHi.S, a1) - case _ => Mux(Gt(a1, clipHi.S), clipHi.S, Mux(Lt(a1, clipLo.S), clipLo.S, a1)) + case _ => Mux(Gt(a1, clipHi.S), clipHi.S, Mux(Lt(a1, clipLo.S), clipLo.S, a1)) } - case sqz@DoPrim(Squeeze, Seq(a1, a2), Nil, tpe: IntervalType) => + case sqz @ DoPrim(Squeeze, Seq(a1, a2), Nil, tpe: IntervalType) => // Using (conditional) reassign interval w/o adding mux val a1tpe = a1.tpe.asInstanceOf[IntervalType] val a2tpe = a2.tpe.asInstanceOf[IntervalType] @@ -117,54 +118,55 @@ class RemoveIntervals extends Pass { val bits = DoPrim(Bits, Seq(a1), Seq(w2 - 1, 0), UIntType(IntWidth(w2))) DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(w2))) } - case w@DoPrim(Wrap, Seq(a1, a2), Nil, tpe: IntervalType) => a2.tpe match { - // If a2 type is Interval wrap around range. If UInt, wrap around width - case t: IntervalType => - // Need to match binary points before getting *adjusted! - val (wrapLo, wrapHi) = t.copy(point = tpe.point) match { - case t: IntervalType => (t.minAdjusted.get, t.maxAdjusted.get) - case _ => Utils.throwInternalError(s"Illegal AST state: cannot have $e not have an IntervalType") - } - val (inLo, inHi) = a1.tpe match { - case t2: IntervalType => (t2.minAdjusted.get, t2.maxAdjusted.get) - case _ => sys.error("Shouldn't be here") - } - // If (max input) - (max wrap) + (min wrap) is less then (maxwrap), we can optimize when (max input > max wrap) - val range = wrapHi - wrapLo - val ltOpt = Add(a1, (range + 1).S) - val gtOpt = Sub(a1, (range + 1).S) - // [Angie]: This is dangerous. Would rather throw compilation error right now than allow "Rem" without the user explicitly including it. - // If x < wl - // output: wh - (wl - x) + 1 AKA x + r + 1 - // worst case: wh - (wl - xl) + 1 = wl - // -> xl + wr + 1 = wl - // If x > wh - // output: wl + (x - wh) - 1 AKA x - r - 1 - // worst case: wl + (xh - wh) - 1 = wh - // -> xh - wr - 1 = wh - val default = Add(Rem(Sub(a1, wrapLo.S), Sub(wrapHi.S, wrapLo.S)), wrapLo.S) - (wrapHi >= inHi, wrapLo <= inLo, (inHi - range - 1) <= wrapHi, (inLo + range + 1) >= wrapLo) match { - case (true, true, _, _) => a1 - case (true, _, _, true) => Mux(Lt(a1, wrapLo.S), ltOpt, a1) - case (_, true, true, _) => Mux(Gt(a1, wrapHi.S), gtOpt, a1) - // Note: inHi - range - 1 = wrapHi can't be true when inLo + range + 1 = wrapLo (i.e. simultaneous extreme cases don't work) - case (_, _, true, true) => Mux(Gt(a1, wrapHi.S), gtOpt, Mux(Lt(a1, wrapLo.S), ltOpt, a1)) - case _ => - errors.append(new WrapWithRemainder(info, mname, w)) - default - } - case _ => sys.error("Shouldn't be here") - } + case w @ DoPrim(Wrap, Seq(a1, a2), Nil, tpe: IntervalType) => + a2.tpe match { + // If a2 type is Interval wrap around range. If UInt, wrap around width + case t: IntervalType => + // Need to match binary points before getting *adjusted! + val (wrapLo, wrapHi) = t.copy(point = tpe.point) match { + case t: IntervalType => (t.minAdjusted.get, t.maxAdjusted.get) + case _ => Utils.throwInternalError(s"Illegal AST state: cannot have $e not have an IntervalType") + } + val (inLo, inHi) = a1.tpe match { + case t2: IntervalType => (t2.minAdjusted.get, t2.maxAdjusted.get) + case _ => sys.error("Shouldn't be here") + } + // If (max input) - (max wrap) + (min wrap) is less then (maxwrap), we can optimize when (max input > max wrap) + val range = wrapHi - wrapLo + val ltOpt = Add(a1, (range + 1).S) + val gtOpt = Sub(a1, (range + 1).S) + // [Angie]: This is dangerous. Would rather throw compilation error right now than allow "Rem" without the user explicitly including it. + // If x < wl + // output: wh - (wl - x) + 1 AKA x + r + 1 + // worst case: wh - (wl - xl) + 1 = wl + // -> xl + wr + 1 = wl + // If x > wh + // output: wl + (x - wh) - 1 AKA x - r - 1 + // worst case: wl + (xh - wh) - 1 = wh + // -> xh - wr - 1 = wh + val default = Add(Rem(Sub(a1, wrapLo.S), Sub(wrapHi.S, wrapLo.S)), wrapLo.S) + (wrapHi >= inHi, wrapLo <= inLo, (inHi - range - 1) <= wrapHi, (inLo + range + 1) >= wrapLo) match { + case (true, true, _, _) => a1 + case (true, _, _, true) => Mux(Lt(a1, wrapLo.S), ltOpt, a1) + case (_, true, true, _) => Mux(Gt(a1, wrapHi.S), gtOpt, a1) + // Note: inHi - range - 1 = wrapHi can't be true when inLo + range + 1 = wrapLo (i.e. simultaneous extreme cases don't work) + case (_, _, true, true) => Mux(Gt(a1, wrapHi.S), gtOpt, Mux(Lt(a1, wrapLo.S), ltOpt, a1)) + case _ => + errors.append(new WrapWithRemainder(info, mname, w)) + default + } + case _ => sys.error("Shouldn't be here") + } case other => other } } - private def replacePortInterval(p: Port): Port = p map replaceTypeInterval + private def replacePortInterval(p: Port): Port = p.map(replaceTypeInterval) private def replaceTypeInterval(t: Type): Type = t match { - case i@IntervalType(l: IsKnown, u: IsKnown, p: IntWidth) => SIntType(i.width) + case i @ IntervalType(l: IsKnown, u: IsKnown, p: IntWidth) => SIntType(i.width) case i: IntervalType => sys.error(s"Shouldn't be here: $i") - case v => v map replaceTypeInterval + case v => v.map(replaceTypeInterval) } /** Replace Interval Nodes with Interval Wires @@ -174,15 +176,16 @@ class RemoveIntervals extends Pass { * @param m module to replace nodes with wire + connection * @return */ - private def makeWireModule(m: DefModule): DefModule = m map makeWireStmt + private def makeWireModule(m: DefModule): DefModule = m.map(makeWireStmt) private def makeWireStmt(s: Statement): Statement = s match { - case DefNode(info, name, value) => value.tpe match { - case IntervalType(l, u, p) => - val newType = IntervalType(l, u, p) - Block(Seq(DefWire(info, name, newType), Connect(info, WRef(name, newType, WireKind, SinkFlow), value))) - case other => s - } - case other => other map makeWireStmt + case DefNode(info, name, value) => + value.tpe match { + case IntervalType(l, u, p) => + val newType = IntervalType(l, u, p) + Block(Seq(DefWire(info, name, newType), Connect(info, WRef(name, newType, WireKind, SinkFlow), value))) + case other => s + } + case other => other.map(makeWireStmt) } } |
