aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--benchmark/src/main/scala/firrtl/benchmark/hot/TransformBenchmark.scala4
-rw-r--r--src/main/scala/firrtl/PrimOps.scala2
-rw-r--r--src/main/scala/firrtl/ir/StructuralHash.scala395
-rw-r--r--src/main/scala/firrtl/transforms/Dedup.scala167
-rw-r--r--src/test/scala/firrtl/ir/StructuralHashSpec.scala277
5 files changed, 707 insertions, 138 deletions
diff --git a/benchmark/src/main/scala/firrtl/benchmark/hot/TransformBenchmark.scala b/benchmark/src/main/scala/firrtl/benchmark/hot/TransformBenchmark.scala
index 05b759c6..abbee5a9 100644
--- a/benchmark/src/main/scala/firrtl/benchmark/hot/TransformBenchmark.scala
+++ b/benchmark/src/main/scala/firrtl/benchmark/hot/TransformBenchmark.scala
@@ -7,8 +7,8 @@ package hot
import firrtl._
import firrtl.passes.LowerTypes
import firrtl.stage.TransformManager
-
import firrtl.benchmark.util._
+import firrtl.transforms.DedupModules
abstract class TransformBenchmark(factory: () => Transform) extends App {
val inputFile = args(0)
@@ -25,3 +25,5 @@ abstract class TransformBenchmark(factory: () => Transform) extends App {
}
object LowerTypesBenchmark extends TransformBenchmark(() => LowerTypes)
+
+object DedupBenchmark extends TransformBenchmark(() => new DedupModules()) \ No newline at end of file
diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala
index cbe3a027..768dcb9f 100644
--- a/src/main/scala/firrtl/PrimOps.scala
+++ b/src/main/scala/firrtl/PrimOps.scala
@@ -477,7 +477,7 @@ object PrimOps extends LazyLogging {
override def toString = "clip"
}
- private lazy val builtinPrimOps: Seq[PrimOp] =
+ private[firrtl] lazy val builtinPrimOps: Seq[PrimOp] =
Seq(Add, Sub, Mul, Div, Rem, Lt, Leq, Gt, Geq, Eq, Neq, Pad, AsUInt, AsSInt, AsInterval, AsClock, AsAsyncReset, Shl, Shr,
Dshl, Dshr, Neg, Cvt, Not, And, Or, Xor, Andr, Orr, Xorr, Cat, Bits, Head, Tail, AsFixedPoint, IncP, DecP,
SetP, Wrap, Clip, Squeeze)
diff --git a/src/main/scala/firrtl/ir/StructuralHash.scala b/src/main/scala/firrtl/ir/StructuralHash.scala
new file mode 100644
index 00000000..b6222ed7
--- /dev/null
+++ b/src/main/scala/firrtl/ir/StructuralHash.scala
@@ -0,0 +1,395 @@
+// See LICENSE for license details.
+
+package firrtl.ir
+import firrtl.PrimOps
+
+import java.security.MessageDigest
+import scala.collection.mutable
+
+/** This object can performs a "structural" hash over any firrtl Module.
+ * It ignores:
+ * - [firrtl.ir.Expression Expression] types
+ * - Any [firrtl.ir.Info Info] fields
+ * - Description on DescribedStmt
+ * - each identifier name is replaced by a unique integer which only depends on the order of declaration
+ * and is thus deterministic
+ * - Module names are ignored.
+ *
+ * Because of the way we "agnostify" bundle types, all SubField access nodes need to have a known
+ * bundle type. Thus - in a lot of cases, like after reading firrtl from a file - you need to run
+ * the firrtl type inference before hashing.
+ *
+ * Please note that module hashes don't include any submodules.
+ * Two structurally equivalent modules are only functionally equivalent if they are part
+ * of the same circuit and thus all modules referred to in DefInstance are the same.
+ *
+ * @author Kevin Laeufer <laeufer@cs.berkeley.edu>
+ * */
+object StructuralHash {
+ def sha256(node: DefModule, moduleRename: String => String = identity): HashCode = {
+ val m = MessageDigest.getInstance(SHA256)
+ new StructuralHash(new MessageDigestHasher(m), moduleRename).hash(node)
+ new MDHashCode(m.digest())
+ }
+
+ /** This includes the names of ports and any port bundle field names in the hash. */
+ def sha256WithSignificantPortNames(module: DefModule, moduleRename: String => String = identity): HashCode = {
+ val m = MessageDigest.getInstance(SHA256)
+ hashModuleAndPortNames(module, new MessageDigestHasher(m), moduleRename)
+ new MDHashCode(m.digest())
+ }
+
+ private[firrtl] def sha256(str: String): HashCode = {
+ val m = MessageDigest.getInstance(SHA256)
+ m.update(str.getBytes())
+ new MDHashCode(m.digest())
+ }
+
+ /** Using this to hash arbitrary nodes can have unexpected results like:
+ * hash(`a <= 1`) == hash(`b <= 1`).
+ * This method is package private to allow for unit testing but should not be exposed to the user.
+ */
+ private[firrtl] def sha256Node(node: FirrtlNode): HashCode = {
+ val m = MessageDigest.getInstance(SHA256)
+ hash(node, new MessageDigestHasher(m), identity)
+ new MDHashCode(m.digest())
+ }
+
+ // see: https://docs.oracle.com/javase/7/docs/api/java/security/MessageDigest.html
+ private val SHA256 = "SHA-256"
+
+ //scalastyle:off cyclomatic.complexity
+ private def hash(node: FirrtlNode, h: Hasher, rename: String => String): Unit = node match {
+ case n : Expression => new StructuralHash(h, rename).hash(n)
+ case n : Statement => new StructuralHash(h, rename).hash(n)
+ case n : Type => new StructuralHash(h, rename).hash(n)
+ case n : Width => new StructuralHash(h, rename).hash(n)
+ case n : Orientation => new StructuralHash(h, rename).hash(n)
+ case n : Field => new StructuralHash(h, rename).hash(n)
+ case n : Direction => new StructuralHash(h, rename).hash(n)
+ case n : Port => new StructuralHash(h, rename).hash(n)
+ case n : Param => new StructuralHash(h, rename).hash(n)
+ case _ : Info => throw new RuntimeException("The structural hash of Info is meaningless.")
+ case n : DefModule => new StructuralHash(h, rename).hash(n)
+ case n : Circuit => hashCircuit(n, h, rename)
+ case n : StringLit => h.update(n.toString)
+ }
+ //scalastyle:on cyclomatic.complexity
+
+ private def hashModuleAndPortNames(m: DefModule, h: Hasher, rename: String => String): Unit = {
+ val sh = new StructuralHash(h, rename)
+ sh.hash(m)
+ // hash port names
+ m.ports.foreach { p =>
+ h.update(p.name)
+ hashPortTypeName(p.tpe, h.update)
+ }
+ }
+
+ private def hashPortTypeName(tpe: Type, h: String => Unit): Unit = tpe match {
+ case BundleType(fields) => fields.foreach{ f => h(f.name) ; hashPortTypeName(f.tpe, h) }
+ case VectorType(vt, _) => hashPortTypeName(vt, h)
+ case _ => // ignore ground types since they do not have field names nor sub-types
+ }
+
+ private def hashCircuit(c: Circuit, h: Hasher, rename: String => String): Unit = {
+ h.update(127)
+ h.update(c.main)
+ // sort modules to make hash more useful
+ val mods = c.modules.sortBy(_.name)
+ // we create a new StructuralHash for each module since each module has its own namespace
+ mods.foreach { m =>
+ new StructuralHash(h, rename).hash(m)
+ }
+ }
+
+ private val primOpToId = PrimOps.builtinPrimOps.zipWithIndex.map{ case (op, i) => op -> (-i -1).toByte }.toMap
+ assert(primOpToId.values.max == -1, "PrimOp nodes use ids -1 ... -50")
+ assert(primOpToId.values.min >= -50, "PrimOp nodes use ids -1 ... -50")
+ private def primOp(p: PrimOp): Byte = primOpToId(p)
+
+ // verification ops are not firrtl nodes and thus not part of the same id namespace
+ private def verificationOp(op: Formal.Value): Byte = op match {
+ case Formal.Assert => 0
+ case Formal.Assume => 1
+ case Formal.Cover => 2
+ }
+}
+
+trait HashCode {
+ protected val str: String
+ override def hashCode(): Int = str.hashCode
+ override def equals(obj: Any): Boolean = obj match {
+ case hashCode: HashCode => this.str.equals(hashCode.str)
+ case _ => false
+ }
+}
+
+private class MDHashCode(code: Array[Byte]) extends HashCode {
+ protected override val str: String = code.map(b => f"${b.toInt & 0xff}%02x").mkString("")
+}
+
+/** Generic hashing interface which allows us to use different backends to trade of speed and collision resistance */
+private trait Hasher {
+ def update(b: Byte): Unit
+ def update(i: Int): Unit
+ def update(l: Long): Unit
+ def update(s: String): Unit
+ def update(b: Array[Byte]): Unit
+ def update(d: Double): Unit = update(java.lang.Double.doubleToRawLongBits(d))
+ def update(i: BigInt): Unit = update(i.toByteArray)
+ def update(b: Boolean): Unit = if(b) update(1.toByte) else update(0.toByte)
+ def update(i: BigDecimal): Unit = {
+ // this might be broken, tried to borrow some code from BigDecimal.computeHashCode
+ val temp = i.bigDecimal.stripTrailingZeros()
+ val bigInt = temp.scaleByPowerOfTen(temp.scale).toBigInteger
+ update(bigInt)
+ update(temp.scale)
+ }
+}
+
+private class MessageDigestHasher(m: MessageDigest) extends Hasher {
+ override def update(b: Byte): Unit = m.update(b)
+ override def update(i: Int): Unit = {
+ m.update(((i >> 0) & 0xff).toByte)
+ m.update(((i >> 8) & 0xff).toByte)
+ m.update(((i >> 16) & 0xff).toByte)
+ m.update(((i >> 24) & 0xff).toByte)
+ }
+ override def update(l: Long): Unit = {
+ m.update(((l >> 0) & 0xff).toByte)
+ m.update(((l >> 8) & 0xff).toByte)
+ m.update(((l >> 16) & 0xff).toByte)
+ m.update(((l >> 24) & 0xff).toByte)
+ m.update(((l >> 32) & 0xff).toByte)
+ m.update(((l >> 40) & 0xff).toByte)
+ m.update(((l >> 48) & 0xff).toByte)
+ m.update(((l >> 56) & 0xff).toByte)
+ }
+ // the encoding of the bytes should not matter as long as we are on the same platform
+ override def update(s: String): Unit = m.update(s.getBytes())
+ override def update(b: Array[Byte]): Unit = m.update(b)
+}
+
+class StructuralHash private(h: Hasher, renameModule: String => String) {
+ // replace identifiers with incrementing integers
+ private val nameToInt = mutable.HashMap[String, Int]()
+ private var nameCounter: Int = 0
+ @inline private def n(name: String): Unit = hash(nameToInt.getOrElseUpdate(name, {
+ val ii = nameCounter
+ nameCounter = nameCounter + 1
+ ii
+ }))
+
+ // internal convenience methods
+ @inline private def id(b: Byte): Unit = h.update(b)
+ @inline private def hash(i: Int): Unit = h.update(i)
+ @inline private def hash(b: Boolean): Unit = h.update(b)
+ @inline private def hash(d: Double): Unit = h.update(d)
+ @inline private def hash(i: BigInt): Unit = h.update(i)
+ @inline private def hash(i: BigDecimal): Unit = h.update(i)
+ @inline private def hash(s: String): Unit = h.update(s)
+
+ //scalastyle:off magic.number
+ //scalastyle:off cyclomatic.complexity
+ private def hash(node: Expression): Unit = node match {
+ case Reference(name, _, _, _) => id(0) ; n(name)
+ case DoPrim(op, args, consts, _) =>
+ // no need to hash the number of arguments or constants since that is implied by the op
+ id(1) ; h.update(StructuralHash.primOp(op)) ; args.foreach(hash) ; consts.foreach(hash)
+ case UIntLiteral(value, width) => id(2) ; hash(value) ; hash(width)
+ // We hash bundles as if fields are accessed by their index.
+ // Thus we need to also hash field accesses that way.
+ // This has the side-effect that `x.y` might hash to the same value as `z.r`, for example if the
+ // types are `x: {y: UInt<1>, ...}` and `z: {r: UInt<1>, ...}` respectively.
+ // They do not hash to the same value if the type of `z` is e.g., `z: {..., r: UInt<1>, ...}`
+ // as that would have the `r` field at a different index.
+ case SubField(expr, name, _, _) => id(3) ; hash(expr)
+ // find field index and hash that instead of the field name
+ val fields = expr.tpe match {
+ case b: BundleType => b.fields
+ case other =>
+ throw new RuntimeException(s"Unexpected type $other for SubField access. Did you run the type checker?")
+ }
+ val index = fields.zipWithIndex.find(_._1.name == name).map(_._2).get
+ hash(index)
+ case SubIndex(expr, value, _, _) => id(4) ; hash(expr) ; hash(value)
+ case SubAccess(expr, index, _, _) => id(5) ; hash(expr) ; hash(index)
+ case Mux(cond, tval, fval, _) => id(6) ; hash(cond) ; hash(tval) ; hash(fval)
+ case ValidIf(cond, value, _) => id(7) ; hash(cond) ; hash(value)
+ case SIntLiteral(value, width) => id(8) ; hash(value) ; hash(width)
+ case FixedLiteral(value, width, point) => id(9) ; hash(value) ; hash(width) ; hash(point)
+ // WIR
+ case firrtl.WVoid => id(10)
+ case firrtl.WInvalid => id(11)
+ case firrtl.EmptyExpression => id(12)
+ // VRandom is used in the Emitter
+ case firrtl.VRandom(width) => id(13) ; hash(width)
+ // ids 14 ... 19 are reserved for future Expression nodes
+ }
+ //scalastyle:on cyclomatic.complexity
+
+ //scalastyle:off cyclomatic.complexity method.length
+ private def hash(node: Statement): Unit = node match {
+ // all info fields are ignore
+ case DefNode(_, name, value) => id(20) ; n(name) ; hash(value)
+ case Connect(_, loc, expr) => id(21) ; hash(loc) ; hash(expr)
+ // we place the unique id 23 between conseq and alt to distinguish between them in case conseq is empty
+ // we place the unique id 24 after alt to distinguish between alt and the next statement in case alt is empty
+ case Conditionally(_, pred, conseq, alt) => id(22) ; hash(pred) ; hash(conseq) ; id(23) ; hash(alt) ; id(24)
+ case EmptyStmt => // empty statements are ignored
+ case Block(stmts) => stmts.foreach(hash) // block structure is ignored
+ case Stop(_, ret, clk, en) => id(25) ; hash(ret) ; hash(clk) ; hash(en)
+ case Print(_, string, args, clk, en) =>
+ // the string is part of the side effect and thus part of the circuit behavior
+ id(26) ; hash(string.string) ; hash(args.length) ; args.foreach(hash) ; hash(clk) ; hash(en)
+ case IsInvalid(_, expr) => id(27) ; hash(expr)
+ case DefWire(_, name, tpe) => id(28) ; n(name) ; hash(tpe)
+ case DefRegister(_, name, tpe, clock, reset, init) =>
+ id(29) ; n(name) ; hash(tpe) ; hash(clock) ; hash(reset) ; hash(init)
+ case DefInstance(_, name, module, _) =>
+ // Module is in the global namespace which is why we cannot replace it with a numeric id.
+ // However, it might have been renamed as part of the dedup consolidation.
+ id(30) ; n(name) ; hash(renameModule(module))
+ // descriptions on statements are ignores
+ case firrtl.DescribedStmt(_, stmt) => hash(stmt)
+ case DefMemory(_, name, dataType, depth, writeLatency, readLatency, readers, writers,
+ readwriters, readUnderWrite) =>
+ id(30) ; n(name) ; hash(dataType) ; hash(depth) ; hash(writeLatency) ; hash(readLatency)
+ hash(readers.length) ; readers.foreach(hash)
+ hash(writers.length) ; writers.foreach(hash)
+ hash(readwriters.length) ; readwriters.foreach(hash)
+ hash(readUnderWrite)
+ case PartialConnect(_, loc, expr) => id(31) ; hash(loc) ; hash(expr)
+ case Attach(_, exprs) => id(32) ; hash(exprs.length) ; exprs.foreach(hash)
+ // WIR
+ case firrtl.CDefMemory(_, name, tpe, size, seq, readUnderWrite) =>
+ id(33) ; n(name) ; hash(tpe); hash(size) ; hash(seq) ; hash(readUnderWrite)
+ case firrtl.CDefMPort(_, name, _, mem, exps, direction) =>
+ // the type of the MPort depends only on the memory (in well types firrtl) and can thus be ignored
+ id(34) ; n(name) ; n(mem) ; hash(exps.length) ; exps.foreach(hash) ; hash(direction)
+ // DefAnnotatedMemory from MemIR.scala
+ case firrtl.passes.memlib.DefAnnotatedMemory(_, name, dataType, depth, writeLatency, readLatency, readers, writers,
+ readwriters, readUnderWrite, maskGran, memRef) =>
+ id(35) ; n(name) ; hash(dataType) ; hash(depth) ; hash(writeLatency) ; hash(readLatency)
+ hash(readers.length) ; readers.foreach(hash)
+ hash(writers.length) ; writers.foreach(hash)
+ hash(readwriters.length) ; readwriters.foreach(hash)
+ hash(readUnderWrite.toString)
+ hash(maskGran.size) ; maskGran.foreach(hash)
+ hash(memRef.size) ; memRef.foreach{ case (a, b) => hash(a) ; hash(b) }
+ case Verification(op, _, clk, pred, en, msg) =>
+ id(36) ; hash(StructuralHash.verificationOp(op)) ; hash(clk) ; hash(pred) ; hash(en) ; hash(msg.string)
+ // ids 37 ... 39 are reserved for future Statement nodes
+ }
+ //scalastyle:on cyclomatic.complexity method.length
+
+ // ReadUnderWrite is never used in place of a FirrtlNode and thus we can start a new id namespace
+ private def hash(ruw: ReadUnderWrite.Value): Unit = ruw match {
+ case ReadUnderWrite.New => id(0)
+ case ReadUnderWrite.Old => id(1)
+ case ReadUnderWrite.Undefined => id(2)
+ }
+
+ private def hash(node: Width): Unit = node match {
+ case IntWidth(width) => id(40) ; hash(width)
+ case UnknownWidth => id(41)
+ case CalcWidth(arg) => id(42) ; hash(arg)
+ // we are hashing the name of the `VarWidth` instead of using `n` since these Vars exist in a different namespace
+ case VarWidth(name) => id(43) ; hash(name)
+ // ids 44 + 45 are reserved for future Width nodes
+ }
+
+ private def hash(node: Orientation): Unit = node match {
+ case Default => id(46)
+ case Flip => id(47)
+ }
+
+ private def hash(node: Field): Unit = {
+ // since we are only interested in a structural hash, we ignore field names
+ // this means that: hash(`{x : UInt<1>, y: UInt<2>}`) == hash(`{y : UInt<1>, x: UInt<2>}`)
+ // but: hash(`{x : UInt<1>, y: UInt<2>}`) != hash(`{y : UInt<2>, x: UInt<1>}`)
+ // which seems strange, since the connect semantics rely on field names, but it is the behavior that
+ // has been used in the Dedup pass for a long time.
+ // This position-based notion of equality requires us to replace field names with field indexes when hashing
+ // SubField accesses.
+ id(48) ; hash(node.flip) ; hash(node.tpe)
+ }
+
+ //scalastyle:off cyclomatic.complexity
+ private def hash(node: Type): Unit = node match {
+ // Types
+ case UIntType(width: Width) => id(50) ; hash(width)
+ case SIntType(width: Width) => id(51) ; hash(width)
+ case FixedType(width, point) => id(52) ; hash(width) ; hash(point)
+ case BundleType(fields) => id(53) ; hash(fields.length) ; fields.foreach(hash)
+ case VectorType(tpe, size) => id(54) ; hash(tpe) ; hash(size)
+ case ClockType => id(55)
+ case ResetType => id(56)
+ case AsyncResetType => id(57)
+ case AnalogType(width) => id(58) ; hash(width)
+ case UnknownType => id(59)
+ case IntervalType(lower, upper, point) => id(60) ; hash(lower) ; hash(upper) ; hash(point)
+ // ids 61 ... 65 are reserved for future Type nodes
+ }
+ //scalastyle:on cyclomatic.complexity
+
+ private def hash(node: Direction): Unit = node match {
+ case Input => id(66)
+ case Output => id(67)
+ }
+
+ private def hash(node: Port): Unit = {
+ id(68) ; n(node.name) ; hash(node.direction) ; hash(node.tpe)
+ }
+
+ private def hash(node: Param): Unit = node match {
+ case IntParam(name, value) => id(70) ; n(name) ; hash(value)
+ case DoubleParam(name, value) => id(71) ; n(name) ; hash(value)
+ case StringParam(name, value) => id(72) ; n(name) ; hash(value.string)
+ case RawStringParam(name, value) => id(73) ; n(name) ; hash(value)
+ // id 74 is reserved for future use
+ }
+
+ private def hash(node: DefModule): Unit = node match {
+ // the module name is ignored since it does not affect module functionality
+ case Module(_, _name, ports, body) =>
+ id(75) ; hash(ports.length) ; ports.foreach(hash) ; hash(body)
+ // the module name is ignored since it does not affect module functionality
+ case ExtModule(_, name, ports, defname, params) =>
+ id(76) ; hash(ports.length) ; ports.foreach(hash) ; hash(defname)
+ hash(params.length) ; params.foreach(hash)
+ }
+
+ // id 127 is reserved for Circuit nodes
+
+ private def hash(d: firrtl.MPortDir): Unit = d match {
+ case firrtl.MInfer => id(-70)
+ case firrtl.MRead => id(-71)
+ case firrtl.MWrite => id(-72)
+ case firrtl.MReadWrite => id(-73)
+ }
+
+ private def hash(c: firrtl.constraint.Constraint): Unit = c match {
+ case b: Bound => hash(b) /* uses ids -80 ... -84 */
+ case firrtl.constraint.IsAdd(known, maxs, mins, others) =>
+ id(-85) ; hash(known.nonEmpty) ; known.foreach(hash)
+ hash(maxs.length) ; maxs.foreach(hash)
+ hash(mins.length) ; mins.foreach(hash)
+ hash(others.length) ; others.foreach(hash)
+ case firrtl.constraint.IsFloor(child, dummyArg) => id(-86) ; hash(child) ; hash(dummyArg)
+ case firrtl.constraint.IsKnown(decimal) => id(-87) ; hash(decimal)
+ case firrtl.constraint.IsNeg(child, dummyArg) => id(-88) ; hash(child) ; hash(dummyArg)
+ case firrtl.constraint.IsPow(child, dummyArg) => id(-89) ; hash(child) ; hash(dummyArg)
+ case firrtl.constraint.IsVar(str) => id(-90) ; n(str)
+ }
+
+ private def hash(b: Bound): Unit = b match {
+ case UnknownBound => id(-80)
+ case CalcBound(arg) => id(-81) ; hash(arg)
+ // we are hashing the name of the `VarBound` instead of using `n` since these Vars exist in a different namespace
+ case VarBound(name) => id(-82) ; hash(name)
+ case Open(value) => id(-83) ; hash(value)
+ case Closed(value) => id(-84) ; hash(value)
+ }
+} \ No newline at end of file
diff --git a/src/main/scala/firrtl/transforms/Dedup.scala b/src/main/scala/firrtl/transforms/Dedup.scala
index dc182858..ba06ba4b 100644
--- a/src/main/scala/firrtl/transforms/Dedup.scala
+++ b/src/main/scala/firrtl/transforms/Dedup.scala
@@ -244,35 +244,6 @@ class DedupModules extends Transform with DependencyAPIMigration {
/** Utility functions for [[DedupModules]] */
object DedupModules extends LazyLogging {
- def fastSerializedHash(s: Statement): Int ={
- def serialize(builder: StringBuilder, nindent: Int)(s: Statement): Unit = s match {
- case Block(stmts) => stmts.map {
- val x = serialize(builder, nindent)(_)
- builder ++= "\n"
- x
- }
- case Conditionally(info, pred, conseq, alt) =>
- builder ++= (" " * nindent)
- builder ++= s"when ${pred.serialize} :"
- builder ++= info.serialize
- serialize(builder, nindent + 1)(conseq)
- builder ++= "\n" + (" " * nindent)
- builder ++= "else :\n"
- serialize(builder, nindent + 1)(alt)
- case Print(info, string, args, clk, en) =>
- builder ++= (" " * nindent)
- val strs = Seq(clk.serialize, en.serialize, string.string) ++
- (args map (_.serialize))
- builder ++= "printf(" + (strs mkString ", ") + ")" + info.serialize
- case other: Statement =>
- builder ++= (" " * nindent)
- builder ++= other.serialize
- }
- val builder = new mutable.StringBuilder()
- serialize(builder, 0)(s)
- builder.hashCode()
- }
-
/** Change's a module's internal signal names, types, infos, and modules.
* @param rename Function to rename a signal. Called on declaration and references.
* @param retype Function to retype a signal. Called on declaration, references, and subfields
@@ -341,60 +312,6 @@ object DedupModules extends LazyLogging {
module map onPort map onStmt
}
- def uniquifyField(ref: String, depth: Int, field: String): String = ref + depth + field
-
- /** Turns a module into a name-agnostic module
- * @param module module to change
- * @return name-agnostic module
- */
- def agnostify(top: CircuitTarget,
- module: DefModule,
- renameMap: RenameMap,
- agnosticModuleName: String
- ): DefModule = {
-
-
- val namespace = Namespace()
- val typeMap = mutable.HashMap[String, Type]()
- val nameMap = mutable.HashMap[String, String]()
-
- val mod = top.module(module.name)
- val agnosticMod = top.module(agnosticModuleName)
-
- def rename(name: String): String = {
- nameMap.getOrElseUpdate(name, {
- val newName = namespace.newTemp
- renameMap.record(mod.ref(name), agnosticMod.ref(newName))
- newName
- })
- }
-
- def retype(name: String)(tpe: Type): Type = {
- if (typeMap.contains(name)) typeMap(name) else {
- def onType(depth: Int)(tpe: Type): Type = tpe map onType(depth + 1) match {
- //TODO bugfix: ref.data.data and ref.datax.data will not rename to the right tags, even if they should be
- case BundleType(fields) =>
- BundleType(fields.map(f => Field(rename(uniquifyField(name, depth, f.name)), f.flip, f.tpe)))
- case other => other
- }
- val newType = onType(0)(tpe)
- typeMap(name) = newType
- newType
- }
- }
-
- def reOfModule(instance: String, ofModule: String): String = {
- renameMap.get(top.module(ofModule)) match {
- case Some(Seq(Target(_, Some(ofModuleTag), Nil))) => ofModuleTag
- case None => ofModule
- case other => throwInternalError(other.toString)
- }
- }
-
- val renamedModule = changeInternals(rename, retype, {i: Info => NoInfo}, reOfModule)(module)
- renamedModule
- }
-
/** Dedup a module's instances based on dedup map
*
* Will fixes up module if deduped instance's ports are differently named
@@ -491,77 +408,55 @@ object DedupModules extends LazyLogging {
dontDedup.toSet
}
- //scalastyle:off
- /** Returns
- * 1) map of tag to all matching module names,
- * 2) renameMap of module name to tag (agnostic name)
- * 3) maps module name to agnostic renameMap
+ /** Visits every module in the circuit, starting at the leaf nodes.
+ * Every module is hashed in order to find ones that have the exact
+ * same structure and are thus functionally equivalent.
+ * Every unique hash is mapped to a human-readable tag which starts with `Dedup#`.
* @param top CircuitTarget
* @param moduleLinearization Sequence of modules from leaf to top
- * @param noDedups Set of modules to not dedup
- * @return
+ * @param noDedups names of modules that should not be deduped
+ * @return A map from tag to names of modules with the same structure and
+ * a RenameMap which maps Module names to their Tag.
*/
def buildRTLTags(top: CircuitTarget,
moduleLinearization: Seq[DefModule],
noDedups: Set[String]
): (collection.Map[String, collection.Set[String]], RenameMap) = {
+ // maps hash code to human readable tag
+ val hashToTag = mutable.HashMap[ir.HashCode, String]()
+ // remembers all modules with the same hash
+ val hashToNames = mutable.HashMap[ir.HashCode, List[String]]()
- // Maps a module name to its agnostic name
- val tagMap = RenameMap()
-
- // Maps a tag to all matching module names
- val tag2all = mutable.HashMap.empty[String, mutable.HashSet[String]]
-
- val agnosticRename = RenameMap()
+ // rename modules that we have already visited to their hash-derived tag name
+ val moduleNameToTag = mutable.HashMap[String, String]()
val dontAgnostifyPorts = modsToNotAgnostifyPorts(moduleLinearization)
moduleLinearization.foreach { originalModule =>
- // Replace instance references to new deduped modules
- val dontcare = RenameMap()
- dontcare.setCircuit("dontcare")
-
- if (noDedups.contains(originalModule.name)) {
- // Don't dedup. Set dedup module to be the same as fixed module
- tag2all(originalModule.name) = mutable.HashSet(originalModule.name)
- } else { // Try to dedup
-
- // Build name-agnostic module
- val agnosticModule = DedupModules.agnostify(top, originalModule, agnosticRename, "thisModule")
- agnosticRename.record(top.module(originalModule.name), top.module("thisModule"))
- agnosticRename.delete(top.module(originalModule.name))
-
- // Build tag
- val builder = new mutable.ArrayBuffer[Any]()
-
- // It may seem weird to use non-agnostified ports with an agnostified body because
- // technically it would be invalid FIRRTL, but it is logically sound for the purpose of
- // calculating deduplication tags
- val ports =
- if (dontAgnostifyPorts(originalModule.name)) originalModule.ports else agnosticModule.ports
- ports.foreach { builder ++= _.serialize }
-
- agnosticModule match {
- case Module(i, n, ps, b) => builder ++= fastSerializedHash(b).toString()//.serialize
- case ExtModule(i, n, ps, dn, p) =>
- builder ++= dn
- p.foreach { builder ++= _.serialize }
- }
- val tag = builder.hashCode().toString
-
- // Match old module name to its tag
- agnosticRename.record(top.module(originalModule.name), top.module(tag))
- tagMap.record(top.module(originalModule.name), top.module(tag))
+ val hash = if (noDedups.contains(originalModule.name)) {
+ // if we do not want to dedup we just hash the name of the module which is guaranteed to be unique
+ StructuralHash.sha256(originalModule.name)
+ } else if (dontAgnostifyPorts(originalModule.name)) {
+ StructuralHash.sha256WithSignificantPortNames(originalModule, moduleNameToTag)
+ } else {
+ StructuralHash.sha256(originalModule, moduleNameToTag)
+ }
- // Set tag's module to be the first matching module
- val all = tag2all.getOrElseUpdate(tag, mutable.HashSet.empty[String])
- all += originalModule.name
+ if (hashToTag.contains(hash)) {
+ hashToNames(hash) = hashToNames(hash) :+ originalModule.name
+ } else {
+ hashToTag(hash) = "Dedup#" + originalModule.name
+ hashToNames(hash) = List(originalModule.name)
}
+ moduleNameToTag(originalModule.name) = hashToTag(hash)
}
+
+ val tag2all = hashToNames.map{ case (hash, names) => hashToTag(hash) -> names.toSet }
+ val tagMap = RenameMap()
+ moduleNameToTag.foreach{ case (name, tag) => tagMap.record(top.module(name), top.module(tag)) }
(tag2all, tagMap)
}
- //scalastyle:on
/** Deduplicate
* @param circuit Circuit
diff --git a/src/test/scala/firrtl/ir/StructuralHashSpec.scala b/src/test/scala/firrtl/ir/StructuralHashSpec.scala
new file mode 100644
index 00000000..17fe0b84
--- /dev/null
+++ b/src/test/scala/firrtl/ir/StructuralHashSpec.scala
@@ -0,0 +1,277 @@
+// See LICENSE for license details.
+
+package firrtl.ir
+
+import firrtl.PrimOps._
+import org.scalatest.flatspec.AnyFlatSpec
+
+class StructuralHashSpec extends AnyFlatSpec {
+ private def hash(n: DefModule): HashCode = StructuralHash.sha256(n, n => n)
+ private def hash(c: Circuit): HashCode = StructuralHash.sha256Node(c)
+ private def hash(e: Expression): HashCode = StructuralHash.sha256Node(e)
+ private def hash(t: Type): HashCode = StructuralHash.sha256Node(t)
+ private def hash(s: Statement): HashCode = StructuralHash.sha256Node(s)
+ private val highFirrtlCompiler = new firrtl.stage.transforms.Compiler(
+ targets = firrtl.stage.Forms.HighForm
+ )
+ private def parse(circuit: String): Circuit = {
+ val rawFirrtl = firrtl.Parser.parse(circuit)
+ // TODO: remove requirement that Firrtl needs to be type checked.
+ // The only reason this is needed for the structural hash right now is because we
+ // define bundles with the same list of field types to be the same, regardless of the
+ // name of these fields. Thus when the fields are accessed, we need to know their position
+ // in order to appropriately hash them.
+ highFirrtlCompiler.transform(firrtl.CircuitState(rawFirrtl, Seq())).circuit
+ }
+
+ private val b0 = UIntLiteral(0,IntWidth(1))
+ private val b1 = UIntLiteral(1,IntWidth(1))
+ private val add = DoPrim(Add, Seq(b0, b1), Seq(), UnknownType)
+
+ it should "generate the same hash if the objects are structurally the same" in {
+ assert(hash(b0) == hash(UIntLiteral(0,IntWidth(1))))
+ assert(hash(b0) != hash(UIntLiteral(1,IntWidth(1))))
+ assert(hash(b0) != hash(UIntLiteral(1,IntWidth(2))))
+
+ assert(hash(b1) == hash(UIntLiteral(1,IntWidth(1))))
+ assert(hash(b1) != hash(UIntLiteral(0,IntWidth(1))))
+ assert(hash(b1) != hash(UIntLiteral(1,IntWidth(2))))
+ }
+
+ it should "ignore expression types" in {
+ assert(hash(add) == hash(DoPrim(Add, Seq(b0, b1), Seq(), UnknownType)))
+ assert(hash(add) == hash(DoPrim(Add, Seq(b0, b1), Seq(), UIntType(UnknownWidth))))
+ assert(hash(add) != hash(DoPrim(Add, Seq(b0, b0), Seq(), UnknownType)))
+ }
+
+ it should "ignore variable names" in {
+ val a =
+ """circuit a:
+ | module a:
+ | input x : UInt<1>
+ | output y: UInt<1>
+ | y <= x
+ |""".stripMargin
+
+ assert(hash(parse(a)) == hash(parse(a)), "the same circuit should always be equivalent")
+
+ val b =
+ """circuit a:
+ | module a:
+ | input abc : UInt<1>
+ | output haha: UInt<1>
+ | haha <= abc
+ |""".stripMargin
+
+ assert(hash(parse(a)) == hash(parse(b)), "renaming ports should not affect the hash by default")
+
+ val c =
+ """circuit a:
+ | module a:
+ | input x : UInt<1>
+ | output y: UInt<1>
+ | y <= and(x, UInt<1>(0))
+ |""".stripMargin
+
+ assert(hash(parse(a)) != hash(parse(c)), "changing an expression should affect the hash")
+
+ val d =
+ """circuit c:
+ | module c:
+ | input abc : UInt<1>
+ | output haha: UInt<1>
+ | haha <= abc
+ |""".stripMargin
+
+ assert(hash(parse(a)) != hash(parse(d)), "circuits with different names are always different")
+ assert(hash(parse(a).modules.head) == hash(parse(d).modules.head),
+ "modules with different names can be structurally different")
+
+ // for the Dedup pass we do need a way to take the port names into account
+ assert(StructuralHash.sha256WithSignificantPortNames(parse(a).modules.head) !=
+ StructuralHash.sha256WithSignificantPortNames(parse(b).modules.head),
+ "renaming ports does affect the hash if we ask to")
+ }
+
+
+ it should "not ignore port names if asked to" in {
+ val e =
+ """circuit a:
+ | module a:
+ | input x : UInt<1>
+ | wire y: UInt<1>
+ | y <= x
+ |""".stripMargin
+
+ val f =
+ """circuit a:
+ | module a:
+ | input z : UInt<1>
+ | wire y: UInt<1>
+ | y <= z
+ |""".stripMargin
+
+ val g =
+ """circuit a:
+ | module a:
+ | input x : UInt<1>
+ | wire z: UInt<1>
+ | z <= x
+ |""".stripMargin
+
+ assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) !=
+ StructuralHash.sha256WithSignificantPortNames(parse(f).modules.head),
+ "renaming ports does affect the hash if we ask to")
+ assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) ==
+ StructuralHash.sha256WithSignificantPortNames(parse(g).modules.head),
+ "renaming internal wires should never affect the hash")
+ assert(hash(parse(e).modules.head) == hash(parse(g).modules.head),
+ "renaming internal wires should never affect the hash")
+ }
+
+ it should "not ignore port bundle names if asked to" in {
+ val e =
+ """circuit a:
+ | module a:
+ | input x : {x: UInt<1>}
+ | wire y: {x: UInt<1>}
+ | y.x <= x.x
+ |""".stripMargin
+
+ val f =
+ """circuit a:
+ | module a:
+ | input x : {z: UInt<1>}
+ | wire y: {x: UInt<1>}
+ | y.x <= x.z
+ |""".stripMargin
+
+ val g =
+ """circuit a:
+ | module a:
+ | input x : {x: UInt<1>}
+ | wire y: {z: UInt<1>}
+ | y.z <= x.x
+ |""".stripMargin
+
+ assert(hash(parse(e).modules.head) == hash(parse(f).modules.head),
+ "renaming port bundles does normally not affect the hash")
+ assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) !=
+ StructuralHash.sha256WithSignificantPortNames(parse(f).modules.head),
+ "renaming port bundles does affect the hash if we ask to")
+ assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) ==
+ StructuralHash.sha256WithSignificantPortNames(parse(g).modules.head),
+ "renaming internal wire bundles should never affect the hash")
+ assert(hash(parse(e).modules.head) == hash(parse(g).modules.head),
+ "renaming internal wire bundles should never affect the hash")
+ }
+
+
+ it should "fail on Info" in {
+ // it does not make sense to hash Info nodes
+ assertThrows[RuntimeException] {
+ StructuralHash.sha256Node(FileInfo(StringLit("")))
+ }
+ }
+
+ "Bundles with different field names" should "be structurally equivalent" in {
+ def parse(str: String): BundleType = {
+ val src =
+ s"""circuit c:
+ | module c:
+ | input z: $str
+ |""".stripMargin
+ val c = firrtl.Parser.parse(src)
+ val tpe = c.modules.head.ports.head.tpe
+ tpe.asInstanceOf[BundleType]
+ }
+
+ val a = "{x: UInt<1>, y: UInt<1>}"
+ assert(hash(parse(a)) == hash(parse(a)), "the same bundle should always be equivalent")
+
+ val b = "{z: UInt<1>, y: UInt<1>}"
+ assert(hash(parse(a)) == hash(parse(b)), "changing a field name should maintain equivalence")
+
+ val c = "{x: UInt<2>, y: UInt<1>}"
+ assert(hash(parse(a)) != hash(parse(c)), "changing a field type should not maintain equivalence")
+
+ val d = "{x: UInt<1>, y: {y: UInt<1>}}"
+ assert(hash(parse(a)) != hash(parse(d)), "changing the structure should not maintain equivalence")
+
+ assert(hash(parse("{z: {y: {x: UInt<1>}}, a: UInt<1>}")) == hash(parse("{a: {b: {c: UInt<1>}}, z: UInt<1>}")))
+ }
+
+ "ExtModules with different names but the same defname" should "be structurally equivalent" in {
+ val a =
+ """circuit a:
+ | extmodule a:
+ | input x : UInt<1>
+ | defname = xyz
+ |""".stripMargin
+
+ val b =
+ """circuit b:
+ | extmodule b:
+ | input y : UInt<1>
+ | defname = xyz
+ |""".stripMargin
+
+ // Q: should extmodule portnames always be significant since they map to the verilog pins?
+ // A: It would be a bug for two exmodules in the same circuit to have the same defname but different
+ // port names. This should be detected by an earlier pass and thus we do not have to deal with that situation.
+ assert(hash(parse(a).modules.head) == hash(parse(b).modules.head),
+ "two ext modules with the same defname and the same type and number of ports")
+ assert(StructuralHash.sha256WithSignificantPortNames(parse(a).modules.head) !=
+ StructuralHash.sha256WithSignificantPortNames(parse(b).modules.head),
+ "two ext modules with significant port names")
+ }
+
+ "Blocks and empty statements" should "not affect structural equivalence" in {
+ val stmtA = DefNode(NoInfo, "a", UIntLiteral(1))
+ val stmtB = DefNode(NoInfo, "b", UIntLiteral(1))
+
+ val a = Block(Seq(Block(Seq(stmtA)), stmtB))
+ val b = Block(Seq(stmtA, stmtB))
+ assert(hash(a) == hash(b))
+
+ val c = Block(Seq(Block(Seq(Block(Seq(stmtA, stmtB))))))
+ assert(hash(a) == hash(c))
+
+ val d = Block(Seq(stmtA))
+ assert(hash(a) != hash(d))
+
+ val e = Block(Seq(Block(Seq(stmtB)), stmtB))
+ assert(hash(a) != hash(e))
+
+ val f = Block(Seq(Block(Seq(Block(Seq(stmtA, EmptyStmt, stmtB))))))
+ assert(hash(a) == hash(f))
+ }
+
+ "Conditionally" should "properly separate if and else branch" in {
+ val stmtA = DefNode(NoInfo, "a", UIntLiteral(1))
+ val stmtB = DefNode(NoInfo, "b", UIntLiteral(1))
+ val cond = UIntLiteral(1)
+
+ val a = Conditionally(NoInfo, cond, stmtA, stmtB)
+ val b = Conditionally(NoInfo, cond, Block(Seq(stmtA)), stmtB)
+ assert(hash(a) == hash(b))
+
+ val c = Conditionally(NoInfo, cond, Block(Seq(stmtA)), Block(Seq(EmptyStmt, stmtB)))
+ assert(hash(a) == hash(c))
+
+ val d = Block(Seq(Conditionally(NoInfo, cond, stmtA, EmptyStmt), stmtB))
+ assert(hash(a) != hash(d))
+
+ val e = Conditionally(NoInfo, cond, stmtA, EmptyStmt)
+ val f = Conditionally(NoInfo, cond, EmptyStmt, stmtA)
+ assert(hash(e) != hash(f))
+ }
+}
+
+private case object DebugHasher extends Hasher {
+ override def update(b: Byte): Unit = println(s"b(${b.toInt & 0xff})")
+ override def update(i: Int): Unit = println(s"i(${i})")
+ override def update(l: Long): Unit = println(s"l(${l})")
+ override def update(s: String): Unit = println(s"s(${s})")
+ override def update(b: Array[Byte]): Unit = println(s"bytes(${b.map(x => x.toInt & 0xff).mkString(", ")})")
+} \ No newline at end of file