aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala16
-rw-r--r--src/main/scala/firrtl/analyses/GetNamespace.scala22
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala13
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala61
-rw-r--r--src/main/scala/firrtl/transforms/RenameModules.scala50
-rw-r--r--src/main/scala/firrtl/util/BackendCompilationUtilities.scala61
-rw-r--r--src/test/scala/firrtlTests/ConstantPropagationTests.scala136
-rw-r--r--src/test/scala/firrtlTests/FirrtlSpec.scala65
8 files changed, 390 insertions, 34 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index c4230b90..ba686f08 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -111,6 +111,15 @@ class LowFirrtlOptimization extends CoreTransform {
passes.CommonSubexpressionElimination,
new firrtl.transforms.DeadCodeElimination)
}
+/** Runs runs only the optimization passes needed for Verilog emission */
+class MinimumLowFirrtlOptimization extends CoreTransform {
+ def inputForm = LowForm
+ def outputForm = LowForm
+ def transforms = Seq(
+ passes.Legalize,
+ passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter
+ passes.SplitExpressions)
+}
import CompilerUtils.getLoweringTransforms
@@ -142,3 +151,10 @@ class VerilogCompiler extends Compiler {
def transforms: Seq[Transform] = getLoweringTransforms(ChirrtlForm, LowForm) ++
Seq(new LowFirrtlOptimization, new BlackBoxSourceHelper)
}
+
+/** Emits Verilog without optimizations */
+class MinimumVerilogCompiler extends Compiler {
+ def emitter = new VerilogEmitter
+ def transforms: Seq[Transform] = getLoweringTransforms(ChirrtlForm, LowForm) ++
+ Seq(new MinimumLowFirrtlOptimization, new BlackBoxSourceHelper)
+}
diff --git a/src/main/scala/firrtl/analyses/GetNamespace.scala b/src/main/scala/firrtl/analyses/GetNamespace.scala
new file mode 100644
index 00000000..5ab096b7
--- /dev/null
+++ b/src/main/scala/firrtl/analyses/GetNamespace.scala
@@ -0,0 +1,22 @@
+// See LICENSE for license details.
+
+package firrtl.analyses
+
+import firrtl.annotations.NoTargetAnnotation
+import firrtl.{CircuitState, LowForm, Namespace, Transform}
+
+case class ModuleNamespaceAnnotation(namespace: Namespace) extends NoTargetAnnotation
+
+/** Create a namespace with this circuit
+ *
+ * namespace is used by RenameModules to get unique names
+ */
+class GetNamespace extends Transform {
+ def inputForm: LowForm.type = LowForm
+ def outputForm: LowForm.type = LowForm
+
+ def execute(state: CircuitState): CircuitState = {
+ val namespace = Namespace(state.circuit)
+ state.copy(annotations = new ModuleNamespaceAnnotation(namespace) +: state.annotations)
+ }
+}
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
index 9bbde5f6..5e5aa26a 100644
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ b/src/main/scala/firrtl/passes/Passes.scala
@@ -3,12 +3,12 @@
package firrtl.passes
import com.typesafe.scalalogging.LazyLogging
-
import firrtl._
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
import firrtl.PrimOps._
+import firrtl.transforms.ConstantPropagation
import scala.collection.mutable
@@ -198,14 +198,9 @@ object Legalize extends Pass {
e
}
}
- private def legalizeBits(expr: DoPrim): Expression = {
- lazy val (hi, low) = (expr.consts.head, expr.consts(1))
- lazy val mask = (BigInt(1) << (hi - low + 1).toInt) - 1
- lazy val width = IntWidth(hi - low + 1)
+ private def legalizeBitExtract(expr: DoPrim): Expression = {
expr.args.head match {
- case UIntLiteral(value, _) => UIntLiteral((value >> low.toInt) & mask, width)
- case SIntLiteral(value, _) => SIntLiteral((value >> low.toInt) & mask, width)
- //case FixedLiteral
+ case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.constPropBitExtract(expr)
case _ => expr
}
}
@@ -236,7 +231,7 @@ object Legalize extends Pass {
case prim: DoPrim => prim.op match {
case Shr => legalizeShiftRight(prim)
case Pad => legalizePad(prim)
- case Bits => legalizeBits(prim)
+ case Bits | Head | Tail => legalizeBitExtract(prim)
case _ => prim
}
case e => e // respect pre-order traversal
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index 4a4f41d1..0d30446c 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -17,12 +17,33 @@ import annotation.tailrec
import collection.mutable
object ConstantPropagation {
+ private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t)
+
/** Pads e to the width of t */
def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match {
case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t)
case (we, wt) if we == wt => e
}
+ def constPropBitExtract(e: DoPrim) = {
+ val arg = e.args.head
+ val (hi, lo) = e.op match {
+ case Bits => (e.consts.head.toInt, e.consts(1).toInt)
+ case Tail => ((bitWidth(arg.tpe) - 1 - e.consts.head).toInt, 0)
+ case Head => ((bitWidth(arg.tpe) - 1).toInt, (bitWidth(arg.tpe) - e.consts.head).toInt)
+ }
+
+ arg match {
+ case lit: Literal =>
+ require(hi >= lo)
+ UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe))
+ case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match {
+ case t: UIntType => x
+ case _ => asUInt(x, e.tpe)
+ }
+ case _ => e
+ }
+ }
}
class ConstantPropagation extends Transform {
@@ -30,9 +51,7 @@ class ConstantPropagation extends Transform {
def inputForm = LowForm
def outputForm = LowForm
- private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t)
-
- trait FoldLogicalOp {
+ trait FoldCommutativeOp {
def fold(c1: Literal, c2: Literal): Expression
def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression
@@ -44,7 +63,19 @@ class ConstantPropagation extends Transform {
}
}
- object FoldAND extends FoldLogicalOp {
+ object FoldADD extends FoldCommutativeOp {
+ def fold(c1: Literal, c2: Literal) = (c1, c2) match {
+ case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1))
+ case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1))
+ }
+ def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
+ case UIntLiteral(v, w) if v == BigInt(0) => rhs
+ case SIntLiteral(v, w) if v == BigInt(0) => rhs
+ case _ => e
+ }
+ }
+
+ object FoldAND extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value & c2.value, c1.width max c2.width)
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w)
@@ -54,7 +85,7 @@ class ConstantPropagation extends Transform {
}
}
- object FoldOR extends FoldLogicalOp {
+ object FoldOR extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value | c2.value, c1.width max c2.width)
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, _) if v == BigInt(0) => rhs
@@ -64,7 +95,7 @@ class ConstantPropagation extends Transform {
}
}
- object FoldXOR extends FoldLogicalOp {
+ object FoldXOR extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width)
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, _) if v == BigInt(0) => rhs
@@ -73,7 +104,7 @@ class ConstantPropagation extends Transform {
}
}
- object FoldEqual extends FoldLogicalOp {
+ object FoldEqual extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1))
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs
@@ -81,7 +112,7 @@ class ConstantPropagation extends Transform {
}
}
- object FoldNotEqual extends FoldLogicalOp {
+ object FoldNotEqual extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1))
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs
@@ -189,6 +220,7 @@ class ConstantPropagation extends Transform {
case Shl => foldShiftLeft(e)
case Shr => foldShiftRight(e)
case Cat => foldConcat(e)
+ case Add => FoldADD(e)
case And => FoldAND(e)
case Or => FoldOR(e)
case Xor => FoldXOR(e)
@@ -215,18 +247,7 @@ class ConstantPropagation extends Transform {
case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head
case _ => e
}
- case Bits => e.args.head match {
- case lit: Literal =>
- val hi = e.consts.head.toInt
- val lo = e.consts(1).toInt
- require(hi >= lo)
- UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe))
- case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match {
- case t: UIntType => x
- case _ => asUInt(x, e.tpe)
- }
- case _ => e
- }
+ case (Bits | Head | Tail) => constPropBitExtract(e)
case _ => e
}
diff --git a/src/main/scala/firrtl/transforms/RenameModules.scala b/src/main/scala/firrtl/transforms/RenameModules.scala
new file mode 100644
index 00000000..af17dda9
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/RenameModules.scala
@@ -0,0 +1,50 @@
+// See LICENSE for license details.
+
+package firrtl.transforms
+
+import firrtl.analyses.{InstanceGraph, ModuleNamespaceAnnotation}
+import firrtl.ir._
+import firrtl._
+
+import scala.collection.mutable
+
+/** Rename Modules
+ *
+ * using namespace created by [[analyses.GetNamespace]], create unique names for modules
+ */
+class RenameModules extends Transform {
+ def inputForm: LowForm.type = LowForm
+ def outputForm: LowForm.type = LowForm
+
+ def collectNameMapping(namespace: Namespace, moduleNameMap: mutable.HashMap[String, String])(mod: DefModule): Unit = {
+ val newName = namespace.newName(mod.name)
+ moduleNameMap.put(mod.name, newName)
+ }
+
+ def onStmt(moduleNameMap: mutable.HashMap[String, String])(stmt: Statement): Statement = stmt match {
+ case inst: WDefInstance if moduleNameMap.contains(inst.module) => inst.copy(module = moduleNameMap(inst.module))
+ case other => other.mapStmt(onStmt(moduleNameMap))
+ }
+
+ def execute(state: CircuitState): CircuitState = {
+ val namespace = state.annotations.collectFirst {
+ case m: ModuleNamespaceAnnotation => m
+ }.map(_.namespace)
+
+ if (namespace.isEmpty) {
+ logger.warn("Skipping Rename Modules")
+ state
+ } else {
+ val moduleOrder = new InstanceGraph(state.circuit).moduleOrder.reverse
+ val nameMappings = new mutable.HashMap[String, String]()
+ moduleOrder.foreach(collectNameMapping(namespace.get, nameMappings))
+
+ val modulesx = state.circuit.modules.map {
+ case mod: Module => mod.mapStmt(onStmt(nameMappings)).mapString(nameMappings)
+ case ext: ExtModule => ext
+ }
+
+ state.copy(circuit = state.circuit.copy(modules = modulesx, main = nameMappings(state.circuit.main)))
+ }
+ }
+}
diff --git a/src/main/scala/firrtl/util/BackendCompilationUtilities.scala b/src/main/scala/firrtl/util/BackendCompilationUtilities.scala
index 0c5ab12f..d3d34e87 100644
--- a/src/main/scala/firrtl/util/BackendCompilationUtilities.scala
+++ b/src/main/scala/firrtl/util/BackendCompilationUtilities.scala
@@ -152,4 +152,65 @@ trait BackendCompilationUtilities {
def executeExpectingSuccess(prefix: String, dir: File): Boolean = {
!executeExpectingFailure(prefix, dir)
}
+
+ /** Creates and runs a Yosys script that creates and runs SAT on a miter
+ * circuit. Returns true if SAT succeeds, false otherwise
+ *
+ * The custom and reference Verilog files must not contain any modules with
+ * the same name otherwise Yosys will not be able to create a miter circuit
+ *
+ * @param customTop name of the DUT with custom transforms without the .v
+ * extension
+ * @param referenceTop name of the DUT without custom transforms without the
+ * .v extension
+ * @param testDir directory containing verilog files
+ * @param resets signals to set for SAT, format is
+ * (timestep, signal, value)
+ */
+ def yosysExpectSuccess(customTop: String,
+ referenceTop: String,
+ testDir: File,
+ resets: Seq[(Int, String, Int)] = Seq.empty): Boolean = {
+ !yosysExpectFailure(customTop, referenceTop, testDir, resets)
+ }
+
+ /** Creates and runs a Yosys script that creates and runs SAT on a miter
+ * circuit. Returns false if SAT succeeds, true otherwise
+ *
+ * The custom and reference Verilog files must not contain any modules with
+ * the same name otherwise Yosys will not be able to create a miter circuit
+ *
+ * @param customTop name of the DUT with custom transforms without the .v
+ * extension
+ * @param referenceTop name of the DUT without custom transforms without the
+ * .v extension
+ * @param testDir directory containing verilog files
+ * @param resets signals to set for SAT, format is
+ * (timestep, signal, value)
+ */
+ def yosysExpectFailure(customTop: String,
+ referenceTop: String,
+ testDir: File,
+ resets: Seq[(Int, String, Int)] = Seq.empty): Boolean = {
+
+ val setSignals = resets.map(_._2).toSet[String].map(s => s"-set in_$s 0").mkString(" ")
+ val setAtSignals = resets.map {
+ case (timestep, signal, value) => s"-set-at $timestep in_$signal $value"
+ }.mkString(" ")
+ val scriptFileName = s"${testDir.getAbsolutePath}/yosys_script"
+ val yosysScriptWriter = new PrintWriter(scriptFileName)
+ yosysScriptWriter.write(
+ s"""read_verilog ${testDir.getAbsolutePath}/$customTop.v
+ |read_verilog ${testDir.getAbsolutePath}/$referenceTop.v
+ |prep; proc; opt; memory
+ |miter -equiv -flatten $customTop $referenceTop miter
+ |hierarchy -top miter
+ |sat -verify -tempinduct -prove trigger 0 $setSignals $setAtSignals -seq 1 miter"""
+ .stripMargin)
+ yosysScriptWriter.close()
+
+ val resultFileName = testDir.getAbsolutePath + "/yosys_results"
+ val command = s"yosys -s $scriptFileName" #> new File(resultFileName)
+ command.! != 0
+ }
}
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
index 6fc685a8..603ddc25 100644
--- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala
+++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
@@ -694,6 +694,46 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec {
"""
(parse(exec(input))) should be (parse(check))
}
+
+ "ConstProp" should "propagate constant addition" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<5>
+ | output z : UInt<5>
+ | node _T_1 = add(UInt<5>("h0"), UInt<5>("h1"))
+ | node _T_2 = add(_T_1, UInt<5>("h2"))
+ | z <= add(x, _T_2)
+ """.stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<5>
+ | output z : UInt<5>
+ | node _T_1 = UInt<6>("h1")
+ | node _T_2 = UInt<7>("h3")
+ | z <= add(x, UInt<7>("h3"))
+ """.stripMargin
+ (parse(exec(input))) should be(parse(check))
+ }
+
+ "ConstProp" should "propagate addition with zero" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<5>
+ | output z : UInt<5>
+ | z <= add(x, UInt<5>("h0"))
+ """.stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<5>
+ | output z : UInt<5>
+ | z <= pad(x, 6)
+ """.stripMargin
+ (parse(exec(input))) should be(parse(check))
+ }
}
// More sophisticated tests of the full compiler
@@ -1065,3 +1105,99 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec {
execute(input, check, Seq.empty)
}
}
+
+
+class ConstantPropagationEquivalenceSpec extends FirrtlFlatSpec {
+ private val srcDir = "/constant_propagation_tests"
+ private val transforms = Seq(new ConstantPropagation)
+
+ "anything added to zero" should "be equal to itself" in {
+ val input =
+ s"""circuit AddZero :
+ | module AddZero :
+ | input in : UInt<5>
+ | output out1 : UInt<6>
+ | output out2 : UInt<6>
+ | out1 <= add(in, UInt<5>("h0"))
+ | out2 <= add(UInt<5>("h0"), in)""".stripMargin
+ firrtlEquivalenceTest(input, transforms)
+ }
+
+ "constants added together" should "be propagated" in {
+ val input =
+ s"""circuit AddLiterals :
+ | module AddLiterals :
+ | input uin : UInt<5>
+ | input sin : SInt<5>
+ | output uout : UInt<6>
+ | output sout : SInt<6>
+ | node uconst = add(UInt<5>("h1"), UInt<5>("h2"))
+ | uout <= add(uconst, uin)
+ | node sconst = add(SInt<5>("h1"), SInt<5>("h-1"))
+ | sout <= add(sconst, sin)""".stripMargin
+ firrtlEquivalenceTest(input, transforms)
+ }
+
+ "UInt addition" should "have the correct widths" in {
+ val input =
+ s"""circuit WidthsAddUInt :
+ | module WidthsAddUInt :
+ | input in : UInt<3>
+ | output out1 : UInt<10>
+ | output out2 : UInt<10>
+ | wire temp : UInt<5>
+ | temp <= add(in, UInt<1>("h0"))
+ | out1 <= cat(temp, temp)
+ | node const = add(UInt<4>("h1"), UInt<3>("h2"))
+ | out2 <= cat(const, const)""".stripMargin
+ firrtlEquivalenceTest(input, transforms)
+ }
+
+ "SInt addition" should "have the correct widths" in {
+ val input =
+ s"""circuit WidthsAddSInt :
+ | module WidthsAddSInt :
+ | input in : SInt<3>
+ | output out1 : UInt<10>
+ | output out2 : UInt<10>
+ | wire temp : SInt<5>
+ | temp <= add(in, SInt<7>("h0"))
+ | out1 <= cat(temp, temp)
+ | node const = add(SInt<4>("h1"), SInt<3>("h-2"))
+ | out2 <= cat(const, const)""".stripMargin
+ firrtlEquivalenceTest(input, transforms)
+ }
+
+ "addition by zero width wires" should "have the correct widths" in {
+ val input =
+ s"""circuit ZeroWidthAdd:
+ | module ZeroWidthAdd:
+ | input x: UInt<0>
+ | output y: UInt<7>
+ | node temp = add(x, UInt<9>("h0"))
+ | y <= cat(temp, temp)""".stripMargin
+ firrtlEquivalenceTest(input, transforms)
+ }
+
+ "tail of constants" should "be propagated" in {
+ val input =
+ s"""circuit TailTester :
+ | module TailTester :
+ | output out : UInt<1>
+ | node temp = add(UInt<1>("h00"), UInt<5>("h017"))
+ | node tail_temp = tail(temp, 1)
+ | out <= tail_temp""".stripMargin
+ firrtlEquivalenceTest(input, transforms)
+ }
+
+ "head of constants" should "be propagated" in {
+ val input =
+ s"""circuit TailTester :
+ | module TailTester :
+ | output out : UInt<1>
+ | node temp = add(UInt<1>("h00"), UInt<5>("h017"))
+ | node head_temp = head(temp, 3)
+ | out <= head_temp""".stripMargin
+ firrtlEquivalenceTest(input, transforms)
+ }
+}
diff --git a/src/test/scala/firrtlTests/FirrtlSpec.scala b/src/test/scala/firrtlTests/FirrtlSpec.scala
index 01ae0431..95b09d93 100644
--- a/src/test/scala/firrtlTests/FirrtlSpec.scala
+++ b/src/test/scala/firrtlTests/FirrtlSpec.scala
@@ -5,22 +5,79 @@ package firrtlTests
import java.io._
import com.typesafe.scalalogging.LazyLogging
+
import scala.sys.process._
import org.scalatest._
import org.scalatest.prop._
-import scala.io.Source
+import scala.io.Source
import firrtl._
import firrtl.ir._
-import firrtl.Parser.UseInfo
+import firrtl.Parser.{IgnoreInfo, UseInfo}
+import firrtl.analyses.{GetNamespace, InstanceGraph, ModuleNamespaceAnnotation}
import firrtl.annotations._
-import firrtl.transforms.{DontTouchAnnotation, NoDedupAnnotation}
+import firrtl.transforms.{DontTouchAnnotation, NoDedupAnnotation, RenameModules}
import firrtl.util.BackendCompilationUtilities
+import scala.collection.mutable
+
trait FirrtlRunners extends BackendCompilationUtilities {
val cppHarnessResourceName: String = "/firrtl/testTop.cpp"
+ private class RenameTop(newTopPrefix: String) extends Transform {
+ def inputForm: LowForm.type = LowForm
+ def outputForm: LowForm.type = LowForm
+
+ def execute(state: CircuitState): CircuitState = {
+ val namespace = state.annotations.collectFirst {
+ case m: ModuleNamespaceAnnotation => m
+ }.get.namespace
+
+ val newTopName = namespace.newName(newTopPrefix)
+ val modulesx = state.circuit.modules.map {
+ case mod: Module if mod.name == state.circuit.main => mod.mapString(_ => newTopName)
+ case other => other
+ }
+
+ state.copy(circuit = state.circuit.copy(main = newTopName, modules = modulesx))
+ }
+ }
+
+ /** Check equivalence of Firrtl transforms using yosys
+ *
+ * @param input string containing Firrtl source
+ * @param customTransforms Firrtl transforms to test for equivalence
+ * @param customAnnotations Optional Firrtl annotations
+ * @param resets tell yosys which signals to set for SAT, format is (timestep, signal, value)
+ */
+ def firrtlEquivalenceTest(input: String,
+ customTransforms: Seq[Transform] = Seq.empty,
+ customAnnotations: AnnotationSeq = Seq.empty,
+ resets: Seq[(Int, String, Int)] = Seq.empty): Unit = {
+ val circuit = Parser.parse(input.split("\n").toIterator)
+ val compiler = new MinimumVerilogCompiler
+ val prefix = circuit.main
+ val testDir = createTestDirectory(prefix + "_equivalence_test")
+
+ val customVerilog = compiler.compileAndEmit(CircuitState(circuit, HighForm, customAnnotations),
+ new GetNamespace +: new RenameTop(s"${prefix}_custom") +: customTransforms)
+ val namespaceAnnotation = customVerilog.annotations.collectFirst { case m: ModuleNamespaceAnnotation => m }.get
+ val customTop = customVerilog.circuit.main
+ val customFile = new PrintWriter(s"${testDir.getAbsolutePath}/$customTop.v")
+ customFile.write(customVerilog.getEmittedCircuit.value)
+ customFile.close()
+
+ val referenceVerilog = compiler.compileAndEmit(CircuitState(circuit, HighForm, Seq(namespaceAnnotation)),
+ Seq(new RenameModules, new RenameTop(s"${prefix}_reference")))
+ val referenceTop = referenceVerilog.circuit.main
+ val referenceFile = new PrintWriter(s"${testDir.getAbsolutePath}/$referenceTop.v")
+ referenceFile.write(referenceVerilog.getEmittedCircuit.value)
+ referenceFile.close()
+
+ assert(yosysExpectSuccess(customTop, referenceTop, testDir, resets))
+ }
+
/** Compiles input Firrtl to Verilog */
def compileToVerilog(input: String, annotations: AnnotationSeq = Seq.empty): String = {
val circuit = Parser.parse(input.split("\n").toIterator)
@@ -248,5 +305,3 @@ abstract class CompilationTest(name: String, dir: String) extends FirrtlPropSpec
compileFirrtlTest(name, dir)
}
}
-
-