aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
authorJack Koenig2017-06-29 14:20:09 -0700
committerGitHub2017-06-29 14:20:09 -0700
commit905cac96053caf4b6c87ac0b9c8addf313d1085c (patch)
tree7d0bcf384f63e0176acdd70f9524369bb5bb4ce0 /src/main
parent8eb69dd91e58915f8dad5e42da0a3fe686c628d8 (diff)
parenta0aeafa3d591f9bcc14eca6d8a41eb2155f1b5b0 (diff)
Merge pull request #617 from freechipsproject/const-prop-regs
Improvements to Constant Propagation and Testing
Diffstat (limited to 'src/main')
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala6
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala (renamed from src/main/scala/firrtl/passes/ConstProp.scala)35
2 files changed, 30 insertions, 11 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index 66ae1673..8dd9b180 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -98,12 +98,12 @@ class LowFirrtlOptimization extends CoreTransform {
def outputForm = LowForm
def transforms = Seq(
passes.RemoveValidIf,
- passes.ConstProp,
+ new firrtl.transforms.ConstantPropagation,
passes.PadWidths,
- passes.ConstProp,
+ new firrtl.transforms.ConstantPropagation,
passes.Legalize,
passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter
- passes.ConstProp,
+ new firrtl.transforms.ConstantPropagation,
passes.SplitExpressions,
passes.CommonSubexpressionElimination,
new firrtl.transforms.DeadCodeElimination)
diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index f2aa1a03..efe06e9b 100644
--- a/src/main/scala/firrtl/passes/ConstProp.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -1,8 +1,10 @@
// See LICENSE for license details.
-package firrtl.passes
+package firrtl
+package transforms
import firrtl._
+import firrtl.annotations._
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
@@ -10,7 +12,10 @@ import firrtl.PrimOps._
import annotation.tailrec
-object ConstProp extends Pass {
+class ConstantPropagation extends Transform {
+ def inputForm = LowForm
+ def outputForm = LowForm
+
private def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match {
case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t)
case (we, wt) if we == wt => e
@@ -239,7 +244,7 @@ object ConstProp extends Pass {
// 2. Propagate references again for backwards reference (Wires)
// TODO Replacing all wires with nodes makes the second pass unnecessary
@tailrec
- private def constPropModule(m: Module): Module = {
+ private def constPropModule(m: Module, dontTouches: Set[String]): Module = {
var nPropagated = 0L
val nodeMap = collection.mutable.HashMap[String, Expression]()
@@ -272,8 +277,8 @@ object ConstProp extends Pass {
def constPropStmt(s: Statement): Statement = {
val stmtx = s map constPropStmt map constPropExpression
stmtx match {
- case x: DefNode => nodeMap(x.name) = x.value
- case Connect(_, WRef(wname, wtpe, WireKind, _), expr) =>
+ case x: DefNode if !dontTouches.contains(x.name) => nodeMap(x.name) = x.value
+ case Connect(_, WRef(wname, wtpe, WireKind, _), expr) if !dontTouches.contains(wname) =>
val exprx = constPropExpression(pad(expr, wtpe))
nodeMap(wname) = exprx
case _ =>
@@ -282,14 +287,28 @@ object ConstProp extends Pass {
}
val res = Module(m.info, m.name, m.ports, backPropStmt(constPropStmt(m.body)))
- if (nPropagated > 0) constPropModule(res) else res
+ if (nPropagated > 0) constPropModule(res, dontTouches) else res
}
- def run(c: Circuit): Circuit = {
+ private def run(c: Circuit, dontTouchMap: Map[String, Set[String]]): Circuit = {
val modulesx = c.modules.map {
case m: ExtModule => m
- case m: Module => constPropModule(m)
+ case m: Module => constPropModule(m, dontTouchMap.getOrElse(m.name, Set.empty))
}
Circuit(c.info, modulesx, c.main)
}
+
+ def execute(state: CircuitState): CircuitState = {
+ val dontTouches: Seq[(String, String)] = state.annotations match {
+ case Some(aMap) => aMap.annotations.collect {
+ case DontTouchAnnotation(ComponentName(c, ModuleName(m, _))) => m -> c
+ }
+ case None => Seq.empty
+ }
+ // Map from module name to component names
+ val dontTouchMap: Map[String, Set[String]] =
+ dontTouches.groupBy(_._1).mapValues(_.map(_._2).toSet)
+
+ state.copy(circuit = run(state.circuit, dontTouchMap))
+ }
}