aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/ir/StructuralHash.scala
blob: 26e7d210e8374dcc327df6100060063a9032f837 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
// SPDX-License-Identifier: Apache-2.0

package firrtl.ir
import firrtl.PrimOps

import java.security.MessageDigest
import scala.collection.mutable

/** This object can performs a "structural" hash over any firrtl Module.
  * It ignores:
  * - [firrtl.ir.Expression Expression] types
  * - Any [firrtl.ir.Info Info] fields
  * - Description on DescribedStmt
  * - each identifier name is replaced by a unique integer which only depends on the order of declaration
  *   and is thus deterministic
  * - Module names are ignored.
  *
  * Because of the way we "agnostify" bundle types, all SubField access nodes need to have a known
  * bundle type. Thus - in a lot of cases, like after reading firrtl from a file - you need to run
  * the firrtl type inference before hashing.
  *
  * Please note that module hashes don't include any submodules.
  * Two structurally equivalent modules are only functionally equivalent if they are part
  * of the same circuit and thus all modules referred to in DefInstance are the same.
  *
  * @author Kevin Laeufer <laeufer@cs.berkeley.edu>
  */
object StructuralHash {
  def sha256(node: DefModule, moduleRename: String => String = identity): HashCode = {
    val m = MessageDigest.getInstance(SHA256)
    new StructuralHash(new MessageDigestHasher(m), moduleRename).hash(node)
    new MDHashCode(m.digest())
  }

  /** This includes the names of ports and any port bundle field names in the hash. */
  def sha256WithSignificantPortNames(module: DefModule, moduleRename: String => String = identity): HashCode = {
    val m = MessageDigest.getInstance(SHA256)
    hashModuleAndPortNames(module, new MessageDigestHasher(m), moduleRename)
    new MDHashCode(m.digest())
  }

  private[firrtl] def sha256(str: String): HashCode = {
    val m = MessageDigest.getInstance(SHA256)
    m.update(str.getBytes())
    new MDHashCode(m.digest())
  }

  /** Using this to hash arbitrary nodes can have unexpected results like:
    *   hash(`a <= 1`) == hash(`b <= 1`).
    * This method is package private to allow for unit testing but should not be exposed to the user.
    */
  private[firrtl] def sha256Node(node: FirrtlNode): HashCode = {
    val m = MessageDigest.getInstance(SHA256)
    hash(node, new MessageDigestHasher(m), identity)
    new MDHashCode(m.digest())
  }

  // see: https://docs.oracle.com/javase/7/docs/api/java/security/MessageDigest.html
  private val SHA256 = "SHA-256"

  private def hash(node: FirrtlNode, h: Hasher, rename: String => String): Unit = node match {
    case n: Expression  => new StructuralHash(h, rename).hash(n)
    case n: Statement   => new StructuralHash(h, rename).hash(n)
    case n: Type        => new StructuralHash(h, rename).hash(n)
    case n: Width       => new StructuralHash(h, rename).hash(n)
    case n: Orientation => new StructuralHash(h, rename).hash(n)
    case n: Field       => new StructuralHash(h, rename).hash(n)
    case n: Direction   => new StructuralHash(h, rename).hash(n)
    case n: Port        => new StructuralHash(h, rename).hash(n)
    case n: Param       => new StructuralHash(h, rename).hash(n)
    case _: Info        => throw new RuntimeException("The structural hash of Info is meaningless.")
    case n: DefModule   => new StructuralHash(h, rename).hash(n)
    case n: Circuit     => hashCircuit(n, h, rename)
    case n: StringLit   => h.update(n.toString)
  }

  private def hashModuleAndPortNames(m: DefModule, h: Hasher, rename: String => String): Unit = {
    val sh = new StructuralHash(h, rename)
    sh.hash(m)
    // hash port names
    m.ports.foreach { p =>
      h.update(p.name)
      hashPortTypeName(p.tpe, h.update)
    }
  }

