aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala
blob: b916842fab900fb0f5b250da4562c29a5c8fa070 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
// SPDX-License-Identifier: Apache-2.0

package firrtl.passes
package memlib
import firrtl._
import firrtl.ir._
import firrtl.Mappers._
import firrtl.annotations._
import firrtl.stage.Forms

/** A component, e.g. register etc. Must be declared only once under the TopAnnotation */
case class NoDedupMemAnnotation(target: ComponentName) extends SingleTargetAnnotation[ComponentName] {
  def duplicate(n: ComponentName) = NoDedupMemAnnotation(n)
}

/** Resolves annotation ref to memories that exactly match (except name) another memory
  */
class ResolveMemoryReference extends Transform with DependencyAPIMigration {

  override def prerequisites = Forms.MidForm
  override def optionalPrerequisites = Seq.empty
  override def optionalPrerequisiteOf = Forms.MidEmitters
  override def invalidates(a: Transform) = false

  /** Helper class for determining when two memories are equivalent while igoring
    * irrelevant details like name and info
    */
  private class WrappedDefAnnoMemory(val underlying: DefAnnotatedMemory) {
    // Remove irrelevant details for comparison
    private def generic = underlying.copy(info = NoInfo, name = "", memRef = None)
    override def hashCode: Int = generic.hashCode
    override def equals(that: Any): Boolean = that match {
      case mem: WrappedDefAnnoMemory => this.generic == mem.generic
      case _ => false
    }
  }
  private def wrap(mem: DefAnnotatedMemory) = new WrappedDefAnnoMemory(mem)

  // Values are Tuple of Module Name and Memory Instance Name
  private type AnnotatedMemories = collection.mutable.HashMap[WrappedDefAnnoMemory, (String, String)]

  private def dedupable(noDedups: Map[String, Set[String]], module: String, memory: String): Boolean =
    noDedups.get(module).map(!_.contains(memory)).getOrElse(true)

  /** If a candidate memory is identical except for name to another, add an
    *   annotation that references the name of the other memory.
    */
  def updateMemStmts(
    mname:        String,
    existingMems: AnnotatedMemories,
    noDedupMap:   Map[String, Set[String]]
  )(s:            Statement
  ): Statement = s match {
    // If not dedupable, no need to add to existing (since nothing can dedup with it)
    // We just return the DefAnnotatedMemory as is in the default case below
    case m: DefAnnotatedMemory if dedupable(noDedupMap, mname, m.name) =>
      val wrapped = wrap(m)
      existingMems.get(wrapped) match {
        case proto @ Some(_) =>
          m.copy(memRef = proto)
        case None =>
          existingMems(wrapped) = (mname, m.name)
          m
      }
    case s => s.map(updateMemStmts(mname, existingMems, noDedupMap))
  }

  def run(c: Circuit, noDedupMap: Map[String, Set[String]]) = {
    val existingMems = new AnnotatedMemories
    val modulesx = c.modules.map(m => m.map(updateMemStmts(m.name, existingMems, noDedupMap)))
    c.copy(modules = modulesx)
  }
  def execute(state: CircuitState): CircuitState = {
    val (remainingAnnotations, noDedupMemAnnos) = state.annotations.partition {
      case _: NoDedupMemAnnotation => false
      case _ => true
    }
    val noDedups = noDedupMemAnnos.map {
      case NoDedupMemAnnotation(ComponentName(cn, ModuleName(mn, _))) => mn -> cn
    }
    val noDedupMap: Map[String, Set[String]] = noDedups.groupBy(_._1).mapValues(_.map(_._2).toSet).toMap
    state.copy(
      circuit = run(state.circuit, noDedupMap),
      annotations = remainingAnnotations
    )
  }
}