aboutsummaryrefslogtreecommitdiff
path: root/src/test/scala/firrtl
diff options
context:
space:
mode:
authorDavid Biancolin2020-03-17 13:26:40 -0700
committerGitHub2020-03-17 13:26:40 -0700
commitba1f24345ac5ab20c669c73b871920001ac3a8ed (patch)
treea6a55fafd5f68c35e574a34842930165af5631ad /src/test/scala/firrtl
parentd0500b33167cad060a9325d68b939d41279f6c9c (diff)
[RFC] Factor out common test classes; package them (#1412)
* Pull out common test utilities into a separate package * Project a fat jar for test utilities Co-authored-by: Albert Magyar <albert.magyar@gmail.com>
Diffstat (limited to 'src/test/scala/firrtl')
-rw-r--r--src/test/scala/firrtl/testutils/FirrtlSpec.scala406
-rw-r--r--src/test/scala/firrtl/testutils/PassTests.scala106
2 files changed, 512 insertions, 0 deletions
diff --git a/src/test/scala/firrtl/testutils/FirrtlSpec.scala b/src/test/scala/firrtl/testutils/FirrtlSpec.scala
new file mode 100644
index 00000000..46f36e87
--- /dev/null
+++ b/src/test/scala/firrtl/testutils/FirrtlSpec.scala
@@ -0,0 +1,406 @@
+// See LICENSE for license details.
+
+package firrtl.testutils
+
+import java.io._
+import java.security.Permission
+
+import logger.LazyLogging
+
+import org.scalatest._
+import org.scalatestplus.scalacheck._
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Parser.UseInfo
+import firrtl.stage.{FirrtlFileAnnotation, InfoModeAnnotation, RunFirrtlTransformAnnotation}
+import firrtl.analyses.{GetNamespace, ModuleNamespaceAnnotation}
+import firrtl.annotations._
+import firrtl.transforms.{DontTouchAnnotation, NoDedupAnnotation, RenameModules}
+import firrtl.util.BackendCompilationUtilities
+
+class CheckLowForm extends SeqTransform {
+ def inputForm = LowForm
+ def outputForm = LowForm
+ def transforms = Seq(
+ passes.CheckHighForm
+ )
+}
+
+trait FirrtlRunners extends BackendCompilationUtilities {
+
+ val cppHarnessResourceName: String = "/firrtl/testTop.cpp"
+ /** Extra transforms to run by default */
+ val extraCheckTransforms = Seq(new CheckLowForm)
+
+ private class RenameTop(newTopPrefix: String) extends Transform {
+ def inputForm: LowForm.type = LowForm
+ def outputForm: LowForm.type = LowForm
+
+ def execute(state: CircuitState): CircuitState = {
+ val namespace = state.annotations.collectFirst {
+ case m: ModuleNamespaceAnnotation => m
+ }.get.namespace
+
+ val newTopName = namespace.newName(newTopPrefix)
+ val modulesx = state.circuit.modules.map {
+ case mod: Module if mod.name == state.circuit.main => mod.mapString(_ => newTopName)
+ case other => other
+ }
+
+ state.copy(circuit = state.circuit.copy(main = newTopName, modules = modulesx))
+ }
+ }
+
+ /** Check equivalence of Firrtl transforms using yosys
+ *
+ * @param input string containing Firrtl source
+ * @param customTransforms Firrtl transforms to test for equivalence
+ * @param customAnnotations Optional Firrtl annotations
+ * @param resets tell yosys which signals to set for SAT, format is (timestep, signal, value)
+ */
+ def firrtlEquivalenceTest(input: String,
+ customTransforms: Seq[Transform] = Seq.empty,
+ customAnnotations: AnnotationSeq = Seq.empty,
+ resets: Seq[(Int, String, Int)] = Seq.empty): Unit = {
+ val circuit = Parser.parse(input.split("\n").toIterator)
+ val compiler = new MinimumVerilogCompiler
+ val prefix = circuit.main
+ val testDir = createTestDirectory(prefix + "_equivalence_test")
+ val firrtlWriter = new PrintWriter(s"${testDir.getAbsolutePath}/$prefix.fir")
+ firrtlWriter.write(input)
+ firrtlWriter.close()
+
+ val customVerilog = compiler.compileAndEmit(CircuitState(circuit, HighForm, customAnnotations),
+ new GetNamespace +: new RenameTop(s"${prefix}_custom") +: customTransforms)
+ val namespaceAnnotation = customVerilog.annotations.collectFirst { case m: ModuleNamespaceAnnotation => m }.get
+ val customTop = customVerilog.circuit.main
+ val customFile = new PrintWriter(s"${testDir.getAbsolutePath}/$customTop.v")
+ customFile.write(customVerilog.getEmittedCircuit.value)
+ customFile.close()
+
+ val referenceVerilog = compiler.compileAndEmit(CircuitState(circuit, HighForm, Seq(namespaceAnnotation)),
+ Seq(new RenameModules, new RenameTop(s"${prefix}_reference")))
+ val referenceTop = referenceVerilog.circuit.main
+ val referenceFile = new PrintWriter(s"${testDir.getAbsolutePath}/$referenceTop.v")
+ referenceFile.write(referenceVerilog.getEmittedCircuit.value)
+ referenceFile.close()
+
+ assert(yosysExpectSuccess(customTop, referenceTop, testDir, resets))
+ }
+
+ /** Compiles input Firrtl to Verilog */
+ def compileToVerilog(input: String, annotations: AnnotationSeq = Seq.empty): String = {
+ val circuit = Parser.parse(input.split("\n").toIterator)
+ val compiler = new VerilogCompiler
+ val res = compiler.compileAndEmit(CircuitState(circuit, HighForm, annotations), extraCheckTransforms)
+ res.getEmittedCircuit.value
+ }
+ /** Compile a Firrtl file
+ *
+ * @param prefix is the name of the Firrtl file without path or file extension
+ * @param srcDir directory where all Resources for this test are located
+ * @param annotations Optional Firrtl annotations
+ */
+ def compileFirrtlTest(
+ prefix: String,
+ srcDir: String,
+ customTransforms: Seq[Transform] = Seq.empty,
+ annotations: AnnotationSeq = Seq.empty): File = {
+ val testDir = createTestDirectory(prefix)
+ val inputFile = new File(testDir, s"${prefix}.fir")
+ copyResourceToFile(s"${srcDir}/${prefix}.fir", inputFile)
+
+ val annos =
+ FirrtlFileAnnotation(inputFile.toString) +:
+ TargetDirAnnotation(testDir.toString) +:
+ InfoModeAnnotation("ignore") +:
+ annotations ++:
+ (customTransforms ++ extraCheckTransforms).map(RunFirrtlTransformAnnotation(_))
+
+ (new firrtl.stage.FirrtlStage).run(annos)
+
+ testDir
+ }
+ /** Execute a Firrtl Test
+ *
+ * @param prefix is the name of the Firrtl file without path or file extension
+ * @param srcDir directory where all Resources for this test are located
+ * @param verilogPrefixes names of option Verilog resources without path or file extension
+ * @param annotations Optional Firrtl annotations
+ */
+ def runFirrtlTest(
+ prefix: String,
+ srcDir: String,
+ verilogPrefixes: Seq[String] = Seq.empty,
+ customTransforms: Seq[Transform] = Seq.empty,
+ annotations: AnnotationSeq = Seq.empty) = {
+ val testDir = compileFirrtlTest(prefix, srcDir, customTransforms, annotations)
+ val harness = new File(testDir, s"top.cpp")
+ copyResourceToFile(cppHarnessResourceName, harness)
+
+ // Note file copying side effect
+ val verilogFiles = verilogPrefixes map { vprefix =>
+ val file = new File(testDir, s"$vprefix.v")
+ copyResourceToFile(s"$srcDir/$vprefix.v", file)
+ file
+ }
+
+ verilogToCpp(prefix, testDir, verilogFiles, harness) #&&
+ cppToExe(prefix, testDir) !
+ loggingProcessLogger
+ assert(executeExpectingSuccess(prefix, testDir))
+ }
+}
+
+trait FirrtlMatchers extends Matchers {
+ def dontTouch(path: String): Annotation = {
+ val parts = path.split('.')
+ require(parts.size >= 2, "Must specify both module and component!")
+ val name = ComponentName(parts.tail.mkString("."), ModuleName(parts.head, CircuitName("Top")))
+ DontTouchAnnotation(name)
+ }
+ def dontDedup(mod: String): Annotation = {
+ require(mod.split('.').size == 1, "Can only specify a Module, not a component or instance")
+ NoDedupAnnotation(ModuleName(mod, CircuitName("Top")))
+ }
+ // Replace all whitespace with a single space and remove leading and
+ // trailing whitespace
+ // Note this is intended for single-line strings, no newlines
+ def normalized(s: String): String = {
+ require(!s.contains("\n"))
+ s.replaceAll("\\s+", " ").trim
+ }
+ /** Helper to make circuits that are the same appear the same */
+ def canonicalize(circuit: Circuit): Circuit = {
+ import firrtl.Mappers._
+ def onModule(mod: DefModule) = mod.map(firrtl.Utils.squashEmpty)
+ circuit.map(onModule)
+ }
+ def parse(str: String) = Parser.parse(str.split("\n").toIterator, UseInfo)
+ /** Helper for executing tests
+ * compiler will be run on input then emitted result will each be split into
+ * lines and normalized.
+ */
+ def executeTest(
+ input: String,
+ expected: Seq[String],
+ compiler: Compiler,
+ annotations: Seq[Annotation] = Seq.empty) = {
+ val finalState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations))
+ val lines = finalState.getEmittedCircuit.value split "\n" map normalized
+ for (e <- expected) {
+ lines should contain (e)
+ }
+ }
+}
+
+object FirrtlCheckers extends FirrtlMatchers {
+ import matchers._
+ implicit class TestingFunctionsOnCircuitState(val state: CircuitState) extends AnyVal {
+ def search(pf: PartialFunction[Any, Boolean]): Boolean = state.circuit.search(pf)
+ }
+ implicit class TestingFunctionsOnCircuit(val circuit: Circuit) extends AnyVal {
+ def search(pf: PartialFunction[Any, Boolean]): Boolean = {
+ val f = pf.lift
+ def rec(node: Any): Boolean = {
+ f(node) match {
+ // If the partial function is defined on this node, return its result
+ case Some(res) => res
+ // Otherwise keep digging
+ case None =>
+ require(node.isInstanceOf[Product] || !node.isInstanceOf[FirrtlNode],
+ "Error! Unexpected FirrtlNode that does not implement Product!")
+ val iter = node match {
+ case p: Product => p.productIterator
+ case i: Iterable[Any] => i.iterator
+ case _ => Iterator.empty
+ }
+ iter.foldLeft(false) {
+ case (res, elt) => if (res) res else rec(elt)
+ }
+ }
+ }
+ rec(circuit)
+ }
+ }
+
+ /** Checks that the emitted circuit has the expected line, both will be normalized */
+ def containLine(expectedLine: String) = containLines(expectedLine)
+
+ /** Checks that the emitted circuit has the expected lines in order, all lines will be normalized */
+ def containLines(expectedLines: String*) = new CircuitStateStringsMatcher(expectedLines)
+
+ class CircuitStateStringsMatcher(expectedLines: Seq[String]) extends Matcher[CircuitState] {
+ override def apply(state: CircuitState): MatchResult = {
+ val emitted = state.getEmittedCircuit.value
+ MatchResult(
+ emitted.split("\n").map(normalized).containsSlice(expectedLines.map(normalized)),
+ emitted + "\n did not contain \"" + expectedLines + "\"",
+ s"${state.circuit.main} contained $expectedLines"
+ )
+ }
+ }
+
+ def containTree(pf: PartialFunction[Any, Boolean]) = new CircuitStatePFMatcher(pf)
+
+ class CircuitStatePFMatcher(pf: PartialFunction[Any, Boolean]) extends Matcher[CircuitState] {
+ override def apply(state: CircuitState): MatchResult = {
+ MatchResult(
+ state.search(pf),
+ state.circuit.serialize + s"\n did not contain $pf",
+ s"${state.circuit.main} contained $pf"
+ )
+ }
+ }
+}
+
+abstract class FirrtlPropSpec extends PropSpec with ScalaCheckPropertyChecks with FirrtlRunners with LazyLogging
+
+abstract class FirrtlFlatSpec extends FlatSpec with FirrtlRunners with FirrtlMatchers with LazyLogging
+
+// Who tests the testers?
+class TestFirrtlFlatSpec extends FirrtlFlatSpec {
+ import FirrtlCheckers._
+
+ val c = parse("""
+ |circuit Test:
+ | module Test :
+ | input in : UInt<8>
+ | output out : UInt<8>
+ | out <= in
+ |""".stripMargin)
+ val state = CircuitState(c, ChirrtlForm)
+ val compiled = (new LowFirrtlCompiler).compileAndEmit(state, List.empty)
+
+ // While useful, ScalaTest helpers should be used over search
+ behavior of "Search"
+
+ it should "be supported on Circuit" in {
+ assert(c search {
+ case Connect(_, Reference("out",_), Reference("in",_)) => true
+ })
+ }
+ it should "be supported on CircuitStates" in {
+ assert(state search {
+ case Connect(_, Reference("out",_), Reference("in",_)) => true
+ })
+ }
+ it should "be supported on the results of compilers" in {
+ assert(compiled search {
+ case Connect(_, WRef("out",_,_,_), WRef("in",_,_,_)) => true
+ })
+ }
+
+ // Use these!!!
+ behavior of "ScalaTest helpers"
+
+ they should "work for lines of emitted text" in {
+ compiled should containLine (s"input in : UInt<8>")
+ compiled should containLine (s"output out : UInt<8>")
+ compiled should containLine (s"out <= in")
+ }
+
+ they should "work for partial functions matching on subtrees" in {
+ val UInt8 = UIntType(IntWidth(8)) // BigInt unapply is weird
+ compiled should containTree { case Port(_, "in", Input, UInt8) => true }
+ compiled should containTree { case Port(_, "out", Output, UInt8) => true }
+ compiled should containTree { case Connect(_, WRef("out",_,_,_), WRef("in",_,_,_)) => true }
+ }
+}
+
+/** Super class for execution driven Firrtl tests */
+abstract class ExecutionTest(name: String, dir: String, vFiles: Seq[String] = Seq.empty, annotations: AnnotationSeq = Seq.empty) extends FirrtlPropSpec {
+ property(s"$name should execute correctly") {
+ runFirrtlTest(name, dir, vFiles, annotations = annotations)
+ }
+}
+/** Super class for compilation driven Firrtl tests */
+abstract class CompilationTest(name: String, dir: String) extends FirrtlPropSpec {
+ property(s"$name should compile correctly") {
+ compileFirrtlTest(name, dir)
+ }
+}
+
+trait Utils {
+
+ /** Run some Scala thunk and return STDOUT and STDERR as strings.
+ * @param thunk some Scala code
+ * @return a tuple containing STDOUT, STDERR, and what the thunk returns
+ */
+ def grabStdOutErr[T](thunk: => T): (String, String, T) = {
+ val stdout, stderr = new ByteArrayOutputStream()
+ val ret = scala.Console.withOut(stdout) { scala.Console.withErr(stderr) { thunk } }
+ (stdout.toString, stderr.toString, ret)
+ }
+
+ /** Encodes a System.exit exit code
+ * @param status the exit code
+ */
+ private case class ExitException(status: Int) extends SecurityException(s"Found a sys.exit with code $status")
+
+ /** A security manager that converts calls to System.exit into [[ExitException]]s by explicitly disabling the ability of
+ * a thread to actually exit. For more information, see:
+ * - https://docs.oracle.com/javase/tutorial/essential/environment/security.html
+ */
+ private class ExceptOnExit extends SecurityManager {
+ override def checkPermission(perm: Permission): Unit = {}
+ override def checkPermission(perm: Permission, context: Object): Unit = {}
+ override def checkExit(status: Int): Unit = {
+ super.checkExit(status)
+ throw ExitException(status)
+ }
+ }
+
+ /** Encodes a file that some code tries to write to
+ * @param the file name
+ */
+ private case class WriteException(file: String) extends SecurityException(s"Tried to write to file $file")
+
+ /** A security manager that converts writes to any file into [[WriteException]]s.
+ */
+ private class ExceptOnWrite extends SecurityManager {
+ override def checkPermission(perm: Permission): Unit = {}
+ override def checkPermission(perm: Permission, context: Object): Unit = {}
+ override def checkWrite(file: String): Unit = {
+ super.checkWrite(file)
+ throw WriteException(file)
+ }
+ }
+
+ /** Run some Scala code (a thunk) in an environment where all System.exit are caught and returned. This avoids a
+ * situation where a test results in something actually exiting and killing the entire test. This is necessary if you
+ * want to test a command line program, e.g., the `main` method of [[firrtl.options.Stage Stage]].
+ *
+ * NOTE: THIS WILL NOT WORK IN SITUATIONS WHERE THE THUNK IS CATCHING ALL [[Exception]]s OR [[Throwable]]s, E.G.,
+ * SCOPT. IF THIS IS HAPPENING THIS WILL NOT WORK. REPEAT THIS WILL NOT WORK.
+ * @param thunk some Scala code
+ * @return either the output of the thunk (`Right[T]`) or an exit code (`Left[Int]`)
+ */
+ def catchStatus[T](thunk: => T): Either[Int, T] = {
+ try {
+ System.setSecurityManager(new ExceptOnExit())
+ Right(thunk)
+ } catch {
+ case ExitException(a) => Left(a)
+ } finally {
+ System.setSecurityManager(null)
+ }
+ }
+
+ /** Run some Scala code (a thunk) in an environment where file writes are caught and the file that a program tries to
+ * write to is returned. This is useful if you want to test that some thunk either tries to write to a specific file
+ * or doesn't try to write at all.
+ */
+ def catchWrites[T](thunk: => T): Either[String, T] = {
+ try {
+ System.setSecurityManager(new ExceptOnWrite())
+ Right(thunk)
+ } catch {
+ case WriteException(a) => Left(a)
+ } finally {
+ System.setSecurityManager(null)
+ }
+ }
+
+}
diff --git a/src/test/scala/firrtl/testutils/PassTests.scala b/src/test/scala/firrtl/testutils/PassTests.scala
new file mode 100644
index 00000000..c172163e
--- /dev/null
+++ b/src/test/scala/firrtl/testutils/PassTests.scala
@@ -0,0 +1,106 @@
+// See LICENSE for license details.
+
+package firrtl.testutils
+
+import org.scalatest.FlatSpec
+import firrtl.ir.Circuit
+import firrtl.passes.{PassExceptions, RemoveEmpty}
+import firrtl.transforms.DedupModules
+import firrtl._
+import firrtl.annotations._
+import logger._
+
+// An example methodology for testing Firrtl Passes
+// Spec class should extend this class
+abstract class SimpleTransformSpec extends FlatSpec with FirrtlMatchers with Compiler with LazyLogging {
+ // Utility function
+ def squash(c: Circuit): Circuit = RemoveEmpty.run(c)
+
+ // Executes the test. Call in tests.
+ // annotations cannot have default value because scalatest trait Suite has a default value
+ def execute(input: String, check: String, annotations: Seq[Annotation]): CircuitState = {
+ val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations))
+ val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize
+ val expected = parse(check).serialize
+ logger.debug(actual)
+ logger.debug(expected)
+ (actual) should be (expected)
+ finalState
+ }
+
+ def executeWithAnnos(input: String, check: String, annotations: Seq[Annotation],
+ checkAnnotations: Seq[Annotation]): CircuitState = {
+ val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations))
+ val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize
+ val expected = parse(check).serialize
+ logger.debug(actual)
+ logger.debug(expected)
+ (actual) should be (expected)
+
+ annotations.foreach { anno =>
+ logger.debug(anno.serialize)
+ }
+
+ finalState.annotations.toSeq.foreach { anno =>
+ logger.debug(anno.serialize)
+ }
+ checkAnnotations.foreach { check =>
+ (finalState.annotations.toSeq) should contain (check)
+ }
+ finalState
+ }
+ // Executes the test, should throw an error
+ // No default to be consistent with execute
+ def failingexecute(input: String, annotations: Seq[Annotation]): Exception = {
+ intercept[PassExceptions] {
+ compile(CircuitState(parse(input), ChirrtlForm, annotations), Seq.empty)
+ }
+ }
+}
+
+class CustomResolveAndCheck(form: CircuitForm) extends SeqTransform {
+ def inputForm = form
+ def outputForm = form
+ def transforms: Seq[Transform] = Seq[Transform](new ResolveAndCheck)
+}
+
+trait LowTransformSpec extends SimpleTransformSpec {
+ def emitter = new LowFirrtlEmitter
+ def transform: Transform
+ def transforms: Seq[Transform] = Seq(
+ new ChirrtlToHighFirrtl(),
+ new IRToWorkingIR(),
+ new ResolveAndCheck(),
+ new DedupModules(),
+ new HighFirrtlToMiddleFirrtl(),
+ new MiddleFirrtlToLowFirrtl(),
+ new CustomResolveAndCheck(LowForm),
+ transform
+ )
+}
+
+trait MiddleTransformSpec extends SimpleTransformSpec {
+ def emitter = new MiddleFirrtlEmitter
+ def transform: Transform
+ def transforms: Seq[Transform] = Seq(
+ new ChirrtlToHighFirrtl(),
+ new IRToWorkingIR(),
+ new ResolveAndCheck(),
+ new DedupModules(),
+ new HighFirrtlToMiddleFirrtl(),
+ new CustomResolveAndCheck(MidForm),
+ transform
+ )
+}
+
+trait HighTransformSpec extends SimpleTransformSpec {
+ def emitter = new HighFirrtlEmitter
+ def transform: Transform
+ def transforms = Seq(
+ new ChirrtlToHighFirrtl(),
+ new IRToWorkingIR(),
+ new CustomResolveAndCheck(HighForm),
+ new DedupModules(),
+ transform
+ )
+}