aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDonggyu Kim2016-08-25 02:22:39 -0700
committerDonggyu Kim2016-09-08 13:08:23 -0700
commit9234b7063c348a4bdb9e4429cbe8caa7b37b5a4e (patch)
tree9fc7180a99becb10c3780bd0106b78b0e21415f0 /src
parent5b34491096c2ce49a3e44b638780467d8bf5e2cd (diff)
refactor InferWidths
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/InferWidths.scala505
1 files changed, 244 insertions, 261 deletions
diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala
index ee54034b..5a81c268 100644
--- a/src/main/scala/firrtl/passes/InferWidths.scala
+++ b/src/main/scala/firrtl/passes/InferWidths.scala
@@ -27,14 +27,9 @@ MODIFICATIONS.
package firrtl.passes
-import com.typesafe.scalalogging.LazyLogging
-import java.nio.file.{Paths, Files}
-
// Datastructures
-import scala.collection.mutable.LinkedHashMap
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.{LinkedHashMap, HashMap, HashSet, ArrayBuffer}
+import scala.collection.immutable.ListMap
import firrtl._
import firrtl.ir._
@@ -45,200 +40,253 @@ import firrtl.WrappedExpression._
object InferWidths extends Pass {
def name = "Infer Widths"
- var mname = ""
- def solve_constraints (l:Seq[WGeq]) : LinkedHashMap[String,Width] = {
- def unique (ls:Seq[Width]) : Seq[Width] = ls.map(w => new WrappedWidth(w)).distinct.map(_.w)
- def make_unique (ls:Seq[WGeq]) : LinkedHashMap[String,Width] = {
- val h = LinkedHashMap[String,Width]()
- for (g <- ls) {
- (g.loc) match {
- case (w:VarWidth) => {
- val n = w.name
- if (h.contains(n)) h(n) = MaxWidth(Seq(g.exp,h(n))) else h(n) = g.exp
- }
- case (w) => w
+
+ def solve_constraints(l: Seq[WGeq]): LinkedHashMap[String, Width] = {
+ def unique(ls: Seq[Width]) : Seq[Width] =
+ (ls map (new WrappedWidth(_))).distinct map (_.w)
+ def make_unique(ls: Seq[WGeq]): ListMap[String,Width] = {
+ (ls foldLeft ListMap[String, Width]())((h, g) => g.loc match {
+ case w: VarWidth => h get w.name match {
+ case None => h + (w.name -> g.exp)
+ case Some(p) => h + (w.name -> MaxWidth(Seq(g.exp, p)))
}
+ case _ => h
+ })
+ }
+ def simplify(w: Width): Width = w map simplify match {
+ case (w: MinWidth) => MinWidth(unique((w.args foldLeft Seq[Width]()){
+ case (res, w: MinWidth) => res ++ w.args
+ case (res, w) => res :+ w
+ }))
+ case (w: MaxWidth) => MaxWidth(unique((w.args foldLeft Seq[Width]()){
+ case (res, w: MaxWidth) => res ++ w.args
+ case (res, w) => res :+ w
+ }))
+ case (w: PlusWidth) => (w.arg1, w.arg2) match {
+ case (w1: IntWidth, w2 :IntWidth) => IntWidth(w1.width + w2.width)
+ case _ => w
+ }
+ case (w: MinusWidth) => (w.arg1, w.arg2) match {
+ case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width - w2.width)
+ case _ => w
+ }
+ case (w: ExpWidth) => w.arg1 match {
+ case (w1: IntWidth) => IntWidth(BigInt((math.pow(2, w1.width.toDouble) - 1).toLong))
+ case (w1) => w
}
- h
+ case _ => w
+ }
+
+ def substitute(h: LinkedHashMap[String, Width])(w: Width): Width = {
+ //;println-all-debug(["Substituting for [" w "]"])
+ val wx = simplify(w)
+ //;println-all-debug(["After Simplify: [" wx "]"])
+ (wx map substitute(h)) match {
+ //;("matched println-debugvarwidth!")
+ case w: VarWidth => h get w.name match {
+ case None => w
+ case Some(p) =>
+ //;println-debug("Contained!")
+ //;println-all-debug(["Width: " w])
+ //;println-all-debug(["Accessed: " h[name(w)]])
+ val t = simplify(substitute(h)(p))
+ h(w.name) = t
+ t
+ }
+ case w => w
+ //;println-all-debug(["not varwidth!" w])
+ }
+ }
+
+ def b_sub(h: LinkedHashMap[String, Width])(w: Width): Width = {
+ w map b_sub(h) match {
+ case w: VarWidth => h getOrElse (w.name, w)
+ case w => w
+ }
+ }
+
+ def remove_cycle(n: String)(w: Width): Width = {
+ //;println-all-debug(["Removing cycle for " n " inside " w])
+ (w map remove_cycle(n)) match {
+ case w: MaxWidth => MaxWidth(w.args filter {
+ case w: VarWidth => !(n equals w.name)
+ case w => true
+ })
+ case w: MinusWidth => w.arg1 match {
+ case v: VarWidth if n == v.name => v
+ case v => w
+ }
+ case w => w
+ }
+ //;println-all-debug(["After removing cycle for " n ", returning " wx])
+ }
+
+ def hasVarWidth(n: String)(w: Width): Boolean = {
+ var has = false
+ def rec(w: Width): Width = {
+ w match {
+ case w: VarWidth if w.name == n => has = true
+ case w =>
+ }
+ w map rec
+ }
+ rec(w)
+ has
}
- def simplify (w:Width) : Width = {
- (w map (simplify)) match {
- case (w:MinWidth) => {
- val v = ArrayBuffer[Width]()
- for (wx <- w.args) {
- (wx) match {
- case (wx:MinWidth) => for (x <- wx.args) { v += x }
- case (wx) => v += wx } }
- MinWidth(unique(v)) }
- case (w:MaxWidth) => {
- val v = ArrayBuffer[Width]()
- for (wx <- w.args) {
- (wx) match {
- case (wx:MaxWidth) => for (x <- wx.args) { v += x }
- case (wx) => v += wx } }
- MaxWidth(unique(v)) }
- case (w:PlusWidth) => {
- (w.arg1,w.arg2) match {
- case (w1:IntWidth,w2:IntWidth) => IntWidth(w1.width + w2.width)
- case (w1,w2) => w }}
- case (w:MinusWidth) => {
- (w.arg1,w.arg2) match {
- case (w1:IntWidth,w2:IntWidth) => IntWidth(w1.width - w2.width)
- case (w1,w2) => w }}
- case (w:ExpWidth) => {
- (w.arg1) match {
- case (w1:IntWidth) => IntWidth(BigInt((scala.math.pow(2,w1.width.toDouble) - 1).toLong))
- case (w1) => w }}
- case (w) => w } }
- def substitute (h:LinkedHashMap[String,Width])(w:Width) : Width = {
- //;println-all-debug(["Substituting for [" w "]"])
- val wx = simplify(w)
- //;println-all-debug(["After Simplify: [" wx "]"])
- (simplify(w) map (substitute(h))) match {
- case (w:VarWidth) => {
- //;("matched println-debugvarwidth!")
- if (h.contains(w.name)) {
- //;println-debug("Contained!")
- //;println-all-debug(["Width: " w])
- //;println-all-debug(["Accessed: " h[name(w)]])
- val t = simplify(substitute(h)(h(w.name)))
- //;val t = h[name(w)]
- //;println-all-debug(["Width after sub: " t])
- h(w.name) = t
- t
- } else w
- }
- case (w) => w
- //;println-all-debug(["not varwidth!" w])
- }
- }
- def b_sub (h:LinkedHashMap[String,Width])(w:Width) : Width = {
- (w map (b_sub(h))) match {
- case (w:VarWidth) => if (h.contains(w.name)) h(w.name) else w
- case (w) => w
- }
- }
- def remove_cycle (n:String)(w:Width) : Width = {
- //;println-all-debug(["Removing cycle for " n " inside " w])
- val wx = (w map (remove_cycle(n))) match {
- case (w:MaxWidth) => MaxWidth(w.args.filter{ w => {
- w match {
- case (w:VarWidth) => !(n equals w.name)
- case (w) => true
- }}})
- case (w:MinusWidth) => {
- w.arg1 match {
- case (v:VarWidth) => if (n == v.name) v else w
- case (v) => w }}
- case (w) => w
- }
- //;println-all-debug(["After removing cycle for " n ", returning " wx])
- wx
- }
- def self_rec (n:String,w:Width) : Boolean = {
- var has = false
- def look (w:Width) : Width = {
- (w map (look)) match {
- case (w:VarWidth) => if (w.name == n) has = true
- case (w) => w }
- w }
- look(w)
- has }
-
- //; Forward solve
- //; Returns a solved list where each constraint undergoes:
- //; 1) Continuous Solving (using triangular solving)
- //; 2) Remove Cycles
- //; 3) Move to solved if not self-recursive
- val u = make_unique(l)
-
- //println("======== UNIQUE CONSTRAINTS ========")
- //for (x <- u) { println(x) }
- //println("====================================")
-
-
- val f = LinkedHashMap[String,Width]()
- val o = ArrayBuffer[String]()
- for (x <- u) {
- //println("==== SOLUTIONS TABLE ====")
- //for (x <- f) println(x)
- //println("=========================")
- val (n, e) = (x._1, x._2)
- val e_sub = substitute(f)(e)
+ //; Forward solve
+ //; Returns a solved list where each constraint undergoes:
+ //; 1) Continuous Solving (using triangular solving)
+ //; 2) Remove Cycles
+ //; 3) Move to solved if not self-recursive
+ val u = make_unique(l)
+
+ //println("======== UNIQUE CONSTRAINTS ========")
+ //for (x <- u) { println(x) }
+ //println("====================================")
+
+ val f = LinkedHashMap[String, Width]()
+ val o = ArrayBuffer[String]()
+ for ((n, e) <- u) {
+ //println("==== SOLUTIONS TABLE ====")
+ //for (x <- f) println(x)
+ //println("=========================")
- //println("Solving " + n + " => " + e)
- //println("After Substitute: " + n + " => " + e_sub)
- //println("==== SOLUTIONS TABLE (Post Substitute) ====")
- //for (x <- f) println(x)
- //println("=========================")
+ val e_sub = substitute(f)(e)
- val ex = remove_cycle(n)(e_sub)
+ //println("Solving " + n + " => " + e)
+ //println("After Substitute: " + n + " => " + e_sub)
+ //println("==== SOLUTIONS TABLE (Post Substitute) ====")
+ //for (x <- f) println(x)
+ //println("=========================")
- //println("After Remove Cycle: " + n + " => " + ex)
- if (!self_rec(n,ex)) {
- //println("Not rec!: " + n + " => " + ex)
- //println("Adding [" + n + "=>" + ex + "] to Solutions Table")
- o += n
- f(n) = ex
- }
- }
-
- //println("Forward Solved Constraints")
- //for (x <- f) println(x)
-
- //; Backwards Solve
- val b = LinkedHashMap[String,Width]()
- for (i <- 0 until o.size) {
- val n = o(o.size - 1 - i)
- /*
- println("SOLVE BACK: [" + n + " => " + f(n) + "]")
- println("==== SOLUTIONS TABLE ====")
- for (x <- b) println(x)
- println("=========================")
- */
- val ex = simplify(b_sub(b)(f(n)))
- /*
- println("BACK RETURN: [" + n + " => " + ex + "]")
- */
- b(n) = ex
- /*
- println("==== SOLUTIONS TABLE (Post backsolve) ====")
- for (x <- b) println(x)
- println("=========================")
- */
- }
- b
+ val ex = remove_cycle(n)(e_sub)
+
+ //println("After Remove Cycle: " + n + " => " + ex)
+ if (!hasVarWidth(n)(ex)) {
+ //println("Not rec!: " + n + " => " + ex)
+ //println("Adding [" + n + "=>" + ex + "] to Solutions Table")
+ f(n) = ex
+ o += n
+ }
+ }
+
+ //println("Forward Solved Constraints")
+ //for (x <- f) println(x)
+
+ //; Backwards Solve
+ val b = LinkedHashMap[String, Width]()
+ for (i <- (o.size - 1) to 0 by -1) {
+ val n = o(i) // Should visit `o` backward
+ /*
+ println("SOLVE BACK: [" + n + " => " + f(n) + "]")
+ println("==== SOLUTIONS TABLE ====")
+ for (x <- b) println(x)
+ println("=========================")
+ */
+ val ex = simplify(b_sub(b)(f(n)))
+ /*
+ println("BACK RETURN: [" + n + " => " + ex + "]")
+ */
+ b(n) = ex
+ /*
+ println("==== SOLUTIONS TABLE (Post backsolve) ====")
+ for (x <- b) println(x)
+ println("=========================")
+ */
+ }
+ b
}
- def width_BANG (t:Type) : Width = {
- (t) match {
- case (t:UIntType) => t.width
- case (t:SIntType) => t.width
- case ClockType => IntWidth(1)
- case (t) => error("No width!"); IntWidth(-1) } }
- def width_BANG (e:Expression) : Width = width_BANG(e.tpe)
+ def run (c: Circuit): Circuit = {
+ val v = ArrayBuffer[WGeq]()
+
+ def get_constraints_t(t1: Type, t2: Type, f: Orientation): Seq[WGeq] = (t1,t2) match {
+ case (t1: UIntType, t2: UIntType) => Seq(WGeq(t1.width, t2.width))
+ case (t1: SIntType, t2: SIntType) => Seq(WGeq(t1.width, t2.width))
+ case (t1: BundleType, t2: BundleType) =>
+ (t1.fields zip t2.fields foldLeft Seq[WGeq]()){case (res, (f1, f2)) =>
+ res ++ get_constraints_t(f1.tpe, f2.tpe, times(f1.flip, f))
+ }
+ case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe, f)
+ }
+
+ def get_constraints_e(e: Expression): Expression = {
+ e match {
+ case (e: Mux) => v ++= Seq(
+ WGeq(width_BANG(e.cond), IntWidth(1)),
+ WGeq(IntWidth(1), width_BANG(e.cond))
+ )
+ case _ =>
+ }
+ e map get_constraints_e
+ }
+
+ def get_constraints_s(s: Statement): Statement = {
+ s match {
+ case (s: Connect) =>
+ val n = get_size(s.loc.tpe)
+ val locs = create_exps(s.loc)
+ val exps = create_exps(s.expr)
+ v ++= ((locs zip exps).zipWithIndex map {case ((locx, expx), i) =>
+ get_flip(s.loc.tpe, i, Default) match {
+ case Default => WGeq(width_BANG(locx), width_BANG(expx))
+ case Flip => WGeq(width_BANG(expx), width_BANG(locx))
+ }
+ })
+ case (s: PartialConnect) =>
+ val ls = get_valid_points(s.loc.tpe, s.expr.tpe, Default, Default)
+ val locs = create_exps(s.loc)
+ val exps = create_exps(s.expr)
+ v ++= (ls map {case (x, y) =>
+ val locx = locs(x)
+ val expx = exps(y)
+ get_flip(s.loc.tpe, x, Default) match {
+ case Default => WGeq(width_BANG(locx), width_BANG(expx))
+ case Flip => WGeq(width_BANG(expx), width_BANG(locx))
+ }
+ })
+ case (s:DefRegister) => v ++= (Seq(
+ WGeq(width_BANG(s.reset), IntWidth(1)),
+ WGeq(IntWidth(1), width_BANG(s.reset))
+ ) ++ get_constraints_t(s.tpe, s.init.tpe, Default))
+ case (s:Conditionally) => v ++= Seq(
+ WGeq(width_BANG(s.pred), IntWidth(1)),
+ WGeq(IntWidth(1), width_BANG(s.pred))
+ )
+ case _ =>
+ }
+ s map get_constraints_e map get_constraints_s
+ }
+
+ c.modules foreach (_ map get_constraints_s)
+
+ //println-debug("======== ALL CONSTRAINTS ========")
+ //for x in v do : println-debug(x)
+ //println-debug("=================================")
+ val h = solve_constraints(v)
+ //println-debug("======== SOLVED CONSTRAINTS ========")
+ //for x in h do : println-debug(x)
+ //println-debug("====================================")
- def reduce_var_widths(c: Circuit, h: LinkedHashMap[String,Width]): Circuit = {
def evaluate(w: Width): Width = {
def map2(a: Option[BigInt], b: Option[BigInt], f: (BigInt,BigInt) => BigInt): Option[BigInt] =
- for (a_num <- a; b_num <- b) yield f(a_num, b_num)
+ for (a_num <- a; b_num <- b) yield f(a_num, b_num)
def reduceOptions(l: Seq[Option[BigInt]], f: (BigInt,BigInt) => BigInt): Option[BigInt] =
- l.reduce(map2(_, _, f))
+ l.reduce(map2(_, _, f))
// This function shouldn't be necessary
// Added as protection in case a constraint accidentally uses MinWidth/MaxWidth
// without any actual Widths. This should be elevated to an earlier error
def forceNonEmpty(in: Seq[Option[BigInt]], default: Option[BigInt]): Seq[Option[BigInt]] =
- if(in.isEmpty) Seq(default)
+ if (in.isEmpty) Seq(default)
else in
-
def solve(w: Width): Option[BigInt] = w match {
case (w: VarWidth) =>
for{
- v <- h.get(w.name) if !v.isInstanceOf[VarWidth]
- result <- solve(v)
+ v <- h.get(w.name) if !v.isInstanceOf[VarWidth]
+ result <- solve(v)
} yield result
case (w: MaxWidth) => reduceOptions(forceNonEmpty(w.args.map(solve _), Some(BigInt(0))), max)
case (w: MinWidth) => reduceOptions(forceNonEmpty(w.args.map(solve _), None), min)
@@ -249,97 +297,32 @@ object InferWidths extends Pass {
case (w) => println(w); error("Shouldn't be here"); None;
}
- val s = solve(w)
- (s) match {
+ solve(w) match {
+ case None => w
case Some(s) => IntWidth(s)
- case (s) => w
}
}
- def reduce_var_widths_w (w:Width) : Width = {
+ def reduce_var_widths_w(w: Width): Width = {
//println-all-debug(["REPLACE: " w])
- val wx = evaluate(w)
+ evaluate(w)
//println-all-debug(["WITH: " wx])
- wx
}
- def reduce_var_widths_s (s: Statement): Statement = {
- def onType(t: Type): Type = t map onType map reduce_var_widths_w
- s map reduce_var_widths_s map onType
+
+ def reduce_var_widths_t(t: Type): Type = {
+ t map reduce_var_widths_t map reduce_var_widths_w
}
-
- val modulesx = c.modules.map{ m => {
- val portsx = m.ports.map{ p => {
- Port(p.info,p.name,p.direction,mapr(reduce_var_widths_w _,p.tpe)) }}
- (m) match {
- case (m:ExtModule) => ExtModule(m.info,m.name,portsx)
- case (m:Module) =>
- mname = m.name
- Module(m.info,m.name,portsx,m.body map reduce_var_widths_s _) }}}
- InferTypes.run(Circuit(c.info,modulesx,c.main))
- }
-
- def run (c:Circuit): Circuit = {
- val v = ArrayBuffer[WGeq]()
- def constrain (w1:Width,w2:Width) : Unit = v += WGeq(w1,w2)
- def get_constraints_t (t1:Type,t2:Type,f:Orientation) : Unit = {
- (t1,t2) match {
- case (t1:UIntType,t2:UIntType) => constrain(t1.width,t2.width)
- case (t1:SIntType,t2:SIntType) => constrain(t1.width,t2.width)
- case (t1:BundleType,t2:BundleType) => {
- (t1.fields,t2.fields).zipped.foreach{ (f1,f2) => {
- get_constraints_t(f1.tpe,f2.tpe,times(f1.flip,f)) }}}
- case (t1:VectorType,t2:VectorType) => get_constraints_t(t1.tpe,t2.tpe,f) }}
- def get_constraints_e (e:Expression) : Expression = {
- (e map (get_constraints_e)) match {
- case (e:Mux) => {
- constrain(width_BANG(e.cond),IntWidth(1))
- constrain(IntWidth(1),width_BANG(e.cond))
- e }
- case (e) => e }}
- def get_constraints (s:Statement) : Statement = {
- (s map (get_constraints_e)) match {
- case (s:Connect) => {
- val n = get_size(s.loc.tpe)
- val ce_loc = create_exps(s.loc)
- val ce_exp = create_exps(s.expr)
- for (i <- 0 until n) {
- val locx = ce_loc(i)
- val expx = ce_exp(i)
- get_flip(s.loc.tpe,i,Default) match {
- case Default => constrain(width_BANG(locx),width_BANG(expx))
- case Flip => constrain(width_BANG(expx),width_BANG(locx)) }}
- s }
- case (s:PartialConnect) => {
- val ls = get_valid_points(s.loc.tpe,s.expr.tpe,Default,Default)
- for (x <- ls) {
- val locx = create_exps(s.loc)(x._1)
- val expx = create_exps(s.expr)(x._2)
- get_flip(s.loc.tpe,x._1,Default) match {
- case Default => constrain(width_BANG(locx),width_BANG(expx))
- case Flip => constrain(width_BANG(expx),width_BANG(locx)) }}
- s }
- case (s:DefRegister) => {
- constrain(width_BANG(s.reset),IntWidth(1))
- constrain(IntWidth(1),width_BANG(s.reset))
- get_constraints_t(s.tpe,s.init.tpe,Default)
- s }
- case (s:Conditionally) => {
- v += WGeq(width_BANG(s.pred),IntWidth(1))
- v += WGeq(IntWidth(1),width_BANG(s.pred))
- s map (get_constraints) }
- case (s) => s map (get_constraints) }}
- for (m <- c.modules) {
- (m) match {
- case (m:Module) => mname = m.name; get_constraints(m.body)
- case (m) => false }}
- //println-debug("======== ALL CONSTRAINTS ========")
- //for x in v do : println-debug(x)
- //println-debug("=================================")
- val h = solve_constraints(v)
- //println-debug("======== SOLVED CONSTRAINTS ========")
- //for x in h do : println-debug(x)
- //println-debug("====================================")
- reduce_var_widths(Circuit(c.info,c.modules,c.main),h)
+ def reduce_var_widths_s(s: Statement): Statement = {
+ s map reduce_var_widths_s map reduce_var_widths_t
+ }
+
+ def reduce_var_widths_p(p: Port): Port = {
+ Port(p.info, p.name, p.direction, reduce_var_widths_t(p.tpe))
+ }
+
+ InferTypes.run(c.copy(modules = c.modules map (_
+ map reduce_var_widths_p
+ map reduce_var_widths_s)))
}
}