aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlbert Magyar2020-04-07 15:04:17 -0700
committerGitHub2020-04-07 22:04:17 +0000
commit1a03e6356e451136d522d5a9acba374dd8972b24 (patch)
tree24e568872cb6db4ac9ce87080d1e06a4001a3017 /src
parenta9034bac8df5672b04a53c0ad99d82f94465d678 (diff)
Fix dynamic SubAccess of zero-length vectors (#1450)
* Fix dynamic SubAccess of zero-length vectors * Fixes #230 * Add new ZeroLengthVecs pass that occurs before RemoveAccesses * Include this in stage.Forms.MidForm * Add to High->Mid order in compiler test based on @seldridge feedback * Use validif to produce out-of-bounds value in ZeroLengthVecs * Update scaladoc * Fix test imports
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/RemoveAccesses.scala1
-rw-r--r--src/main/scala/firrtl/passes/ZeroLengthVecs.scala69
-rw-r--r--src/main/scala/firrtl/stage/Forms.scala1
-rw-r--r--src/test/scala/firrtlTests/LoweringCompilersSpec.scala13
-rw-r--r--src/test/scala/firrtlTests/ZeroLengthVecsSpec.scala68
5 files changed, 146 insertions, 6 deletions
diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala
index ac5d8a4e..5c6dfc3f 100644
--- a/src/main/scala/firrtl/passes/RemoveAccesses.scala
+++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala
@@ -18,6 +18,7 @@ object RemoveAccesses extends Pass {
override val prerequisites =
Seq( Dependency(PullMuxes),
+ Dependency(ZeroLengthVecs),
Dependency(ReplaceAccesses),
Dependency(ExpandConnects) ) ++ firrtl.stage.Forms.Deduped
diff --git a/src/main/scala/firrtl/passes/ZeroLengthVecs.scala b/src/main/scala/firrtl/passes/ZeroLengthVecs.scala
new file mode 100644
index 00000000..67d9bce4
--- /dev/null
+++ b/src/main/scala/firrtl/passes/ZeroLengthVecs.scala
@@ -0,0 +1,69 @@
+// See LICENSE for license details.
+
+package firrtl.passes
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Mappers._
+import firrtl.PrimOps._
+import firrtl.options.{Dependency, PreservesAll}
+
+/** Handles dynamic accesses to zero-length vectors.
+ *
+ * @note Removes assignments that use a zero-length vector as a sink
+ * @note Removes signals resulting from accesses to a zero-length vector from attach groups
+ * @note Removes attaches that become degenerate after zero-length-accessor removal
+ * @note Replaces "source" references to elements of zero-length vectors with always-invalid validif
+ */
+object ZeroLengthVecs extends Pass with PreservesAll[Transform] {
+ override val prerequisites =
+ Seq( Dependency(PullMuxes),
+ Dependency(ResolveKinds),
+ Dependency(InferTypes),
+ Dependency(ExpandConnects) )
+
+ // Pass in an expression, not just a type, since it's not possible to generate an expression of
+ // interval type with the type alone unless you declare a component
+ private def replaceWithDontCare(toReplace: Expression): Expression = {
+ val default = toReplace.tpe match {
+ case UIntType(w) => UIntLiteral(0, w)
+ case SIntType(w) => SIntLiteral(0, w)
+ case FixedType(w, p) => FixedLiteral(0, w, p)
+ case it: IntervalType =>
+ val zeroType = IntervalType(Closed(0), Closed(0), IntWidth(0))
+ val zeroLit = DoPrim(AsInterval, Seq(SIntLiteral(0)), Seq(0, 0, 0), zeroType)
+ DoPrim(Clip, Seq(zeroLit, toReplace), Nil, it)
+ }
+ ValidIf(UIntLiteral(0), default, toReplace.tpe)
+ }
+
+ private def zeroLenDerivedRefLike(expr: Expression): Boolean = (expr, expr.tpe) match {
+ case (_, VectorType(_, 0)) => true
+ case (WSubIndex(e, _, _, _), _) => zeroLenDerivedRefLike(e)
+ case (WSubAccess(e, _, _, _), _) => zeroLenDerivedRefLike(e)
+ case (WSubField(e, _, _, _), _) => zeroLenDerivedRefLike(e)
+ case _ => false
+ }
+
+ // The connects have all been lowered, so all aggregate-typed expressions are "grounded" by WSubField/WSubAccess/WSubIndex
+ // Map before matching because we want don't-cares to propagate UP expression trees
+ private def dropZeroLenSubAccesses(expr: Expression): Expression = expr match {
+ case _: WSubIndex | _: WSubAccess | _: WSubField =>
+ if (zeroLenDerivedRefLike(expr)) replaceWithDontCare(expr) else expr
+ case e => e map dropZeroLenSubAccesses
+ }
+
+ // Attach semantics: drop all zero-length-derived members of attach group, drop stmt if trivial
+ private def onStmt(stmt: Statement): Statement = stmt match {
+ case Connect(_, sink, _) if zeroLenDerivedRefLike(sink) => EmptyStmt
+ case IsInvalid(_, sink) if zeroLenDerivedRefLike(sink) => EmptyStmt
+ case Attach(info, sinks) =>
+ val filtered = Attach(info, sinks.filterNot(zeroLenDerivedRefLike))
+ if (filtered.exprs.length < 2) EmptyStmt else filtered
+ case s => s.map(onStmt).map(dropZeroLenSubAccesses)
+ }
+
+ override def run(c: Circuit): Circuit = {
+ c.copy(modules = c.modules.map(m => m.map(onStmt)))
+ }
+}
diff --git a/src/main/scala/firrtl/stage/Forms.scala b/src/main/scala/firrtl/stage/Forms.scala
index 3e9803b7..76587abc 100644
--- a/src/main/scala/firrtl/stage/Forms.scala
+++ b/src/main/scala/firrtl/stage/Forms.scala
@@ -51,6 +51,7 @@ object Forms {
Dependency(passes.ReplaceAccesses),
Dependency(passes.ExpandConnects),
Dependency(passes.RemoveAccesses),
+ Dependency(passes.ZeroLengthVecs),
Dependency[passes.ExpandWhensAndCheck],
Dependency[passes.RemoveIntervals],
Dependency(passes.ConvertFixedToSInt),
diff --git a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
index dcc4e48d..648e45cd 100644
--- a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
+++ b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
@@ -61,6 +61,7 @@ class LoweringCompilersSpec extends FlatSpec with Matchers {
passes.PullMuxes,
passes.ReplaceAccesses,
passes.ExpandConnects,
+ passes.ZeroLengthVecs,
passes.RemoveAccesses,
passes.Uniquify,
passes.ExpandWhens,
@@ -156,17 +157,17 @@ class LoweringCompilersSpec extends FlatSpec with Matchers {
it should "replicate the old order" in {
val tm = new TransformManager(Forms.MidForm, Forms.Deduped)
val patches = Seq(
- Add(5, Seq(Dependency(firrtl.passes.ResolveKinds),
+ Add(6, Seq(Dependency(firrtl.passes.ResolveKinds),
Dependency(firrtl.passes.InferTypes))),
- Del(6),
Del(7),
- Add(6, Seq(Dependency[firrtl.passes.ExpandWhensAndCheck])),
- Del(10),
+ Del(8),
+ Add(7, Seq(Dependency[firrtl.passes.ExpandWhensAndCheck])),
Del(11),
Del(12),
- Add(11, Seq(Dependency(firrtl.passes.ResolveFlows),
+ Del(13),
+ Add(12, Seq(Dependency(firrtl.passes.ResolveFlows),
Dependency[firrtl.passes.InferWidths])),
- Del(13)
+ Del(14)
)
compare(legacyTransforms(new HighFirrtlToMiddleFirrtl), tm, patches)
}
diff --git a/src/test/scala/firrtlTests/ZeroLengthVecsSpec.scala b/src/test/scala/firrtlTests/ZeroLengthVecsSpec.scala
new file mode 100644
index 00000000..715714dd
--- /dev/null
+++ b/src/test/scala/firrtlTests/ZeroLengthVecsSpec.scala
@@ -0,0 +1,68 @@
+// See LICENSE for license details.
+
+package firrtlTests
+
+import firrtl._
+import firrtl.passes._
+import firrtl.testutils.FirrtlFlatSpec
+
+class ZeroLengthVecsSpec extends FirrtlFlatSpec {
+ val transforms = Seq(
+ ToWorkingIR,
+ ResolveKinds,
+ InferTypes,
+ ResolveFlows,
+ new InferWidths,
+ ZeroLengthVecs,
+ CheckTypes)
+ protected def exec(input: String) = {
+ transforms.foldLeft(CircuitState(parse(input), UnknownForm)) {
+ (c: CircuitState, t: Transform) => t.runTransform(c)
+ }.circuit.serialize
+ }
+
+ "ZeroLengthVecs" should "drop subaccesses to zero-length vectors" in {
+ val input =
+ """circuit bar :
+ | module bar :
+ | input i : { a : UInt<8>, b : UInt<4> }[0]
+ | input sel : UInt<1>
+ | output foo : UInt<1>[0]
+ | output o : UInt<8>
+ | foo[UInt<1>(0)] <= UInt<1>(0)
+ | o <= i[sel].a
+ |""".stripMargin
+ val check =
+ """circuit bar :
+ | module bar :
+ | input i : { a : UInt<8>, b : UInt<4> }[0]
+ | input sel : UInt<1>
+ | output foo : UInt<1>[0]
+ | output o : UInt<8>
+ | skip
+ | o <= validif(UInt<1>(0), UInt<8>(0))
+ |""".stripMargin
+ (parse(exec(input))) should be (parse(check))
+ }
+
+ "ZeroLengthVecs" should "handle intervals correctly" in {
+ val input =
+ """circuit bar :
+ | module bar :
+ | input i : Interval[3,4].0[0]
+ | input sel : UInt<1>
+ | output o : Interval[3,4].0
+ | o <= i[sel]
+ |""".stripMargin
+ val check =
+ """circuit bar :
+ | module bar :
+ | input i : Interval[3,4].0[0]
+ | input sel : UInt<1>
+ | output o : Interval[3,4].0
+ | o <= validif(UInt<1>(0), clip(asInterval(SInt<1>(0), 0, 0, 0), i[sel]))
+ |""".stripMargin
+ (parse(exec(input))) should be (parse(check))
+ }
+
+}