aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSchuyler Eldridge2020-07-20 11:32:32 -0400
committerGitHub2020-07-20 15:32:32 +0000
commitd177add0df50bfd7059557b2b648d101489b7285 (patch)
treefa5944c5c6a70583827b054b141833c4a97499ee /src
parent1b9f4ddff4102fee72ae4dd8c111c82c32e42d5d (diff)
Make InferWidths thread safe (#1775)
Change the class-global, but private ConstraintSolver object inside InferWidths to instead be constructed on each execute invocation. This prevents issues with thread safety where running the same InferWidths object at the same time would cause the ConstraintSolver to get trampled on. Signed-off-by: Schuyler Eldridge <schuyler.eldridge@ibm.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/InferWidths.scala28
1 files changed, 17 insertions, 11 deletions
diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala
index 4c4afca1..d481b713 100644
--- a/src/main/scala/firrtl/passes/InferWidths.scala
+++ b/src/main/scala/firrtl/passes/InferWidths.scala
@@ -14,7 +14,7 @@ import firrtl.options.Dependency
object InferWidths {
def apply(): InferWidths = new InferWidths()
- def run(c: Circuit): Circuit = new InferWidths().run(c)
+ def run(c: Circuit): Circuit = new InferWidths().run(c)(new ConstraintSolver)
def execute(state: CircuitState): CircuitState = new InferWidths().execute(state)
}
@@ -73,11 +73,13 @@ class InferWidths extends Transform
Dependency[passes.TrimIntervals] ) ++ firrtl.stage.Forms.WorkingIR
override def invalidates(a: Transform) = false
- private val constraintSolver = new ConstraintSolver()
-
val annotationClasses = Seq(classOf[WidthGeqConstraintAnnotation])
- private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1,t2) match {
+ private def addTypeConstraints
+ (r1: ReferenceTarget, r2: ReferenceTarget)
+ (t1: Type, t2: Type)
+ (implicit constraintSolver: ConstraintSolver)
+ : Unit = (t1,t2) match {
case (UIntType(w1), UIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
case (SIntType(w1), SIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
case (ClockType, ClockType) =>
@@ -105,14 +107,16 @@ class InferWidths extends Transform
case (_, ResetType) => Nil
}
- private def addExpConstraints(e: Expression): Expression = e map addExpConstraints match {
+ private def addExpConstraints(e: Expression)(implicit constraintSolver: ConstraintSolver)
+ : Expression = e map addExpConstraints match {
case m@Mux(p, tVal, fVal, t) =>
constraintSolver.addGeq(getWidth(p), Closed(1), "mux predicate", "1.W")
m
case other => other
}
- private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s map addExpConstraints match {
+ private def addStmtConstraints(mt: ModuleTarget)(s: Statement)(implicit constraintSolver: ConstraintSolver)
+ : Statement = s map addExpConstraints match {
case c: Connect =>
val n = get_size(c.loc.tpe)
val locs = create_exps(c.loc)
@@ -155,12 +159,12 @@ class InferWidths extends Transform
c map addStmtConstraints(mt)
case x => x map addStmtConstraints(mt)
}
- private def fixWidth(w: Width): Width = constraintSolver.get(w) match {
+ private def fixWidth(w: Width)(implicit constraintSolver: ConstraintSolver): Width = constraintSolver.get(w) match {
case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt)
case None => w
case _ => sys.error("Shouldn't be here")
}
- private def fixType(t: Type): Type = t map fixType map fixWidth match {
+ private def fixType(t: Type)(implicit constraintSolver: ConstraintSolver): Type = t map fixType map fixWidth match {
case IntervalType(l, u, p) =>
val (lx, ux) = (constraintSolver.get(l), constraintSolver.get(u)) match {
case (Some(x: Bound), Some(y: Bound)) => (x, y)
@@ -173,12 +177,12 @@ class InferWidths extends Transform
case FixedType(w, p) => FixedType(w, fixWidth(p))
case x => x
}
- private def fixStmt(s: Statement): Statement = s map fixStmt map fixType
- private def fixPort(p: Port): Port = {
+ private def fixStmt(s: Statement)(implicit constraintSolver: ConstraintSolver): Statement = s map fixStmt map fixType
+ private def fixPort(p: Port)(implicit constraintSolver: ConstraintSolver): Port = {
Port(p.info, p.name, p.direction, fixType(p.tpe))
}
- def run (c: Circuit): Circuit = {
+ def run (c: Circuit)(implicit constraintSolver: ConstraintSolver): Circuit = {
val ct = CircuitTarget(c.main)
c.modules foreach ( m => m map addStmtConstraints(ct.module(m.name)))
constraintSolver.solve()
@@ -190,6 +194,8 @@ class InferWidths extends Transform
}
def execute(state: CircuitState): CircuitState = {
+ implicit val constraintSolver = new ConstraintSolver()
+
val circuitName = state.circuit.main
val typeMap = new collection.mutable.HashMap[ReferenceTarget, Type]