aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/transforms')
-rw-r--r--src/main/scala/firrtl/transforms/CheckCombLoops.scala2
-rw-r--r--src/main/scala/firrtl/transforms/CombineCats.scala3
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala143
-rw-r--r--src/main/scala/firrtl/transforms/DeadCodeElimination.scala10
-rw-r--r--src/main/scala/firrtl/transforms/RemoveReset.scala2
-rw-r--r--src/main/scala/firrtl/transforms/RemoveWires.scala14
6 files changed, 105 insertions, 69 deletions
diff --git a/src/main/scala/firrtl/transforms/CheckCombLoops.scala b/src/main/scala/firrtl/transforms/CheckCombLoops.scala
index dee2f9c8..eec9d1af 100644
--- a/src/main/scala/firrtl/transforms/CheckCombLoops.scala
+++ b/src/main/scala/firrtl/transforms/CheckCombLoops.scala
@@ -101,7 +101,7 @@ case class CombinationalPath(sink: ReferenceTarget, sources: Seq[ReferenceTarget
class CheckCombLoops extends Transform with RegisteredTransform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.MidForm ++
- Seq(Dependency(passes.LowerTypes), Dependency(passes.Legalize), Dependency(firrtl.transforms.RemoveReset))
+ Seq(Dependency(passes.LowerTypes), Dependency(firrtl.transforms.RemoveReset))
override def optionalPrerequisites = Seq.empty
diff --git a/src/main/scala/firrtl/transforms/CombineCats.scala b/src/main/scala/firrtl/transforms/CombineCats.scala
index a37e6f08..71ef34bf 100644
--- a/src/main/scala/firrtl/transforms/CombineCats.scala
+++ b/src/main/scala/firrtl/transforms/CombineCats.scala
@@ -63,12 +63,11 @@ class CombineCats extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.LowForm ++
Seq(
Dependency(passes.RemoveValidIf),
- Dependency[firrtl.transforms.ConstantPropagation],
Dependency(firrtl.passes.memlib.VerilogMemDelays),
Dependency(firrtl.passes.SplitExpressions)
)
- override def optionalPrerequisites = Seq.empty
+ override def optionalPrerequisites = Seq(Dependency[firrtl.transforms.ConstantPropagation])
override def optionalPrerequisiteOf = Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index bc1fc9af..f216a3a3 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -3,7 +3,7 @@
package firrtl
package transforms
-import firrtl._
+import firrtl.{options, _}
import firrtl.annotations._
import firrtl.annotations.TargetToken._
import firrtl.ir._
@@ -41,8 +41,64 @@ object ConstantPropagation {
}
)
case (we, wt) if we == wt => e
+ case (we, wt) =>
+ throw new RuntimeException(s"Cannot pad from $we-bit to $wt-bit! ${e.serialize}")
}
+ def constPropPad(e: DoPrim): Expression = {
+ // we constant prop through casts here in order to allow LegalizeConnects
+ // to not mess up async reset checks in CheckResets
+ val propCasts = e.args.head match {
+ case c @ DoPrim(AsUInt, _, _, _) => constPropCasts(c)
+ case c @ DoPrim(AsSInt, _, _, _) => constPropCasts(c)
+ case other => other
+ }
+ propCasts match {
+ case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head.max(w)))
+ case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head.max(w)))
+ case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head
+ case _ => e
+ }
+ }
+
+ def constPropCasts(e: DoPrim): Expression = e.op match {
+ case AsUInt =>
+ e.args.head match {
+ case SIntLiteral(v, IntWidth(w)) => litToUInt(v, w.toInt)
+ case arg =>
+ arg.tpe match {
+ case _: UIntType => arg
+ case _ => e
+ }
+ }
+ case AsSInt =>
+ e.args.head match {
+ case UIntLiteral(v, IntWidth(w)) => litToSInt(v, w.toInt)
+ case arg =>
+ arg.tpe match {
+ case _: SIntType => arg
+ case _ => e
+ }
+ }
+ case AsClock =>
+ val arg = e.args.head
+ arg.tpe match {
+ case ClockType => arg
+ case _ => e
+ }
+ case AsAsyncReset =>
+ val arg = e.args.head
+ arg.tpe match {
+ case AsyncResetType => arg
+ case _ => e
+ }
+ }
+
+ private def litToSInt(unsignedValue: BigInt, w: Int): SIntLiteral =
+ SIntLiteral(unsignedValue - ((unsignedValue >> (w - 1)) << w), IntWidth(w))
+ private def litToUInt(signedValue: BigInt, w: Int): UIntLiteral =
+ UIntLiteral(signedValue + (if (signedValue < 0) BigInt(1) << w else 0), IntWidth(w))
+
def constPropBitExtract(e: DoPrim) = {
val arg = e.args.head
val (hi, lo) = e.op match {
@@ -68,11 +124,26 @@ object ConstantPropagation {
case 0 => e.args.head
case x =>
e.args.head match {
- // TODO when amount >= x.width, return a zero-width wire
case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x).max(1)))
- // take sign bit if shift amount is larger than arg width
case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x).max(1)))
- case _ => e
+ // Handle non-literal arguments where shift is larger than width
+ case _ =>
+ val amount = e.consts.head.toInt
+ val width = bitWidth(e.args.head.tpe)
+ lazy val msb = width - 1
+ if (amount >= width) {
+ e.tpe match {
+ // When amount >= x.width, return a zero-width wire
+ case UIntType(_) => zero
+ // Take sign bit if shift amount is larger than arg width
+ case SIntType(_) =>
+ val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType)
+ DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1)))
+ case t => error(s"Unsupported type $t for Primop Shift Right")
+ }
+ } else {
+ e
+ }
}
}
@@ -114,12 +185,13 @@ object ConstantPropagation {
class ConstantPropagation extends Transform with RegisteredTransform with DependencyAPIMigration {
import ConstantPropagation._
- override def prerequisites =
- ((new mutable.LinkedHashSet())
- ++ firrtl.stage.Forms.LowForm
- - Dependency(firrtl.passes.Legalize)).toSeq
+ override def prerequisites = firrtl.stage.Forms.LowForm
- override def optionalPrerequisites = Seq(Dependency(firrtl.passes.RemoveValidIf))
+ override def optionalPrerequisites = Seq(
+ // both passes allow constant prop to be more effective!
+ Dependency(firrtl.passes.RemoveValidIf),
+ Dependency(firrtl.passes.PadWidths)
+ )
override def optionalPrerequisiteOf =
Seq(
@@ -130,8 +202,7 @@ class ConstantPropagation extends Transform with RegisteredTransform with Depend
)
override def invalidates(a: Transform): Boolean = a match {
- case firrtl.passes.Legalize => true
- case _ => false
+ case _ => false
}
val options = Seq(
@@ -353,11 +424,16 @@ class ConstantPropagation extends Transform with RegisteredTransform with Depend
def <(that: Range) = this.max < that.min
def <=(that: Range) = this.max <= that.min
}
+ // Padding increases the width but doesn't increase the range of values
+ def trueType(e: Expression): Type = e match {
+ case DoPrim(Pad, Seq(a), _, _) => a.tpe
+ case other => other.tpe
+ }
def range(e: Expression): Range = e match {
case UIntLiteral(value, _) => Range(value, value)
case SIntLiteral(value, _) => Range(value, value)
case _ =>
- e.tpe match {
+ trueType(e) match {
case SIntType(IntWidth(width)) =>
Range(
min = BigInt(0) - BigInt(2).pow(width.toInt - 1),
@@ -444,45 +520,10 @@ class ConstantPropagation extends Transform with RegisteredTransform with Depend
case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w))
case _ => e
}
- case AsUInt =>
- e.args.head match {
- case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w))
- case arg =>
- arg.tpe match {
- case _: UIntType => arg
- case _ => e
- }
- }
- case AsSInt =>
- e.args.head match {
- case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt - 1)) << w.toInt), IntWidth(w))
- case arg =>
- arg.tpe match {
- case _: SIntType => arg
- case _ => e
- }
- }
- case AsClock =>
- val arg = e.args.head
- arg.tpe match {
- case ClockType => arg
- case _ => e
- }
- case AsAsyncReset =>
- val arg = e.args.head
- arg.tpe match {
- case AsyncResetType => arg
- case _ => e
- }
- case Pad =>
- e.args.head match {
- case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head.max(w)))
- case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head.max(w)))
- case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head
- case _ => e
- }
- case (Bits | Head | Tail) => constPropBitExtract(e)
- case _ => e
+ case AsUInt | AsSInt | AsClock | AsAsyncReset => constPropCasts(e)
+ case Pad => constPropPad(e)
+ case (Bits | Head | Tail) => constPropBitExtract(e)
+ case _ => e
}
private def constPropMuxCond(m: Mux) = m.cond match {
diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
index f72585d1..41ffd2be 100644
--- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
+++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
@@ -33,15 +33,7 @@ import collection.mutable
*/
class DeadCodeElimination extends Transform with RegisteredTransform with DependencyAPIMigration {
- override def prerequisites = firrtl.stage.Forms.LowForm ++
- Seq(
- Dependency(firrtl.passes.RemoveValidIf),
- Dependency[firrtl.transforms.ConstantPropagation],
- Dependency(firrtl.passes.memlib.VerilogMemDelays),
- Dependency(firrtl.passes.SplitExpressions),
- Dependency[firrtl.transforms.CombineCats],
- Dependency(passes.CommonSubexpressionElimination)
- )
+ override def prerequisites = firrtl.stage.Forms.LowForm
override def optionalPrerequisites = Seq.empty
diff --git a/src/main/scala/firrtl/transforms/RemoveReset.scala b/src/main/scala/firrtl/transforms/RemoveReset.scala
index d7f59321..62b341cd 100644
--- a/src/main/scala/firrtl/transforms/RemoveReset.scala
+++ b/src/main/scala/firrtl/transforms/RemoveReset.scala
@@ -19,7 +19,7 @@ import scala.collection.{immutable, mutable}
object RemoveReset extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.MidForm ++
- Seq(Dependency(passes.LowerTypes), Dependency(passes.Legalize))
+ Seq(Dependency(passes.LowerTypes))
override def optionalPrerequisites = Seq.empty
diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala
index 7500b386..4fa70002 100644
--- a/src/main/scala/firrtl/transforms/RemoveWires.scala
+++ b/src/main/scala/firrtl/transforms/RemoveWires.scala
@@ -12,6 +12,7 @@ import firrtl.graph.{CyclicException, MutableDiGraph}
import firrtl.options.Dependency
import firrtl.Utils.getGroundZero
import firrtl.backends.experimental.smt.random.DefRandom
+import firrtl.passes.PadWidths
import scala.collection.mutable
import scala.util.{Failure, Success, Try}
@@ -27,10 +28,10 @@ class RemoveWires extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.MidForm ++
Seq(
Dependency(passes.LowerTypes),
- Dependency(passes.Legalize),
Dependency(passes.ResolveKinds),
Dependency(transforms.RemoveReset),
- Dependency[transforms.CheckCombLoops]
+ Dependency[transforms.CheckCombLoops],
+ Dependency(passes.LegalizeConnects)
)
override def optionalPrerequisites = Seq(Dependency[checks.CheckResets])
@@ -131,10 +132,13 @@ class RemoveWires extends Transform with DependencyAPIMigration {
case con @ Connect(cinfo, lhs, rhs) =>
kind(lhs) match {
case WireKind =>
- // Be sure to pad the rhs since nodes get their type from the rhs
- val paddedRhs = ConstantPropagation.pad(rhs, lhs.tpe)
+ // be sure that connects have the same bit widths on rhs and lhs
+ assert(
+ bitWidth(lhs.tpe) == bitWidth(rhs.tpe),
+ "Connection widths should have been taken care of by LegalizeConnects!"
+ )
val dinfo = wireInfo(lhs)
- netlist(we(lhs)) = (Seq(paddedRhs), MultiInfo(dinfo, cinfo))
+ netlist(we(lhs)) = (Seq(rhs), MultiInfo(dinfo, cinfo))
case _ => otherStmts += con // Other connections just pass through
}
case invalid @ IsInvalid(info, expr) =>