  private def hashPortTypeName(tpe: Type, h: String => Unit): Unit = tpe match {
    case BundleType(fields) => fields.foreach { f => h(f.name); hashPortTypeName(f.tpe, h) }
    case VectorType(vt, _)  => hashPortTypeName(vt, h)
    case _                  => // ignore ground types since they do not have field names nor sub-types
  }

  private def hashCircuit(c: Circuit, h: Hasher, rename: String => String): Unit = {
    h.update(127)
    h.update(c.main)
    // sort modules to make hash more useful
    val mods = c.modules.sortBy(_.name)
    // we create a new StructuralHash for each module since each module has its own namespace
    mods.foreach { m =>
      new StructuralHash(h, rename).hash(m)
    }
  }

  private val primOpToId = PrimOps.builtinPrimOps.zipWithIndex.map { case (op, i) => op -> (-i - 1).toByte }.toMap
  assert(primOpToId.values.max == -1, "PrimOp nodes use ids -1 ... -50")
  assert(primOpToId.values.min >= -50, "PrimOp nodes use ids -1 ... -50")
  private def primOp(p: PrimOp): Byte = primOpToId(p)

  // verification ops are not firrtl nodes and thus not part of the same id namespace
  private def verificationOp(op: Formal.Value): Byte = op match {
    case Formal.Assert => 0
    case Formal.Assume => 1
    case Formal.Cover  => 2
  }
}

trait HashCode {

  /** String representation of the hash code.
    * Two instances of [[HashCode]] are equal if and only if their toHashString values are equal.
    */
  def toHashString: String
  override def hashCode(): Int = toHashString.hashCode
  override def equals(obj: Any): Boolean = obj match {
    case hashCode: HashCode => this.toHashString.equals(hashCode.toHashString)
    case _ => false
  }
}

private class MDHashCode(code: Array[Byte]) extends HashCode {
  override val toHashString: String = code.map(b => f"${b.toInt & 0xff}%02x").mkString("")
}

/** Generic hashing interface which allows us to use different backends to trade of speed and collision resistance */
private trait Hasher {
  def update(b: Byte):    Unit
  def update(i: Int):     Unit
  def update(l: Long):    Unit
  def update(s: String):  Unit
  def update(b: Array[Byte]): Unit
  def update(d: Double): Unit = update(java.lang.Double.doubleToRawLongBits(d))
  def update(i: BigInt):  Unit = update(i.toByteArray)
  def update(b: Boolean): Unit = if (b) update(1.toByte) else update(0.toByte)
  def update(i: BigDecimal): Unit = {
    // this might be broken, tried to borrow some code from BigDecimal.computeHashCode
    val temp = i.bigDecimal.stripTrailingZeros()
    val bigInt = temp.scaleByPowerOfTen(temp.scale).toBigInteger
    update(bigInt)
    update(temp.scale)
  }
}

private class MessageDigestHasher(m: MessageDigest) extends Hasher {
  override def update(b: Byte): Unit = m.update(b)
  override def update(i: Int): Unit = {
    m.update(((i >> 0) & 0xff).toByte)
    m.update(((i >> 8) & 0xff).toByte)
    m.update(((i >> 16) & 0xff).toByte)
    m.update(((i >> 24) & 0xff).toByte)
  }
  override def update(l: Long): Unit = {
    m.update(((l >> 0) & 0xff).toByte)
    m.update(((l >> 8) & 0xff).toByte)
    m.update(((l >> 16) & 0xff).toByte)
    m.update(((l >> 24) & 0xff).toByte)
    m.update(((l >> 32) & 0xff).toByte)
    m.update(((l >> 40) & 0xff).toByte)
    m.update(((l >> 48) & 0xff).toByte)
    m.update(((l >> 56) & 0xff).toByte)
  }
  // the encoding of the bytes should not matter as long as we are on the same platform
  override def update(s: String):      Unit = m.update(s.getBytes())
  override def update(b: Array[Byte]): Unit = m.update(b)
}

