aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
authoralbertchen-sifive2018-07-20 14:36:30 -0700
committerAdam Izraelevitz2018-07-20 14:36:30 -0700
commit7dff927840a30893facae957595a8e88ea62509a (patch)
tree08210d9b2936fc4606ae8a0fe1c9f12a8c7c673e /src/main
parent897dad039a12a49b3c4ae833fbf0d02087b26ed5 (diff)
Constant prop add (#849)
* add FoldADD to const prop, add yosys miter tests * add option for verilog compiler without optimizations * rename FoldLogicalOp to FoldCommutativeOp * add GetNamespace and RenameModules, GetNamespace stores namespace as a ModuleNamespaceAnnotation * add constant propagation for Tail DoPrims * add scaladocs for MinimumLowFirrtlOptimization and yosysExpectFalure/Success, add constant propagation for Head DoPrim * add legalize pass to MinimumLowFirrtlOptimizations, use constPropBitExtract in legalize pass
Diffstat (limited to 'src/main')
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala16
-rw-r--r--src/main/scala/firrtl/analyses/GetNamespace.scala22
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala13
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala61
-rw-r--r--src/main/scala/firrtl/transforms/RenameModules.scala50
-rw-r--r--src/main/scala/firrtl/util/BackendCompilationUtilities.scala61
6 files changed, 194 insertions, 29 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index c4230b90..ba686f08 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -111,6 +111,15 @@ class LowFirrtlOptimization extends CoreTransform {
passes.CommonSubexpressionElimination,
new firrtl.transforms.DeadCodeElimination)
}
+/** Runs runs only the optimization passes needed for Verilog emission */
+class MinimumLowFirrtlOptimization extends CoreTransform {
+ def inputForm = LowForm
+ def outputForm = LowForm
+ def transforms = Seq(
+ passes.Legalize,
+ passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter
+ passes.SplitExpressions)
+}
import CompilerUtils.getLoweringTransforms
@@ -142,3 +151,10 @@ class VerilogCompiler extends Compiler {
def transforms: Seq[Transform] = getLoweringTransforms(ChirrtlForm, LowForm) ++
Seq(new LowFirrtlOptimization, new BlackBoxSourceHelper)
}
+
+/** Emits Verilog without optimizations */
+class MinimumVerilogCompiler extends Compiler {
+ def emitter = new VerilogEmitter
+ def transforms: Seq[Transform] = getLoweringTransforms(ChirrtlForm, LowForm) ++
+ Seq(new MinimumLowFirrtlOptimization, new BlackBoxSourceHelper)
+}
diff --git a/src/main/scala/firrtl/analyses/GetNamespace.scala b/src/main/scala/firrtl/analyses/GetNamespace.scala
new file mode 100644
index 00000000..5ab096b7
--- /dev/null
+++ b/src/main/scala/firrtl/analyses/GetNamespace.scala
@@ -0,0 +1,22 @@
+// See LICENSE for license details.
+
+package firrtl.analyses
+
+import firrtl.annotations.NoTargetAnnotation
+import firrtl.{CircuitState, LowForm, Namespace, Transform}
+
+case class ModuleNamespaceAnnotation(namespace: Namespace) extends NoTargetAnnotation
+
+/** Create a namespace with this circuit
+ *
+ * namespace is used by RenameModules to get unique names
+ */
+class GetNamespace extends Transform {
+ def inputForm: LowForm.type = LowForm
+ def outputForm: LowForm.type = LowForm
+
+ def execute(state: CircuitState): CircuitState = {
+ val namespace = Namespace(state.circuit)
+ state.copy(annotations = new ModuleNamespaceAnnotation(namespace) +: state.annotations)
+ }
+}
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
index 9bbde5f6..5e5aa26a 100644
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ b/src/main/scala/firrtl/passes/Passes.scala
@@ -3,12 +3,12 @@
package firrtl.passes
import com.typesafe.scalalogging.LazyLogging
-
import firrtl._
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
import firrtl.PrimOps._
+import firrtl.transforms.ConstantPropagation
import scala.collection.mutable
@@ -198,14 +198,9 @@ object Legalize extends Pass {
e
}
}
- private def legalizeBits(expr: DoPrim): Expression = {
- lazy val (hi, low) = (expr.consts.head, expr.consts(1))
- lazy val mask = (BigInt(1) << (hi - low + 1).toInt) - 1
- lazy val width = IntWidth(hi - low + 1)
+ private def legalizeBitExtract(expr: DoPrim): Expression = {
expr.args.head match {
- case UIntLiteral(value, _) => UIntLiteral((value >> low.toInt) & mask, width)
- case SIntLiteral(value, _) => SIntLiteral((value >> low.toInt) & mask, width)
- //case FixedLiteral
+ case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.constPropBitExtract(expr)
case _ => expr
}
}
@@ -236,7 +231,7 @@ object Legalize extends Pass {
case prim: DoPrim => prim.op match {
case Shr => legalizeShiftRight(prim)
case Pad => legalizePad(prim)
- case Bits => legalizeBits(prim)
+ case Bits | Head | Tail => legalizeBitExtract(prim)
case _ => prim
}
case e => e // respect pre-order traversal
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index 4a4f41d1..0d30446c 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -17,12 +17,33 @@ import annotation.tailrec
import collection.mutable
object ConstantPropagation {
+ private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t)
+
/** Pads e to the width of t */
def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match {
case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t)
case (we, wt) if we == wt => e
}
+ def constPropBitExtract(e: DoPrim) = {
+ val arg = e.args.head
+ val (hi, lo) = e.op match {
+ case Bits => (e.consts.head.toInt, e.consts(1).toInt)
+ case Tail => ((bitWidth(arg.tpe) - 1 - e.consts.head).toInt, 0)
+ case Head => ((bitWidth(arg.tpe) - 1).toInt, (bitWidth(arg.tpe) - e.consts.head).toInt)
+ }
+
+ arg match {
+ case lit: Literal =>
+ require(hi >= lo)
+ UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe))
+ case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match {
+ case t: UIntType => x
+ case _ => asUInt(x, e.tpe)
+ }
+ case _ => e
+ }
+ }
}
class ConstantPropagation extends Transform {
@@ -30,9 +51,7 @@ class ConstantPropagation extends Transform {
def inputForm = LowForm
def outputForm = LowForm
- private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t)
-
- trait FoldLogicalOp {
+ trait FoldCommutativeOp {
def fold(c1: Literal, c2: Literal): Expression
def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression
@@ -44,7 +63,19 @@ class ConstantPropagation extends Transform {
}
}
- object FoldAND extends FoldLogicalOp {
+ object FoldADD extends FoldCommutativeOp {
+ def fold(c1: Literal, c2: Literal) = (c1, c2) match {
+ case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1))
+ case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1))
+ }
+ def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
+ case UIntLiteral(v, w) if v == BigInt(0) => rhs
+ case SIntLiteral(v, w) if v == BigInt(0) => rhs
+ case _ => e
+ }
+ }
+
+ object FoldAND extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value & c2.value, c1.width max c2.width)
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w)
@@ -54,7 +85,7 @@ class ConstantPropagation extends Transform {
}
}
- object FoldOR extends FoldLogicalOp {
+ object FoldOR extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value | c2.value, c1.width max c2.width)
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, _) if v == BigInt(0) => rhs
@@ -64,7 +95,7 @@ class ConstantPropagation extends Transform {
}
}
- object FoldXOR extends FoldLogicalOp {
+ object FoldXOR extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width)
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, _) if v == BigInt(0) => rhs
@@ -73,7 +104,7 @@ class ConstantPropagation extends Transform {
}
}
- object FoldEqual extends FoldLogicalOp {
+ object FoldEqual extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1))
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs
@@ -81,7 +112,7 @@ class ConstantPropagation extends Transform {
}
}
- object FoldNotEqual extends FoldLogicalOp {
+ object FoldNotEqual extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1))
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs
@@ -189,6 +220,7 @@ class ConstantPropagation extends Transform {
case Shl => foldShiftLeft(e)
case Shr => foldShiftRight(e)
case Cat => foldConcat(e)
+ case Add => FoldADD(e)
case And => FoldAND(e)
case Or => FoldOR(e)
case Xor => FoldXOR(e)
@@ -215,18 +247,7 @@ class ConstantPropagation extends Transform {
case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head
case _ => e
}
- case Bits => e.args.head match {
- case lit: Literal =>
- val hi = e.consts.head.toInt
- val lo = e.consts(1).toInt
- require(hi >= lo)
- UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe))
- case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match {
- case t: UIntType => x
- case _ => asUInt(x, e.tpe)
- }
- case _ => e
- }
+ case (Bits | Head | Tail) => constPropBitExtract(e)
case _ => e
}
diff --git a/src/main/scala/firrtl/transforms/RenameModules.scala b/src/main/scala/firrtl/transforms/RenameModules.scala
new file mode 100644
index 00000000..af17dda9
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/RenameModules.scala
@@ -0,0 +1,50 @@
+// See LICENSE for license details.
+
+package firrtl.transforms
+
+import firrtl.analyses.{InstanceGraph, ModuleNamespaceAnnotation}
+import firrtl.ir._
+import firrtl._
+
+import scala.collection.mutable
+
+/** Rename Modules
+ *
+ * using namespace created by [[analyses.GetNamespace]], create unique names for modules
+ */
+class RenameModules extends Transform {
+ def inputForm: LowForm.type = LowForm
+ def outputForm: LowForm.type = LowForm
+
+ def collectNameMapping(namespace: Namespace, moduleNameMap: mutable.HashMap[String, String])(mod: DefModule): Unit = {
+ val newName = namespace.newName(mod.name)
+ moduleNameMap.put(mod.name, newName)
+ }
+
+ def onStmt(moduleNameMap: mutable.HashMap[String, String])(stmt: Statement): Statement = stmt match {
+ case inst: WDefInstance if moduleNameMap.contains(inst.module) => inst.copy(module = moduleNameMap(inst.module))
+ case other => other.mapStmt(onStmt(moduleNameMap))
+ }
+
+ def execute(state: CircuitState): CircuitState = {
+ val namespace = state.annotations.collectFirst {
+ case m: ModuleNamespaceAnnotation => m
+ }.map(_.namespace)
+
+ if (namespace.isEmpty) {
+ logger.warn("Skipping Rename Modules")
+ state
+ } else {
+ val moduleOrder = new InstanceGraph(state.circuit).moduleOrder.reverse
+ val nameMappings = new mutable.HashMap[String, String]()
+ moduleOrder.foreach(collectNameMapping(namespace.get, nameMappings))
+
+ val modulesx = state.circuit.modules.map {
+ case mod: Module => mod.mapStmt(onStmt(nameMappings)).mapString(nameMappings)
+ case ext: ExtModule => ext
+ }
+
+ state.copy(circuit = state.circuit.copy(modules = modulesx, main = nameMappings(state.circuit.main)))
+ }
+ }
+}
diff --git a/src/main/scala/firrtl/util/BackendCompilationUtilities.scala b/src/main/scala/firrtl/util/BackendCompilationUtilities.scala
index 0c5ab12f..d3d34e87 100644
--- a/src/main/scala/firrtl/util/BackendCompilationUtilities.scala
+++ b/src/main/scala/firrtl/util/BackendCompilationUtilities.scala
@@ -152,4 +152,65 @@ trait BackendCompilationUtilities {
def executeExpectingSuccess(prefix: String, dir: File): Boolean = {
!executeExpectingFailure(prefix, dir)
}
+
+ /** Creates and runs a Yosys script that creates and runs SAT on a miter
+ * circuit. Returns true if SAT succeeds, false otherwise
+ *
+ * The custom and reference Verilog files must not contain any modules with
+ * the same name otherwise Yosys will not be able to create a miter circuit
+ *
+ * @param customTop name of the DUT with custom transforms without the .v
+ * extension
+ * @param referenceTop name of the DUT without custom transforms without the
+ * .v extension
+ * @param testDir directory containing verilog files
+ * @param resets signals to set for SAT, format is
+ * (timestep, signal, value)
+ */
+ def yosysExpectSuccess(customTop: String,
+ referenceTop: String,
+ testDir: File,
+ resets: Seq[(Int, String, Int)] = Seq.empty): Boolean = {
+ !yosysExpectFailure(customTop, referenceTop, testDir, resets)
+ }
+
+ /** Creates and runs a Yosys script that creates and runs SAT on a miter
+ * circuit. Returns false if SAT succeeds, true otherwise
+ *
+ * The custom and reference Verilog files must not contain any modules with
+ * the same name otherwise Yosys will not be able to create a miter circuit
+ *
+ * @param customTop name of the DUT with custom transforms without the .v
+ * extension
+ * @param referenceTop name of the DUT without custom transforms without the
+ * .v extension
+ * @param testDir directory containing verilog files
+ * @param resets signals to set for SAT, format is
+ * (timestep, signal, value)
+ */
+ def yosysExpectFailure(customTop: String,
+ referenceTop: String,
+ testDir: File,
+ resets: Seq[(Int, String, Int)] = Seq.empty): Boolean = {
+
+ val setSignals = resets.map(_._2).toSet[String].map(s => s"-set in_$s 0").mkString(" ")
+ val setAtSignals = resets.map {
+ case (timestep, signal, value) => s"-set-at $timestep in_$signal $value"
+ }.mkString(" ")
+ val scriptFileName = s"${testDir.getAbsolutePath}/yosys_script"
+ val yosysScriptWriter = new PrintWriter(scriptFileName)
+ yosysScriptWriter.write(
+ s"""read_verilog ${testDir.getAbsolutePath}/$customTop.v
+ |read_verilog ${testDir.getAbsolutePath}/$referenceTop.v
+ |prep; proc; opt; memory
+ |miter -equiv -flatten $customTop $referenceTop miter
+ |hierarchy -top miter
+ |sat -verify -tempinduct -prove trigger 0 $setSignals $setAtSignals -seq 1 miter"""
+ .stripMargin)
+ yosysScriptWriter.close()
+
+ val resultFileName = testDir.getAbsolutePath + "/yosys_results"
+ val command = s"yosys -s $scriptFileName" #> new File(resultFileName)
+ command.! != 0
+ }
}