diff options
| author | Schuyler Eldridge | 2020-07-20 11:32:32 -0400 |
|---|---|---|
| committer | GitHub | 2020-07-20 15:32:32 +0000 |
| commit | d177add0df50bfd7059557b2b648d101489b7285 (patch) | |
| tree | fa5944c5c6a70583827b054b141833c4a97499ee /src | |
| parent | 1b9f4ddff4102fee72ae4dd8c111c82c32e42d5d (diff) | |
Make InferWidths thread safe (#1775)
Change the class-global, but private ConstraintSolver object inside
InferWidths to instead be constructed on each execute invocation. This
prevents issues with thread safety where running the same InferWidths
object at the same time would cause the ConstraintSolver to get
trampled on.
Signed-off-by: Schuyler Eldridge <schuyler.eldridge@ibm.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/InferWidths.scala | 28 |
1 files changed, 17 insertions, 11 deletions
diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index 4c4afca1..d481b713 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -14,7 +14,7 @@ import firrtl.options.Dependency object InferWidths { def apply(): InferWidths = new InferWidths() - def run(c: Circuit): Circuit = new InferWidths().run(c) + def run(c: Circuit): Circuit = new InferWidths().run(c)(new ConstraintSolver) def execute(state: CircuitState): CircuitState = new InferWidths().execute(state) } @@ -73,11 +73,13 @@ class InferWidths extends Transform Dependency[passes.TrimIntervals] ) ++ firrtl.stage.Forms.WorkingIR override def invalidates(a: Transform) = false - private val constraintSolver = new ConstraintSolver() - val annotationClasses = Seq(classOf[WidthGeqConstraintAnnotation]) - private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1,t2) match { + private def addTypeConstraints + (r1: ReferenceTarget, r2: ReferenceTarget) + (t1: Type, t2: Type) + (implicit constraintSolver: ConstraintSolver) + : Unit = (t1,t2) match { case (UIntType(w1), UIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint("")) case (SIntType(w1), SIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint("")) case (ClockType, ClockType) => @@ -105,14 +107,16 @@ class InferWidths extends Transform case (_, ResetType) => Nil } - private def addExpConstraints(e: Expression): Expression = e map addExpConstraints match { + private def addExpConstraints(e: Expression)(implicit constraintSolver: ConstraintSolver) + : Expression = e map addExpConstraints match { case m@Mux(p, tVal, fVal, t) => constraintSolver.addGeq(getWidth(p), Closed(1), "mux predicate", "1.W") m case other => other } - private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s map addExpConstraints match { + private def addStmtConstraints(mt: ModuleTarget)(s: Statement)(implicit constraintSolver: ConstraintSolver) + : Statement = s map addExpConstraints match { case c: Connect => val n = get_size(c.loc.tpe) val locs = create_exps(c.loc) @@ -155,12 +159,12 @@ class InferWidths extends Transform c map addStmtConstraints(mt) case x => x map addStmtConstraints(mt) } - private def fixWidth(w: Width): Width = constraintSolver.get(w) match { + private def fixWidth(w: Width)(implicit constraintSolver: ConstraintSolver): Width = constraintSolver.get(w) match { case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) case None => w case _ => sys.error("Shouldn't be here") } - private def fixType(t: Type): Type = t map fixType map fixWidth match { + private def fixType(t: Type)(implicit constraintSolver: ConstraintSolver): Type = t map fixType map fixWidth match { case IntervalType(l, u, p) => val (lx, ux) = (constraintSolver.get(l), constraintSolver.get(u)) match { case (Some(x: Bound), Some(y: Bound)) => (x, y) @@ -173,12 +177,12 @@ class InferWidths extends Transform case FixedType(w, p) => FixedType(w, fixWidth(p)) case x => x } - private def fixStmt(s: Statement): Statement = s map fixStmt map fixType - private def fixPort(p: Port): Port = { + private def fixStmt(s: Statement)(implicit constraintSolver: ConstraintSolver): Statement = s map fixStmt map fixType + private def fixPort(p: Port)(implicit constraintSolver: ConstraintSolver): Port = { Port(p.info, p.name, p.direction, fixType(p.tpe)) } - def run (c: Circuit): Circuit = { + def run (c: Circuit)(implicit constraintSolver: ConstraintSolver): Circuit = { val ct = CircuitTarget(c.main) c.modules foreach ( m => m map addStmtConstraints(ct.module(m.name))) constraintSolver.solve() @@ -190,6 +194,8 @@ class InferWidths extends Transform } def execute(state: CircuitState): CircuitState = { + implicit val constraintSolver = new ConstraintSolver() + val circuitName = state.circuit.main val typeMap = new collection.mutable.HashMap[ReferenceTarget, Type] |