class StructuralHash private (h: Hasher, renameModule: String => String) {
  // replace identifiers with incrementing integers
  private val nameToInt = mutable.HashMap[String, Int]()
  private var nameCounter: Int = 0
  @inline private def n(name: String): Unit = hash(
    nameToInt.getOrElseUpdate(
      name, {
        val ii = nameCounter
        nameCounter = nameCounter + 1
        ii
      }
    )
  )

  // internal convenience methods
  @inline private def id(b:   Byte):       Unit = h.update(b)
  @inline private def hash(i: Int):        Unit = h.update(i)
  @inline private def hash(b: Boolean):    Unit = h.update(b)
  @inline private def hash(d: Double):     Unit = h.update(d)
  @inline private def hash(i: BigInt):     Unit = h.update(i)
  @inline private def hash(i: BigDecimal): Unit = h.update(i)
  @inline private def hash(s: String):     Unit = h.update(s)

  private def hash(node: Expression): Unit = node match {
    case Reference(name, _, _, _)    => id(0); n(name)
    case DoPrim(op, args, consts, _) =>
      // no need to hash the number of arguments or constants since that is implied by the op
      id(1); h.update(StructuralHash.primOp(op)); args.foreach(hash); consts.foreach(hash)
    case UIntLiteral(value, width) => id(2); hash(value); hash(width)
    // We hash bundles as if fields are accessed by their index.
    // Thus we need to also hash field accesses that way.
    // This has the side-effect that `x.y` might hash to the same value as `z.r`, for example if the
    // types are `x: {y: UInt<1>, ...}` and `z: {r: UInt<1>, ...}` respectively.
    // They do not hash to the same value if the type of `z` is e.g., `z: {..., r: UInt<1>, ...}`
    // as that would have the `r` field at a different index.
    case SubField(expr, name, _, _) =>
      id(3); hash(expr)
      // find field index and hash that instead of the field name
      val fields = expr.tpe match {
        case b: BundleType => b.fields
        case other =>
          throw new RuntimeException(s"Unexpected type $other for SubField access. Did you run the type checker?")
      }
      val index = fields.zipWithIndex.find(_._1.name == name).map(_._2).get
      hash(index)
    case SubIndex(expr, value, _, _)  => id(4); hash(expr); hash(value)
    case SubAccess(expr, index, _, _) => id(5); hash(expr); hash(index)
    case Mux(cond, tval, fval, _)     => id(6); hash(cond); hash(tval); hash(fval)
    case ValidIf(cond, value, _)      => id(7); hash(cond); hash(value)
    case SIntLiteral(value, width)    => id(8); hash(value); hash(width)
    case FixedLiteral(value, width, point) => id(9); hash(value); hash(width); hash(point)
    // WIR
    case firrtl.WVoid           => id(10)
    case firrtl.WInvalid        => id(11)
    case firrtl.EmptyExpression => id(12)
    // VRandom is used in the Emitter
    case firrtl.VRandom(width) => id(13); hash(width)
    // ids 14 ... 19 are reserved for future Expression nodes
  }

