diff options
| author | Aditya Naik | 2024-05-29 16:57:13 -0700 |
|---|---|---|
| committer | Aditya Naik | 2024-05-29 16:57:13 -0700 |
| commit | 165804ee58cb18443042b9655328278434ddedf4 (patch) | |
| tree | 4e167eff9e7b3ec09d73dbd9feaa6f9964cd8a68 | |
| parent | 57b8a395ee8d5fdabb2deed3db7d0c644f0a7eed (diff) | |
Add Scala3 support
100 files changed, 284 insertions, 4204 deletions
@@ -2,10 +2,14 @@ enablePlugins(SiteScaladocPlugin) +val scala3Version = "3.4.1" // "3.3.3" + +Compile / compile / logLevel := Level.Error + lazy val commonSettings = Seq( organization := "edu.berkeley.cs", - scalaVersion := "2.12.17", - crossScalaVersions := Seq("2.13.10", "2.12.17") + scalaVersion := scala3Version, + crossScalaVersions := Seq("2.13.10", "2.12.17", scala3Version) ) lazy val isAtLeastScala213 = Def.setting { @@ -16,44 +20,30 @@ lazy val isAtLeastScala213 = Def.setting { lazy val firrtlSettings = Seq( name := "firrtl", version := "1.6-SNAPSHOT", - addCompilerPlugin(scalafixSemanticdb), + // addCompilerPlugin(scalafixSemanticdb), scalacOptions := Seq( "-deprecation", "-unchecked", "-language:reflectiveCalls", "-language:existentials", "-language:implicitConversions", - "-Yrangepos" // required by SemanticDB compiler plugin + // "-rewrite", "-source:3.4-migration" ), // Always target Java8 for maximum compatibility javacOptions ++= Seq("-source", "1.8", "-target", "1.8"), libraryDependencies ++= Seq( - "org.scala-lang" % "scala-reflect" % scalaVersion.value, - "org.scalatest" %% "scalatest" % "3.2.14" % "test", - "org.scalatestplus" %% "scalacheck-1-15" % "3.2.11.0" % "test", - "com.github.scopt" %% "scopt" % "3.7.1", - "net.jcazevedo" %% "moultingyaml" % "0.4.2", - "org.json4s" %% "json4s-native" % "4.0.6", - "org.apache.commons" % "commons-text" % "1.10.0", - "io.github.alexarchambault" %% "data-class" % "0.2.5", - "com.lihaoyi" %% "os-lib" % "0.8.1" + "org.scala-lang" %% "toolkit" % "0.1.7", + // "org.scala-lang" % "scala-reflect" % scalaVersion.value, + // "org.scalatest" %% "scalatest" % "3.2.14" % "test", + // "org.scalatestplus" %% "scalacheck-1-15" % "3.2.11.0" % "test", + "org.scala-lang.modules" %% "scala-parallel-collections" % "1.0.4", + "com.github.scopt" %% "scopt" % "4.1.0", + // "net.jcazevedo" %% "moultingyaml" % "0.4.2", + "org.json4s" %% "json4s-native" % "4.1.0-M5", + "org.apache.commons" % "commons-text" % "1.12.0", + // "io.github.alexarchambault" %% "data-class" % "0.2.5", + "com.lihaoyi" %% "os-lib" % "0.9.1" ), - // macros for the data-class library - libraryDependencies ++= { - if (isAtLeastScala213.value) Nil - else Seq(compilerPlugin(("org.scalamacros" % "paradise" % "2.1.1").cross(CrossVersion.full))) - }, - scalacOptions ++= { - if (isAtLeastScala213.value) Seq("-Ymacro-annotations") - else Nil - }, - // starting with scala 2.13 the parallel collections are separate from the standard library - libraryDependencies ++= { - CrossVersion.partialVersion(scalaVersion.value) match { - case Some((2, major)) if major <= 12 => Seq() - case _ => Seq("org.scala-lang.modules" %% "scala-parallel-collections" % "1.0.4") - } - }, resolvers ++= Seq( Resolver.sonatypeRepo("snapshots"), Resolver.sonatypeRepo("releases") @@ -97,70 +87,6 @@ lazy val antlrSettings = Seq( Antlr4 / javaSource := (Compile / sourceManaged).value ) -lazy val publishSettings = Seq( - publishMavenStyle := true, - Test / publishArtifact := false, - pomIncludeRepository := { x => false }, - // scm is set by sbt-ci-release - pomExtra := <url>http://chisel.eecs.berkeley.edu/</url> - <licenses> - <license> - <name>apache_v2</name> - <url>https://opensource.org/licenses/Apache-2.0</url> - <distribution>repo</distribution> - </license> - </licenses> - <developers> - <developer> - <id>jackbackrack</id> - <name>Jonathan Bachrach</name> - <url>http://www.eecs.berkeley.edu/~jrb/</url> - </developer> - </developers>, - publishTo := { - val v = version.value - val nexus = "https://oss.sonatype.org/" - if (v.trim.endsWith("SNAPSHOT")) { - Some("snapshots".at(nexus + "content/repositories/snapshots")) - } else { - Some("releases".at(nexus + "service/local/staging/deploy/maven2")) - } - } -) - -lazy val docSettings = Seq( - Compile / doc := (ScalaUnidoc / doc).value, - autoAPIMappings := true, - Compile / doc / scalacOptions ++= Seq( - // ANTLR-generated classes aren't really part of public API and cause - // errors in ScalaDoc generation - "-skip-packages", - "firrtl.antlr", - "-Xfatal-warnings", - "-feature", - "-diagrams", - "-diagrams-max-classes", - "25", - "-doc-version", - version.value, - "-doc-title", - name.value, - "-doc-root-content", - baseDirectory.value + "/root-doc.txt", - "-sourcepath", - (ThisBuild / baseDirectory).value.toString, - "-doc-source-url", { - val branch = - if (version.value.endsWith("-SNAPSHOT")) { - "1.6.x" - } else { - s"v${version.value}" - } - s"https://github.com/chipsalliance/firrtl/tree/$branch€{FILE_PATH_EXT}#L€{FILE_LINE}" - } - ) -) - lazy val firrtl = (project in file(".")) .enablePlugins(ProtobufPlugin) .enablePlugins(ScalaUnidocPlugin) @@ -176,8 +102,6 @@ lazy val firrtl = (project in file(".")) .settings(assemblySettings) .settings(inConfig(Test)(baseAssemblySettings)) .settings(testAssemblySettings) - .settings(publishSettings) - .settings(docSettings) .enablePlugins(BuildInfoPlugin) .settings( buildInfoPackage := name.value, diff --git a/project/build.properties b/project/build.properties index 46e43a97..081fdbbc 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=1.8.2 +sbt.version=1.10.0 diff --git a/src/main/scala/firrtl/AddDescriptionNodes.scala b/src/main/scala/firrtl/AddDescriptionNodes.scala index de9ff523..cc4dd49f 100644 --- a/src/main/scala/firrtl/AddDescriptionNodes.scala +++ b/src/main/scala/firrtl/AddDescriptionNodes.scala @@ -154,7 +154,7 @@ class AddDescriptionNodes extends Transform with DependencyAPIMigration { Dependency[firrtl.transforms.LegalizeClocksAndAsyncResetsTransform], Dependency[firrtl.transforms.FlattenRegUpdate], Dependency(passes.VerilogModulusCleanup), - Dependency[firrtl.transforms.VerilogRename], + Dependency[firrtl.transforms.VerilogRename[?]], Dependency(firrtl.passes.VerilogPrep) ) @@ -215,11 +215,11 @@ class AddDescriptionNodes extends Transform with DependencyAPIMigration { val (docs: Seq[DocString] @unchecked, nodocs) = descs.partition { case _: DocString => true case _ => false - } + }: @unchecked val (attrs: Seq[Attribute] @unchecked, rest) = nodocs.partition { case _: Attribute => true case _ => false - } + }: @unchecked val doc = if (docs.nonEmpty) { Seq(DocString(StringLit.unescape(docs.map(_.string.string).mkString("\n\n")))) diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala index 3466f3e1..7c80e87a 100644 --- a/src/main/scala/firrtl/Compiler.scala +++ b/src/main/scala/firrtl/Compiler.scala @@ -76,8 +76,8 @@ case class CircuitState( * @param annoClasses * @return */ - def resolvePathsOf(annoClasses: Class[_]*): CircuitState = { - val targets = getAnnotationsOf(annoClasses: _*).flatMap(_.getTargets) + def resolvePathsOf(annoClasses: Class[?]*): CircuitState = { + val targets = getAnnotationsOf(annoClasses*).flatMap(_.getTargets) if (targets.nonEmpty) resolvePaths(targets.flatMap { _.getComplete }) else this } @@ -85,7 +85,7 @@ case class CircuitState( * @param annoClasses * @return */ - def getAnnotationsOf(annoClasses: Class[_]*): AnnotationSeq = { + def getAnnotationsOf(annoClasses: Class[?]*): AnnotationSeq = { annotations.collect { case a if annoClasses.contains(a.getClass) => a } } } @@ -376,10 +376,10 @@ abstract class SeqTransform extends Transform with SeqTransformBased { trait ResolvedAnnotationPaths { this: Transform => - val annotationClasses: Traversable[Class[_]] + val annotationClasses: Traversable[Class[?]] override def prepare(state: CircuitState): CircuitState = { - state.resolvePathsOf(annotationClasses.toSeq: _*) + state.resolvePathsOf(annotationClasses.toSeq*) } // Any transform with this trait invalidates DedupAnnotationsTransform diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index e0f95dcb..14d5089a 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -4,24 +4,22 @@ package firrtl import java.io.File import firrtl.annotations.NoTargetAnnotation -import firrtl.backends.experimental.smt.{Btor2Emitter, SMTLibEmitter} -import firrtl.backends.experimental.rtlil.RtlilEmitter import firrtl.backends.proto.{Emitter => ProtoEmitter} import firrtl.options.Viewer.view import firrtl.options.{CustomFileEmission, Dependency, HasShellOptions, PhaseException, ShellOption} import firrtl.passes.PassException -import firrtl.stage.{FirrtlFileAnnotation, FirrtlOptions, RunFirrtlTransformAnnotation} +import firrtl.stage.{FirrtlFileAnnotation, FirrtlOptions, RunFirrtlTransformAnnotation, FirrtlOptionsView} case class EmitterException(message: String) extends PassException(message) // ***** Annotations for telling the Emitters what to emit ***** sealed trait EmitAnnotation extends NoTargetAnnotation { - val emitter: Class[_ <: Emitter] + val emitter: Class[? <: Emitter] } -case class EmitCircuitAnnotation(emitter: Class[_ <: Emitter]) extends EmitAnnotation +case class EmitCircuitAnnotation(emitter: Class[? <: Emitter]) extends EmitAnnotation -case class EmitAllModulesAnnotation(emitter: Class[_ <: Emitter]) extends EmitAnnotation +case class EmitAllModulesAnnotation(emitter: Class[? <: Emitter]) extends EmitAnnotation object EmitCircuitAnnotation extends HasShellOptions { val options = Seq( @@ -57,12 +55,6 @@ object EmitCircuitAnnotation extends HasShellOptions { RunFirrtlTransformAnnotation(new SystemVerilogEmitter), EmitCircuitAnnotation(classOf[SystemVerilogEmitter]) ) - case "experimental-btor2" | "btor2" => - Seq(RunFirrtlTransformAnnotation(Dependency(Btor2Emitter)), EmitCircuitAnnotation(Btor2Emitter.getClass)) - case "experimental-smt2" | "smt2" => - Seq(RunFirrtlTransformAnnotation(Dependency(SMTLibEmitter)), EmitCircuitAnnotation(SMTLibEmitter.getClass)) - case "experimental-rtlil" => - Seq(RunFirrtlTransformAnnotation(Dependency[RtlilEmitter]), EmitCircuitAnnotation(classOf[RtlilEmitter])) case _ => throw new PhaseException(s"Unknown emitter '$a'! (Did you misspell it?)") }, helpText = "Run the specified circuit emitter (all modules in one file)", @@ -149,8 +141,6 @@ object EmitAllModulesAnnotation extends HasShellOptions { RunFirrtlTransformAnnotation(new SystemVerilogEmitter), EmitAllModulesAnnotation(classOf[SystemVerilogEmitter]) ) - case "experimental-rtlil" => - Seq(RunFirrtlTransformAnnotation(Dependency[RtlilEmitter]), EmitAllModulesAnnotation(classOf[RtlilEmitter])) case _ => throw new PhaseException(s"Unknown emitter '$a'! (Did you misspell it?)") }, helpText = "Run the specified module emitter (one file per module)", diff --git a/src/main/scala/firrtl/Macros.scala b/src/main/scala/firrtl/Macros.scala new file mode 100644 index 00000000..c3683b2e --- /dev/null +++ b/src/main/scala/firrtl/Macros.scala @@ -0,0 +1,13 @@ +package firrtl.macros + +import scala.quoted._ + +object Macros { + inline def isModuleClass[T](obj: T): Boolean = ${ isModuleClassImpl('obj) } + + def isModuleClassImpl[T: Type](obj: Expr[T])(using Quotes): Expr[Boolean] = { + import quotes.reflect._ + val objType = TypeRepr.of[T] + Expr(objType.typeSymbol.flags.is(Flags.Module)) + } +} diff --git a/src/main/scala/firrtl/Namespace.scala b/src/main/scala/firrtl/Namespace.scala index a4b7bc7a..81fdaaa1 100644 --- a/src/main/scala/firrtl/Namespace.scala +++ b/src/main/scala/firrtl/Namespace.scala @@ -27,10 +27,12 @@ class Namespace private { else { var idx = indices.getOrElse(value, 0) var str = value - do { + while { + !(tryName(str)) + } do { str = s"${value}_$idx" idx += 1 - } while (!(tryName(str))) + } indices(value) = idx str } diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 9e267a39..fab70ee3 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -403,7 +403,7 @@ object Utils extends LazyLogging { /** Returns an inlined expression (replacing node references with values), * stopping on a stopping condition or until the reference is not a node */ - def inline(nodeMap: NodeMap, stop: String => Boolean = { x: String => false })(e: Expression): Expression = { + def inline(nodeMap: NodeMap, stop: String => Boolean = { (x: String) => false })(e: Expression): Expression = { def onExp(e: Expression): Expression = e.map(onExp) match { case Reference(name, _, _, _) if nodeMap.contains(name) && !stop(name) => onExp(nodeMap(name)) case other => other diff --git a/src/main/scala/firrtl/analyses/GetNamespace.scala b/src/main/scala/firrtl/analyses/GetNamespace.scala index dddfc338..35428d59 100644 --- a/src/main/scala/firrtl/analyses/GetNamespace.scala +++ b/src/main/scala/firrtl/analyses/GetNamespace.scala @@ -4,8 +4,10 @@ package firrtl.analyses import firrtl.annotations.NoTargetAnnotation import firrtl.{CircuitState, DependencyAPIMigration, Namespace, Transform} +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import firrtl.stage.Forms + case class ModuleNamespaceAnnotation(namespace: Namespace) extends NoTargetAnnotation /** Create a namespace with this circuit diff --git a/src/main/scala/firrtl/analyses/InstanceKeyGraph.scala b/src/main/scala/firrtl/analyses/InstanceKeyGraph.scala index 7584e3c8..7e2fff48 100644 --- a/src/main/scala/firrtl/analyses/InstanceKeyGraph.scala +++ b/src/main/scala/firrtl/analyses/InstanceKeyGraph.scala @@ -37,8 +37,8 @@ class InstanceKeyGraph private (c: ir.Circuit) { circuitTopInstance.OfModule +: internalGraph.reachableFrom(circuitTopInstance).toSeq.map(_.OfModule) private lazy val cachedUnreachableModules: Seq[OfModule] = { - val all = mutable.LinkedHashSet(childInstances.map(c => OfModule(c._1)): _*) - val reachable = mutable.LinkedHashSet(cachedReachableModules: _*) + val all = mutable.LinkedHashSet(childInstances.map(c => OfModule(c._1))*) + val reachable = mutable.LinkedHashSet(cachedReachableModules*) all.diff(reachable).toSeq } @@ -97,9 +97,9 @@ class InstanceKeyGraph private (c: ir.Circuit) { def getChildInstanceMap: mutable.LinkedHashMap[OfModule, mutable.LinkedHashMap[Instance, OfModule]] = mutable.LinkedHashMap(childInstances.map { case (k, v) => - val moduleMap: mutable.LinkedHashMap[Instance, OfModule] = mutable.LinkedHashMap(v.map(_.toTokens): _*) + val moduleMap: mutable.LinkedHashMap[Instance, OfModule] = mutable.LinkedHashMap(v.map(_.toTokens)*) TargetToken.OfModule(k) -> moduleMap - }: _*) + }*) /** All modules in the circuit reachable from the top module */ def reachableModules: Seq[OfModule] = cachedReachableModules diff --git a/src/main/scala/firrtl/annotations/Annotation.scala b/src/main/scala/firrtl/annotations/Annotation.scala index c6145c86..6a350f3c 100644 --- a/src/main/scala/firrtl/annotations/Annotation.scala +++ b/src/main/scala/firrtl/annotations/Annotation.scala @@ -26,7 +26,7 @@ trait Annotation extends Product { * @param ls * @return */ - private def extractComponents(ls: Traversable[_]): Traversable[Target] = { + private def extractComponents(ls: Traversable[?]): Traversable[Target] = { ls.flatMap { case c: Target => Seq(c) case x: scala.collection.Traversable[_] => extractComponents(x) diff --git a/src/main/scala/firrtl/annotations/JsonProtocol.scala b/src/main/scala/firrtl/annotations/JsonProtocol.scala index fe35c77d..40a8ac07 100644 --- a/src/main/scala/firrtl/annotations/JsonProtocol.scala +++ b/src/main/scala/firrtl/annotations/JsonProtocol.scala @@ -20,7 +20,7 @@ trait HasSerializationHints { // For serialization of complicated constructor arguments, let the annotation // writer specify additional type hints for relevant classes that might be // contained within - def typeHints: Seq[Class[_]] + def typeHints: Seq[Class[?]] } /** Wrapper [[Annotation]] for Annotations that cannot be serialized */ @@ -30,9 +30,9 @@ object JsonProtocol extends LazyLogging { private val GetClassPattern = "[^']*'([^']+)'.*".r class TransformClassSerializer - extends CustomSerializer[Class[_ <: Transform]](format => + extends CustomSerializer[Class[? <: Transform]](format => ( - { case JString(s) => Class.forName(s).asInstanceOf[Class[_ <: Transform]] }, + { case JString(s) => Class.forName(s).asInstanceOf[Class[? <: Transform]] }, { case x: Class[_] => JString(x.getName) } ) ) @@ -71,7 +71,7 @@ object JsonProtocol extends LazyLogging { { case JString(s) => try { - Class.forName(s).asInstanceOf[Class[_ <: Transform]].newInstance() + Class.forName(s).asInstanceOf[Class[? <: Transform]].newInstance() } catch { case e: java.lang.InstantiationException => throw new FirrtlInternalException( @@ -222,7 +222,7 @@ object JsonProtocol extends LazyLogging { ) /** Construct Json formatter for annotations */ - def jsonFormat(tags: Seq[Class[_]]) = { + def jsonFormat(tags: Seq[Class[?]]) = { Serialization.formats(FullTypeHints(tags.toList, "class")) + new TransformClassSerializer + new NamedSerializer + new CircuitNameSerializer + new ModuleNameSerializer + new ComponentNameSerializer + new TargetSerializer + @@ -245,7 +245,7 @@ object JsonProtocol extends LazyLogging { ): Seq[(Annotation, Throwable)] = annos.map(a => a -> Try(write(a))).collect { case (a, Failure(e)) => (a, e) } - private def getTags(annos: Seq[Annotation]): Seq[Class[_]] = + private def getTags(annos: Seq[Annotation]): Seq[Class[?]] = annos .flatMap({ case anno: HasSerializationHints => anno.getClass +: anno.typeHints @@ -292,7 +292,7 @@ object JsonProtocol extends LazyLogging { } def deserializeTry(in: JsonInput, allowUnrecognizedAnnotations: Boolean = false): Try[Seq[Annotation]] = Try { - val parsed = parse(in) + val parsed: JValue = parse(in) val annos = parsed match { case JArray(objs) => objs case x => @@ -336,7 +336,7 @@ object JsonProtocol extends LazyLogging { case _: java.lang.ClassNotFoundException => classNotFoundBuildingLoaded = true None - }): Option[Class[_]] + }): Option[Class[?]] } implicit val formats = jsonFormat(loaded) try { diff --git a/src/main/scala/firrtl/annotations/Target.scala b/src/main/scala/firrtl/annotations/Target.scala index 02ec42b8..f913004a 100644 --- a/src/main/scala/firrtl/annotations/Target.scala +++ b/src/main/scala/firrtl/annotations/Target.scala @@ -165,7 +165,7 @@ object Target { * @return */ def isOnly(seq: Seq[TargetToken], keywords: String*): Boolean = { - seq.map(_.is(keywords: _*)).foldLeft(false)(_ || _) && keywords.nonEmpty + seq.map(_.is(keywords*)).foldLeft(false)(_ || _) && keywords.nonEmpty } /** @return [[Target]] from human-readable serialization */ @@ -323,7 +323,7 @@ case class GenericTarget(circuitOpt: Option[String], moduleOpt: Option[String], * @param keywords */ private def requireLast(default: Boolean, keywords: String*): Unit = { - val isOne = if (tokens.isEmpty) default else tokens.last.is(keywords: _*) + val isOne = if (tokens.isEmpty) default else tokens.last.is(keywords*) require(isOne, s"${tokens.last} is not one of $keywords") } @@ -509,7 +509,7 @@ trait IsComponent extends IsMember { override def toNamed: ComponentName = { if (isLocal) { val mn = ModuleName(module, CircuitName(circuit)) - Seq(tokens: _*) match { + Seq(tokens*) match { case Seq(Ref(name)) => ComponentName(name, mn) case Ref(_) :: tail if Target.isOnly(tail, ".", "[]") => val name = tokens.foldLeft("") { diff --git a/src/main/scala/firrtl/annotations/analysis/DuplicationHelper.scala b/src/main/scala/firrtl/annotations/analysis/DuplicationHelper.scala index c16a2670..838ba8f5 100644 --- a/src/main/scala/firrtl/annotations/analysis/DuplicationHelper.scala +++ b/src/main/scala/firrtl/annotations/analysis/DuplicationHelper.scala @@ -71,7 +71,7 @@ case class DuplicationHelper(existingModules: Set[String]) { case None => // Need a new name val prefix = path.last._2.value + "___" val postfix = top + "_" + path.map { case (i, m) => i.value }.mkString("_") - val ns = mutable.HashSet(allModules.toSeq: _*) + val ns = mutable.HashSet(allModules.toSeq*) val finalName = firrtl.Namespace.findValidPrefix(prefix, Seq(postfix), ns) + postfix allModules += finalName cachedNames((top, path)) = finalName diff --git a/src/main/scala/firrtl/annotations/transforms/EliminateTargetPaths.scala b/src/main/scala/firrtl/annotations/transforms/EliminateTargetPaths.scala index 83bea253..5950fc62 100644 --- a/src/main/scala/firrtl/annotations/transforms/EliminateTargetPaths.scala +++ b/src/main/scala/firrtl/annotations/transforms/EliminateTargetPaths.scala @@ -9,7 +9,7 @@ import firrtl.annotations.TargetToken.{fromDefModuleToTargetToken, Instance, OfM import firrtl.annotations.analysis.DuplicationHelper import firrtl.annotations._ import firrtl.ir._ -import firrtl.{AnnotationSeq, CircuitState, DependencyAPIMigration, FirrtlInternalException, RenameMap, Transform} +import firrtl.{AnnotationSeq, CircuitState, DependencyAPIMigration, FirrtlInternalException, RenameMap, Transform, annoSeqToSeq, seqToAnnoSeq} import firrtl.renamemap.MutableRenameMap import firrtl.stage.Forms import firrtl.transforms.DedupedResult diff --git a/src/main/scala/firrtl/backends/experimental/rtlil/RtlilEmitter.scala b/src/main/scala/firrtl/backends/experimental/rtlil/RtlilEmitter.scala deleted file mode 100644 index 6c6c0b69..00000000 --- a/src/main/scala/firrtl/backends/experimental/rtlil/RtlilEmitter.scala +++ /dev/null @@ -1,1083 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -package firrtl.backends.experimental.rtlil - -import java.io.Writer -import firrtl._ -import firrtl.PrimOps._ -import firrtl.ir._ -import firrtl.Utils.{throwInternalError, _} -import firrtl.WrappedExpression._ -import firrtl.traversals.Foreachers._ -import firrtl.annotations._ -import firrtl.options.Viewer.view -import firrtl.options.{CustomFileEmission, Dependency} -import firrtl.passes.LowerTypes -import firrtl.passes.MemPortUtils.memPortField -import firrtl.stage.{FirrtlOptions, TransformManager} - -import scala.annotation.tailrec -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer -import scala.language.postfixOps - -case class EmittedRtlilCircuitAnnotation(name: String, value: String, outputSuffix: String) - extends NoTargetAnnotation - with CustomFileEmission { - override protected def baseFileName(annotations: AnnotationSeq): String = - view[FirrtlOptions](annotations).outputFileName.getOrElse(name) - override protected def suffix: Option[String] = Some(outputSuffix) - override def getBytes: Iterable[Byte] = value.getBytes -} -case class EmittedRtlilModuleAnnotation(name: String, value: String, outputSuffix: String) - extends NoTargetAnnotation - with CustomFileEmission { - override protected def baseFileName(annotations: AnnotationSeq): String = - view[FirrtlOptions](annotations).outputFileName.getOrElse(name) - override protected def suffix: Option[String] = Some(outputSuffix) - override def getBytes: Iterable[Byte] = value.getBytes -} - -private[firrtl] class RtlilEmitter extends SeqTransform with Emitter with DependencyAPIMigration { - - override def prerequisites: Seq[TransformManager.TransformDependency] = - Seq( - Dependency[firrtl.transforms.CombineCats], - Dependency(firrtl.passes.memlib.VerilogMemDelays) - ) ++: firrtl.stage.Forms.LowFormOptimized - - override def outputSuffix: String = ".il" - val tab = " " - - override def transforms: Seq[Transform] = new TransformManager(prerequisites).flattenedTransformOrder - - def emit(state: CircuitState, writer: Writer): Unit = { - val cs = runTransforms(state) - val emissionOptions = new EmissionOptions(cs.annotations) - val moduleMap = cs.circuit.modules.map(m => m.name -> m).toMap - cs.circuit.modules.foreach { - case DescribedMod(d, pds, m: Module) => - val renderer = new RtlilRender(d, pds, m, moduleMap, cs.circuit.main, emissionOptions)(writer) - renderer.emit_rtlil() - case m: Module => - val renderer = new RtlilRender(m, moduleMap, cs.circuit.main, emissionOptions)(writer) - renderer.emit_rtlil() - case _ => // do nothing - } - } - - override def execute(state: CircuitState): CircuitState = { - val writerToString = - (writer: java.io.StringWriter) => writer.toString.replaceAll("""(?m) +$""", "") // trim trailing whitespace - - val newAnnos = state.annotations.flatMap { - case EmitCircuitAnnotation(a) if this.getClass == a => - val writer = new java.io.StringWriter - emit(state, writer) - Seq( - EmittedRtlilModuleAnnotation(state.circuit.main, writerToString(writer), outputSuffix) - ) - - case EmitAllModulesAnnotation(a) if this.getClass == a => - val cs = runTransforms(state) - val emissionOptions = new EmissionOptions(cs.annotations) - val moduleMap = cs.circuit.modules.map(m => m.name -> m).toMap - - cs.circuit.modules.flatMap { - case DescribedMod(d, pds, module: Module) => - val writer = new java.io.StringWriter - val renderer = new RtlilRender(d, pds, module, moduleMap, cs.circuit.main, emissionOptions)(writer) - renderer.emit_rtlil() - Some( - EmittedRtlilModuleAnnotation(module.name, writerToString(writer), outputSuffix) - ) - case module: Module => - val writer = new java.io.StringWriter - val renderer = new RtlilRender(module, moduleMap, cs.circuit.main, emissionOptions)(writer) - renderer.emit_rtlil() - Some( - EmittedRtlilModuleAnnotation(module.name, writerToString(writer), outputSuffix) - ) - case _ => None - } - case _ => Seq() - } - state.copy(annotations = newAnnos ++ state.annotations) - } - - private class RtlilRender( - description: Seq[Description], - portDescriptions: Map[String, Seq[Description]], - m: Module, - moduleMap: Map[String, DefModule], - circuitName: String, - emissionOptions: EmissionOptions - )( - implicit writer: Writer) { - def this( - m: Module, - moduleMap: Map[String, DefModule], - circuitName: String, - emissionOptions: EmissionOptions - )( - implicit writer: Writer - ) = { - this(Seq(), Map.empty, m, moduleMap, circuitName, emissionOptions)(writer) - } - - private val netlist: mutable.LinkedHashMap[WrappedExpression, InfoExpr] = mutable.LinkedHashMap() - private val namespace: Namespace = Namespace(m) - - private val portdefs: ArrayBuffer[Seq[Any]] = ArrayBuffer[Seq[Any]]() - private val declares: ArrayBuffer[Seq[Any]] = ArrayBuffer() - private val instdeclares: mutable.Map[String, InstInfo] = mutable.Map() - private val assigns: ArrayBuffer[Seq[Any]] = ArrayBuffer() - private val attachSynAssigns: ArrayBuffer[Seq[Any]] = ArrayBuffer() - private val processes: ArrayBuffer[Seq[Any]] = ArrayBuffer() - // Used to determine type of initvar for initializing memories - private val initials: ArrayBuffer[Seq[Any]] = ArrayBuffer() - private val formals: ArrayBuffer[Seq[Any]] = ArrayBuffer() - private val moduleTarget: ModuleTarget = CircuitTarget(circuitName).module(m.name) - - private def getLeadingTabs(x: Any): String = { - x match { - case seq: Seq[_] => - val head = seq.takeWhile(_ == tab).mkString - val tail = seq.dropWhile(_ == tab).headOption.map(getLeadingTabs).getOrElse(tab) - head + tail - case _ => tab - } - } - - private def emit(x: Any)(implicit w: Writer): Unit = { - this.emitCol(x, 0, getLeadingTabs(x))(writer) - } - - private def emit(x: Any, top: Int)(implicit w: Writer): Unit = { - emitCol(x, top, "")(writer) - } - - private def emitCol(x: Any, top: Int, tabs: String)(implicit w: Writer): Unit = { - x match { - case e: SrcInfo => w.write(e.str_rep) - case e: Reference => w.write(ref_to_name(e)) - case e: ValidIf => emitCol(Seq(e.value), top + 1, tabs)(writer) - case e: WSubField => w.write(SrcInfo(e).str_rep) - case e: WSubAccess => - w.write("\\" + s"${LowerTypes.loweredName(e.expr)} [ ${LowerTypes.loweredName(e.index)} ]") - case e: Literal => w.write(bigint_to_str_rep(e.value, get_type_width(e.tpe))) - case t: GroundType => w.write(stringify(t)) - case t: VectorType => - emit(t.tpe, top + 1)(writer) - w.write(s"[${t.size - 1}:0]") - case s: String => w.write(s) - case i: Int => w.write(i.toString) - case i: Long => w.write(i.toString) - case i: BigInt => w.write(bigint_to_str_rep(i, if (i > 0) i.bitLength else i.bitLength + 1)) - case i: Info => - infos_to_attr(i) match { - case Some(attr) => - w.write(attr) - case None => - } - case s: Seq[Any] => - s.foreach { e => emitCol(e, top + 1, tabs)(writer) } - if (top == 0) - w.write("\n") - case x => throwInternalError(s"trying to emit unsupported operation: $x") - } - } - - private def build_netlist(s: Statement): Unit = { - s.foreach(build_netlist) - s match { - case sx: Connect => netlist(sx.loc) = InfoExpr(sx.info, sx.expr) - case _: IsInvalid => error("Should have removed these!") - // TODO Since only register update and memories use the netlist anymore, I think nodes are unnecessary - case sx: DefNode => - val e = WRef(sx.name, sx.value.tpe, NodeKind, SourceFlow) - netlist(e) = InfoExpr(sx.info, sx.value) - case _ => - } - } - - @tailrec - private def remove_root(ex: Expression): Expression = ex match { - case ex: WSubField => - ex.expr match { - case e: WSubField => remove_root(e) - case _: WRef => WRef(ex.name, ex.tpe, InstanceKind, UnknownFlow) - } - case _ => throwInternalError(s"shouldn't be here: remove_root($ex)") - } - - private def stringify(tpe: GroundType): String = tpe match { - case _: UIntType | _: AnalogType => - val wx = bitWidth(tpe) - if (wx > 1) s"width $wx" else "" - case _: SIntType => - val wx = bitWidth(tpe) - if (wx > 1) s"signed width $wx" else "signed" - case ClockType | AsyncResetType => "" - case _ => throwInternalError(s"trying to write unsupported type in the Rtlil Emitter: $tpe") - } - - private def stringify(param: Param): String = param match { - case IntParam(name, value) => - val lit = - if (value.isValidInt) { - s"$value" - } else { - val blen = value.bitLength - if (value > 0) s"$blen'd$value" else s"-${blen + 1}'sd${value.abs}" - } - s"parameter \\$name $lit" - case DoubleParam(name, value) => s"parameter \\$name $value" - case StringParam(name, value) => s"parameter \\$name ${value.verilogEscape}" - case RawStringParam(name, value) => s"parameter \\$name $value" - } - - // turn strings into Seq[String] verilog comments - private def build_comment(desc: String): Seq[Seq[String]] = { - val lines = desc.split("\n").toSeq - lines.tail.map { - case "" => Seq("#") - case nonEmpty => Seq("#", nonEmpty) - } - } - private def build_attribute(attr: String): Seq[Seq[String]] = { - Seq(Seq("attribute \\") ++ Seq(attr)) - } - - private def build_description(d: Seq[Description]): Seq[Seq[String]] = d.flatMap { - case DocString(desc) => build_comment(desc.string) - case Attribute(attr) => build_attribute(attr.string) - } - - // Turn ports into Seq[String] and add to portdefs - private def build_ports(): Unit = { - def padToMax(strs: Seq[String]): Seq[String] = { - val len = if (strs.nonEmpty) strs.map(_.length).max else 0 - strs.map(_.padTo(len, ' ')) - } - - // Turn directions into strings (and AnalogType into inout) - val dirs = m.ports.map { - case Port(_, _, dir, tpe) => - (dir, tpe) match { - case (_, AnalogType(_)) => "inout " // padded to length of output - case (Input, _) => "input " - case (Output, _) => "output" - } - } - // Turn types into strings, all ports must be GroundTypes - val tpes = m.ports.map { - case Port(_, _, _, tpe: GroundType) => stringify(tpe) - case port: Port => error(s"Trying to emit non-GroundType Port $port") - } - - // dirs are already padded - (dirs, padToMax(tpes), m.ports).zipped.toSeq.zipWithIndex.foreach { - case ((dir, tpe, Port(info, name, _, _)), i) => - portDescriptions.get(name).map { d => - portdefs += Seq("") - portdefs ++= build_description(d) - } - portdefs += Seq("wire ", tpe, " ", dir, " ", i + 1, " \\", name, info) - } - } - - private def infos_to_attr(info: Info): Option[String] = { - def info_extract(info: Info, prev: Seq[String] = Seq()): Seq[String] = info match { - case FileInfo(str) => - val (file, line, col) = FileInfo(str).split - prev :+ (file + ":" + line + "." + col) - case MultiInfo(infos) => - infos.foldLeft(prev)((a, b) => { - info_extract(b, a) - }) - case NoInfo => - prev - } - val srcinfo = info_extract(info) - if (srcinfo.isEmpty) - Option.empty - else - Option("attribute \\src \"" + srcinfo.mkString("|") + "\"") - } - - private def string_to_rtlil_name(name: String): String = { - if (name.head == '_') { - "$" + name - } else { - "\\" + name - } - } - - private def ref_to_name(ref: Reference): String = { - string_to_rtlil_name(ref.name) - } - - private def regUpdate(r: Expression, clk: Expression, reset: Expression, init: Expression) = { - val procName = namespace.newName("$process$" + this.m.name) - val regTempName = "\\" + r.serialize + procName - val loweredReset = SrcInfo(reset) - val loweredClk = SrcInfo(clk) - val loweredInit = SrcInfo(init) - val loweredReg = SrcInfo(r) - def addUpdate(info: Info, expr: Expression, tabs: Seq[String]): Seq[Seq[Any]] = expr match { - case m: Mux => - if (m.tpe == ClockType) throw EmitterException("Cannot emit clock muxes directly") - if (m.tpe == AsyncResetType) throw EmitterException("Cannot emit async reset muxes directly") - - val (eninfo, tinfo, finfo) = MultiInfo.demux(info) - lazy val _if: Seq[Seq[Any]] = - Seq(Seq(tabs, eninfo), Seq(tabs, "switch ", SrcInfo(m.cond, eninfo).str_rep)) ++ ( - if (infos_to_attr(tinfo).nonEmpty) - Seq(Seq(tabs, tab, tinfo), Seq(tabs, tab, "case 1'1")) - else - Seq(Seq(tabs, tab, "case 1'1")) - ) - lazy val _else: Seq[Seq[Any]] = infos_to_attr(finfo) match { - case Some(_) => - Seq(Seq(tabs, tab, finfo), Seq(tabs, tab, "case")) - case None => - Seq(Seq(tabs, tab, "case")) - } - lazy val _ifNot: Seq[Seq[Any]] = - Seq(Seq(tabs, eninfo), Seq(tabs, "switch ", SrcInfo(m.cond, eninfo).str_rep)) ++ ( - if (infos_to_attr(finfo).nonEmpty) - Seq(Seq(tabs, tab, finfo), Seq(tabs, tab, "case 1'0")) - else - Seq(Seq(tabs, tab, "case 1'0")) - ) - lazy val _end = Seq(Seq(tabs, "end")) - lazy val _true = addUpdate(tinfo, m.tval, Seq(tab, tab) ++ tabs) - lazy val _false = addUpdate(finfo, m.fval, Seq(tab, tab) ++ tabs) - /* For a Mux assignment, there are five possibilities, with one subcase for asynchronous reset: - * 1. Both the true and false condition are self-assignments; do nothing - * 2. The true condition is a self-assignment; invert the false condition and use that only - * 3. The false condition is a self-assignment - * a) The reset is asynchronous; emit both 'if' and a trivial 'else' to avoid latches - * b) The reset is synchronous; skip the false condition - * 4. The false condition is a Mux; use the true condition and use 'else if' for the false condition - * 5. Default; use both the true and false conditions - */ - (m.tval, m.fval) match { - case (t, f) if weq(t, r) && weq(f, r) => Nil - case (t, _) if weq(t, r) => _ifNot ++ _false ++ _end - case (_, f) if weq(f, r) => - m.cond.tpe match { - case AsyncResetType => (_if ++ _true ++ _else) ++ _true ++ _end - case _ => _if ++ _true ++ _end - } - case _ => (_if ++ _true ++ _else) ++ _false ++ _end - } - case e => - Seq(Seq(tabs, "assign ", regTempName, " ", SrcInfo(e, info).str_rep)) - } - if (weq(init, r)) { // Synchronous Reset - val InfoExpr(info, e) = netlist(r) - processes += Seq(info) - processes += Seq("wire ", r.tpe, " ", regTempName) - processes += Seq("process ", procName) - processes += Seq("assign ", regTempName, " ", loweredInit.str_rep) - processes ++= addUpdate(info, e, Seq(tab)) - processes += Seq(tab, "sync posedge ", clk) - processes += Seq(tab, tab, "update ", SrcInfo(r).str_rep, " ", regTempName) - processes += Seq("end") - } else { // Asynchronous Reset - assert(reset.tpe == AsyncResetType, "Error! Synchronous reset should have been removed!") - val tv = init - val InfoExpr(finfo, fv) = netlist(r) - processes += Seq(finfo) - processes += Seq("wire ", r.tpe, " ", regTempName) - processes += Seq("process ", procName) - processes += Seq("assign ", regTempName, " ", loweredInit.str_rep) - processes ++= addUpdate(NoInfo, Mux(reset, tv, fv, mux_type_and_widths(tv, fv)), Seq.empty) - processes += Seq("sync posedge ", loweredClk.str_rep) - processes += Seq(tab, "update ", loweredReset.str_rep, " ", regTempName) - processes += Seq("sync posedge ", reset) - processes += Seq(tab, "update ", loweredReg.str_rep, " ", regTempName) - processes += Seq("end") - } - } - - private def bigint_to_str_rep(bigInt: BigInt, width: BigInt): String = { - if (width > 31) { - var bigboi = bigInt - var widthcnt = width - var concatlist: Seq[String] = List() - - while (widthcnt > 32) { - val lowbits = bigboi & 0xffffffff - concatlist = concatlist :+ "%d'%s".format(32, lowbits.toString(2)) - bigboi >>= 32 - widthcnt -= 32 - } - concatlist = concatlist :+ "%d'%s".format(widthcnt, bigboi.toString(2)) - "{ " + concatlist.reverse.mkString(" ") + " }" - } else - "%d'%s".format(width, bigInt.toString(2)) - } - - private case class InstInfo(inst_name: String, mod_name: String, info: Info) { - val conns: mutable.Map[String, String] = mutable.Map() - var params: Seq[String] = Seq() - def getConnection(port: String): Option[String] = { - conns.get(port) - } - def addConnection(port: String, targetValue: String): Unit = { - conns(port) = targetValue - } - } - - private case class SrcInfo(str_rep: String, signed: Boolean, width: BigInt) - private object SrcInfo { - def apply(e: Expression, i: Info = NoInfo): SrcInfo = e match { - case InfoExpr(info, expr) => - SrcInfo(expr, MultiInfo(info, i)) - case x: Reference => - SrcInfo(ref_to_name(x), x.tpe.isInstanceOf[SIntType], get_type_width(x.tpe)) - case x: Literal => - val width = x.width.asInstanceOf[IntWidth].width - SrcInfo(bigint_to_str_rep(x.value, width), x.isInstanceOf[SIntLiteral], width) - case x @ DoPrim(op, args, consts, tpe) => - op match { - case Cat => - SrcInfo( - Seq(" { ", args.map(SrcInfo(_).str_rep).mkString(" "), " }").mkString, - tpe.isInstanceOf[SIntType], - get_type_width(tpe) - ) - case Head => - val src0 = SrcInfo(args.head) - SrcInfo( - Seq(src0.str_rep, " [", (src0.width - 1).toInt, ":", consts.head.toInt, "]").mkString, - tpe.isInstanceOf[SIntType], - get_type_width(tpe) - ) - case Tail => - val src0 = SrcInfo(args.head) - SrcInfo( - Seq(src0.str_rep, " [", (src0.width - 1 - consts.head).toInt, ":0]").mkString, - tpe.isInstanceOf[SIntType], - get_type_width(tpe) - ) - case Pad => - val src0 = SrcInfo(args.head) - if (src0.width >= consts.head) - SrcInfo( - Seq(src0.str_rep, " [", (consts.head - 1).toInt, ":0]").mkString, - tpe.isInstanceOf[SIntType], - get_type_width(tpe) - ) - else if (src0.signed) - SrcInfo( - Seq( - " { ", - s"${src0.str_rep} [${src0.width - 1}] " * (consts.head - src0.width).toInt, - src0.str_rep, - " }" - ).mkString, - tpe.isInstanceOf[SIntType], - get_type_width(tpe) - ) - else - SrcInfo( - Seq(" { ", (consts.head - src0.width).toInt, "'0 ", src0.str_rep, " }").mkString, - tpe.isInstanceOf[SIntType], - get_type_width(tpe) - ) - case _ => - val tempNetName = namespace.newName("$_PRIM_EX") - if (infos_to_attr(i).nonEmpty) declares += Seq(i) - declares += Seq("wire ", x.tpe, " ", tempNetName) - assigns ++= output_expr(tempNetName, x, i) - SrcInfo(tempNetName, x.tpe.isInstanceOf[SIntType], get_type_width(x.tpe)) - } - case x @ SubField(Reference(modname, _, InstanceKind, _), portname, _, _) => - val currentPortConn = instdeclares(modname).getConnection(portname) - if (currentPortConn.isEmpty) { - val tempNetName = "\\" + LowerTypes.loweredName(x) - if (infos_to_attr(i).nonEmpty) declares += Seq(i) - declares += Seq("wire ", x.tpe, " ", tempNetName) - instdeclares(modname).addConnection(portname, tempNetName) - SrcInfo(tempNetName, x.tpe.isInstanceOf[SIntType], get_type_width(x.tpe)) - } else { - SrcInfo(currentPortConn.get, x.tpe.isInstanceOf[SIntType], get_type_width(x.tpe)) - } - case x: SubField => - SrcInfo("\\" + LowerTypes.loweredName(x), x.tpe.isInstanceOf[SIntType], get_type_width(x.tpe)) - case x: Mux => - val tempNetName = namespace.newName("$_MUX_EX") - if (infos_to_attr(i).nonEmpty) declares += Seq(i) - declares += Seq("wire ", x.tpe, " ", tempNetName) - assigns ++= output_expr(tempNetName, e, i) - SrcInfo(tempNetName, x.tpe.isInstanceOf[SIntType], get_type_width(x.tpe)) - case x => - throw EmitterException(s"Internal error! unhandled value $x passed to SrcInfo()") - } - } - - private def emit_streams(): Unit = { - build_description(description).foreach(emit(_)) - emit(Seq("# Generated by firrtl.RtlilEmitter (FIRRTL Version ", BuildInfo.version + ")")) - emit(Seq("autoidx 1")) - emit(Seq("attribute \\cells_not_processed 1")) - emit(Seq("module \\", m.name, m.info)) - for (x <- portdefs) emit(Seq(tab, x)) - for (x <- declares) emit(Seq(tab, x)) - for ((_, x) <- instdeclares) { - emit(Seq(tab, "attribute \\module_not_derived 1")) - emit(Seq(tab, x.info)) - emit(Seq(tab, "cell \\", x.mod_name, " \\", x.inst_name)) - for (p <- x.params) emit(Seq(tab, tab, p)) - for ((a, b) <- x.conns) emit(Seq(tab, tab, "connect \\", a, " ", b)) - emit(Seq(tab, "end")) - } - for (x <- assigns) emit(Seq(tab, x)) - for (x <- processes) emit(Seq(tab, x)) - for (x <- attachSynAssigns) emit(Seq(tab, x)) - for (x <- initials) emit(Seq(tab, x)) - emit(Seq("end")) - emit(Seq()) - } - - private def primop_to_cell(p: PrimOp): String = p match { - case Not => "$not" - case Neg => "$neg" - case Andr => "$reduce_and" - case Orr => "$reduce_or" - case Xorr => "$reduce_xor" - case And => "$and" - case Or => "$or" - case Xor => "$xor" - case Shl => "$shl" - case Dshl => "$shl" - case Eq => "$eq" - case Lt => "$lt" - case Leq => "$le" - case Neq => "$ne" - case Geq => "$ge" - case Gt => "$gt" - case Add => "$add" - case Addw => "$add" - case Sub => "$sub" - case Subw => "$sub" - case Mul => "$mul" - case Div => "$div" - case Rem => "$rem" - case _ => - throwInternalError( - "Internal Error! primop %s shouldn't have propagated this far!".format(p.serialize) - ); - } - - private def unary_cells = List("$not", "$neg", "$reduce_and", "$reduce_or", "$reduce_xor") - private def get_type_width(e: Type): BigInt = { // just trust me bro, its lofirrtl - e.asInstanceOf[GroundType].width.asInstanceOf[IntWidth].width - } - - private def emit_cell( - i: Info, - name: String, - params: Seq[(String, String)], - connections: Seq[(String, String)] - ): Seq[Seq[Any]] = { - Seq(Seq(i), Seq("cell ", name, " ", namespace.newName(name + "$" + m.name))) ++ - params.map { p => Seq(tab, "parameter \\", p._1, " ", p._2) } ++ - connections.map { c => Seq(tab, "connect \\", c._1, " ", c._2) } ++ - Seq(Seq("end")) - } - - private def emit_unary_cell(cell: String, src: SrcInfo, target: String, tgt_width: BigInt): Seq[Seq[Any]] = { - emit_cell( - NoInfo, - cell, - Seq( - ( - "A_SIGNED", - if (src.signed) { "1" } - else { "0" } - ), - ("A_WIDTH", src.width.toString), - ("Y_WIDTH", tgt_width.toString) - ), - Seq(("A", src.str_rep), ("Y", target)) - ) - } - - private def emit_binary_cell( - cell: String, - src_a: SrcInfo, - src_b: SrcInfo, - target: String, - tgt_width: BigInt - ): Seq[Seq[Any]] = { - emit_cell( - NoInfo, - cell, - Seq( - ( - "A_SIGNED", - if (src_a.signed) "1" else "0" - ), - ("A_WIDTH", src_a.width.toString), - ( - "B_SIGNED", - if (src_b.signed) "1" else "0" - ), - ("B_WIDTH", src_b.width.toString), - ("Y_WIDTH", tgt_width.toString) - ), - Seq(("A", src_a.str_rep), ("B", src_b.str_rep), ("Y", target)) - ) - } - - @tailrec - private def output_expr(n: String, d: Expression, i: Info): Seq[Seq[Any]] = d match { - case UIntLiteral(_, _) | SIntLiteral(_, _) | Reference(_, _, _, _) | SubField(_, _, _, _) => - Seq(Seq("connect ", n, " ", SrcInfo(d, i).str_rep)) - case InfoExpr(info, expr) => - output_expr(n, expr, MultiInfo(Seq(i, info))) - case Mux(cond, tval, fval, tpe) => - val (eninfo, tinfo, finfo) = MultiInfo.demux(i) - val csrc = SrcInfo(cond, eninfo) - val tsrc = SrcInfo(tval, tinfo) - val fsrc = SrcInfo(fval, finfo) - emit_cell( - i, - "$mux", - Seq(("WIDTH", get_type_width(tpe).toString)), - Seq(("A", fsrc.str_rep), ("B", tsrc.str_rep), ("S", csrc.str_rep), ("Y", n)) - ) - case DoPrim(op, args, consts, _) => - val sources = args.map(SrcInfo(_, i)) - val src0 = sources.head - if (sources.map(_.width).contains(-1)) return Seq() - op match { - case AsSInt | AsUInt | AsClock | AsAsyncReset => - Seq(Seq("connect ", n, " ", src0)) - case Cvt => - if (src0.signed) - Seq(Seq("connect ", n, " ", src0)) - else - Seq(Seq("connect ", n, " { 1'0 ", src0, " }")) - case Bits => - if (consts.head == consts.last) - Seq(Seq("connect ", n, " ", src0, " [", consts.head.toInt, "]")) - else - Seq(Seq("connect ", n, " ", src0, " [", consts.head.toInt, ":", consts.last.toInt, "]")) - case Shr | Shl => - val prim = if (op == Shr) (if (src0.signed) "$sshr" else "$shr") else "$shl" - emit_binary_cell( - prim, - src0, - SrcInfo(bigint_to_str_rep(consts.head, consts.head.bitLength), signed = false, consts.head.bitLength), - n, - get_type_width(d.tpe) - ) - case Add => - if (src0.signed && sources(1).signed) { - val src0_ext = SrcInfo(s"{ ${src0.str_rep} [${src0.width - 1}] ${src0.str_rep} }", true, src0.width + 1) - val src1_ext = SrcInfo( - s"{ ${sources(1).str_rep} [${sources(1).width - 1}] ${sources(1).str_rep} }", - true, - sources(1).width + 1 - ) - emit_binary_cell("$add", src0_ext, src1_ext, n, get_type_width(d.tpe)) - } else { - emit_binary_cell("$add", src0, sources(1), n, get_type_width(d.tpe)) - } - case Dshr | Dshl => - val prim = if (op == Dshr) (if (src0.signed) "$sshr" else "$shr") else "$shl" - emit_binary_cell(prim, src0, sources(1), n, get_type_width(d.tpe)) - case Cat => - Seq(Seq("connect ", n, " { ", sources.map(_.str_rep).mkString(" "), " }")) - case Head => - Seq(Seq("connect ", n, " ", src0, " [", (src0.width - 1).toInt, ":", consts.head.toInt, "]")) - case Tail => - Seq(Seq("connect ", n, " ", src0, " [", (src0.width - 1 - consts.head).toInt, ":0]")) - case Pad => - if (src0.width >= consts.head) - Seq(Seq("connect ", n, " ", src0, " [", (consts.head - 1).toInt, ":0]")) - else if (src0.signed) - Seq( - Seq("connect ", n) ++ - Seq( - " { ", - s"${src0.str_rep} [${src0.width - 1}] " * (consts.head - src0.width).toInt, - src0.str_rep, - " }" - ) - ) - else - Seq(Seq("connect ", n, " { ", (consts.head - src0.width).toInt, "'0 ", src0, " }")) - case _ => - val cell = primop_to_cell(op) - if (unary_cells.contains(cell)) - Seq(i) +: emit_unary_cell(cell, src0, n, get_type_width(d.tpe)) - else - Seq(i) +: emit_binary_cell(cell, src0, sources(1), n, get_type_width(d.tpe)) - } - case unk => - throw EmitterException(s"Internal error! unhandled output expression $unk passed to output_expr()") - } - - private def build_streams(s: Statement): Unit = { - val withoutDescription = s match { - case DescribedStmt(d, stmt) => - stmt match { - case _: IsDeclaration => - declares ++= build_description(d) - case _ => - } - stmt - case stmt => stmt - } - withoutDescription.foreach(build_streams) - withoutDescription match { - case DefInstance(info, name, mdle, _) => - val (module, params) = moduleMap(mdle) match { - case DescribedMod(_, _, ExtModule(_, _, _, extname, params)) => (extname, params) - case DescribedMod(_, _, Module(_, name, _, _)) => (name, Seq.empty) - case ExtModule(_, _, _, extname, params) => (extname, params) - case Module(_, name, _, _) => (name, Seq.empty) - } - instdeclares(name) = InstInfo(name, module, info) - instdeclares(name).params = if (params.nonEmpty) params.map(stringify) else Seq() - case WDefInstanceConnector(info, name, mdle, _, portCons) => - val (_, params) = moduleMap(mdle) match { - case DescribedMod(_, _, ExtModule(_, _, _, extname, params)) => (extname, params) - case DescribedMod(_, _, Module(_, name, _, _)) => (name, Seq.empty) - case ExtModule(_, _, _, extname, params) => (extname, params) - case Module(_, name, _, _) => (name, Seq.empty) - } - instdeclares(name) = InstInfo(name, mdle, info) - instdeclares(name).params = if (params.nonEmpty) params.map(stringify) else Seq() - for ((port, ref) <- portCons) { - val portName = SrcInfo(remove_root(port)).str_rep.tail - if (instdeclares(name).getConnection(portName).nonEmpty) { - assigns ++= output_expr(instdeclares(name).getConnection(portName).get, ref, NoInfo) - } else { - instdeclares(name).addConnection(SrcInfo(remove_root(port)).str_rep.tail, SrcInfo(ref).str_rep) - } - } - case Connect(info, loc @ WRef(_, _, PortKind | WireKind | InstanceKind, _), expr) => - assigns ++= output_expr(ref_to_name(loc), expr, info) - case Connect(info, SubField(Reference(modname, _, InstanceKind, _), portname, _, _), expr) => - if (instdeclares(modname).getConnection(portname).nonEmpty) { - assigns ++= output_expr(instdeclares(modname).getConnection(portname).get, expr, NoInfo) - } else { - instdeclares(modname).addConnection(portname, SrcInfo(expr, info).str_rep) - } - case sx: DefWire => - declares += Seq(sx.info) - declares += Seq("wire ", sx.tpe, " ", string_to_rtlil_name(sx.name)) - case sx: DefRegister => - val options = emissionOptions.getRegisterEmissionOption(moduleTarget.ref(sx.name)) - val e = WRef(sx.name, sx.tpe, ExpKind, UnknownFlow) - declares += Seq(sx.info) - declares += Seq("wire ", sx.tpe, " ", string_to_rtlil_name(sx.name)) - if (options.useInitAsPreset) - regUpdate(e, sx.clock, sx.reset, e) - else - regUpdate(e, sx.clock, sx.reset, sx.init) - case sx: DefNode => - declares += Seq(sx.info) - declares += Seq("wire ", sx.value.tpe, " ", string_to_rtlil_name(sx.name)) - assigns ++= output_expr(string_to_rtlil_name(sx.name), sx.value, sx.info) - case x @ Verification(value, info, _, pred, en, _) => - value match { - case Formal.Assert => - formals += emit_cell( - info, - "$assert", - Seq(), - Seq(("A", SrcInfo(pred).str_rep), ("EN", SrcInfo(en).str_rep)) - ) - case Formal.Assume => - formals += emit_cell( - info, - "$assume", - Seq(), - Seq(("A", SrcInfo(pred).str_rep), ("EN", SrcInfo(en).str_rep)) - ) - case Formal.Cover => - formals += emit_cell( - info, - "$cover", - Seq(), - Seq(("A", SrcInfo(pred).str_rep), ("EN", SrcInfo(en).str_rep)) - ) - } - case x @ DefMemory(i, name, tpe, depth, wlat, rlat, rd, wr, rdwr, runderw) => - val options = emissionOptions.getMemoryEmissionOption(moduleTarget.ref(name)) - val hasComplexRW = rdwr.nonEmpty && (rlat != 1) - if (rlat > 1 || wlat != 1 || hasComplexRW) - throw EmitterException( - Seq( - s"Memory $name is too complex to emit directly.", - "Consider running VerilogMemDelays to simplify complex memories.", - "Alternatively, add the --repl-seq-mem flag to replace memories with blackboxes." - ).mkString(" ") - ) - val dataWidth = bitWidth(tpe) - val maxDataValue = (BigInt(1) << dataWidth.toInt) - 1 - - def checkValueRange(value: BigInt, at: String): Unit = { - if (value > maxDataValue) - throw EmitterException( - s"Memory $at cannot be initialized with value: $value. Too large (> $maxDataValue)!" - ) - } - declares += Seq("memory width ", dataWidth.toString, " size ", depth.toString, " \\", name) - options.initValue match { - case MemoryArrayInit(values) => - values.zipWithIndex.foreach { - case (value, addr) => - checkValueRange(value, s"$name[$addr]") - initials ++= emit_cell( - i, - "$meminit_v2", - Seq( - ("MEMID", "\"\\\\" + name + "\""), - ("ABITS", "32"), - ("WIDTH", dataWidth.toString), - ("WORDS", "1"), - ("PRIORITY", addr.toString) - ), - Seq( - ("ADDR", addr.toString), - ("DATA", bigint_to_str_rep(value, dataWidth)), - ("EN", bigint_to_str_rep(BigInt(2).pow(dataWidth.toInt) - BigInt(1), dataWidth)) - ) - ) - } - - case MemoryScalarInit(value) => - for (addr <- 0 until depth.intValue) { - initials ++= emit_cell( - i, - "$meminit_v2", - Seq( - ("MEMID", "\"\\\\" + name + "\""), - ("ABITS", "32"), - ("WIDTH", dataWidth.toString), - ("WORDS", "1"), - ("PRIORITY", addr.toString) - ), - Seq( - ("ADDR", addr.toString), - ("DATA", bigint_to_str_rep(value, dataWidth)), - ("EN", bigint_to_str_rep(BigInt(2).pow(dataWidth.toInt) - BigInt(1), dataWidth)) - ) - ) - } - case MemoryRandomInit => - println(s"Memory $name cannot be initialized with random data, RTLIL cannot express this.") - println("Leaving memory uninitialized.") - case MemoryFileInlineInit(_, _) => - throw EmitterException(s"Memory $name cannot be initialized from a file, RTLIL cannot express this.") - case MemoryNoInit => - // No initialization to emit - } - for (r <- rd) { - val data = memPortField(x, r, "data") - val addr = memPortField(x, r, "addr") - val en = memPortField(x, r, "en") - val hasClk = if (rlat == 1) { "1'1" } - else { "1'0" } - val clkSrc = netlist(memPortField(x, r, "clk")).expr - val transparent = runderw match { - case ReadUnderWrite.New => "1'1" - case ReadUnderWrite.Old => "1'0" - case ReadUnderWrite.Undefined => "1'x" - } - declares += Seq("wire ", data.tpe, " ", SrcInfo(data).str_rep) - assigns ++= emit_cell( - i, - "$memrd", - Seq( - ("ABITS", get_type_width(addr.tpe).toString), - ("MEMID", "\"\\\\" + name + "\""), - ("WIDTH", get_type_width(data.tpe).toString), - ("CLK_ENABLE", hasClk), - ("CLK_POLARITY", "1'1"), - ("TRANSPARENT", transparent) - ), - Seq( - ("CLK", SrcInfo(clkSrc, i).str_rep), - ("EN", if (rlat == 1) SrcInfo(netlist(en), i).str_rep else "1'1"), - ("ADDR", SrcInfo(netlist(addr), i).str_rep), - ("DATA", SrcInfo(data, i).str_rep) - ) - ) - } - for (w <- wr) { - val data = memPortField(x, w, "data") - val addr = memPortField(x, w, "addr") - val en = memPortField(x, w, "en") - val mask = memPortField(x, w, "mask") - val enSrc = SrcInfo(netlist(en)) - val maskSrc = SrcInfo(netlist(mask)) - if (maskSrc.width > 1) { - throw EmitterException("Compound type memory write ports arent fully supported yet.") - } - var memwr_enmask = enSrc.str_rep - if (bitWidth(data.tpe) != 1) { - memwr_enmask = namespace.newName("$memwr_enmask$" + m.name) - declares += Seq("wire signed width ", bitWidth(data.tpe).toInt, " ", memwr_enmask) - assigns ++= emit_cell( - i, - "$and", - Seq( - ("A_SIGNED", "1"), - ("B_SIGNED", "1"), - ("A_WIDTH", bitWidth(en.tpe).toString()), - ("B_WIDTH", maskSrc.width.toString()), - ("Y_WIDTH", bitWidth(data.tpe).toString()) - ), - Seq(("A", enSrc.str_rep), ("B", maskSrc.str_rep), ("Y", memwr_enmask)) - ) - } - val hasClk = if (wlat == 1) { "1'1" } - else { "1'0" } - val clkSrc = netlist(memPortField(x, w, "clk")).expr - assigns ++= emit_cell( - i, - "$memwr", - Seq( - ("ABITS", get_type_width(addr.tpe).toString), - ("MEMID", "\"\\\\" + name + "\""), - ("WIDTH", get_type_width(data.tpe).toString), - ("CLK_ENABLE", hasClk), - ("CLK_POLARITY", "1'1"), - ("PRIORITY", "32'1") - ), - Seq( - ("CLK", SrcInfo(clkSrc).str_rep), - ("EN", memwr_enmask), - ("ADDR", SrcInfo(netlist(addr)).str_rep), - ("DATA", SrcInfo(netlist(data)).str_rep) - ) - ) - } - case sx: Attach => - for (set <- sx.exprs.toSet.subsets(2)) { - val (a, b) = set.toSeq match { - case Seq(x, y) => (x, y) - } - attachSynAssigns += Seq("connect ", SrcInfo(a, sx.info).str_rep, " ", SrcInfo(b, sx.info).str_rep) - } - case _ => - } - } - - def emit_rtlil(): DefModule = { - build_netlist(m.body) - build_ports() - build_streams(m.body) - emit_streams() - m - } - } -} - -private[firrtl] class EmissionOptionMap[V <: EmissionOption](val df: V) { - private val m = collection.mutable.HashMap[ReferenceTarget, V]().withDefaultValue(df) - def +=(elem: (ReferenceTarget, V)): EmissionOptionMap.this.type = { - if (m.contains(elem._1)) - throw EmitterException(s"Multiple EmissionOption for the target ${elem._1} (${m(elem._1)} ; ${elem._2})") - m += elem - this - } - def apply(key: ReferenceTarget): V = m.apply(key) -} - -private[firrtl] class EmissionOptions(annotations: AnnotationSeq) { - // Private so that we can present an immutable API - private val memoryEmissionOption = new EmissionOptionMap[MemoryEmissionOption]( - annotations.collectFirst { case a: CustomDefaultMemoryEmission => a }.getOrElse(MemoryEmissionOptionDefault) - ) - private val registerEmissionOption = new EmissionOptionMap[RegisterEmissionOption]( - annotations.collectFirst { case a: CustomDefaultRegisterEmission => a }.getOrElse(RegisterEmissionOptionDefault) - ) - private val wireEmissionOption = new EmissionOptionMap[WireEmissionOption](WireEmissionOptionDefault) - private val portEmissionOption = new EmissionOptionMap[PortEmissionOption](PortEmissionOptionDefault) - private val nodeEmissionOption = new EmissionOptionMap[NodeEmissionOption](NodeEmissionOptionDefault) - private val connectEmissionOption = new EmissionOptionMap[ConnectEmissionOption](ConnectEmissionOptionDefault) - - def getMemoryEmissionOption(target: ReferenceTarget): MemoryEmissionOption = - memoryEmissionOption(target) - - def getRegisterEmissionOption(target: ReferenceTarget): RegisterEmissionOption = - registerEmissionOption(target) - - def getWireEmissionOption(target: ReferenceTarget): WireEmissionOption = - wireEmissionOption(target) - - def getPortEmissionOption(target: ReferenceTarget): PortEmissionOption = - portEmissionOption(target) - - def getNodeEmissionOption(target: ReferenceTarget): NodeEmissionOption = - nodeEmissionOption(target) - - def getConnectEmissionOption(target: ReferenceTarget): ConnectEmissionOption = - connectEmissionOption(target) - - def emitMemoryInitAsNoSynth: Boolean = { - val annos = annotations.collect { case a @ (MemoryNoSynthInit | MemorySynthInit) => a } - annos match { - case Seq() => true - case Seq(MemoryNoSynthInit) => true - case Seq(MemorySynthInit) => false - case _ => - throw new FirrtlUserException( - "There should only be at most one memory initialization option annotation, got $other" - ) - } - } - - private val emissionAnnos = annotations.collect { - case m: SingleTargetAnnotation[ReferenceTarget] @unchecked with EmissionOption => m - } - - annotations.foreach { - case a: Annotation if a.dedup.nonEmpty => - val (_, _, target) = a.dedup.get - if (!target.isLocal) { - throw new FirrtlUserException( - s"At least one dedupable annotation did not deduplicate: got non-local annotation $a from [[DedupAnnotationsTransform]]" - ) - } - case _ => - } - - // using multiple foreach instead of a single partial function as an Annotation can gather multiple EmissionOptions for simplicity - emissionAnnos.foreach { - case a: MemoryEmissionOption => memoryEmissionOption += ((a.target, a)) - case _ => - } - emissionAnnos.foreach { - case a: RegisterEmissionOption => registerEmissionOption += ((a.target, a)) - case _ => - } - emissionAnnos.foreach { - case a: WireEmissionOption => wireEmissionOption += ((a.target, a)) - case _ => - } - emissionAnnos.foreach { - case a: PortEmissionOption => portEmissionOption += ((a.target, a)) - case _ => - } - emissionAnnos.foreach { - case a: NodeEmissionOption => nodeEmissionOption += ((a.target, a)) - case _ => - } - emissionAnnos.foreach { - case a: ConnectEmissionOption => connectEmissionOption += ((a.target, a)) - case _ => - } -} diff --git a/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala deleted file mode 100644 index 37f9228f..00000000 --- a/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala +++ /dev/null @@ -1,253 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> - -package firrtl.backends.experimental.smt - -import scala.collection.mutable - -object Btor2Serializer { - def serialize(sys: TransitionSystem, skipOutput: Boolean = false): Iterable[String] = { - new Btor2Serializer().run(sys, skipOutput) - } -} - -private class Btor2Serializer private () { - private val symbols = mutable.HashMap[String, Int]() - private val lines = mutable.ArrayBuffer[String]() - private var index = 1 - - private def line(l: String): Int = { - val ii = index - lines += s"$ii $l" - index += 1 - ii - } - - private def comment(c: String): Unit = { lines += s"; $c" } - private def trailingComment(c: String): Unit = { - val lastLine = lines.last - val newLine = if (lastLine.contains(';')) { lastLine + " " + c } - else { lastLine + " ; " + c } - lines(lines.size - 1) = newLine - } - - // bit vector type serialization - private val bitVecTypeCache = mutable.HashMap[Int, Int]() - - private def t(width: Int): Int = bitVecTypeCache.getOrElseUpdate(width, line(s"sort bitvec $width")) - - // bit vector expression serialization - private def s(expr: BVExpr): Int = expr match { - case BVLiteral(value, width) => lit(value, width) - case BVSymbol(name, _) => symbols.getOrElse(name, throw new RuntimeException(s"Unknown symbol: $name")) - case BVExtend(e, 0, _) => s(e) - case BVExtend(e, by, true) => line(s"sext ${t(expr.width)} ${s(e)} $by") - case BVExtend(e, by, false) => line(s"uext ${t(expr.width)} ${s(e)} $by") - case BVSlice(e, hi, lo) => - if (lo == 0 && hi == e.width - 1) { s(e) } - else { - line(s"slice ${t(expr.width)} ${s(e)} $hi $lo") - } - case BVNot(BVEqual(a, b)) => binary("neq", expr.width, a, b) - case BVNot(BVNot(e)) => s(e) - case BVNot(e) => unary("not", expr.width, e) - case BVNegate(e) => unary("neg", expr.width, e) - case BVReduceAnd(e) => unary("redand", expr.width, e) - case BVReduceOr(e) => unary("redor", expr.width, e) - case BVReduceXor(e) => unary("redxor", expr.width, e) - case BVImplies(BVLiteral(v, 1), b) if v == 1 => s(b) - case BVImplies(a, b) => binary("implies", expr.width, a, b) - case BVEqual(a, b) => binary("eq", expr.width, a, b) - case ArrayEqual(a, b) => line(s"eq ${t(expr.width)} ${s(a)} ${s(b)}") - case BVComparison(Compare.Greater, a, b, false) => binary("ugt", expr.width, a, b) - case BVComparison(Compare.GreaterEqual, a, b, false) => binary("ugte", expr.width, a, b) - case BVComparison(Compare.Greater, a, b, true) => binary("sgt", expr.width, a, b) - case BVComparison(Compare.GreaterEqual, a, b, true) => binary("sgte", expr.width, a, b) - case BVOp(op, a, b) => binary(s(op), expr.width, a, b) - case BVConcat(a, b) => binary("concat", expr.width, a, b) - case call: BVFunctionCall => s(functionCallToArrayRead(call)) - case ArrayRead(array, index) => - line(s"read ${t(expr.width)} ${s(array)} ${s(index)}") - case BVIte(cond, tru, fals) => - line(s"ite ${t(expr.width)} ${s(cond)} ${s(tru)} ${s(fals)}") - case b @ BVAnd(terms) => variadic("and", b.width, terms) - case b @ BVOr(terms) => variadic("or", b.width, terms) - case forall: BVForall => - throw new RuntimeException(s"Quantifiers are not supported by the btor2 format: ${forall}") - } - - private def s(op: Op.Value): String = op match { - case Op.Xor => "xor" - case Op.ArithmeticShiftRight => "sra" - case Op.ShiftRight => "srl" - case Op.ShiftLeft => "sll" - case Op.Add => "add" - case Op.Mul => "mul" - case Op.Sub => "sub" - case Op.SignedDiv => "sdiv" - case Op.UnsignedDiv => "udiv" - case Op.SignedMod => "smod" - case Op.SignedRem => "srem" - case Op.UnsignedRem => "urem" - } - - private def unary(op: String, width: Int, e: BVExpr): Int = line(s"$op ${t(width)} ${s(e)}") - - private def binary(op: String, width: Int, a: BVExpr, b: BVExpr): Int = - line(s"$op ${t(width)} ${s(a)} ${s(b)}") - - private def variadic(op: String, width: Int, terms: List[BVExpr]): Int = terms match { - case Seq() | Seq(_) => throw new RuntimeException(s"expected at least two elements in variadic op $op") - case Seq(a, b) => binary(op, width, a, b) - case head :: tail => - val tailId = variadic(op, width, tail) - line(s"$op ${t(width)} ${s(head)} ${tailId}") - } - - private def lit(value: BigInt, w: Int): Int = { - val typ = t(w) - lazy val mask = (BigInt(1) << w) - 1 - if (value == 0) line(s"zero $typ") - else if (value == 1) line(s"one $typ") - else if (value == mask) line(s"ones $typ") - else { - val digits = value.toString(2) - val padded = digits.reverse.padTo(w, '0').reverse - line(s"const $typ $padded") - } - } - - // array type serialization - private val arrayTypeCache = mutable.HashMap[(Int, Int), Int]() - - private def t(indexWidth: Int, dataWidth: Int): Int = - arrayTypeCache.getOrElseUpdate((indexWidth, dataWidth), line(s"sort array ${t(indexWidth)} ${t(dataWidth)}")) - - // array expression serialization - private def s(expr: ArrayExpr): Int = expr match { - case ArraySymbol(name, _, _) => symbols(name) - case ArrayStore(array, index, data) => - line(s"write ${t(expr.indexWidth, expr.dataWidth)} ${s(array)} ${s(index)} ${s(data)}") - case ArrayIte(cond, tru, fals) => - // println("WARN: ITE on array is probably not supported by btor2") - // While the spec does not seem to allow array ite, it seems to be supported in practice. - // It is essential to model memories, so any support in the wild should be fairly well tested. - line(s"ite ${t(expr.indexWidth, expr.dataWidth)} ${s(cond)} ${s(tru)} ${s(fals)}") - case ArrayConstant(e, indexWidth) => - // The problem we are facing here is that the only way to create a constant array from a bv expression - // seems to be to use the bv expression as the init value of a state variable. - // Thus we need to create a fake state for every array init expression. - arrayConstants.getOrElseUpdate( - e.toString, { - comment(s"$expr") - val eId = s(e) - val tpeId = t(indexWidth, e.width) - val state = line(s"state $tpeId") - line(s"init $tpeId $state $eId") - state - } - ) - case f: ArrayFunctionCall => - throw new RuntimeException(s"The btor2 format does not support uninterpreted functions that return arrays!: $f") - } - private val arrayConstants = mutable.HashMap[String, Int]() - - private def s(expr: SMTExpr): Int = expr match { - case b: BVExpr => s(b) - case a: ArrayExpr => s(a) - } - - // serialize the type of the expression - private def t(expr: SMTExpr): Int = expr match { - case b: BVExpr => t(b.width) - case a: ArrayExpr => t(a.indexWidth, a.dataWidth) - } - - private def functionCallToArrayRead(call: BVFunctionCall): BVExpr = { - if (call.args.isEmpty) { - BVSymbol(call.name, call.width) - } else { - val args: List[BVExpr] = call.args.map { - case b: BVExpr => b - case other => throw new RuntimeException(s"Unsupported call argument: $other in $call") - } - val index = concat(args) - val a = ArraySymbol(call.name, indexWidth = index.width, dataWidth = call.width) - ArrayRead(a, index) - } - } - private def concat(e: Iterable[BVExpr]): BVExpr = { - require(e.nonEmpty) - e.reduce((a, b) => BVConcat(a, b)) - } - - def run(sys: TransitionSystem, skipOutput: Boolean): Iterable[String] = { - def declare(name: String, lbl: Option[SignalLabel], expr: => Int): Unit = { - assert(!symbols.contains(name), s"Trying to redeclare `$name`") - val id = expr - symbols(name) = id - // add label - lbl match { - case Some(IsOutput) => if (!skipOutput) line(s"output $id ; $name") - case Some(IsConstraint) => line(s"constraint $id ; $name") - case Some(IsBad) => line(s"bad $id ; $name") - case Some(IsFair) => line(s"fair $id ; $name") - case _ => - } - // add trailing comment - sys.comments.get(name).foreach(trailingComment) - } - - // header - if (sys.header.nonEmpty) { - sys.header.split('\n').foreach(comment) - } - - // declare inputs - sys.inputs.foreach { ii => - declare(ii.name, None, line(s"input ${t(ii.width)} ${ii.name}")) - } - - // declare uninterpreted functions a constant arrays - val ufs = TransitionSystem.findUninterpretedFunctions(sys) - ufs.foreach { foo => - // only functions returning bit-vectors are supported! - val bvSym = foo.sym.asInstanceOf[BVSymbol] - val sym = if (foo.args.isEmpty) { bvSym } - else { - ArraySymbol(bvSym.name, foo.args.map(_.asInstanceOf[BVExpr].width).sum, bvSym.width) - } - comment(foo.toString) - declare(sym.name, None, line(s"state ${t(sym)} ${sym.name}")) - line(s"next ${t(sym)} ${s(sym)} ${s(sym)}") - } - - // define state init - sys.states.foreach { st => - // calculate init expression before declaring the state - // this is required by btormc (presumably to avoid cycles in the init expression) - val initId = st.init.map { - // only in the context of initializing a state can we use a bv expression to model an array - case ArrayConstant(e, _) => comment(s"${st.sym}.init"); s(e) - case init => comment(s"${st.sym}.init"); s(init) - } - declare(st.sym.name, None, line(s"state ${t(st.sym)} ${st.sym.name}")) - st.init.foreach { init => line(s"init ${t(init)} ${s(st.sym)} ${initId.get}") } - } - - // define all other signals - sys.signals.foreach { signal => - declare(signal.name, Some(signal.lbl), s(signal.e)) - } - - // define state next - sys.states.foreach { st => - st.next.foreach { next => - comment(s"${st.sym}.next") - line(s"next ${t(next)} ${s(st.sym)} ${s(next)}") - } - } - - lines - } -} diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala deleted file mode 100644 index 865382c9..00000000 --- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala +++ /dev/null @@ -1,179 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> - -package firrtl.backends.experimental.smt - -import firrtl.ir -import firrtl.PrimOps -import firrtl.passes.CheckWidths.WidthTooBig - -private object FirrtlExpressionSemantics { - def toSMT(e: ir.Expression): BVExpr = { - val eSMT = e match { - case ir.DoPrim(op, args, consts, _) => onPrim(op, args, consts) - case r: ir.RefLikeExpression => BVSymbol(r.serialize, getWidth(r)) - case ir.UIntLiteral(value, ir.IntWidth(width)) => BVLiteral(value, width.toInt) - case ir.SIntLiteral(value, ir.IntWidth(width)) => - val twosComplementValue = value & ((BigInt(1) << width.toInt) - 1) - BVLiteral(twosComplementValue, width.toInt) - case ir.Mux(cond, tval, fval, _) => - val width = List(tval, fval).map(getWidth).max - BVIte(toSMT(cond), toSMT(tval, width), toSMT(fval, width)) - case v: ir.ValidIf => - throw new RuntimeException(s"Unsupported expression: ValidIf ${v.serialize}") - } - assert( - eSMT.width == getWidth(e), - "We aim to always produce a SMT expression of the same width as the firrtl expression." - ) - eSMT - } - - /** Ensures that the result has the desired width by appropriately extending it. */ - def toSMT(e: ir.Expression, width: Int, allowNarrow: Boolean = false): BVExpr = - forceWidth(toSMT(e), isSigned(e), width, allowNarrow) - - private def forceWidth(eSMT: BVExpr, eSigned: Boolean, width: Int, allowNarrow: Boolean = false): BVExpr = { - if (eSMT.width == width) { eSMT } - else if (width < eSMT.width) { - assert(allowNarrow, s"Narrowing from ${eSMT.width} bits to $width bits is not allowed!") - BVSlice(eSMT, width - 1, 0) - } else { - BVExtend(eSMT, width - eSMT.width, eSigned) - } - } - - // see "Primitive Operations" section in the Firrtl Specification - private def onPrim( - op: ir.PrimOp, - args: Seq[ir.Expression], - consts: Seq[BigInt] - ): BVExpr = { - (op, args, consts) match { - case (PrimOps.Add, Seq(e1, e2), _) => - val width = args.map(getWidth).max + 1 - BVOp(Op.Add, toSMT(e1, width), toSMT(e2, width)) - case (PrimOps.Sub, Seq(e1, e2), _) => - val width = args.map(getWidth).max + 1 - BVOp(Op.Sub, toSMT(e1, width), toSMT(e2, width)) - case (PrimOps.Mul, Seq(e1, e2), _) => - val width = args.map(getWidth).sum - BVOp(Op.Mul, toSMT(e1, width), toSMT(e2, width)) - case (PrimOps.Div, Seq(num, den), _) => - val signed = isSigned(num) - val resWidth = if (signed) { getWidth(num) + 1 } - else { getWidth(num) } - val op = if (signed) { Op.SignedDiv } - else { Op.UnsignedDiv } - // we do the calculation on the widened values and then narrow the result if needed - val width = args.map(getWidth).max + (if (signed) 1 else 0) - val res = BVOp(op, toSMT(num, width), toSMT(den, width)) - forceWidth(res, signed, resWidth, allowNarrow = true) - case (PrimOps.Rem, Seq(num, den), _) => - val signed = isSigned(num) - val op = if (signed) Op.SignedRem else Op.UnsignedRem - val width = args.map(getWidth).max - val resWidth = args.map(getWidth).min - val res = BVOp(op, toSMT(num, width), toSMT(den, width)) - forceWidth(res, signed, resWidth, allowNarrow = true) - case (PrimOps.Lt, Seq(e1, e2), _) => - val width = args.map(getWidth).max - BVNot(BVComparison(Compare.GreaterEqual, toSMT(e1, width), toSMT(e2, width), isSigned(e1))) - case (PrimOps.Leq, Seq(e1, e2), _) => - val width = args.map(getWidth).max - BVNot(BVComparison(Compare.Greater, toSMT(e1, width), toSMT(e2, width), isSigned(e1))) - case (PrimOps.Gt, Seq(e1, e2), _) => - val width = args.map(getWidth).max - BVComparison(Compare.Greater, toSMT(e1, width), toSMT(e2, width), isSigned(e1)) - case (PrimOps.Geq, Seq(e1, e2), _) => - val width = args.map(getWidth).max - BVComparison(Compare.GreaterEqual, toSMT(e1, width), toSMT(e2, width), isSigned(e1)) - case (PrimOps.Eq, Seq(e1, e2), _) => - val width = args.map(getWidth).max - BVEqual(toSMT(e1, width), toSMT(e2, width)) - case (PrimOps.Neq, Seq(e1, e2), _) => - val width = args.map(getWidth).max - BVNot(BVEqual(toSMT(e1, width), toSMT(e2, width))) - case (PrimOps.Pad, Seq(e), Seq(n)) => - val width = getWidth(e) - if (n <= width) { toSMT(e) } - else { BVExtend(toSMT(e), n.toInt - width, isSigned(e)) } - case (PrimOps.AsUInt, Seq(e), _) => checkForClockInCast(PrimOps.AsUInt, e); toSMT(e) - case (PrimOps.AsSInt, Seq(e), _) => checkForClockInCast(PrimOps.AsSInt, e); toSMT(e) - case (PrimOps.AsFixedPoint, Seq(e), _) => throw new AssertionError("Fixed-Point numbers need to be lowered!") - case (PrimOps.AsClock, Seq(e), _) => toSMT(e) - case (PrimOps.AsAsyncReset, Seq(e), _) => - checkForClockInCast(PrimOps.AsAsyncReset, e) - throw new AssertionError(s"Asynchronous resets are not supported! Cannot cast ${e.serialize}.") - case (PrimOps.Shl, Seq(e), Seq(n)) => - if (n == 0) { toSMT(e) } - else { - val zeros = BVLiteral(0, n.toInt) - BVConcat(toSMT(e), zeros) - } - case (PrimOps.Shr, Seq(e), Seq(n)) => - val width = getWidth(e) - // "If n is greater than or equal to the bit-width of e, - // the resulting value will be zero for unsigned types - // and the sign bit for signed types" - if (n >= width) { - if (isSigned(e)) { BVSlice(toSMT(e), width - 1, width - 1) } - else { BV1BitZero } - } else { - BVSlice(toSMT(e), width - 1, n.toInt) - } - case (PrimOps.Dshl, Seq(e1, e2), _) => - val width = getWidth(e1) + (1 << getWidth(e2)) - 1 - BVOp(Op.ShiftLeft, toSMT(e1, width), toSMT(e2, width)) - case (PrimOps.Dshr, Seq(e1, e2), _) => - val width = getWidth(e1) - val o = if (isSigned(e1)) Op.ArithmeticShiftRight else Op.ShiftRight - BVOp(o, toSMT(e1, width), toSMT(e2, width)) - case (PrimOps.Cvt, Seq(e), _) => - if (isSigned(e)) { toSMT(e) } - else { BVConcat(BV1BitZero, toSMT(e)) } - case (PrimOps.Neg, Seq(e), _) => BVNegate(BVExtend(toSMT(e), 1, isSigned(e))) - case (PrimOps.Not, Seq(e), _) => BVNot(toSMT(e)) - case (PrimOps.And, Seq(e1, e2), _) => - val width = args.map(getWidth).max - BVAnd(toSMT(e1, width), toSMT(e2, width)) - case (PrimOps.Or, Seq(e1, e2), _) => - val width = args.map(getWidth).max - BVOr(toSMT(e1, width), toSMT(e2, width)) - case (PrimOps.Xor, Seq(e1, e2), _) => - val width = args.map(getWidth).max - BVOp(Op.Xor, toSMT(e1, width), toSMT(e2, width)) - case (PrimOps.Andr, Seq(e), _) => BVReduceAnd(toSMT(e)) - case (PrimOps.Orr, Seq(e), _) => BVReduceOr(toSMT(e)) - case (PrimOps.Xorr, Seq(e), _) => BVReduceXor(toSMT(e)) - case (PrimOps.Cat, Seq(e1, e2), _) => BVConcat(toSMT(e1), toSMT(e2)) - case (PrimOps.Bits, Seq(e), Seq(hi, lo)) => BVSlice(toSMT(e), hi.toInt, lo.toInt) - case (PrimOps.Head, Seq(e), Seq(n)) => - val width = getWidth(e) - assert(n >= 0 && n <= width) - BVSlice(toSMT(e), width - 1, width - n.toInt) - case (PrimOps.Tail, Seq(e), Seq(n)) => - val width = getWidth(e) - assert(n >= 0 && n <= width) - assert(n < width, "While allowed by the firrtl standard, we do not support 0-bit values in this backend!") - BVSlice(toSMT(e), width - n.toInt - 1, 0) - } - } - - /** For now we strictly forbid casting clocks to anything else. - * Eventually this should be replaced by a more sophisticated clock analysis pass. - */ - private def checkForClockInCast(cast: ir.PrimOp, signal: ir.Expression): Unit = { - assert(signal.tpe != ir.ClockType, s"Cannot cast (${cast.serialize}) clock expression ${signal.serialize}!") - } - - private val BV1BitZero = BVLiteral(0, 1) - - private def isSigned(e: ir.Expression): Boolean = e.tpe match { - case _: ir.SIntType => true - case _ => false - } - - // Helper function - private def getWidth(e: ir.Expression): Int = firrtl.bitWidth(e.tpe).toInt -} diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala deleted file mode 100644 index 7da2e1e6..00000000 --- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala +++ /dev/null @@ -1,379 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> - -package firrtl.backends.experimental.smt - -import firrtl.annotations.{MemoryInitAnnotation, NoTargetAnnotation, PresetRegAnnotation} -import firrtl._ -import firrtl.backends.experimental.smt.random._ -import firrtl.options.Dependency -import firrtl.passes.MemPortUtils.memPortField -import firrtl.passes.PassException -import firrtl.passes.memlib.VerilogMemDelays -import firrtl.stage.Forms -import firrtl.stage.TransformManager.TransformDependency -import firrtl.transforms.{EnsureNamedStatements, PropagatePresetAnnotations} -import logger.LazyLogging - -import scala.collection.mutable - -case class TransitionSystemAnnotation(sys: TransitionSystem) extends NoTargetAnnotation - -/** Contains code to convert a flat firrtl module into a functional transition system which - * can then be exported as SMTLib or Btor2 file. - */ -object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration { - override def prerequisites: Seq[Dependency[Transform]] = Forms.LowForm ++ - Seq( - Dependency(VerilogMemDelays), - Dependency(EnsureNamedStatements), // this is required to give assert/assume statements good names - Dependency[PropagatePresetAnnotations] - ) - override def invalidates(a: Transform): Boolean = false - // since this pass only runs on the main module, inlining needs to happen before - override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency[firrtl.passes.InlineInstances]) - - override protected def execute(state: CircuitState): CircuitState = { - val circuit = state.circuit - val presetRegs = state.annotations.collect { - case PresetRegAnnotation(target) if target.module == circuit.main => target.ref - }.toSet - - // collect all non-random memory initialization - val memInit = state.annotations.collect { case a: MemoryInitAnnotation if !a.isRandomInit => a } - .filter(_.target.module == circuit.main) - .map(a => a.target.ref -> a.initValue) - .toMap - - // module look up table - val modules = circuit.modules.map(m => m.name -> m).toMap - - // collect uninterpreted module annotations - val uninterpreted = state.annotations.collect { - case a: UninterpretedModuleAnnotation => - UninterpretedModuleAnnotation.checkModule(modules(a.target.module), a) - a.target.module -> a - }.toMap - - // convert the main module - val main = modules(circuit.main) - val sys = main match { - case _: ir.ExtModule => - throw new ExtModuleException( - "External modules are not supported by the SMT backend. Use yosys if you need to convert Verilog." - ) - case m: ir.Module => - new ModuleToTransitionSystem(presetRegs = presetRegs, memInit = memInit, uninterpreted = uninterpreted).run(m) - } - - val sortedSys = TopologicalSort.run(sys) - val anno = TransitionSystemAnnotation(sortedSys) - state.copy(circuit = circuit, annotations = state.annotations :+ anno) - } -} - -private object UnsupportedException { - val HowToRunStuttering: String = - """ - |You can run the StutteringClockTransform which - |replaces all clock inputs with a clock enable signal. - |This is required not only for multi-clock designs, but also to - |accurately model asynchronous reset which could happen even if there - |isn't a clock edge. - | If you are using the firrtl CLI, please add: - | -fct firrtl.backends.experimental.smt.StutteringClockTransform - | If you are calling into firrtl programmatically you can use: - | RunFirrtlTransformAnnotation(Dependency[StutteringClockTransform]) - | To designate a clock to be the global_clock (i.e. the simulation tick), use: - | GlobalClockAnnotation(CircuitTarget(...).module(...).ref("your_clock"))) - |""".stripMargin -} - -private class ExtModuleException(s: String) extends PassException(s) -private class AsyncResetException(s: String) extends PassException(s + UnsupportedException.HowToRunStuttering) -private class MultiClockException(s: String) extends PassException(s + UnsupportedException.HowToRunStuttering) -private class MissingFeatureException(s: String) - extends PassException("Unfortunately the SMT backend does not yet support: " + s) - -private class ModuleToTransitionSystem( - presetRegs: Set[String], - memInit: Map[String, MemoryInitValue], - uninterpreted: Map[String, UninterpretedModuleAnnotation]) - extends LazyLogging { - - def run(m: ir.Module): TransitionSystem = { - // first pass over the module to convert expressions; discover state and I/O - m.foreachPort(onPort) - m.foreachStmt(onStatement) - - // multi-clock support requires the StutteringClock transform to be run - if (clocks.size > 1) { - throw new MultiClockException(s"The module ${m.name} has more than one clock: ${clocks.mkString(", ")}") - } - - // generate comments from infos - val comments = mutable.HashMap[String, String]() - infos.foreach { - case (name, info) => - val infoStr = info.serialize.trim - if (infoStr.nonEmpty) { - val prefix = comments.get(name).map(_ + ", ").getOrElse("") - comments(name) = prefix + infoStr - } - } - - // module info to the comment header - val header = m.info.serialize.trim - - TransitionSystem(m.name, inputs.toList, states.values.toList, signals.toList, comments.toMap, header) - } - - private val inputs = mutable.ArrayBuffer[BVSymbol]() - private val clocks = mutable.ArrayBuffer[String]() - private val signals = mutable.ArrayBuffer[Signal]() - private val states = mutable.LinkedHashMap[String, State]() - private val infos = mutable.ArrayBuffer[(String, ir.Info)]() - - private def onPort(p: ir.Port): Unit = { - if (isAsyncReset(p.tpe)) { - throw new AsyncResetException(s"Found AsyncReset ${p.name}.") - } - infos.append(p.name -> p.info) - p.direction match { - case ir.Input => - if (isClock(p.tpe)) { - clocks.append(p.name) - } else { - inputs.append(BVSymbol(p.name, bitWidth(p.tpe).toInt)) - } - case ir.Output => - } - } - - private def onStatement(s: ir.Statement): Unit = s match { - case DefRandom(info, name, tpe, _, en) => - assert(!isClock(tpe), "rand should never be a clock!") - // we model random sources as inputs and the enable signal as output - infos.append(name -> info) - inputs.append(BVSymbol(name, bitWidth(tpe).toInt)) - signals.append(Signal(name + ".en", onExpression(en, 1), IsOutput)) - case w: ir.DefWire => - if (!isClock(w.tpe)) { - // InlineInstances can insert wires without re-running RemoveWires for now we just deal with it when - // the Wires is connected to (ir.Connect). - } - case ir.DefNode(info, name, expr) => - if (!isClock(expr.tpe) && !isAsyncReset(expr.tpe)) { - infos.append(name -> info) - signals.append(Signal(name, onExpression(expr), IsNode)) - } - case r: ir.DefRegister => - infos.append(r.name -> r.info) - states(r.name) = onRegister(r) - case m: ir.DefMemory => - infos.append(m.name -> m.info) - states(m.name) = onMemory(m) - case ir.Connect(info, loc, expr) => - if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!") - if (!isClock(loc.tpe) && !isAsyncReset(expr.tpe)) { // we ignore clock connections - val name = loc.serialize - val e = onExpression(expr, bitWidth(loc.tpe).toInt, allowNarrow = false) - Utils.kind(loc) match { - case RegKind => states(name) = states(name).copy(next = Some(e)) - case PortKind | InstanceKind => // module output or submodule input - infos.append(name -> info) - signals.append(Signal(name, e, IsOutput)) - case MemKind | WireKind => - // InlineInstances can insert wires without re-running RemoveWires for now we just deal with it. - infos.append(name -> info) - signals.append(Signal(name, e, IsNode)) - } - } - case i: ir.IsInvalid => - throw new UnsupportedFeatureException(s"IsInvalid statements are not supported: ${i.serialize}") - case ir.DefInstance(info, name, module, tpe) => onInstance(info, name, module, tpe) - case s: ir.Verification => - if (s.op == ir.Formal.Cover) { - logger.info(s"[info] Cover statement was ignored: ${s.serialize}") - } else { - val name = s.name - val predicate = onExpression(s.pred) - val enabled = onExpression(s.en) - val e = BVImplies(enabled, predicate) - infos.append(name -> s.info) - val signal = if (s.op == ir.Formal.Assert) { - Signal(name, BVNot(e), IsBad) - } else { - Signal(name, e, IsConstraint) - } - signals.append(signal) - } - case s: ir.Conditionally => - error(s"When conditions are not supported. Please run ExpandWhens: ${s.serialize}") - case s: ir.PartialConnect => - error(s"PartialConnects are not supported. Please run ExpandConnects: ${s.serialize}") - case s: ir.Attach => - error(s"Analog wires are not supported in the SMT backend: ${s.serialize}") - case s: ir.Stop => - if (s.ret == 0) { - logger.info( - s"[info] Stop statements with a return code of 0 are currently not supported. Ignoring: ${s.serialize}" - ) - } else { - // we treat Stop statements with a non-zero exit value as assertions that en will always be false! - val name = s.name - infos.append(name -> s.info) - signals.append(Signal(name, onExpression(s.en), IsBad)) - } - case s: ir.Print => - logger.info(s"Info: ignoring: ${s.serialize}") - case other => other.foreachStmt(onStatement) - } - - private def onRegister(r: ir.DefRegister): State = { - val width = bitWidth(r.tpe).toInt - val resetExpr = onExpression(r.reset, 1) - assert(resetExpr == False(), s"Expected reset expression of ${r.name} to be 0, not $resetExpr") - val initExpr = onExpression(r.init, width) - val sym = BVSymbol(r.name, width) - val hasReset = initExpr != sym - val isPreset = presetRegs.contains(r.name) - assert(!isPreset || hasReset, s"Expected preset register ${r.name} to have a reset value, not just $initExpr!") - val state = State(sym, if (isPreset) Some(initExpr) else None, None) - state - } - - private def onInstance(info: ir.Info, name: String, module: String, tpe: ir.Type): Unit = { - if (!tpe.isInstanceOf[ir.BundleType]) error(s"Instance $name of $module has an invalid type: ${tpe.serialize}") - if (uninterpreted.contains(module)) { - onUninterpretedInstance(info: ir.Info, name: String, module: String, tpe: ir.Type) - } else { - // We treat all instances that aren't annotated as uninterpreted as blackboxes - // this means that their outputs could be any value, no matter what their inputs are. - logger.warn( - s"WARN: treating instance $name of $module as blackbox. " + - "Please flatten your hierarchy if you want to include submodules in the formal model." - ) - val ports = tpe.asInstanceOf[ir.BundleType].fields - // skip async reset ports - ports.filterNot(p => isAsyncReset(p.tpe)).foreach { p => - if (!p.tpe.isInstanceOf[ir.GroundType]) error(s"Instance $name of $module has an invalid port type: $p") - val isOutput = p.flip == ir.Default - val pName = name + "." + p.name - infos.append(pName -> info) - // outputs of the submodule become inputs to our module - if (isOutput) { - if (isClock(p.tpe)) { - clocks.append(pName) - } else { - inputs.append(BVSymbol(pName, bitWidth(p.tpe).toInt)) - } - } - } - } - } - - private def onUninterpretedInstance(info: ir.Info, instanceName: String, module: String, tpe: ir.Type): Unit = { - val anno = uninterpreted(module) - - // sanity checks for ports were done already using the UninterpretedModule.checkModule function - val ports = tpe.asInstanceOf[ir.BundleType].fields - - val outputs = ports.filter(_.flip == ir.Default).map(p => BVSymbol(p.name, bitWidth(p.tpe).toInt)) - val inputs = ports.filterNot(_.flip == ir.Default).map(p => BVSymbol(p.name, bitWidth(p.tpe).toInt)) - - assert(anno.stateBits == 0, "TODO: implement support for uninterpreted stateful modules!") - - // for state-less (i.e. combinatorial) circuits, the outputs only depend on the inputs - val args = inputs.map(i => BVSymbol(instanceName + "." + i.name, i.width)).toList - outputs.foreach { out => - val functionName = anno.prefix + "." + out.name - val call = BVFunctionCall(functionName, args, out.width) - val wireName = instanceName + "." + out.name - signals.append(Signal(wireName, call)) - } - } - - private def onMemory(m: ir.DefMemory): State = { - checkMem(m) - - // derive the type of the memory from the dataType and depth - val dataWidth = bitWidth(m.dataType).toInt - val indexWidth = Utils.getUIntWidth(m.depth - 1).max(1) - val memSymbol = ArraySymbol(m.name, indexWidth, dataWidth) - - // there could be a constant init - val init = memInit.get(m.name).map(getMemInit(m, indexWidth, dataWidth, _)) - init.foreach(e => assert(e.dataWidth == memSymbol.dataWidth && e.indexWidth == memSymbol.indexWidth)) - - // derive next state expression - val next = if (m.writers.isEmpty) { - memSymbol - } else { - m.writers.foldLeft[ArrayExpr](memSymbol) { - case (prev, write) => - // update - val addr = BVSymbol(memPortField(m, write, "addr").serialize, indexWidth) - val data = BVSymbol(memPortField(m, write, "data").serialize, dataWidth) - val update = ArrayStore(prev, index = addr, data = data) - - // update guard - val en = BVSymbol(memPortField(m, write, "en").serialize, 1) - val mask = BVSymbol(memPortField(m, write, "mask").serialize, 1) - ArrayIte(BVAnd(en, mask), update, prev) - } - } - - val state = State(memSymbol, init, Some(next)) - - // derive read expressions - val readSignals = m.readers.map { read => - val addr = BVSymbol(memPortField(m, read, "addr").serialize, indexWidth) - Signal(memPortField(m, read, "data").serialize, ArrayRead(memSymbol, addr), IsNode) - } - signals ++= readSignals - - state - } - - private def getMemInit(m: ir.DefMemory, indexWidth: Int, dataWidth: Int, initValue: MemoryInitValue): ArrayExpr = - initValue match { - case MemoryScalarInit(value) => ArrayConstant(BVLiteral(value, dataWidth), indexWidth) - case MemoryArrayInit(values) => - assert( - values.length == m.depth, - s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!" - ) - // in order to get a more compact encoding try to find the most common values - val histogram = mutable.LinkedHashMap[BigInt, Int]() - values.foreach(v => histogram(v) = 1 + histogram.getOrElse(v, 0)) - val baseValue = histogram.maxBy(_._2)._1 - val base = ArrayConstant(BVLiteral(baseValue, dataWidth), indexWidth) - values.zipWithIndex - .filterNot(_._1 == baseValue) - .foldLeft[ArrayExpr](base) { - case (array, (value, index)) => - ArrayStore(array, BVLiteral(index, indexWidth), BVLiteral(value, dataWidth)) - } - case other => throw new RuntimeException(s"Unsupported memory init option: $other") - } - - private def checkMem(m: ir.DefMemory): Unit = { - assert(m.readLatency == 0, "Expected read latency to be 0. Did you run VerilogMemDelays?") - assert(m.writeLatency == 1, "Expected read latency to be 1. Did you run VerilogMemDelays?") - assert( - m.dataType.isInstanceOf[ir.GroundType], - s"Memory $m is of type ${m.dataType} which is not a ground type!" - ) - assert(m.readwriters.isEmpty, "Combined read/write ports are not supported! Please split them up.") - } - - private def onExpression(e: ir.Expression, width: Int, allowNarrow: Boolean = false): BVExpr = - FirrtlExpressionSemantics.toSMT(e, width, allowNarrow) - private def onExpression(e: ir.Expression): BVExpr = FirrtlExpressionSemantics.toSMT(e) - - private def error(msg: String): Unit = throw new RuntimeException(msg) - private def isGroundType(tpe: ir.Type): Boolean = tpe.isInstanceOf[ir.GroundType] - private def isClock(tpe: ir.Type): Boolean = tpe == ir.ClockType - private def isAsyncReset(tpe: ir.Type): Boolean = tpe == ir.AsyncResetType -} diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTCommand.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTCommand.scala deleted file mode 100644 index 7b332b83..00000000 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTCommand.scala +++ /dev/null @@ -1,12 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> - -package firrtl.backends.experimental.smt - -sealed trait SMTCommand -case class Comment(msg: String) extends SMTCommand -case class SetLogic(logic: String) extends SMTCommand -case class DefineFunction(name: String, args: Seq[SMTFunctionArg], e: SMTExpr) extends SMTCommand -case class DeclareFunction(sym: SMTSymbol, args: Seq[SMTFunctionArg]) extends SMTCommand -case class DeclareUninterpretedSort(name: String) extends SMTCommand -case class DeclareUninterpretedSymbol(name: String, tpe: String) extends SMTCommand diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala deleted file mode 100644 index 45ec6898..00000000 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala +++ /dev/null @@ -1,81 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> - -package firrtl.backends.experimental.smt - -import java.io.Writer - -import firrtl._ -import firrtl.annotations.{Annotation, NoTargetAnnotation} -import firrtl.options.Viewer.view -import firrtl.options.{CustomFileEmission, Dependency} -import firrtl.stage.FirrtlOptions - -private[firrtl] abstract class SMTEmitter private[firrtl] () - extends Transform - with Emitter - with DependencyAPIMigration { - override def prerequisites: Seq[Dependency[Transform]] = Seq(Dependency(FirrtlToTransitionSystem)) - override def invalidates(a: Transform): Boolean = false - - override def emit(state: CircuitState, writer: Writer): Unit = error("Deprecated since firrtl 1.0!") - - protected def serialize(sys: TransitionSystem): Annotation - - override protected def execute(state: CircuitState): CircuitState = { - val emitCircuit = state.annotations.exists { - case EmitCircuitAnnotation(a) if this.getClass == a => true - case EmitAllModulesAnnotation(a) if this.getClass == a => error("EmitAllModulesAnnotation not supported!") - case _ => false - } - - if (!emitCircuit) { return state } - - val sys = state.annotations.collectFirst { case TransitionSystemAnnotation(sys) => sys }.getOrElse { - error("Could not find the transition system!") - } - state.copy(annotations = state.annotations :+ serialize(sys)) - } - - protected def generatedHeader(format: String, name: String): String = - s"; $format description generated by firrtl ${BuildInfo.version} for module $name.\n" - - protected def error(msg: String): Nothing = throw new RuntimeException(msg) -} - -case class EmittedSMTModelAnnotation(name: String, src: String, outputSuffix: String) - extends NoTargetAnnotation - with CustomFileEmission { - override protected def baseFileName(annotations: AnnotationSeq): String = - view[FirrtlOptions](annotations).outputFileName.getOrElse(name) - override protected def suffix: Option[String] = Some(outputSuffix) - override def getBytes: Iterable[Byte] = src.getBytes -} - -/** Turns the transition system generated by [[FirrtlToTransitionSystem]] into a btor2 file. */ -object Btor2Emitter extends SMTEmitter { - override def outputSuffix: String = ".btor2" - override protected def serialize(sys: TransitionSystem): Annotation = { - val btor = generatedHeader("BTOR", sys.name) + Btor2Serializer.serialize(sys).mkString("\n") + "\n" - EmittedSMTModelAnnotation(sys.name, btor, outputSuffix) - } -} - -/** Turns the transition system generated by [[FirrtlToTransitionSystem]] into an SMTLib file. */ -object SMTLibEmitter extends SMTEmitter { - override def outputSuffix: String = ".smt2" - override protected def serialize(sys: TransitionSystem): Annotation = { - val hasMemory = sys.states.exists(_.sym.isInstanceOf[ArrayExpr]) - val logic = if (hasMemory) "QF_AUFBV" else "QF_UFBV" - val logicCmd = SMTLibSerializer.serialize(SetLogic(logic)) + "\n" - val header = if (hasMemory) { - "; We have to disable the logic for z3 to accept the non-standard \"as const\"\n" + - "; see https://github.com/Z3Prover/z3/issues/1803\n" + - "; for CVC4 you probably want to include the logic\n" + - ";" + logicCmd - } else { logicCmd } - val smt = generatedHeader("SMT-LIBv2", sys.name) + header + - SMTTransitionSystemEncoder.encode(sys).map(SMTLibSerializer.serialize).mkString("\n") + "\n" - EmittedSMTModelAnnotation(sys.name, smt, outputSuffix) - } -} diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala deleted file mode 100644 index f2eae58a..00000000 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala +++ /dev/null @@ -1,342 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> -// Inspired by the uclid5 SMT library (https://github.com/uclid-org/uclid). -// And the btor2 documentation (BTOR2 , BtorMC and Boolector 3.0 by Niemetz et.al.) - -package firrtl.backends.experimental.smt - -/** base trait for all SMT expressions */ -sealed trait SMTExpr extends SMTFunctionArg { - def tpe: SMTType - def children: List[SMTExpr] -} -sealed trait SMTSymbol extends SMTExpr with SMTNullaryExpr { - def name: String - - /** keeps the type of the symbol while changing the name */ - def rename(newName: String): SMTSymbol -} -object SMTSymbol { - - /** makes a SMTSymbol of the same type as the expression */ - def fromExpr(name: String, e: SMTExpr): SMTSymbol = e match { - case b: BVExpr => BVSymbol(name, b.width) - case a: ArrayExpr => ArraySymbol(name, a.indexWidth, a.dataWidth) - } -} -sealed trait SMTNullaryExpr extends SMTExpr { - override def children: List[SMTExpr] = List() -} - -/** a SMT bit vector expression: https://smtlib.cs.uiowa.edu/theories-FixedSizeBitVectors.shtml */ -sealed trait BVExpr extends SMTExpr { - def width: Int - def tpe: BVType = BVType(width) - override def toString: String = SMTExprSerializer.serialize(this) -} -case class BVLiteral(value: BigInt, width: Int) extends BVExpr with SMTNullaryExpr { - private def minWidth = value.bitLength + (if (value <= 0) 1 else 0) - assert(value >= 0, "Negative values are not supported! Please normalize by calculating 2s complement.") - assert(width > 0, "Zero or negative width literals are not allowed!") - assert(width >= minWidth, "Value (" + value.toString + ") too big for BitVector of width " + width + " bits.") -} -object BVLiteral { - def apply(nums: String): BVLiteral = nums.head match { - case 'b' => BVLiteral(BigInt(nums.drop(1), 2), nums.length - 1) - } -} -case class BVSymbol(name: String, width: Int) extends BVExpr with SMTSymbol { - assert(!name.contains("|"), s"Invalid id $name contains escape character `|`") - assert(width > 0, "Zero width bit vectors are not supported!") - override def rename(newName: String) = BVSymbol(newName, width) -} - -sealed trait BVUnaryExpr extends BVExpr { - def e: BVExpr - - /** same function, different child, e.g.: not(x) -- reapply(Y) --> not(Y) */ - def reapply(expr: BVExpr): BVUnaryExpr - override def children: List[BVExpr] = List(e) -} -case class BVExtend(e: BVExpr, by: Int, signed: Boolean) extends BVUnaryExpr { - assert(by >= 0, "Extension must be non-negative!") - override val width: Int = e.width + by - override def reapply(expr: BVExpr) = BVExtend(expr, by, signed) -} -// also known as bit extract operation -case class BVSlice(e: BVExpr, hi: Int, lo: Int) extends BVUnaryExpr { - assert(lo >= 0, s"lo (lsb) must be non-negative!") - assert(hi >= lo, s"hi (msb) must not be smaller than lo (lsb): msb: $hi lsb: $lo") - assert(e.width > hi, s"Out off bounds hi (msb) access: width: ${e.width} msb: $hi") - override def width: Int = hi - lo + 1 - override def reapply(expr: BVExpr) = BVSlice(expr, hi, lo) -} -case class BVNot(e: BVExpr) extends BVUnaryExpr { - override val width: Int = e.width - override def reapply(expr: BVExpr) = new BVNot(expr) -} -case class BVNegate(e: BVExpr) extends BVUnaryExpr { - override val width: Int = e.width - override def reapply(expr: BVExpr) = BVNegate(expr) -} - -case class BVReduceOr(e: BVExpr) extends BVUnaryExpr { - override def width: Int = 1 - override def reapply(expr: BVExpr) = BVReduceOr(expr) -} -case class BVReduceAnd(e: BVExpr) extends BVUnaryExpr { - override def width: Int = 1 - override def reapply(expr: BVExpr) = BVReduceAnd(expr) -} -case class BVReduceXor(e: BVExpr) extends BVUnaryExpr { - override def width: Int = 1 - override def reapply(expr: BVExpr) = BVReduceXor(expr) -} - -sealed trait BVBinaryExpr extends BVExpr { - def a: BVExpr - def b: BVExpr - override def children: List[BVExpr] = List(a, b) - - /** same function, different child, e.g.: add(a,b) -- reapply(a,c) --> add(a,c) */ - def reapply(nA: BVExpr, nB: BVExpr): BVBinaryExpr -} -case class BVEqual(a: BVExpr, b: BVExpr) extends BVBinaryExpr { - assert(a.width == b.width, s"Both argument need to be the same width!") - override def width: Int = 1 - override def reapply(nA: BVExpr, nB: BVExpr) = BVEqual(nA, nB) -} -// added as a separate node because it is used a lot in model checking and benefits from pretty printing -class BVImplies(val a: BVExpr, val b: BVExpr) extends BVBinaryExpr { - assert(a.width == 1, s"The antecedent needs to be a boolean expression!") - assert(b.width == 1, s"The consequent needs to be a boolean expression!") - override def width: Int = 1 - override def reapply(nA: BVExpr, nB: BVExpr) = new BVImplies(nA, nB) -} -object BVImplies { - def apply(a: BVExpr, b: BVExpr): BVExpr = { - assert(a.width == b.width, s"Both argument need to be the same width!") - (a, b) match { - case (True(), b) => b // (!1 || b) = b - case (False(), _) => True() // (!0 || _) = (1 || _) = 1 - case (_, True()) => True() // (!a || 1) = 1 - case (a, False()) => BVNot(a) // (!a || 0) = !a - case (a, b) => new BVImplies(a, b) - } - } - def unapply(i: BVImplies): Some[(BVExpr, BVExpr)] = Some((i.a, i.b)) -} - -object Compare extends Enumeration { - val Greater, GreaterEqual = Value -} -case class BVComparison(op: Compare.Value, a: BVExpr, b: BVExpr, signed: Boolean) extends BVBinaryExpr { - assert(a.width == b.width, s"Both argument need to be the same width!") - override def width: Int = 1 - override def reapply(nA: BVExpr, nB: BVExpr) = BVComparison(op, nA, nB, signed) -} - -object Op extends Enumeration { - val Xor = Value("xor") - val ShiftLeft = Value("logical_shift_left") - val ArithmeticShiftRight = Value("arithmetic_shift_right") - val ShiftRight = Value("logical_shift_right") - val Add = Value("add") - val Mul = Value("mul") - val SignedDiv = Value("sdiv") - val UnsignedDiv = Value("udiv") - val SignedMod = Value("smod") - val SignedRem = Value("srem") - val UnsignedRem = Value("urem") - val Sub = Value("sub") -} -case class BVOp(op: Op.Value, a: BVExpr, b: BVExpr) extends BVBinaryExpr { - assert(a.width == b.width, s"Both argument need to be the same width!") - override val width: Int = a.width - override def reapply(nA: BVExpr, nB: BVExpr) = BVOp(op, nA, nB) -} -case class BVConcat(a: BVExpr, b: BVExpr) extends BVBinaryExpr { - override val width: Int = a.width + b.width - override def reapply(nA: BVExpr, nB: BVExpr) = BVConcat(nA, nB) -} -case class ArrayRead(array: ArrayExpr, index: BVExpr) extends BVExpr { - assert(array.indexWidth == index.width, "Index with does not match expected array index width!") - override val width: Int = array.dataWidth - override def children: List[SMTExpr] = List(array, index) -} -case class BVIte(cond: BVExpr, tru: BVExpr, fals: BVExpr) extends BVExpr { - assert(cond.width == 1, s"Condition needs to be a 1-bit value not ${cond.width}-bit!") - assert(tru.width == fals.width, s"Both branches need to be of the same width! ${tru.width} vs ${fals.width}") - override val width: Int = tru.width - override def children: List[BVExpr] = List(cond, tru, fals) -} - -case class BVAnd(terms: List[BVExpr]) extends BVExpr { - require(terms.size > 1) - override val width: Int = terms.head.width - require(terms.forall(_.width == width)) - override def children: List[BVExpr] = terms -} - -case class BVOr(terms: List[BVExpr]) extends BVExpr { - require(terms.size > 1) - override val width: Int = terms.head.width - require(terms.forall(_.width == width)) - override def children: List[BVExpr] = terms -} - -sealed trait ArrayExpr extends SMTExpr { - val indexWidth: Int - val dataWidth: Int - def tpe: ArrayType = ArrayType(indexWidth = indexWidth, dataWidth = dataWidth) - override def toString: String = SMTExprSerializer.serialize(this) -} -case class ArraySymbol(name: String, indexWidth: Int, dataWidth: Int) extends ArrayExpr with SMTSymbol { - assert(!name.contains("|"), s"Invalid id $name contains escape character `|`") - assert(!name.contains("\\"), s"Invalid id $name contains `\\`") - override def rename(newName: String) = ArraySymbol(newName, indexWidth, dataWidth) -} -case class ArrayConstant(e: BVExpr, indexWidth: Int) extends ArrayExpr { - override val dataWidth: Int = e.width - override def children: List[SMTExpr] = List(e) -} -case class ArrayEqual(a: ArrayExpr, b: ArrayExpr) extends BVExpr { - assert(a.indexWidth == b.indexWidth, s"Both argument need to be the same index width!") - assert(a.dataWidth == b.dataWidth, s"Both argument need to be the same data width!") - override def width: Int = 1 - override def children: List[SMTExpr] = List(a, b) -} -case class ArrayStore(array: ArrayExpr, index: BVExpr, data: BVExpr) extends ArrayExpr { - assert(array.indexWidth == index.width, "Index with does not match expected array index width!") - assert(array.dataWidth == data.width, "Data with does not match expected array data width!") - override val dataWidth: Int = array.dataWidth - override val indexWidth: Int = array.indexWidth - override def children: List[SMTExpr] = List(array, index, data) -} -case class ArrayIte(cond: BVExpr, tru: ArrayExpr, fals: ArrayExpr) extends ArrayExpr { - assert(cond.width == 1, s"Condition needs to be a 1-bit value not ${cond.width}-bit!") - assert( - tru.indexWidth == fals.indexWidth, - s"Both branches need to be of the same type! ${tru.indexWidth} vs ${fals.indexWidth}" - ) - assert( - tru.dataWidth == fals.dataWidth, - s"Both branches need to be of the same type! ${tru.dataWidth} vs ${fals.dataWidth}" - ) - override val dataWidth: Int = tru.dataWidth - override val indexWidth: Int = tru.indexWidth - override def children: List[SMTExpr] = List(cond, tru, fals) -} - -case class BVForall(variable: BVSymbol, e: BVExpr) extends BVUnaryExpr { - assert(e.width == 1, "Can only quantify over boolean expressions!") - override def width = 1 - override def reapply(expr: BVExpr) = BVForall(variable, expr) -} - -/** apply arguments to a function which returns a result of bit vector type */ -case class BVFunctionCall(name: String, args: List[SMTFunctionArg], width: Int) extends BVExpr { - override def children = args.map(_.asInstanceOf[SMTExpr]) -} - -/** apply arguments to a function which returns a result of array type */ -case class ArrayFunctionCall(name: String, args: List[SMTFunctionArg], indexWidth: Int, dataWidth: Int) - extends ArrayExpr { - override def children = args.map(_.asInstanceOf[SMTExpr]) -} -sealed trait SMTFunctionArg -// we allow symbols with uninterpreted type to be function arguments -case class UTSymbol(name: String, tpe: String) extends SMTFunctionArg - -object BVAnd { - def apply(a: BVExpr, b: BVExpr): BVExpr = { - assert(a.width == b.width, s"Both argument need to be the same width!") - (a, b) match { - case (True(), b) => b - case (a, True()) => a - case (False(), _) => False() - case (_, False()) => False() - case (a, b) => new BVAnd(List(a, b)) - } - } - def apply(exprs: List[BVExpr]): BVExpr = { - assert(exprs.nonEmpty, "Don't know what to do with an empty list!") - val nonTriviallyTrue = exprs.filterNot(_ == True()) - nonTriviallyTrue.distinct match { - case Seq() => True() - case Seq(one) => one - case terms => new BVAnd(terms) - } - } -} -object BVOr { - def apply(a: BVExpr, b: BVExpr): BVExpr = { - assert(a.width == b.width, s"Both argument need to be the same width!") - (a, b) match { - case (True(), _) => True() - case (_, True()) => True() - case (False(), b) => b - case (a, False()) => a - case (a, b) => new BVOr(List(a, b)) - } - } - def apply(exprs: List[BVExpr]): BVExpr = { - assert(exprs.nonEmpty, "Don't know what to do with an empty list!") - val nonTriviallyFalse = exprs.filterNot(_ == False()) - nonTriviallyFalse.distinct match { - case Seq() => False() - case Seq(one) => one - case terms => new BVOr(terms) - } - } -} - -object BVNot { - def apply(e: BVExpr): BVExpr = e match { - case True() => False() - case False() => True() - case BVNot(inner) => inner - case other => new BVNot(other) - } -} - -object SMTEqual { - def apply(a: SMTExpr, b: SMTExpr): BVExpr = (a, b) match { - case (ab: BVExpr, bb: BVExpr) => BVEqual(ab, bb) - case (aa: ArrayExpr, ba: ArrayExpr) => ArrayEqual(aa, ba) - case _ => throw new RuntimeException(s"Cannot compare $a and $b") - } -} - -object SMTIte { - def apply(cond: BVExpr, tru: SMTExpr, fals: SMTExpr): SMTExpr = (tru, fals) match { - case (ab: BVExpr, bb: BVExpr) => BVIte(cond, ab, bb) - case (aa: ArrayExpr, ba: ArrayExpr) => ArrayIte(cond, aa, ba) - case _ => throw new RuntimeException(s"Cannot mux $tru and $fals") - } -} - -object SMTExpr { - def serializeType(e: SMTExpr): String = e match { - case b: BVExpr => s"bv<${b.width}>" - case a: ArrayExpr => s"bv<${a.indexWidth}> -> bv<${a.dataWidth}>" - } -} - -// unapply for matching BVLiteral(1, 1) -object True { - private val _True = BVLiteral(1, 1) - def apply(): BVLiteral = _True - def unapply(l: BVLiteral): Boolean = l.value == 1 && l.width == 1 -} - -// unapply for matching BVLiteral(0, 1) -object False { - private val _False = BVLiteral(0, 1) - def apply(): BVLiteral = _False - def unapply(l: BVLiteral): Boolean = l.value == 0 && l.width == 1 -} - -sealed trait SMTType -case class BVType(width: Int) extends SMTType -case class ArrayType(indexWidth: Int, dataWidth: Int) extends SMTType diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExprMap.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExprMap.scala deleted file mode 100644 index 8e035186..00000000 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTExprMap.scala +++ /dev/null @@ -1,88 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> -package firrtl.backends.experimental.smt - -object SMTExprMap { - - /** maps f over subexpressions of expr and returns expr with the results replaced */ - def mapExpr(expr: SMTExpr, f: SMTExpr => SMTExpr): SMTExpr = { - val bv = (b: BVExpr) => f(b).asInstanceOf[BVExpr] - val ar = (a: ArrayExpr) => f(a).asInstanceOf[ArrayExpr] - expr match { - case b: BVExpr => mapExpr(b, bv, ar) - case a: ArrayExpr => mapExpr(a, bv, ar) - } - } - - /** maps bv/ar over subexpressions of expr and returns expr with the results replaced */ - def mapExpr(expr: BVExpr, bv: BVExpr => BVExpr, ar: ArrayExpr => ArrayExpr): BVExpr = expr match { - // nullary - case old: BVLiteral => old - case old: BVSymbol => old - // unary - case old @ BVExtend(e, by, signed) => val n = bv(e); if (n.eq(e)) old else BVExtend(n, by, signed) - case old @ BVSlice(e, hi, lo) => val n = bv(e); if (n.eq(e)) old else BVSlice(n, hi, lo) - case old @ BVNot(e) => val n = bv(e); if (n.eq(e)) old else BVNot(n) - case old @ BVNegate(e) => val n = bv(e); if (n.eq(e)) old else BVNegate(n) - case old @ BVForall(variables, e) => val n = bv(e); if (n.eq(e)) old else BVForall(variables, n) - case old @ BVReduceAnd(e) => val n = bv(e); if (n.eq(e)) old else BVReduceAnd(n) - case old @ BVReduceOr(e) => val n = bv(e); if (n.eq(e)) old else BVReduceOr(n) - case old @ BVReduceXor(e) => val n = bv(e); if (n.eq(e)) old else BVReduceXor(n) - // binary - case old @ BVEqual(a, b) => - val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVEqual(nA, nB) - case old @ ArrayEqual(a, b) => - val (nA, nB) = (ar(a), ar(b)); if (nA.eq(a) && nB.eq(b)) old else ArrayEqual(nA, nB) - case old @ BVComparison(op, a, b, signed) => - val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVComparison(op, nA, nB, signed) - case old @ BVOp(op, a, b) => - val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVOp(op, nA, nB) - case old @ BVConcat(a, b) => - val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVConcat(nA, nB) - case old @ ArrayRead(a, b) => - val (nA, nB) = (ar(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else ArrayRead(nA, nB) - case old @ BVImplies(a, b) => - val (nA, nB) = (bv(a), bv(b)); if (nA.eq(a) && nB.eq(b)) old else BVImplies(nA, nB) - // ternary - case old @ BVIte(a, b, c) => - val (nA, nB, nC) = (bv(a), bv(b), bv(c)) - if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else BVIte(nA, nB, nC) - // n-ary - case old @ BVFunctionCall(name, args, width) => - val nArgs = args.map { - case b: BVExpr => bv(b) - case a: ArrayExpr => ar(a) - case u: UTSymbol => u - } - val anyNew = nArgs.zip(args).exists { case (n, o) => !n.eq(o) } - if (anyNew) BVFunctionCall(name, nArgs, width) else old - case old @ BVAnd(terms) => - val nTerms = terms.map(bv) - val anyNew = nTerms.zip(terms).exists { case (n, o) => !n.eq(o) } - if (anyNew) BVAnd(nTerms) else old - case old @ BVOr(terms) => - val nTerms = terms.map(bv) - val anyNew = nTerms.zip(terms).exists { case (n, o) => !n.eq(o) } - if (anyNew) BVOr(nTerms) else old - } - - /** maps bv/ar over subexpressions of expr and returns expr with the results replaced */ - def mapExpr(expr: ArrayExpr, bv: BVExpr => BVExpr, ar: ArrayExpr => ArrayExpr): ArrayExpr = expr match { - case old: ArraySymbol => old - case old @ ArrayConstant(e, indexWidth) => val n = bv(e); if (n.eq(e)) old else ArrayConstant(n, indexWidth) - case old @ ArrayStore(a, b, c) => - val (nA, nB, nC) = (ar(a), bv(b), bv(c)) - if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayStore(nA, nB, nC) - case old @ ArrayIte(a, b, c) => - val (nA, nB, nC) = (bv(a), ar(b), ar(c)) - if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayIte(nA, nB, nC) - case old @ ArrayFunctionCall(name, args, indexWidth, dataWidth) => - val nArgs = args.map { - case b: BVExpr => bv(b) - case a: ArrayExpr => ar(a) - case u: UTSymbol => u - } - val anyNew = nArgs.zip(args).exists { case (n, o) => !n.eq(o) } - if (anyNew) ArrayFunctionCall(name, nArgs, indexWidth, dataWidth) else old - } -} diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExprSerializer.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExprSerializer.scala deleted file mode 100644 index 4aaf78a2..00000000 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTExprSerializer.scala +++ /dev/null @@ -1,60 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> - -package firrtl.backends.experimental.smt - -private object SMTExprSerializer { - def serialize(expr: BVExpr): String = expr match { - // nullary - case lit: BVLiteral => - if (lit.width <= 8) { - lit.width.toString + "'b" + lit.value.toString(2) - } else { - lit.width.toString + "'x" + lit.value.toString(16) - } - case BVSymbol(name, _) => name - // unary - case BVExtend(e, by, false) => s"zext(${serialize(e)}, $by)" - case BVExtend(e, by, true) => s"sext(${serialize(e)}, $by)" - case BVSlice(e, hi, lo) if hi == lo => s"${serialize(e)}[$hi]" - case BVSlice(e, hi, lo) => s"${serialize(e)}[$hi:$lo]" - case BVNot(e) => s"not(${serialize(e)})" - case BVNegate(e) => s"neg(${serialize(e)})" - case BVForall(variable, e) => s"forall(${variable.name} : bv<${variable.width}, ${serialize(e)})" - case BVReduceAnd(e) => s"redand(${serialize(e)})" - case BVReduceOr(e) => s"redor(${serialize(e)})" - case BVReduceXor(e) => s"redxor(${serialize(e)})" - // binary - case BVEqual(a, b) => s"eq(${serialize(a)}, ${serialize(b)})" - case BVComparison(Compare.Greater, a, b, false) => s"ugt(${serialize(a)}, ${serialize(b)})" - case BVComparison(Compare.Greater, a, b, true) => s"sgt(${serialize(a)}, ${serialize(b)})" - case BVComparison(Compare.GreaterEqual, a, b, false) => s"ugeq(${serialize(a)}, ${serialize(b)})" - case BVComparison(Compare.GreaterEqual, a, b, true) => s"sgeq(${serialize(a)}, ${serialize(b)})" - case BVOp(op, a, b) => s"$op(${serialize(a)}, ${serialize(b)})" - case BVConcat(a, b) => s"concat(${serialize(a)}, ${serialize(b)})" - case ArrayRead(array, index) => s"${serialize(array)}[${serialize(index)}]" - case ArrayEqual(a, b) => s"eq(${serialize(a)}, ${serialize(b)})" - case BVImplies(a, b) => s"implies(${serialize(a)}, ${serialize(b)})" - // ternary - case BVIte(cond, tru, fals) => s"ite(${serialize(cond)}, ${serialize(tru)}, ${serialize(fals)})" - // n-ary - case BVFunctionCall(name, args, _) => name + serialize(args).mkString("(", ",", ")") - case BVAnd(terms) => terms.map(serialize).mkString("and(", ", ", ")") - case BVOr(terms) => terms.map(serialize).mkString("or(", ", ", ")") - } - - def serialize(expr: ArrayExpr): String = expr match { - case ArraySymbol(name, _, _) => name - case ArrayConstant(e, indexWidth) => s"([${serialize(e)}] x ${(BigInt(1) << indexWidth)})" - case ArrayStore(array, index, data) => s"${serialize(array)}[${serialize(index)} := ${serialize(data)}]" - case ArrayIte(cond, tru, fals) => s"ite(${serialize(cond)}, ${serialize(tru)}, ${serialize(fals)})" - case ArrayFunctionCall(name, args, _, _) => name + serialize(args).mkString("(", ",", ")") - } - - private def serialize(args: Iterable[SMTFunctionArg]): Iterable[String] = - args.map { - case b: BVExpr => serialize(b) - case a: ArrayExpr => serialize(a) - case u: UTSymbol => u.name - } -} diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala deleted file mode 100644 index 20a499b9..00000000 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala +++ /dev/null @@ -1,177 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> - -package firrtl.backends.experimental.smt - -import scala.util.matching.Regex - -/** Converts STM Expressions to a SMTLib compatible string representation. - * See http://smtlib.cs.uiowa.edu/ - * Assumes well typed expression, so it is advisable to run the TypeChecker - * before serializing! - * Automatically converts 1-bit vectors to bool. - */ -object SMTLibSerializer { - def serialize(e: SMTExpr): String = e match { - case b: BVExpr => serialize(b) - case a: ArrayExpr => serialize(a) - } - - def serialize(t: SMTType): String = t match { - case BVType(width) => serializeBitVectorType(width) - case ArrayType(indexWidth, dataWidth) => serializeArrayType(indexWidth, dataWidth) - } - - private def serialize(e: BVExpr): String = e match { - case BVLiteral(value, width) => - val mask = (BigInt(1) << width) - 1 - val twosComplement = if (value < 0) { ((~(-value)) & mask) + 1 } - else value - if (width == 1) { - if (twosComplement == 1) "true" else "false" - } else { - s"(_ bv$twosComplement $width)" - } - case BVSymbol(name, _) => escapeIdentifier(name) - case BVExtend(e, 0, _) => serialize(e) - case BVExtend(BVLiteral(value, width), by, false) => serialize(BVLiteral(value, width + by)) - case BVExtend(e, by, signed) => - val foo = if (signed) "sign_extend" else "zero_extend" - s"((_ $foo $by) ${asBitVector(e)})" - case BVSlice(e, hi, lo) => - if (lo == 0 && hi == e.width - 1) { serialize(e) } - else { - val bits = s"((_ extract $hi $lo) ${asBitVector(e)})" - // 1-bit extracts need to be turned into a boolean - if (lo == hi) { toBool(bits) } - else { bits } - } - case BVNot(BVEqual(a, b)) if a.width == 1 => s"(distinct ${serialize(a)} ${serialize(b)})" - case BVNot(BVNot(e)) => serialize(e) - case BVNot(e) => - if (e.width == 1) { s"(not ${serialize(e)})" } - else { s"(bvnot ${serialize(e)})" } - case BVNegate(e) => s"(bvneg ${asBitVector(e)})" - case r: BVReduceAnd => serialize(Expander.expand(r)) - case r: BVReduceOr => serialize(Expander.expand(r)) - case r: BVReduceXor => serialize(Expander.expand(r)) - case BVImplies(BVLiteral(v, 1), b) if v == 1 => serialize(b) - case BVImplies(a, b) => s"(=> ${serialize(a)} ${serialize(b)})" - case BVEqual(a, b) => s"(= ${serialize(a)} ${serialize(b)})" - case ArrayEqual(a, b) => s"(= ${serialize(a)} ${serialize(b)})" - case BVComparison(Compare.Greater, a, b, false) => s"(bvugt ${asBitVector(a)} ${asBitVector(b)})" - case BVComparison(Compare.GreaterEqual, a, b, false) => s"(bvuge ${asBitVector(a)} ${asBitVector(b)})" - case BVComparison(Compare.Greater, a, b, true) => s"(bvsgt ${asBitVector(a)} ${asBitVector(b)})" - case BVComparison(Compare.GreaterEqual, a, b, true) => s"(bvsge ${asBitVector(a)} ${asBitVector(b)})" - // boolean operations get a special treatment for 1-bit vectors aka bools - case b: BVAnd => serializeVariadic(if (b.width == 1) "and" else "bvand", b.terms) - case b: BVOr => serializeVariadic(if (b.width == 1) "or" else "bvor", b.terms) - case BVOp(Op.Xor, a, b) if a.width == 1 => s"(xor ${serialize(a)} ${serialize(b)})" - case BVOp(op, a, b) if a.width == 1 => toBool(s"(${serialize(op)} ${asBitVector(a)} ${asBitVector(b)})") - case BVOp(op, a, b) => s"(${serialize(op)} ${serialize(a)} ${serialize(b)})" - case BVConcat(a, b) => s"(concat ${asBitVector(a)} ${asBitVector(b)})" - case ArrayRead(array, index) => s"(select ${serialize(array)} ${serialize(index)})" - case BVIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})" - case BVFunctionCall(name, args, _) => args.map(serializeArg).mkString(s"($name ", " ", ")") - case BVForall(variable, e) => s"(forall ((${variable.name} ${serialize(variable.tpe)})) ${serialize(e)})" - } - - private def serializeVariadic(op: String, terms: List[BVExpr]): String = terms match { - case Seq() | Seq(_) => throw new RuntimeException(s"expected at least two elements in variadic op $op") - case Seq(a, b) => s"($op ${serialize(a)} ${serialize(b)})" - case head :: tail => s"($op ${serialize(head)} ${serializeVariadic(op, tail)})" - } - - def serialize(e: ArrayExpr): String = e match { - case ArraySymbol(name, _, _) => escapeIdentifier(name) - case ArrayStore(array, index, data) => s"(store ${serialize(array)} ${serialize(index)} ${serialize(data)})" - case ArrayIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})" - case c @ ArrayConstant(e, _) => s"((as const ${serializeArrayType(c.indexWidth, c.dataWidth)}) ${serialize(e)})" - case ArrayFunctionCall(name, args, _, _) => args.map(serializeArg).mkString(s"($name ", " ", ")") - } - - def serialize(c: SMTCommand): String = c match { - case Comment(msg) => msg.split("\n").map("; " + _).mkString("\n") - case DeclareUninterpretedSort(name) => s"(declare-sort ${escapeIdentifier(name)} 0)" - case DefineFunction(name, args, e) => - val aa = args.map(a => s"(${serializeArg(a)} ${serializeArgTpe(a)})").mkString(" ") - s"(define-fun ${escapeIdentifier(name)} ($aa) ${serialize(e.tpe)} ${serialize(e)})" - case DeclareFunction(sym, tpes) => - val aa = tpes.map(serializeArgTpe).mkString(" ") - s"(declare-fun ${escapeIdentifier(sym.name)} ($aa) ${serialize(sym.tpe)})" - case SetLogic(logic) => s"(set-logic $logic)" - case DeclareUninterpretedSymbol(name, tpe) => - s"(declare-fun ${escapeIdentifier(name)} () ${escapeIdentifier(tpe)})" - } - - private def serializeArgTpe(a: SMTFunctionArg): String = - a match { - case u: UTSymbol => escapeIdentifier(u.tpe) - case s: SMTExpr => serialize(s.tpe) - } - private def serializeArg(a: SMTFunctionArg): String = - a match { - case u: UTSymbol => escapeIdentifier(u.name) - case s: SMTExpr => serialize(s) - } - - private def serializeArrayType(indexWidth: Int, dataWidth: Int): String = - s"(Array ${serializeBitVectorType(indexWidth)} ${serializeBitVectorType(dataWidth)})" - private def serializeBitVectorType(width: Int): String = - if (width == 1) { "Bool" } - else { assert(width > 1); s"(_ BitVec $width)" } - - private def serialize(op: Op.Value): String = op match { - case Op.Xor => "bvxor" - case Op.ArithmeticShiftRight => "bvashr" - case Op.ShiftRight => "bvlshr" - case Op.ShiftLeft => "bvshl" - case Op.Add => "bvadd" - case Op.Mul => "bvmul" - case Op.Sub => "bvsub" - case Op.SignedDiv => "bvsdiv" - case Op.UnsignedDiv => "bvudiv" - case Op.SignedMod => "bvsmod" - case Op.SignedRem => "bvsrem" - case Op.UnsignedRem => "bvurem" - } - - private def toBool(e: String): String = s"(= $e (_ bv1 1))" - - private val bvZero = "(_ bv0 1)" - private val bvOne = "(_ bv1 1)" - private def asBitVector(e: BVExpr): String = - if (e.width > 1) { serialize(e) } - else { s"(ite ${serialize(e)} $bvOne $bvZero)" } - - // See <simple_symbol> definition in the Concrete Syntax Appendix of the SMTLib Spec - private val simple: Regex = raw"[a-zA-Z\+-/\*\=%\?!\.$$_~&\^<>@][a-zA-Z0-9\+-/\*\=%\?!\.$$_~&\^<>@]*".r - def escapeIdentifier(name: String): String = name match { - case simple() => name - case _ => if (name.startsWith("|") && name.endsWith("|")) name else s"|$name|" - } -} - -/** Expands expressions that are not natively supported by SMTLib */ -private object Expander { - def expand(r: BVReduceAnd): BVExpr = { - if (r.e.width == 1) { r.e } - else { - val allOnes = (BigInt(1) << r.e.width) - 1 - BVEqual(r.e, BVLiteral(allOnes, r.e.width)) - } - } - def expand(r: BVReduceOr): BVExpr = { - if (r.e.width == 1) { r.e } - else { - BVNot(BVEqual(r.e, BVLiteral(0, r.e.width))) - } - } - def expand(r: BVReduceXor): BVExpr = { - if (r.e.width == 1) { r.e } - else { - val bits = (0 until r.e.width).map(ii => BVSlice(r.e, ii, ii)) - bits.reduce[BVExpr]((a, b) => BVOp(Op.Xor, a, b)) - } - } -} diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala deleted file mode 100644 index 4f096c28..00000000 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala +++ /dev/null @@ -1,133 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> - -package firrtl.backends.experimental.smt - -import scala.collection.mutable - -/** This Transition System encoding is directly inspired by yosys' SMT backend: - * https://github.com/YosysHQ/yosys/blob/master/backends/smt2/smt2.cc - * It if fairly compact, but unfortunately, the use of an uninterpreted sort for the state - * prevents this encoding from working with boolector. - * For simplicity reasons, we do not support hierarchical designs (no `_h` function). - */ -object SMTTransitionSystemEncoder { - - def encode(sys: TransitionSystem): Iterable[SMTCommand] = { - val cmds = mutable.ArrayBuffer[SMTCommand]() - val name = sys.name - - // declare UFs if necessary - cmds ++= TransitionSystem.findUninterpretedFunctions(sys) - - // emit header as comments - if (sys.header.nonEmpty) { - cmds ++= sys.header.split('\n').map(Comment) - } - - // declare state type - val stateType = id(name + "_s") - cmds += DeclareUninterpretedSort(stateType) - - // state symbol - val State = UTSymbol("state", stateType) - val StateNext = UTSymbol("state_n", stateType) - - // inputs and states are modelled as constants - def declare(sym: SMTSymbol, kind: String): Unit = { - cmds ++= toDescription(sym, kind, sys.comments.get) - val s = SMTSymbol.fromExpr(sym.name + SignalSuffix, sym) - cmds += DeclareFunction(s, List(State)) - } - sys.inputs.foreach(i => declare(i, "input")) - sys.states.foreach(s => declare(s.sym, "register")) - - // signals are just functions of other signals, inputs and state - def define(sym: SMTSymbol, e: SMTExpr, suffix: String = SignalSuffix): Unit = { - val withReplacedSymbols = replaceSymbols(SignalSuffix, State)(e) - cmds += DefineFunction(sym.name + suffix, List(State), withReplacedSymbols) - } - sys.signals.foreach { signal => - val sym = signal.sym - cmds ++= toDescription(sym, lblToKind(signal.lbl), sys.comments.get) - val e = if (signal.lbl == IsBad) BVNot(signal.e.asInstanceOf[BVExpr]) else signal.e - define(sym, e) - } - - // define the next and init functions for all states - sys.states.foreach { state => - assert(state.next.nonEmpty, "Next function required") - define(state.sym, state.next.get, NextSuffix) - // init is optional - state.init.foreach { init => - define(state.sym, init, InitSuffix) - } - } - - def defineConjunction(e: List[BVExpr], suffix: String): Unit = { - define(BVSymbol(name, 1), if (e.isEmpty) True() else BVAnd(e), suffix) - } - - // the transition relation asserts that the value of the next state is the next value from the previous state - // e.g., (reg state_n) == (reg_next state) - val transitionRelations = sys.states.map { state => - val newState = replaceSymbols(SignalSuffix, StateNext)(state.sym) - val nextOldState = replaceSymbols(NextSuffix, State)(state.sym) - SMTEqual(newState, nextOldState) - } - // the transition relation is over two states - val transitionExpr = if (transitionRelations.isEmpty) { True() } - else { - replaceSymbols(SignalSuffix, State)(BVAnd(transitionRelations)) - } - cmds += DefineFunction(name + "_t", List(State, StateNext), transitionExpr) - - // The init relation just asserts that all init function hold - val initRelations = sys.states.filter(_.init.isDefined).map { state => - val stateSignal = replaceSymbols(SignalSuffix, State)(state.sym) - val initSignal = replaceSymbols(InitSuffix, State)(state.sym) - SMTEqual(stateSignal, initSignal) - } - defineConjunction(initRelations, "_i") - - // assertions and assumptions - val assertions = sys.signals.filter(_.lbl == IsBad).map(a => replaceSymbols(SignalSuffix, State)(a.sym)) - defineConjunction(assertions.map(_.asInstanceOf[BVExpr]), AssertionSuffix) - val assumptions = sys.signals.filter(_.lbl == IsConstraint).map(a => replaceSymbols(SignalSuffix, State)(a.sym)) - defineConjunction(assumptions.map(_.asInstanceOf[BVExpr]), AssumptionSuffix) - - cmds - } - - private def id(s: String): String = SMTLibSerializer.escapeIdentifier(s) - private val SignalSuffix = "_f" - private val NextSuffix = "_next" - private val InitSuffix = "_init" - val AssertionSuffix = "_a" - val AssumptionSuffix = "_u" - private def lblToKind(lbl: SignalLabel): String = lbl match { - case IsNode | IsInit | IsNext => "wire" - case IsOutput => "output" - // for the SMT encoding we turn bad state signals back into assertions - case IsBad => "assert" - case IsConstraint => "assume" - case IsFair => "fair" - } - private def toDescription(sym: SMTSymbol, kind: String, comments: String => Option[String]): List[Comment] = { - List(sym match { - case BVSymbol(name, width) => Comment(s"firrtl-smt2-$kind $name $width") - case ArraySymbol(name, indexWidth, dataWidth) => - Comment(s"firrtl-smt2-$kind $name $indexWidth $dataWidth") - }) ++ comments(sym.name).map(Comment) - } - // All signals are modelled with functions that need to be called with the state as argument, - // this replaces all Symbols with function applications to the state. - private def replaceSymbols(suffix: String, arg: SMTFunctionArg, vars: Set[String] = Set())(e: SMTExpr): SMTExpr = - e match { - case BVSymbol(name, width) if !vars(name) => BVFunctionCall(id(name + suffix), List(arg), width) - case ArraySymbol(name, indexWidth, dataWidth) if !vars(name) => - ArrayFunctionCall(id(name + suffix), List(arg), indexWidth, dataWidth) - case fa @ BVForall(variable, _) => SMTExprMap.mapExpr(fa, replaceSymbols(suffix, arg, vars + variable.name)) - case other => SMTExprMap.mapExpr(other, replaceSymbols(suffix, arg, vars)) - } -} diff --git a/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala b/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala deleted file mode 100644 index 534db217..00000000 --- a/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala +++ /dev/null @@ -1,272 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> - -package firrtl.backends.experimental.smt - -import firrtl._ -import firrtl.annotations._ -import firrtl.ir.EmptyStmt -import firrtl.options.Dependency -import firrtl.passes.PassException -import firrtl.stage.Forms -import firrtl.stage.TransformManager.TransformDependency -import firrtl.transforms.PropagatePresetAnnotations -import firrtl.renamemap.MutableRenameMap - -import scala.collection.mutable - -case class GlobalClockAnnotation(target: ReferenceTarget) extends SingleTargetAnnotation[ReferenceTarget] { - override def duplicate(n: ReferenceTarget): Annotation = this.copy(n) -} - -/** Converts every input clock into a clock enable input and adds a single global clock. - * - all registers and memory ports will be connected to the new global clock - * - all registers and memory ports will be guarded by the enable signal of their original clock - * - the clock enabled signal can be understood as a clock tick or posedge - * - this transform can be used in order to (formally) verify designs with multiple clocks or asynchronous resets - */ -class StutteringClockTransform extends Transform with DependencyAPIMigration { - override def prerequisites: Seq[TransformDependency] = Forms.LowForm - override def invalidates(a: Transform): Boolean = false - - // this pass needs to run *before* converting to a transition system - override def optionalPrerequisiteOf: Seq[TransformDependency] = Seq(Dependency(FirrtlToTransitionSystem)) - // since this pass only runs on the main module, inlining needs to happen before - override def optionalPrerequisites: Seq[TransformDependency] = Seq( - Dependency[firrtl.passes.InlineInstances], - Dependency[PropagatePresetAnnotations] - ) - - override protected def execute(state: CircuitState): CircuitState = { - if (state.circuit.modules.size > 1) { - logger.warn( - "WARN: StutteringClockTransform currently only supports running on a single module.\n" + - s"All submodules of ${state.circuit.main} will be ignored! Please inline all submodules if this is not what you want." - ) - } - - // get main module - val main = state.circuit.modules.find(_.name == state.circuit.main).get match { - case m: ir.Module => m - case e: ir.ExtModule => unsupportedError(s"Cannot run on extmodule $e") - } - mainName = main.name - - val namespace = Namespace(main) - - // create a global clock - val globalClocks = state.annotations.collect { case GlobalClockAnnotation(c) => c } - assert(globalClocks.size < 2, "There can only be a single global clock: " + globalClocks.mkString(", ")) - val (globalClock, portsWithGlobalClock) = globalClocks.headOption match { - case Some(clock) => - assert(clock.module == main.name, "GlobalClock needs to be an input of the main module!") - assert(main.ports.exists(_.name == clock.ref), "GlobalClock needs to be an input port!") - assert(main.ports.find(_.name == clock.ref).get.direction == ir.Input, "GlobalClock needs to be an input port!") - (clock.ref, main.ports) - case None => - val name = namespace.newName("global_clock") - (name, ir.Port(ir.NoInfo, name, ir.Input, ir.ClockType) +: main.ports) - } - - // replace all other clocks with enable signals, unless they are the global clock - val clocks = portsWithGlobalClock.filter(p => p.tpe == ir.ClockType && p.name != globalClock).map(_.name) - val clockToEnable = clocks.map { c => - c -> ir.Reference(namespace.newName(c + "_en"), Utils.BoolType, firrtl.PortKind, firrtl.SourceFlow) - }.toMap - val portsWithEnableSignals = portsWithGlobalClock.map { p => - if (clockToEnable.contains(p.name)) { p.copy(name = clockToEnable(p.name).name, tpe = Utils.BoolType) } - else { p } - } - // replace async reset with synchronous reset (since everything will we synchronous with the global clock) - // unless it is a preset reset - val asyncResets = portsWithEnableSignals.filter(_.tpe == ir.AsyncResetType).map(_.name) - val isPresetReset = state.annotations.collect { case PresetAnnotation(r) if r.module == main.name => r.ref }.toSet - val resetsToChange = asyncResets.filterNot(isPresetReset).toSet - val portsWithSyncReset = portsWithEnableSignals.map { p => - if (resetsToChange.contains(p.name)) { p.copy(tpe = Utils.BoolType) } - else { p } - } - val presetRegs = state.annotations.collect { - case PresetRegAnnotation(target) if target.module == mainName => target.ref - }.toSet - - // discover clock and reset connections - val scan = scanClocks(main, clockToEnable, resetsToChange) - - // rename clocks to clock enable signals - val mRef = CircuitTarget(state.circuit.main).module(main.name) - val renameMap = MutableRenameMap() - scan.clockToEnable.foreach { - case (clk, en) => - renameMap.record(mRef.ref(clk), mRef.ref(en.name)) - } - - // make changes - implicit val ctx: Context = new Context(globalClock, scan, presetRegs) - val newMain = main.copy(ports = portsWithSyncReset).mapStmt(onStatement) - - val nonMainModules = state.circuit.modules.filterNot(_.name == state.circuit.main) - val newCircuit = state.circuit.copy(modules = nonMainModules :+ newMain) - state.copy(circuit = newCircuit, renames = Some(renameMap)) - } - - private def onStatement(s: ir.Statement)(implicit ctx: Context): ir.Statement = { - s.foreachExpr(checkExpr) - s match { - // memory field connects - case c @ ir.Connect(_, ir.SubField(ir.SubField(ir.Reference(mem, _, _, _), port, _, _), field, _, _), _) - if ctx.isMem(mem) && ctx.memPortToClockEnable.contains(mem + "." + port) => - // replace clock with the global clock - if (field == "clk") { - c.copy(expr = ctx.globalClock) - } else if (field == "en") { - val m = ctx.memInfo(mem) - val isWritePort = m.writers.contains(port) - assert(isWritePort || m.readers.contains(port)) - - // for write ports we guard the write enable with the clock enable signal, similar to registers - if (isWritePort) { - val clockEn = ctx.memPortToClockEnable(mem + "." + port) - val guardedEnable = Utils.and(clockEn, c.expr) - c.copy(expr = guardedEnable) - } else { c } - } else { c } - // register field connects - case c @ ir.Connect(_, r: ir.Reference, next) if ctx.registerToEnable.contains(r.name) => - val clockEnable = ctx.registerToEnable(r.name) - val guardedNext = Utils.mux(clockEnable, next, r) - val withReset = ctx.registerToAsyncReset.get(r.name) match { - case None => guardedNext - case Some((asyncReset, init)) => Utils.mux(asyncReset, init, guardedNext) - } - c.copy(expr = withReset) - // remove other clock wires and nodes - case ir.Connect(_, loc, expr) if expr.tpe == ir.ClockType && ctx.isRemovedClock(loc.serialize) => EmptyStmt - case ir.DefNode(_, name, value) if value.tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt - case ir.DefWire(_, name, tpe) if tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt - // change async reset to synchronous reset - case ir.Connect(info, loc: ir.Reference, expr: ir.Reference) - if expr.tpe == ir.AsyncResetType && ctx.isResetToChange(loc.serialize) => - ir.Connect(info, loc.copy(tpe = Utils.BoolType), expr.copy(tpe = Utils.BoolType)) - case d @ ir.DefNode(_, name, value: ir.Reference) - if value.tpe == ir.AsyncResetType && ctx.isResetToChange(name) => - d.copy(value = value.copy(tpe = Utils.BoolType)) - case d @ ir.DefWire(_, name, tpe) if tpe == ir.AsyncResetType && ctx.isResetToChange(name) => - d.copy(tpe = Utils.BoolType) - // change memory clock and synchronize reset - case ir.DefRegister(info, name, tpe, _, _, init) if ctx.registerToEnable.contains(name) => - val newInit = if (ctx.isPresetReg(name)) init else ir.Reference(name, tpe, RegKind, SourceFlow) - ir.DefRegister(info, name, tpe, ctx.globalClock, Utils.False(), newInit) - case other => other.mapStmt(onStatement) - } - } - - private def scanClocks( - m: ir.Module, - initialClockToEnable: Map[String, ir.Reference], - resetsToChange: Set[String] - ): ScanCtx = { - implicit val ctx: ScanCtx = new ScanCtx(initialClockToEnable, resetsToChange) - m.foreachStmt(scanClocksAndResets) - ctx - } - - private def scanClocksAndResets(s: ir.Statement)(implicit ctx: ScanCtx): Unit = { - s.foreachExpr(checkExpr) - s match { - // track clock aliases - case ir.Connect(_, loc, expr) if expr.tpe == ir.ClockType => - val locName = loc.serialize - ctx.clockToEnable.get(expr.serialize).foreach { clockEn => - ctx.clockToEnable(locName) = clockEn - // keep track of memory clocks - if (loc.isInstanceOf[ir.SubField]) { - val parts = locName.split('.') - if (ctx.mems.contains(parts.head)) { - assert(parts.length == 3 && parts.last == "clk") - ctx.memPortToClockEnable.append(parts.dropRight(1).mkString(".") -> clockEn) - } - } - } - case ir.DefNode(_, name, value) if value.tpe == ir.ClockType => - ctx.clockToEnable.get(value.serialize).foreach(c => ctx.clockToEnable(name) = c) - // track reset aliases - case ir.Connect(_, loc, expr) if expr.tpe == ir.AsyncResetType && ctx.resetsToChange(expr.serialize) => - ctx.resetsToChange.add(loc.serialize) - case ir.DefNode(_, name, value) if value.tpe == ir.AsyncResetType && ctx.resetsToChange(value.serialize) => - ctx.resetsToChange.add(name) - // modify clocked elements - case ir.DefRegister(_, name, _, clock, reset, init) => - ctx.clockToEnable.get(clock.serialize).foreach { clockEnable => - ctx.registerToEnable.append(name -> clockEnable) - } - reset match { - case Utils.False() => - case other => ctx.registerToAsyncReset.append(name -> (other, init)) - } - case m: ir.DefMemory => - assert(m.readwriters.isEmpty, "Combined read/write ports are not supported!") - assert(m.readLatency == 0 || m.readLatency == 1, "Only read-latency 1 and read latency 0 are supported!") - assert(m.writeLatency == 1, "Only write-latency 1 is supported!") - if (m.readers.nonEmpty && m.readLatency == 1) { - unsupportedError("Registers memory read ports are not properly implemented yet :(") - } - ctx.mems(m.name) = m - case other => other.foreachStmt(scanClocksAndResets) - } - } - - // we rely on people not casting clocks or async resets - private def checkExpr(expr: ir.Expression): Unit = expr match { - case ir.DoPrim(PrimOps.AsUInt, Seq(e), _, _) if e.tpe == ir.ClockType => - unsupportedError(s"Clock casts are not supported: ${expr.serialize}") - case ir.DoPrim(PrimOps.AsSInt, Seq(e), _, _) if e.tpe == ir.ClockType => - unsupportedError(s"Clock casts are not supported: ${expr.serialize}") - case ir.DoPrim(PrimOps.AsUInt, Seq(e), _, _) if e.tpe == ir.AsyncResetType => - unsupportedError(s"AsyncReset casts are not supported: ${expr.serialize}") - case ir.DoPrim(PrimOps.AsSInt, Seq(e), _, _) if e.tpe == ir.AsyncResetType => - unsupportedError(s"AsyncReset casts are not supported: ${expr.serialize}") - case ir.DoPrim(PrimOps.AsAsyncReset, _, _, _) => - unsupportedError(s"AsyncReset casts are not supported: ${expr.serialize}") - case ir.DoPrim(PrimOps.AsClock, _, _, _) => - unsupportedError(s"Clock casts are not supported: ${expr.serialize}") - case other => other.foreachExpr(checkExpr) - } - - private class ScanCtx(initialClockToEnable: Map[String, ir.Reference], initialResetsToChange: Set[String]) { - // keeps track of which clock signals will be replaced by which clock enable signal - val clockToEnable = mutable.HashMap[String, ir.Reference]() ++ initialClockToEnable - // kepp track of asynchronous resets that need to be changed to bool - val resetsToChange = mutable.HashSet[String]() ++ initialResetsToChange - // registers whose next function needs to be guarded with a clock enable - val registerToEnable = mutable.ArrayBuffer[(String, ir.Reference)]() - // registers with asynchronous reset - val registerToAsyncReset = mutable.ArrayBuffer[(String, (ir.Expression, ir.Expression))]() - // memory enables which need to be guarded with clock enables - val memPortToClockEnable = mutable.ArrayBuffer[(String, ir.Reference)]() - // keep track of memory names - val mems = mutable.HashMap[String, ir.DefMemory]() - } - - private class Context(globalClockName: String, scanResults: ScanCtx, val isPresetReg: String => Boolean) { - val globalClock: ir.Reference = ir.Reference(globalClockName, ir.ClockType, firrtl.PortKind, firrtl.SourceFlow) - // keeps track of which clock signals will be replaced by which clock enable signal - val isRemovedClock: String => Boolean = scanResults.clockToEnable.contains - // registers whose next function needs to be guarded with a clock enable - val registerToEnable: Map[String, ir.Reference] = scanResults.registerToEnable.toMap - // registers with asynchronous reset - val registerToAsyncReset: Map[String, (ir.Expression, ir.Expression)] = scanResults.registerToAsyncReset.toMap - // memory enables which need to be guarded with clock enables - val memPortToClockEnable: Map[String, ir.Reference] = scanResults.memPortToClockEnable.toMap - // keep track of memory names - val isMem: String => Boolean = scanResults.mems.contains - val memInfo: String => ir.DefMemory = scanResults.mems - val isResetToChange: String => Boolean = scanResults.resetsToChange.contains - } - - private var mainName: String = "" // for debugging - private def unsupportedError(msg: String): Nothing = - throw new UnsupportedFeatureException(s"StutteringClockTransform: [$mainName] $msg") -} - -private class UnsupportedFeatureException(s: String) extends PassException(s) diff --git a/src/main/scala/firrtl/backends/experimental/smt/TransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/TransitionSystem.scala deleted file mode 100644 index bd3ad740..00000000 --- a/src/main/scala/firrtl/backends/experimental/smt/TransitionSystem.scala +++ /dev/null @@ -1,120 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> - -package firrtl.backends.experimental.smt - -import firrtl.graph.MutableDiGraph -import scala.collection.mutable - -case class State(sym: SMTSymbol, init: Option[SMTExpr], next: Option[SMTExpr]) { - def name: String = sym.name -} -case class Signal(name: String, e: SMTExpr, lbl: SignalLabel = IsNode) { - def toSymbol: SMTSymbol = SMTSymbol.fromExpr(name, e) - def sym: SMTSymbol = toSymbol -} -case class TransitionSystem( - name: String, - inputs: List[BVSymbol], - states: List[State], - signals: List[Signal], - comments: Map[String, String] = Map(), - header: String = "") { - def serialize: String = TransitionSystem.serialize(this) -} - -sealed trait SignalLabel -case object IsNode extends SignalLabel -case object IsOutput extends SignalLabel -case object IsConstraint extends SignalLabel -case object IsBad extends SignalLabel -case object IsFair extends SignalLabel -case object IsNext extends SignalLabel -case object IsInit extends SignalLabel - -object SignalLabel { - private val labels = Seq(IsNode, IsOutput, IsConstraint, IsBad, IsFair, IsNext, IsInit) - val labelStrings = Seq("node", "output", "constraint", "bad", "fair", "next", "init") - val labelToString: SignalLabel => String = labels.zip(labelStrings).toMap - val stringToLabel: String => SignalLabel = labelStrings.zip(labels).toMap -} - -object TransitionSystem { - def serialize(sys: TransitionSystem): String = { - (Iterator(sys.name) ++ - sys.inputs.map(i => s"input ${i.name} : ${SMTExpr.serializeType(i)}") ++ - sys.signals.map(s => s"${SignalLabel.labelToString(s.lbl)} ${s.name} : ${SMTExpr.serializeType(s.e)} = ${s.e}") ++ - sys.states.map(serialize)).mkString("\n") - } - - def serialize(s: State): String = { - s"state ${s.sym.name} : ${SMTExpr.serializeType(s.sym)}" + - s.init.map("\n [init] " + _).getOrElse("") + - s.next.map("\n [next] " + _).getOrElse("") - } - - def systemExpressions(sys: TransitionSystem): List[SMTExpr] = - sys.signals.map(_.e) ++ sys.states.flatMap(s => s.init ++ s.next) - - def findUninterpretedFunctions(sys: TransitionSystem): List[DeclareFunction] = { - val calls = systemExpressions(sys).flatMap(findUFCalls) - // find unique functions - calls.groupBy(_.sym.name).map(_._2.head).toList - } - - private def findUFCalls(e: SMTExpr): List[DeclareFunction] = { - val f = e match { - case BVFunctionCall(name, args, width) => - Some(DeclareFunction(BVSymbol(name, width), args)) - case ArrayFunctionCall(name, args, indexWidth, dataWidth) => - Some(DeclareFunction(ArraySymbol(name, indexWidth, dataWidth), args)) - case _ => None - } - f.toList ++ e.children.flatMap(findUFCalls) - } -} - -private object TopologicalSort { - - /** Ensures that all signals in the resulting system are topologically sorted. - * This is necessary because [[firrtl.transforms.RemoveWires]] does - * not sort assignments to outputs, submodule inputs nor memory ports. - */ - def run(sys: TransitionSystem): TransitionSystem = { - val inputsAndStates = sys.inputs.map(_.name) ++ sys.states.map(_.sym.name) - val signalOrder = sort(sys.signals.map(s => s.name -> s.e), inputsAndStates) - // TODO: maybe sort init expressions of states (this should not be needed most of the time) - signalOrder match { - case None => sys - case Some(order) => - val signalMap = sys.signals.map(s => s.name -> s).toMap - // we flatMap over `get` in order to ignore inputs/states in the order - sys.copy(signals = order.flatMap(signalMap.get).toList) - } - } - - private def sort(signals: Iterable[(String, SMTExpr)], globalSignals: Iterable[String]): Option[Iterable[String]] = { - val known = new mutable.HashSet[String]() ++ globalSignals - var needsReordering = false - val digraph = new MutableDiGraph[String] - signals.foreach { - case (name, expr) => - digraph.addVertex(name) - val uniqueDependencies = mutable.LinkedHashSet[String]() ++ findDependencies(expr) - uniqueDependencies.foreach { d => - if (!known.contains(d)) { needsReordering = true } - digraph.addPairWithEdge(name, d) - } - known.add(name) - } - if (needsReordering) { - Some(digraph.linearize.reverse) - } else { None } - } - - private def findDependencies(expr: SMTExpr): List[String] = expr match { - case BVSymbol(name, _) => List(name) - case ArraySymbol(name, _, _) => List(name) - case other => other.children.flatMap(findDependencies) - } -} diff --git a/src/main/scala/firrtl/backends/experimental/smt/UninterpretedModuleAnnotation.scala b/src/main/scala/firrtl/backends/experimental/smt/UninterpretedModuleAnnotation.scala deleted file mode 100644 index c7442f69..00000000 --- a/src/main/scala/firrtl/backends/experimental/smt/UninterpretedModuleAnnotation.scala +++ /dev/null @@ -1,86 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Author: Kevin Laeufer <laeufer@cs.berkeley.edu> - -package firrtl.backends.experimental.smt - -import firrtl.annotations._ -import firrtl.ir -import firrtl.passes.PassException - -/** ExtModules annotated as UninterpretedModule will be modelled as - * UninterpretedFunction (SMTLib) or constant arrays (btor2). - * This can be useful when trying to abstract over a function that the - * SMT solver or model checker is struggling with. - * - * E.g., one could declare an abstract 64bit multiplier like this: - * ``` - * extmodule Mul64 : - * input a : UInt<64> - * input b : UInt<64> - * output r : UInt<64> - * ``` - * Now instead of using Chisel to actually implement a multiplication circuit - * we can instantiate this Mul64 module twice: Once in our implementation - * and once for our correctness property that might specify how the - * multiply instruction is supposed to be executed on our CPU. - * Now instead of having to prove equivalence of multiplication circuits, the - * solver only has to make sure that the connections to the multiplier are correct, - * since if `a` and `b` are the same on both instances of `Mul64`, then the `r` output - * will also be the same. This is a much easier problem and will result in much faster - * solving due to manual abstraction. - * - * When [[stateBits]] is 0, we model the module as purely combinatorial circuit and - * thus expect there to be no clock wire going into the module. - * Every output is thus a function of all inputs of the module. - * - * When [[stateBits]] is an N greater than zero, we will model the module as having an abstract state of width N. - * Thus on every clock transition the abstract state is updated and all outputs will take the state - * as well as the current inputs as arguments. - * TODO: Support for stateful circuits is work in progress. - * - * All output functions well be prefixed with [[prefix]] and end in the name of the output pin. - * It is the users responsibility to ensure that all function names will be unique by choosing apropriate - * prefixes. - * - * The annotation is consumed by the [[FirrtlToTransitionSystem]] pass. - */ -case class UninterpretedModuleAnnotation(target: ModuleTarget, prefix: String, stateBits: Int = 0) - extends SingleTargetAnnotation[ModuleTarget] { - require(stateBits >= 0, "negative number of bits is forbidden") - if (stateBits > 0) throw new NotImplementedError("TODO: support for stateful circuits is not implemented yet!") - override def duplicate(n: ModuleTarget) = copy(n) -} - -object UninterpretedModuleAnnotation { - - /** checks to see whether the annotation module can actually be abstracted. Use *after* LowerTypes! */ - def checkModule(m: ir.DefModule, anno: UninterpretedModuleAnnotation): Unit = m match { - case _: ir.Module => - throw new UninterpretedModuleException(s"UninterpretedModuleAnnotation can only be used with extmodule! $anno") - case m: ir.ExtModule => - val clockInputs = m.ports.collect { case p @ ir.Port(_, _, ir.Input, ir.ClockType) => p.name } - val clockOutput = m.ports.collect { case p @ ir.Port(_, _, ir.Output, ir.ClockType) => p.name } - val asyncResets = m.ports.collect { case p @ ir.Port(_, _, _, ir.AsyncResetType) => p.name } - if (clockOutput.nonEmpty) { - throw new UninterpretedModuleException( - s"We do not support clock outputs for uninterpreted modules! $clockOutput" - ) - } - if (asyncResets.nonEmpty) { - throw new UninterpretedModuleException( - s"We do not support async reset I/O for uninterpreted modules! $asyncResets" - ) - } - if (anno.stateBits == 0) { - if (clockInputs.nonEmpty) { - throw new UninterpretedModuleException(s"A combinatorial module may not have any clock inputs! $clockInputs") - } - } else { - if (clockInputs.size != 1) { - throw new UninterpretedModuleException(s"A stateful module must have exactly one clock input! $clockInputs") - } - } - } -} - -private class UninterpretedModuleException(s: String) extends PassException(s) diff --git a/src/main/scala/firrtl/backends/experimental/smt/random/DefRandom.scala b/src/main/scala/firrtl/backends/experimental/smt/random/DefRandom.scala deleted file mode 100644 index 7381056e..00000000 --- a/src/main/scala/firrtl/backends/experimental/smt/random/DefRandom.scala +++ /dev/null @@ -1,31 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -package firrtl.backends.experimental.smt.random - -import firrtl.Utils -import firrtl.ir._ - -/** Named source of random values. If there is no clock expression, than it will be clocked by the global clock. */ -case class DefRandom( - info: Info, - name: String, - tpe: Type, - clock: Option[Expression], - en: Expression = Utils.True()) - extends Statement - with HasInfo - with IsDeclaration - with CanBeReferenced - with UseSerializer { - def mapStmt(f: Statement => Statement): Statement = this - def mapExpr(f: Expression => Expression): Statement = - DefRandom(info, name, tpe, clock.map(f), f(en)) - def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) - def mapString(f: String => String): Statement = this.copy(name = f(name)) - def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) - def foreachStmt(f: Statement => Unit): Unit = () - def foreachExpr(f: Expression => Unit): Unit = { clock.foreach(f); f(en) } - def foreachType(f: Type => Unit): Unit = f(tpe) - def foreachString(f: String => Unit): Unit = f(name) - def foreachInfo(f: Info => Unit): Unit = f(info) -} diff --git a/src/main/scala/firrtl/backends/experimental/smt/random/InvalidToRandomPass.scala b/src/main/scala/firrtl/backends/experimental/smt/random/InvalidToRandomPass.scala deleted file mode 100644 index c7eaad74..00000000 --- a/src/main/scala/firrtl/backends/experimental/smt/random/InvalidToRandomPass.scala +++ /dev/null @@ -1,125 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -package firrtl.backends.experimental.smt.random - -import firrtl._ -import firrtl.annotations.NoTargetAnnotation -import firrtl.ir._ -import firrtl.passes._ -import firrtl.options.Dependency -import firrtl.stage.Forms -import firrtl.transforms.RemoveWires - -import scala.collection.mutable - -/** Chooses how to model explicit and implicit invalid values in the circuit */ -case class InvalidToRandomOptions( - randomizeInvalidSignals: Boolean = true, - randomizeDivisionByZero: Boolean = true) - extends NoTargetAnnotation - -/** Replaces all explicit and implicit "invalid" values with random values. - * Explicit invalids are: - * - signal is invalid - * - signal <= valid(..., expr) - * Implicit invalids are: - * - a / b when eq(b, 0) - */ -object InvalidToRandomPass extends Transform with DependencyAPIMigration { - override def prerequisites = Forms.LowForm - // once ValidIf has been removed, we can no longer detect and randomize them - override def optionalPrerequisiteOf = Seq(Dependency(RemoveValidIf)) - override def invalidates(a: Transform) = a match { - // this pass might destroy SSA form, as we add a wire for the data field of every read port - case _: RemoveWires => true - // TODO: should we add some optimization passes here? we could be generating some dead code. - case _ => false - } - - override protected def execute(state: CircuitState): CircuitState = { - val opts = state.annotations.collect { case o: InvalidToRandomOptions => o } - require(opts.size < 2, s"Multiple options: $opts") - val opt = opts.headOption.getOrElse(InvalidToRandomOptions()) - - // quick exit if we just want to skip this pass - if (!opt.randomizeDivisionByZero && !opt.randomizeInvalidSignals) { - state - } else { - val c = state.circuit.mapModule(onModule(_, opt)) - state.copy(circuit = c) - } - } - - private def onModule(m: DefModule, opt: InvalidToRandomOptions): DefModule = m match { - case d: DescribedMod => - throw new RuntimeException(s"CompilerError: Unexpected internal node: ${d.serialize}") - case e: ExtModule => e - case mod: Module => - val namespace = Namespace(mod) - mod.mapStmt(onStmt(namespace, opt, _)) - } - - private def onStmt(namespace: Namespace, opt: InvalidToRandomOptions, s: Statement): Statement = s match { - case IsInvalid(info, loc: RefLikeExpression) if opt.randomizeInvalidSignals => - val name = namespace.newName(loc.serialize.replace('.', '_') + "_invalid") - val rand = DefRandom(info, name, loc.tpe, None) - Block(List(rand, Connect(info, loc, Reference(rand)))) - case other => - val info = other match { - case h: HasInfo => h.info - case _ => NoInfo - } - val prefix = other match { - case c: Connect => c.loc.serialize.replace('.', '_') - case h: HasName => h.name - case _ => "" - } - val ctx = ExprCtx(namespace, opt, prefix, info, mutable.ListBuffer[Statement]()) - val stmt = other.mapExpr(onExpr(ctx, _)).mapStmt(onStmt(namespace, opt, _)) - if (ctx.rands.isEmpty) { stmt } - else { Block(Block(ctx.rands.toList), stmt) } - } - - private case class ExprCtx( - namespace: Namespace, - opt: InvalidToRandomOptions, - prefix: String, - info: Info, - rands: mutable.ListBuffer[Statement]) - - private def onExpr(ctx: ExprCtx, e: Expression): Expression = - e.mapExpr(onExpr(ctx, _)) match { - case ValidIf(_, value, tpe) if tpe == ClockType => - // we currently assume that clocks are always valid - // TODO: is that a good assumption? - value - case ValidIf(cond, value, tpe) if ctx.opt.randomizeInvalidSignals => - makeRand(ctx, cond, tpe, value, invert = true) - case d @ DoPrim(PrimOps.Div, Seq(_, den), _, tpe) if ctx.opt.randomizeDivisionByZero => - val denIsZero = Utils.eq(den, Utils.getGroundZero(den.tpe.asInstanceOf[GroundType])) - makeRand(ctx, denIsZero, tpe, d, invert = false) - case other => other - } - - private def makeRand( - ctx: ExprCtx, - cond: Expression, - tpe: Type, - value: Expression, - invert: Boolean - ): Expression = { - val name = ctx.namespace.newName(if (ctx.prefix.isEmpty) "invalid" else ctx.prefix + "_invalid") - // create a condition node if the condition isn't a reference already - val condRef = cond match { - case r: RefLikeExpression => if (invert) Utils.not(r) else r - case other => - val cond = if (invert) Utils.not(other) else other - val condNode = DefNode(ctx.info, ctx.namespace.newName(name + "_cond"), cond) - ctx.rands.append(condNode) - Reference(condNode) - } - val rand = DefRandom(ctx.info, name, tpe, None, condRef) - ctx.rands.append(rand) - Utils.mux(condRef, Reference(rand), value) - } -} diff --git a/src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala b/src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala deleted file mode 100644 index 96582778..00000000 --- a/src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala +++ /dev/null @@ -1,461 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -package firrtl.backends.experimental.smt.random - -import firrtl.Utils.{isLiteral, BoolType} -import firrtl._ -import firrtl.annotations.NoTargetAnnotation -import firrtl.backends.experimental.smt._ -import firrtl.ir._ -import firrtl.options.Dependency -import firrtl.passes.MemPortUtils.memPortField -import firrtl.passes.memlib.AnalysisUtils.Connects -import firrtl.passes.memlib.InferReadWritePass.checkComplement -import firrtl.passes.memlib.{AnalysisUtils, InferReadWritePass, VerilogMemDelays} -import firrtl.stage.Forms -import firrtl.transforms.RemoveWires - -import scala.collection.mutable - -/** Chooses which undefined memory behaviors should be instrumented. */ -case class UndefinedMemoryBehaviorOptions( - randomizeWriteWriteConflicts: Boolean = true, - assertNoOutOfBoundsWrites: Boolean = false, - randomizeOutOfBoundsRead: Boolean = true, - randomizeDisabledReads: Boolean = true, - randomizeReadWriteConflicts: Boolean = true) - extends NoTargetAnnotation - -/** Adds sources of randomness to model the various "undefined behaviors" of firrtl memory. - * - Write/Write conflict: leads to arbitrary value written to write address - * - Out-of-bounds write: assertion failure (disabled by default) - * - Out-Of-bounds read: leads to arbitrary value being read - * - Read w/ en=0: leads to arbitrary value being read - * - Read/Write conflict: leads to arbitrary value being read - */ -object UndefinedMemoryBehaviorPass extends Transform with DependencyAPIMigration { - override def prerequisites = Forms.LowForm - override def optionalPrerequisiteOf = Seq(Dependency(VerilogMemDelays)) - override def invalidates(a: Transform) = a match { - // this pass might destroy SSA form, as we add a wire for the data field of every read port - case _: RemoveWires => true - // TODO: should we add some optimization passes here? we could be generating some dead code. - case _ => false - } - - override protected def execute(state: CircuitState): CircuitState = { - val opts = state.annotations.collect { case o: UndefinedMemoryBehaviorOptions => o } - require(opts.size < 2, s"Multiple options: $opts") - val opt = opts.headOption.getOrElse(UndefinedMemoryBehaviorOptions()) - - val c = state.circuit.mapModule(onModule(_, opt)) - state.copy(circuit = c) - } - - private def onModule(m: DefModule, opt: UndefinedMemoryBehaviorOptions): DefModule = m match { - case mod: Module => - val mems = findMems(mod) - if (mems.isEmpty) { mod } - else { - val namespace = Namespace(mod) - val connects = AnalysisUtils.getConnects(mod) - new InstrumentMems(opt, mems, connects, namespace).run(mod) - } - case other => other - } - - /** finds all memory instantiations in a circuit */ - private def findMems(m: Module): List[DefMemory] = { - val mems = mutable.ListBuffer[DefMemory]() - m.foreachStmt(findMems(_, mems)) - mems.toList - } - private def findMems(s: Statement, mems: mutable.ListBuffer[DefMemory]): Unit = s match { - case mem: DefMemory => mems.append(mem) - case other => other.foreachStmt(findMems(_, mems)) - } -} - -private class InstrumentMems( - opt: UndefinedMemoryBehaviorOptions, - mems: List[DefMemory], - connects: Connects, - namespace: Namespace) { - def run(m: Module): DefModule = { - // ensure that all memories are the kind we can support - mems.foreach(checkSupported(m.name, _)) - - // transform circuit - val body = m.body.mapStmt(transform) - m.copy(body = Block(body +: newStmts.toList)) - } - - // used to replace memory signals like `m.r.data` in RHS expressions - private val exprReplacements = mutable.HashMap[String, Expression]() - // add new statements at the end of the circuit - private val newStmts = mutable.ListBuffer[Statement]() - // disconnect references so that they can be reassigned - private val doDisconnect = mutable.HashSet[String]() - - // generates new expression replacements and immediately uses them - private def transform(s: Statement): Statement = s.mapStmt(transform) match { - case mem: DefMemory => onMem(mem) - case sx: Connect if doDisconnect.contains(sx.loc.serialize) => EmptyStmt // Filter old mem connections - case sx => sx.mapExpr(swapMemRefs) - } - private def swapMemRefs(e: Expression): Expression = e.mapExpr(swapMemRefs) match { - case sf: RefLikeExpression => exprReplacements.getOrElse(sf.serialize, sf) - case ex => ex - } - - private def onMem(m: DefMemory): Statement = { - // collect wire and random statement defines - implicit val declarations: mutable.ListBuffer[Statement] = mutable.ListBuffer[Statement]() - - // cache for the expressions of memory inputs - implicit val cache: mutable.HashMap[String, Expression] = mutable.HashMap[String, Expression]() - - // only for non power of 2 memories do we have to worry about reading or writing out of bounds - val canBeOutOfBounds = !isPow2(m.depth) - - // only if we have at least two write ports, can there be conflicts - val canHaveWriteWriteConflicts = m.writers.size > 1 - - // only certain memory types exhibit undefined read/write conflicts - val readWriteUndefined = (m.readLatency == m.writeLatency) && (m.readUnderWrite == ReadUnderWrite.Undefined) - assert( - m.readLatency == 0 || m.readLatency == m.writeLatency, - "TODO: what happens if a sync read mem has asymmetrical latencies?" - ) - - // a write port is enabled iff mask & en - val writeEn = m.writers.map { write => - val enRef = memPortField(m, write, "en") - val maskRef = memPortField(m, write, "mask") - - val prods = getProductTerms(enRef) ++ getProductTerms(maskRef) - val expr = Utils.and(readInput(m.info, enRef), readInput(m.info, maskRef)) - - (expr, prods) - } - - // implement the three undefined read behaviors - m.readers.foreach { read => - // many memories have their read enable hard wired to true - val canBeDisabled = !isTrue(readInput(m, read, "en")) - val readEn = if (canBeDisabled) readInput(m, read, "en") else Utils.True() - - // collect signals that would lead to a randomization - var doRand = List[Expression]() - - // randomize the read value when the address is out of bounds - if (canBeOutOfBounds && opt.randomizeOutOfBoundsRead) { - val addr = readInput(m, read, "addr") - val cond = Utils.and(readEn, Utils.not(isInBounds(m.depth, addr))) - val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_oob"), cond) - declarations += node - doRand = Reference(node) +: doRand - } - - if (readWriteUndefined && opt.randomizeReadWriteConflicts) { - val cond = readWriteConflict(m, read, writeEn) - val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_rwc"), cond) - declarations += node - doRand = Reference(node) +: doRand - } - - // randomize the read value when the read is disabled - if (canBeDisabled && opt.randomizeDisabledReads) { - val cond = Utils.not(readEn) - val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_disabled"), cond) - declarations += node - doRand = Reference(node) +: doRand - } - - // if there are no signals that would require a randomization, there is nothing to do - if (doRand.isEmpty) { - // nothing to do - } else { - val doRandName = s"${m.name}_${read}_do_rand" - val doRandNode = if (doRand.size == 1) { doRand.head } - else { - val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_do_rand"), doRand.reduce(Utils.or)) - declarations += node - Reference(node) - } - val doRandSignal = if (m.readLatency == 0) { doRandNode } - else { - val clock = readInput(m, read, "clk") - val (signal, regDecls) = pipeline(m.info, clock, doRandName, doRandNode, m.readLatency) - declarations ++= regDecls - signal - } - - // all old rhs references to m.r.data need to replace with m_r_data which might be random - val dataRef = memPortField(m, read, "data") - val dataWire = DefWire(m.info, namespace.newName(s"${m.name}_${read}_data"), m.dataType) - declarations += dataWire - exprReplacements(dataRef.serialize) = Reference(dataWire) - - // create a source of randomness and connect the new wire either to the actual data port or to the random value - val randName = namespace.newName(s"${m.name}_${read}_rand_data") - val random = DefRandom(m.info, randName, m.dataType, Some(readInput(m, read, "clk")), doRandSignal) - declarations += random - val data = Utils.mux(doRandSignal, Reference(random), dataRef) - newStmts.append(Connect(m.info, Reference(dataWire), data)) - } - } - - // write - if (opt.randomizeWriteWriteConflicts) { - writeWriteConflicts(m, writeEn) - } - - // add an assertion that if the write is taking place, then the address must be in range - if (canBeOutOfBounds && opt.assertNoOutOfBoundsWrites) { - m.writers.zip(writeEn).foreach { - case (write, (combinedEn, _)) => - val addr = readInput(m, write, "addr") - val cond = Utils.implies(combinedEn, isInBounds(m.depth, addr)) - val clk = readInput(m, write, "clk") - val a = Verification(Formal.Assert, m.info, clk, cond, Utils.True(), StringLit("out of bounds read")) - newStmts.append(a) - } - } - - Block(m +: declarations.toList) - } - - private def pipeline( - info: Info, - clk: Expression, - prefix: String, - e: Expression, - latency: Int - ): (Expression, Seq[Statement]) = { - require(latency > 0) - val regs = (1 to latency).map { i => - val name = namespace.newName(prefix + s"_r$i") - DefRegister(info, name, e.tpe, clk, Utils.False(), Reference(name, e.tpe, RegKind, UnknownFlow)) - } - val expr = regs.foldLeft(e) { - case (prev, reg) => - newStmts.append(Connect(info, Reference(reg), prev)) - Reference(reg) - } - (expr, regs) - } - - private def readWriteConflict( - m: DefMemory, - read: String, - writeEn: Seq[(Expression, ProdTerms)] - )( - implicit cache: mutable.HashMap[String, Expression], - declarations: mutable.ListBuffer[Statement] - ): Expression = { - if (m.writers.isEmpty) return Utils.False() - - val readEn = readInput(m, read, "en") - val readProd = getProductTerms(readEn) - - // create all conflict signals - val conflicts = m.writers.zip(writeEn).map { - case (write, (writeEn, writeProd)) => - if (isMutuallyExclusive(readProd, writeProd)) { - Utils.False() - } else { - val name = namespace.newName(s"${m.name}_${read}_${write}_rwc") - val bothEn = Utils.and(readEn, writeEn) - val sameAddr = Utils.eq(readInput(m, read, "addr"), readInput(m, write, "addr")) - // we need a wire because this condition might be used in a random statement - val wire = DefWire(m.info, name, BoolType) - declarations += wire - newStmts.append(Connect(m.info, Reference(wire), Utils.and(bothEn, sameAddr))) - Reference(wire) - } - } - - conflicts.reduce(Utils.or) - } - - private type ProdTerms = Seq[Expression] - private def writeWriteConflicts( - m: DefMemory, - writeEn: Seq[(Expression, ProdTerms)] - )( - implicit cache: mutable.HashMap[String, Expression], - declarations: mutable.ListBuffer[Statement] - ): Unit = { - if (m.writers.size < 2) return - - // we first create all conflict signals: - val conflict = - m.writers - .zip(writeEn) - .zipWithIndex - .flatMap { - case ((w1, (en1, en1Prod)), i1) => - m.writers.zip(writeEn).drop(i1 + 1).map { - case (w2, (en2, en2Prod)) => - if (isMutuallyExclusive(en1Prod, en2Prod)) { - (w1, w2) -> Utils.False() - } else { - val name = namespace.newName(s"${m.name}_${w1}_${w2}_wwc") - val bothEn = Utils.and(en1, en2) - val sameAddr = Utils.eq(readInput(m, w1, "addr"), readInput(m, w2, "addr")) - // we need a wire because this condition might be used in a random statement - val wire = DefWire(m.info, name, BoolType) - declarations += wire - newStmts.append(Connect(m.info, Reference(wire), Utils.and(bothEn, sameAddr))) - (w1, w2) -> Reference(wire) - } - } - } - .toMap - - // now we calculate the new enable and data signals - m.writers.zip(writeEn).zipWithIndex.foreach { - case ((w1, (en1, _)), i1) => - val prev = m.writers.take(i1) - val next = m.writers.drop(i1 + 1) - - // the write is enabled if the original enable is true and there are no prior conflicts - val en = if (prev.isEmpty) { - en1 - } else { - val prevConflicts = prev.map(o => conflict(o, w1)).reduce(Utils.or) - Utils.and(en1, Utils.not(prevConflicts)) - } - - // we write random data if there is a conflict with any of the next ports - if (next.isEmpty) { - // nothing to do, leave data as is - } else { - val nextConflicts = next.map(n => conflict(w1, n)).reduce(Utils.or) - // if the conflict expression is more complex, create a node for the signal - val hasConflict = nextConflicts match { - case _: DoPrim | _: Mux => - val node = DefNode(m.info, namespace.newName(s"${m.name}_${w1}_wwc_active"), nextConflicts) - declarations += node - Reference(node) - case _ => nextConflicts - } - - // create the source of randomness - val name = namespace.newName(s"${m.name}_${w1}_wwc_data") - val random = DefRandom(m.info, name, m.dataType, Some(readInput(m, w1, "clk")), hasConflict) - declarations.append(random) - - // generate new data input - val data = Utils.mux(hasConflict, Reference(random), readInput(m, w1, "data")) - newStmts.append(Connect(m.info, memPortField(m, w1, "data"), data)) - doDisconnect.add(memPortField(m, w1, "data").serialize) - } - - // connect data enable signals - val maskIsOne = isTrue(readInput(m, w1, "mask")) - if (!maskIsOne) { - newStmts.append(Connect(m.info, memPortField(m, w1, "mask"), Utils.True())) - doDisconnect.add(memPortField(m, w1, "mask").serialize) - } - newStmts.append(Connect(m.info, memPortField(m, w1, "en"), en)) - doDisconnect.add(memPortField(m, w1, "en").serialize) - } - } - - /** check whether two signals can be proven to be mutually exclusive */ - private def isMutuallyExclusive(prodA: ProdTerms, prodB: ProdTerms): Boolean = { - // this uses the same approach as the InferReadWrite pass - val proofOfMutualExclusion = prodA.find(a => prodB.exists(b => checkComplement(a, b))) - proofOfMutualExclusion.nonEmpty - } - - /** memory inputs my not be read, only assigned to, thus we might need to add a wire to make them accessible */ - private def readInput( - info: Info, - signal: RefLikeExpression - )( - implicit cache: mutable.HashMap[String, Expression], - declarations: mutable.ListBuffer[Statement] - ): Expression = - cache.getOrElseUpdate( - signal.serialize, { - // if it is a literal, we just return it - val value = connects(signal.serialize) - if (isLiteral(value)) { - value - } else { - // otherwise we make a wire that refelect the value - val wire = DefWire(info, copyName(signal), signal.tpe) - declarations += wire - - // connect the old expression to the new wire - val con = Connect(info, Reference(wire), value) - newStmts.append(con) - - // use a reference to this new wire - Reference(wire) - } - } - ) - private def readInput( - m: DefMemory, - port: String, - field: String - )( - implicit cache: mutable.HashMap[String, Expression], - declarations: mutable.ListBuffer[Statement] - ): Expression = - readInput(m.info, memPortField(m, port, field)) - - private def copyName(ref: RefLikeExpression): String = - namespace.newName(ref.serialize.replace('.', '_')) - - private def isInBounds(depth: BigInt, addr: Expression): Expression = { - val width = getWidth(addr) - // depth > addr (e.g. if the depth is 3, then the address must be in {0, 1, 2}) - DoPrim(PrimOps.Gt, List(UIntLiteral(depth, width), addr), List(), BoolType) - } - - private def isPow2(v: BigInt): Boolean = ((v - 1) & v) == 0 - - private def checkSupported(modName: String, m: DefMemory): Unit = { - assert(m.readwriters.isEmpty, s"[$modName] Combined read/write ports are currently not supported!") - if (m.writeLatency != 1) { - throw new UnsupportedFeatureException(s"[$modName] memories with write latency > 1 (${m.name})") - } - if (m.readLatency > 1) { - throw new UnsupportedFeatureException(s"[$modName] memories with read latency > 1 (${m.name})") - } - } - - private def getProductTerms(e: Expression): ProdTerms = - InferReadWritePass.getProductTerms(connects)(e) - - /** tries to expand the expression based on the connects we collected */ - private def expandExpr(e: Expression, fuel: Int): Expression = { - e match { - case m @ Mux(cond, tval, fval, _) => - m.copy(cond = expandExpr(cond, fuel), tval = expandExpr(tval, fuel), fval = expandExpr(fval, fuel)) - case p @ DoPrim(_, args, _, _) => - p.copy(args = args.map(expandExpr(_, fuel))) - case r: RefLikeExpression => - if (fuel > 0) { - connects.get(r.serialize) match { - case None => r - case Some(expr) => expandExpr(expr, fuel - 1) - } - } else { - r - } - case other => other - } - } - - private def isTrue(e: Expression): Boolean = simplifyExpr(expandExpr(e, fuel = 2)) == Utils.True() - - private def simplifyExpr(e: Expression): Expression = { - e // TODO: better simplification could improve the resulting circuit size - } -} diff --git a/src/main/scala/firrtl/backends/proto/ProtoBufEmitter.scala b/src/main/scala/firrtl/backends/proto/ProtoBufEmitter.scala index c617ea27..9af68235 100644 --- a/src/main/scala/firrtl/backends/proto/ProtoBufEmitter.scala +++ b/src/main/scala/firrtl/backends/proto/ProtoBufEmitter.scala @@ -7,7 +7,7 @@ import firrtl.annotations.NoTargetAnnotation import firrtl.options.CustomFileEmission import firrtl.options.Viewer.view import firrtl.proto.ToProto -import firrtl.stage.{FirrtlOptions, Forms} +import firrtl.stage.{FirrtlOptions, Forms, FirrtlOptionsView} import firrtl.stage.TransformManager.TransformDependency import firrtl.traversals.Foreachers._ import java.io.{ByteArrayOutputStream, Writer} diff --git a/src/main/scala/firrtl/backends/verilog/VerilogEmitter.scala b/src/main/scala/firrtl/backends/verilog/VerilogEmitter.scala index 30d2e891..2634a8e1 100644 --- a/src/main/scala/firrtl/backends/verilog/VerilogEmitter.scala +++ b/src/main/scala/firrtl/backends/verilog/VerilogEmitter.scala @@ -511,7 +511,7 @@ class VerilogEmitter extends SeqTransform with Emitter { } private val emissionAnnos = annotations.collect { - case m: SingleTargetAnnotation[ReferenceTarget] @unchecked with EmissionOption => m + case m: SingleTargetAnnotation[ReferenceTarget] @unchecked & EmissionOption => m } annotations.foreach { diff --git a/src/main/scala/firrtl/constraint/Constraint.scala b/src/main/scala/firrtl/constraint/Constraint.scala index 87d15737..a8e49963 100644 --- a/src/main/scala/firrtl/constraint/Constraint.scala +++ b/src/main/scala/firrtl/constraint/Constraint.scala @@ -6,7 +6,7 @@ package firrtl.constraint trait Constraint { def serialize: String def map(f: Constraint => Constraint): Constraint - val children: Vector[Constraint] + lazy val children: Vector[Constraint] def reduce(): Constraint } diff --git a/src/main/scala/firrtl/constraint/IsFloor.scala b/src/main/scala/firrtl/constraint/IsFloor.scala index 48173b55..4e3e9166 100644 --- a/src/main/scala/firrtl/constraint/IsFloor.scala +++ b/src/main/scala/firrtl/constraint/IsFloor.scala @@ -22,7 +22,7 @@ case class IsFloor private (child: Constraint, dummyArg: Int) extends Constraint case x: IsFloor => x case _ => this } - val children = Vector(child) + lazy val children = Vector(child) override def map(f: Constraint => Constraint): Constraint = IsFloor(f(child)) diff --git a/src/main/scala/firrtl/constraint/IsKnown.scala b/src/main/scala/firrtl/constraint/IsKnown.scala index b11adc3a..4da2efed 100644 --- a/src/main/scala/firrtl/constraint/IsKnown.scala +++ b/src/main/scala/firrtl/constraint/IsKnown.scala @@ -36,7 +36,7 @@ trait IsKnown extends Constraint { override def map(f: Constraint => Constraint): Constraint = this - val children: Vector[Constraint] = Vector.empty[Constraint] + lazy val children: Vector[Constraint] = Vector.empty[Constraint] def reduce(): IsKnown = this } diff --git a/src/main/scala/firrtl/constraint/IsPow.scala b/src/main/scala/firrtl/constraint/IsPow.scala index 92ca649f..f5db351f 100644 --- a/src/main/scala/firrtl/constraint/IsPow.scala +++ b/src/main/scala/firrtl/constraint/IsPow.scala @@ -23,7 +23,7 @@ case class IsPow private (child: Constraint, dummyArg: Int) extends Constraint { case _ => this } - val children = Vector(child) + lazy val children = Vector(child) override def map(f: Constraint => Constraint): Constraint = IsPow(f(child)) diff --git a/src/main/scala/firrtl/constraint/IsVar.scala b/src/main/scala/firrtl/constraint/IsVar.scala index 1675a6fc..11a25e2c 100644 --- a/src/main/scala/firrtl/constraint/IsVar.scala +++ b/src/main/scala/firrtl/constraint/IsVar.scala @@ -20,7 +20,7 @@ trait IsVar extends Constraint { override def reduce() = this - val children = Vector() + lazy val children = Vector() } case class VarCon(name: String) extends IsVar diff --git a/src/main/scala/firrtl/features/LetterCaseTransform.scala b/src/main/scala/firrtl/features/LetterCaseTransform.scala index 6d0c2b30..8686908a 100644 --- a/src/main/scala/firrtl/features/LetterCaseTransform.scala +++ b/src/main/scala/firrtl/features/LetterCaseTransform.scala @@ -8,7 +8,7 @@ import firrtl.transforms.ManipulateNames import scala.reflect.ClassTag /** Parent of transforms that do change the letter case of names in a FIRRTL circuit */ -abstract class LetterCaseTransform[A <: ManipulateNames[_]: ClassTag] extends ManipulateNames[A] { +abstract class LetterCaseTransform[A <: ManipulateNames[?]: ClassTag] extends ManipulateNames[A] { protected def newName: String => String diff --git a/src/main/scala/firrtl/graph/DiGraph.scala b/src/main/scala/firrtl/graph/DiGraph.scala index 99bf8403..154aaab3 100644 --- a/src/main/scala/firrtl/graph/DiGraph.scala +++ b/src/main/scala/firrtl/graph/DiGraph.scala @@ -279,11 +279,12 @@ class DiGraph[T](private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) { if (frame.childCall.isEmpty) { if (lowlinks(v) == indices(v)) { val scc = new mutable.ArrayBuffer[T] + while (scc.last != v) do { val w = stack.pop() onstack -= w scc += w - } while (scc.last != v); + } sccs.append(scc.toSeq) } callStack.pop() diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index 8ba29d8e..df32d287 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -4,8 +4,6 @@ package firrtl package ir import Utils.{dec2string, trim} -import firrtl.backends.experimental.smt.random.DefRandom -import dataclass.{data, since} import firrtl.constraint.{Constraint, IsKnown, IsVar} import org.apache.commons.text.translate.{AggregateTranslator, JavaUnicodeEscaper, LookupTranslator} @@ -349,9 +347,6 @@ object Reference { /** Creates a Reference from a Register */ def apply(reg: DefRegister): Reference = Reference(reg.name, reg.tpe, RegKind, UnknownFlow) - /** Creates a Reference from a Random Source */ - def apply(rnd: DefRandom): Reference = Reference(rnd.name, rnd.tpe, RandomKind, UnknownFlow) - /** Creates a Reference from a Node */ def apply(node: DefNode): Reference = Reference(node.name, node.value.tpe, NodeKind, SourceFlow) @@ -720,7 +715,12 @@ case class Attach(info: Info, exprs: Seq[Expression]) extends Statement with Has def foreachInfo(f: Info => Unit): Unit = f(info) } -@data class Stop(info: Info, ret: Int, clk: Expression, en: Expression, @since("FIRRTL 1.5") name: String = "") +class Stop( + val info: Info, + val ret: Int, + val clk: Expression, + val en: Expression, + val name: String = "") extends Statement with HasInfo with IsDeclaration @@ -728,7 +728,7 @@ case class Attach(info: Info, exprs: Seq[Expression]) extends Statement with Has def mapStmt(f: Statement => Statement): Statement = this def mapExpr(f: Expression => Expression): Statement = Stop(info, ret, f(clk), f(en), name) def mapType(f: Type => Type): Statement = this - def mapString(f: String => String): Statement = withName(f(name)) + def mapString(f: String => String): Statement = withName(f(name)).asInstanceOf[Statement] def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) def foreachStmt(f: Statement => Unit): Unit = () def foreachExpr(f: Expression => Unit): Unit = { f(clk); f(en) } @@ -736,6 +736,9 @@ case class Attach(info: Info, exprs: Seq[Expression]) extends Statement with Has def foreachString(f: String => Unit): Unit = f(name) def foreachInfo(f: Info => Unit): Unit = f(info) def copy(info: Info = info, ret: Int = ret, clk: Expression = clk, en: Expression = en): Stop = { + copyWithName(info, ret, clk, en, name) + } + def copyWithName(info: Info = info, ret: Int = ret, clk: Expression = clk, en: Expression = en, name: String): Stop = { Stop(info, ret, clk, en, name) } } @@ -744,14 +747,13 @@ object Stop { Some((s.info, s.ret, s.clk, s.en)) } } -@data class Print( - info: Info, - string: StringLit, - args: Seq[Expression], - clk: Expression, - en: Expression, - @since("FIRRTL 1.5") - name: String = "") +class Print( + val info: Info, + val string: StringLit, + val args: Seq[Expression], + val clk: Expression, + val en: Expression, + val name: String = "") extends Statement with HasInfo with IsDeclaration @@ -759,7 +761,7 @@ object Stop { def mapStmt(f: Statement => Statement): Statement = this def mapExpr(f: Expression => Expression): Statement = Print(info, string, args.map(f), f(clk), f(en), name) def mapType(f: Type => Type): Statement = this - def mapString(f: String => String): Statement = withName(f(name)) + def mapString(f: String => String): Statement = withName(f(name)).asInstanceOf[Statement] def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) def foreachStmt(f: Statement => Unit): Unit = () def foreachExpr(f: Expression => Unit): Unit = { args.foreach(f); f(clk); f(en) } @@ -773,6 +775,16 @@ object Stop { clk: Expression = clk, en: Expression = en ): Print = { + copyWithName(info, string, args, clk, en, name) + } + def copyWithName( + info: Info = info, + string: StringLit = string, + args: Seq[Expression] = args, + clk: Expression = clk, + en: Expression = en, + name: String + ): Print = { Print(info, string, args, clk, en, name) } } @@ -789,15 +801,14 @@ object Formal extends Enumeration { val Cover = Value("cover") } -@data class Verification( - op: Formal.Value, - info: Info, - clk: Expression, - pred: Expression, - en: Expression, - msg: StringLit, - @since("FIRRTL 1.5") - name: String = "") +class Verification( + val op: Formal.Value, + val info: Info, + val clk: Expression, + val pred: Expression, + val en: Expression, + val msg: StringLit, + val name: String = "") extends Statement with HasInfo with IsDeclaration @@ -806,7 +817,7 @@ object Formal extends Enumeration { def mapExpr(f: Expression => Expression): Statement = copy(clk = f(clk), pred = f(pred), en = f(en)) def mapType(f: Type => Type): Statement = this - def mapString(f: String => String): Statement = withName(f(name)) + def mapString(f: String => String): Statement = withName(f(name)).asInstanceOf[Statement] def mapInfo(f: Info => Info): Statement = copy(info = f(info)) def foreachStmt(f: Statement => Unit): Unit = () def foreachExpr(f: Expression => Unit): Unit = { f(clk); f(pred); f(en); } @@ -821,6 +832,17 @@ object Formal extends Enumeration { en: Expression = en, msg: StringLit = msg ): Verification = { + copyWithName(op, info, clk, pred, en, msg, name) + } + def copyWithName( + op: Formal.Value = op, + info: Info = info, + clk: Expression = clk, + pred: Expression = pred, + en: Expression = en, + msg: StringLit = msg, + name: String + ): Verification = { Verification(op, info, clk, pred, en, msg, name) } } @@ -924,13 +946,13 @@ case object UnknownBound extends Bound { def serialize: String = Serializer.serialize(this) def map(f: Constraint => Constraint): Constraint = this override def reduce(): Constraint = this - val children = Vector() + lazy val children = Vector() } case class CalcBound(arg: Constraint) extends Bound { def serialize: String = Serializer.serialize(this) def map(f: Constraint => Constraint): Constraint = f(arg) override def reduce(): Constraint = arg - val children = Vector(arg) + lazy val children = Vector(arg) } case class VarBound(name: String) extends IsVar with Bound { override def serialize: String = Serializer.serialize(this) @@ -1079,7 +1101,7 @@ case class IntervalType(lower: Bound, upper: Bound, point: Width) extends Ground }) /** If bounds are known, calculates the width, otherwise returns UnknownWidth */ - lazy val width: Width = (point, lower, upper) match { + val width: Width = (point, lower, upper) match { case (IntWidth(i), l: IsKnown, u: IsKnown) => IntWidth(Math.max(Utils.getSIntWidth(minAdjusted.get), Utils.getSIntWidth(maxAdjusted.get))) case _ => UnknownWidth diff --git a/src/main/scala/firrtl/ir/Serializer.scala b/src/main/scala/firrtl/ir/Serializer.scala index cf919b37..a2864d33 100644 --- a/src/main/scala/firrtl/ir/Serializer.scala +++ b/src/main/scala/firrtl/ir/Serializer.scala @@ -3,7 +3,6 @@ package firrtl.ir import firrtl.Utils -import firrtl.backends.experimental.smt.random.DefRandom import firrtl.constraint.Constraint case class Version(major: Int, minor: Int, patch: Int) { @@ -190,7 +189,7 @@ object Serializer { // We could initialze the StringBuilder size, but this is bad for small modules which may not // even reach the bufferSize. - private implicit val b = new StringBuilder + private implicit val b: StringBuilder = new StringBuilder // The flattening of Whens into WhenBegin and friends requires us to keep track of the // indention level @@ -264,11 +263,6 @@ object Serializer { case DefRegister(info, name, tpe, clock, reset, init) => b ++= "reg "; b ++= name; b ++= " : "; s(tpe); b ++= ", "; s(clock); b ++= " with :"; newLineAndIndent(1) b ++= "reset => ("; s(reset); b ++= ", "; s(init); b += ')'; s(info) - case DefRandom(info, name, tpe, clock, en) => - b ++= "rand "; b ++= name; b ++= " : "; s(tpe); - if (clock.isDefined) { b ++= ", "; s(clock.get); } - en match { case Utils.True() => case _ => b ++= " when "; s(en) } - s(info) case DefInstance(info, name, module, _) => b ++= "inst "; b ++= name; b ++= " of "; b ++= module; s(info) case DefMemory( info, diff --git a/src/main/scala/firrtl/options/DependencyManager.scala b/src/main/scala/firrtl/options/DependencyManager.scala index ae22b4b4..32f6b3c9 100644 --- a/src/main/scala/firrtl/options/DependencyManager.scala +++ b/src/main/scala/firrtl/options/DependencyManager.scala @@ -17,7 +17,7 @@ case class DependencyManagerException(message: String, cause: Throwable = null) * @tparam A the type over which this transforms * @tparam B the type of the [[firrtl.options.TransformLike TransformLike]] */ -trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends TransformLike[A] with DependencyAPI[B] { +trait DependencyManager[A, B <: TransformLike[A] & DependencyAPI[B]] extends TransformLike[A] with DependencyAPI[B] { import DependencyManagerUtils.CharSet override def prerequisites: Seq[Dependency[B]] = currentState @@ -52,7 +52,7 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends /** Store of conversions between classes and objects. Objects that do not exist in the map will be lazily constructed. */ protected lazy val dependencyToObject: LinkedHashMap[Dependency[B], B] = { - val init = LinkedHashMap[Dependency[B], B](knownObjects.map(x => oToD(x) -> x).toSeq: _*) + val init = LinkedHashMap[Dependency[B], B](knownObjects.map(x => oToD(x) -> x).toSeq*) (_targets ++ _currentState) .filter(!init.contains(_)) .map(x => init(x) = x.getObject()) @@ -84,9 +84,9 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends ): LinkedHashMap[B, LinkedHashSet[B]] = { val (queue, edges) = { - val a: Queue[Dependency[B]] = Queue(start.toSeq: _*) + val a: Queue[Dependency[B]] = Queue(start.toSeq*) val b: LinkedHashMap[B, LinkedHashSet[B]] = - LinkedHashMap[B, LinkedHashSet[B]](start.map((dToO(_) -> LinkedHashSet[B]())).toSeq: _*) + LinkedHashMap[B, LinkedHashSet[B]](start.map((dToO(_) -> LinkedHashSet[B]())).toSeq*) (a, b) } @@ -183,7 +183,7 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends } /** Wrap a possible [[CyclicException]] thrown by a thunk in a [[DependencyManagerException]] */ - private def cyclePossible[A](a: String, diGraph: DiGraph[_])(thunk: => A): A = try { thunk } + private def cyclePossible[A](a: String, diGraph: DiGraph[?])(thunk: => A): A = try { thunk } catch { case e: CyclicException => throw new DependencyManagerException( diff --git a/src/main/scala/firrtl/options/OptionParser.scala b/src/main/scala/firrtl/options/OptionParser.scala index 79163aea..27251b00 100644 --- a/src/main/scala/firrtl/options/OptionParser.scala +++ b/src/main/scala/firrtl/options/OptionParser.scala @@ -11,7 +11,7 @@ case object OptionsHelpException extends Exception("Usage help invoked") /** OptionParser mixin that causes the OptionParser to not call exit (call `sys.exit`) if the `--help` option is * passed */ -trait DoNotTerminateOnExit { this: OptionParser[_] => +trait DoNotTerminateOnExit { this: OptionParser[?] => override def terminate(exitState: Either[String, Unit]): Unit = () } @@ -21,7 +21,7 @@ trait DoNotTerminateOnExit { this: OptionParser[_] => * [[StageUtils.dramaticError]]. By converting this to an [[OptionsException]], a [[Stage]] can then catch the error an * convert it to an [[OptionsException]] that a [[Stage]] can get at. */ -trait ExceptOnError { this: OptionParser[_] => +trait ExceptOnError { this: OptionParser[?] => override def reportError(msg: String): Unit = throw new OptionsException(msg) } diff --git a/src/main/scala/firrtl/options/Phase.scala b/src/main/scala/firrtl/options/Phase.scala index b836c386..04538fa5 100644 --- a/src/main/scala/firrtl/options/Phase.scala +++ b/src/main/scala/firrtl/options/Phase.scala @@ -3,7 +3,7 @@ package firrtl.options import firrtl.AnnotationSeq - +import firrtl.macros.Macros import logger.LazyLogging import scala.collection.mutable.LinkedHashSet @@ -12,33 +12,33 @@ import scala.reflect import scala.reflect.ClassTag object Dependency { - def apply[A <: DependencyAPI[_]: ClassTag]: Dependency[A] = { + def apply[A <: DependencyAPI[?]: ClassTag]: Dependency[A] = { val clazz = reflect.classTag[A].runtimeClass Dependency(Left(clazz.asInstanceOf[Class[A]])) } - def apply[A <: DependencyAPI[_]](c: Class[_ <: A]): Dependency[A] = { + def apply[A <: DependencyAPI[?]](c: Class[? <: A]): Dependency[A] = { // It's forbidden to wrap the class of a singleton as a Dependency require(c.getName.last != '$') Dependency(Left(c)) } - def apply[A <: DependencyAPI[_]](o: A with Singleton): Dependency[A] = Dependency(Right(o)) + def apply[A <: DependencyAPI[?]](o: A & Singleton): Dependency[A] = Dependency(Right(o)) - def fromTransform[A <: DependencyAPI[_]](t: A): Dependency[A] = { + def fromTransform[A <: DependencyAPI[?]](t: A): Dependency[A] = { if (isSingleton(t)) { - Dependency[A](Right(t.asInstanceOf[A with Singleton])) + Dependency[A](Right(t.asInstanceOf[A & Singleton])) } else { Dependency[A](Left(t.getClass)) } } private def isSingleton(obj: AnyRef): Boolean = { - reflect.runtime.currentMirror.reflect(obj).symbol.isModuleClass + Macros.isModuleClass(obj) } } -case class Dependency[+A <: DependencyAPI[_]](id: Either[Class[_ <: A], A with Singleton]) { +case class Dependency[+A <: DependencyAPI[?]](id: Either[Class[? <: A], A & Singleton]) { def getObject(): A = id match { case Left(c) => safeConstruct(c) case Right(o) => o @@ -55,7 +55,7 @@ case class Dependency[+A <: DependencyAPI[_]](id: Either[Class[_ <: A], A with S } /** Wrap an [[IllegalAccessException]] due to attempted object construction in a [[DependencyManagerException]] */ - private def safeConstruct[A](a: Class[_ <: A]): A = try { a.newInstance } + private def safeConstruct[A](a: Class[? <: A]): A = try { a.newInstance } catch { case e: IllegalAccessException => throw new DependencyManagerException(s"Failed to construct '$a'! (Did you try to construct an object?)", e) @@ -123,7 +123,7 @@ trait IdentityLike[A] { this: TransformLike[A] => * @define seqNote @note The use of a Seq here is to preserve input order. Internally, this will be converted to a private, * ordered Set. */ -trait DependencyAPI[A <: DependencyAPI[A]] { this: TransformLike[_] => +trait DependencyAPI[A <: DependencyAPI[A]] { this: TransformLike[?] => /** All transform that must run before this transform * $seqNote diff --git a/src/main/scala/firrtl/options/Registration.scala b/src/main/scala/firrtl/options/Registration.scala index 1ebccea4..c0eead69 100644 --- a/src/main/scala/firrtl/options/Registration.scala +++ b/src/main/scala/firrtl/options/Registration.scala @@ -3,7 +3,7 @@ package firrtl.options import firrtl.{AnnotationSeq, Transform} - +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import scopt.{OptionDef, OptionParser, Read} /** Contains information about a [[Shell]] command line option @@ -43,7 +43,7 @@ trait HasShellOptions { /** A sequence of options provided */ - def options: Seq[ShellOption[_]] + def options: Seq[ShellOption[?]] /** Add all shell (command line) options to an option parser * @param p an option parser diff --git a/src/main/scala/firrtl/options/Shell.scala b/src/main/scala/firrtl/options/Shell.scala index f7b9371f..89ae0f57 100644 --- a/src/main/scala/firrtl/options/Shell.scala +++ b/src/main/scala/firrtl/options/Shell.scala @@ -3,7 +3,7 @@ package firrtl.options import firrtl.AnnotationSeq - +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import logger.{ClassLogLevelAnnotation, LogClassNamesAnnotation, LogFileAnnotation, LogLevelAnnotation} import scopt.OptionParser @@ -16,7 +16,7 @@ import java.util.ServiceLoader class Shell(val applicationName: String) { /** Command line argument parser (OptionParser) with modifications */ - protected val parser = new OptionParser[AnnotationSeq](applicationName) with DuplicateHandling with ExceptOnError + val parser = new OptionParser[AnnotationSeq](applicationName) with DuplicateHandling with ExceptOnError /** Contains all discovered [[RegisteredLibrary]] */ final lazy val registeredLibraries: Seq[RegisteredLibrary] = { diff --git a/src/main/scala/firrtl/options/Stage.scala b/src/main/scala/firrtl/options/Stage.scala index cefdd957..2744bc00 100644 --- a/src/main/scala/firrtl/options/Stage.scala +++ b/src/main/scala/firrtl/options/Stage.scala @@ -3,6 +3,7 @@ package firrtl.options import firrtl.AnnotationSeq +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import logger.Logger diff --git a/src/main/scala/firrtl/options/StageAnnotations.scala b/src/main/scala/firrtl/options/StageAnnotations.scala index 1642e248..bc5952b6 100644 --- a/src/main/scala/firrtl/options/StageAnnotations.scala +++ b/src/main/scala/firrtl/options/StageAnnotations.scala @@ -3,6 +3,7 @@ package firrtl.options import firrtl.AnnotationSeq +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import firrtl.annotations.{Annotation, NoTargetAnnotation} import firrtl.options.Viewer.view diff --git a/src/main/scala/firrtl/options/phases/AddDefaults.scala b/src/main/scala/firrtl/options/phases/AddDefaults.scala index dcd5b031..626410fd 100644 --- a/src/main/scala/firrtl/options/phases/AddDefaults.scala +++ b/src/main/scala/firrtl/options/phases/AddDefaults.scala @@ -3,6 +3,7 @@ package firrtl.options.phases import firrtl.AnnotationSeq +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import firrtl.options.{Dependency, Phase, TargetDirAnnotation} /** Add default annotations for a [[Stage]] diff --git a/src/main/scala/firrtl/options/phases/Checks.scala b/src/main/scala/firrtl/options/phases/Checks.scala index 64d81fb4..4ffe0510 100644 --- a/src/main/scala/firrtl/options/phases/Checks.scala +++ b/src/main/scala/firrtl/options/phases/Checks.scala @@ -4,6 +4,7 @@ package firrtl.options.phases import firrtl.AnnotationSeq import firrtl.annotations.Annotation +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import firrtl.options.{OptionsException, OutputAnnotationFileAnnotation, Phase, TargetDirAnnotation} import firrtl.options.Dependency diff --git a/src/main/scala/firrtl/options/phases/DeletedWrapper.scala b/src/main/scala/firrtl/options/phases/DeletedWrapper.scala index fe2c6d78..2b2ccf16 100644 --- a/src/main/scala/firrtl/options/phases/DeletedWrapper.scala +++ b/src/main/scala/firrtl/options/phases/DeletedWrapper.scala @@ -3,6 +3,7 @@ package firrtl.options.phases import firrtl.AnnotationSeq +import firrtl.{annoSeqToSeq, seqToAnnoSeq} import firrtl.annotations.DeletedAnnotation import firrtl.options.{Phase, Translator} @@ -25,14 +26,10 @@ class DeletedWrapper(p: Phase) extends Phase with Translator[AnnotationSeq, (Ann def aToB(a: AnnotationSeq): (AnnotationSeq, AnnotationSeq) = (a, a) def bToA(b: (AnnotationSeq, AnnotationSeq)): AnnotationSeq = { - - val (in, out) = (mutable.LinkedHashSet() ++ b._1, mutable.LinkedHashSet() ++ b._2) - - (in -- out).map { + b._1.diff(b._2).map { case DeletedAnnotation(n, a) => DeletedAnnotation(s"$n+$name", a) case a => DeletedAnnotation(name, a) }.toSeq ++ b._2 - } def internalTransform(b: (AnnotationSeq, AnnotationSeq)): (AnnotationSeq, AnnotationSeq) = (b._1, p.transform(b._2)) diff --git a/src/main/scala/firrtl/options/phases/GetIncludes.scala b/src/main/scala/firrtl/options/phases/GetIncludes.scala index d50b2c6f..05810a31 100644 --- a/src/main/scala/firrtl/options/phases/GetIncludes.scala +++ b/src/main/scala/firrtl/options/phases/GetIncludes.scala @@ -3,9 +3,11 @@ package firrtl.options.phases import firrtl.AnnotationSeq +import org.json4s.convertToJsonInput import firrtl.annotations.{AnnotationFileNotFoundException, JsonProtocol} import firrtl.options.{InputAnnotationFileAnnotation, Phase, StageUtils} import firrtl.FileUtils +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import firrtl.stage.AllowUnrecognizedAnnotations import java.io.File diff --git a/src/main/scala/firrtl/options/phases/WriteOutputAnnotations.scala b/src/main/scala/firrtl/options/phases/WriteOutputAnnotations.scala index ba38bb87..4cbb0496 100644 --- a/src/main/scala/firrtl/options/phases/WriteOutputAnnotations.scala +++ b/src/main/scala/firrtl/options/phases/WriteOutputAnnotations.scala @@ -3,6 +3,8 @@ package firrtl.options.phases import firrtl.AnnotationSeq +import firrtl.{seqToAnnoSeq, annoSeqToSeq} +import firrtl.options.StageOptionsView import firrtl.annotations.{Annotation, DeletedAnnotation, JsonProtocol} import firrtl.options.{ BufferedCustomFileEmission, diff --git a/src/main/scala/firrtl/package.scala b/src/main/scala/firrtl/package.scala index 67d5e52c..9e2f74a6 100644 --- a/src/main/scala/firrtl/package.scala +++ b/src/main/scala/firrtl/package.scala @@ -6,7 +6,7 @@ package object firrtl { // Force initialization of the Forms object - https://github.com/freechipsproject/firrtl/issues/1462 private val _dummyForms = firrtl.stage.Forms - implicit def seqToAnnoSeq(xs: Seq[Annotation]) = AnnotationSeq(xs) + implicit def seqToAnnoSeq(xs: Seq[Annotation]): AnnotationSeq = AnnotationSeq(xs) implicit def annoSeqToSeq(as: AnnotationSeq): Seq[Annotation] = as.toSeq /* Options as annotations compatibility items */ diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala index 02d35740..dad0f69c 100644 --- a/src/main/scala/firrtl/passes/CheckWidths.scala +++ b/src/main/scala/firrtl/passes/CheckWidths.scala @@ -73,7 +73,7 @@ object CheckWidths extends Pass { (w, t) match { case (IntWidth(width), _) if width >= MaxWidth => errors.append(new WidthTooBig(info, target.serialize, width)) - case (w: IntWidth, f: FixedType) if (w.width < 0 && w.width == f.width) => + case (w: IntWidth, f: FixedType) if (w.width < 0 && w.width == f.width.asInstanceOf[IntWidth].width) => errors.append(new NegWidthException(info, target.serialize)) case (_: IntWidth, _) => case _ => diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index 8fb4e5fb..1fd79f3a 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -125,13 +125,22 @@ object ExpandWhens extends Pass { EmptyStmt // For simulation constructs, update simlist with predicated statement and return EmptyStmt case sx: Print => - simlist += (if (weq(p, one)) sx else sx.withEn(AND(p, sx.en))) + simlist += (if (weq(p, one)) sx else sx match { + case s: Print => s.copy(en = AND(p, s.en)) + case o => o + }) EmptyStmt case sx: Stop => - simlist += (if (weq(p, one)) sx else sx.withEn(AND(p, sx.en))) + simlist += (if (weq(p, one)) sx else sx match { + case s: Stop => s.copy(en = AND(p, s.en)) + case o => o + }) EmptyStmt case sx: Verification => - simlist += (if (weq(p, one)) sx else sx.withEn(AND(p, sx.en))) + simlist += (if (weq(p, one)) sx else sx match { + case s: Verification => s.copy(en = AND(p, s.en)) + case o => o + }) EmptyStmt // Expand conditionally, see comments below case sx: Conditionally => diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala index 8ab78fee..b1910831 100644 --- a/src/main/scala/firrtl/passes/InferTypes.scala +++ b/src/main/scala/firrtl/passes/InferTypes.scala @@ -41,7 +41,7 @@ object InferTypes extends Pass { // we first need to remove the unknown widths and bounds from all ports, // as their type will determine the module types - val portsKnown = c.modules.map(_.map { p: Port => p.copy(tpe = remove_unknowns(p.tpe)) }) + val portsKnown = c.modules.map(_.map { (p: Port) => p.copy(tpe = remove_unknowns(p.tpe)) }) val mtypes = portsKnown.map(m => m.name -> module_type(m)).toMap def infer_types_e(types: TypeLookup)(e: Expression): Expression = diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index d0677fad..b3ef569a 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -31,7 +31,7 @@ case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarg "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint() ) } - } + }: @unchecked (newLoc, newExp) match { case (Some(l: ReferenceTarget), Some(e: ReferenceTarget)) => Seq(WidthGeqConstraintAnnotation(l, e)) @@ -265,7 +265,7 @@ class InferWidths extends Transform with ResolvedAnnotationPaths with Dependency } leafType - } + }: @unchecked //get_constraints_t(locType, expType) addTypeConstraints(anno.loc, anno.exp)(locType, expType) diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala index e2105ff2..496b4ecf 100644 --- a/src/main/scala/firrtl/passes/Inline.scala +++ b/src/main/scala/firrtl/passes/Inline.scala @@ -244,7 +244,7 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe * The [[RenameMap]]s in renamesMap are appear in renamesSeq * in the order that they should be applied */ - val (renamesMap, renamesSeq) = { + val (renamesMap, renamesSeq): (Map[(OfModule, Instance), MutableRenameMap], Seq[RenameMap]) = {{ val mutableDiGraph = new MutableDiGraph[(OfModule, Instance)] // compute instance graph instMaps.foreach { @@ -284,7 +284,7 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe val resultMap = indexMap.mapValues(idx => resultSeq(maxIdx - idx)) (resultMap, resultSeq) } - } + }}: @unchecked def fixupRefs( instMap: collection.Map[Instance, OfModule], @@ -369,7 +369,7 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe }) // Upcast so reduce works (andThen returns RenameMap) - val renames = (renamesSeq: Seq[RenameMap]).reduceLeftOption(_ andThen _) + val renames = (renamesSeq: Seq[RenameMap]).reduceLeftOption(_ `andThen` _) val cleanedAnnos = annos.filterNot { case InlineAnnotation(_) => true diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index 976741fd..7bec15fc 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -25,6 +25,7 @@ import firrtl.{ UnknownForm, Utils } +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import firrtl.ir._ import firrtl.options.Dependency import firrtl.stage.TransformManager.TransformDependency diff --git a/src/main/scala/firrtl/passes/ResolveKinds.scala b/src/main/scala/firrtl/passes/ResolveKinds.scala index 745be1e2..e3218467 100644 --- a/src/main/scala/firrtl/passes/ResolveKinds.scala +++ b/src/main/scala/firrtl/passes/ResolveKinds.scala @@ -5,7 +5,6 @@ package firrtl.passes import firrtl._ import firrtl.ir._ import firrtl.Mappers._ -import firrtl.backends.experimental.smt.random.DefRandom import firrtl.traversals.Foreachers._ object ResolveKinds extends Pass { @@ -32,7 +31,6 @@ object ResolveKinds extends Pass { case sx: DefRegister => kinds(sx.name) = RegKind case sx: WDefInstance => kinds(sx.name) = InstanceKind case sx: DefMemory => kinds(sx.name) = MemKind - case sx: DefRandom => kinds(sx.name) = RandomKind case _ => } s.map(resolve_stmt(kinds)) diff --git a/src/main/scala/firrtl/passes/TrimIntervals.scala b/src/main/scala/firrtl/passes/TrimIntervals.scala index 99a97a38..bb927381 100644 --- a/src/main/scala/firrtl/passes/TrimIntervals.scala +++ b/src/main/scala/firrtl/passes/TrimIntervals.scala @@ -76,10 +76,10 @@ class TrimIntervals extends Pass { case DoPrim(o, args, consts, t) if opsToFix.contains(o) && (args.map(_.tpe).collect { case x: IntervalType => x }).size == args.size => - val maxBP = args.map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ max _) + val maxBP = args.map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ `max` _) DoPrim(o, args.map { a => fixBP(maxBP)(a) }, consts, t) case Mux(cond, tval, fval, t: IntervalType) => - val maxBP = Seq(tval, fval).map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ max _) + val maxBP = Seq(tval, fval).map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ `max` _) Mux(cond, fixBP(maxBP)(tval), fixBP(maxBP)(fval), t) case other => other } diff --git a/src/main/scala/firrtl/passes/VerilogPrep.scala b/src/main/scala/firrtl/passes/VerilogPrep.scala index 2bd17519..358c34e2 100644 --- a/src/main/scala/firrtl/passes/VerilogPrep.scala +++ b/src/main/scala/firrtl/passes/VerilogPrep.scala @@ -32,7 +32,7 @@ object VerilogPrep extends Pass { Dependency[firrtl.transforms.LegalizeClocksAndAsyncResetsTransform], Dependency[firrtl.transforms.FlattenRegUpdate], Dependency(passes.VerilogModulusCleanup), - Dependency[firrtl.transforms.VerilogRename] + Dependency[firrtl.transforms.VerilogRename[?]] ) override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized diff --git a/src/main/scala/firrtl/passes/memlib/CreateMemoryAnnotations.scala b/src/main/scala/firrtl/passes/memlib/CreateMemoryAnnotations.scala index 240c2c9a..c2d7a6f1 100644 --- a/src/main/scala/firrtl/passes/memlib/CreateMemoryAnnotations.scala +++ b/src/main/scala/firrtl/passes/memlib/CreateMemoryAnnotations.scala @@ -22,8 +22,9 @@ class CreateMemoryAnnotations extends Transform with DependencyAPIMigration { Seq(MemLibOutConfigFileAnnotation(outputConfig, Nil)) ++ { if (inputFileName.isEmpty) None else if (new File(inputFileName).exists) { - import CustomYAMLProtocol._ - Some(PinAnnotation(new YamlFileReader(inputFileName).parse[Config].map(_.pin.name))) + error("custom yaml protocol not supported in scala 3") + // import CustomYAMLProtocol._ + // Some(PinAnnotation(new YamlFileReader(inputFileName).parse[Config].map(_.pin.name))) } else error("Input configuration file does not exist!") } case a => Seq(a) diff --git a/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala b/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala index 186ca78c..f19db4e5 100644 --- a/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala +++ b/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala @@ -7,7 +7,7 @@ import firrtl.options.{RegisteredLibrary, ShellOption} class MemLibOptions extends RegisteredLibrary { val name: String = "MemLib Options" - val options: Seq[ShellOption[_]] = Seq(new InferReadWrite, new ReplSeqMem) + val options: Seq[ShellOption[?]] = Seq(new InferReadWrite, new ReplSeqMem) .flatMap(_.options) } diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala index 7a1a57fb..add05fe2 100644 --- a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala +++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala @@ -90,7 +90,7 @@ object ResolveMaskGranularity extends Pass { */ def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = { val wenOrigin = getOrigin(connects)(wen) - val wmaskOrigin = connects.keys.filter(_.startsWith(wmask.serialize)).map { s: String => getOrigin(connects, s) } + val wmaskOrigin = connects.keys.filter(_.startsWith(wmask.serialize)).map { (s: String) => getOrigin(connects, s) } // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking) val redundantMask = wmaskOrigin.forall(x => weq(x, wenOrigin) || weq(x, one)) if (redundantMask) None else Some(wmaskOrigin.size) diff --git a/src/main/scala/firrtl/passes/memlib/YamlUtils.scala b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala index 99344973..fbb25656 100644 --- a/src/main/scala/firrtl/passes/memlib/YamlUtils.scala +++ b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala @@ -1,45 +1,45 @@ -// SPDX-License-Identifier: Apache-2.0 +// // SPDX-License-Identifier: Apache-2.0 -package firrtl.passes -package memlib -import net.jcazevedo.moultingyaml._ -import java.io.{CharArrayWriter, File, PrintWriter} -import firrtl.FileUtils +// package firrtl.passes +// package memlib +// import net.jcazevedo.moultingyaml._ +// import java.io.{CharArrayWriter, File, PrintWriter} +// import firrtl.FileUtils -object CustomYAMLProtocol extends DefaultYamlProtocol { - // bottom depends on top - implicit val _pin = yamlFormat1(Pin) - implicit val _source = yamlFormat2(Source) - implicit val _top = yamlFormat1(Top) - implicit val _configs = yamlFormat3(Config) -} +// object CustomYAMLProtocol extends DefaultYamlProtocol { +// // bottom depends on top +// implicit val _pin = yamlFormat1(Pin) +// implicit val _source = yamlFormat2(Source) +// implicit val _top = yamlFormat1(Top) +// implicit val _configs = yamlFormat3(Config) +// } -case class Pin(name: String) -case class Source(name: String, module: String) -case class Top(name: String) -case class Config(pin: Pin, source: Source, top: Top) +// case class Pin(name: String) +// case class Source(name: String, module: String) +// case class Top(name: String) +// case class Config(pin: Pin, source: Source, top: Top) -class YamlFileReader(file: String) { - def parse[A](implicit reader: YamlReader[A]): Seq[A] = { - if (new File(file).exists) { - val yamlString = FileUtils.getText(file) - yamlString.parseYamls.flatMap(x => - try Some(reader.read(x)) - catch { case e: Exception => None } - ) - } else sys.error("Yaml file doesn't exist!") - } -} +// class YamlFileReader(file: String) { +// def parse[A](implicit reader: YamlReader[A]): Seq[A] = { +// if (new File(file).exists) { +// val yamlString = FileUtils.getText(file) +// yamlString.parseYamls.flatMap(x => +// try Some(reader.read(x)) +// catch { case e: Exception => None } +// ) +// } else sys.error("Yaml file doesn't exist!") +// } +// } -class YamlFileWriter(file: String) { - val outputBuffer = new CharArrayWriter - val separator = "--- \n" - def append(in: YamlValue): Unit = { - outputBuffer.append(s"$separator${in.prettyPrint}") - } - def dump(): Unit = { - val outputFile = new PrintWriter(file) - outputFile.write(outputBuffer.toString) - outputFile.close() - } -} +// class YamlFileWriter(file: String) { +// val outputBuffer = new CharArrayWriter +// val separator = "--- \n" +// def append(in: YamlValue): Unit = { +// outputBuffer.append(s"$separator${in.prettyPrint}") +// } +// def dump(): Unit = { +// val outputFile = new PrintWriter(file) +// outputFile.write(outputBuffer.toString) +// outputFile.close() +// } +// } diff --git a/src/main/scala/firrtl/passes/wiring/Wiring.scala b/src/main/scala/firrtl/passes/wiring/Wiring.scala index 45eb61bf..c5faefac 100644 --- a/src/main/scala/firrtl/passes/wiring/Wiring.scala +++ b/src/main/scala/firrtl/passes/wiring/Wiring.scala @@ -50,7 +50,7 @@ class Wiring(wiSeq: Seq[WiringInfo]) extends Pass { portNames(i) = portNames(i) + (m.name -> { if (si.exists(getModuleName(_) == m.name)) ns.newName(p) - else ns.newName(tokenize(c).filterNot("[]." contains _).mkString("_")) + else ns.newName(tokenize(c).filterNot("[]." `contains` _).mkString("_")) }) } } diff --git a/src/main/scala/firrtl/stage/FirrtlAnnotations.scala b/src/main/scala/firrtl/stage/FirrtlAnnotations.scala index 7f4266fd..f8273b9f 100644 --- a/src/main/scala/firrtl/stage/FirrtlAnnotations.scala +++ b/src/main/scala/firrtl/stage/FirrtlAnnotations.scala @@ -238,7 +238,7 @@ object RunFirrtlTransformAnnotation extends HasShellOptions { longOption = "custom-transforms", toAnnotationSeq = _.map(txName => try { - val tx = Class.forName(txName).asInstanceOf[Class[_ <: Transform]].newInstance() + val tx = Class.forName(txName).asInstanceOf[Class[? <: Transform]].newInstance() RunFirrtlTransformAnnotation(tx) } catch { case e: ClassNotFoundException => diff --git a/src/main/scala/firrtl/stage/FirrtlCompilerTargets.scala b/src/main/scala/firrtl/stage/FirrtlCompilerTargets.scala index e31acbbf..47bc9ed5 100644 --- a/src/main/scala/firrtl/stage/FirrtlCompilerTargets.scala +++ b/src/main/scala/firrtl/stage/FirrtlCompilerTargets.scala @@ -2,6 +2,7 @@ package firrtl.stage +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import firrtl.transforms._ import firrtl.passes.memlib._ import firrtl.options.{HasShellOptions, ShellOption} diff --git a/src/main/scala/firrtl/stage/Forms.scala b/src/main/scala/firrtl/stage/Forms.scala index f36e8f87..9f0c9c55 100644 --- a/src/main/scala/firrtl/stage/Forms.scala +++ b/src/main/scala/firrtl/stage/Forms.scala @@ -3,8 +3,6 @@ package firrtl.stage import firrtl._ -import firrtl.backends.experimental.rtlil.RtlilEmitter -import firrtl.backends.experimental.smt.{Btor2Emitter, SMTLibEmitter} import firrtl.options.Dependency import firrtl.stage.TransformManager.TransformDependency @@ -115,7 +113,7 @@ object Forms { Dependency[firrtl.transforms.LegalizeClocksAndAsyncResetsTransform], Dependency[firrtl.transforms.FlattenRegUpdate], Dependency(passes.VerilogModulusCleanup), - Dependency[firrtl.transforms.VerilogRename], + Dependency[firrtl.transforms.VerilogRename[?]], Dependency(passes.VerilogPrep), Dependency[firrtl.AddDescriptionNodes] ) @@ -136,9 +134,6 @@ object Forms { Dependency[VerilogEmitter], Dependency[MinimumVerilogEmitter], Dependency[SystemVerilogEmitter], - Dependency(SMTLibEmitter), - Dependency(Btor2Emitter), - Dependency[RtlilEmitter] ) val LowEmitters = Dependency[LowFirrtlEmitter] +: BackendEmitters diff --git a/src/main/scala/firrtl/stage/phases/AddCircuit.scala b/src/main/scala/firrtl/stage/phases/AddCircuit.scala index 1401560c..420d813e 100644 --- a/src/main/scala/firrtl/stage/phases/AddCircuit.scala +++ b/src/main/scala/firrtl/stage/phases/AddCircuit.scala @@ -5,6 +5,7 @@ package firrtl.stage.phases import firrtl.stage._ import firrtl.{AnnotationSeq, Parser} +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import firrtl.options.{Dependency, Phase, PhasePrerequisiteException} /** [[firrtl.options.Phase Phase]] that expands [[FirrtlFileAnnotation]]/[[FirrtlSourceAnnotation]] into diff --git a/src/main/scala/firrtl/stage/phases/AddDefaults.scala b/src/main/scala/firrtl/stage/phases/AddDefaults.scala index 63de1a7b..2fa621b9 100644 --- a/src/main/scala/firrtl/stage/phases/AddDefaults.scala +++ b/src/main/scala/firrtl/stage/phases/AddDefaults.scala @@ -3,6 +3,7 @@ package firrtl.stage.phases import firrtl.{AnnotationSeq, VerilogEmitter} +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import firrtl.options.{Dependency, Phase, TargetDirAnnotation} import firrtl.stage.TransformManager.TransformDependency import firrtl.transforms.BlackBoxTargetDirAnno diff --git a/src/main/scala/firrtl/stage/phases/AddImplicitEmitter.scala b/src/main/scala/firrtl/stage/phases/AddImplicitEmitter.scala index a252f4f3..b7736f55 100644 --- a/src/main/scala/firrtl/stage/phases/AddImplicitEmitter.scala +++ b/src/main/scala/firrtl/stage/phases/AddImplicitEmitter.scala @@ -2,6 +2,7 @@ package firrtl.stage.phases +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import firrtl.{AnnotationSeq, EmitAnnotation, EmitCircuitAnnotation, Emitter} import firrtl.stage.{CompilerAnnotation, RunFirrtlTransformAnnotation} import firrtl.options.{Dependency, Phase} diff --git a/src/main/scala/firrtl/stage/phases/AddImplicitOutputFile.scala b/src/main/scala/firrtl/stage/phases/AddImplicitOutputFile.scala index dcc2be4f..1c82736a 100644 --- a/src/main/scala/firrtl/stage/phases/AddImplicitOutputFile.scala +++ b/src/main/scala/firrtl/stage/phases/AddImplicitOutputFile.scala @@ -2,9 +2,12 @@ package firrtl.stage.phases +import firrtl.{seqToAnnoSeq, annoSeqToSeq} +import firrtl.stage.FirrtlOptionsView import firrtl.{AnnotationSeq, EmitAllModulesAnnotation} import firrtl.options.{Dependency, Phase, Viewer} import firrtl.stage.{FirrtlOptions, OutputFileAnnotation} +import firrtl.options.StageOptionsView /** [[firrtl.options.Phase Phase]] that adds an [[OutputFileAnnotation]] if one does not already exist. * diff --git a/src/main/scala/firrtl/stage/phases/Checks.scala b/src/main/scala/firrtl/stage/phases/Checks.scala index edfb7d03..0b83cdbf 100644 --- a/src/main/scala/firrtl/stage/phases/Checks.scala +++ b/src/main/scala/firrtl/stage/phases/Checks.scala @@ -3,7 +3,7 @@ package firrtl.stage.phases import firrtl.stage._ - +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import firrtl.{AnnotationSeq, EmitAllModulesAnnotation, EmitCircuitAnnotation} import firrtl.annotations.Annotation import firrtl.options.{Dependency, OptionsException, Phase} diff --git a/src/main/scala/firrtl/stage/phases/Compiler.scala b/src/main/scala/firrtl/stage/phases/Compiler.scala index d24775e7..6c44e73f 100644 --- a/src/main/scala/firrtl/stage/phases/Compiler.scala +++ b/src/main/scala/firrtl/stage/phases/Compiler.scala @@ -2,7 +2,8 @@ package firrtl.stage.phases -import firrtl.{AnnotationSeq, ChirrtlForm, CircuitState, Compiler => FirrtlCompiler, Transform, seqToAnnoSeq} +import firrtl.{AnnotationSeq, ChirrtlForm, CircuitState, Compiler => FirrtlCompiler, Transform} +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import firrtl.options.{Dependency, Phase, PhasePrerequisiteException, Translator} import firrtl.stage.{ CompilerAnnotation, @@ -12,7 +13,7 @@ import firrtl.stage.{ RunFirrtlTransformAnnotation } import firrtl.stage.TransformManager.TransformDependency - +import scala.collection.parallel.CollectionConverters._ import scala.collection.mutable /** An encoding of the information necessary to run the FIRRTL compiler once */ @@ -142,7 +143,7 @@ class Compiler extends Phase with Translator[AnnotationSeq, Seq[CompilerRun]] { if (b.size <= 1) { b.map(f) } else { - collection.parallel.immutable.ParVector(b: _*).par.map(f).seq + scala.collection.immutable.Vector(b*).par.map(f).seq } } diff --git a/src/main/scala/firrtl/stage/phases/ConvertCompilerAnnotations.scala b/src/main/scala/firrtl/stage/phases/ConvertCompilerAnnotations.scala index edd38dc6..c21b031f 100644 --- a/src/main/scala/firrtl/stage/phases/ConvertCompilerAnnotations.scala +++ b/src/main/scala/firrtl/stage/phases/ConvertCompilerAnnotations.scala @@ -3,6 +3,7 @@ package firrtl.stage.phases import firrtl.AnnotationSeq +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import firrtl.options.{Dependency, OptionsException, Phase} import firrtl.stage.{CompilerAnnotation, RunFirrtlTransformAnnotation} diff --git a/src/main/scala/firrtl/stage/phases/DriverCompatibility.scala b/src/main/scala/firrtl/stage/phases/DriverCompatibility.scala index 546a18f5..5d816f23 100644 --- a/src/main/scala/firrtl/stage/phases/DriverCompatibility.scala +++ b/src/main/scala/firrtl/stage/phases/DriverCompatibility.scala @@ -5,10 +5,12 @@ package firrtl.stage.phases import firrtl.stage._ import firrtl.{AnnotationSeq, EmitAllModulesAnnotation, EmitCircuitAnnotation, Emitter, Parser} +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import firrtl.annotations.NoTargetAnnotation import firrtl.FileUtils import firrtl.proto.FromProto import firrtl.options.{InputAnnotationFileAnnotation, OptionsException, Phase, StageOptions, StageUtils} +import firrtl.options.StageOptionsView import firrtl.options.Viewer import firrtl.options.Dependency diff --git a/src/main/scala/firrtl/stage/transforms/TrackTransforms.scala b/src/main/scala/firrtl/stage/transforms/TrackTransforms.scala index 62363a77..4a5b7260 100644 --- a/src/main/scala/firrtl/stage/transforms/TrackTransforms.scala +++ b/src/main/scala/firrtl/stage/transforms/TrackTransforms.scala @@ -3,6 +3,7 @@ package firrtl.stage.transforms import firrtl.{AnnotationSeq, CircuitState, Transform} +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import firrtl.annotations.NoTargetAnnotation import firrtl.options.{Dependency, DependencyManagerException} diff --git a/src/main/scala/firrtl/transforms/CheckCombLoops.scala b/src/main/scala/firrtl/transforms/CheckCombLoops.scala index eec9d1af..21356752 100644 --- a/src/main/scala/firrtl/transforms/CheckCombLoops.scala +++ b/src/main/scala/firrtl/transforms/CheckCombLoops.scala @@ -54,8 +54,8 @@ object LogicNode { object CheckCombLoops { type AbstractConnMap = DiGraph[LogicNode] - type ConnMap = DiGraph[LogicNode] with EdgeData[LogicNode, Info] - type MutableConnMap = MutableDiGraph[LogicNode] with MutableEdgeData[LogicNode, Info] + type ConnMap = DiGraph[LogicNode] & EdgeData[LogicNode, Info] + type MutableConnMap = MutableDiGraph[LogicNode] & MutableEdgeData[LogicNode, Info] class CombLoopException(info: Info, mname: String, cycle: Seq[String]) extends PassException(s"$info: [module $mname] Combinational loop detected:\n" + cycle.mkString("\n")) diff --git a/src/main/scala/firrtl/transforms/CustomRadixTransform.scala b/src/main/scala/firrtl/transforms/CustomRadixTransform.scala index 42724be8..d5afc49b 100644 --- a/src/main/scala/firrtl/transforms/CustomRadixTransform.scala +++ b/src/main/scala/firrtl/transforms/CustomRadixTransform.scala @@ -4,6 +4,7 @@ package firrtl.transforms import firrtl.annotations.TargetToken.Instance import firrtl.annotations.{Annotation, NoTargetAnnotation, ReferenceTarget, SingleTargetAnnotation} +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import firrtl.options.{CustomFileEmission, Dependency, HasShellOptions, ShellOption} import firrtl.stage.TransformManager.TransformDependency import firrtl.stage.{Forms, RunFirrtlTransformAnnotation} diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala index a622feb4..2c9e2547 100644 --- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala +++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala @@ -12,7 +12,6 @@ import firrtl.Mappers._ import firrtl.Utils.{kind, throwInternalError} import firrtl.MemoizedHash._ import firrtl.renamemap.MutableRenameMap -import firrtl.backends.experimental.smt.random.DefRandom import firrtl.options.{Dependency, RegisteredTransform, ShellOption} import collection.mutable @@ -44,7 +43,7 @@ class DeadCodeElimination extends Transform with RegisteredTransform with Depend Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], Dependency[firrtl.transforms.FlattenRegUpdate], Dependency(passes.VerilogModulusCleanup), - Dependency[firrtl.transforms.VerilogRename], + Dependency[firrtl.transforms.VerilogRename[?]], Dependency(passes.VerilogPrep), Dependency[firrtl.AddDescriptionNodes] ) @@ -120,11 +119,6 @@ class DeadCodeElimination extends Transform with RegisteredTransform with Depend val node = LogicNode(mod.name, name) depGraph.addVertex(node) Seq(clock, reset, init).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(node, ref)) - case DefRandom(_, name, _, clock, en) => - val node = LogicNode(mod.name, name) - depGraph.addVertex(node) - val inputs = clock ++: en +: Nil - inputs.flatMap(getDeps).foreach(ref => depGraph.addPairWithEdge(node, ref)) case DefNode(_, name, value) => val node = LogicNode(mod.name, name) depGraph.addVertex(node) @@ -224,7 +218,6 @@ class DeadCodeElimination extends Transform with RegisteredTransform with Depend val tpe = decl match { case _: DefNode => "node" case _: DefRegister => "reg" - case _: DefRandom => "rand" case _: DefWire => "wire" case _: Port => "port" case _: DefMemory => "mem" diff --git a/src/main/scala/firrtl/transforms/Dedup.scala b/src/main/scala/firrtl/transforms/Dedup.scala index 373066c8..6cd528c2 100644 --- a/src/main/scala/firrtl/transforms/Dedup.scala +++ b/src/main/scala/firrtl/transforms/Dedup.scala @@ -424,7 +424,7 @@ object DedupModules extends LazyLogging { val instToModule = mutable.HashMap.empty[String, String] def markAggregatePorts(expr: Expression): Unit = { if (kind(expr) == InstanceKind && hasBundleType(expr.tpe)) { - val (WRef(inst, _, _, _), _) = splitRef(expr) + val (WRef(inst, _, _, _), _) = splitRef(expr): @unchecked dontDedup += instToModule(inst) } } diff --git a/src/main/scala/firrtl/transforms/EnsureNamedStatements.scala b/src/main/scala/firrtl/transforms/EnsureNamedStatements.scala index a40409f9..bdef14f9 100644 --- a/src/main/scala/firrtl/transforms/EnsureNamedStatements.scala +++ b/src/main/scala/firrtl/transforms/EnsureNamedStatements.scala @@ -25,15 +25,15 @@ object EnsureNamedStatements extends Transform with DependencyAPIMigration { } private def onStmt(namespace: Namespace)(stmt: Statement): Statement = stmt match { - case s: Print if s.name.isEmpty => s.withName(namespace.newName("print")) - case s: Stop if s.name.isEmpty => s.withName(namespace.newName("stop")) + case s: Print if s.name.isEmpty => s.copyWithName(name = namespace.newName("print")) + case s: Stop if s.name.isEmpty => s.copyWithName(name = namespace.newName("stop")) case s: Verification if s.name.isEmpty => val baseName = s.op match { case Formal.Cover => "cover" case Formal.Assert => "assert" case Formal.Assume => "assume" } - s.withName(namespace.newName(baseName)) + s.copyWithName(name = namespace.newName(baseName)) case other => other.mapStmt(onStmt(namespace)) } } diff --git a/src/main/scala/firrtl/transforms/GroupComponents.scala b/src/main/scala/firrtl/transforms/GroupComponents.scala index c2a79d53..01b85aac 100644 --- a/src/main/scala/firrtl/transforms/GroupComponents.scala +++ b/src/main/scala/firrtl/transforms/GroupComponents.scala @@ -97,7 +97,7 @@ class GroupComponents extends Transform with DependencyAPIMigration { // The label "" indicates the original module, and components belonging to that group will remain // in the original module (not get moved into a new module) val label2group: Map[String, MSet[String]] = groups.collect { - case GroupAnnotation(set, module, instance, _, _) => set.head.name -> mutable.Set(set.map(_.name): _*) + case GroupAnnotation(set, module, instance, _, _) => set.head.name -> mutable.Set(set.map(_.name)*) }.toMap + ("" -> mutable.Set("")) // Name of new module containing each group, by label diff --git a/src/main/scala/firrtl/transforms/InferResets.scala b/src/main/scala/firrtl/transforms/InferResets.scala index 9e3e4a61..821df054 100644 --- a/src/main/scala/firrtl/transforms/InferResets.scala +++ b/src/main/scala/firrtl/transforms/InferResets.scala @@ -72,12 +72,12 @@ object InferResets { // Vectors must all have the same type, so we only process Index 0 // If the subtype is an aggregate, there can be multiple of each index val ts = tokens.collect { case (TargetToken.Index(0) +: tail, tpe) => (tail, tpe) } - VectorTree(fromTokens(ts: _*)) + VectorTree(fromTokens(ts*)) // BundleTree case (TargetToken.Field(_) +: _, _) +: _ => val fields = tokens.groupBy { case (TargetToken.Field(n) +: t, _) => n }.mapValues { ts => - fromTokens(ts.map { case (_ +: t, tpe) => (t, tpe) }: _*) + fromTokens(ts.map { case (_ +: t, tpe) => (t, tpe) }*) }.toMap BundleTree(fields) } @@ -274,7 +274,7 @@ class InferResets extends Transform with DependencyAPIMigration { map .groupBy(_._1.ref) .mapValues { ts => - TypeTree.fromTokens(ts.toSeq.map { case (target, tpe) => (target.component, tpe) }: _*) + TypeTree.fromTokens(ts.toSeq.map { case (target, tpe) => (target.component, tpe) }*) } .toMap diff --git a/src/main/scala/firrtl/transforms/ManipulateNames.scala b/src/main/scala/firrtl/transforms/ManipulateNames.scala index 3596b7e6..758628b5 100644 --- a/src/main/scala/firrtl/transforms/ManipulateNames.scala +++ b/src/main/scala/firrtl/transforms/ManipulateNames.scala @@ -33,7 +33,7 @@ import scala.reflect.ClassTag * behavior, use a combination of a sub-class of this annotation and a [[firrtl.transforms.NoDedupAnnotation * NoDedupAnnotation]]. */ -sealed trait ManipulateNamesListAnnotation[A <: ManipulateNames[_]] extends MultiTargetAnnotation { +sealed trait ManipulateNamesListAnnotation[A <: ManipulateNames[?]] extends MultiTargetAnnotation { def transform: Dependency[A] @@ -58,7 +58,7 @@ sealed trait ManipulateNamesListAnnotation[A <: ManipulateNames[_]] extends Mult * @throws java.lang.IllegalArgumentException if any non-local targets are given * @note $noteLocalTargets */ -case class ManipulateNamesBlocklistAnnotation[A <: ManipulateNames[_]]( +case class ManipulateNamesBlocklistAnnotation[A <: ManipulateNames[?]]( targets: Seq[Seq[Target]], transform: Dependency[A]) extends ManipulateNamesListAnnotation[A] { @@ -79,7 +79,7 @@ case class ManipulateNamesBlocklistAnnotation[A <: ManipulateNames[_]]( * @throws java.lang.IllegalArgumentException if any non-local targets are given * @note $noteLocalTargets */ -case class ManipulateNamesAllowlistAnnotation[A <: ManipulateNames[_]]( +case class ManipulateNamesAllowlistAnnotation[A <: ManipulateNames[?]]( targets: Seq[Seq[Target]], transform: Dependency[A]) extends ManipulateNamesListAnnotation[A] { @@ -97,7 +97,7 @@ case class ManipulateNamesAllowlistAnnotation[A <: ManipulateNames[_]]( * @param transform the transform that performed this rename * @param oldTargets the old targets */ -case class ManipulateNamesAllowlistResultAnnotation[A <: ManipulateNames[_]]( +case class ManipulateNamesAllowlistResultAnnotation[A <: ManipulateNames[?]]( targets: Seq[Seq[Target]], transform: Dependency[A], oldTargets: Seq[Seq[Target]]) @@ -178,7 +178,7 @@ private class RenameDataStructure( /** Transform for manipulate all the names in a FIRRTL circuit. * @tparam A the type of the child transform */ -abstract class ManipulateNames[A <: ManipulateNames[_]: ClassTag] extends Transform with DependencyAPIMigration { +abstract class ManipulateNames[A <: ManipulateNames[?]: ClassTag] extends Transform with DependencyAPIMigration { /** A function used to manipulate a name in a FIRRTL circuit */ def manipulate: (String, Namespace) => Option[String] diff --git a/src/main/scala/firrtl/transforms/MustDedup.scala b/src/main/scala/firrtl/transforms/MustDedup.scala index 417e46ac..7f751ff3 100644 --- a/src/main/scala/firrtl/transforms/MustDedup.scala +++ b/src/main/scala/firrtl/transforms/MustDedup.scala @@ -150,7 +150,7 @@ object MustDeduplicateTransform { val nodesToKeep = findNodesToKeep(failure, graph) graph.subgraph(nodesToKeep) + // Add fake nodes to represent parents of the "shouldDedup" nodes - DiGraph(shouldDedup.map(n => getParents(n).mkString(", ") -> n): _*) + DiGraph(shouldDedup.map(n => getParents(n).mkString(", ") -> n)*) } // Gather candidate modules and assign indices for reference val candidateIdx: Map[String, Int] = diff --git a/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala b/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala index 6e2c9a4a..e0bcc697 100644 --- a/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala +++ b/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala @@ -6,11 +6,12 @@ import firrtl._ import firrtl.Utils.v_keywords import firrtl.options.Dependency +import scala.reflect.ClassTag /** Transform that removes collisions with reserved keywords * @param keywords a set of reserved words */ -class RemoveKeywordCollisions(keywords: Set[String]) extends ManipulateNames { +class RemoveKeywordCollisions[A <: ManipulateNames[?]: ClassTag](keywords: Set[String]) extends ManipulateNames[A] { private val inlineDelim = "_" @@ -29,7 +30,7 @@ class RemoveKeywordCollisions(keywords: Set[String]) extends ManipulateNames { } /** Transform that removes collisions with Verilog keywords */ -class VerilogRename extends RemoveKeywordCollisions(v_keywords) { +class VerilogRename[A <: ManipulateNames[?]: ClassTag] extends RemoveKeywordCollisions[A](v_keywords) { override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ Seq( diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala index 4fa70002..9089ad83 100644 --- a/src/main/scala/firrtl/transforms/RemoveWires.scala +++ b/src/main/scala/firrtl/transforms/RemoveWires.scala @@ -11,7 +11,6 @@ import firrtl.WrappedExpression._ import firrtl.graph.{CyclicException, MutableDiGraph} import firrtl.options.Dependency import firrtl.Utils.getGroundZero -import firrtl.backends.experimental.smt.random.DefRandom import firrtl.passes.PadWidths import scala.collection.mutable @@ -63,7 +62,6 @@ class RemoveWires extends Transform with DependencyAPIMigration { private def getOrderedNodes( netlist: mutable.LinkedHashMap[WrappedExpression, (Seq[Expression], Info)], regInfo: mutable.Map[WrappedExpression, DefRegister], - randInfo: mutable.Map[WrappedExpression, DefRandom] ): Try[Seq[Statement]] = { val digraph = new MutableDiGraph[WrappedExpression] for ((sink, (exprs, _)) <- netlist) { @@ -81,10 +79,9 @@ class RemoveWires extends Transform with DependencyAPIMigration { Try { val ordered = digraph.linearize.reverse ordered.map { key => - val WRef(name, _, kind, _) = key.e1 + val WRef(name, _, kind, _) = key.e1: @unchecked kind match { case RegKind => regInfo(key) - case RandomKind => randInfo(key) case WireKind | NodeKind => val (Seq(rhs), info) = netlist(key) DefNode(info, name, rhs) @@ -104,8 +101,6 @@ class RemoveWires extends Transform with DependencyAPIMigration { val wireInfo = mutable.HashMap.empty[WrappedExpression, Info] // Additional info about registers val regInfo = mutable.HashMap.empty[WrappedExpression, DefRegister] - // Additional info about rand statements - val randInfo = mutable.HashMap.empty[WrappedExpression, DefRandom] def onStmt(stmt: Statement): Statement = { stmt match { @@ -121,9 +116,6 @@ class RemoveWires extends Transform with DependencyAPIMigration { val initDep = Some(reg.init).filter(we(WRef(reg)) != we(_)) // Dependency exists IF reg doesn't init itself regInfo(we(WRef(reg))) = reg netlist(we(WRef(reg))) = (Seq(reg.clock) ++ resetDep ++ initDep, reg.info) - case rand: DefRandom => - randInfo(we(Reference(rand))) = rand - netlist(we(Reference(rand))) = (rand.clock ++: rand.en +: List(), rand.info) case decl: CanBeReferenced => // Keep all declarations except for nodes and non-Analog wires and "other" statements. // Thus this is expected to match DefInstance and DefMemory which both do not connect to @@ -160,7 +152,7 @@ class RemoveWires extends Transform with DependencyAPIMigration { m match { case mod @ Module(info, name, ports, body) => onStmt(body) - getOrderedNodes(netlist, regInfo, randInfo) match { + getOrderedNodes(netlist, regInfo) match { case Success(logic) => Module(info, name, ports, Block(List() ++ decls ++ logic ++ otherStmts)) // If we hit a CyclicException, just abort removing wires diff --git a/src/main/scala/firrtl/transforms/TopWiring.scala b/src/main/scala/firrtl/transforms/TopWiring.scala index 9fc40c59..abab471a 100644 --- a/src/main/scala/firrtl/transforms/TopWiring.scala +++ b/src/main/scala/firrtl/transforms/TopWiring.scala @@ -151,7 +151,7 @@ class TopWiringTransform extends Transform with DependencyAPIMigration { // Map of component name to relative instance paths that result in a debug wire val sourcemods: mutable.Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]] = - mutable.Map(sSourcesModNames.map(_ -> Seq()): _*) + mutable.Map(sSourcesModNames.map(_ -> Seq())*) state.circuit.modules.foreach { m => m.map(getSourceTypes(sSourcesNames, sourcemods, ModuleName(m.name, CircuitName(state.circuit.main)), state)) diff --git a/src/main/scala/firrtl/transforms/formal/AssertSubmoduleAssumptions.scala b/src/main/scala/firrtl/transforms/formal/AssertSubmoduleAssumptions.scala index 3199cedf..2dd6c9da 100644 --- a/src/main/scala/firrtl/transforms/formal/AssertSubmoduleAssumptions.scala +++ b/src/main/scala/firrtl/transforms/formal/AssertSubmoduleAssumptions.scala @@ -5,6 +5,7 @@ package firrtl.transforms.formal import firrtl.ir.{Circuit, Formal, Statement, Verification} import firrtl.stage.TransformManager.TransformDependency import firrtl.{CircuitState, DependencyAPIMigration, Transform} +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import firrtl.annotations.NoTargetAnnotation import firrtl.options.{PreservesAll, RegisteredTransform, ShellOption} @@ -36,7 +37,7 @@ class AssertSubmoduleAssumptions ) def assertAssumption(s: Statement): Statement = s match { - case v: Verification if v.op == Formal.Assume => v.withOp(Formal.Assert) + case v: Verification if v.op == Formal.Assume => v.copy(op = Formal.Assert) case t => t.mapStmt(assertAssumption) } diff --git a/src/main/scala/logger/Logger.scala b/src/main/scala/logger/Logger.scala index f14eec71..9da25656 100644 --- a/src/main/scala/logger/Logger.scala +++ b/src/main/scala/logger/Logger.scala @@ -5,6 +5,7 @@ package logger import java.io.{ByteArrayOutputStream, File, FileOutputStream, PrintStream} import firrtl.AnnotationSeq +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import firrtl.options.Viewer.view import logger.phases.{AddDefaults, Checks} @@ -268,7 +269,7 @@ object Logger { * @param classType Kind of class * @param level log level to set */ - def setLevel(classType: Class[_ <: LazyLogging], level: LogLevel.Value): Unit = { + def setLevel(classType: Class[? <: LazyLogging], level: LogLevel.Value): Unit = { clearCache() val name = classType.getCanonicalName state.classLevels(name) = level diff --git a/src/main/scala/logger/LoggerAnnotations.scala b/src/main/scala/logger/LoggerAnnotations.scala index 0185492e..afecc011 100644 --- a/src/main/scala/logger/LoggerAnnotations.scala +++ b/src/main/scala/logger/LoggerAnnotations.scala @@ -3,6 +3,7 @@ package logger import firrtl.annotations.{Annotation, NoTargetAnnotation} +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import firrtl.options.{HasShellOptions, ShellOption} /** An annotation associated with a Logger command line option */ @@ -47,7 +48,7 @@ object ClassLogLevelAnnotation extends HasShellOptions { longOption = "class-log-level", toAnnotationSeq = (a: Seq[String]) => a.map { aa => - val className :: levelName :: _ = aa.split(":").toList + val className :: levelName :: _ = aa.split(":").toList: @unchecked val level = LogLevel(levelName) ClassLogLevelAnnotation(className, level) }, diff --git a/src/main/scala/logger/package.scala b/src/main/scala/logger/package.scala index 320b06cb..377989b5 100644 --- a/src/main/scala/logger/package.scala +++ b/src/main/scala/logger/package.scala @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 import firrtl.AnnotationSeq +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import firrtl.options.OptionsView package object logger { diff --git a/src/main/scala/logger/phases/AddDefaults.scala b/src/main/scala/logger/phases/AddDefaults.scala index 722b7c78..88db8b88 100644 --- a/src/main/scala/logger/phases/AddDefaults.scala +++ b/src/main/scala/logger/phases/AddDefaults.scala @@ -3,6 +3,7 @@ package logger.phases import firrtl.AnnotationSeq +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import firrtl.options.Phase import logger.{LogLevelAnnotation, LoggerOption} diff --git a/src/main/scala/logger/phases/Checks.scala b/src/main/scala/logger/phases/Checks.scala index 96df4a14..4fa22f25 100644 --- a/src/main/scala/logger/phases/Checks.scala +++ b/src/main/scala/logger/phases/Checks.scala @@ -3,6 +3,7 @@ package logger.phases import firrtl.AnnotationSeq +import firrtl.{seqToAnnoSeq, annoSeqToSeq} import firrtl.annotations.Annotation import firrtl.options.{Dependency, Phase} |
