aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/InferWidths.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/passes/InferWidths.scala')
-rw-r--r--src/main/scala/firrtl/passes/InferWidths.scala48
1 files changed, 30 insertions, 18 deletions
diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala
index 8fd5eef1..ac786386 100644
--- a/src/main/scala/firrtl/passes/InferWidths.scala
+++ b/src/main/scala/firrtl/passes/InferWidths.scala
@@ -200,14 +200,19 @@ object InferWidths extends Pass {
def run (c: Circuit): Circuit = {
val v = ArrayBuffer[WGeq]()
- def get_constraints_t(t1: Type, t2: Type, f: Orientation): Seq[WGeq] = (t1,t2) match {
+ def get_constraints_t(t1: Type, t2: Type): Seq[WGeq] = (t1,t2) match {
case (t1: UIntType, t2: UIntType) => Seq(WGeq(t1.width, t2.width))
case (t1: SIntType, t2: SIntType) => Seq(WGeq(t1.width, t2.width))
+ case (ClockType, ClockType) => Nil
+ case (FixedType(w1, p1), FixedType(w2, p2)) => Seq(WGeq(w1,w2), WGeq(p1,p2))
case (t1: BundleType, t2: BundleType) =>
(t1.fields zip t2.fields foldLeft Seq[WGeq]()){case (res, (f1, f2)) =>
- res ++ get_constraints_t(f1.tpe, f2.tpe, times(f1.flip, f))
+ res ++ (f1.flip match {
+ case Default => get_constraints_t(f1.tpe, f2.tpe)
+ case Flip => get_constraints_t(f2.tpe, f1.tpe)
+ })
}
- case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe, f)
+ case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe)
}
def get_constraints_e(e: Expression): Expression = {
@@ -221,38 +226,44 @@ object InferWidths extends Pass {
e map get_constraints_e
}
+ def get_constraints_declared_type (t: Type): Type = t match {
+ case FixedType(_, p) =>
+ v += WGeq(p,IntWidth(0))
+ t
+ case _ => t map get_constraints_declared_type
+ }
+
def get_constraints_s(s: Statement): Statement = {
- s match {
+ s map get_constraints_declared_type match {
case (s: Connect) =>
val n = get_size(s.loc.tpe)
val locs = create_exps(s.loc)
val exps = create_exps(s.expr)
- v ++= ((locs zip exps).zipWithIndex map {case ((locx, expx), i) =>
+ v ++= ((locs zip exps).zipWithIndex flatMap {case ((locx, expx), i) =>
get_flip(s.loc.tpe, i, Default) match {
- case Default => WGeq(getWidth(locx), getWidth(expx))
- case Flip => WGeq(getWidth(expx), getWidth(locx))
+ case Default => get_constraints_t(locx.tpe, expx.tpe)//WGeq(getWidth(locx), getWidth(expx))
+ case Flip => get_constraints_t(expx.tpe, locx.tpe)//WGeq(getWidth(expx), getWidth(locx))
}
})
case (s: PartialConnect) =>
val ls = get_valid_points(s.loc.tpe, s.expr.tpe, Default, Default)
val locs = create_exps(s.loc)
val exps = create_exps(s.expr)
- v ++= (ls map {case (x, y) =>
+ v ++= (ls flatMap {case (x, y) =>
val locx = locs(x)
val expx = exps(y)
get_flip(s.loc.tpe, x, Default) match {
- case Default => WGeq(getWidth(locx), getWidth(expx))
- case Flip => WGeq(getWidth(expx), getWidth(locx))
+ case Default => get_constraints_t(locx.tpe, expx.tpe)//WGeq(getWidth(locx), getWidth(expx))
+ case Flip => get_constraints_t(expx.tpe, locx.tpe)//WGeq(getWidth(expx), getWidth(locx))
}
})
- case (s:DefRegister) => v ++= (Seq(
- WGeq(getWidth(s.reset), IntWidth(1)),
- WGeq(IntWidth(1), getWidth(s.reset))
- ) ++ get_constraints_t(s.tpe, s.init.tpe, Default))
- case (s:Conditionally) => v ++= Seq(
- WGeq(getWidth(s.pred), IntWidth(1)),
- WGeq(IntWidth(1), getWidth(s.pred))
- )
+ case (s: DefRegister) => v ++= (
+ get_constraints_t(s.reset.tpe, UIntType(IntWidth(1))) ++
+ get_constraints_t(UIntType(IntWidth(1)), s.reset.tpe) ++
+ get_constraints_t(s.tpe, s.init.tpe))
+ case (s:Conditionally) => v ++=
+ get_constraints_t(s.pred.tpe, UIntType(IntWidth(1))) ++
+ get_constraints_t(UIntType(IntWidth(1)), s.pred.tpe)
case (s: Attach) =>
v += WGeq(getWidth(s.source), MaxWidth(s.exprs map (e => getWidth(e.tpe))))
case _ =>
@@ -261,6 +272,7 @@ object InferWidths extends Pass {
}
c.modules foreach (_ map get_constraints_s)
+ c.modules foreach (_.ports foreach {p => get_constraints_declared_type(p.tpe)})
//println("======== ALL CONSTRAINTS ========")
//for(x <- v) println(x)