  private def hash(node: Statement): Unit = node match {
    // all info fields are ignore
    case DefNode(_, name, value) => id(20); n(name); hash(value)
    case Connect(_, loc, expr) => id(21); hash(loc); hash(expr)
    // we place the unique id 23 between conseq and alt to distinguish between them in case conseq is empty
    // we place the unique id 24 after alt to distinguish between alt and the next statement in case alt is empty
    case Conditionally(_, pred, conseq, alt) => id(22); hash(pred); hash(conseq); id(23); hash(alt); id(24)
    case EmptyStmt                           => // empty statements are ignored
    case Block(stmts)                        => stmts.foreach(hash) // block structure is ignored
    case Stop(_, ret, clk, en)               => id(25); hash(ret); hash(clk); hash(en)
    case Print(_, string, args, clk, en)     =>
      // the string is part of the side effect and thus part of the circuit behavior
      id(26); hash(string.string); hash(args.length); args.foreach(hash); hash(clk); hash(en)
    case IsInvalid(_, expr)    => id(27); hash(expr)
    case DefWire(_, name, tpe) => id(28); n(name); hash(tpe)
    case DefRegister(_, name, tpe, clock, reset, init) =>
      id(29); n(name); hash(tpe); hash(clock); hash(reset); hash(init)
    case DefInstance(_, name, module, _) =>
      // Module is in the global namespace which is why we cannot replace it with a numeric id.
      // However, it might have been renamed as part of the dedup consolidation.
      id(30); n(name); hash(renameModule(module))
    // descriptions on statements are ignores
    case firrtl.DescribedStmt(_, stmt) => hash(stmt)
    case DefMemory(
          _,
          name,
          dataType,
          depth,
          writeLatency,
          readLatency,
          readers,
          writers,
          readwriters,
          readUnderWrite
        ) =>
      id(30); n(name); hash(dataType); hash(depth); hash(writeLatency); hash(readLatency)
      hash(readers.length); readers.foreach(hash)
      hash(writers.length); writers.foreach(hash)
      hash(readwriters.length); readwriters.foreach(hash)
      hash(readUnderWrite)
    case PartialConnect(_, loc, expr) => id(31); hash(loc); hash(expr)
    case Attach(_, exprs) => id(32); hash(exprs.length); exprs.foreach(hash)
    // WIR
    case firrtl.CDefMemory(_, name, tpe, size, seq, readUnderWrite) =>
      id(33); n(name); hash(tpe); hash(size); hash(seq); hash(readUnderWrite)
    case firrtl.CDefMPort(_, name, _, mem, exps, direction) =>
      // the type of the MPort depends only on the memory (in well types firrtl) and can thus be ignored
      id(34); n(name); n(mem); hash(exps.length); exps.foreach(hash); hash(direction)
    // DefAnnotatedMemory from MemIR.scala
    case firrtl.passes.memlib.DefAnnotatedMemory(
          _,
          name,
          dataType,
          depth,
          writeLatency,
          readLatency,
          readers,
          writers,
          readwriters,
          readUnderWrite,
          maskGran,
          memRef
        ) =>
      id(35); n(name); hash(dataType); hash(depth); hash(writeLatency); hash(readLatency)
      hash(readers.length); readers.foreach(hash)
      hash(writers.length); writers.foreach(hash)
      hash(readwriters.length); readwriters.foreach(hash)
      hash(readUnderWrite.toString)
      hash(maskGran.size); maskGran.foreach(hash)
      hash(memRef.size); memRef.foreach { case (a, b) => hash(a); hash(b) }
    case Verification(op, _, clk, pred, en, msg) =>
      id(36); hash(StructuralHash.verificationOp(op)); hash(clk); hash(pred); hash(en); hash(msg.string)
    // ids 37 ... 39 are reserved for future Statement nodes
  }

  // ReadUnderWrite is never used in place of a FirrtlNode and thus we can start a new id namespace
  private def hash(ruw: ReadUnderWrite.Value): Unit = ruw match {
    case ReadUnderWrite.New       => id(0)
    case ReadUnderWrite.Old       => id(1)
    case ReadUnderWrite.Undefined => id(2)
  }

  private def hash(node: Width): Unit = node match {
    case IntWidth(width) => id(40); hash(width)
    case UnknownWidth    => id(41)
    case CalcWidth(arg) => id(42); hash(arg)
    // we are hashing the name of the `VarWidth` instead of using `n` since these Vars exist in a different namespace
    case VarWidth(name) => id(43); hash(name)
    // ids 44 + 45 are reserved for future Width nodes
  }

  private def hash(node: Orientation): Unit = node match {
    case Default => id(46)
    case Flip    => id(47)
  }

