aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJiuyang Liu2020-12-02 01:53:04 +0000
committerGitHub2020-12-02 01:53:04 +0000
commit6c5ce834e26386100b196881f6e487aed26c9c0a (patch)
treef2b9225dc42fd04ea7e7c8fb4d80bd2071b68966 /src
parent4e46f8c614b81143621f2b4187392f6912d882bf (diff)
Fix subaccess (#1984)
* add test for RemoveAccessesSpec. * fix nested SubAccess bug. Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/RemoveAccesses.scala35
-rw-r--r--src/test/scala/firrtlTests/LowerTypesSpec.scala17
2 files changed, 39 insertions, 13 deletions
diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala
index f49af935..90437e56 100644
--- a/src/main/scala/firrtl/passes/RemoveAccesses.scala
+++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala
@@ -47,8 +47,7 @@ object RemoveAccesses extends Pass {
* Seq(Location(a[0], UIntLiteral(0)), Location(a[1], UIntLiteral(1)))
*/
private def getLocations(e: Expression): Seq[Location] = e match {
- case e: WRef => create_exps(e).map(Location(_, one))
- case e: WSubIndex =>
+ case e: SubIndex =>
val ls = getLocations(e.expr)
val start = get_point(e)
val end = start + get_size(e.tpe)
@@ -57,7 +56,7 @@ object RemoveAccesses extends Pass {
(l, i) <- ls.zipWithIndex
if ((i % stride) >= start) & ((i % stride) < end)
) yield l
- case e: WSubField =>
+ case e: SubField =>
val ls = getLocations(e.expr)
val start = get_point(e)
val end = start + get_size(e.tpe)
@@ -66,17 +65,27 @@ object RemoveAccesses extends Pass {
(l, i) <- ls.zipWithIndex
if ((i % stride) >= start) & ((i % stride) < end)
) yield l
- case e: WSubAccess =>
- val ls = getLocations(e.expr)
- val stride = get_size(e.tpe)
- val wrap = e.expr.tpe.asInstanceOf[VectorType].size
- ls.zipWithIndex.map {
- case (l, i) =>
- val c = (i / stride) % wrap
- val basex = l.base
- val guardx = AND(l.guard, EQV(UIntLiteral(c), e.index))
- Location(basex, guardx)
+ case SubAccess(expr, index, tpe, _) =>
+ getLocations(expr).zipWithIndex.flatMap {
+ case (Location(exprBase, exprGuard), exprIndex) =>
+ getLocations(index).map {
+ case Location(indexBase, indexGuard) =>
+ Location(
+ exprBase,
+ AND(
+ AND(
+ indexGuard,
+ exprGuard
+ ),
+ EQV(
+ UIntLiteral((exprIndex / get_size(tpe)) % expr.tpe.asInstanceOf[VectorType].size),
+ indexBase
+ )
+ )
+ )
+ }
}
+ case e => create_exps(e).map(Location(_, one))
}
/** Returns true if e contains a [[firrtl.WSubAccess]]
diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala
index 9e58b74c..da84b362 100644
--- a/src/test/scala/firrtlTests/LowerTypesSpec.scala
+++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala
@@ -456,6 +456,23 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec {
executeTest(input, expected)
}
+ it should "remove index express in SubAccess" in {
+ val input =
+ s"""circuit Bug :
+ | module Bug :
+ | input in0 : UInt<1> [2][2]
+ | input in1 : UInt<1> [2]
+ | input in2 : UInt<1> [2]
+ | output out : UInt<1>
+ | out <= in0[in1[in2[0]]][in1[in2[1]]]
+ |""".stripMargin
+ val expected = Seq(
+ "out <= _in0_in1_in1_in2_1"
+ )
+
+ executeTest(input, expected)
+ }
+
it should "rename memories" in {
val input =
"""circuit Test :