aboutsummaryrefslogtreecommitdiff
path: root/src/test/scala/firrtl/testutils/LeanTransformSpec.scala
blob: 2d1cad8de7a294c4f96ccee901680fd43842e3d3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
// SPDX-License-Identifier: Apache-2.0

package firrtl.testutils

import firrtl.{ir, AnnotationSeq, CircuitState, EmitCircuitAnnotation}
import firrtl.options.Dependency
import firrtl.passes.RemoveEmpty
import firrtl.stage.TransformManager.TransformDependency
import logger.LazyLogging
import org.scalatest.flatspec.AnyFlatSpec

class VerilogTransformSpec extends LeanTransformSpec(Seq(Dependency[firrtl.VerilogEmitter]))
class LowFirrtlTransformSpec extends LeanTransformSpec(Seq(Dependency[firrtl.LowFirrtlEmitter]))

/** The new cool kid on the block, creates a custom compiler for your transform. */
class LeanTransformSpec(protected val transforms: Seq[TransformDependency])
    extends AnyFlatSpec
    with FirrtlMatchers
    with LazyLogging {
  private val compiler = new firrtl.stage.transforms.Compiler(transforms)
  private val emitterAnnos = LeanTransformSpec.deriveEmitCircuitAnnotations(transforms)

  protected def compile(src: String): CircuitState = compile(src, Seq())
  protected def compile(src: String, annos: AnnotationSeq): CircuitState = compile(firrtl.Parser.parse(src), annos)
  protected def compile(c:   ir.Circuit): CircuitState = compile(c, Seq())
  protected def compile(c:   ir.Circuit, annos: AnnotationSeq): CircuitState =
    compiler.transform(CircuitState(c, emitterAnnos ++ annos))
  protected def execute(input: String, check: String): CircuitState = execute(input, check, Seq())
  protected def execute(input: String, check: String, inAnnos: AnnotationSeq): CircuitState = {
    val finalState = compiler.transform(CircuitState(parse(input), inAnnos))
    val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize
    val expected = parse(check).serialize
    logger.debug(actual)
    logger.debug(expected)
    actual should be(expected)
    finalState
  }
  protected def removeSkip(c: ir.Circuit): ir.Circuit = {
    def onStmt(s: ir.Statement): ir.Statement = s.mapStmt(onStmt)
    c.mapModule(m => m.mapStmt(onStmt))
  }
}

private object LeanTransformSpec {
  private def deriveEmitCircuitAnnotations(transforms: Iterable[TransformDependency]): AnnotationSeq = {
    val emitters = transforms.map(_.getObject()).collect { case e: firrtl.Emitter => e }
    emitters.map(e => EmitCircuitAnnotation(e.getClass)).toSeq
  }
}

/** Use this if you just need to create a standard compiler and want to save some typing. */
trait MakeCompiler {
  protected def makeVerilogCompiler(transforms: Seq[TransformDependency] = Seq()) =
    new firrtl.stage.transforms.Compiler(Seq(Dependency[firrtl.VerilogEmitter]) ++ transforms)
  protected def makeMinimumVerilogCompiler(transforms: Seq[TransformDependency] = Seq()) =
    new firrtl.stage.transforms.Compiler(Seq(Dependency[firrtl.MinimumVerilogEmitter]) ++ transforms)
  protected def makeLowFirrtlCompiler(transforms: Seq[TransformDependency] = Seq()) =
    new firrtl.stage.transforms.Compiler(Seq(Dependency[firrtl.LowFirrtlEmitter]) ++ transforms)
}