  private def hash(node: Field): Unit = {
    // since we are only interested in a structural hash, we ignore field names
    // this means that: hash(`{x : UInt<1>, y: UInt<2>}`) == hash(`{y : UInt<1>, x: UInt<2>}`)
    // but:             hash(`{x : UInt<1>, y: UInt<2>}`) != hash(`{y : UInt<2>, x: UInt<1>}`)
    // which seems strange, since the connect semantics rely on field names, but it is the behavior that
    // has been used in the Dedup pass for a long time.
    // This position-based notion of equality requires us to replace field names with field indexes when hashing
    // SubField accesses.
    id(48); hash(node.flip); hash(node.tpe)
  }

  private def hash(node: Type): Unit = node match {
    // Types
    case UIntType(width: Width) => id(50); hash(width)
    case SIntType(width: Width) => id(51); hash(width)
    case FixedType(width, point)           => id(52); hash(width); hash(point)
    case BundleType(fields)                => id(53); hash(fields.length); fields.foreach(hash)
    case VectorType(tpe, size)             => id(54); hash(tpe); hash(size)
    case ClockType                         => id(55)
    case ResetType                         => id(56)
    case AsyncResetType                    => id(57)
    case AnalogType(width)                 => id(58); hash(width)
    case UnknownType                       => id(59)
    case IntervalType(lower, upper, point) => id(60); hash(lower); hash(upper); hash(point)
    // ids 61 ... 65 are reserved for future Type nodes
  }

  private def hash(node: Direction): Unit = node match {
    case Input  => id(66)
    case Output => id(67)
  }

  private def hash(node: Port): Unit = {
    id(68); n(node.name); hash(node.direction); hash(node.tpe)
  }

  private def hash(node: Param): Unit = node match {
    case IntParam(name, value)       => id(70); n(name); hash(value)
    case DoubleParam(name, value)    => id(71); n(name); hash(value)
    case StringParam(name, value)    => id(72); n(name); hash(value.string)
    case RawStringParam(name, value) => id(73); n(name); hash(value)
    // id 74 is reserved for future use
  }

  private def hash(node: DefModule): Unit = node match {
    // the module name is ignored since it does not affect module functionality
    case Module(_, _name, ports, body) =>
      id(75); hash(ports.length); ports.foreach(hash); hash(body)
    // the module name is ignored since it does not affect module functionality
    case ExtModule(_, name, ports, defname, params) =>
      id(76); hash(ports.length); ports.foreach(hash); hash(defname)
      hash(params.length); params.foreach(hash)
  }

  // id 127 is reserved for Circuit nodes

  private def hash(d: firrtl.MPortDir): Unit = d match {
    case firrtl.MInfer     => id(-70)
    case firrtl.MRead      => id(-71)
    case firrtl.MWrite     => id(-72)
    case firrtl.MReadWrite => id(-73)
  }

  private def hash(c: firrtl.constraint.Constraint): Unit = c match {
    case b: Bound => hash(b) /* uses ids -80 ... -84 */
    case firrtl.constraint.IsAdd(known, maxs, mins, others) =>
      id(-85); hash(known.nonEmpty); known.foreach(hash)
      hash(maxs.length); maxs.foreach(hash)
      hash(mins.length); mins.foreach(hash)
      hash(others.length); others.foreach(hash)
    case firrtl.constraint.IsFloor(child, dummyArg) => id(-86); hash(child); hash(dummyArg)
    case firrtl.constraint.IsKnown(decimal)         => id(-87); hash(decimal)
    case firrtl.constraint.IsNeg(child, dummyArg)   => id(-88); hash(child); hash(dummyArg)
    case firrtl.constraint.IsPow(child, dummyArg)   => id(-89); hash(child); hash(dummyArg)
    case firrtl.constraint.IsVar(str)               => id(-90); n(str)
  }

  private def hash(b: Bound): Unit = b match {
    case UnknownBound => id(-80)
    case CalcBound(arg) => id(-81); hash(arg)
    // we are hashing the name of the `VarBound` instead of using `n` since these Vars exist in a different namespace
    case VarBound(name) => id(-82); hash(name)
    case Open(value)    => id(-83); hash(value)
    case Closed(value)  => id(-84); hash(value)
  }
}