diff options
| author | albertchen-sifive | 2018-07-20 14:36:30 -0700 |
|---|---|---|
| committer | Adam Izraelevitz | 2018-07-20 14:36:30 -0700 |
| commit | 7dff927840a30893facae957595a8e88ea62509a (patch) | |
| tree | 08210d9b2936fc4606ae8a0fe1c9f12a8c7c673e /src/main | |
| parent | 897dad039a12a49b3c4ae833fbf0d02087b26ed5 (diff) | |
Constant prop add (#849)
* add FoldADD to const prop, add yosys miter tests
* add option for verilog compiler without optimizations
* rename FoldLogicalOp to FoldCommutativeOp
* add GetNamespace and RenameModules, GetNamespace stores namespace as a ModuleNamespaceAnnotation
* add constant propagation for Tail DoPrims
* add scaladocs for MinimumLowFirrtlOptimization and yosysExpectFalure/Success, add constant propagation for Head DoPrim
* add legalize pass to MinimumLowFirrtlOptimizations, use constPropBitExtract in legalize pass
Diffstat (limited to 'src/main')
| -rw-r--r-- | src/main/scala/firrtl/LoweringCompilers.scala | 16 | ||||
| -rw-r--r-- | src/main/scala/firrtl/analyses/GetNamespace.scala | 22 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 13 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 61 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/RenameModules.scala | 50 | ||||
| -rw-r--r-- | src/main/scala/firrtl/util/BackendCompilationUtilities.scala | 61 |
6 files changed, 194 insertions, 29 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index c4230b90..ba686f08 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -111,6 +111,15 @@ class LowFirrtlOptimization extends CoreTransform { passes.CommonSubexpressionElimination, new firrtl.transforms.DeadCodeElimination) } +/** Runs runs only the optimization passes needed for Verilog emission */ +class MinimumLowFirrtlOptimization extends CoreTransform { + def inputForm = LowForm + def outputForm = LowForm + def transforms = Seq( + passes.Legalize, + passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter + passes.SplitExpressions) +} import CompilerUtils.getLoweringTransforms @@ -142,3 +151,10 @@ class VerilogCompiler extends Compiler { def transforms: Seq[Transform] = getLoweringTransforms(ChirrtlForm, LowForm) ++ Seq(new LowFirrtlOptimization, new BlackBoxSourceHelper) } + +/** Emits Verilog without optimizations */ +class MinimumVerilogCompiler extends Compiler { + def emitter = new VerilogEmitter + def transforms: Seq[Transform] = getLoweringTransforms(ChirrtlForm, LowForm) ++ + Seq(new MinimumLowFirrtlOptimization, new BlackBoxSourceHelper) +} diff --git a/src/main/scala/firrtl/analyses/GetNamespace.scala b/src/main/scala/firrtl/analyses/GetNamespace.scala new file mode 100644 index 00000000..5ab096b7 --- /dev/null +++ b/src/main/scala/firrtl/analyses/GetNamespace.scala @@ -0,0 +1,22 @@ +// See LICENSE for license details. + +package firrtl.analyses + +import firrtl.annotations.NoTargetAnnotation +import firrtl.{CircuitState, LowForm, Namespace, Transform} + +case class ModuleNamespaceAnnotation(namespace: Namespace) extends NoTargetAnnotation + +/** Create a namespace with this circuit + * + * namespace is used by RenameModules to get unique names + */ +class GetNamespace extends Transform { + def inputForm: LowForm.type = LowForm + def outputForm: LowForm.type = LowForm + + def execute(state: CircuitState): CircuitState = { + val namespace = Namespace(state.circuit) + state.copy(annotations = new ModuleNamespaceAnnotation(namespace) +: state.annotations) + } +} diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 9bbde5f6..5e5aa26a 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -3,12 +3,12 @@ package firrtl.passes import com.typesafe.scalalogging.LazyLogging - import firrtl._ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ import firrtl.PrimOps._ +import firrtl.transforms.ConstantPropagation import scala.collection.mutable @@ -198,14 +198,9 @@ object Legalize extends Pass { e } } - private def legalizeBits(expr: DoPrim): Expression = { - lazy val (hi, low) = (expr.consts.head, expr.consts(1)) - lazy val mask = (BigInt(1) << (hi - low + 1).toInt) - 1 - lazy val width = IntWidth(hi - low + 1) + private def legalizeBitExtract(expr: DoPrim): Expression = { expr.args.head match { - case UIntLiteral(value, _) => UIntLiteral((value >> low.toInt) & mask, width) - case SIntLiteral(value, _) => SIntLiteral((value >> low.toInt) & mask, width) - //case FixedLiteral + case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.constPropBitExtract(expr) case _ => expr } } @@ -236,7 +231,7 @@ object Legalize extends Pass { case prim: DoPrim => prim.op match { case Shr => legalizeShiftRight(prim) case Pad => legalizePad(prim) - case Bits => legalizeBits(prim) + case Bits | Head | Tail => legalizeBitExtract(prim) case _ => prim } case e => e // respect pre-order traversal diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 4a4f41d1..0d30446c 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -17,12 +17,33 @@ import annotation.tailrec import collection.mutable object ConstantPropagation { + private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t) + /** Pads e to the width of t */ def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match { case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) case (we, wt) if we == wt => e } + def constPropBitExtract(e: DoPrim) = { + val arg = e.args.head + val (hi, lo) = e.op match { + case Bits => (e.consts.head.toInt, e.consts(1).toInt) + case Tail => ((bitWidth(arg.tpe) - 1 - e.consts.head).toInt, 0) + case Head => ((bitWidth(arg.tpe) - 1).toInt, (bitWidth(arg.tpe) - e.consts.head).toInt) + } + + arg match { + case lit: Literal => + require(hi >= lo) + UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe)) + case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match { + case t: UIntType => x + case _ => asUInt(x, e.tpe) + } + case _ => e + } + } } class ConstantPropagation extends Transform { @@ -30,9 +51,7 @@ class ConstantPropagation extends Transform { def inputForm = LowForm def outputForm = LowForm - private def asUInt(e: Expression, t: Type) = DoPrim(AsUInt, Seq(e), Seq(), t) - - trait FoldLogicalOp { + trait FoldCommutativeOp { def fold(c1: Literal, c2: Literal): Expression def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression @@ -44,7 +63,19 @@ class ConstantPropagation extends Transform { } } - object FoldAND extends FoldLogicalOp { + object FoldADD extends FoldCommutativeOp { + def fold(c1: Literal, c2: Literal) = (c1, c2) match { + case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1)) + case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1)) + } + def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { + case UIntLiteral(v, w) if v == BigInt(0) => rhs + case SIntLiteral(v, w) if v == BigInt(0) => rhs + case _ => e + } + } + + object FoldAND extends FoldCommutativeOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value & c2.value, c1.width max c2.width) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) @@ -54,7 +85,7 @@ class ConstantPropagation extends Transform { } } - object FoldOR extends FoldLogicalOp { + object FoldOR extends FoldCommutativeOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value | c2.value, c1.width max c2.width) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, _) if v == BigInt(0) => rhs @@ -64,7 +95,7 @@ class ConstantPropagation extends Transform { } } - object FoldXOR extends FoldLogicalOp { + object FoldXOR extends FoldCommutativeOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, _) if v == BigInt(0) => rhs @@ -73,7 +104,7 @@ class ConstantPropagation extends Transform { } } - object FoldEqual extends FoldLogicalOp { + object FoldEqual extends FoldCommutativeOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1)) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs @@ -81,7 +112,7 @@ class ConstantPropagation extends Transform { } } - object FoldNotEqual extends FoldLogicalOp { + object FoldNotEqual extends FoldCommutativeOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1)) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs @@ -189,6 +220,7 @@ class ConstantPropagation extends Transform { case Shl => foldShiftLeft(e) case Shr => foldShiftRight(e) case Cat => foldConcat(e) + case Add => FoldADD(e) case And => FoldAND(e) case Or => FoldOR(e) case Xor => FoldXOR(e) @@ -215,18 +247,7 @@ class ConstantPropagation extends Transform { case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head case _ => e } - case Bits => e.args.head match { - case lit: Literal => - val hi = e.consts.head.toInt - val lo = e.consts(1).toInt - require(hi >= lo) - UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe)) - case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match { - case t: UIntType => x - case _ => asUInt(x, e.tpe) - } - case _ => e - } + case (Bits | Head | Tail) => constPropBitExtract(e) case _ => e } diff --git a/src/main/scala/firrtl/transforms/RenameModules.scala b/src/main/scala/firrtl/transforms/RenameModules.scala new file mode 100644 index 00000000..af17dda9 --- /dev/null +++ b/src/main/scala/firrtl/transforms/RenameModules.scala @@ -0,0 +1,50 @@ +// See LICENSE for license details. + +package firrtl.transforms + +import firrtl.analyses.{InstanceGraph, ModuleNamespaceAnnotation} +import firrtl.ir._ +import firrtl._ + +import scala.collection.mutable + +/** Rename Modules + * + * using namespace created by [[analyses.GetNamespace]], create unique names for modules + */ +class RenameModules extends Transform { + def inputForm: LowForm.type = LowForm + def outputForm: LowForm.type = LowForm + + def collectNameMapping(namespace: Namespace, moduleNameMap: mutable.HashMap[String, String])(mod: DefModule): Unit = { + val newName = namespace.newName(mod.name) + moduleNameMap.put(mod.name, newName) + } + + def onStmt(moduleNameMap: mutable.HashMap[String, String])(stmt: Statement): Statement = stmt match { + case inst: WDefInstance if moduleNameMap.contains(inst.module) => inst.copy(module = moduleNameMap(inst.module)) + case other => other.mapStmt(onStmt(moduleNameMap)) + } + + def execute(state: CircuitState): CircuitState = { + val namespace = state.annotations.collectFirst { + case m: ModuleNamespaceAnnotation => m + }.map(_.namespace) + + if (namespace.isEmpty) { + logger.warn("Skipping Rename Modules") + state + } else { + val moduleOrder = new InstanceGraph(state.circuit).moduleOrder.reverse + val nameMappings = new mutable.HashMap[String, String]() + moduleOrder.foreach(collectNameMapping(namespace.get, nameMappings)) + + val modulesx = state.circuit.modules.map { + case mod: Module => mod.mapStmt(onStmt(nameMappings)).mapString(nameMappings) + case ext: ExtModule => ext + } + + state.copy(circuit = state.circuit.copy(modules = modulesx, main = nameMappings(state.circuit.main))) + } + } +} diff --git a/src/main/scala/firrtl/util/BackendCompilationUtilities.scala b/src/main/scala/firrtl/util/BackendCompilationUtilities.scala index 0c5ab12f..d3d34e87 100644 --- a/src/main/scala/firrtl/util/BackendCompilationUtilities.scala +++ b/src/main/scala/firrtl/util/BackendCompilationUtilities.scala @@ -152,4 +152,65 @@ trait BackendCompilationUtilities { def executeExpectingSuccess(prefix: String, dir: File): Boolean = { !executeExpectingFailure(prefix, dir) } + + /** Creates and runs a Yosys script that creates and runs SAT on a miter + * circuit. Returns true if SAT succeeds, false otherwise + * + * The custom and reference Verilog files must not contain any modules with + * the same name otherwise Yosys will not be able to create a miter circuit + * + * @param customTop name of the DUT with custom transforms without the .v + * extension + * @param referenceTop name of the DUT without custom transforms without the + * .v extension + * @param testDir directory containing verilog files + * @param resets signals to set for SAT, format is + * (timestep, signal, value) + */ + def yosysExpectSuccess(customTop: String, + referenceTop: String, + testDir: File, + resets: Seq[(Int, String, Int)] = Seq.empty): Boolean = { + !yosysExpectFailure(customTop, referenceTop, testDir, resets) + } + + /** Creates and runs a Yosys script that creates and runs SAT on a miter + * circuit. Returns false if SAT succeeds, true otherwise + * + * The custom and reference Verilog files must not contain any modules with + * the same name otherwise Yosys will not be able to create a miter circuit + * + * @param customTop name of the DUT with custom transforms without the .v + * extension + * @param referenceTop name of the DUT without custom transforms without the + * .v extension + * @param testDir directory containing verilog files + * @param resets signals to set for SAT, format is + * (timestep, signal, value) + */ + def yosysExpectFailure(customTop: String, + referenceTop: String, + testDir: File, + resets: Seq[(Int, String, Int)] = Seq.empty): Boolean = { + + val setSignals = resets.map(_._2).toSet[String].map(s => s"-set in_$s 0").mkString(" ") + val setAtSignals = resets.map { + case (timestep, signal, value) => s"-set-at $timestep in_$signal $value" + }.mkString(" ") + val scriptFileName = s"${testDir.getAbsolutePath}/yosys_script" + val yosysScriptWriter = new PrintWriter(scriptFileName) + yosysScriptWriter.write( + s"""read_verilog ${testDir.getAbsolutePath}/$customTop.v + |read_verilog ${testDir.getAbsolutePath}/$referenceTop.v + |prep; proc; opt; memory + |miter -equiv -flatten $customTop $referenceTop miter + |hierarchy -top miter + |sat -verify -tempinduct -prove trigger 0 $setSignals $setAtSignals -seq 1 miter""" + .stripMargin) + yosysScriptWriter.close() + + val resultFileName = testDir.getAbsolutePath + "/yosys_results" + val command = s"yosys -s $scriptFileName" #> new File(resultFileName) + command.! != 0 + } } |
