diff options
| author | Donggyu Kim | 2016-09-02 17:07:46 -0700 |
|---|---|---|
| committer | Donggyu Kim | 2016-09-13 13:32:47 -0700 |
| commit | b8ee3179ed8070211c95ecbcceda0f7dbf635a13 (patch) | |
| tree | af33cd813833f2b79b2fc52ed2af0ec9e4338ce5 /src | |
| parent | ad36a1216f52bc01a27dac93cfd8cd42beb84c73 (diff) | |
clean up MemUtils
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 29 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/MemUtils.scala | 222 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveCHIRRTL.scala | 4 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/ReplaceMemMacros.scala | 4 |
4 files changed, 135 insertions, 124 deletions
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 572d1ccc..2c84528f 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -87,23 +87,16 @@ object Utils extends LazyLogging { def min(a: BigInt, b: BigInt): BigInt = if (a >= b) b else a def pow_minus_one(a: BigInt, b: BigInt): BigInt = a.pow(b.toInt) - 1 val BoolType = UIntType(IntWidth(1)) - val one = UIntLiteral(BigInt(1),IntWidth(1)) - val zero = UIntLiteral(BigInt(0),IntWidth(1)) - def uint (i:Int) : UIntLiteral = { - val num_bits = req_num_bits(i) - val w = IntWidth(scala.math.max(1,num_bits - 1)) - UIntLiteral(BigInt(i),w) + val one = UIntLiteral(BigInt(1), IntWidth(1)) + val zero = UIntLiteral(BigInt(0), IntWidth(1)) + def uint(i: Int): UIntLiteral = { + val num_bits = req_num_bits(i) + val w = IntWidth(scala.math.max(1, num_bits - 1)) + UIntLiteral(BigInt(i), w) } - def req_num_bits (i: Int) : Int = { - val ix = if (i < 0) ((-1 * i) - 1) else i - ceil_log2(ix + 1) + 1 - } - - def create_mask(dt: Type): Type = dt match { - case t: VectorType => VectorType(create_mask(t.tpe),t.size) - case t: BundleType => BundleType(t.fields.map (f => f.copy(tpe=create_mask(f.tpe)))) - case t: UIntType => BoolType - case t: SIntType => BoolType + def req_num_bits(i: Int): Int = { + val ix = if (i < 0) ((-1 * i) - 1) else i + ceil_log2(ix + 1) + 1 } def create_exps(n: String, t: Type): Seq[Expression] = @@ -390,11 +383,11 @@ object Utils extends LazyLogging { val clk = Field("clk", Default, ClockType) val def_data = Field("data", Default, s.dataType) val rev_data = Field("data", Flip, s.dataType) - val mask = Field("mask", Default, create_mask(s.dataType)) + val mask = Field("mask", Default, passes.createMask(s.dataType)) val wmode = Field("wmode", Default, UIntType(IntWidth(1))) val rdata = Field("rdata", Flip, s.dataType) val wdata = Field("wdata", Default, s.dataType) - val wmask = Field("wmask", Default, create_mask(s.dataType)) + val wmask = Field("wmask", Default, passes.createMask(s.dataType)) val read_type = BundleType(Seq(rev_data, addr, en, clk)) val write_type = BundleType(Seq(def_data, mask, addr, en, clk)) val readwrite_type = BundleType(Seq(wmode, rdata, wdata, wmask, addr, en, clk)) diff --git a/src/main/scala/firrtl/passes/MemUtils.scala b/src/main/scala/firrtl/passes/MemUtils.scala index adbf23e5..09be2b38 100644 --- a/src/main/scala/firrtl/passes/MemUtils.scala +++ b/src/main/scala/firrtl/passes/MemUtils.scala @@ -38,24 +38,23 @@ object seqCat { def apply(args: Seq[Expression]): Expression = args.length match { case 0 => error("Empty Seq passed to seqcat") case 1 => args(0) - case 2 => DoPrim(PrimOps.Cat, args, Seq.empty[BigInt], UIntType(UnknownWidth)) - case _ => { - val seqs = args.splitAt(args.length/2) - DoPrim(PrimOps.Cat, Seq(seqCat(seqs._1), seqCat(seqs._2)), Seq.empty[BigInt], UIntType(UnknownWidth)) - } + case 2 => DoPrim(PrimOps.Cat, args, Nil, UIntType(UnknownWidth)) + case _ => + val (high, low) = args splitAt (args.length / 2) + DoPrim(PrimOps.Cat, Seq(seqCat(high), seqCat(low)), Nil, UIntType(UnknownWidth)) } } object toBits { def apply(e: Expression): Expression = e match { - case ex: WRef => hiercat(ex, ex.tpe) - case ex: WSubField => hiercat(ex, ex.tpe) - case ex: WSubIndex => hiercat(ex, ex.tpe) + case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiercat(ex, ex.tpe) case t => error("Invalid operand expression for toBits!") } - def hiercat(e: Expression, dt: Type): Expression = dt match { - case t: VectorType => seqCat((0 until t.size).reverse.map(i => hiercat(WSubIndex(e, i, t.tpe, UNKNOWNGENDER), t.tpe))) - case t: BundleType => seqCat(t.fields.map(f => hiercat(WSubField(e, f.name, f.tpe, UNKNOWNGENDER), f.tpe))) + private def hiercat(e: Expression, dt: Type): Expression = dt match { + case t: VectorType => seqCat((0 until t.size) map (i => + hiercat(WSubIndex(e, i, t.tpe, UNKNOWNGENDER),t.tpe))) + case t: BundleType => seqCat(t.fields map (f => + hiercat(WSubField(e, f.name, f.tpe, UNKNOWNGENDER), f.tpe))) case t: GroundType => e case t => error("Unknown type encountered in toBits!") } @@ -64,23 +63,28 @@ object toBits { // TODO: make easier to understand object toBitMask { def apply(e: Expression, dataType: Type): Expression = e match { - case ex: WRef => hiermask(ex, ex.tpe, dataType) - case ex: WSubField => hiermask(ex, ex.tpe, dataType) - case ex: WSubIndex => hiermask(ex, ex.tpe, dataType) + case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiermask(ex, ex.tpe, dataType) case t => error("Invalid operand expression for toBits!") } - def hiermask(e: Expression, maskType: Type, dataType: Type): Expression = (maskType, dataType) match { - case (mt: VectorType, dt: VectorType) => seqCat((0 until mt.size).reverse.map(i => hiermask(WSubIndex(e, i, mt.tpe, UNKNOWNGENDER), mt.tpe, dt.tpe))) - case (mt: BundleType, dt: BundleType) => seqCat((mt.fields zip dt.fields).map { case (mf, df) => - hiermask(WSubField(e, mf.name, mf.tpe, UNKNOWNGENDER), mf.tpe, df.tpe) } ) - case (mt: UIntType, dt: GroundType) => seqCat(List.fill(bitWidth(dt).intValue)(e)) - case (mt, dt) => error("Invalid type for mask component!") - } + private def hiermask(e: Expression, maskType: Type, dataType: Type): Expression = + (maskType, dataType) match { + case (mt: VectorType, dt: VectorType) => + seqCat((0 until mt.size).reverse map { i => + hiermask(WSubIndex(e, i, mt.tpe, UNKNOWNGENDER), mt.tpe, dt.tpe) + }) + case (mt: BundleType, dt: BundleType) => + seqCat((mt.fields zip dt.fields) map { case (mf, df) => + hiermask(WSubField(e, mf.name, mf.tpe, UNKNOWNGENDER), mf.tpe, df.tpe) + }) + case (mt: UIntType, dt: GroundType) => + seqCat(List.fill(bitWidth(dt).intValue)(e)) + case (mt, dt) => error("Invalid type for mask component!") + } } object bitWidth { def apply(dt: Type): BigInt = widthOf(dt) - def widthOf(dt: Type): BigInt = dt match { + private def widthOf(dt: Type): BigInt = dt match { case t: VectorType => t.size * bitWidth(t.tpe) case t: BundleType => t.fields.map(f => bitWidth(f.tpe)).foldLeft(BigInt(0))(_+_) case GroundType(IntWidth(width)) => width @@ -91,43 +95,47 @@ object bitWidth { object fromBits { def apply(lhs: Expression, rhs: Expression): Statement = { val fbits = lhs match { - case ex: WRef => getPart(ex, ex.tpe, rhs, 0) - case ex: WSubField => getPart(ex, ex.tpe, rhs, 0) - case ex: WSubIndex => getPart(ex, ex.tpe, rhs, 0) - case t => error("Invalid LHS expression for fromBits!") + case ex @ (_: WRef | _: WSubField | _: WSubIndex) => getPart(ex, ex.tpe, rhs, 0) + case _ => error("Invalid LHS expression for fromBits!") } Block(fbits._2) } - def getPartGround(lhs: Expression, lhst: Type, rhs: Expression, offset: BigInt): (BigInt, Seq[Statement]) = { + private def getPartGround(lhs: Expression, + lhst: Type, + rhs: Expression, + offset: BigInt): (BigInt, Seq[Statement]) = { val intWidth = bitWidth(lhst) - val sel = DoPrim(PrimOps.Bits, Seq(rhs), Seq(offset+intWidth-1, offset), UnknownType) + val sel = DoPrim(PrimOps.Bits, Seq(rhs), Seq(offset + intWidth - 1, offset), UnknownType) (offset + intWidth, Seq(Connect(NoInfo, lhs, sel))) } - def getPart(lhs: Expression, lhst: Type, rhs: Expression, offset: BigInt): (BigInt, Seq[Statement]) = { + private def getPart(lhs: Expression, + lhst: Type, + rhs: Expression, + offset: BigInt): (BigInt, Seq[Statement]) = lhst match { - case t: VectorType => { - var currentOffset = offset - var stmts = Seq.empty[Statement] - for (i <- (0 until t.size)) { - val (tmpOffset, substmts) = getPart(WSubIndex(lhs, i, t.tpe, UNKNOWNGENDER), t.tpe, rhs, currentOffset) - stmts = stmts ++ substmts - currentOffset = tmpOffset - } - (currentOffset, stmts) + case t: VectorType => (0 until t.size foldRight (offset, Seq[Statement]())) { + case (i, (curOffset, stmts)) => + val subidx = WSubIndex(lhs, i, t.tpe, UNKNOWNGENDER) + val (tmpOffset, substmts) = getPart(subidx, t.tpe, rhs, curOffset) + (tmpOffset, stmts ++ substmts) } - case t: BundleType => { - var currentOffset = offset - var stmts = Seq.empty[Statement] - for (f <- t.fields.reverse) { - val (tmpOffset, substmts) = getPart(WSubField(lhs, f.name, f.tpe, UNKNOWNGENDER), f.tpe, rhs, currentOffset) - stmts = stmts ++ substmts - currentOffset = tmpOffset - } - (currentOffset, stmts) + case t: BundleType => (t.fields foldRight (offset, Seq[Statement]())) { + case (f, (curOffset, stmts)) => + val subfield = WSubField(lhs, f.name, f.tpe, UNKNOWNGENDER) + val (tmpOffset, substmts) = getPart(subfield, f.tpe, rhs, curOffset) + (tmpOffset, stmts ++ substmts) } case t: GroundType => getPartGround(lhs, t, rhs, offset) case t => error("Unknown type encountered in fromBits!") } +} + +object createMask { + def apply(dt: Type): Type = dt match { + case t: VectorType => VectorType(apply(t.tpe), t.size) + case t: BundleType => BundleType(t.fields map (f => f copy (tpe=apply(f.tpe)))) + case t: UIntType => BoolType + case t: SIntType => BoolType } } @@ -139,77 +147,87 @@ object MemPortUtils { def defaultPortSeq(mem: DefMemory) = Seq( Field("addr", Default, UIntType(IntWidth(ceil_log2(mem.depth)))), - Field("en", Default, UIntType(IntWidth(1))), + Field("en", Default, BoolType), Field("clk", Default, ClockType) ) - def getFillWMask(mem: DefMemory) = { - val maskGran = getInfo(mem.info, "maskGran") - if (maskGran == None) false - else maskGran.get == 1 - } + def getFillWMask(mem: DefMemory) = + getInfo(mem.info, "maskGran") match { + case None => false + case Some(maskGran) => maskGran == 1 + } - def rPortToBundle(mem: DefMemory) = BundleType(defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType)) - def rPortToFlattenBundle(mem: DefMemory) = BundleType(defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType))) + def rPortToBundle(mem: DefMemory) = BundleType( + defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType)) + def rPortToFlattenBundle(mem: DefMemory) = BundleType( + defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType))) - def wPortToBundle(mem: DefMemory) = { - val defaultSeq = defaultPortSeq(mem) :+ Field("data", Default, mem.dataType) - BundleType( - if (containsInfo(mem.info, "maskGran")) defaultSeq :+ Field("mask", Default, create_mask(mem.dataType)) - else defaultSeq - ) - } - - def wPortToFlattenBundle(mem: DefMemory) = { - val defaultSeq = defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType)) - BundleType( - if (containsInfo(mem.info, "maskGran")) { - defaultSeq :+ { - if (getFillWMask(mem)) Field("mask", Default, flattenType(mem.dataType)) - else Field("mask", Default, flattenType(create_mask(mem.dataType))) - } - } - else defaultSeq - ) - } - // TODO: Don't use create_mask??? + def wPortToBundle(mem: DefMemory) = BundleType( + (defaultPortSeq(mem) :+ Field("data", Default, mem.dataType)) ++ + (if (!containsInfo(mem.info, "maskGran")) Nil + else Seq(Field("mask", Default, createMask(mem.dataType)))) + ) + def wPortToFlattenBundle(mem: DefMemory) = BundleType( + (defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType))) ++ + (if (!containsInfo(mem.info, "maskGran")) Nil + else if (getFillWMask(mem)) Seq(Field("mask", Default, flattenType(mem.dataType))) + else Seq(Field("mask", Default, flattenType(createMask(mem.dataType))))) + ) + // TODO: Don't use createMask??? - def rwPortToBundle(mem: DefMemory) = { - val defaultSeq = defaultPortSeq(mem) ++ Seq( - Field("wmode", Default, UIntType(IntWidth(1))), + def rwPortToBundle(mem: DefMemory) = BundleType( + defaultPortSeq(mem) ++ Seq( + Field("wmode", Default, BoolType), Field("wdata", Default, mem.dataType), Field("rdata", Flip, mem.dataType) + ) ++ (if (!containsInfo(mem.info, "maskGran")) Nil + else Seq(Field("wmask", Default, createMask(mem.dataType))) ) - BundleType( - if (containsInfo(mem.info, "maskGran")) defaultSeq :+ Field("wmask", Default, create_mask(mem.dataType)) - else defaultSeq - ) - } + ) - def rwPortToFlattenBundle(mem: DefMemory) = { - val defaultSeq = defaultPortSeq(mem) ++ Seq( + def rwPortToFlattenBundle(mem: DefMemory) = BundleType( + defaultPortSeq(mem) ++ Seq( Field("wmode", Default, UIntType(IntWidth(1))), Field("wdata", Default, flattenType(mem.dataType)), Field("rdata", Flip, flattenType(mem.dataType)) - ) - BundleType( - if (containsInfo(mem.info, "maskGran")) { - defaultSeq :+ { - if (getFillWMask(mem)) Field("wmask", Default, flattenType(mem.dataType)) - else Field("wmask", Default, flattenType(create_mask(mem.dataType))) - } - } - else defaultSeq + ) ++ (if (!containsInfo(mem.info, "maskGran")) Nil + else if (getFillWMask(mem)) Seq(Field("wmask", Default, flattenType(mem.dataType))) + else Seq(Field("wmask", Default, flattenType(createMask(mem.dataType)))) ) - } + ) def memToBundle(s: DefMemory) = BundleType( - s.readers.map(p => Field(p, Default, rPortToBundle(s))) ++ - s.writers.map(p => Field(p, Default, wPortToBundle(s))) ++ - s.readwriters.map(p => Field(p, Default, rwPortToBundle(s)))) + s.readers.map(Field(_, Default, rPortToBundle(s))) ++ + s.writers.map(Field(_, Default, wPortToBundle(s))) ++ + s.readwriters.map(Field(_, Default, rwPortToBundle(s)))) def memToFlattenBundle(s: DefMemory) = BundleType( - s.readers.map(p => Field(p, Default, rPortToFlattenBundle(s))) ++ - s.writers.map(p => Field(p, Default, wPortToFlattenBundle(s))) ++ - s.readwriters.map(p => Field(p, Default, rwPortToFlattenBundle(s)))) + s.readers.map(Field(_, Default, rPortToFlattenBundle(s))) ++ + s.writers.map(Field(_, Default, wPortToFlattenBundle(s))) ++ + s.readwriters.map(Field(_, Default, rwPortToFlattenBundle(s)))) + + // Todo: merge it with memToBundle + def memType(mem: DefMemory) = { + val rType = rPortToBundle(mem) + val wType = BundleType(defaultPortSeq(mem) ++ Seq( + Field("data", Default, mem.dataType), + Field("mask", Default, createMask(mem.dataType)))) + val rwType = BundleType(defaultPortSeq(mem) ++ Seq( + Field("rdata", Flip, mem.dataType), + Field("wmode", Default, UIntType(IntWidth(1))), + Field("wdata", Default, mem.dataType), + Field("wmask", Default, createMask(mem.dataType)))) + BundleType( + (mem.readers map (Field(_, Flip, rType))) ++ + (mem.writers map (Field(_, Flip, wType))) ++ + (mem.readwriters map (Field(_, Flip, rwType)))) + } + + def kind(s: DefMemory) = MemKind(s.readers ++ s.writers ++ s.readwriters) + def memPortField(s: DefMemory, p: String, f: String) = { + val mem = WRef(s.name, memType(s), kind(s), UNKNOWNGENDER) + val t1 = field_type(mem.tpe, p) + val t2 = field_type(t1, f) + WSubField(WSubField(mem, p, t1, UNKNOWNGENDER), f, t2, UNKNOWNGENDER) + } } diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala index 2bae92a7..ca860ab6 100644 --- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala +++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala @@ -101,7 +101,7 @@ object RemoveCHIRRTL extends Pass { Connect(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), wmode, taddr), zero) ) def set_write (vec: Seq[MPort], data: String, mask: String) = vec flatMap {r => - val tmask = create_mask(s.tpe) + val tmask = createMask(s.tpe) IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), data, tdata)) +: (create_exps(SubField(SubField(Reference(s.name, ut), r.name, ut), mask, tmask)) map (Connect(s.info, _, zero)) @@ -160,7 +160,7 @@ object RemoveCHIRRTL extends Pass { e map get_mask(refs) match { case e: Reference => refs get e.name match { case None => e - case Some(p) => SubField(p.exp, p.mask, create_mask(e.tpe)) + case Some(p) => SubField(p.exp, p.mask, createMask(e.tpe)) } case e => e } diff --git a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/ReplaceMemMacros.scala index 54c522d7..7bb9c6c4 100644 --- a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala +++ b/src/main/scala/firrtl/passes/ReplaceMemMacros.scala @@ -117,7 +117,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass { ) ) if (containsInfo(wrapperMem.info, "maskGran")) { - val wrapperMask = create_mask(wrapperMem.dataType) + val wrapperMask = createMask(wrapperMem.dataType) val fillWMask = getFillWMask(wrapperMem) val bbMask = if (fillWMask) flattenType(wrapperMem.dataType) else flattenType(wrapperMask) val rhs = { @@ -150,7 +150,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass { ) ) if (containsInfo(wrapperMem.info, "maskGran")) { - val wrapperMask = create_mask(wrapperMem.dataType) + val wrapperMask = createMask(wrapperMem.dataType) val fillWMask = getFillWMask(wrapperMem) val bbMask = if (fillWMask) flattenType(wrapperMem.dataType) else flattenType(wrapperMask) val rhs = { |
