diff options
| -rw-r--r-- | benchmark/src/main/scala/firrtl/benchmark/hot/TransformBenchmark.scala | 4 | ||||
| -rw-r--r-- | src/main/scala/firrtl/PrimOps.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/ir/StructuralHash.scala | 395 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/Dedup.scala | 167 | ||||
| -rw-r--r-- | src/test/scala/firrtl/ir/StructuralHashSpec.scala | 277 |
5 files changed, 707 insertions, 138 deletions
diff --git a/benchmark/src/main/scala/firrtl/benchmark/hot/TransformBenchmark.scala b/benchmark/src/main/scala/firrtl/benchmark/hot/TransformBenchmark.scala index 05b759c6..abbee5a9 100644 --- a/benchmark/src/main/scala/firrtl/benchmark/hot/TransformBenchmark.scala +++ b/benchmark/src/main/scala/firrtl/benchmark/hot/TransformBenchmark.scala @@ -7,8 +7,8 @@ package hot import firrtl._ import firrtl.passes.LowerTypes import firrtl.stage.TransformManager - import firrtl.benchmark.util._ +import firrtl.transforms.DedupModules abstract class TransformBenchmark(factory: () => Transform) extends App { val inputFile = args(0) @@ -25,3 +25,5 @@ abstract class TransformBenchmark(factory: () => Transform) extends App { } object LowerTypesBenchmark extends TransformBenchmark(() => LowerTypes) + +object DedupBenchmark extends TransformBenchmark(() => new DedupModules())
\ No newline at end of file diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala index cbe3a027..768dcb9f 100644 --- a/src/main/scala/firrtl/PrimOps.scala +++ b/src/main/scala/firrtl/PrimOps.scala @@ -477,7 +477,7 @@ object PrimOps extends LazyLogging { override def toString = "clip" } - private lazy val builtinPrimOps: Seq[PrimOp] = + private[firrtl] lazy val builtinPrimOps: Seq[PrimOp] = Seq(Add, Sub, Mul, Div, Rem, Lt, Leq, Gt, Geq, Eq, Neq, Pad, AsUInt, AsSInt, AsInterval, AsClock, AsAsyncReset, Shl, Shr, Dshl, Dshr, Neg, Cvt, Not, And, Or, Xor, Andr, Orr, Xorr, Cat, Bits, Head, Tail, AsFixedPoint, IncP, DecP, SetP, Wrap, Clip, Squeeze) diff --git a/src/main/scala/firrtl/ir/StructuralHash.scala b/src/main/scala/firrtl/ir/StructuralHash.scala new file mode 100644 index 00000000..b6222ed7 --- /dev/null +++ b/src/main/scala/firrtl/ir/StructuralHash.scala @@ -0,0 +1,395 @@ +// See LICENSE for license details. + +package firrtl.ir +import firrtl.PrimOps + +import java.security.MessageDigest +import scala.collection.mutable + +/** This object can performs a "structural" hash over any firrtl Module. + * It ignores: + * - [firrtl.ir.Expression Expression] types + * - Any [firrtl.ir.Info Info] fields + * - Description on DescribedStmt + * - each identifier name is replaced by a unique integer which only depends on the order of declaration + * and is thus deterministic + * - Module names are ignored. + * + * Because of the way we "agnostify" bundle types, all SubField access nodes need to have a known + * bundle type. Thus - in a lot of cases, like after reading firrtl from a file - you need to run + * the firrtl type inference before hashing. + * + * Please note that module hashes don't include any submodules. + * Two structurally equivalent modules are only functionally equivalent if they are part + * of the same circuit and thus all modules referred to in DefInstance are the same. + * + * @author Kevin Laeufer <laeufer@cs.berkeley.edu> + * */ +object StructuralHash { + def sha256(node: DefModule, moduleRename: String => String = identity): HashCode = { + val m = MessageDigest.getInstance(SHA256) + new StructuralHash(new MessageDigestHasher(m), moduleRename).hash(node) + new MDHashCode(m.digest()) + } + + /** This includes the names of ports and any port bundle field names in the hash. */ + def sha256WithSignificantPortNames(module: DefModule, moduleRename: String => String = identity): HashCode = { + val m = MessageDigest.getInstance(SHA256) + hashModuleAndPortNames(module, new MessageDigestHasher(m), moduleRename) + new MDHashCode(m.digest()) + } + + private[firrtl] def sha256(str: String): HashCode = { + val m = MessageDigest.getInstance(SHA256) + m.update(str.getBytes()) + new MDHashCode(m.digest()) + } + + /** Using this to hash arbitrary nodes can have unexpected results like: + * hash(`a <= 1`) == hash(`b <= 1`). + * This method is package private to allow for unit testing but should not be exposed to the user. + */ + private[firrtl] def sha256Node(node: FirrtlNode): HashCode = { + val m = MessageDigest.getInstance(SHA256) + hash(node, new MessageDigestHasher(m), identity) + new MDHashCode(m.digest()) + } + + // see: https://docs.oracle.com/javase/7/docs/api/java/security/MessageDigest.html + private val SHA256 = "SHA-256" + + //scalastyle:off cyclomatic.complexity + private def hash(node: FirrtlNode, h: Hasher, rename: String => String): Unit = node match { + case n : Expression => new StructuralHash(h, rename).hash(n) + case n : Statement => new StructuralHash(h, rename).hash(n) + case n : Type => new StructuralHash(h, rename).hash(n) + case n : Width => new StructuralHash(h, rename).hash(n) + case n : Orientation => new StructuralHash(h, rename).hash(n) + case n : Field => new StructuralHash(h, rename).hash(n) + case n : Direction => new StructuralHash(h, rename).hash(n) + case n : Port => new StructuralHash(h, rename).hash(n) + case n : Param => new StructuralHash(h, rename).hash(n) + case _ : Info => throw new RuntimeException("The structural hash of Info is meaningless.") + case n : DefModule => new StructuralHash(h, rename).hash(n) + case n : Circuit => hashCircuit(n, h, rename) + case n : StringLit => h.update(n.toString) + } + //scalastyle:on cyclomatic.complexity + + private def hashModuleAndPortNames(m: DefModule, h: Hasher, rename: String => String): Unit = { + val sh = new StructuralHash(h, rename) + sh.hash(m) + // hash port names + m.ports.foreach { p => + h.update(p.name) + hashPortTypeName(p.tpe, h.update) + } + } + + private def hashPortTypeName(tpe: Type, h: String => Unit): Unit = tpe match { + case BundleType(fields) => fields.foreach{ f => h(f.name) ; hashPortTypeName(f.tpe, h) } + case VectorType(vt, _) => hashPortTypeName(vt, h) + case _ => // ignore ground types since they do not have field names nor sub-types + } + + private def hashCircuit(c: Circuit, h: Hasher, rename: String => String): Unit = { + h.update(127) + h.update(c.main) + // sort modules to make hash more useful + val mods = c.modules.sortBy(_.name) + // we create a new StructuralHash for each module since each module has its own namespace + mods.foreach { m => + new StructuralHash(h, rename).hash(m) + } + } + + private val primOpToId = PrimOps.builtinPrimOps.zipWithIndex.map{ case (op, i) => op -> (-i -1).toByte }.toMap + assert(primOpToId.values.max == -1, "PrimOp nodes use ids -1 ... -50") + assert(primOpToId.values.min >= -50, "PrimOp nodes use ids -1 ... -50") + private def primOp(p: PrimOp): Byte = primOpToId(p) + + // verification ops are not firrtl nodes and thus not part of the same id namespace + private def verificationOp(op: Formal.Value): Byte = op match { + case Formal.Assert => 0 + case Formal.Assume => 1 + case Formal.Cover => 2 + } +} + +trait HashCode { + protected val str: String + override def hashCode(): Int = str.hashCode + override def equals(obj: Any): Boolean = obj match { + case hashCode: HashCode => this.str.equals(hashCode.str) + case _ => false + } +} + +private class MDHashCode(code: Array[Byte]) extends HashCode { + protected override val str: String = code.map(b => f"${b.toInt & 0xff}%02x").mkString("") +} + +/** Generic hashing interface which allows us to use different backends to trade of speed and collision resistance */ +private trait Hasher { + def update(b: Byte): Unit + def update(i: Int): Unit + def update(l: Long): Unit + def update(s: String): Unit + def update(b: Array[Byte]): Unit + def update(d: Double): Unit = update(java.lang.Double.doubleToRawLongBits(d)) + def update(i: BigInt): Unit = update(i.toByteArray) + def update(b: Boolean): Unit = if(b) update(1.toByte) else update(0.toByte) + def update(i: BigDecimal): Unit = { + // this might be broken, tried to borrow some code from BigDecimal.computeHashCode + val temp = i.bigDecimal.stripTrailingZeros() + val bigInt = temp.scaleByPowerOfTen(temp.scale).toBigInteger + update(bigInt) + update(temp.scale) + } +} + +private class MessageDigestHasher(m: MessageDigest) extends Hasher { + override def update(b: Byte): Unit = m.update(b) + override def update(i: Int): Unit = { + m.update(((i >> 0) & 0xff).toByte) + m.update(((i >> 8) & 0xff).toByte) + m.update(((i >> 16) & 0xff).toByte) + m.update(((i >> 24) & 0xff).toByte) + } + override def update(l: Long): Unit = { + m.update(((l >> 0) & 0xff).toByte) + m.update(((l >> 8) & 0xff).toByte) + m.update(((l >> 16) & 0xff).toByte) + m.update(((l >> 24) & 0xff).toByte) + m.update(((l >> 32) & 0xff).toByte) + m.update(((l >> 40) & 0xff).toByte) + m.update(((l >> 48) & 0xff).toByte) + m.update(((l >> 56) & 0xff).toByte) + } + // the encoding of the bytes should not matter as long as we are on the same platform + override def update(s: String): Unit = m.update(s.getBytes()) + override def update(b: Array[Byte]): Unit = m.update(b) +} + +class StructuralHash private(h: Hasher, renameModule: String => String) { + // replace identifiers with incrementing integers + private val nameToInt = mutable.HashMap[String, Int]() + private var nameCounter: Int = 0 + @inline private def n(name: String): Unit = hash(nameToInt.getOrElseUpdate(name, { + val ii = nameCounter + nameCounter = nameCounter + 1 + ii + })) + + // internal convenience methods + @inline private def id(b: Byte): Unit = h.update(b) + @inline private def hash(i: Int): Unit = h.update(i) + @inline private def hash(b: Boolean): Unit = h.update(b) + @inline private def hash(d: Double): Unit = h.update(d) + @inline private def hash(i: BigInt): Unit = h.update(i) + @inline private def hash(i: BigDecimal): Unit = h.update(i) + @inline private def hash(s: String): Unit = h.update(s) + + //scalastyle:off magic.number + //scalastyle:off cyclomatic.complexity + private def hash(node: Expression): Unit = node match { + case Reference(name, _, _, _) => id(0) ; n(name) + case DoPrim(op, args, consts, _) => + // no need to hash the number of arguments or constants since that is implied by the op + id(1) ; h.update(StructuralHash.primOp(op)) ; args.foreach(hash) ; consts.foreach(hash) + case UIntLiteral(value, width) => id(2) ; hash(value) ; hash(width) + // We hash bundles as if fields are accessed by their index. + // Thus we need to also hash field accesses that way. + // This has the side-effect that `x.y` might hash to the same value as `z.r`, for example if the + // types are `x: {y: UInt<1>, ...}` and `z: {r: UInt<1>, ...}` respectively. + // They do not hash to the same value if the type of `z` is e.g., `z: {..., r: UInt<1>, ...}` + // as that would have the `r` field at a different index. + case SubField(expr, name, _, _) => id(3) ; hash(expr) + // find field index and hash that instead of the field name + val fields = expr.tpe match { + case b: BundleType => b.fields + case other => + throw new RuntimeException(s"Unexpected type $other for SubField access. Did you run the type checker?") + } + val index = fields.zipWithIndex.find(_._1.name == name).map(_._2).get + hash(index) + case SubIndex(expr, value, _, _) => id(4) ; hash(expr) ; hash(value) + case SubAccess(expr, index, _, _) => id(5) ; hash(expr) ; hash(index) + case Mux(cond, tval, fval, _) => id(6) ; hash(cond) ; hash(tval) ; hash(fval) + case ValidIf(cond, value, _) => id(7) ; hash(cond) ; hash(value) + case SIntLiteral(value, width) => id(8) ; hash(value) ; hash(width) + case FixedLiteral(value, width, point) => id(9) ; hash(value) ; hash(width) ; hash(point) + // WIR + case firrtl.WVoid => id(10) + case firrtl.WInvalid => id(11) + case firrtl.EmptyExpression => id(12) + // VRandom is used in the Emitter + case firrtl.VRandom(width) => id(13) ; hash(width) + // ids 14 ... 19 are reserved for future Expression nodes + } + //scalastyle:on cyclomatic.complexity + + //scalastyle:off cyclomatic.complexity method.length + private def hash(node: Statement): Unit = node match { + // all info fields are ignore + case DefNode(_, name, value) => id(20) ; n(name) ; hash(value) + case Connect(_, loc, expr) => id(21) ; hash(loc) ; hash(expr) + // we place the unique id 23 between conseq and alt to distinguish between them in case conseq is empty + // we place the unique id 24 after alt to distinguish between alt and the next statement in case alt is empty + case Conditionally(_, pred, conseq, alt) => id(22) ; hash(pred) ; hash(conseq) ; id(23) ; hash(alt) ; id(24) + case EmptyStmt => // empty statements are ignored + case Block(stmts) => stmts.foreach(hash) // block structure is ignored + case Stop(_, ret, clk, en) => id(25) ; hash(ret) ; hash(clk) ; hash(en) + case Print(_, string, args, clk, en) => + // the string is part of the side effect and thus part of the circuit behavior + id(26) ; hash(string.string) ; hash(args.length) ; args.foreach(hash) ; hash(clk) ; hash(en) + case IsInvalid(_, expr) => id(27) ; hash(expr) + case DefWire(_, name, tpe) => id(28) ; n(name) ; hash(tpe) + case DefRegister(_, name, tpe, clock, reset, init) => + id(29) ; n(name) ; hash(tpe) ; hash(clock) ; hash(reset) ; hash(init) + case DefInstance(_, name, module, _) => + // Module is in the global namespace which is why we cannot replace it with a numeric id. + // However, it might have been renamed as part of the dedup consolidation. + id(30) ; n(name) ; hash(renameModule(module)) + // descriptions on statements are ignores + case firrtl.DescribedStmt(_, stmt) => hash(stmt) + case DefMemory(_, name, dataType, depth, writeLatency, readLatency, readers, writers, + readwriters, readUnderWrite) => + id(30) ; n(name) ; hash(dataType) ; hash(depth) ; hash(writeLatency) ; hash(readLatency) + hash(readers.length) ; readers.foreach(hash) + hash(writers.length) ; writers.foreach(hash) + hash(readwriters.length) ; readwriters.foreach(hash) + hash(readUnderWrite) + case PartialConnect(_, loc, expr) => id(31) ; hash(loc) ; hash(expr) + case Attach(_, exprs) => id(32) ; hash(exprs.length) ; exprs.foreach(hash) + // WIR + case firrtl.CDefMemory(_, name, tpe, size, seq, readUnderWrite) => + id(33) ; n(name) ; hash(tpe); hash(size) ; hash(seq) ; hash(readUnderWrite) + case firrtl.CDefMPort(_, name, _, mem, exps, direction) => + // the type of the MPort depends only on the memory (in well types firrtl) and can thus be ignored + id(34) ; n(name) ; n(mem) ; hash(exps.length) ; exps.foreach(hash) ; hash(direction) + // DefAnnotatedMemory from MemIR.scala + case firrtl.passes.memlib.DefAnnotatedMemory(_, name, dataType, depth, writeLatency, readLatency, readers, writers, + readwriters, readUnderWrite, maskGran, memRef) => + id(35) ; n(name) ; hash(dataType) ; hash(depth) ; hash(writeLatency) ; hash(readLatency) + hash(readers.length) ; readers.foreach(hash) + hash(writers.length) ; writers.foreach(hash) + hash(readwriters.length) ; readwriters.foreach(hash) + hash(readUnderWrite.toString) + hash(maskGran.size) ; maskGran.foreach(hash) + hash(memRef.size) ; memRef.foreach{ case (a, b) => hash(a) ; hash(b) } + case Verification(op, _, clk, pred, en, msg) => + id(36) ; hash(StructuralHash.verificationOp(op)) ; hash(clk) ; hash(pred) ; hash(en) ; hash(msg.string) + // ids 37 ... 39 are reserved for future Statement nodes + } + //scalastyle:on cyclomatic.complexity method.length + + // ReadUnderWrite is never used in place of a FirrtlNode and thus we can start a new id namespace + private def hash(ruw: ReadUnderWrite.Value): Unit = ruw match { + case ReadUnderWrite.New => id(0) + case ReadUnderWrite.Old => id(1) + case ReadUnderWrite.Undefined => id(2) + } + + private def hash(node: Width): Unit = node match { + case IntWidth(width) => id(40) ; hash(width) + case UnknownWidth => id(41) + case CalcWidth(arg) => id(42) ; hash(arg) + // we are hashing the name of the `VarWidth` instead of using `n` since these Vars exist in a different namespace + case VarWidth(name) => id(43) ; hash(name) + // ids 44 + 45 are reserved for future Width nodes + } + + private def hash(node: Orientation): Unit = node match { + case Default => id(46) + case Flip => id(47) + } + + private def hash(node: Field): Unit = { + // since we are only interested in a structural hash, we ignore field names + // this means that: hash(`{x : UInt<1>, y: UInt<2>}`) == hash(`{y : UInt<1>, x: UInt<2>}`) + // but: hash(`{x : UInt<1>, y: UInt<2>}`) != hash(`{y : UInt<2>, x: UInt<1>}`) + // which seems strange, since the connect semantics rely on field names, but it is the behavior that + // has been used in the Dedup pass for a long time. + // This position-based notion of equality requires us to replace field names with field indexes when hashing + // SubField accesses. + id(48) ; hash(node.flip) ; hash(node.tpe) + } + + //scalastyle:off cyclomatic.complexity + private def hash(node: Type): Unit = node match { + // Types + case UIntType(width: Width) => id(50) ; hash(width) + case SIntType(width: Width) => id(51) ; hash(width) + case FixedType(width, point) => id(52) ; hash(width) ; hash(point) + case BundleType(fields) => id(53) ; hash(fields.length) ; fields.foreach(hash) + case VectorType(tpe, size) => id(54) ; hash(tpe) ; hash(size) + case ClockType => id(55) + case ResetType => id(56) + case AsyncResetType => id(57) + case AnalogType(width) => id(58) ; hash(width) + case UnknownType => id(59) + case IntervalType(lower, upper, point) => id(60) ; hash(lower) ; hash(upper) ; hash(point) + // ids 61 ... 65 are reserved for future Type nodes + } + //scalastyle:on cyclomatic.complexity + + private def hash(node: Direction): Unit = node match { + case Input => id(66) + case Output => id(67) + } + + private def hash(node: Port): Unit = { + id(68) ; n(node.name) ; hash(node.direction) ; hash(node.tpe) + } + + private def hash(node: Param): Unit = node match { + case IntParam(name, value) => id(70) ; n(name) ; hash(value) + case DoubleParam(name, value) => id(71) ; n(name) ; hash(value) + case StringParam(name, value) => id(72) ; n(name) ; hash(value.string) + case RawStringParam(name, value) => id(73) ; n(name) ; hash(value) + // id 74 is reserved for future use + } + + private def hash(node: DefModule): Unit = node match { + // the module name is ignored since it does not affect module functionality + case Module(_, _name, ports, body) => + id(75) ; hash(ports.length) ; ports.foreach(hash) ; hash(body) + // the module name is ignored since it does not affect module functionality + case ExtModule(_, name, ports, defname, params) => + id(76) ; hash(ports.length) ; ports.foreach(hash) ; hash(defname) + hash(params.length) ; params.foreach(hash) + } + + // id 127 is reserved for Circuit nodes + + private def hash(d: firrtl.MPortDir): Unit = d match { + case firrtl.MInfer => id(-70) + case firrtl.MRead => id(-71) + case firrtl.MWrite => id(-72) + case firrtl.MReadWrite => id(-73) + } + + private def hash(c: firrtl.constraint.Constraint): Unit = c match { + case b: Bound => hash(b) /* uses ids -80 ... -84 */ + case firrtl.constraint.IsAdd(known, maxs, mins, others) => + id(-85) ; hash(known.nonEmpty) ; known.foreach(hash) + hash(maxs.length) ; maxs.foreach(hash) + hash(mins.length) ; mins.foreach(hash) + hash(others.length) ; others.foreach(hash) + case firrtl.constraint.IsFloor(child, dummyArg) => id(-86) ; hash(child) ; hash(dummyArg) + case firrtl.constraint.IsKnown(decimal) => id(-87) ; hash(decimal) + case firrtl.constraint.IsNeg(child, dummyArg) => id(-88) ; hash(child) ; hash(dummyArg) + case firrtl.constraint.IsPow(child, dummyArg) => id(-89) ; hash(child) ; hash(dummyArg) + case firrtl.constraint.IsVar(str) => id(-90) ; n(str) + } + + private def hash(b: Bound): Unit = b match { + case UnknownBound => id(-80) + case CalcBound(arg) => id(-81) ; hash(arg) + // we are hashing the name of the `VarBound` instead of using `n` since these Vars exist in a different namespace + case VarBound(name) => id(-82) ; hash(name) + case Open(value) => id(-83) ; hash(value) + case Closed(value) => id(-84) ; hash(value) + } +}
\ No newline at end of file diff --git a/src/main/scala/firrtl/transforms/Dedup.scala b/src/main/scala/firrtl/transforms/Dedup.scala index dc182858..ba06ba4b 100644 --- a/src/main/scala/firrtl/transforms/Dedup.scala +++ b/src/main/scala/firrtl/transforms/Dedup.scala @@ -244,35 +244,6 @@ class DedupModules extends Transform with DependencyAPIMigration { /** Utility functions for [[DedupModules]] */ object DedupModules extends LazyLogging { - def fastSerializedHash(s: Statement): Int ={ - def serialize(builder: StringBuilder, nindent: Int)(s: Statement): Unit = s match { - case Block(stmts) => stmts.map { - val x = serialize(builder, nindent)(_) - builder ++= "\n" - x - } - case Conditionally(info, pred, conseq, alt) => - builder ++= (" " * nindent) - builder ++= s"when ${pred.serialize} :" - builder ++= info.serialize - serialize(builder, nindent + 1)(conseq) - builder ++= "\n" + (" " * nindent) - builder ++= "else :\n" - serialize(builder, nindent + 1)(alt) - case Print(info, string, args, clk, en) => - builder ++= (" " * nindent) - val strs = Seq(clk.serialize, en.serialize, string.string) ++ - (args map (_.serialize)) - builder ++= "printf(" + (strs mkString ", ") + ")" + info.serialize - case other: Statement => - builder ++= (" " * nindent) - builder ++= other.serialize - } - val builder = new mutable.StringBuilder() - serialize(builder, 0)(s) - builder.hashCode() - } - /** Change's a module's internal signal names, types, infos, and modules. * @param rename Function to rename a signal. Called on declaration and references. * @param retype Function to retype a signal. Called on declaration, references, and subfields @@ -341,60 +312,6 @@ object DedupModules extends LazyLogging { module map onPort map onStmt } - def uniquifyField(ref: String, depth: Int, field: String): String = ref + depth + field - - /** Turns a module into a name-agnostic module - * @param module module to change - * @return name-agnostic module - */ - def agnostify(top: CircuitTarget, - module: DefModule, - renameMap: RenameMap, - agnosticModuleName: String - ): DefModule = { - - - val namespace = Namespace() - val typeMap = mutable.HashMap[String, Type]() - val nameMap = mutable.HashMap[String, String]() - - val mod = top.module(module.name) - val agnosticMod = top.module(agnosticModuleName) - - def rename(name: String): String = { - nameMap.getOrElseUpdate(name, { - val newName = namespace.newTemp - renameMap.record(mod.ref(name), agnosticMod.ref(newName)) - newName - }) - } - - def retype(name: String)(tpe: Type): Type = { - if (typeMap.contains(name)) typeMap(name) else { - def onType(depth: Int)(tpe: Type): Type = tpe map onType(depth + 1) match { - //TODO bugfix: ref.data.data and ref.datax.data will not rename to the right tags, even if they should be - case BundleType(fields) => - BundleType(fields.map(f => Field(rename(uniquifyField(name, depth, f.name)), f.flip, f.tpe))) - case other => other - } - val newType = onType(0)(tpe) - typeMap(name) = newType - newType - } - } - - def reOfModule(instance: String, ofModule: String): String = { - renameMap.get(top.module(ofModule)) match { - case Some(Seq(Target(_, Some(ofModuleTag), Nil))) => ofModuleTag - case None => ofModule - case other => throwInternalError(other.toString) - } - } - - val renamedModule = changeInternals(rename, retype, {i: Info => NoInfo}, reOfModule)(module) - renamedModule - } - /** Dedup a module's instances based on dedup map * * Will fixes up module if deduped instance's ports are differently named @@ -491,77 +408,55 @@ object DedupModules extends LazyLogging { dontDedup.toSet } - //scalastyle:off - /** Returns - * 1) map of tag to all matching module names, - * 2) renameMap of module name to tag (agnostic name) - * 3) maps module name to agnostic renameMap + /** Visits every module in the circuit, starting at the leaf nodes. + * Every module is hashed in order to find ones that have the exact + * same structure and are thus functionally equivalent. + * Every unique hash is mapped to a human-readable tag which starts with `Dedup#`. * @param top CircuitTarget * @param moduleLinearization Sequence of modules from leaf to top - * @param noDedups Set of modules to not dedup - * @return + * @param noDedups names of modules that should not be deduped + * @return A map from tag to names of modules with the same structure and + * a RenameMap which maps Module names to their Tag. */ def buildRTLTags(top: CircuitTarget, moduleLinearization: Seq[DefModule], noDedups: Set[String] ): (collection.Map[String, collection.Set[String]], RenameMap) = { + // maps hash code to human readable tag + val hashToTag = mutable.HashMap[ir.HashCode, String]() + // remembers all modules with the same hash + val hashToNames = mutable.HashMap[ir.HashCode, List[String]]() - // Maps a module name to its agnostic name - val tagMap = RenameMap() - - // Maps a tag to all matching module names - val tag2all = mutable.HashMap.empty[String, mutable.HashSet[String]] - - val agnosticRename = RenameMap() + // rename modules that we have already visited to their hash-derived tag name + val moduleNameToTag = mutable.HashMap[String, String]() val dontAgnostifyPorts = modsToNotAgnostifyPorts(moduleLinearization) moduleLinearization.foreach { originalModule => - // Replace instance references to new deduped modules - val dontcare = RenameMap() - dontcare.setCircuit("dontcare") - - if (noDedups.contains(originalModule.name)) { - // Don't dedup. Set dedup module to be the same as fixed module - tag2all(originalModule.name) = mutable.HashSet(originalModule.name) - } else { // Try to dedup - - // Build name-agnostic module - val agnosticModule = DedupModules.agnostify(top, originalModule, agnosticRename, "thisModule") - agnosticRename.record(top.module(originalModule.name), top.module("thisModule")) - agnosticRename.delete(top.module(originalModule.name)) - - // Build tag - val builder = new mutable.ArrayBuffer[Any]() - - // It may seem weird to use non-agnostified ports with an agnostified body because - // technically it would be invalid FIRRTL, but it is logically sound for the purpose of - // calculating deduplication tags - val ports = - if (dontAgnostifyPorts(originalModule.name)) originalModule.ports else agnosticModule.ports - ports.foreach { builder ++= _.serialize } - - agnosticModule match { - case Module(i, n, ps, b) => builder ++= fastSerializedHash(b).toString()//.serialize - case ExtModule(i, n, ps, dn, p) => - builder ++= dn - p.foreach { builder ++= _.serialize } - } - val tag = builder.hashCode().toString - - // Match old module name to its tag - agnosticRename.record(top.module(originalModule.name), top.module(tag)) - tagMap.record(top.module(originalModule.name), top.module(tag)) + val hash = if (noDedups.contains(originalModule.name)) { + // if we do not want to dedup we just hash the name of the module which is guaranteed to be unique + StructuralHash.sha256(originalModule.name) + } else if (dontAgnostifyPorts(originalModule.name)) { + StructuralHash.sha256WithSignificantPortNames(originalModule, moduleNameToTag) + } else { + StructuralHash.sha256(originalModule, moduleNameToTag) + } - // Set tag's module to be the first matching module - val all = tag2all.getOrElseUpdate(tag, mutable.HashSet.empty[String]) - all += originalModule.name + if (hashToTag.contains(hash)) { + hashToNames(hash) = hashToNames(hash) :+ originalModule.name + } else { + hashToTag(hash) = "Dedup#" + originalModule.name + hashToNames(hash) = List(originalModule.name) } + moduleNameToTag(originalModule.name) = hashToTag(hash) } + + val tag2all = hashToNames.map{ case (hash, names) => hashToTag(hash) -> names.toSet } + val tagMap = RenameMap() + moduleNameToTag.foreach{ case (name, tag) => tagMap.record(top.module(name), top.module(tag)) } (tag2all, tagMap) } - //scalastyle:on /** Deduplicate * @param circuit Circuit diff --git a/src/test/scala/firrtl/ir/StructuralHashSpec.scala b/src/test/scala/firrtl/ir/StructuralHashSpec.scala new file mode 100644 index 00000000..17fe0b84 --- /dev/null +++ b/src/test/scala/firrtl/ir/StructuralHashSpec.scala @@ -0,0 +1,277 @@ +// See LICENSE for license details. + +package firrtl.ir + +import firrtl.PrimOps._ +import org.scalatest.flatspec.AnyFlatSpec + +class StructuralHashSpec extends AnyFlatSpec { + private def hash(n: DefModule): HashCode = StructuralHash.sha256(n, n => n) + private def hash(c: Circuit): HashCode = StructuralHash.sha256Node(c) + private def hash(e: Expression): HashCode = StructuralHash.sha256Node(e) + private def hash(t: Type): HashCode = StructuralHash.sha256Node(t) + private def hash(s: Statement): HashCode = StructuralHash.sha256Node(s) + private val highFirrtlCompiler = new firrtl.stage.transforms.Compiler( + targets = firrtl.stage.Forms.HighForm + ) + private def parse(circuit: String): Circuit = { + val rawFirrtl = firrtl.Parser.parse(circuit) + // TODO: remove requirement that Firrtl needs to be type checked. + // The only reason this is needed for the structural hash right now is because we + // define bundles with the same list of field types to be the same, regardless of the + // name of these fields. Thus when the fields are accessed, we need to know their position + // in order to appropriately hash them. + highFirrtlCompiler.transform(firrtl.CircuitState(rawFirrtl, Seq())).circuit + } + + private val b0 = UIntLiteral(0,IntWidth(1)) + private val b1 = UIntLiteral(1,IntWidth(1)) + private val add = DoPrim(Add, Seq(b0, b1), Seq(), UnknownType) + + it should "generate the same hash if the objects are structurally the same" in { + assert(hash(b0) == hash(UIntLiteral(0,IntWidth(1)))) + assert(hash(b0) != hash(UIntLiteral(1,IntWidth(1)))) + assert(hash(b0) != hash(UIntLiteral(1,IntWidth(2)))) + + assert(hash(b1) == hash(UIntLiteral(1,IntWidth(1)))) + assert(hash(b1) != hash(UIntLiteral(0,IntWidth(1)))) + assert(hash(b1) != hash(UIntLiteral(1,IntWidth(2)))) + } + + it should "ignore expression types" in { + assert(hash(add) == hash(DoPrim(Add, Seq(b0, b1), Seq(), UnknownType))) + assert(hash(add) == hash(DoPrim(Add, Seq(b0, b1), Seq(), UIntType(UnknownWidth)))) + assert(hash(add) != hash(DoPrim(Add, Seq(b0, b0), Seq(), UnknownType))) + } + + it should "ignore variable names" in { + val a = + """circuit a: + | module a: + | input x : UInt<1> + | output y: UInt<1> + | y <= x + |""".stripMargin + + assert(hash(parse(a)) == hash(parse(a)), "the same circuit should always be equivalent") + + val b = + """circuit a: + | module a: + | input abc : UInt<1> + | output haha: UInt<1> + | haha <= abc + |""".stripMargin + + assert(hash(parse(a)) == hash(parse(b)), "renaming ports should not affect the hash by default") + + val c = + """circuit a: + | module a: + | input x : UInt<1> + | output y: UInt<1> + | y <= and(x, UInt<1>(0)) + |""".stripMargin + + assert(hash(parse(a)) != hash(parse(c)), "changing an expression should affect the hash") + + val d = + """circuit c: + | module c: + | input abc : UInt<1> + | output haha: UInt<1> + | haha <= abc + |""".stripMargin + + assert(hash(parse(a)) != hash(parse(d)), "circuits with different names are always different") + assert(hash(parse(a).modules.head) == hash(parse(d).modules.head), + "modules with different names can be structurally different") + + // for the Dedup pass we do need a way to take the port names into account + assert(StructuralHash.sha256WithSignificantPortNames(parse(a).modules.head) != + StructuralHash.sha256WithSignificantPortNames(parse(b).modules.head), + "renaming ports does affect the hash if we ask to") + } + + + it should "not ignore port names if asked to" in { + val e = + """circuit a: + | module a: + | input x : UInt<1> + | wire y: UInt<1> + | y <= x + |""".stripMargin + + val f = + """circuit a: + | module a: + | input z : UInt<1> + | wire y: UInt<1> + | y <= z + |""".stripMargin + + val g = + """circuit a: + | module a: + | input x : UInt<1> + | wire z: UInt<1> + | z <= x + |""".stripMargin + + assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) != + StructuralHash.sha256WithSignificantPortNames(parse(f).modules.head), + "renaming ports does affect the hash if we ask to") + assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) == + StructuralHash.sha256WithSignificantPortNames(parse(g).modules.head), + "renaming internal wires should never affect the hash") + assert(hash(parse(e).modules.head) == hash(parse(g).modules.head), + "renaming internal wires should never affect the hash") + } + + it should "not ignore port bundle names if asked to" in { + val e = + """circuit a: + | module a: + | input x : {x: UInt<1>} + | wire y: {x: UInt<1>} + | y.x <= x.x + |""".stripMargin + + val f = + """circuit a: + | module a: + | input x : {z: UInt<1>} + | wire y: {x: UInt<1>} + | y.x <= x.z + |""".stripMargin + + val g = + """circuit a: + | module a: + | input x : {x: UInt<1>} + | wire y: {z: UInt<1>} + | y.z <= x.x + |""".stripMargin + + assert(hash(parse(e).modules.head) == hash(parse(f).modules.head), + "renaming port bundles does normally not affect the hash") + assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) != + StructuralHash.sha256WithSignificantPortNames(parse(f).modules.head), + "renaming port bundles does affect the hash if we ask to") + assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) == + StructuralHash.sha256WithSignificantPortNames(parse(g).modules.head), + "renaming internal wire bundles should never affect the hash") + assert(hash(parse(e).modules.head) == hash(parse(g).modules.head), + "renaming internal wire bundles should never affect the hash") + } + + + it should "fail on Info" in { + // it does not make sense to hash Info nodes + assertThrows[RuntimeException] { + StructuralHash.sha256Node(FileInfo(StringLit(""))) + } + } + + "Bundles with different field names" should "be structurally equivalent" in { + def parse(str: String): BundleType = { + val src = + s"""circuit c: + | module c: + | input z: $str + |""".stripMargin + val c = firrtl.Parser.parse(src) + val tpe = c.modules.head.ports.head.tpe + tpe.asInstanceOf[BundleType] + } + + val a = "{x: UInt<1>, y: UInt<1>}" + assert(hash(parse(a)) == hash(parse(a)), "the same bundle should always be equivalent") + + val b = "{z: UInt<1>, y: UInt<1>}" + assert(hash(parse(a)) == hash(parse(b)), "changing a field name should maintain equivalence") + + val c = "{x: UInt<2>, y: UInt<1>}" + assert(hash(parse(a)) != hash(parse(c)), "changing a field type should not maintain equivalence") + + val d = "{x: UInt<1>, y: {y: UInt<1>}}" + assert(hash(parse(a)) != hash(parse(d)), "changing the structure should not maintain equivalence") + + assert(hash(parse("{z: {y: {x: UInt<1>}}, a: UInt<1>}")) == hash(parse("{a: {b: {c: UInt<1>}}, z: UInt<1>}"))) + } + + "ExtModules with different names but the same defname" should "be structurally equivalent" in { + val a = + """circuit a: + | extmodule a: + | input x : UInt<1> + | defname = xyz + |""".stripMargin + + val b = + """circuit b: + | extmodule b: + | input y : UInt<1> + | defname = xyz + |""".stripMargin + + // Q: should extmodule portnames always be significant since they map to the verilog pins? + // A: It would be a bug for two exmodules in the same circuit to have the same defname but different + // port names. This should be detected by an earlier pass and thus we do not have to deal with that situation. + assert(hash(parse(a).modules.head) == hash(parse(b).modules.head), + "two ext modules with the same defname and the same type and number of ports") + assert(StructuralHash.sha256WithSignificantPortNames(parse(a).modules.head) != + StructuralHash.sha256WithSignificantPortNames(parse(b).modules.head), + "two ext modules with significant port names") + } + + "Blocks and empty statements" should "not affect structural equivalence" in { + val stmtA = DefNode(NoInfo, "a", UIntLiteral(1)) + val stmtB = DefNode(NoInfo, "b", UIntLiteral(1)) + + val a = Block(Seq(Block(Seq(stmtA)), stmtB)) + val b = Block(Seq(stmtA, stmtB)) + assert(hash(a) == hash(b)) + + val c = Block(Seq(Block(Seq(Block(Seq(stmtA, stmtB)))))) + assert(hash(a) == hash(c)) + + val d = Block(Seq(stmtA)) + assert(hash(a) != hash(d)) + + val e = Block(Seq(Block(Seq(stmtB)), stmtB)) + assert(hash(a) != hash(e)) + + val f = Block(Seq(Block(Seq(Block(Seq(stmtA, EmptyStmt, stmtB)))))) + assert(hash(a) == hash(f)) + } + + "Conditionally" should "properly separate if and else branch" in { + val stmtA = DefNode(NoInfo, "a", UIntLiteral(1)) + val stmtB = DefNode(NoInfo, "b", UIntLiteral(1)) + val cond = UIntLiteral(1) + + val a = Conditionally(NoInfo, cond, stmtA, stmtB) + val b = Conditionally(NoInfo, cond, Block(Seq(stmtA)), stmtB) + assert(hash(a) == hash(b)) + + val c = Conditionally(NoInfo, cond, Block(Seq(stmtA)), Block(Seq(EmptyStmt, stmtB))) + assert(hash(a) == hash(c)) + + val d = Block(Seq(Conditionally(NoInfo, cond, stmtA, EmptyStmt), stmtB)) + assert(hash(a) != hash(d)) + + val e = Conditionally(NoInfo, cond, stmtA, EmptyStmt) + val f = Conditionally(NoInfo, cond, EmptyStmt, stmtA) + assert(hash(e) != hash(f)) + } +} + +private case object DebugHasher extends Hasher { + override def update(b: Byte): Unit = println(s"b(${b.toInt & 0xff})") + override def update(i: Int): Unit = println(s"i(${i})") + override def update(l: Long): Unit = println(s"l(${l})") + override def update(s: String): Unit = println(s"s(${s})") + override def update(b: Array[Byte]): Unit = println(s"bytes(${b.map(x => x.toInt & 0xff).mkString(", ")})") +}
\ No newline at end of file |
