aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/InferBinaryPoints.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/passes/InferBinaryPoints.scala')
-rw-r--r--src/main/scala/firrtl/passes/InferBinaryPoints.scala98
1 files changed, 52 insertions, 46 deletions
diff --git a/src/main/scala/firrtl/passes/InferBinaryPoints.scala b/src/main/scala/firrtl/passes/InferBinaryPoints.scala
index a16205a7..f393d8a5 100644
--- a/src/main/scala/firrtl/passes/InferBinaryPoints.scala
+++ b/src/main/scala/firrtl/passes/InferBinaryPoints.scala
@@ -13,9 +13,7 @@ import firrtl.options.Dependency
class InferBinaryPoints extends Pass {
override def prerequisites =
- Seq( Dependency(ResolveKinds),
- Dependency(InferTypes),
- Dependency(ResolveFlows) )
+ Seq(Dependency(ResolveKinds), Dependency(InferTypes), Dependency(ResolveFlows))
override def optionalPrerequisiteOf = Seq.empty
@@ -23,12 +21,12 @@ class InferBinaryPoints extends Pass {
private val constraintSolver = new ConstraintSolver()
- private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1,t2) match {
- case (UIntType(w1), UIntType(w2)) =>
- case (SIntType(w1), SIntType(w2)) =>
- case (ClockType, ClockType) =>
- case (ResetType, _) =>
- case (_, ResetType) =>
+ private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1, t2) match {
+ case (UIntType(w1), UIntType(w2)) =>
+ case (SIntType(w1), SIntType(w2)) =>
+ case (ClockType, ClockType) =>
+ case (ResetType, _) =>
+ case (_, ResetType) =>
case (AsyncResetType, AsyncResetType) =>
case (FixedType(w1, p1), FixedType(w2, p2)) =>
constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint(""))
@@ -36,78 +34,86 @@ class InferBinaryPoints extends Pass {
constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint(""))
case (AnalogType(w1), AnalogType(w2)) =>
case (t1: BundleType, t2: BundleType) =>
- (t1.fields zip t2.fields) foreach { case (f1, f2) =>
- (f1.flip, f2.flip) match {
- case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe)
- case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe)
- case _ => sys.error("Shouldn't be here")
- }
+ (t1.fields.zip(t2.fields)).foreach {
+ case (f1, f2) =>
+ (f1.flip, f2.flip) match {
+ case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe)
+ case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe)
+ case _ => sys.error("Shouldn't be here")
+ }
}
case (t1: VectorType, t2: VectorType) => addTypeConstraints(r1.index(0), r2.index(0))(t1.tpe, t2.tpe)
case other => throwInternalError(s"Illegal compiler state: cannot constraint different types - $other")
}
- private def addDecConstraints(t: Type): Type = t map addDecConstraints
- private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s map addDecConstraints match {
+ private def addDecConstraints(t: Type): Type = t.map(addDecConstraints)
+ private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s.map(addDecConstraints) match {
case c: Connect =>
val n = get_size(c.loc.tpe)
val locs = create_exps(c.loc)
val exps = create_exps(c.expr)
- (locs zip exps) foreach { case (loc, exp) =>
- to_flip(flow(loc)) match {
- case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
- case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
- }
+ (locs.zip(exps)).foreach {
+ case (loc, exp) =>
+ to_flip(flow(loc)) match {
+ case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
+ case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
+ }
}
c
case pc: PartialConnect =>
val ls = get_valid_points(pc.loc.tpe, pc.expr.tpe, Default, Default)
val locs = create_exps(pc.loc)
val exps = create_exps(pc.expr)
- ls foreach { case (x, y) =>
- val loc = locs(x)
- val exp = exps(y)
- to_flip(flow(loc)) match {
- case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
- case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
- }
+ ls.foreach {
+ case (x, y) =>
+ val loc = locs(x)
+ val exp = exps(y)
+ to_flip(flow(loc)) match {
+ case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
+ case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
+ }
}
pc
case r: DefRegister =>
- addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe)
+ addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe)
r
- case x => x map addStmtConstraints(mt)
+ case x => x.map(addStmtConstraints(mt))
}
private def fixWidth(w: Width): Width = constraintSolver.get(w) match {
case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt)
- case None => w
- case _ => sys.error("Shouldn't be here")
+ case None => w
+ case _ => sys.error("Shouldn't be here")
}
- private def fixType(t: Type): Type = t map fixType map fixWidth match {
+ private def fixType(t: Type): Type = t.map(fixType).map(fixWidth) match {
case IntervalType(l, u, p) =>
val px = constraintSolver.get(p) match {
case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt)
- case None => p
- case _ => sys.error("Shouldn't be here")
+ case None => p
+ case _ => sys.error("Shouldn't be here")
}
IntervalType(l, u, px)
case FixedType(w, p) =>
val px = constraintSolver.get(p) match {
case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt)
- case None => p
- case _ => sys.error("Shouldn't be here")
+ case None => p
+ case _ => sys.error("Shouldn't be here")
}
FixedType(w, px)
case x => x
}
- private def fixStmt(s: Statement): Statement = s map fixStmt map fixType
- private def fixPort(p: Port): Port = Port(p.info, p.name, p.direction, fixType(p.tpe))
- def run (c: Circuit): Circuit = {
+ private def fixStmt(s: Statement): Statement = s.map(fixStmt).map(fixType)
+ private def fixPort(p: Port): Port = Port(p.info, p.name, p.direction, fixType(p.tpe))
+ def run(c: Circuit): Circuit = {
val ct = CircuitTarget(c.main)
- c.modules foreach (m => m map addStmtConstraints(ct.module(m.name)))
- c.modules foreach (_.ports foreach {p => addDecConstraints(p.tpe)})
+ c.modules.foreach(m => m.map(addStmtConstraints(ct.module(m.name))))
+ c.modules.foreach(_.ports.foreach { p => addDecConstraints(p.tpe) })
constraintSolver.solve()
- InferTypes.run(c.copy(modules = c.modules map (_
- map fixPort
- map fixStmt)))
+ InferTypes.run(
+ c.copy(modules =
+ c.modules.map(
+ _.map(fixPort)
+ .map(fixStmt)
+ )
+ )
+ )
}
}