diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/InferWidths.scala | 28 |
1 files changed, 17 insertions, 11 deletions
diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index 4c4afca1..d481b713 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -14,7 +14,7 @@ import firrtl.options.Dependency object InferWidths { def apply(): InferWidths = new InferWidths() - def run(c: Circuit): Circuit = new InferWidths().run(c) + def run(c: Circuit): Circuit = new InferWidths().run(c)(new ConstraintSolver) def execute(state: CircuitState): CircuitState = new InferWidths().execute(state) } @@ -73,11 +73,13 @@ class InferWidths extends Transform Dependency[passes.TrimIntervals] ) ++ firrtl.stage.Forms.WorkingIR override def invalidates(a: Transform) = false - private val constraintSolver = new ConstraintSolver() - val annotationClasses = Seq(classOf[WidthGeqConstraintAnnotation]) - private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1,t2) match { + private def addTypeConstraints + (r1: ReferenceTarget, r2: ReferenceTarget) + (t1: Type, t2: Type) + (implicit constraintSolver: ConstraintSolver) + : Unit = (t1,t2) match { case (UIntType(w1), UIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint("")) case (SIntType(w1), SIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint("")) case (ClockType, ClockType) => @@ -105,14 +107,16 @@ class InferWidths extends Transform case (_, ResetType) => Nil } - private def addExpConstraints(e: Expression): Expression = e map addExpConstraints match { + private def addExpConstraints(e: Expression)(implicit constraintSolver: ConstraintSolver) + : Expression = e map addExpConstraints match { case m@Mux(p, tVal, fVal, t) => constraintSolver.addGeq(getWidth(p), Closed(1), "mux predicate", "1.W") m case other => other } - private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s map addExpConstraints match { + private def addStmtConstraints(mt: ModuleTarget)(s: Statement)(implicit constraintSolver: ConstraintSolver) + : Statement = s map addExpConstraints match { case c: Connect => val n = get_size(c.loc.tpe) val locs = create_exps(c.loc) @@ -155,12 +159,12 @@ class InferWidths extends Transform c map addStmtConstraints(mt) case x => x map addStmtConstraints(mt) } - private def fixWidth(w: Width): Width = constraintSolver.get(w) match { + private def fixWidth(w: Width)(implicit constraintSolver: ConstraintSolver): Width = constraintSolver.get(w) match { case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) case None => w case _ => sys.error("Shouldn't be here") } - private def fixType(t: Type): Type = t map fixType map fixWidth match { + private def fixType(t: Type)(implicit constraintSolver: ConstraintSolver): Type = t map fixType map fixWidth match { case IntervalType(l, u, p) => val (lx, ux) = (constraintSolver.get(l), constraintSolver.get(u)) match { case (Some(x: Bound), Some(y: Bound)) => (x, y) @@ -173,12 +177,12 @@ class InferWidths extends Transform case FixedType(w, p) => FixedType(w, fixWidth(p)) case x => x } - private def fixStmt(s: Statement): Statement = s map fixStmt map fixType - private def fixPort(p: Port): Port = { + private def fixStmt(s: Statement)(implicit constraintSolver: ConstraintSolver): Statement = s map fixStmt map fixType + private def fixPort(p: Port)(implicit constraintSolver: ConstraintSolver): Port = { Port(p.info, p.name, p.direction, fixType(p.tpe)) } - def run (c: Circuit): Circuit = { + def run (c: Circuit)(implicit constraintSolver: ConstraintSolver): Circuit = { val ct = CircuitTarget(c.main) c.modules foreach ( m => m map addStmtConstraints(ct.module(m.name))) constraintSolver.solve() @@ -190,6 +194,8 @@ class InferWidths extends Transform } def execute(state: CircuitState): CircuitState = { + implicit val constraintSolver = new ConstraintSolver() + val circuitName = state.circuit.main val typeMap = new collection.mutable.HashMap[ReferenceTarget, Type] |
