From bf900917c50a440632dbcaae17bcfe9613d14452 Mon Sep 17 00:00:00 2001 From: azidar Date: Thu, 4 Feb 2016 18:02:13 -0800 Subject: Added Lower Types. --- src/main/scala/firrtl/Compiler.scala | 10 +- src/main/scala/firrtl/Utils.scala | 48 ++- src/main/scala/firrtl/WIR.scala | 101 +++++- src/main/scala/firrtl/passes/Passes.scala | 551 ++++++++++++++++++++++++++++-- test/parser/gcd.fir | 2 +- 5 files changed, 674 insertions(+), 38 deletions(-) diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala index 8facc27d..34feab99 100644 --- a/src/main/scala/firrtl/Compiler.scala +++ b/src/main/scala/firrtl/Compiler.scala @@ -21,19 +21,23 @@ object FIRRTLCompiler extends Compiler { object VerilogCompiler extends Compiler { // Copied from Stanza implementation val passes = Seq( - CheckHighForm, - Resolve, + //CheckHighForm, ToWorkingIR, ResolveKinds, InferTypes, ResolveGenders, + InferWidths, PullMuxes, ExpandConnects, RemoveAccesses, ExpandWhens, CheckInitialization, ConstProp, - Resolve, + ToWorkingIR, + ResolveKinds, + InferTypes, + ResolveGenders, + InferWidths, LowerTypes ) def run(c: Circuit, w: Writer) diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 406e393c..b9fb49c6 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -27,7 +27,11 @@ object Utils { import scala.reflect._ def as[O: ClassTag]: Option[O] = x match { case o: O => Some(o) - case _ => None } } + case _ => None } + def typeof[O: ClassTag]: Boolean = x match { + case o: O => true + case _ => false } + } implicit def toWrappedExpression (x:Expression) = new WrappedExpression(x) def ceil_log2(x: BigInt): BigInt = (x-1).bitLength def ceil_log2(x: Int): Int = scala.math.ceil(scala.math.log(x) / scala.math.log(2)).toInt @@ -182,7 +186,9 @@ object Utils { case (w1,w2) => MaxWidth(Seq(w1,w2)) } } - if (equals(t1,t2)) { + val wt1 = new WrappedType(t1) + val wt2 = new WrappedType(t2) + if (wt1 == wt2) { (t1,t2) match { case (t1:UIntType,t2:UIntType) => UIntType(wmax(t1.width,t2.width)) case (t1:SIntType,t2:SIntType) => SIntType(wmax(t1.width,t2.width)) @@ -340,6 +346,12 @@ object Utils { case REVERSE => DEFAULT } } + def to_dir (g:Gender) : Direction = { + g match { + case MALE => INPUT + case FEMALE => OUTPUT + } + } def to_gender (d:Direction) : Gender = { d match { case INPUT => MALE @@ -414,10 +426,10 @@ object Utils { } def gender (e:Expression) : Gender = { e match { - case e:WRef => gender(e) - case e:WSubField => gender(e) - case e:WSubIndex => gender(e) - case e:WSubAccess => gender(e) + case e:WRef => e.gender + case e:WSubField => e.gender + case e:WSubIndex => e.gender + case e:WSubAccess => e.gender case e:PrimOp => MALE case e:UIntValue => MALE case e:SIntValue => MALE @@ -608,6 +620,29 @@ object Utils { case w => w } } + def stMap (f: String => String, c:Stmt) : Stmt = { + c match { + case (c:DefWire) => DefWire(c.info,f(c.name),c.tpe) + case (c:DefPoison) => DefPoison(c.info,f(c.name),c.tpe) + case (c:DefRegister) => DefRegister(c.info,f(c.name), c.tpe, c.clock, c.reset, c.init) + case (c:DefMemory) => DefMemory(c.info,f(c.name), c.data_type, c.depth, c.write_latency, c.read_latency, c.readers, c.writers, c.readwriters) + case (c:DefNode) => DefNode(c.info,f(c.name),c.value) + case (c:DefInstance) => DefInstance(c.info,f(c.name), c.module) + case (c) => c + } + } + def mapr (f: Width => Width, t:Type) : Type = { + def apply_t (t:Type) : Type = wMap(f,tMap(apply_t _,t)) + apply_t(t) + } + def mapr (f: Width => Width, s:Stmt) : Stmt = { + def apply_t (t:Type) : Type = mapr(f,t) + def apply_e (e:Expression) : Expression = + wMap(f,tMap(apply_t _,eMap(apply_e _,e))) + def apply_s (s:Stmt) : Stmt = + tMap(apply_t _,eMap(apply_e _,sMap(apply_s _,s))) + apply_s(s) + } val ONE = IntWidth(1) //def digits (s:String) : Boolean { // val digits = "0123456789" @@ -803,6 +838,7 @@ object Utils { val s = w match { case w:UnknownWidth => "" //"?" case w: IntWidth => s"<${w.width.toString}>" + case w: VarWidth => s"<${w.name}>" } s + debug(w) } diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index 35fcb93a..26e3a131 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -69,6 +69,103 @@ case class MinusWidth(arg1:Width,arg2:Width) extends Width case class MaxWidth(args:Seq[Width]) extends Width case class MinWidth(args:Seq[Width]) extends Width case class ExpWidth(arg1:Width) extends Width -//case class IntWidth(width: BigInt) extends Width -//case object UnknownWidth extends Width + +class WrappedType (val t:Type) { + def wt (tx:Type) = new WrappedType(tx) + override def equals (o:Any) : Boolean = { + o match { + case (t2:WrappedType) => { + (t,t2.t) match { + case (t1:UIntType,t2:UIntType) => true + case (t1:SIntType,t2:SIntType) => true + case (t1:ClockType,t2:ClockType) => true + case (t1:VectorType,t2:VectorType) => (wt(t1.tpe) == wt(t2.tpe) && t1.size == t2.size) + case (t1:BundleType,t2:BundleType) => { + var ret = true + (t1.fields,t2.fields).zipped.foreach{ (f1,f2) => { + if (f1.flip != f2.flip) ret = false + if (f1.name != f2.name) ret = false + if (wt(f1.tpe) != wt(f2.tpe)) ret = false + }} + ret + } + case (t1,t2) => false + } + } + case _ => false + } + } +} +class WrappedWidth (val w:Width) { + override def toString = { + w match { + case (w:VarWidth) => w.name + case (w:MaxWidth) => "max(" + w.args.map(_.toString).reduce(_ + _) + ")" + case (w:MinWidth) => "min(" + w.args.map(_.toString).reduce(_ + _) + ")" + case (w:PlusWidth) => "(" + w.arg1 + " + " + w.arg2 + ")" + case (w:MinusWidth) => "(" + w.arg1 + " - " + w.arg2 + ")" + case (w:ExpWidth) => "exp(" + w.arg1 + ")" + case (w:IntWidth) => w.width.toString + case (w:UnknownWidth) => "?" + } + } + def eq (w1:Width,w2:Width) : Boolean = { + (new WrappedWidth(w1)) == (new WrappedWidth(w2)) + } + override def equals (o:Any) : Boolean = { + o match { + case (w2:WrappedWidth) => { + (w,w2.w) match { + case (w1:VarWidth,w2:VarWidth) => w1.name.equals(w2.name) + case (w1:MaxWidth,w2:MaxWidth) => { + var ret = true + if (w1.args.size != w2.args.size) ret = false + else { + for (a1 <- w1.args) { + var found = false + for (a2 <- w2.args) { if (eq(a1,a2)) found = true } + if (found == false) ret = false + } + } + ret + } + case (w1:MinWidth,w2:MinWidth) => { + var ret = true + if (w1.args.size != w2.args.size) ret = false + else { + for (a1 <- w1.args) { + var found = false + for (a2 <- w2.args) { if (eq(a1,a2)) found = true } + if (found == false) ret = false + } + } + ret + } + case (w1:IntWidth,w2:IntWidth) => w1.width == w2.width + case (w1:PlusWidth,w2:PlusWidth) => + (w1.arg1 == w2.arg1 && w1.arg2 == w2.arg2) || (w1.arg1 == w2.arg2 && w1.arg2 == w2.arg1) + case (w1:MinusWidth,w2:MinusWidth) => + (w1.arg1 == w2.arg1 && w1.arg2 == w2.arg2) || (w1.arg1 == w2.arg2 && w1.arg2 == w2.arg1) + case (w1:ExpWidth,w2:ExpWidth) => w1.arg1 == w2.arg1 + case (w1:UnknownWidth,w2:UnknownWidth) => true + case (w1,w2) => false + } + } + case _ => false + } + } +} + +trait Constraint +class WGeq(val loc:Width,val exp:Width) extends Constraint { + override def toString = { + val wloc = new WrappedWidth(loc) + val wexp = new WrappedWidth(exp) + wloc.toString + " >= " + wexp.toString + } +} +object WGeq { + def apply (loc:Width,exp:Width) = new WGeq(loc,exp) +} + diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 7cd4fdcf..d3d7027a 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -44,7 +44,7 @@ trait StanzaPass extends LazyLogging { } object PassUtils extends LazyLogging { - val listOfPasses: Seq[Pass] = Seq(ToWorkingIR,ResolveKinds,ResolveGenders,PullMuxes,ExpandConnects,RemoveAccesses,ExpandWhens) + val listOfPasses: Seq[Pass] = Seq(ToWorkingIR,ResolveKinds,InferTypes,ResolveGenders,InferWidths,PullMuxes,ExpandConnects,RemoveAccesses,ExpandWhens,LowerTypes) lazy val mapNameToPass: Map[String, Pass] = listOfPasses.map(p => p.name -> p).toMap def executePasses(c: Circuit, passes: Seq[Pass]): Circuit = { @@ -52,9 +52,11 @@ object PassUtils extends LazyLogging { else { val p = passes.head val name = p.name - logger.debug(c.serialize()) logger.debug(s"Starting ${name}") - executePasses(p.run(c), passes.tail) + val x = p.run(c) + logger.debug(x.serialize()) + logger.debug(s"Finished ${name}") + executePasses(x, passes.tail) } } } @@ -172,12 +174,6 @@ object InferTypes extends Pass { } } def remove_unknowns (t:Type): Type = mapr(remove_unknowns_w _,t) - def mapr (f: Width => Width, t:Type) : Type = { - def apply_t (t:Type) : Type = { - wMap(f,tMap(apply_t _,t)) - } - apply_t(t) - } def run (c:Circuit): Circuit = { val module_types = HashMap[String,Type]() def infer_types (m:Module) : Module = { @@ -329,9 +325,291 @@ object CheckGenders extends Pass with StanzaPass { def run (c:Circuit): Circuit = stanzaPass(c, "check-genders") } -object InferWidths extends Pass with StanzaPass { +object InferWidths extends Pass { def name = "Infer Widths" - def run (c:Circuit): Circuit = stanzaPass(c, "infer-widths") + var mname = "" + def solve_constraints (l:Seq[WGeq]) : HashMap[String,Width] = { + def unique (ls:Seq[Width]) : Seq[Width] = ls.map(w => new WrappedWidth(w)).distinct.map(_.w) + def make_unique (ls:Seq[WGeq]) : HashMap[String,Width] = { + val h = HashMap[String,Width]() + for (g <- ls) { + (g.loc) match { + case (w:VarWidth) => { + val n = w.name + if (h.contains(n)) h(n) = MaxWidth(Seq(g.exp,h(n))) else h(n) = g.exp + } + case (w) => w + } + } + h + } + def simplify (w:Width) : Width = { + (wMap(simplify _,w)) match { + case (w:MinWidth) => { + val v = ArrayBuffer[Width]() + for (wx <- w.args) { + (wx) match { + case (wx:MinWidth) => for (x <- wx.args) { v += x } + case (wx) => v += wx } } + MinWidth(unique(v)) } + case (w:MaxWidth) => { + val v = ArrayBuffer[Width]() + for (wx <- w.args) { + (wx) match { + case (wx:MaxWidth) => for (x <- wx.args) { v += x } + case (wx) => v += wx } } + MaxWidth(unique(v)) } + case (w:PlusWidth) => { + (w.arg1,w.arg2) match { + case (w1:IntWidth,w2:IntWidth) => IntWidth(w1.width + w2.width) + case (w1,w2) => w }} + case (w:MinusWidth) => { + (w.arg1,w.arg2) match { + case (w1:IntWidth,w2:IntWidth) => IntWidth(w1.width - w2.width) + case (w1,w2) => w }} + case (w:ExpWidth) => { + (w.arg1) match { + case (w1:IntWidth) => IntWidth((2 ^ w1.width) - 1) + case (w1) => w }} + case (w) => w } } + def substitute (h:HashMap[String,Width])(w:Width) : Width = { + //;println-all-debug(["Substituting for [" w "]"]) + val wx = simplify(w) + //;println-all-debug(["After Simplify: [" wx "]"]) + (wMap(substitute(h) _,simplify(w))) match { + case (w:VarWidth) => { + //;("matched println-debugvarwidth!") + if (h.contains(w.name)) { + //;println-debug("Contained!") + //;println-all-debug(["Width: " w]) + //;println-all-debug(["Accessed: " h[name(w)]]) + val t = simplify(substitute(h)(h(w.name))) + //;val t = h[name(w)] + //;println-all-debug(["Width after sub: " t]) + h(w.name) = t + t + } else w + } + case (w) => w + //;println-all-debug(["not varwidth!" w]) + } + } + def b_sub (h:HashMap[String,Width])(w:Width) : Width = { + (wMap(b_sub(h) _,w)) match { + case (w:VarWidth) => if (h.contains(w.name)) h(w.name) else w + case (w) => w + } + } + def remove_cycle (n:String)(w:Width) : Width = { + //;println-all-debug(["Removing cycle for " n " inside " w]) + val wx = (wMap(remove_cycle(n) _,w)) match { + case (w:MaxWidth) => MaxWidth(w.args.filter{ w => { + w match { + case (w:VarWidth) => n equals w.name + case (w) => false + }}}) + case (w:MinusWidth) => { + w.arg1 match { + case (v:VarWidth) => if (n == v.name) v else w + case (v) => w }} + case (w) => w + } + //;println-all-debug(["After removing cycle for " n ", returning " wx]) + wx + } + def self_rec (n:String,w:Width) : Boolean = { + var has = false + def look (w:Width) : Width = { + (wMap(look _,w)) match { + case (w:VarWidth) => if (w.name == n) has = true + case (w) => w } + w } + look(w) + has } + + //; Forward solve + //; Returns a solved list where each constraint undergoes: + //; 1) Continuous Solving (using triangular solving) + //; 2) Remove Cycles + //; 3) Move to solved if not self-recursive + val u = make_unique(l) + //println-debug("======== UNIQUE CONSTRAINTS ========") + //for (x <- u) { println-debug(x) } + //println-debug("====================================") + + val f = HashMap[String,Width]() + val o = ArrayBuffer[String]() + for (x <- u) { + //println-debug("==== SOLUTIONS TABLE ====") + //for x in f do : println-debug(x) + //println-debug("=========================") + + val (n, e) = (x._1, x._2) + + val e_sub = substitute(f)(e) + //println-debug(["Solving " n " => " e]) + //println-debug(["After Substitute: " n " => " e-sub]) + //println-debug("==== SOLUTIONS TABLE (Post Substitute) ====") + //for x in f do : println-debug(x) + //println-debug("=========================") + val ex = remove_cycle(n)(e_sub) + //;println-debug(["After Remove Cycle: " n " => " ex]) + if (!self_rec(n,ex)) { + //;println-all-debug(["Not rec!: " n " => " ex]) + //;println-all-debug(["Adding [" n "=>" ex "] to Solutions Table"]) + o += n + f(n) = ex + } + } + + //println-debug("Forward Solved Constraints") + //for x in f do : println-debug(x) + + //; Backwards Solve + val b = HashMap[String,Width]() + for (i <- 0 until o.size) { + val n = o(o.size - 1 - i) + //println-all-debug(["SOLVE BACK: [" n " => " f[n] "]"]) + //println-debug("==== SOLUTIONS TABLE ====") + //for x in b do : println-debug(x) + //println-debug("=========================") + val ex = simplify(b_sub(b)(f(n))) + //println-all-debug(["BACK RETURN: [" n " => " ex "]"]) + b(n) = ex + //println-debug("==== SOLUTIONS TABLE (Post backsolve) ====") + //for x in b do : println-debug(x) + //println-debug("=========================") + } + b + } + + def width_BANG (t:Type) : Width = { + (t) match { + case (t:UIntType) => t.width + case (t:SIntType) => t.width + case (t:ClockType) => IntWidth(1) + case (t) => error("No width!"); IntWidth(-1) } } + def width_BANG (e:Expression) : Width = width_BANG(tpe(e)) + def reduce_var_widths (c:Circuit,h:HashMap[String,Width]) : Circuit = { + def evaluate (w:Width) : Width = { + def apply_2 (a:Option[BigInt],b:Option[BigInt], f: (BigInt,BigInt) => BigInt) : Option[BigInt] = { + (a,b) match { + case (a:Some[BigInt],b:Some[BigInt]) => Some(f(a.get,b.get)) + case (a,b) => None } } + def apply_1 (a:Option[BigInt], f: (BigInt) => BigInt) : Option[BigInt] = { + (a) match { + case (a:Some[BigInt]) => Some(f(a.get)) + case (a) => None } } + def apply_l (l:Seq[Option[BigInt]],f:(BigInt,BigInt) => BigInt) : Option[BigInt] = { + if (l.size == 0) Some(BigInt(0)) else apply_2(l.head,apply_l(l.tail,f),f) + } + def max (a:BigInt,b:BigInt) : BigInt = if (a >= b) a else b + def min (a:BigInt,b:BigInt) : BigInt = if (a >= b) b else a + def solve (w:Width) : Option[BigInt] = { + (w) match { + case (w:VarWidth) => { + val wx = h.get(w.name) + (wx) match { + case (wx:Some[Width]) => { + wx.get match { + case (v:VarWidth) => None + case (v) => solve(v) }} + case (None) => None }} + case (w:MaxWidth) => apply_l(w.args.map(solve _),max) + case (w:MinWidth) => apply_l(w.args.map(solve _),min) + case (w:PlusWidth) => apply_2(solve(w.arg1),solve(w.arg2),{_ + _}) + case (w:MinusWidth) => apply_2(solve(w.arg1),solve(w.arg2),{_ - _}) + case (w:ExpWidth) => apply_2(Some(BigInt(2)),solve(w.arg1),{(x,y) => (x ^ y) - BigInt(1)}) + case (w:IntWidth) => Some(w.width) + case (w) => println(w); error("Shouldn't be here"); None; + } + } + val s = solve(w) + (s) match { + case (s:Some[BigInt]) => IntWidth(s.get) + case (s) => w } + } + + def reduce_var_widths_w (w:Width) : Width = { + //println-all-debug(["REPLACE: " w]) + val wx = evaluate(w) + //println-all-debug(["WITH: " wx]) + wx + } + + val modulesx = c.modules.map{ m => { + val portsx = m.ports.map{ p => { + Port(p.info,p.name,p.direction,mapr(reduce_var_widths_w _,p.tpe)) }} + (m) match { + case (m:ExModule) => ExModule(m.info,m.name,portsx) + case (m:InModule) => mname = m.name; InModule(m.info,m.name,portsx,mapr(reduce_var_widths_w _,m.body)) }}} + Circuit(c.info,modulesx,c.main) + } + + def run (c:Circuit): Circuit = { + val v = ArrayBuffer[WGeq]() + def constrain (w1:Width,w2:Width) : Unit = v += WGeq(w1,w2) + def get_constraints_t (t1:Type,t2:Type,f:Flip) : Unit = { + (t1,t2) match { + case (t1:UIntType,t2:UIntType) => constrain(t1.width,t2.width) + case (t1:SIntType,t2:SIntType) => constrain(t1.width,t2.width) + case (t1:BundleType,t2:BundleType) => { + (t1.fields,t2.fields).zipped.foreach{ (f1,f2) => { + get_constraints_t(f1.tpe,f2.tpe,times(f1.flip,f)) }}} + case (t1:VectorType,t2:VectorType) => get_constraints_t(t1.tpe,t2.tpe,f) }} + def get_constraints_e (e:Expression) : Expression = { + (eMap(get_constraints_e _,e)) match { + case (e:Mux) => { + constrain(width_BANG(e.cond),ONE) + constrain(ONE,width_BANG(e.cond)) + e } + case (e) => e }} + def get_constraints (s:Stmt) : Stmt = { + (eMap(get_constraints_e _,s)) match { + case (s:Connect) => { + val n = get_size(tpe(s.loc)) + val ce_loc = create_exps(s.loc) + val ce_exp = create_exps(s.exp) + for (i <- 0 until n) { + val locx = ce_loc(i) + val expx = ce_exp(i) + get_flip(tpe(s.loc),i,DEFAULT) match { + case DEFAULT => constrain(width_BANG(locx),width_BANG(expx)) + case REVERSE => constrain(width_BANG(expx),width_BANG(locx)) }} + s } + case (s:BulkConnect) => { + val ls = get_valid_points(tpe(s.loc),tpe(s.exp),DEFAULT,DEFAULT) + for (x <- ls) { + val locx = create_exps(s.loc)(x._1) + val expx = create_exps(s.exp)(x._2) + get_flip(tpe(s.loc),x._1,DEFAULT) match { + case DEFAULT => constrain(width_BANG(locx),width_BANG(expx)) + case REVERSE => constrain(width_BANG(expx),width_BANG(locx)) }} + s } + case (s:DefRegister) => { + constrain(width_BANG(s.reset),ONE) + constrain(ONE,width_BANG(s.reset)) + get_constraints_t(s.tpe,tpe(s.init),DEFAULT) + s } + case (s:Conditionally) => { + v += WGeq(width_BANG(s.pred),ONE) + v += WGeq(ONE,width_BANG(s.pred)) + sMap(get_constraints _,s) } + case (s) => sMap(get_constraints _,s) }} + + for (m <- c.modules) { + (m) match { + case (m:InModule) => mname = m.name; get_constraints(m.body) + case (m) => false }} + //println-debug("======== ALL CONSTRAINTS ========") + //for x in v do : println-debug(x) + //println-debug("=================================") + val h = solve_constraints(v) + //println-debug("======== SOLVED CONSTRAINTS ========") + //for x in h do : println-debug(x) + //println-debug("====================================") + reduce_var_widths(Circuit(c.info,c.modules,c.main),h) + } } object CheckWidths extends Pass with StanzaPass { @@ -827,37 +1105,258 @@ object ExpandWhens extends Pass with StanzaPass { } object CheckInitialization extends Pass with StanzaPass { - def name = "Check Initialization" - def run (c:Circuit): Circuit = stanzaPass(c, "check-init") + def name = "Check Initialization" + def run (c:Circuit): Circuit = stanzaPass(c, "check-init") } object ConstProp extends Pass with StanzaPass { - def name = "Constant Propogation" - def run (c:Circuit): Circuit = stanzaPass(c, "const-prop") + def name = "Constant Propogation" + def run (c:Circuit): Circuit = stanzaPass(c, "const-prop") } object LoToVerilog extends Pass with StanzaPass { - def name = "Lo To Verilog" - def run (c:Circuit): Circuit = stanzaPass(c, "lo-to-verilog") + def name = "Lo To Verilog" + def run (c:Circuit): Circuit = stanzaPass(c, "lo-to-verilog") } object VerilogWrap extends Pass with StanzaPass { - def name = "Verilog Wrap" - def run (c:Circuit): Circuit = stanzaPass(c, "verilog-wrap") + def name = "Verilog Wrap" + def run (c:Circuit): Circuit = stanzaPass(c, "verilog-wrap") } object SplitExp extends Pass with StanzaPass { - def name = "Split Expressions" - def run (c:Circuit): Circuit = stanzaPass(c, "split-expressions") + def name = "Split Expressions" + def run (c:Circuit): Circuit = stanzaPass(c, "split-expressions") } object VerilogRename extends Pass with StanzaPass { - def name = "Verilog Rename" - def run (c:Circuit): Circuit = stanzaPass(c, "verilog-rename") + def name = "Verilog Rename" + def run (c:Circuit): Circuit = stanzaPass(c, "verilog-rename") } -object LowerTypes extends Pass with StanzaPass { - def name = "Lower Types" - def run (c:Circuit): Circuit = stanzaPass(c, "lower-types") +object LowerTypes extends Pass { + def name = "Lower Types" + var mname = "" + def is_ground (t:Type) : Boolean = { + (t) match { + case (_:UIntType|_:SIntType) => true + case (t) => false + } + } + def data (ex:Expression) : Boolean = { + (kind(ex)) match { + case (k:MemKind) => (ex) match { + case (_:WRef|_:WSubIndex) => false + case (ex:WSubField) => { + var yes = ex.name match { + case "rdata" => true + case "data" => true + case "mask" => true + case _ => false + } + yes && ((ex.exp) match { + case (e:WSubField) => kind(e).as[MemKind].get.ports.contains(e.name) && (e.exp.typeof[WRef]) + case (e) => false + }) + } + case (ex) => false + } + case (k) => false + } + } + def expand_name (e:Expression) : Seq[String] = { + val names = ArrayBuffer[String]() + def expand_name_e (e:Expression) : Expression = { + (eMap(expand_name_e _,e)) match { + case (e:WRef) => names += e.name + case (e:WSubField) => names += e.name + case (e:WSubIndex) => names += e.value.toString + } + e + } + expand_name_e(e) + names + } + def lower_other_mem (e:Expression, dt:Type) : Seq[Expression] = { + val names = expand_name(e) + if (names.size < 3) error("Shouldn't be here") + create_exps(names(0),dt).map{ x => { + var base = lowered_name(x) + for (i <- 0 until names.size) { + if (i >= 3) base = base + "_" + names(i) + } + val m = WRef(base, UnknownType(), kind(e), UNKNOWNGENDER) + val p = WSubField(m,names(1),UnknownType(),UNKNOWNGENDER) + WSubField(p,names(2),UnknownType(),UNKNOWNGENDER) + }} + } + def lower_data_mem (e:Expression) : Expression = { + val names = expand_name(e) + if (names.size < 3) error("Shouldn't be here") + else { + var base = names(0) + for (i <- 0 until names.size) { + if (i >= 3) base = base + "_" + names(i) + } + val m = WRef(base, UnknownType(), kind(e), UNKNOWNGENDER) + val p = WSubField(m,names(1),UnknownType(),UNKNOWNGENDER) + WSubField(p,names(2),UnknownType(),UNKNOWNGENDER) + } + } + def merge (a:String,b:String,x:String) : String = a + x + b + def lowered_name (e:Expression) : String = { + (e) match { + case (e:WRef) => e.name + case (e:WSubField) => lowered_name(e.exp) + "_" + e.name + case (e:WSubIndex) => lowered_name(e.exp) + "_" + e.value + } + } + def root_ref (e:Expression) : WRef = { + (e) match { + case (e:WRef) => e + case (e:WSubField) => root_ref(e.exp) + case (e:WSubIndex) => root_ref(e.exp) + case (e:WSubAccess) => root_ref(e.exp) + } + } + + //;------------- Pass ------------------ + + def lower_types (m:Module) : Module = { + val mdt = HashMap[String,Type]() + mname = m.name + def lower_types (s:Stmt) : Stmt = { + def lower_mem (e:Expression) : Seq[Expression] = { + val names = expand_name(e) + if (Seq("data","mask","rdata").contains(names(2))) Seq(lower_data_mem(e)) + else lower_other_mem(e,mdt(root_ref(e).name)) + } + def lower_types_e (e:Expression) : Expression = { + e match { + case (_:WRef|_:UIntValue|_:SIntValue) => e + case (_:WSubField|_:WSubIndex) => { + (kind(e)) match { + case (k:InstanceKind) => { + val names = expand_name(e) + var n = names(1) + for (i <- 0 until names.size) { + if (i > 1) n = n + "_" + names(i) + } + WSubField(root_ref(e),n,tpe(e),gender(e)) + } + case (k:MemKind) => { + if (gender(e) != FEMALE) lower_mem(e)(0) + else e + } + case (k) => WRef(lowered_name(e),tpe(e),kind(e),gender(e)) + } + } + case (e:DoPrim) => eMap(lower_types_e _,e) + case (e:Mux) => eMap(lower_types_e _,e) + case (e:ValidIf) => eMap(lower_types_e _,e) + } + } + (s) match { + case (s:DefWire) => { + if (is_ground(s.tpe)) s else { + val es = create_exps(s.name,s.tpe) + val stmts = (es, 0 until es.size).zipped.map{ (e,i) => { + DefWire(s.info,lowered_name(e),tpe(e)) + }} + Begin(stmts) + } + } + case (s:DefPoison) => { + if (is_ground(s.tpe)) s else { + val es = create_exps(s.name,s.tpe) + val stmts = (es, 0 until es.size).zipped.map{ (e,i) => { + DefPoison(s.info,lowered_name(e),tpe(e)) + }} + Begin(stmts) + } + } + case (s:DefRegister) => { + if (is_ground(s.tpe)) s else { + val es = create_exps(s.name,s.tpe) + val inits = create_exps(s.init) + val stmts = (es, 0 until es.size).zipped.map{ (e,i) => { + val init = lower_types_e(inits(i)) + DefRegister(s.info,lowered_name(e),tpe(e),s.clock,s.reset,init) + }} + Begin(stmts) + } + } + case (s:WDefInstance) => { + val fieldsx = s.tpe.as[BundleType].get.fields.flatMap{ f => { + val es = create_exps(WRef(f.name,f.tpe,ExpKind(),times(f.flip,MALE))) + es.map{ e => { + gender(e) match { + case MALE => Field(lowered_name(e),DEFAULT,f.tpe) + case FEMALE => Field(lowered_name(e),REVERSE,f.tpe) + } + }} + }} + WDefInstance(s.info,s.name,s.module,BundleType(fieldsx)) + } + case (s:DefMemory) => { + mdt(s.name) = s.data_type + if (is_ground(s.data_type)) s else { + val es = create_exps(s.name,s.data_type) + val stmts = es.map{ e => { + DefMemory(s.info,lowered_name(e),tpe(e),s.depth,s.write_latency,s.read_latency,s.readers,s.writers,s.readwriters) + }} + Begin(stmts) + } + } + case (s:IsInvalid) => { + val sx = eMap(lower_types_e _,s).as[IsInvalid].get + kind(sx.exp) match { + case (k:MemKind) => { + val es = lower_mem(sx.exp) + Begin(es.map(e => {IsInvalid(sx.info,e)})) + } + case (_) => sx + } + } + case (s:Connect) => { + val sx = eMap(lower_types_e _,s).as[Connect].get + kind(sx.loc) match { + case (k:MemKind) => { + val es = lower_mem(sx.loc) + Begin(es.map(e => {Connect(sx.info,e,sx.exp)})) + } + case (_) => sx + } + } + case (s:DefNode) => { + val locs = create_exps(s.name,tpe(s.value)) + val n = locs.size + val nodes = ArrayBuffer[Stmt]() + val exps = create_exps(s.value) + for (i <- 0 until n) { + val locx = locs(i) + val expx = exps(i) + nodes += DefNode(s.info,lowered_name(locx),lower_types_e(expx)) + } + if (n == 1) nodes(0) else Begin(nodes) + } + case (s) => eMap(lower_types_e _,sMap(lower_types _,s)) + } + } + + val portsx = m.ports.flatMap{ p => { + val es = create_exps(WRef(p.name,p.tpe,PortKind(),to_gender(p.direction))) + es.map(e => { Port(p.info,lowered_name(e),to_dir(gender(e)),tpe(e)) }) + }} + (m) match { + case (m:ExModule) => ExModule(m.info,m.name,portsx) + case (m:InModule) => InModule(m.info,m.name,portsx,lower_types(m.body)) + } + } + + def run (c:Circuit) : Circuit = { + val modulesx = c.modules.map(m => lower_types(m)) + Circuit(c.info,modulesx,c.main) + } } diff --git a/test/parser/gcd.fir b/test/parser/gcd.fir index e0958a7a..45a048f2 100644 --- a/test/parser/gcd.fir +++ b/test/parser/gcd.fir @@ -4,7 +4,7 @@ circuit GCD : input e : UInt<1> input clk : Clock input reset : UInt<1> - output z : UInt<16> + output z : UInt output v : UInt<1> input a : UInt<16> input b : UInt<16> -- cgit v1.2.3