diff options
| author | Donggyu | 2016-09-13 16:57:00 -0700 |
|---|---|---|
| committer | GitHub | 2016-09-13 16:57:00 -0700 |
| commit | 96340374f091d5258ca69ef7fc614910e1c2cbb7 (patch) | |
| tree | a283ed9716f10cee128a9a782dada088bba97d5f /src | |
| parent | ad36a1216f52bc01a27dac93cfd8cd42beb84c73 (diff) | |
| parent | 4cb46ca17da26c7ccc0b66a6be489a49fb2e9173 (diff) | |
Merge pull request #284 from ucb-bar/more_utils_cleanups
More utils cleanups
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Emitter.scala | 33 | ||||
| -rw-r--r-- | src/main/scala/firrtl/PrimOps.scala | 4 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 182 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Checks.scala | 12 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/ConstProp.scala | 18 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/InferTypes.scala | 14 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/InferWidths.scala | 20 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/MemUtils.scala | 230 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/PadWidths.scala | 3 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveCHIRRTL.scala | 4 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/ReplaceMemMacros.scala | 4 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Uniquify.scala | 7 |
12 files changed, 221 insertions, 310 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index 1610655b..0af002b2 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -112,10 +112,10 @@ class VerilogEmitter extends Emitter { case (e: Literal) => v_print(e) case (e: VRandom) => w write s"{${e.nWords}{$$random}}" case (t: UIntType) => - val wx = long_BANG(t) - 1 + val wx = bitWidth(t) - 1 if (wx > 0) w write s"[$wx:0]" case (t: SIntType) => - val wx = long_BANG(t) - 1 + val wx = bitWidth(t) - 1 if (wx > 0) w write s"[$wx:0]" case ClockType => case (t: VectorType) => @@ -126,6 +126,7 @@ class VerilogEmitter extends Emitter { case (s: String) => w write s case (i: Int) => w write i.toString case (i: Long) => w write i.toString + case (i: BigInt) => w write i.toString case (t: VIndent) => w write " " case (s: Seq[Any]) => s foreach (emit(_, top + 1)) @@ -189,7 +190,7 @@ class VerilogEmitter extends Emitter { case Eq => Seq(cast_if(a0), " == ", cast_if(a1)) case Neq => Seq(cast_if(a0), " != ", cast_if(a1)) case Pad => - val w = long_BANG(a0.tpe) + val w = bitWidth(a0.tpe) val diff = (c0 - w) if (w == 0) Seq(a0) else doprim.tpe match { @@ -210,9 +211,9 @@ class VerilogEmitter extends Emitter { } case Shlw => Seq(cast(a0), " << ", c0) case Shl => Seq(cast(a0), " << ", c0) - case Shr if c0 >= long_BANG(a0.tpe) => + case Shr if c0 >= bitWidth(a0.tpe) => error("Verilog emitter does not support SHIFT_RIGHT >= arg width") - case Shr => Seq(a0,"[", long_BANG(a0.tpe) - 1, ":", c0, "]") + case Shr => Seq(a0,"[", bitWidth(a0.tpe) - 1, ":", c0, "]") case Neg => Seq("-{", cast(a0), "}") case Cvt => a0.tpe match { case (_: UIntType) => Seq("{1'b0,", cast(a0), "}") @@ -222,24 +223,24 @@ class VerilogEmitter extends Emitter { case And => Seq(cast_as(a0), " & ", cast_as(a1)) case Or => Seq(cast_as(a0), " | ", cast_as(a1)) case Xor => Seq(cast_as(a0), " ^ ", cast_as(a1)) - case Andr => (0 until long_BANG(doprim.tpe).toInt) map ( + case Andr => (0 until bitWidth(doprim.tpe).toInt) map ( Seq(cast(a0), "[", _, "]")) reduce (_ + " & " + _) - case Orr => (0 until long_BANG(doprim.tpe).toInt) map ( + case Orr => (0 until bitWidth(doprim.tpe).toInt) map ( Seq(cast(a0), "[", _, "]")) reduce (_ + " | " + _) - case Xorr => (0 until long_BANG(doprim.tpe).toInt) map ( + case Xorr => (0 until bitWidth(doprim.tpe).toInt) map ( Seq(cast(a0), "[", _, "]")) reduce (_ + " ^ " + _) case Cat => Seq("{", cast(a0), ",", cast(a1), "}") // If selecting zeroth bit and single-bit wire, just emit the wire - case Bits if c0 == 0 && c1 == 0 && long_BANG(a0.tpe) == 1 => Seq(a0) + case Bits if c0 == 0 && c1 == 0 && bitWidth(a0.tpe) == 1 => Seq(a0) case Bits if c0 == c1 => Seq(a0, "[", c0, "]") case Bits => Seq(a0, "[", c0, ":", c1, "]") case Head => - val w = long_BANG(a0.tpe) + val w = bitWidth(a0.tpe) val high = w - 1 val low = w - c0 Seq(a0, "[", high, ":", low, "]") case Tail => - val w = long_BANG(a0.tpe) + val w = bitWidth(a0.tpe) val low = w - c0 - 1 Seq(a0, "[", low, ":", 0, "]") } @@ -260,7 +261,7 @@ class VerilogEmitter extends Emitter { simlist += s s case (s: DefNode) => - val e = WRef(s.name, get_type(s), NodeKind(), MALE) + val e = WRef(s.name, s.value.tpe, NodeKind(), MALE) netlist(e) = s.value s case (s) => s @@ -350,11 +351,11 @@ class VerilogEmitter extends Emitter { // Then, return the correct number of bits selected from the random value def rand_string(t: Type) : Seq[Any] = { val nx = namespace.newTemp - val rand = VRandom(long_BANG(t)) + val rand = VRandom(bitWidth(t)) val tx = SIntType(IntWidth(rand.realWidth)) declare("reg",nx, tx) - initials += Seq(wref(nx, tx), " = ", VRandom(long_BANG(t)), ";") - Seq(nx, "[", long_BANG(t) - 1, ":0]") + initials += Seq(wref(nx, tx), " = ", VRandom(bitWidth(t)), ";") + Seq(nx, "[", bitWidth(t) - 1, ":0]") } def initialize(e: Expression) = { @@ -471,7 +472,7 @@ class VerilogEmitter extends Emitter { instantiate(s.name, s.module, es) s case (s: DefMemory) => - val mem = WRef(s.name, get_type(s), + val mem = WRef(s.name, MemPortUtils.memType(s), MemKind(s.readers ++ s.writers ++ s.readwriters), UNKNOWNGENDER) def mem_exp (p: String, f: String) = { val t1 = field_type(mem.tpe, p) diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala index dc6dfadb..8d677104 100644 --- a/src/main/scala/firrtl/PrimOps.scala +++ b/src/main/scala/firrtl/PrimOps.scala @@ -135,8 +135,8 @@ object PrimOps extends LazyLogging { def t1 = e.args(0).tpe def t2 = e.args(1).tpe def t3 = e.args(2).tpe - def w1 = Utils.width_BANG(e.args(0).tpe) - def w2 = Utils.width_BANG(e.args(1).tpe) + def w1 = passes.getWidth(e.args(0).tpe) + def w2 = passes.getWidth(e.args(1).tpe) def c1 = IntWidth(e.consts(0)) def c2 = IntWidth(e.consts(1)) e copy (tpe = (e.op match { diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 572d1ccc..d9c74840 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -87,23 +87,16 @@ object Utils extends LazyLogging { def min(a: BigInt, b: BigInt): BigInt = if (a >= b) b else a def pow_minus_one(a: BigInt, b: BigInt): BigInt = a.pow(b.toInt) - 1 val BoolType = UIntType(IntWidth(1)) - val one = UIntLiteral(BigInt(1),IntWidth(1)) - val zero = UIntLiteral(BigInt(0),IntWidth(1)) - def uint (i:Int) : UIntLiteral = { - val num_bits = req_num_bits(i) - val w = IntWidth(scala.math.max(1,num_bits - 1)) - UIntLiteral(BigInt(i),w) + val one = UIntLiteral(BigInt(1), IntWidth(1)) + val zero = UIntLiteral(BigInt(0), IntWidth(1)) + def uint(i: Int): UIntLiteral = { + val num_bits = req_num_bits(i) + val w = IntWidth(scala.math.max(1, num_bits - 1)) + UIntLiteral(BigInt(i), w) } - def req_num_bits (i: Int) : Int = { - val ix = if (i < 0) ((-1 * i) - 1) else i - ceil_log2(ix + 1) + 1 - } - - def create_mask(dt: Type): Type = dt match { - case t: VectorType => VectorType(create_mask(t.tpe),t.size) - case t: BundleType => BundleType(t.fields.map (f => f.copy(tpe=create_mask(f.tpe)))) - case t: UIntType => BoolType - case t: SIntType => BoolType + def req_num_bits(i: Int): Int = { + val ix = if (i < 0) ((-1 * i) - 1) else i + ceil_log2(ix + 1) + 1 } def create_exps(n: String, t: Type): Seq[Expression] = @@ -113,14 +106,14 @@ object Utils extends LazyLogging { val e1s = create_exps(e.tval) val e2s = create_exps(e.fval) e1s zip e2s map {case (e1, e2) => - Mux(e.cond, e1, e2, mux_type_and_widths(e1,e2)) + Mux(e.cond, e1, e2, mux_type_and_widths(e1, e2)) } case (e: ValidIf) => create_exps(e.value) map (e1 => ValidIf(e.cond, e1, e1.tpe)) case (e) => e.tpe match { case (_: GroundType) => Seq(e) case (t: BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) => exps ++ create_exps(WSubField(e, f.name, f.tpe,times(gender(e), f.flip)))) - case (t: VectorType) => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) => + case (t: VectorType) => (0 until t.size foldLeft Seq[Expression]())((exps, i) => exps ++ create_exps(WSubIndex(e, i, t.tpe,gender(e)))) } } @@ -172,30 +165,30 @@ object Utils extends LazyLogging { } //============== TYPES ================ - def mux_type (e1:Expression, e2:Expression) : Type = mux_type(e1.tpe, e2.tpe) - def mux_type (t1:Type, t2:Type) : Type = (t1,t2) match { - case (t1:UIntType, t2:UIntType) => UIntType(UnknownWidth) - case (t1:SIntType, t2:SIntType) => SIntType(UnknownWidth) - case (t1:VectorType, t2:VectorType) => VectorType(mux_type(t1.tpe, t2.tpe), t1.size) - case (t1:BundleType, t2:BundleType) => BundleType((t1.fields zip t2.fields) map { + def mux_type(e1: Expression, e2: Expression): Type = mux_type(e1.tpe, e2.tpe) + def mux_type(t1: Type, t2: Type): Type = (t1, t2) match { + case (t1: UIntType, t2: UIntType) => UIntType(UnknownWidth) + case (t1: SIntType, t2: SIntType) => SIntType(UnknownWidth) + case (t1: VectorType, t2: VectorType) => VectorType(mux_type(t1.tpe, t2.tpe), t1.size) + case (t1: BundleType, t2: BundleType) => BundleType(t1.fields zip t2.fields map { case (f1, f2) => Field(f1.name, f1.flip, mux_type(f1.tpe, f2.tpe)) }) case _ => UnknownType } - def mux_type_and_widths (e1:Expression,e2:Expression) : Type = + def mux_type_and_widths(e1: Expression,e2: Expression): Type = mux_type_and_widths(e1.tpe, e2.tpe) - def mux_type_and_widths (t1:Type, t2:Type) : Type = { - def wmax (w1:Width, w2:Width) : Width = (w1,w2) match { - case (w1:IntWidth, w2:IntWidth) => IntWidth(w1.width max w2.width) + def mux_type_and_widths(t1: Type, t2: Type): Type = { + def wmax(w1: Width, w2: Width): Width = (w1, w2) match { + case (w1: IntWidth, w2: IntWidth) => IntWidth(w1.width max w2.width) case (w1, w2) => MaxWidth(Seq(w1, w2)) } (t1, t2) match { - case (t1:UIntType, t2:UIntType) => UIntType(wmax(t1.width, t2.width)) - case (t1:SIntType, t2:SIntType) => SIntType(wmax(t1.width, t2.width)) - case (t1:VectorType, t2:VectorType) => VectorType( + case (t1: UIntType, t2: UIntType) => UIntType(wmax(t1.width, t2.width)) + case (t1: SIntType, t2: SIntType) => SIntType(wmax(t1.width, t2.width)) + case (t1: VectorType, t2: VectorType) => VectorType( mux_type_and_widths(t1.tpe, t2.tpe), t1.size) - case (t1:BundleType, t2:BundleType) => BundleType((t1.fields zip t2.fields) map { - case (f1, f2) => Field(f1.name,f1.flip,mux_type_and_widths(f1.tpe, f2.tpe)) + case (t1: BundleType, t2: BundleType) => BundleType(t1.fields zip t2.fields map { + case (f1, f2) => Field(f1.name, f1.flip, mux_type_and_widths(f1.tpe, f2.tpe)) }) case _ => UnknownType } @@ -208,7 +201,7 @@ object Utils extends LazyLogging { case v: VectorType => v.tpe case v => UnknownType } - def field_type(v:Type, s: String) : Type = v match { + def field_type(v: Type, s: String) : Type = v match { case v: BundleType => v.fields find (_.name == s) match { case Some(f) => f.tpe case None => UnknownType @@ -216,22 +209,6 @@ object Utils extends LazyLogging { case v => UnknownType } -////===================================== - def width_BANG(t: Type) : Width = t match { - case g: GroundType => g.width - case t => error("No width!") - } - def width_BANG(e: Expression) : Width = width_BANG(e.tpe) - def long_BANG(t: Type): Long = t match { - case (g: GroundType) => g.width match { - case IntWidth(x) => x.toLong - case _ => error(s"Expecting IntWidth, got: ${g.width}") - } - case (t: BundleType) => (t.fields foldLeft 0)((w, f) => - w + long_BANG(f.tpe).toInt) - case (t: VectorType) => t.size * long_BANG(t.tpe) - } - // ================================= def error(str: String) = throw new FIRRTLException(str) @@ -306,15 +283,15 @@ object Utils extends LazyLogging { case FEMALE => Default } - def field_flip(v:Type, s:String) : Orientation = v match { - case (v:BundleType) => v.fields find (_.name == s) match { + def field_flip(v: Type, s: String): Orientation = v match { + case (v: BundleType) => v.fields find (_.name == s) match { case Some(ft) => ft.flip case None => Default } case v => Default } - def get_field(v:Type, s:String) : Field = v match { - case (v:BundleType) => v.fields find (_.name == s) match { + def get_field(v: Type, s: String): Field = v match { + case (v: BundleType) => v.fields find (_.name == s) match { case Some(ft) => ft case None => error("Shouldn't be here") } @@ -350,7 +327,7 @@ object Utils extends LazyLogging { case e: WSubAccess => kind(e.exp) case e => ExpKind() } - def gender (e: Expression): Gender = e match { + def gender(e: Expression): Gender = e match { case e: WRef => e.gender case e: WSubField => e.gender case e: WSubIndex => e.gender @@ -363,7 +340,7 @@ object Utils extends LazyLogging { case e: WInvalid => MALE case e => println(e); error("Shouldn't be here") } - def get_gender(s:Statement): Gender = s match { + def get_gender(s: Statement): Gender = s match { case s: DefWire => BIGENDER case s: DefRegister => BIGENDER case s: WDefInstance => MALE @@ -379,37 +356,6 @@ object Utils extends LazyLogging { case EmptyStmt => UNKNOWNGENDER } def get_gender(p: Port): Gender = if (p.direction == Input) MALE else FEMALE - def get_type(s: Statement): Type = s match { - case s: DefWire => s.tpe - case s: DefRegister => s.tpe - case s: DefNode => s.value.tpe - case s: DefMemory => - val depth = s.depth - val addr = Field("addr", Default, UIntType(IntWidth(scala.math.max(ceil_log2(depth), 1)))) - val en = Field("en", Default, BoolType) - val clk = Field("clk", Default, ClockType) - val def_data = Field("data", Default, s.dataType) - val rev_data = Field("data", Flip, s.dataType) - val mask = Field("mask", Default, create_mask(s.dataType)) - val wmode = Field("wmode", Default, UIntType(IntWidth(1))) - val rdata = Field("rdata", Flip, s.dataType) - val wdata = Field("wdata", Default, s.dataType) - val wmask = Field("wmask", Default, create_mask(s.dataType)) - val read_type = BundleType(Seq(rev_data, addr, en, clk)) - val write_type = BundleType(Seq(def_data, mask, addr, en, clk)) - val readwrite_type = BundleType(Seq(wmode, rdata, wdata, wmask, addr, en, clk)) - BundleType( - (s.readers map (Field(_, Flip, read_type))) ++ - (s.writers map (Field(_, Flip, write_type))) ++ - (s.readwriters map (Field(_, Flip, readwrite_type))) - ) - case s: WDefInstance => s.tpe - case _ => UnknownType - } - def get_name(s: Statement): String = s match { - case s: HasName => s.name - case _ => error("Shouldn't be here") - } def get_info(s: Statement): Info = s match { case s: HasInfo => s.info case _ => NoInfo @@ -492,68 +438,6 @@ object Utils extends LazyLogging { } } -// =============== RECURISVE MAPPERS =================== - def mapr (f: Width => Width, t:Type) : Type = { - def apply_t (t:Type) : Type = t map (apply_t) map (f) - apply_t(t) - } - def mapr (f: Width => Width, s:Statement) : Statement = { - def apply_t (t:Type) : Type = mapr(f,t) - def apply_e (e:Expression) : Expression = e map (apply_e) map (apply_t) map (f) - def apply_s (s:Statement) : Statement = s map (apply_s) map (apply_e) map (apply_t) - apply_s(s) - } - //def digits (s:String) : Boolean { - // val digits = "0123456789" - // var yes:Boolean = true - // for (c <- s) { - // if !digits.contains(c) : yes = false - // } - // yes - //} - //def generated (s:String) : Option[Int] = { - // (1 until s.length() - 1).find{ - // i => { - // val sub = s.substring(i + 1) - // s.substring(i,i).equals("_") & digits(sub) & !s.substring(i - 1,i-1).equals("_") - // } - // } - //} - //def get-sym-hash (m:InModule) : LinkedHashMap[String,Int] = { get-sym-hash(m,Seq()) } - //def get-sym-hash (m:InModule,keywords:Seq[String]) : LinkedHashMap[String,Int] = { - // val sym-hash = LinkedHashMap[String,Int]() - // for (k <- keywords) { sym-hash += (k -> 0) } - // def add-name (s:String) : String = { - // val sx = to-string(s) - // val ix = generated(sx) - // ix match { - // case (i:False) => { - // if (sym_hash.contains(s)) { - // val num = sym-hash(s) - // sym-hash += (s -> max(num,0)) - // } else { - // sym-hash += (s -> 0) - // } - // } - // case (i:Int) => { - // val name = sx.substring(0,i) - // val digit = to-int(substring(sx,i + 1)) - // if key?(sym-hash,name) : - // val num = sym-hash[name] - // sym-hash[name] = max(num,digit) - // else : - // sym-hash[name] = digit - // } - // s - // - // defn to-port (p:Port) : add-name(name(p)) - // defn to-stmt (s:Stmt) -> Stmt : - // map{to-stmt,_} $ map(add-name,s) - // - // to-stmt(body(m)) - // map(to-port,ports(m)) - // sym-hash - val v_keywords = Set( "alias", "always", "always_comb", "always_ff", "always_latch", "and", "assert", "assign", "assume", "attribute", "automatic", diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index bba3efe7..16b16ff7 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -572,12 +572,12 @@ object CheckWidths extends Pass { errors append new WidthTooSmall(info, mname, e.value) case _ => } - case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if long_BANG(a.tpe) <= hi => - errors append new BitsWidthException(info, mname, hi, long_BANG(a.tpe)) - case DoPrim(Head, Seq(a), Seq(n), _) if long_BANG(a.tpe) < n => - errors append new HeadWidthException(info, mname, n, long_BANG(a.tpe)) - case DoPrim(Tail, Seq(a), Seq(n), _) if long_BANG(a.tpe) <= n => - errors append new TailWidthException(info, mname, n, long_BANG(a.tpe)) + case DoPrim(Bits, Seq(a), Seq(hi, lo), _) if bitWidth(a.tpe) <= hi => + errors append new BitsWidthException(info, mname, hi, bitWidth(a.tpe)) + case DoPrim(Head, Seq(a), Seq(n), _) if bitWidth(a.tpe) < n => + errors append new HeadWidthException(info, mname, n, bitWidth(a.tpe)) + case DoPrim(Tail, Seq(a), Seq(n), _) if bitWidth(a.tpe) <= n => + errors append new TailWidthException(info, mname, n, bitWidth(a.tpe)) case _ => } e map check_width_w(info, mname) map check_width_e(info, mname) diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala index a4d9078c..789f2e03 100644 --- a/src/main/scala/firrtl/passes/ConstProp.scala +++ b/src/main/scala/firrtl/passes/ConstProp.scala @@ -38,7 +38,7 @@ import annotation.tailrec object ConstProp extends Pass { def name = "Constant Propagation" - private def pad(e: Expression, t: Type) = (long_BANG(e.tpe), long_BANG(t)) match { + private def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match { case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) case (we, wt) if we == wt => e } @@ -62,7 +62,7 @@ object ConstProp extends Pass { def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, w) if v == 0 => UIntLiteral(0, w) case SIntLiteral(v, w) if v == 0 => UIntLiteral(0, w) - case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << long_BANG(rhs.tpe).toInt) - 1 => rhs + case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => rhs case _ => e } } @@ -72,7 +72,7 @@ object ConstProp extends Pass { def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, _) if v == 0 => rhs case SIntLiteral(v, _) if v == 0 => asUInt(rhs, e.tpe) - case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << long_BANG(rhs.tpe).toInt) - 1 => lhs + case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => lhs case _ => e } } @@ -89,7 +89,7 @@ object ConstProp extends Pass { object FoldEqual extends FoldLogicalOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1)) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, IntWidth(w)) if v == 1 && w == 1 && long_BANG(rhs.tpe) == 1 => rhs + case UIntLiteral(v, IntWidth(w)) if v == 1 && w == 1 && bitWidth(rhs.tpe) == 1 => rhs case _ => e } } @@ -97,7 +97,7 @@ object ConstProp extends Pass { object FoldNotEqual extends FoldLogicalOp { def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1)) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, IntWidth(w)) if v == 0 && w == 1 && long_BANG(rhs.tpe) == 1 => rhs + case UIntLiteral(v, IntWidth(w)) if v == 0 && w == 1 && bitWidth(rhs.tpe) == 1 => rhs case _ => e } } @@ -226,7 +226,7 @@ object ConstProp extends Pass { case Pad => e.args(0) match { case UIntLiteral(v, _) => UIntLiteral(v, IntWidth(e.consts(0))) case SIntLiteral(v, _) => SIntLiteral(v, IntWidth(e.consts(0))) - case _ if long_BANG(e.args(0).tpe) == e.consts(0) => e.args(0) + case _ if bitWidth(e.args(0).tpe) == e.consts(0) => e.args(0) case _ => e } case Bits => e.args(0) match { @@ -234,9 +234,9 @@ object ConstProp extends Pass { val hi = e.consts(0).toInt val lo = e.consts(1).toInt require(hi >= lo) - UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), width_BANG(e.tpe)) + UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe)) } - case x if long_BANG(e.tpe) == long_BANG(x.tpe) => x.tpe match { + case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match { case t: UIntType => x case _ => asUInt(x, e.tpe) } @@ -253,7 +253,7 @@ object ConstProp extends Pass { private def constPropMux(m: Mux): Expression = (m.tval, m.fval) match { case _ if m.tval == m.fval => m.tval case (t: UIntLiteral, f: UIntLiteral) => - if (t.value == 1 && f.value == 0 && long_BANG(m.tpe) == 1) m.cond + if (t.value == 1 && f.value == 0 && bitWidth(m.tpe) == 1) m.cond else constPropMuxCond(m) case _ => constPropMuxCond(m) } diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala index b36298e8..79200a58 100644 --- a/src/main/scala/firrtl/passes/InferTypes.scala +++ b/src/main/scala/firrtl/passes/InferTypes.scala @@ -66,20 +66,20 @@ object InferTypes extends Pass { types(s.name) = t s copy (tpe = t) case s: DefWire => - val t = remove_unknowns(get_type(s)) + val t = remove_unknowns(s.tpe) types(s.name) = t s copy (tpe = t) case s: DefNode => - val sx = s map infer_types_e(types) - val t = remove_unknowns(get_type(sx)) + val sx = (s map infer_types_e(types)).asInstanceOf[DefNode] + val t = remove_unknowns(sx.value.tpe) types(s.name) = t sx map infer_types_e(types) case s: DefRegister => - val t = remove_unknowns(get_type(s)) + val t = remove_unknowns(s.tpe) types(s.name) = t s copy (tpe = t) map infer_types_e(types) case s: DefMemory => - val t = remove_unknowns(get_type(s)) + val t = remove_unknowns(MemPortUtils.memType(s)) types(s.name) = t s copy (dataType = remove_unknowns(s.dataType)) case s => s map infer_types_s(types) map infer_types_e(types) @@ -128,10 +128,10 @@ object CInferTypes extends Pass { types(s.name) = s.tpe s case (s: DefNode) => - types(s.name) = get_type(s) + types(s.name) = s.value.tpe s case (s: DefMemory) => - types(s.name) = get_type(s) + types(s.name) = MemPortUtils.memType(s) s case (s: CDefMPort) => val t = types getOrElse(s.mem, UnknownType) diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index 5a81c268..6b2ff6ed 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -214,8 +214,8 @@ object InferWidths extends Pass { def get_constraints_e(e: Expression): Expression = { e match { case (e: Mux) => v ++= Seq( - WGeq(width_BANG(e.cond), IntWidth(1)), - WGeq(IntWidth(1), width_BANG(e.cond)) + WGeq(getWidth(e.cond), IntWidth(1)), + WGeq(IntWidth(1), getWidth(e.cond)) ) case _ => } @@ -230,8 +230,8 @@ object InferWidths extends Pass { val exps = create_exps(s.expr) v ++= ((locs zip exps).zipWithIndex map {case ((locx, expx), i) => get_flip(s.loc.tpe, i, Default) match { - case Default => WGeq(width_BANG(locx), width_BANG(expx)) - case Flip => WGeq(width_BANG(expx), width_BANG(locx)) + case Default => WGeq(getWidth(locx), getWidth(expx)) + case Flip => WGeq(getWidth(expx), getWidth(locx)) } }) case (s: PartialConnect) => @@ -242,17 +242,17 @@ object InferWidths extends Pass { val locx = locs(x) val expx = exps(y) get_flip(s.loc.tpe, x, Default) match { - case Default => WGeq(width_BANG(locx), width_BANG(expx)) - case Flip => WGeq(width_BANG(expx), width_BANG(locx)) + case Default => WGeq(getWidth(locx), getWidth(expx)) + case Flip => WGeq(getWidth(expx), getWidth(locx)) } }) case (s:DefRegister) => v ++= (Seq( - WGeq(width_BANG(s.reset), IntWidth(1)), - WGeq(IntWidth(1), width_BANG(s.reset)) + WGeq(getWidth(s.reset), IntWidth(1)), + WGeq(IntWidth(1), getWidth(s.reset)) ) ++ get_constraints_t(s.tpe, s.init.tpe, Default)) case (s:Conditionally) => v ++= Seq( - WGeq(width_BANG(s.pred), IntWidth(1)), - WGeq(IntWidth(1), width_BANG(s.pred)) + WGeq(getWidth(s.pred), IntWidth(1)), + WGeq(IntWidth(1), getWidth(s.pred)) ) case _ => } diff --git a/src/main/scala/firrtl/passes/MemUtils.scala b/src/main/scala/firrtl/passes/MemUtils.scala index adbf23e5..87033176 100644 --- a/src/main/scala/firrtl/passes/MemUtils.scala +++ b/src/main/scala/firrtl/passes/MemUtils.scala @@ -38,24 +38,23 @@ object seqCat { def apply(args: Seq[Expression]): Expression = args.length match { case 0 => error("Empty Seq passed to seqcat") case 1 => args(0) - case 2 => DoPrim(PrimOps.Cat, args, Seq.empty[BigInt], UIntType(UnknownWidth)) - case _ => { - val seqs = args.splitAt(args.length/2) - DoPrim(PrimOps.Cat, Seq(seqCat(seqs._1), seqCat(seqs._2)), Seq.empty[BigInt], UIntType(UnknownWidth)) - } + case 2 => DoPrim(PrimOps.Cat, args, Nil, UIntType(UnknownWidth)) + case _ => + val (high, low) = args splitAt (args.length / 2) + DoPrim(PrimOps.Cat, Seq(seqCat(high), seqCat(low)), Nil, UIntType(UnknownWidth)) } } object toBits { def apply(e: Expression): Expression = e match { - case ex: WRef => hiercat(ex, ex.tpe) - case ex: WSubField => hiercat(ex, ex.tpe) - case ex: WSubIndex => hiercat(ex, ex.tpe) + case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiercat(ex, ex.tpe) case t => error("Invalid operand expression for toBits!") } - def hiercat(e: Expression, dt: Type): Expression = dt match { - case t: VectorType => seqCat((0 until t.size).reverse.map(i => hiercat(WSubIndex(e, i, t.tpe, UNKNOWNGENDER), t.tpe))) - case t: BundleType => seqCat(t.fields.map(f => hiercat(WSubField(e, f.name, f.tpe, UNKNOWNGENDER), f.tpe))) + private def hiercat(e: Expression, dt: Type): Expression = dt match { + case t: VectorType => seqCat((0 until t.size) map (i => + hiercat(WSubIndex(e, i, t.tpe, UNKNOWNGENDER),t.tpe))) + case t: BundleType => seqCat(t.fields map (f => + hiercat(WSubField(e, f.name, f.tpe, UNKNOWNGENDER), f.tpe))) case t: GroundType => e case t => error("Unknown type encountered in toBits!") } @@ -64,23 +63,36 @@ object toBits { // TODO: make easier to understand object toBitMask { def apply(e: Expression, dataType: Type): Expression = e match { - case ex: WRef => hiermask(ex, ex.tpe, dataType) - case ex: WSubField => hiermask(ex, ex.tpe, dataType) - case ex: WSubIndex => hiermask(ex, ex.tpe, dataType) + case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiermask(ex, ex.tpe, dataType) case t => error("Invalid operand expression for toBits!") } - def hiermask(e: Expression, maskType: Type, dataType: Type): Expression = (maskType, dataType) match { - case (mt: VectorType, dt: VectorType) => seqCat((0 until mt.size).reverse.map(i => hiermask(WSubIndex(e, i, mt.tpe, UNKNOWNGENDER), mt.tpe, dt.tpe))) - case (mt: BundleType, dt: BundleType) => seqCat((mt.fields zip dt.fields).map { case (mf, df) => - hiermask(WSubField(e, mf.name, mf.tpe, UNKNOWNGENDER), mf.tpe, df.tpe) } ) - case (mt: UIntType, dt: GroundType) => seqCat(List.fill(bitWidth(dt).intValue)(e)) - case (mt, dt) => error("Invalid type for mask component!") + private def hiermask(e: Expression, maskType: Type, dataType: Type): Expression = + (maskType, dataType) match { + case (mt: VectorType, dt: VectorType) => + seqCat((0 until mt.size).reverse map { i => + hiermask(WSubIndex(e, i, mt.tpe, UNKNOWNGENDER), mt.tpe, dt.tpe) + }) + case (mt: BundleType, dt: BundleType) => + seqCat((mt.fields zip dt.fields) map { case (mf, df) => + hiermask(WSubField(e, mf.name, mf.tpe, UNKNOWNGENDER), mf.tpe, df.tpe) + }) + case (mt: UIntType, dt: GroundType) => + seqCat(List.fill(bitWidth(dt).intValue)(e)) + case (mt, dt) => error("Invalid type for mask component!") + } +} + +object getWidth { + def apply(t: Type): Width = t match { + case t: GroundType => t.width + case _ => error("No width!") } + def apply(e: Expression): Width = apply(e.tpe) } object bitWidth { def apply(dt: Type): BigInt = widthOf(dt) - def widthOf(dt: Type): BigInt = dt match { + private def widthOf(dt: Type): BigInt = dt match { case t: VectorType => t.size * bitWidth(t.tpe) case t: BundleType => t.fields.map(f => bitWidth(f.tpe)).foldLeft(BigInt(0))(_+_) case GroundType(IntWidth(width)) => width @@ -91,43 +103,47 @@ object bitWidth { object fromBits { def apply(lhs: Expression, rhs: Expression): Statement = { val fbits = lhs match { - case ex: WRef => getPart(ex, ex.tpe, rhs, 0) - case ex: WSubField => getPart(ex, ex.tpe, rhs, 0) - case ex: WSubIndex => getPart(ex, ex.tpe, rhs, 0) - case t => error("Invalid LHS expression for fromBits!") + case ex @ (_: WRef | _: WSubField | _: WSubIndex) => getPart(ex, ex.tpe, rhs, 0) + case _ => error("Invalid LHS expression for fromBits!") } Block(fbits._2) } - def getPartGround(lhs: Expression, lhst: Type, rhs: Expression, offset: BigInt): (BigInt, Seq[Statement]) = { + private def getPartGround(lhs: Expression, + lhst: Type, + rhs: Expression, + offset: BigInt): (BigInt, Seq[Statement]) = { val intWidth = bitWidth(lhst) - val sel = DoPrim(PrimOps.Bits, Seq(rhs), Seq(offset+intWidth-1, offset), UnknownType) + val sel = DoPrim(PrimOps.Bits, Seq(rhs), Seq(offset + intWidth - 1, offset), UnknownType) (offset + intWidth, Seq(Connect(NoInfo, lhs, sel))) } - def getPart(lhs: Expression, lhst: Type, rhs: Expression, offset: BigInt): (BigInt, Seq[Statement]) = { + private def getPart(lhs: Expression, + lhst: Type, + rhs: Expression, + offset: BigInt): (BigInt, Seq[Statement]) = lhst match { - case t: VectorType => { - var currentOffset = offset - var stmts = Seq.empty[Statement] - for (i <- (0 until t.size)) { - val (tmpOffset, substmts) = getPart(WSubIndex(lhs, i, t.tpe, UNKNOWNGENDER), t.tpe, rhs, currentOffset) - stmts = stmts ++ substmts - currentOffset = tmpOffset - } - (currentOffset, stmts) + case t: VectorType => (0 until t.size foldRight (offset, Seq[Statement]())) { + case (i, (curOffset, stmts)) => + val subidx = WSubIndex(lhs, i, t.tpe, UNKNOWNGENDER) + val (tmpOffset, substmts) = getPart(subidx, t.tpe, rhs, curOffset) + (tmpOffset, stmts ++ substmts) } - case t: BundleType => { - var currentOffset = offset - var stmts = Seq.empty[Statement] - for (f <- t.fields.reverse) { - val (tmpOffset, substmts) = getPart(WSubField(lhs, f.name, f.tpe, UNKNOWNGENDER), f.tpe, rhs, currentOffset) - stmts = stmts ++ substmts - currentOffset = tmpOffset - } - (currentOffset, stmts) + case t: BundleType => (t.fields foldRight (offset, Seq[Statement]())) { + case (f, (curOffset, stmts)) => + val subfield = WSubField(lhs, f.name, f.tpe, UNKNOWNGENDER) + val (tmpOffset, substmts) = getPart(subfield, f.tpe, rhs, curOffset) + (tmpOffset, stmts ++ substmts) } case t: GroundType => getPartGround(lhs, t, rhs, offset) case t => error("Unknown type encountered in fromBits!") } +} + +object createMask { + def apply(dt: Type): Type = dt match { + case t: VectorType => VectorType(apply(t.tpe), t.size) + case t: BundleType => BundleType(t.fields map (f => f copy (tpe=apply(f.tpe)))) + case t: UIntType => BoolType + case t: SIntType => BoolType } } @@ -138,78 +154,88 @@ object MemPortUtils { def flattenType(t: Type) = UIntType(IntWidth(bitWidth(t))) def defaultPortSeq(mem: DefMemory) = Seq( - Field("addr", Default, UIntType(IntWidth(ceil_log2(mem.depth)))), - Field("en", Default, UIntType(IntWidth(1))), + Field("addr", Default, UIntType(IntWidth(ceil_log2(mem.depth) max 1))), + Field("en", Default, BoolType), Field("clk", Default, ClockType) ) - def getFillWMask(mem: DefMemory) = { - val maskGran = getInfo(mem.info, "maskGran") - if (maskGran == None) false - else maskGran.get == 1 - } + def getFillWMask(mem: DefMemory) = + getInfo(mem.info, "maskGran") match { + case None => false + case Some(maskGran) => maskGran == 1 + } - def rPortToBundle(mem: DefMemory) = BundleType(defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType)) - def rPortToFlattenBundle(mem: DefMemory) = BundleType(defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType))) + def rPortToBundle(mem: DefMemory) = BundleType( + defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType)) + def rPortToFlattenBundle(mem: DefMemory) = BundleType( + defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType))) - def wPortToBundle(mem: DefMemory) = { - val defaultSeq = defaultPortSeq(mem) :+ Field("data", Default, mem.dataType) - BundleType( - if (containsInfo(mem.info, "maskGran")) defaultSeq :+ Field("mask", Default, create_mask(mem.dataType)) - else defaultSeq - ) - } - - def wPortToFlattenBundle(mem: DefMemory) = { - val defaultSeq = defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType)) - BundleType( - if (containsInfo(mem.info, "maskGran")) { - defaultSeq :+ { - if (getFillWMask(mem)) Field("mask", Default, flattenType(mem.dataType)) - else Field("mask", Default, flattenType(create_mask(mem.dataType))) - } - } - else defaultSeq - ) - } - // TODO: Don't use create_mask??? + def wPortToBundle(mem: DefMemory) = BundleType( + (defaultPortSeq(mem) :+ Field("data", Default, mem.dataType)) ++ + (if (!containsInfo(mem.info, "maskGran")) Nil + else Seq(Field("mask", Default, createMask(mem.dataType)))) + ) + def wPortToFlattenBundle(mem: DefMemory) = BundleType( + (defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType))) ++ + (if (!containsInfo(mem.info, "maskGran")) Nil + else if (getFillWMask(mem)) Seq(Field("mask", Default, flattenType(mem.dataType))) + else Seq(Field("mask", Default, flattenType(createMask(mem.dataType))))) + ) + // TODO: Don't use createMask??? - def rwPortToBundle(mem: DefMemory) = { - val defaultSeq = defaultPortSeq(mem) ++ Seq( - Field("wmode", Default, UIntType(IntWidth(1))), + def rwPortToBundle(mem: DefMemory) = BundleType( + defaultPortSeq(mem) ++ Seq( + Field("wmode", Default, BoolType), Field("wdata", Default, mem.dataType), Field("rdata", Flip, mem.dataType) + ) ++ (if (!containsInfo(mem.info, "maskGran")) Nil + else Seq(Field("wmask", Default, createMask(mem.dataType))) ) - BundleType( - if (containsInfo(mem.info, "maskGran")) defaultSeq :+ Field("wmask", Default, create_mask(mem.dataType)) - else defaultSeq - ) - } + ) - def rwPortToFlattenBundle(mem: DefMemory) = { - val defaultSeq = defaultPortSeq(mem) ++ Seq( + def rwPortToFlattenBundle(mem: DefMemory) = BundleType( + defaultPortSeq(mem) ++ Seq( Field("wmode", Default, UIntType(IntWidth(1))), Field("wdata", Default, flattenType(mem.dataType)), Field("rdata", Flip, flattenType(mem.dataType)) - ) - BundleType( - if (containsInfo(mem.info, "maskGran")) { - defaultSeq :+ { - if (getFillWMask(mem)) Field("wmask", Default, flattenType(mem.dataType)) - else Field("wmask", Default, flattenType(create_mask(mem.dataType))) - } - } - else defaultSeq + ) ++ (if (!containsInfo(mem.info, "maskGran")) Nil + else if (getFillWMask(mem)) Seq(Field("wmask", Default, flattenType(mem.dataType))) + else Seq(Field("wmask", Default, flattenType(createMask(mem.dataType)))) ) - } + ) def memToBundle(s: DefMemory) = BundleType( - s.readers.map(p => Field(p, Default, rPortToBundle(s))) ++ - s.writers.map(p => Field(p, Default, wPortToBundle(s))) ++ - s.readwriters.map(p => Field(p, Default, rwPortToBundle(s)))) + s.readers.map(Field(_, Flip, rPortToBundle(s))) ++ + s.writers.map(Field(_, Flip, wPortToBundle(s))) ++ + s.readwriters.map(Field(_, Flip, rwPortToBundle(s)))) def memToFlattenBundle(s: DefMemory) = BundleType( - s.readers.map(p => Field(p, Default, rPortToFlattenBundle(s))) ++ - s.writers.map(p => Field(p, Default, wPortToFlattenBundle(s))) ++ - s.readwriters.map(p => Field(p, Default, rwPortToFlattenBundle(s)))) + s.readers.map(Field(_, Flip, rPortToFlattenBundle(s))) ++ + s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++ + s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s)))) + + // Todo: merge it with memToBundle + def memType(mem: DefMemory) = { + val rType = rPortToBundle(mem) + val wType = BundleType(defaultPortSeq(mem) ++ Seq( + Field("data", Default, mem.dataType), + Field("mask", Default, createMask(mem.dataType)))) + val rwType = BundleType(defaultPortSeq(mem) ++ Seq( + Field("rdata", Flip, mem.dataType), + Field("wmode", Default, UIntType(IntWidth(1))), + Field("wdata", Default, mem.dataType), + Field("wmask", Default, createMask(mem.dataType)))) + BundleType( + (mem.readers map (Field(_, Flip, rType))) ++ + (mem.writers map (Field(_, Flip, wType))) ++ + (mem.readwriters map (Field(_, Flip, rwType)))) + } + + def kind(s: DefMemory) = MemKind(s.readers ++ s.writers ++ s.readwriters) + def memPortField(s: DefMemory, p: String, f: String) = { + val mem = WRef(s.name, memType(s), kind(s), UNKNOWNGENDER) + val t1 = field_type(mem.tpe, p) + val t2 = field_type(t1, f) + WSubField(WSubField(mem, p, t1, UNKNOWNGENDER), f, t2, UNKNOWNGENDER) + } } diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala index 1a134d11..bef9ac33 100644 --- a/src/main/scala/firrtl/passes/PadWidths.scala +++ b/src/main/scala/firrtl/passes/PadWidths.scala @@ -4,12 +4,11 @@ package passes import firrtl.ir._ import firrtl.PrimOps._ import firrtl.Mappers._ -import firrtl.Utils.long_BANG // Makes all implicit width extensions and truncations explicit object PadWidths extends Pass { def name = "Pad Widths" - private def width(t: Type): Int = long_BANG(t).toInt + private def width(t: Type): Int = bitWidth(t).toInt private def width(e: Expression): Int = width(e.tpe) // Returns an expression with the correct integer width private def fixup(i: Int)(e: Expression) = { diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala index 2bae92a7..ca860ab6 100644 --- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala +++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala @@ -101,7 +101,7 @@ object RemoveCHIRRTL extends Pass { Connect(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), wmode, taddr), zero) ) def set_write (vec: Seq[MPort], data: String, mask: String) = vec flatMap {r => - val tmask = create_mask(s.tpe) + val tmask = createMask(s.tpe) IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), data, tdata)) +: (create_exps(SubField(SubField(Reference(s.name, ut), r.name, ut), mask, tmask)) map (Connect(s.info, _, zero)) @@ -160,7 +160,7 @@ object RemoveCHIRRTL extends Pass { e map get_mask(refs) match { case e: Reference => refs get e.name match { case None => e - case Some(p) => SubField(p.exp, p.mask, create_mask(e.tpe)) + case Some(p) => SubField(p.exp, p.mask, createMask(e.tpe)) } case e => e } diff --git a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/ReplaceMemMacros.scala index 54c522d7..7bb9c6c4 100644 --- a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala +++ b/src/main/scala/firrtl/passes/ReplaceMemMacros.scala @@ -117,7 +117,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass { ) ) if (containsInfo(wrapperMem.info, "maskGran")) { - val wrapperMask = create_mask(wrapperMem.dataType) + val wrapperMask = createMask(wrapperMem.dataType) val fillWMask = getFillWMask(wrapperMem) val bbMask = if (fillWMask) flattenType(wrapperMem.dataType) else flattenType(wrapperMask) val rhs = { @@ -150,7 +150,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass { ) ) if (containsInfo(wrapperMem.info, "maskGran")) { - val wrapperMask = create_mask(wrapperMem.dataType) + val wrapperMask = createMask(wrapperMem.dataType) val fillWMask = getFillWMask(wrapperMem) val bbMask = if (fillWMask) flattenType(wrapperMem.dataType) else flattenType(wrapperMask) val rhs = { diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala index d034719a..758791b2 100644 --- a/src/main/scala/firrtl/passes/Uniquify.scala +++ b/src/main/scala/firrtl/passes/Uniquify.scala @@ -34,6 +34,7 @@ import firrtl._ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ +import MemPortUtils.memType /** Resolve name collisions that would occur in [[LowerTypes]] * @@ -228,7 +229,7 @@ object Uniquify extends Pass { case s: WDefInstance => Seq(Field(s.name, Default, s.tpe)) case s: DefMemory => s.dataType match { case (_: UIntType | _: SIntType) => - Seq(Field(s.name, Default, get_type(s))) + Seq(Field(s.name, Default, memType(s))) case tpe: BundleType => val newFields = tpe.fields map ( f => DefMemory(s.info, f.name, f.tpe, s.depth, s.writeLatency, @@ -241,7 +242,7 @@ object Uniquify extends Pass { ) flatMap (recStmtToType) Seq(Field(s.name, Default, BundleType(newFields))) } - case s: DefNode => Seq(Field(s.name, Default, get_type(s))) + case s: DefNode => Seq(Field(s.name, Default, s.value.tpe)) case s: Conditionally => recStmtToType(s.conseq) ++ recStmtToType(s.alt) case s: Block => (s.stmts map (recStmtToType)).flatten case s => Seq() @@ -305,7 +306,7 @@ object Uniquify extends Pass { val dataType = uniquifyNamesType(s.dataType, node.elts) val mem = s.copy(name = node.name, dataType = dataType) // Create new mapping to handle references to memory data fields - val uniqueMemMap = createNameMapping(get_type(s), get_type(mem)) + val uniqueMemMap = createNameMapping(memType(s), memType(mem)) nameMap(s.name) = NameMapNode(node.name, node.elts ++ uniqueMemMap) mem } else { |
