aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/InferWidths.scala
diff options
context:
space:
mode:
authorchick2020-08-14 19:47:53 -0700
committerJack Koenig2020-08-14 19:47:53 -0700
commit6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch)
tree2ed103ee80b0fba613c88a66af854ae9952610ce /src/main/scala/firrtl/passes/InferWidths.scala
parentb516293f703c4de86397862fee1897aded2ae140 (diff)
All of src/ formatted with scalafmt
Diffstat (limited to 'src/main/scala/firrtl/passes/InferWidths.scala')
-rw-r--r--src/main/scala/firrtl/passes/InferWidths.scala190
1 files changed, 110 insertions, 80 deletions
diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala
index 3720523b..eae9690f 100644
--- a/src/main/scala/firrtl/passes/InferWidths.scala
+++ b/src/main/scala/firrtl/passes/InferWidths.scala
@@ -14,7 +14,7 @@ import firrtl.options.Dependency
object InferWidths {
def apply(): InferWidths = new InferWidths()
- def run(c: Circuit): Circuit = new InferWidths().run(c)(new ConstraintSolver)
+ def run(c: Circuit): Circuit = new InferWidths().run(c)(new ConstraintSolver)
def execute(state: CircuitState): CircuitState = new InferWidths().execute(state)
}
@@ -22,12 +22,14 @@ case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarg
def update(renameMap: RenameMap): Seq[WidthGeqConstraintAnnotation] = {
val newLoc :: newExp :: Nil = Seq(loc, exp).map { target =>
renameMap.get(target) match {
- case None => Some(target)
- case Some(Seq()) => None
+ case None => Some(target)
+ case Some(Seq()) => None
case Some(Seq(one)) => Some(one)
case Some(many) =>
- throw new Exception(s"Target below is an AggregateType, which " +
- "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint())
+ throw new Exception(
+ s"Target below is an AggregateType, which " +
+ "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint()
+ )
}
}
@@ -60,28 +62,31 @@ case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarg
*
* Uses firrtl.constraint package to infer widths
*/
-class InferWidths extends Transform
- with ResolvedAnnotationPaths
- with DependencyAPIMigration {
+class InferWidths extends Transform with ResolvedAnnotationPaths with DependencyAPIMigration {
override def prerequisites =
- Seq( Dependency(passes.ResolveKinds),
- Dependency(passes.InferTypes),
- Dependency(passes.ResolveFlows),
- Dependency[passes.InferBinaryPoints],
- Dependency[passes.TrimIntervals] ) ++ firrtl.stage.Forms.WorkingIR
+ Seq(
+ Dependency(passes.ResolveKinds),
+ Dependency(passes.InferTypes),
+ Dependency(passes.ResolveFlows),
+ Dependency[passes.InferBinaryPoints],
+ Dependency[passes.TrimIntervals]
+ ) ++ firrtl.stage.Forms.WorkingIR
override def invalidates(a: Transform) = false
val annotationClasses = Seq(classOf[WidthGeqConstraintAnnotation])
- private def addTypeConstraints
- (r1: ReferenceTarget, r2: ReferenceTarget)
- (t1: Type, t2: Type)
- (implicit constraintSolver: ConstraintSolver)
- : Unit = (t1,t2) match {
+ private def addTypeConstraints(
+ r1: ReferenceTarget,
+ r2: ReferenceTarget
+ )(t1: Type,
+ t2: Type
+ )(
+ implicit constraintSolver: ConstraintSolver
+ ): Unit = (t1, t2) match {
case (UIntType(w1), UIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
case (SIntType(w1), SIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
- case (ClockType, ClockType) =>
+ case (ClockType, ClockType) =>
case (FixedType(w1, p1), FixedType(w2, p2)) =>
constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint(""))
constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
@@ -93,101 +98,119 @@ class InferWidths extends Transform
constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
constraintSolver.addGeq(w2, w1, r1.prettyPrint(""), r2.prettyPrint(""))
case (t1: BundleType, t2: BundleType) =>
- (t1.fields zip t2.fields) foreach { case (f1, f2) =>
- (f1.flip, f2.flip) match {
- case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe)
- case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe)
- case _ => sys.error("Shouldn't be here")
- }
+ (t1.fields.zip(t2.fields)).foreach {
+ case (f1, f2) =>
+ (f1.flip, f2.flip) match {
+ case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe)
+ case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe)
+ case _ => sys.error("Shouldn't be here")
+ }
}
case (t1: VectorType, t2: VectorType) => addTypeConstraints(r1.index(0), r2.index(0))(t1.tpe, t2.tpe)
case (AsyncResetType, AsyncResetType) => Nil
- case (ResetType, _) => Nil
- case (_, ResetType) => Nil
+ case (ResetType, _) => Nil
+ case (_, ResetType) => Nil
}
- private def addExpConstraints(e: Expression)(implicit constraintSolver: ConstraintSolver)
- : Expression = e map addExpConstraints match {
- case m@Mux(p, tVal, fVal, t) =>
- constraintSolver.addGeq(getWidth(p), Closed(1), "mux predicate", "1.W")
- m
- case other => other
- }
+ private def addExpConstraints(e: Expression)(implicit constraintSolver: ConstraintSolver): Expression =
+ e.map(addExpConstraints) match {
+ case m @ Mux(p, tVal, fVal, t) =>
+ constraintSolver.addGeq(getWidth(p), Closed(1), "mux predicate", "1.W")
+ m
+ case other => other
+ }
- private def addStmtConstraints(mt: ModuleTarget)(s: Statement)(implicit constraintSolver: ConstraintSolver)
- : Statement = s map addExpConstraints match {
+ private def addStmtConstraints(
+ mt: ModuleTarget
+ )(s: Statement
+ )(
+ implicit constraintSolver: ConstraintSolver
+ ): Statement = s.map(addExpConstraints) match {
case c: Connect =>
val n = get_size(c.loc.tpe)
val locs = create_exps(c.loc)
val exps = create_exps(c.expr)
- (locs zip exps).foreach { case (loc, exp) =>
- to_flip(flow(loc)) match {
- case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
- case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
- }
- }
+ (locs.zip(exps)).foreach {
+ case (loc, exp) =>
+ to_flip(flow(loc)) match {
+ case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
+ case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
+ }
+ }
c
case pc: PartialConnect =>
val ls = get_valid_points(pc.loc.tpe, pc.expr.tpe, Default, Default)
val locs = create_exps(pc.loc)
val exps = create_exps(pc.expr)
- ls foreach { case (x, y) =>
- val loc = locs(x)
- val exp = exps(y)
- to_flip(flow(loc)) match {
- case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
- case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
- }
+ ls.foreach {
+ case (x, y) =>
+ val loc = locs(x)
+ val exp = exps(y)
+ to_flip(flow(loc)) match {
+ case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
+ case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
+ }
}
pc
case r: DefRegister =>
- if (r.reset.tpe != AsyncResetType ) {
+ if (r.reset.tpe != AsyncResetType) {
addTypeConstraints(Target.asTarget(mt)(r.reset), mt.ref("1"))(r.reset.tpe, UIntType(IntWidth(1)))
}
addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe)
r
- case a@Attach(_, exprs) =>
- val widths = exprs map (e => (e, getWidth(e.tpe)))
+ case a @ Attach(_, exprs) =>
+ val widths = exprs.map(e => (e, getWidth(e.tpe)))
val maxWidth = IsMax(widths.map(x => width2constraint(x._2)))
- widths.foreach { case (e, w) =>
- constraintSolver.addGeq(w, CalcWidth(maxWidth), Target.asTarget(mt)(e).prettyPrint(""), mt.ref(a.serialize).prettyPrint(""))
+ widths.foreach {
+ case (e, w) =>
+ constraintSolver.addGeq(
+ w,
+ CalcWidth(maxWidth),
+ Target.asTarget(mt)(e).prettyPrint(""),
+ mt.ref(a.serialize).prettyPrint("")
+ )
}
a
case c: Conditionally =>
addTypeConstraints(Target.asTarget(mt)(c.pred), mt.ref("1.W"))(c.pred.tpe, UIntType(IntWidth(1)))
- c map addStmtConstraints(mt)
- case x => x map addStmtConstraints(mt)
+ c.map(addStmtConstraints(mt))
+ case x => x.map(addStmtConstraints(mt))
}
private def fixWidth(w: Width)(implicit constraintSolver: ConstraintSolver): Width = constraintSolver.get(w) match {
case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt)
- case None => w
- case _ => sys.error("Shouldn't be here")
+ case None => w
+ case _ => sys.error("Shouldn't be here")
}
- private def fixType(t: Type)(implicit constraintSolver: ConstraintSolver): Type = t map fixType map fixWidth match {
+ private def fixType(t: Type)(implicit constraintSolver: ConstraintSolver): Type = t.map(fixType).map(fixWidth) match {
case IntervalType(l, u, p) =>
val (lx, ux) = (constraintSolver.get(l), constraintSolver.get(u)) match {
case (Some(x: Bound), Some(y: Bound)) => (x, y)
case (None, None) => (l, u)
- case x => sys.error(s"Shouldn't be here: $x")
-
+ case x => sys.error(s"Shouldn't be here: $x")
}
IntervalType(lx, ux, fixWidth(p))
case FixedType(w, p) => FixedType(w, fixWidth(p))
- case x => x
+ case x => x
}
- private def fixStmt(s: Statement)(implicit constraintSolver: ConstraintSolver): Statement = s map fixStmt map fixType
+ private def fixStmt(s: Statement)(implicit constraintSolver: ConstraintSolver): Statement =
+ s.map(fixStmt).map(fixType)
private def fixPort(p: Port)(implicit constraintSolver: ConstraintSolver): Port = {
Port(p.info, p.name, p.direction, fixType(p.tpe))
}
- def run (c: Circuit)(implicit constraintSolver: ConstraintSolver): Circuit = {
+ def run(c: Circuit)(implicit constraintSolver: ConstraintSolver): Circuit = {
val ct = CircuitTarget(c.main)
- c.modules foreach ( m => m map addStmtConstraints(ct.module(m.name)))
+ c.modules.foreach(m => m.map(addStmtConstraints(ct.module(m.name))))
constraintSolver.solve()
- val ret = InferTypes.run(c.copy(modules = c.modules map (_
- map fixPort
- map fixStmt)))
+ val ret = InferTypes.run(
+ c.copy(modules =
+ c.modules.map(
+ _.map(fixPort)
+ .map(fixStmt)
+ )
+ )
+ )
constraintSolver.clear()
ret
}
@@ -200,15 +223,16 @@ class InferWidths extends Transform
def getDeclTypes(modName: String)(stmt: Statement): Unit = {
val pairOpt = stmt match {
- case w: DefWire => Some(w.name -> w.tpe)
- case r: DefRegister => Some(r.name -> r.tpe)
- case n: DefNode => Some(n.name -> n.value.tpe)
+ case w: DefWire => Some(w.name -> w.tpe)
+ case r: DefRegister => Some(r.name -> r.tpe)
+ case n: DefNode => Some(n.name -> n.value.tpe)
case i: WDefInstance => Some(i.name -> i.tpe)
- case m: DefMemory => Some(m.name -> MemPortUtils.memType(m))
+ case m: DefMemory => Some(m.name -> MemPortUtils.memType(m))
case other => None
}
- pairOpt.foreach { case (ref, tpe) =>
- typeMap += (ReferenceTarget(circuitName, modName, Nil, ref, Nil) -> tpe)
+ pairOpt.foreach {
+ case (ref, tpe) =>
+ typeMap += (ReferenceTarget(circuitName, modName, Nil, ref, Nil) -> tpe)
}
stmt.foreachStmt(getDeclTypes(modName))
}
@@ -223,14 +247,20 @@ class InferWidths extends Transform
}
state.annotations.foreach {
- case anno: WidthGeqConstraintAnnotation if anno.loc.isLocal && anno.exp.isLocal =>
- val locType :: expType :: Nil = Seq(anno.loc, anno.exp) map { target =>
- val baseType = typeMap.getOrElse(target.copy(component = Seq.empty),
- throw new Exception(s"Target below from WidthGeqConstraintAnnotation was not found\n" + target.prettyPrint()))
+ case anno: WidthGeqConstraintAnnotation if anno.loc.isLocal && anno.exp.isLocal =>
+ val locType :: expType :: Nil = Seq(anno.loc, anno.exp).map { target =>
+ val baseType = typeMap.getOrElse(
+ target.copy(component = Seq.empty),
+ throw new Exception(
+ s"Target below from WidthGeqConstraintAnnotation was not found\n" + target.prettyPrint()
+ )
+ )
val leafType = target.componentType(baseType)
if (leafType.isInstanceOf[AggregateType]) {
- throw new Exception(s"Target below is an AggregateType, which " +
- "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint())
+ throw new Exception(
+ s"Target below is an AggregateType, which " +
+ "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint()
+ )
}
leafType