diff options
Diffstat (limited to 'src/main/scala/firrtl/transforms')
6 files changed, 105 insertions, 69 deletions
diff --git a/src/main/scala/firrtl/transforms/CheckCombLoops.scala b/src/main/scala/firrtl/transforms/CheckCombLoops.scala index dee2f9c8..eec9d1af 100644 --- a/src/main/scala/firrtl/transforms/CheckCombLoops.scala +++ b/src/main/scala/firrtl/transforms/CheckCombLoops.scala @@ -101,7 +101,7 @@ case class CombinationalPath(sink: ReferenceTarget, sources: Seq[ReferenceTarget class CheckCombLoops extends Transform with RegisteredTransform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.MidForm ++ - Seq(Dependency(passes.LowerTypes), Dependency(passes.Legalize), Dependency(firrtl.transforms.RemoveReset)) + Seq(Dependency(passes.LowerTypes), Dependency(firrtl.transforms.RemoveReset)) override def optionalPrerequisites = Seq.empty diff --git a/src/main/scala/firrtl/transforms/CombineCats.scala b/src/main/scala/firrtl/transforms/CombineCats.scala index a37e6f08..71ef34bf 100644 --- a/src/main/scala/firrtl/transforms/CombineCats.scala +++ b/src/main/scala/firrtl/transforms/CombineCats.scala @@ -63,12 +63,11 @@ class CombineCats extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.LowForm ++ Seq( Dependency(passes.RemoveValidIf), - Dependency[firrtl.transforms.ConstantPropagation], Dependency(firrtl.passes.memlib.VerilogMemDelays), Dependency(firrtl.passes.SplitExpressions) ) - override def optionalPrerequisites = Seq.empty + override def optionalPrerequisites = Seq(Dependency[firrtl.transforms.ConstantPropagation]) override def optionalPrerequisiteOf = Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index bc1fc9af..f216a3a3 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -3,7 +3,7 @@ package firrtl package transforms -import firrtl._ +import firrtl.{options, _} import firrtl.annotations._ import firrtl.annotations.TargetToken._ import firrtl.ir._ @@ -41,8 +41,64 @@ object ConstantPropagation { } ) case (we, wt) if we == wt => e + case (we, wt) => + throw new RuntimeException(s"Cannot pad from $we-bit to $wt-bit! ${e.serialize}") } + def constPropPad(e: DoPrim): Expression = { + // we constant prop through casts here in order to allow LegalizeConnects + // to not mess up async reset checks in CheckResets + val propCasts = e.args.head match { + case c @ DoPrim(AsUInt, _, _, _) => constPropCasts(c) + case c @ DoPrim(AsSInt, _, _, _) => constPropCasts(c) + case other => other + } + propCasts match { + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head.max(w))) + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head.max(w))) + case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head + case _ => e + } + } + + def constPropCasts(e: DoPrim): Expression = e.op match { + case AsUInt => + e.args.head match { + case SIntLiteral(v, IntWidth(w)) => litToUInt(v, w.toInt) + case arg => + arg.tpe match { + case _: UIntType => arg + case _ => e + } + } + case AsSInt => + e.args.head match { + case UIntLiteral(v, IntWidth(w)) => litToSInt(v, w.toInt) + case arg => + arg.tpe match { + case _: SIntType => arg + case _ => e + } + } + case AsClock => + val arg = e.args.head + arg.tpe match { + case ClockType => arg + case _ => e + } + case AsAsyncReset => + val arg = e.args.head + arg.tpe match { + case AsyncResetType => arg + case _ => e + } + } + + private def litToSInt(unsignedValue: BigInt, w: Int): SIntLiteral = + SIntLiteral(unsignedValue - ((unsignedValue >> (w - 1)) << w), IntWidth(w)) + private def litToUInt(signedValue: BigInt, w: Int): UIntLiteral = + UIntLiteral(signedValue + (if (signedValue < 0) BigInt(1) << w else 0), IntWidth(w)) + def constPropBitExtract(e: DoPrim) = { val arg = e.args.head val (hi, lo) = e.op match { @@ -68,11 +124,26 @@ object ConstantPropagation { case 0 => e.args.head case x => e.args.head match { - // TODO when amount >= x.width, return a zero-width wire case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x).max(1))) - // take sign bit if shift amount is larger than arg width case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x).max(1))) - case _ => e + // Handle non-literal arguments where shift is larger than width + case _ => + val amount = e.consts.head.toInt + val width = bitWidth(e.args.head.tpe) + lazy val msb = width - 1 + if (amount >= width) { + e.tpe match { + // When amount >= x.width, return a zero-width wire + case UIntType(_) => zero + // Take sign bit if shift amount is larger than arg width + case SIntType(_) => + val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType) + DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1))) + case t => error(s"Unsupported type $t for Primop Shift Right") + } + } else { + e + } } } @@ -114,12 +185,13 @@ object ConstantPropagation { class ConstantPropagation extends Transform with RegisteredTransform with DependencyAPIMigration { import ConstantPropagation._ - override def prerequisites = - ((new mutable.LinkedHashSet()) - ++ firrtl.stage.Forms.LowForm - - Dependency(firrtl.passes.Legalize)).toSeq + override def prerequisites = firrtl.stage.Forms.LowForm - override def optionalPrerequisites = Seq(Dependency(firrtl.passes.RemoveValidIf)) + override def optionalPrerequisites = Seq( + // both passes allow constant prop to be more effective! + Dependency(firrtl.passes.RemoveValidIf), + Dependency(firrtl.passes.PadWidths) + ) override def optionalPrerequisiteOf = Seq( @@ -130,8 +202,7 @@ class ConstantPropagation extends Transform with RegisteredTransform with Depend ) override def invalidates(a: Transform): Boolean = a match { - case firrtl.passes.Legalize => true - case _ => false + case _ => false } val options = Seq( @@ -353,11 +424,16 @@ class ConstantPropagation extends Transform with RegisteredTransform with Depend def <(that: Range) = this.max < that.min def <=(that: Range) = this.max <= that.min } + // Padding increases the width but doesn't increase the range of values + def trueType(e: Expression): Type = e match { + case DoPrim(Pad, Seq(a), _, _) => a.tpe + case other => other.tpe + } def range(e: Expression): Range = e match { case UIntLiteral(value, _) => Range(value, value) case SIntLiteral(value, _) => Range(value, value) case _ => - e.tpe match { + trueType(e) match { case SIntType(IntWidth(width)) => Range( min = BigInt(0) - BigInt(2).pow(width.toInt - 1), @@ -444,45 +520,10 @@ class ConstantPropagation extends Transform with RegisteredTransform with Depend case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) case _ => e } - case AsUInt => - e.args.head match { - case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w)) - case arg => - arg.tpe match { - case _: UIntType => arg - case _ => e - } - } - case AsSInt => - e.args.head match { - case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt - 1)) << w.toInt), IntWidth(w)) - case arg => - arg.tpe match { - case _: SIntType => arg - case _ => e - } - } - case AsClock => - val arg = e.args.head - arg.tpe match { - case ClockType => arg - case _ => e - } - case AsAsyncReset => - val arg = e.args.head - arg.tpe match { - case AsyncResetType => arg - case _ => e - } - case Pad => - e.args.head match { - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head.max(w))) - case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head.max(w))) - case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head - case _ => e - } - case (Bits | Head | Tail) => constPropBitExtract(e) - case _ => e + case AsUInt | AsSInt | AsClock | AsAsyncReset => constPropCasts(e) + case Pad => constPropPad(e) + case (Bits | Head | Tail) => constPropBitExtract(e) + case _ => e } private def constPropMuxCond(m: Mux) = m.cond match { diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala index f72585d1..41ffd2be 100644 --- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala +++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala @@ -33,15 +33,7 @@ import collection.mutable */ class DeadCodeElimination extends Transform with RegisteredTransform with DependencyAPIMigration { - override def prerequisites = firrtl.stage.Forms.LowForm ++ - Seq( - Dependency(firrtl.passes.RemoveValidIf), - Dependency[firrtl.transforms.ConstantPropagation], - Dependency(firrtl.passes.memlib.VerilogMemDelays), - Dependency(firrtl.passes.SplitExpressions), - Dependency[firrtl.transforms.CombineCats], - Dependency(passes.CommonSubexpressionElimination) - ) + override def prerequisites = firrtl.stage.Forms.LowForm override def optionalPrerequisites = Seq.empty diff --git a/src/main/scala/firrtl/transforms/RemoveReset.scala b/src/main/scala/firrtl/transforms/RemoveReset.scala index d7f59321..62b341cd 100644 --- a/src/main/scala/firrtl/transforms/RemoveReset.scala +++ b/src/main/scala/firrtl/transforms/RemoveReset.scala @@ -19,7 +19,7 @@ import scala.collection.{immutable, mutable} object RemoveReset extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.MidForm ++ - Seq(Dependency(passes.LowerTypes), Dependency(passes.Legalize)) + Seq(Dependency(passes.LowerTypes)) override def optionalPrerequisites = Seq.empty diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala index 7500b386..4fa70002 100644 --- a/src/main/scala/firrtl/transforms/RemoveWires.scala +++ b/src/main/scala/firrtl/transforms/RemoveWires.scala @@ -12,6 +12,7 @@ import firrtl.graph.{CyclicException, MutableDiGraph} import firrtl.options.Dependency import firrtl.Utils.getGroundZero import firrtl.backends.experimental.smt.random.DefRandom +import firrtl.passes.PadWidths import scala.collection.mutable import scala.util.{Failure, Success, Try} @@ -27,10 +28,10 @@ class RemoveWires extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.MidForm ++ Seq( Dependency(passes.LowerTypes), - Dependency(passes.Legalize), Dependency(passes.ResolveKinds), Dependency(transforms.RemoveReset), - Dependency[transforms.CheckCombLoops] + Dependency[transforms.CheckCombLoops], + Dependency(passes.LegalizeConnects) ) override def optionalPrerequisites = Seq(Dependency[checks.CheckResets]) @@ -131,10 +132,13 @@ class RemoveWires extends Transform with DependencyAPIMigration { case con @ Connect(cinfo, lhs, rhs) => kind(lhs) match { case WireKind => - // Be sure to pad the rhs since nodes get their type from the rhs - val paddedRhs = ConstantPropagation.pad(rhs, lhs.tpe) + // be sure that connects have the same bit widths on rhs and lhs + assert( + bitWidth(lhs.tpe) == bitWidth(rhs.tpe), + "Connection widths should have been taken care of by LegalizeConnects!" + ) val dinfo = wireInfo(lhs) - netlist(we(lhs)) = (Seq(paddedRhs), MultiInfo(dinfo, cinfo)) + netlist(we(lhs)) = (Seq(rhs), MultiInfo(dinfo, cinfo)) case _ => otherStmts += con // Other connections just pass through } case invalid @ IsInvalid(info, expr) => |
