aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/InferWidths.scala28
1 files changed, 17 insertions, 11 deletions
diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala
index 4c4afca1..d481b713 100644
--- a/src/main/scala/firrtl/passes/InferWidths.scala
+++ b/src/main/scala/firrtl/passes/InferWidths.scala
@@ -14,7 +14,7 @@ import firrtl.options.Dependency
object InferWidths {
def apply(): InferWidths = new InferWidths()
- def run(c: Circuit): Circuit = new InferWidths().run(c)
+ def run(c: Circuit): Circuit = new InferWidths().run(c)(new ConstraintSolver)
def execute(state: CircuitState): CircuitState = new InferWidths().execute(state)
}
@@ -73,11 +73,13 @@ class InferWidths extends Transform
Dependency[passes.TrimIntervals] ) ++ firrtl.stage.Forms.WorkingIR
override def invalidates(a: Transform) = false
- private val constraintSolver = new ConstraintSolver()
-
val annotationClasses = Seq(classOf[WidthGeqConstraintAnnotation])
- private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1,t2) match {
+ private def addTypeConstraints
+ (r1: ReferenceTarget, r2: ReferenceTarget)
+ (t1: Type, t2: Type)
+ (implicit constraintSolver: ConstraintSolver)
+ : Unit = (t1,t2) match {
case (UIntType(w1), UIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
case (SIntType(w1), SIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
case (ClockType, ClockType) =>
@@ -105,14 +107,16 @@ class InferWidths extends Transform
case (_, ResetType) => Nil
}
- private def addExpConstraints(e: Expression): Expression = e map addExpConstraints match {
+ private def addExpConstraints(e: Expression)(implicit constraintSolver: ConstraintSolver)
+ : Expression = e map addExpConstraints match {
case m@Mux(p, tVal, fVal, t) =>
constraintSolver.addGeq(getWidth(p), Closed(1), "mux predicate", "1.W")
m
case other => other
}
- private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s map addExpConstraints match {
+ private def addStmtConstraints(mt: ModuleTarget)(s: Statement)(implicit constraintSolver: ConstraintSolver)
+ : Statement = s map addExpConstraints match {
case c: Connect =>
val n = get_size(c.loc.tpe)
val locs = create_exps(c.loc)
@@ -155,12 +159,12 @@ class InferWidths extends Transform
c map addStmtConstraints(mt)
case x => x map addStmtConstraints(mt)
}
- private def fixWidth(w: Width): Width = constraintSolver.get(w) match {
+ private def fixWidth(w: Width)(implicit constraintSolver: ConstraintSolver): Width = constraintSolver.get(w) match {
case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt)
case None => w
case _ => sys.error("Shouldn't be here")
}
- private def fixType(t: Type): Type = t map fixType map fixWidth match {
+ private def fixType(t: Type)(implicit constraintSolver: ConstraintSolver): Type = t map fixType map fixWidth match {
case IntervalType(l, u, p) =>
val (lx, ux) = (constraintSolver.get(l), constraintSolver.get(u)) match {
case (Some(x: Bound), Some(y: Bound)) => (x, y)
@@ -173,12 +177,12 @@ class InferWidths extends Transform
case FixedType(w, p) => FixedType(w, fixWidth(p))
case x => x
}
- private def fixStmt(s: Statement): Statement = s map fixStmt map fixType
- private def fixPort(p: Port): Port = {
+ private def fixStmt(s: Statement)(implicit constraintSolver: ConstraintSolver): Statement = s map fixStmt map fixType
+ private def fixPort(p: Port)(implicit constraintSolver: ConstraintSolver): Port = {
Port(p.info, p.name, p.direction, fixType(p.tpe))
}
- def run (c: Circuit): Circuit = {
+ def run (c: Circuit)(implicit constraintSolver: ConstraintSolver): Circuit = {
val ct = CircuitTarget(c.main)
c.modules foreach ( m => m map addStmtConstraints(ct.module(m.name)))
constraintSolver.solve()
@@ -190,6 +194,8 @@ class InferWidths extends Transform
}
def execute(state: CircuitState): CircuitState = {
+ implicit val constraintSolver = new ConstraintSolver()
+
val circuitName = state.circuit.main
val typeMap = new collection.mutable.HashMap[ReferenceTarget, Type]