aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorazidar2016-02-04 18:02:13 -0800
committerazidar2016-02-09 18:57:07 -0800
commitbf900917c50a440632dbcaae17bcfe9613d14452 (patch)
tree09ca1e2b58bbfc7b32cfe88f1a5cbc70e954027f /src
parent69f0ac34b9fd81b9bca932d32b01c522781a64f6 (diff)
Added Lower Types.
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/Compiler.scala10
-rw-r--r--src/main/scala/firrtl/Utils.scala48
-rw-r--r--src/main/scala/firrtl/WIR.scala101
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala551
4 files changed, 673 insertions, 37 deletions
diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala
index 8facc27d..34feab99 100644
--- a/src/main/scala/firrtl/Compiler.scala
+++ b/src/main/scala/firrtl/Compiler.scala
@@ -21,19 +21,23 @@ object FIRRTLCompiler extends Compiler {
object VerilogCompiler extends Compiler {
// Copied from Stanza implementation
val passes = Seq(
- CheckHighForm,
- Resolve,
+ //CheckHighForm,
ToWorkingIR,
ResolveKinds,
InferTypes,
ResolveGenders,
+ InferWidths,
PullMuxes,
ExpandConnects,
RemoveAccesses,
ExpandWhens,
CheckInitialization,
ConstProp,
- Resolve,
+ ToWorkingIR,
+ ResolveKinds,
+ InferTypes,
+ ResolveGenders,
+ InferWidths,
LowerTypes
)
def run(c: Circuit, w: Writer)
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index 406e393c..b9fb49c6 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -27,7 +27,11 @@ object Utils {
import scala.reflect._
def as[O: ClassTag]: Option[O] = x match {
case o: O => Some(o)
- case _ => None } }
+ case _ => None }
+ def typeof[O: ClassTag]: Boolean = x match {
+ case o: O => true
+ case _ => false }
+ }
implicit def toWrappedExpression (x:Expression) = new WrappedExpression(x)
def ceil_log2(x: BigInt): BigInt = (x-1).bitLength
def ceil_log2(x: Int): Int = scala.math.ceil(scala.math.log(x) / scala.math.log(2)).toInt
@@ -182,7 +186,9 @@ object Utils {
case (w1,w2) => MaxWidth(Seq(w1,w2))
}
}
- if (equals(t1,t2)) {
+ val wt1 = new WrappedType(t1)
+ val wt2 = new WrappedType(t2)
+ if (wt1 == wt2) {
(t1,t2) match {
case (t1:UIntType,t2:UIntType) => UIntType(wmax(t1.width,t2.width))
case (t1:SIntType,t2:SIntType) => SIntType(wmax(t1.width,t2.width))
@@ -340,6 +346,12 @@ object Utils {
case REVERSE => DEFAULT
}
}
+ def to_dir (g:Gender) : Direction = {
+ g match {
+ case MALE => INPUT
+ case FEMALE => OUTPUT
+ }
+ }
def to_gender (d:Direction) : Gender = {
d match {
case INPUT => MALE
@@ -414,10 +426,10 @@ object Utils {
}
def gender (e:Expression) : Gender = {
e match {
- case e:WRef => gender(e)
- case e:WSubField => gender(e)
- case e:WSubIndex => gender(e)
- case e:WSubAccess => gender(e)
+ case e:WRef => e.gender
+ case e:WSubField => e.gender
+ case e:WSubIndex => e.gender
+ case e:WSubAccess => e.gender
case e:PrimOp => MALE
case e:UIntValue => MALE
case e:SIntValue => MALE
@@ -608,6 +620,29 @@ object Utils {
case w => w
}
}
+ def stMap (f: String => String, c:Stmt) : Stmt = {
+ c match {
+ case (c:DefWire) => DefWire(c.info,f(c.name),c.tpe)
+ case (c:DefPoison) => DefPoison(c.info,f(c.name),c.tpe)
+ case (c:DefRegister) => DefRegister(c.info,f(c.name), c.tpe, c.clock, c.reset, c.init)
+ case (c:DefMemory) => DefMemory(c.info,f(c.name), c.data_type, c.depth, c.write_latency, c.read_latency, c.readers, c.writers, c.readwriters)
+ case (c:DefNode) => DefNode(c.info,f(c.name),c.value)
+ case (c:DefInstance) => DefInstance(c.info,f(c.name), c.module)
+ case (c) => c
+ }
+ }
+ def mapr (f: Width => Width, t:Type) : Type = {
+ def apply_t (t:Type) : Type = wMap(f,tMap(apply_t _,t))
+ apply_t(t)
+ }
+ def mapr (f: Width => Width, s:Stmt) : Stmt = {
+ def apply_t (t:Type) : Type = mapr(f,t)
+ def apply_e (e:Expression) : Expression =
+ wMap(f,tMap(apply_t _,eMap(apply_e _,e)))
+ def apply_s (s:Stmt) : Stmt =
+ tMap(apply_t _,eMap(apply_e _,sMap(apply_s _,s)))
+ apply_s(s)
+ }
val ONE = IntWidth(1)
//def digits (s:String) : Boolean {
// val digits = "0123456789"
@@ -803,6 +838,7 @@ object Utils {
val s = w match {
case w:UnknownWidth => "" //"?"
case w: IntWidth => s"<${w.width.toString}>"
+ case w: VarWidth => s"<${w.name}>"
}
s + debug(w)
}
diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala
index 35fcb93a..26e3a131 100644
--- a/src/main/scala/firrtl/WIR.scala
+++ b/src/main/scala/firrtl/WIR.scala
@@ -69,6 +69,103 @@ case class MinusWidth(arg1:Width,arg2:Width) extends Width
case class MaxWidth(args:Seq[Width]) extends Width
case class MinWidth(args:Seq[Width]) extends Width
case class ExpWidth(arg1:Width) extends Width
-//case class IntWidth(width: BigInt) extends Width
-//case object UnknownWidth extends Width
+
+class WrappedType (val t:Type) {
+ def wt (tx:Type) = new WrappedType(tx)
+ override def equals (o:Any) : Boolean = {
+ o match {
+ case (t2:WrappedType) => {
+ (t,t2.t) match {
+ case (t1:UIntType,t2:UIntType) => true
+ case (t1:SIntType,t2:SIntType) => true
+ case (t1:ClockType,t2:ClockType) => true
+ case (t1:VectorType,t2:VectorType) => (wt(t1.tpe) == wt(t2.tpe) && t1.size == t2.size)
+ case (t1:BundleType,t2:BundleType) => {
+ var ret = true
+ (t1.fields,t2.fields).zipped.foreach{ (f1,f2) => {
+ if (f1.flip != f2.flip) ret = false
+ if (f1.name != f2.name) ret = false
+ if (wt(f1.tpe) != wt(f2.tpe)) ret = false
+ }}
+ ret
+ }
+ case (t1,t2) => false
+ }
+ }
+ case _ => false
+ }
+ }
+}
+class WrappedWidth (val w:Width) {
+ override def toString = {
+ w match {
+ case (w:VarWidth) => w.name
+ case (w:MaxWidth) => "max(" + w.args.map(_.toString).reduce(_ + _) + ")"
+ case (w:MinWidth) => "min(" + w.args.map(_.toString).reduce(_ + _) + ")"
+ case (w:PlusWidth) => "(" + w.arg1 + " + " + w.arg2 + ")"
+ case (w:MinusWidth) => "(" + w.arg1 + " - " + w.arg2 + ")"
+ case (w:ExpWidth) => "exp(" + w.arg1 + ")"
+ case (w:IntWidth) => w.width.toString
+ case (w:UnknownWidth) => "?"
+ }
+ }
+ def eq (w1:Width,w2:Width) : Boolean = {
+ (new WrappedWidth(w1)) == (new WrappedWidth(w2))
+ }
+ override def equals (o:Any) : Boolean = {
+ o match {
+ case (w2:WrappedWidth) => {
+ (w,w2.w) match {
+ case (w1:VarWidth,w2:VarWidth) => w1.name.equals(w2.name)
+ case (w1:MaxWidth,w2:MaxWidth) => {
+ var ret = true
+ if (w1.args.size != w2.args.size) ret = false
+ else {
+ for (a1 <- w1.args) {
+ var found = false
+ for (a2 <- w2.args) { if (eq(a1,a2)) found = true }
+ if (found == false) ret = false
+ }
+ }
+ ret
+ }
+ case (w1:MinWidth,w2:MinWidth) => {
+ var ret = true
+ if (w1.args.size != w2.args.size) ret = false
+ else {
+ for (a1 <- w1.args) {
+ var found = false
+ for (a2 <- w2.args) { if (eq(a1,a2)) found = true }
+ if (found == false) ret = false
+ }
+ }
+ ret
+ }
+ case (w1:IntWidth,w2:IntWidth) => w1.width == w2.width
+ case (w1:PlusWidth,w2:PlusWidth) =>
+ (w1.arg1 == w2.arg1 && w1.arg2 == w2.arg2) || (w1.arg1 == w2.arg2 && w1.arg2 == w2.arg1)
+ case (w1:MinusWidth,w2:MinusWidth) =>
+ (w1.arg1 == w2.arg1 && w1.arg2 == w2.arg2) || (w1.arg1 == w2.arg2 && w1.arg2 == w2.arg1)
+ case (w1:ExpWidth,w2:ExpWidth) => w1.arg1 == w2.arg1
+ case (w1:UnknownWidth,w2:UnknownWidth) => true
+ case (w1,w2) => false
+ }
+ }
+ case _ => false
+ }
+ }
+}
+
+trait Constraint
+class WGeq(val loc:Width,val exp:Width) extends Constraint {
+ override def toString = {
+ val wloc = new WrappedWidth(loc)
+ val wexp = new WrappedWidth(exp)
+ wloc.toString + " >= " + wexp.toString
+ }
+}
+object WGeq {
+ def apply (loc:Width,exp:Width) = new WGeq(loc,exp)
+}
+
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
index 7cd4fdcf..d3d7027a 100644
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ b/src/main/scala/firrtl/passes/Passes.scala
@@ -44,7 +44,7 @@ trait StanzaPass extends LazyLogging {
}
object PassUtils extends LazyLogging {
- val listOfPasses: Seq[Pass] = Seq(ToWorkingIR,ResolveKinds,ResolveGenders,PullMuxes,ExpandConnects,RemoveAccesses,ExpandWhens)
+ val listOfPasses: Seq[Pass] = Seq(ToWorkingIR,ResolveKinds,InferTypes,ResolveGenders,InferWidths,PullMuxes,ExpandConnects,RemoveAccesses,ExpandWhens,LowerTypes)
lazy val mapNameToPass: Map[String, Pass] = listOfPasses.map(p => p.name -> p).toMap
def executePasses(c: Circuit, passes: Seq[Pass]): Circuit = {
@@ -52,9 +52,11 @@ object PassUtils extends LazyLogging {
else {
val p = passes.head
val name = p.name
- logger.debug(c.serialize())
logger.debug(s"Starting ${name}")
- executePasses(p.run(c), passes.tail)
+ val x = p.run(c)
+ logger.debug(x.serialize())
+ logger.debug(s"Finished ${name}")
+ executePasses(x, passes.tail)
}
}
}
@@ -172,12 +174,6 @@ object InferTypes extends Pass {
}
}
def remove_unknowns (t:Type): Type = mapr(remove_unknowns_w _,t)
- def mapr (f: Width => Width, t:Type) : Type = {
- def apply_t (t:Type) : Type = {
- wMap(f,tMap(apply_t _,t))
- }
- apply_t(t)
- }
def run (c:Circuit): Circuit = {
val module_types = HashMap[String,Type]()
def infer_types (m:Module) : Module = {
@@ -329,9 +325,291 @@ object CheckGenders extends Pass with StanzaPass {
def run (c:Circuit): Circuit = stanzaPass(c, "check-genders")
}
-object InferWidths extends Pass with StanzaPass {
+object InferWidths extends Pass {
def name = "Infer Widths"
- def run (c:Circuit): Circuit = stanzaPass(c, "infer-widths")
+ var mname = ""
+ def solve_constraints (l:Seq[WGeq]) : HashMap[String,Width] = {
+ def unique (ls:Seq[Width]) : Seq[Width] = ls.map(w => new WrappedWidth(w)).distinct.map(_.w)
+ def make_unique (ls:Seq[WGeq]) : HashMap[String,Width] = {
+ val h = HashMap[String,Width]()
+ for (g <- ls) {
+ (g.loc) match {
+ case (w:VarWidth) => {
+ val n = w.name
+ if (h.contains(n)) h(n) = MaxWidth(Seq(g.exp,h(n))) else h(n) = g.exp
+ }
+ case (w) => w
+ }
+ }
+ h
+ }
+ def simplify (w:Width) : Width = {
+ (wMap(simplify _,w)) match {
+ case (w:MinWidth) => {
+ val v = ArrayBuffer[Width]()
+ for (wx <- w.args) {
+ (wx) match {
+ case (wx:MinWidth) => for (x <- wx.args) { v += x }
+ case (wx) => v += wx } }
+ MinWidth(unique(v)) }
+ case (w:MaxWidth) => {
+ val v = ArrayBuffer[Width]()
+ for (wx <- w.args) {
+ (wx) match {
+ case (wx:MaxWidth) => for (x <- wx.args) { v += x }
+ case (wx) => v += wx } }
+ MaxWidth(unique(v)) }
+ case (w:PlusWidth) => {
+ (w.arg1,w.arg2) match {
+ case (w1:IntWidth,w2:IntWidth) => IntWidth(w1.width + w2.width)
+ case (w1,w2) => w }}
+ case (w:MinusWidth) => {
+ (w.arg1,w.arg2) match {
+ case (w1:IntWidth,w2:IntWidth) => IntWidth(w1.width - w2.width)
+ case (w1,w2) => w }}
+ case (w:ExpWidth) => {
+ (w.arg1) match {
+ case (w1:IntWidth) => IntWidth((2 ^ w1.width) - 1)
+ case (w1) => w }}
+ case (w) => w } }
+ def substitute (h:HashMap[String,Width])(w:Width) : Width = {
+ //;println-all-debug(["Substituting for [" w "]"])
+ val wx = simplify(w)
+ //;println-all-debug(["After Simplify: [" wx "]"])
+ (wMap(substitute(h) _,simplify(w))) match {
+ case (w:VarWidth) => {
+ //;("matched println-debugvarwidth!")
+ if (h.contains(w.name)) {
+ //;println-debug("Contained!")
+ //;println-all-debug(["Width: " w])
+ //;println-all-debug(["Accessed: " h[name(w)]])
+ val t = simplify(substitute(h)(h(w.name)))
+ //;val t = h[name(w)]
+ //;println-all-debug(["Width after sub: " t])
+ h(w.name) = t
+ t
+ } else w
+ }
+ case (w) => w
+ //;println-all-debug(["not varwidth!" w])
+ }
+ }
+ def b_sub (h:HashMap[String,Width])(w:Width) : Width = {
+ (wMap(b_sub(h) _,w)) match {
+ case (w:VarWidth) => if (h.contains(w.name)) h(w.name) else w
+ case (w) => w
+ }
+ }
+ def remove_cycle (n:String)(w:Width) : Width = {
+ //;println-all-debug(["Removing cycle for " n " inside " w])
+ val wx = (wMap(remove_cycle(n) _,w)) match {
+ case (w:MaxWidth) => MaxWidth(w.args.filter{ w => {
+ w match {
+ case (w:VarWidth) => n equals w.name
+ case (w) => false
+ }}})
+ case (w:MinusWidth) => {
+ w.arg1 match {
+ case (v:VarWidth) => if (n == v.name) v else w
+ case (v) => w }}
+ case (w) => w
+ }
+ //;println-all-debug(["After removing cycle for " n ", returning " wx])
+ wx
+ }
+ def self_rec (n:String,w:Width) : Boolean = {
+ var has = false
+ def look (w:Width) : Width = {
+ (wMap(look _,w)) match {
+ case (w:VarWidth) => if (w.name == n) has = true
+ case (w) => w }
+ w }
+ look(w)
+ has }
+
+ //; Forward solve
+ //; Returns a solved list where each constraint undergoes:
+ //; 1) Continuous Solving (using triangular solving)
+ //; 2) Remove Cycles
+ //; 3) Move to solved if not self-recursive
+ val u = make_unique(l)
+ //println-debug("======== UNIQUE CONSTRAINTS ========")
+ //for (x <- u) { println-debug(x) }
+ //println-debug("====================================")
+
+ val f = HashMap[String,Width]()
+ val o = ArrayBuffer[String]()
+ for (x <- u) {
+ //println-debug("==== SOLUTIONS TABLE ====")
+ //for x in f do : println-debug(x)
+ //println-debug("=========================")
+
+ val (n, e) = (x._1, x._2)
+
+ val e_sub = substitute(f)(e)
+ //println-debug(["Solving " n " => " e])
+ //println-debug(["After Substitute: " n " => " e-sub])
+ //println-debug("==== SOLUTIONS TABLE (Post Substitute) ====")
+ //for x in f do : println-debug(x)
+ //println-debug("=========================")
+ val ex = remove_cycle(n)(e_sub)
+ //;println-debug(["After Remove Cycle: " n " => " ex])
+ if (!self_rec(n,ex)) {
+ //;println-all-debug(["Not rec!: " n " => " ex])
+ //;println-all-debug(["Adding [" n "=>" ex "] to Solutions Table"])
+ o += n
+ f(n) = ex
+ }
+ }
+
+ //println-debug("Forward Solved Constraints")
+ //for x in f do : println-debug(x)
+
+ //; Backwards Solve
+ val b = HashMap[String,Width]()
+ for (i <- 0 until o.size) {
+ val n = o(o.size - 1 - i)
+ //println-all-debug(["SOLVE BACK: [" n " => " f[n] "]"])
+ //println-debug("==== SOLUTIONS TABLE ====")
+ //for x in b do : println-debug(x)
+ //println-debug("=========================")
+ val ex = simplify(b_sub(b)(f(n)))
+ //println-all-debug(["BACK RETURN: [" n " => " ex "]"])
+ b(n) = ex
+ //println-debug("==== SOLUTIONS TABLE (Post backsolve) ====")
+ //for x in b do : println-debug(x)
+ //println-debug("=========================")
+ }
+ b
+ }
+
+ def width_BANG (t:Type) : Width = {
+ (t) match {
+ case (t:UIntType) => t.width
+ case (t:SIntType) => t.width
+ case (t:ClockType) => IntWidth(1)
+ case (t) => error("No width!"); IntWidth(-1) } }
+ def width_BANG (e:Expression) : Width = width_BANG(tpe(e))
+ def reduce_var_widths (c:Circuit,h:HashMap[String,Width]) : Circuit = {
+ def evaluate (w:Width) : Width = {
+ def apply_2 (a:Option[BigInt],b:Option[BigInt], f: (BigInt,BigInt) => BigInt) : Option[BigInt] = {
+ (a,b) match {
+ case (a:Some[BigInt],b:Some[BigInt]) => Some(f(a.get,b.get))
+ case (a,b) => None } }
+ def apply_1 (a:Option[BigInt], f: (BigInt) => BigInt) : Option[BigInt] = {
+ (a) match {
+ case (a:Some[BigInt]) => Some(f(a.get))
+ case (a) => None } }
+ def apply_l (l:Seq[Option[BigInt]],f:(BigInt,BigInt) => BigInt) : Option[BigInt] = {
+ if (l.size == 0) Some(BigInt(0)) else apply_2(l.head,apply_l(l.tail,f),f)
+ }
+ def max (a:BigInt,b:BigInt) : BigInt = if (a >= b) a else b
+ def min (a:BigInt,b:BigInt) : BigInt = if (a >= b) b else a
+ def solve (w:Width) : Option[BigInt] = {
+ (w) match {
+ case (w:VarWidth) => {
+ val wx = h.get(w.name)
+ (wx) match {
+ case (wx:Some[Width]) => {
+ wx.get match {
+ case (v:VarWidth) => None
+ case (v) => solve(v) }}
+ case (None) => None }}
+ case (w:MaxWidth) => apply_l(w.args.map(solve _),max)
+ case (w:MinWidth) => apply_l(w.args.map(solve _),min)
+ case (w:PlusWidth) => apply_2(solve(w.arg1),solve(w.arg2),{_ + _})
+ case (w:MinusWidth) => apply_2(solve(w.arg1),solve(w.arg2),{_ - _})
+ case (w:ExpWidth) => apply_2(Some(BigInt(2)),solve(w.arg1),{(x,y) => (x ^ y) - BigInt(1)})
+ case (w:IntWidth) => Some(w.width)
+ case (w) => println(w); error("Shouldn't be here"); None;
+ }
+ }
+ val s = solve(w)
+ (s) match {
+ case (s:Some[BigInt]) => IntWidth(s.get)
+ case (s) => w }
+ }
+
+ def reduce_var_widths_w (w:Width) : Width = {
+ //println-all-debug(["REPLACE: " w])
+ val wx = evaluate(w)
+ //println-all-debug(["WITH: " wx])
+ wx
+ }
+
+ val modulesx = c.modules.map{ m => {
+ val portsx = m.ports.map{ p => {
+ Port(p.info,p.name,p.direction,mapr(reduce_var_widths_w _,p.tpe)) }}
+ (m) match {
+ case (m:ExModule) => ExModule(m.info,m.name,portsx)
+ case (m:InModule) => mname = m.name; InModule(m.info,m.name,portsx,mapr(reduce_var_widths_w _,m.body)) }}}
+ Circuit(c.info,modulesx,c.main)
+ }
+
+ def run (c:Circuit): Circuit = {
+ val v = ArrayBuffer[WGeq]()
+ def constrain (w1:Width,w2:Width) : Unit = v += WGeq(w1,w2)
+ def get_constraints_t (t1:Type,t2:Type,f:Flip) : Unit = {
+ (t1,t2) match {
+ case (t1:UIntType,t2:UIntType) => constrain(t1.width,t2.width)
+ case (t1:SIntType,t2:SIntType) => constrain(t1.width,t2.width)
+ case (t1:BundleType,t2:BundleType) => {
+ (t1.fields,t2.fields).zipped.foreach{ (f1,f2) => {
+ get_constraints_t(f1.tpe,f2.tpe,times(f1.flip,f)) }}}
+ case (t1:VectorType,t2:VectorType) => get_constraints_t(t1.tpe,t2.tpe,f) }}
+ def get_constraints_e (e:Expression) : Expression = {
+ (eMap(get_constraints_e _,e)) match {
+ case (e:Mux) => {
+ constrain(width_BANG(e.cond),ONE)
+ constrain(ONE,width_BANG(e.cond))
+ e }
+ case (e) => e }}
+ def get_constraints (s:Stmt) : Stmt = {
+ (eMap(get_constraints_e _,s)) match {
+ case (s:Connect) => {
+ val n = get_size(tpe(s.loc))
+ val ce_loc = create_exps(s.loc)
+ val ce_exp = create_exps(s.exp)
+ for (i <- 0 until n) {
+ val locx = ce_loc(i)
+ val expx = ce_exp(i)
+ get_flip(tpe(s.loc),i,DEFAULT) match {
+ case DEFAULT => constrain(width_BANG(locx),width_BANG(expx))
+ case REVERSE => constrain(width_BANG(expx),width_BANG(locx)) }}
+ s }
+ case (s:BulkConnect) => {
+ val ls = get_valid_points(tpe(s.loc),tpe(s.exp),DEFAULT,DEFAULT)
+ for (x <- ls) {
+ val locx = create_exps(s.loc)(x._1)
+ val expx = create_exps(s.exp)(x._2)
+ get_flip(tpe(s.loc),x._1,DEFAULT) match {
+ case DEFAULT => constrain(width_BANG(locx),width_BANG(expx))
+ case REVERSE => constrain(width_BANG(expx),width_BANG(locx)) }}
+ s }
+ case (s:DefRegister) => {
+ constrain(width_BANG(s.reset),ONE)
+ constrain(ONE,width_BANG(s.reset))
+ get_constraints_t(s.tpe,tpe(s.init),DEFAULT)
+ s }
+ case (s:Conditionally) => {
+ v += WGeq(width_BANG(s.pred),ONE)
+ v += WGeq(ONE,width_BANG(s.pred))
+ sMap(get_constraints _,s) }
+ case (s) => sMap(get_constraints _,s) }}
+
+ for (m <- c.modules) {
+ (m) match {
+ case (m:InModule) => mname = m.name; get_constraints(m.body)
+ case (m) => false }}
+ //println-debug("======== ALL CONSTRAINTS ========")
+ //for x in v do : println-debug(x)
+ //println-debug("=================================")
+ val h = solve_constraints(v)
+ //println-debug("======== SOLVED CONSTRAINTS ========")
+ //for x in h do : println-debug(x)
+ //println-debug("====================================")
+ reduce_var_widths(Circuit(c.info,c.modules,c.main),h)
+ }
}
object CheckWidths extends Pass with StanzaPass {
@@ -827,37 +1105,258 @@ object ExpandWhens extends Pass with StanzaPass {
}
object CheckInitialization extends Pass with StanzaPass {
- def name = "Check Initialization"
- def run (c:Circuit): Circuit = stanzaPass(c, "check-init")
+ def name = "Check Initialization"
+ def run (c:Circuit): Circuit = stanzaPass(c, "check-init")
}
object ConstProp extends Pass with StanzaPass {
- def name = "Constant Propogation"
- def run (c:Circuit): Circuit = stanzaPass(c, "const-prop")
+ def name = "Constant Propogation"
+ def run (c:Circuit): Circuit = stanzaPass(c, "const-prop")
}
object LoToVerilog extends Pass with StanzaPass {
- def name = "Lo To Verilog"
- def run (c:Circuit): Circuit = stanzaPass(c, "lo-to-verilog")
+ def name = "Lo To Verilog"
+ def run (c:Circuit): Circuit = stanzaPass(c, "lo-to-verilog")
}
object VerilogWrap extends Pass with StanzaPass {
- def name = "Verilog Wrap"
- def run (c:Circuit): Circuit = stanzaPass(c, "verilog-wrap")
+ def name = "Verilog Wrap"
+ def run (c:Circuit): Circuit = stanzaPass(c, "verilog-wrap")
}
object SplitExp extends Pass with StanzaPass {
- def name = "Split Expressions"
- def run (c:Circuit): Circuit = stanzaPass(c, "split-expressions")
+ def name = "Split Expressions"
+ def run (c:Circuit): Circuit = stanzaPass(c, "split-expressions")
}
object VerilogRename extends Pass with StanzaPass {
- def name = "Verilog Rename"
- def run (c:Circuit): Circuit = stanzaPass(c, "verilog-rename")
+ def name = "Verilog Rename"
+ def run (c:Circuit): Circuit = stanzaPass(c, "verilog-rename")
}
-object LowerTypes extends Pass with StanzaPass {
- def name = "Lower Types"
- def run (c:Circuit): Circuit = stanzaPass(c, "lower-types")
+object LowerTypes extends Pass {
+ def name = "Lower Types"
+ var mname = ""
+ def is_ground (t:Type) : Boolean = {
+ (t) match {
+ case (_:UIntType|_:SIntType) => true
+ case (t) => false
+ }
+ }
+ def data (ex:Expression) : Boolean = {
+ (kind(ex)) match {
+ case (k:MemKind) => (ex) match {
+ case (_:WRef|_:WSubIndex) => false
+ case (ex:WSubField) => {
+ var yes = ex.name match {
+ case "rdata" => true
+ case "data" => true
+ case "mask" => true
+ case _ => false
+ }
+ yes && ((ex.exp) match {
+ case (e:WSubField) => kind(e).as[MemKind].get.ports.contains(e.name) && (e.exp.typeof[WRef])
+ case (e) => false
+ })
+ }
+ case (ex) => false
+ }
+ case (k) => false
+ }
+ }
+ def expand_name (e:Expression) : Seq[String] = {
+ val names = ArrayBuffer[String]()
+ def expand_name_e (e:Expression) : Expression = {
+ (eMap(expand_name_e _,e)) match {
+ case (e:WRef) => names += e.name
+ case (e:WSubField) => names += e.name
+ case (e:WSubIndex) => names += e.value.toString
+ }
+ e
+ }
+ expand_name_e(e)
+ names
+ }
+ def lower_other_mem (e:Expression, dt:Type) : Seq[Expression] = {
+ val names = expand_name(e)
+ if (names.size < 3) error("Shouldn't be here")
+ create_exps(names(0),dt).map{ x => {
+ var base = lowered_name(x)
+ for (i <- 0 until names.size) {
+ if (i >= 3) base = base + "_" + names(i)
+ }
+ val m = WRef(base, UnknownType(), kind(e), UNKNOWNGENDER)
+ val p = WSubField(m,names(1),UnknownType(),UNKNOWNGENDER)
+ WSubField(p,names(2),UnknownType(),UNKNOWNGENDER)
+ }}
+ }
+ def lower_data_mem (e:Expression) : Expression = {
+ val names = expand_name(e)
+ if (names.size < 3) error("Shouldn't be here")
+ else {
+ var base = names(0)
+ for (i <- 0 until names.size) {
+ if (i >= 3) base = base + "_" + names(i)
+ }
+ val m = WRef(base, UnknownType(), kind(e), UNKNOWNGENDER)
+ val p = WSubField(m,names(1),UnknownType(),UNKNOWNGENDER)
+ WSubField(p,names(2),UnknownType(),UNKNOWNGENDER)
+ }
+ }
+ def merge (a:String,b:String,x:String) : String = a + x + b
+ def lowered_name (e:Expression) : String = {
+ (e) match {
+ case (e:WRef) => e.name
+ case (e:WSubField) => lowered_name(e.exp) + "_" + e.name
+ case (e:WSubIndex) => lowered_name(e.exp) + "_" + e.value
+ }
+ }
+ def root_ref (e:Expression) : WRef = {
+ (e) match {
+ case (e:WRef) => e
+ case (e:WSubField) => root_ref(e.exp)
+ case (e:WSubIndex) => root_ref(e.exp)
+ case (e:WSubAccess) => root_ref(e.exp)
+ }
+ }
+
+ //;------------- Pass ------------------
+
+ def lower_types (m:Module) : Module = {
+ val mdt = HashMap[String,Type]()
+ mname = m.name
+ def lower_types (s:Stmt) : Stmt = {
+ def lower_mem (e:Expression) : Seq[Expression] = {
+ val names = expand_name(e)
+ if (Seq("data","mask","rdata").contains(names(2))) Seq(lower_data_mem(e))
+ else lower_other_mem(e,mdt(root_ref(e).name))
+ }
+ def lower_types_e (e:Expression) : Expression = {
+ e match {
+ case (_:WRef|_:UIntValue|_:SIntValue) => e
+ case (_:WSubField|_:WSubIndex) => {
+ (kind(e)) match {
+ case (k:InstanceKind) => {
+ val names = expand_name(e)
+ var n = names(1)
+ for (i <- 0 until names.size) {
+ if (i > 1) n = n + "_" + names(i)
+ }
+ WSubField(root_ref(e),n,tpe(e),gender(e))
+ }
+ case (k:MemKind) => {
+ if (gender(e) != FEMALE) lower_mem(e)(0)
+ else e
+ }
+ case (k) => WRef(lowered_name(e),tpe(e),kind(e),gender(e))
+ }
+ }
+ case (e:DoPrim) => eMap(lower_types_e _,e)
+ case (e:Mux) => eMap(lower_types_e _,e)
+ case (e:ValidIf) => eMap(lower_types_e _,e)
+ }
+ }
+ (s) match {
+ case (s:DefWire) => {
+ if (is_ground(s.tpe)) s else {
+ val es = create_exps(s.name,s.tpe)
+ val stmts = (es, 0 until es.size).zipped.map{ (e,i) => {
+ DefWire(s.info,lowered_name(e),tpe(e))
+ }}
+ Begin(stmts)
+ }
+ }
+ case (s:DefPoison) => {
+ if (is_ground(s.tpe)) s else {
+ val es = create_exps(s.name,s.tpe)
+ val stmts = (es, 0 until es.size).zipped.map{ (e,i) => {
+ DefPoison(s.info,lowered_name(e),tpe(e))
+ }}
+ Begin(stmts)
+ }
+ }
+ case (s:DefRegister) => {
+ if (is_ground(s.tpe)) s else {
+ val es = create_exps(s.name,s.tpe)
+ val inits = create_exps(s.init)
+ val stmts = (es, 0 until es.size).zipped.map{ (e,i) => {
+ val init = lower_types_e(inits(i))
+ DefRegister(s.info,lowered_name(e),tpe(e),s.clock,s.reset,init)
+ }}
+ Begin(stmts)
+ }
+ }
+ case (s:WDefInstance) => {
+ val fieldsx = s.tpe.as[BundleType].get.fields.flatMap{ f => {
+ val es = create_exps(WRef(f.name,f.tpe,ExpKind(),times(f.flip,MALE)))
+ es.map{ e => {
+ gender(e) match {
+ case MALE => Field(lowered_name(e),DEFAULT,f.tpe)
+ case FEMALE => Field(lowered_name(e),REVERSE,f.tpe)
+ }
+ }}
+ }}
+ WDefInstance(s.info,s.name,s.module,BundleType(fieldsx))
+ }
+ case (s:DefMemory) => {
+ mdt(s.name) = s.data_type
+ if (is_ground(s.data_type)) s else {
+ val es = create_exps(s.name,s.data_type)
+ val stmts = es.map{ e => {
+ DefMemory(s.info,lowered_name(e),tpe(e),s.depth,s.write_latency,s.read_latency,s.readers,s.writers,s.readwriters)
+ }}
+ Begin(stmts)
+ }
+ }
+ case (s:IsInvalid) => {
+ val sx = eMap(lower_types_e _,s).as[IsInvalid].get
+ kind(sx.exp) match {
+ case (k:MemKind) => {
+ val es = lower_mem(sx.exp)
+ Begin(es.map(e => {IsInvalid(sx.info,e)}))
+ }
+ case (_) => sx
+ }
+ }
+ case (s:Connect) => {
+ val sx = eMap(lower_types_e _,s).as[Connect].get
+ kind(sx.loc) match {
+ case (k:MemKind) => {
+ val es = lower_mem(sx.loc)
+ Begin(es.map(e => {Connect(sx.info,e,sx.exp)}))
+ }
+ case (_) => sx
+ }
+ }
+ case (s:DefNode) => {
+ val locs = create_exps(s.name,tpe(s.value))
+ val n = locs.size
+ val nodes = ArrayBuffer[Stmt]()
+ val exps = create_exps(s.value)
+ for (i <- 0 until n) {
+ val locx = locs(i)
+ val expx = exps(i)
+ nodes += DefNode(s.info,lowered_name(locx),lower_types_e(expx))
+ }
+ if (n == 1) nodes(0) else Begin(nodes)
+ }
+ case (s) => eMap(lower_types_e _,sMap(lower_types _,s))
+ }
+ }
+
+ val portsx = m.ports.flatMap{ p => {
+ val es = create_exps(WRef(p.name,p.tpe,PortKind(),to_gender(p.direction)))
+ es.map(e => { Port(p.info,lowered_name(e),to_dir(gender(e)),tpe(e)) })
+ }}
+ (m) match {
+ case (m:ExModule) => ExModule(m.info,m.name,portsx)
+ case (m:InModule) => InModule(m.info,m.name,portsx,lower_types(m.body))
+ }
+ }
+
+ def run (c:Circuit) : Circuit = {
+ val modulesx = c.modules.map(m => lower_types(m))
+ Circuit(c.info,modulesx,c.main)
+ }
}