diff options
| author | Albert Magyar | 2020-04-07 15:04:17 -0700 |
|---|---|---|
| committer | GitHub | 2020-04-07 22:04:17 +0000 |
| commit | 1a03e6356e451136d522d5a9acba374dd8972b24 (patch) | |
| tree | 24e568872cb6db4ac9ce87080d1e06a4001a3017 /src | |
| parent | a9034bac8df5672b04a53c0ad99d82f94465d678 (diff) | |
Fix dynamic SubAccess of zero-length vectors (#1450)
* Fix dynamic SubAccess of zero-length vectors
* Fixes #230
* Add new ZeroLengthVecs pass that occurs before RemoveAccesses
* Include this in stage.Forms.MidForm
* Add to High->Mid order in compiler test based on @seldridge feedback
* Use validif to produce out-of-bounds value in ZeroLengthVecs
* Update scaladoc
* Fix test imports
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveAccesses.scala | 1 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/ZeroLengthVecs.scala | 69 | ||||
| -rw-r--r-- | src/main/scala/firrtl/stage/Forms.scala | 1 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/LoweringCompilersSpec.scala | 13 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ZeroLengthVecsSpec.scala | 68 |
5 files changed, 146 insertions, 6 deletions
diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala index ac5d8a4e..5c6dfc3f 100644 --- a/src/main/scala/firrtl/passes/RemoveAccesses.scala +++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala @@ -18,6 +18,7 @@ object RemoveAccesses extends Pass { override val prerequisites = Seq( Dependency(PullMuxes), + Dependency(ZeroLengthVecs), Dependency(ReplaceAccesses), Dependency(ExpandConnects) ) ++ firrtl.stage.Forms.Deduped diff --git a/src/main/scala/firrtl/passes/ZeroLengthVecs.scala b/src/main/scala/firrtl/passes/ZeroLengthVecs.scala new file mode 100644 index 00000000..67d9bce4 --- /dev/null +++ b/src/main/scala/firrtl/passes/ZeroLengthVecs.scala @@ -0,0 +1,69 @@ +// See LICENSE for license details. + +package firrtl.passes + +import firrtl._ +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.PrimOps._ +import firrtl.options.{Dependency, PreservesAll} + +/** Handles dynamic accesses to zero-length vectors. + * + * @note Removes assignments that use a zero-length vector as a sink + * @note Removes signals resulting from accesses to a zero-length vector from attach groups + * @note Removes attaches that become degenerate after zero-length-accessor removal + * @note Replaces "source" references to elements of zero-length vectors with always-invalid validif + */ +object ZeroLengthVecs extends Pass with PreservesAll[Transform] { + override val prerequisites = + Seq( Dependency(PullMuxes), + Dependency(ResolveKinds), + Dependency(InferTypes), + Dependency(ExpandConnects) ) + + // Pass in an expression, not just a type, since it's not possible to generate an expression of + // interval type with the type alone unless you declare a component + private def replaceWithDontCare(toReplace: Expression): Expression = { + val default = toReplace.tpe match { + case UIntType(w) => UIntLiteral(0, w) + case SIntType(w) => SIntLiteral(0, w) + case FixedType(w, p) => FixedLiteral(0, w, p) + case it: IntervalType => + val zeroType = IntervalType(Closed(0), Closed(0), IntWidth(0)) + val zeroLit = DoPrim(AsInterval, Seq(SIntLiteral(0)), Seq(0, 0, 0), zeroType) + DoPrim(Clip, Seq(zeroLit, toReplace), Nil, it) + } + ValidIf(UIntLiteral(0), default, toReplace.tpe) + } + + private def zeroLenDerivedRefLike(expr: Expression): Boolean = (expr, expr.tpe) match { + case (_, VectorType(_, 0)) => true + case (WSubIndex(e, _, _, _), _) => zeroLenDerivedRefLike(e) + case (WSubAccess(e, _, _, _), _) => zeroLenDerivedRefLike(e) + case (WSubField(e, _, _, _), _) => zeroLenDerivedRefLike(e) + case _ => false + } + + // The connects have all been lowered, so all aggregate-typed expressions are "grounded" by WSubField/WSubAccess/WSubIndex + // Map before matching because we want don't-cares to propagate UP expression trees + private def dropZeroLenSubAccesses(expr: Expression): Expression = expr match { + case _: WSubIndex | _: WSubAccess | _: WSubField => + if (zeroLenDerivedRefLike(expr)) replaceWithDontCare(expr) else expr + case e => e map dropZeroLenSubAccesses + } + + // Attach semantics: drop all zero-length-derived members of attach group, drop stmt if trivial + private def onStmt(stmt: Statement): Statement = stmt match { + case Connect(_, sink, _) if zeroLenDerivedRefLike(sink) => EmptyStmt + case IsInvalid(_, sink) if zeroLenDerivedRefLike(sink) => EmptyStmt + case Attach(info, sinks) => + val filtered = Attach(info, sinks.filterNot(zeroLenDerivedRefLike)) + if (filtered.exprs.length < 2) EmptyStmt else filtered + case s => s.map(onStmt).map(dropZeroLenSubAccesses) + } + + override def run(c: Circuit): Circuit = { + c.copy(modules = c.modules.map(m => m.map(onStmt))) + } +} diff --git a/src/main/scala/firrtl/stage/Forms.scala b/src/main/scala/firrtl/stage/Forms.scala index 3e9803b7..76587abc 100644 --- a/src/main/scala/firrtl/stage/Forms.scala +++ b/src/main/scala/firrtl/stage/Forms.scala @@ -51,6 +51,7 @@ object Forms { Dependency(passes.ReplaceAccesses), Dependency(passes.ExpandConnects), Dependency(passes.RemoveAccesses), + Dependency(passes.ZeroLengthVecs), Dependency[passes.ExpandWhensAndCheck], Dependency[passes.RemoveIntervals], Dependency(passes.ConvertFixedToSInt), diff --git a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala index dcc4e48d..648e45cd 100644 --- a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala +++ b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala @@ -61,6 +61,7 @@ class LoweringCompilersSpec extends FlatSpec with Matchers { passes.PullMuxes, passes.ReplaceAccesses, passes.ExpandConnects, + passes.ZeroLengthVecs, passes.RemoveAccesses, passes.Uniquify, passes.ExpandWhens, @@ -156,17 +157,17 @@ class LoweringCompilersSpec extends FlatSpec with Matchers { it should "replicate the old order" in { val tm = new TransformManager(Forms.MidForm, Forms.Deduped) val patches = Seq( - Add(5, Seq(Dependency(firrtl.passes.ResolveKinds), + Add(6, Seq(Dependency(firrtl.passes.ResolveKinds), Dependency(firrtl.passes.InferTypes))), - Del(6), Del(7), - Add(6, Seq(Dependency[firrtl.passes.ExpandWhensAndCheck])), - Del(10), + Del(8), + Add(7, Seq(Dependency[firrtl.passes.ExpandWhensAndCheck])), Del(11), Del(12), - Add(11, Seq(Dependency(firrtl.passes.ResolveFlows), + Del(13), + Add(12, Seq(Dependency(firrtl.passes.ResolveFlows), Dependency[firrtl.passes.InferWidths])), - Del(13) + Del(14) ) compare(legacyTransforms(new HighFirrtlToMiddleFirrtl), tm, patches) } diff --git a/src/test/scala/firrtlTests/ZeroLengthVecsSpec.scala b/src/test/scala/firrtlTests/ZeroLengthVecsSpec.scala new file mode 100644 index 00000000..715714dd --- /dev/null +++ b/src/test/scala/firrtlTests/ZeroLengthVecsSpec.scala @@ -0,0 +1,68 @@ +// See LICENSE for license details. + +package firrtlTests + +import firrtl._ +import firrtl.passes._ +import firrtl.testutils.FirrtlFlatSpec + +class ZeroLengthVecsSpec extends FirrtlFlatSpec { + val transforms = Seq( + ToWorkingIR, + ResolveKinds, + InferTypes, + ResolveFlows, + new InferWidths, + ZeroLengthVecs, + CheckTypes) + protected def exec(input: String) = { + transforms.foldLeft(CircuitState(parse(input), UnknownForm)) { + (c: CircuitState, t: Transform) => t.runTransform(c) + }.circuit.serialize + } + + "ZeroLengthVecs" should "drop subaccesses to zero-length vectors" in { + val input = + """circuit bar : + | module bar : + | input i : { a : UInt<8>, b : UInt<4> }[0] + | input sel : UInt<1> + | output foo : UInt<1>[0] + | output o : UInt<8> + | foo[UInt<1>(0)] <= UInt<1>(0) + | o <= i[sel].a + |""".stripMargin + val check = + """circuit bar : + | module bar : + | input i : { a : UInt<8>, b : UInt<4> }[0] + | input sel : UInt<1> + | output foo : UInt<1>[0] + | output o : UInt<8> + | skip + | o <= validif(UInt<1>(0), UInt<8>(0)) + |""".stripMargin + (parse(exec(input))) should be (parse(check)) + } + + "ZeroLengthVecs" should "handle intervals correctly" in { + val input = + """circuit bar : + | module bar : + | input i : Interval[3,4].0[0] + | input sel : UInt<1> + | output o : Interval[3,4].0 + | o <= i[sel] + |""".stripMargin + val check = + """circuit bar : + | module bar : + | input i : Interval[3,4].0[0] + | input sel : UInt<1> + | output o : Interval[3,4].0 + | o <= validif(UInt<1>(0), clip(asInterval(SInt<1>(0), 0, 0, 0), i[sel])) + |""".stripMargin + (parse(exec(input))) should be (parse(check)) + } + +} |
