diff options
Diffstat (limited to 'src/main/scala/firrtl/passes/InferBinaryPoints.scala')
| -rw-r--r-- | src/main/scala/firrtl/passes/InferBinaryPoints.scala | 98 |
1 files changed, 52 insertions, 46 deletions
diff --git a/src/main/scala/firrtl/passes/InferBinaryPoints.scala b/src/main/scala/firrtl/passes/InferBinaryPoints.scala index a16205a7..f393d8a5 100644 --- a/src/main/scala/firrtl/passes/InferBinaryPoints.scala +++ b/src/main/scala/firrtl/passes/InferBinaryPoints.scala @@ -13,9 +13,7 @@ import firrtl.options.Dependency class InferBinaryPoints extends Pass { override def prerequisites = - Seq( Dependency(ResolveKinds), - Dependency(InferTypes), - Dependency(ResolveFlows) ) + Seq(Dependency(ResolveKinds), Dependency(InferTypes), Dependency(ResolveFlows)) override def optionalPrerequisiteOf = Seq.empty @@ -23,12 +21,12 @@ class InferBinaryPoints extends Pass { private val constraintSolver = new ConstraintSolver() - private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1,t2) match { - case (UIntType(w1), UIntType(w2)) => - case (SIntType(w1), SIntType(w2)) => - case (ClockType, ClockType) => - case (ResetType, _) => - case (_, ResetType) => + private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1, t2) match { + case (UIntType(w1), UIntType(w2)) => + case (SIntType(w1), SIntType(w2)) => + case (ClockType, ClockType) => + case (ResetType, _) => + case (_, ResetType) => case (AsyncResetType, AsyncResetType) => case (FixedType(w1, p1), FixedType(w2, p2)) => constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint("")) @@ -36,78 +34,86 @@ class InferBinaryPoints extends Pass { constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint("")) case (AnalogType(w1), AnalogType(w2)) => case (t1: BundleType, t2: BundleType) => - (t1.fields zip t2.fields) foreach { case (f1, f2) => - (f1.flip, f2.flip) match { - case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe) - case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe) - case _ => sys.error("Shouldn't be here") - } + (t1.fields.zip(t2.fields)).foreach { + case (f1, f2) => + (f1.flip, f2.flip) match { + case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe) + case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe) + case _ => sys.error("Shouldn't be here") + } } case (t1: VectorType, t2: VectorType) => addTypeConstraints(r1.index(0), r2.index(0))(t1.tpe, t2.tpe) case other => throwInternalError(s"Illegal compiler state: cannot constraint different types - $other") } - private def addDecConstraints(t: Type): Type = t map addDecConstraints - private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s map addDecConstraints match { + private def addDecConstraints(t: Type): Type = t.map(addDecConstraints) + private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s.map(addDecConstraints) match { case c: Connect => val n = get_size(c.loc.tpe) val locs = create_exps(c.loc) val exps = create_exps(c.expr) - (locs zip exps) foreach { case (loc, exp) => - to_flip(flow(loc)) match { - case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) - case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) - } + (locs.zip(exps)).foreach { + case (loc, exp) => + to_flip(flow(loc)) match { + case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) + case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) + } } c case pc: PartialConnect => val ls = get_valid_points(pc.loc.tpe, pc.expr.tpe, Default, Default) val locs = create_exps(pc.loc) val exps = create_exps(pc.expr) - ls foreach { case (x, y) => - val loc = locs(x) - val exp = exps(y) - to_flip(flow(loc)) match { - case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) - case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) - } + ls.foreach { + case (x, y) => + val loc = locs(x) + val exp = exps(y) + to_flip(flow(loc)) match { + case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) + case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) + } } pc case r: DefRegister => - addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe) + addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe) r - case x => x map addStmtConstraints(mt) + case x => x.map(addStmtConstraints(mt)) } private def fixWidth(w: Width): Width = constraintSolver.get(w) match { case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) - case None => w - case _ => sys.error("Shouldn't be here") + case None => w + case _ => sys.error("Shouldn't be here") } - private def fixType(t: Type): Type = t map fixType map fixWidth match { + private def fixType(t: Type): Type = t.map(fixType).map(fixWidth) match { case IntervalType(l, u, p) => val px = constraintSolver.get(p) match { case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) - case None => p - case _ => sys.error("Shouldn't be here") + case None => p + case _ => sys.error("Shouldn't be here") } IntervalType(l, u, px) case FixedType(w, p) => val px = constraintSolver.get(p) match { case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) - case None => p - case _ => sys.error("Shouldn't be here") + case None => p + case _ => sys.error("Shouldn't be here") } FixedType(w, px) case x => x } - private def fixStmt(s: Statement): Statement = s map fixStmt map fixType - private def fixPort(p: Port): Port = Port(p.info, p.name, p.direction, fixType(p.tpe)) - def run (c: Circuit): Circuit = { + private def fixStmt(s: Statement): Statement = s.map(fixStmt).map(fixType) + private def fixPort(p: Port): Port = Port(p.info, p.name, p.direction, fixType(p.tpe)) + def run(c: Circuit): Circuit = { val ct = CircuitTarget(c.main) - c.modules foreach (m => m map addStmtConstraints(ct.module(m.name))) - c.modules foreach (_.ports foreach {p => addDecConstraints(p.tpe)}) + c.modules.foreach(m => m.map(addStmtConstraints(ct.module(m.name)))) + c.modules.foreach(_.ports.foreach { p => addDecConstraints(p.tpe) }) constraintSolver.solve() - InferTypes.run(c.copy(modules = c.modules map (_ - map fixPort - map fixStmt))) + InferTypes.run( + c.copy(modules = + c.modules.map( + _.map(fixPort) + .map(fixStmt) + ) + ) + ) } } |
