aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDonggyu Kim2016-09-02 17:07:46 -0700
committerDonggyu Kim2016-09-13 13:32:47 -0700
commitb8ee3179ed8070211c95ecbcceda0f7dbf635a13 (patch)
treeaf33cd813833f2b79b2fc52ed2af0ec9e4338ce5 /src
parentad36a1216f52bc01a27dac93cfd8cd42beb84c73 (diff)
clean up MemUtils
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/Utils.scala29
-rw-r--r--src/main/scala/firrtl/passes/MemUtils.scala222
-rw-r--r--src/main/scala/firrtl/passes/RemoveCHIRRTL.scala4
-rw-r--r--src/main/scala/firrtl/passes/ReplaceMemMacros.scala4
4 files changed, 135 insertions, 124 deletions
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index 572d1ccc..2c84528f 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -87,23 +87,16 @@ object Utils extends LazyLogging {
def min(a: BigInt, b: BigInt): BigInt = if (a >= b) b else a
def pow_minus_one(a: BigInt, b: BigInt): BigInt = a.pow(b.toInt) - 1
val BoolType = UIntType(IntWidth(1))
- val one = UIntLiteral(BigInt(1),IntWidth(1))
- val zero = UIntLiteral(BigInt(0),IntWidth(1))
- def uint (i:Int) : UIntLiteral = {
- val num_bits = req_num_bits(i)
- val w = IntWidth(scala.math.max(1,num_bits - 1))
- UIntLiteral(BigInt(i),w)
+ val one = UIntLiteral(BigInt(1), IntWidth(1))
+ val zero = UIntLiteral(BigInt(0), IntWidth(1))
+ def uint(i: Int): UIntLiteral = {
+ val num_bits = req_num_bits(i)
+ val w = IntWidth(scala.math.max(1, num_bits - 1))
+ UIntLiteral(BigInt(i), w)
}
- def req_num_bits (i: Int) : Int = {
- val ix = if (i < 0) ((-1 * i) - 1) else i
- ceil_log2(ix + 1) + 1
- }
-
- def create_mask(dt: Type): Type = dt match {
- case t: VectorType => VectorType(create_mask(t.tpe),t.size)
- case t: BundleType => BundleType(t.fields.map (f => f.copy(tpe=create_mask(f.tpe))))
- case t: UIntType => BoolType
- case t: SIntType => BoolType
+ def req_num_bits(i: Int): Int = {
+ val ix = if (i < 0) ((-1 * i) - 1) else i
+ ceil_log2(ix + 1) + 1
}
def create_exps(n: String, t: Type): Seq[Expression] =
@@ -390,11 +383,11 @@ object Utils extends LazyLogging {
val clk = Field("clk", Default, ClockType)
val def_data = Field("data", Default, s.dataType)
val rev_data = Field("data", Flip, s.dataType)
- val mask = Field("mask", Default, create_mask(s.dataType))
+ val mask = Field("mask", Default, passes.createMask(s.dataType))
val wmode = Field("wmode", Default, UIntType(IntWidth(1)))
val rdata = Field("rdata", Flip, s.dataType)
val wdata = Field("wdata", Default, s.dataType)
- val wmask = Field("wmask", Default, create_mask(s.dataType))
+ val wmask = Field("wmask", Default, passes.createMask(s.dataType))
val read_type = BundleType(Seq(rev_data, addr, en, clk))
val write_type = BundleType(Seq(def_data, mask, addr, en, clk))
val readwrite_type = BundleType(Seq(wmode, rdata, wdata, wmask, addr, en, clk))
diff --git a/src/main/scala/firrtl/passes/MemUtils.scala b/src/main/scala/firrtl/passes/MemUtils.scala
index adbf23e5..09be2b38 100644
--- a/src/main/scala/firrtl/passes/MemUtils.scala
+++ b/src/main/scala/firrtl/passes/MemUtils.scala
@@ -38,24 +38,23 @@ object seqCat {
def apply(args: Seq[Expression]): Expression = args.length match {
case 0 => error("Empty Seq passed to seqcat")
case 1 => args(0)
- case 2 => DoPrim(PrimOps.Cat, args, Seq.empty[BigInt], UIntType(UnknownWidth))
- case _ => {
- val seqs = args.splitAt(args.length/2)
- DoPrim(PrimOps.Cat, Seq(seqCat(seqs._1), seqCat(seqs._2)), Seq.empty[BigInt], UIntType(UnknownWidth))
- }
+ case 2 => DoPrim(PrimOps.Cat, args, Nil, UIntType(UnknownWidth))
+ case _ =>
+ val (high, low) = args splitAt (args.length / 2)
+ DoPrim(PrimOps.Cat, Seq(seqCat(high), seqCat(low)), Nil, UIntType(UnknownWidth))
}
}
object toBits {
def apply(e: Expression): Expression = e match {
- case ex: WRef => hiercat(ex, ex.tpe)
- case ex: WSubField => hiercat(ex, ex.tpe)
- case ex: WSubIndex => hiercat(ex, ex.tpe)
+ case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiercat(ex, ex.tpe)
case t => error("Invalid operand expression for toBits!")
}
- def hiercat(e: Expression, dt: Type): Expression = dt match {
- case t: VectorType => seqCat((0 until t.size).reverse.map(i => hiercat(WSubIndex(e, i, t.tpe, UNKNOWNGENDER), t.tpe)))
- case t: BundleType => seqCat(t.fields.map(f => hiercat(WSubField(e, f.name, f.tpe, UNKNOWNGENDER), f.tpe)))
+ private def hiercat(e: Expression, dt: Type): Expression = dt match {
+ case t: VectorType => seqCat((0 until t.size) map (i =>
+ hiercat(WSubIndex(e, i, t.tpe, UNKNOWNGENDER),t.tpe)))
+ case t: BundleType => seqCat(t.fields map (f =>
+ hiercat(WSubField(e, f.name, f.tpe, UNKNOWNGENDER), f.tpe)))
case t: GroundType => e
case t => error("Unknown type encountered in toBits!")
}
@@ -64,23 +63,28 @@ object toBits {
// TODO: make easier to understand
object toBitMask {
def apply(e: Expression, dataType: Type): Expression = e match {
- case ex: WRef => hiermask(ex, ex.tpe, dataType)
- case ex: WSubField => hiermask(ex, ex.tpe, dataType)
- case ex: WSubIndex => hiermask(ex, ex.tpe, dataType)
+ case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiermask(ex, ex.tpe, dataType)
case t => error("Invalid operand expression for toBits!")
}
- def hiermask(e: Expression, maskType: Type, dataType: Type): Expression = (maskType, dataType) match {
- case (mt: VectorType, dt: VectorType) => seqCat((0 until mt.size).reverse.map(i => hiermask(WSubIndex(e, i, mt.tpe, UNKNOWNGENDER), mt.tpe, dt.tpe)))
- case (mt: BundleType, dt: BundleType) => seqCat((mt.fields zip dt.fields).map { case (mf, df) =>
- hiermask(WSubField(e, mf.name, mf.tpe, UNKNOWNGENDER), mf.tpe, df.tpe) } )
- case (mt: UIntType, dt: GroundType) => seqCat(List.fill(bitWidth(dt).intValue)(e))
- case (mt, dt) => error("Invalid type for mask component!")
- }
+ private def hiermask(e: Expression, maskType: Type, dataType: Type): Expression =
+ (maskType, dataType) match {
+ case (mt: VectorType, dt: VectorType) =>
+ seqCat((0 until mt.size).reverse map { i =>
+ hiermask(WSubIndex(e, i, mt.tpe, UNKNOWNGENDER), mt.tpe, dt.tpe)
+ })
+ case (mt: BundleType, dt: BundleType) =>
+ seqCat((mt.fields zip dt.fields) map { case (mf, df) =>
+ hiermask(WSubField(e, mf.name, mf.tpe, UNKNOWNGENDER), mf.tpe, df.tpe)
+ })
+ case (mt: UIntType, dt: GroundType) =>
+ seqCat(List.fill(bitWidth(dt).intValue)(e))
+ case (mt, dt) => error("Invalid type for mask component!")
+ }
}
object bitWidth {
def apply(dt: Type): BigInt = widthOf(dt)
- def widthOf(dt: Type): BigInt = dt match {
+ private def widthOf(dt: Type): BigInt = dt match {
case t: VectorType => t.size * bitWidth(t.tpe)
case t: BundleType => t.fields.map(f => bitWidth(f.tpe)).foldLeft(BigInt(0))(_+_)
case GroundType(IntWidth(width)) => width
@@ -91,43 +95,47 @@ object bitWidth {
object fromBits {
def apply(lhs: Expression, rhs: Expression): Statement = {
val fbits = lhs match {
- case ex: WRef => getPart(ex, ex.tpe, rhs, 0)
- case ex: WSubField => getPart(ex, ex.tpe, rhs, 0)
- case ex: WSubIndex => getPart(ex, ex.tpe, rhs, 0)
- case t => error("Invalid LHS expression for fromBits!")
+ case ex @ (_: WRef | _: WSubField | _: WSubIndex) => getPart(ex, ex.tpe, rhs, 0)
+ case _ => error("Invalid LHS expression for fromBits!")
}
Block(fbits._2)
}
- def getPartGround(lhs: Expression, lhst: Type, rhs: Expression, offset: BigInt): (BigInt, Seq[Statement]) = {
+ private def getPartGround(lhs: Expression,
+ lhst: Type,
+ rhs: Expression,
+ offset: BigInt): (BigInt, Seq[Statement]) = {
val intWidth = bitWidth(lhst)
- val sel = DoPrim(PrimOps.Bits, Seq(rhs), Seq(offset+intWidth-1, offset), UnknownType)
+ val sel = DoPrim(PrimOps.Bits, Seq(rhs), Seq(offset + intWidth - 1, offset), UnknownType)
(offset + intWidth, Seq(Connect(NoInfo, lhs, sel)))
}
- def getPart(lhs: Expression, lhst: Type, rhs: Expression, offset: BigInt): (BigInt, Seq[Statement]) = {
+ private def getPart(lhs: Expression,
+ lhst: Type,
+ rhs: Expression,
+ offset: BigInt): (BigInt, Seq[Statement]) =
lhst match {
- case t: VectorType => {
- var currentOffset = offset
- var stmts = Seq.empty[Statement]
- for (i <- (0 until t.size)) {
- val (tmpOffset, substmts) = getPart(WSubIndex(lhs, i, t.tpe, UNKNOWNGENDER), t.tpe, rhs, currentOffset)
- stmts = stmts ++ substmts
- currentOffset = tmpOffset
- }
- (currentOffset, stmts)
+ case t: VectorType => (0 until t.size foldRight (offset, Seq[Statement]())) {
+ case (i, (curOffset, stmts)) =>
+ val subidx = WSubIndex(lhs, i, t.tpe, UNKNOWNGENDER)
+ val (tmpOffset, substmts) = getPart(subidx, t.tpe, rhs, curOffset)
+ (tmpOffset, stmts ++ substmts)
}
- case t: BundleType => {
- var currentOffset = offset
- var stmts = Seq.empty[Statement]
- for (f <- t.fields.reverse) {
- val (tmpOffset, substmts) = getPart(WSubField(lhs, f.name, f.tpe, UNKNOWNGENDER), f.tpe, rhs, currentOffset)
- stmts = stmts ++ substmts
- currentOffset = tmpOffset
- }
- (currentOffset, stmts)
+ case t: BundleType => (t.fields foldRight (offset, Seq[Statement]())) {
+ case (f, (curOffset, stmts)) =>
+ val subfield = WSubField(lhs, f.name, f.tpe, UNKNOWNGENDER)
+ val (tmpOffset, substmts) = getPart(subfield, f.tpe, rhs, curOffset)
+ (tmpOffset, stmts ++ substmts)
}
case t: GroundType => getPartGround(lhs, t, rhs, offset)
case t => error("Unknown type encountered in fromBits!")
}
+}
+
+object createMask {
+ def apply(dt: Type): Type = dt match {
+ case t: VectorType => VectorType(apply(t.tpe), t.size)
+ case t: BundleType => BundleType(t.fields map (f => f copy (tpe=apply(f.tpe))))
+ case t: UIntType => BoolType
+ case t: SIntType => BoolType
}
}
@@ -139,77 +147,87 @@ object MemPortUtils {
def defaultPortSeq(mem: DefMemory) = Seq(
Field("addr", Default, UIntType(IntWidth(ceil_log2(mem.depth)))),
- Field("en", Default, UIntType(IntWidth(1))),
+ Field("en", Default, BoolType),
Field("clk", Default, ClockType)
)
- def getFillWMask(mem: DefMemory) = {
- val maskGran = getInfo(mem.info, "maskGran")
- if (maskGran == None) false
- else maskGran.get == 1
- }
+ def getFillWMask(mem: DefMemory) =
+ getInfo(mem.info, "maskGran") match {
+ case None => false
+ case Some(maskGran) => maskGran == 1
+ }
- def rPortToBundle(mem: DefMemory) = BundleType(defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType))
- def rPortToFlattenBundle(mem: DefMemory) = BundleType(defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType)))
+ def rPortToBundle(mem: DefMemory) = BundleType(
+ defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType))
+ def rPortToFlattenBundle(mem: DefMemory) = BundleType(
+ defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType)))
- def wPortToBundle(mem: DefMemory) = {
- val defaultSeq = defaultPortSeq(mem) :+ Field("data", Default, mem.dataType)
- BundleType(
- if (containsInfo(mem.info, "maskGran")) defaultSeq :+ Field("mask", Default, create_mask(mem.dataType))
- else defaultSeq
- )
- }
-
- def wPortToFlattenBundle(mem: DefMemory) = {
- val defaultSeq = defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType))
- BundleType(
- if (containsInfo(mem.info, "maskGran")) {
- defaultSeq :+ {
- if (getFillWMask(mem)) Field("mask", Default, flattenType(mem.dataType))
- else Field("mask", Default, flattenType(create_mask(mem.dataType)))
- }
- }
- else defaultSeq
- )
- }
- // TODO: Don't use create_mask???
+ def wPortToBundle(mem: DefMemory) = BundleType(
+ (defaultPortSeq(mem) :+ Field("data", Default, mem.dataType)) ++
+ (if (!containsInfo(mem.info, "maskGran")) Nil
+ else Seq(Field("mask", Default, createMask(mem.dataType))))
+ )
+ def wPortToFlattenBundle(mem: DefMemory) = BundleType(
+ (defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType))) ++
+ (if (!containsInfo(mem.info, "maskGran")) Nil
+ else if (getFillWMask(mem)) Seq(Field("mask", Default, flattenType(mem.dataType)))
+ else Seq(Field("mask", Default, flattenType(createMask(mem.dataType)))))
+ )
+ // TODO: Don't use createMask???
- def rwPortToBundle(mem: DefMemory) = {
- val defaultSeq = defaultPortSeq(mem) ++ Seq(
- Field("wmode", Default, UIntType(IntWidth(1))),
+ def rwPortToBundle(mem: DefMemory) = BundleType(
+ defaultPortSeq(mem) ++ Seq(
+ Field("wmode", Default, BoolType),
Field("wdata", Default, mem.dataType),
Field("rdata", Flip, mem.dataType)
+ ) ++ (if (!containsInfo(mem.info, "maskGran")) Nil
+ else Seq(Field("wmask", Default, createMask(mem.dataType)))
)
- BundleType(
- if (containsInfo(mem.info, "maskGran")) defaultSeq :+ Field("wmask", Default, create_mask(mem.dataType))
- else defaultSeq
- )
- }
+ )
- def rwPortToFlattenBundle(mem: DefMemory) = {
- val defaultSeq = defaultPortSeq(mem) ++ Seq(
+ def rwPortToFlattenBundle(mem: DefMemory) = BundleType(
+ defaultPortSeq(mem) ++ Seq(
Field("wmode", Default, UIntType(IntWidth(1))),
Field("wdata", Default, flattenType(mem.dataType)),
Field("rdata", Flip, flattenType(mem.dataType))
- )
- BundleType(
- if (containsInfo(mem.info, "maskGran")) {
- defaultSeq :+ {
- if (getFillWMask(mem)) Field("wmask", Default, flattenType(mem.dataType))
- else Field("wmask", Default, flattenType(create_mask(mem.dataType)))
- }
- }
- else defaultSeq
+ ) ++ (if (!containsInfo(mem.info, "maskGran")) Nil
+ else if (getFillWMask(mem)) Seq(Field("wmask", Default, flattenType(mem.dataType)))
+ else Seq(Field("wmask", Default, flattenType(createMask(mem.dataType))))
)
- }
+ )
def memToBundle(s: DefMemory) = BundleType(
- s.readers.map(p => Field(p, Default, rPortToBundle(s))) ++
- s.writers.map(p => Field(p, Default, wPortToBundle(s))) ++
- s.readwriters.map(p => Field(p, Default, rwPortToBundle(s))))
+ s.readers.map(Field(_, Default, rPortToBundle(s))) ++
+ s.writers.map(Field(_, Default, wPortToBundle(s))) ++
+ s.readwriters.map(Field(_, Default, rwPortToBundle(s))))
def memToFlattenBundle(s: DefMemory) = BundleType(
- s.readers.map(p => Field(p, Default, rPortToFlattenBundle(s))) ++
- s.writers.map(p => Field(p, Default, wPortToFlattenBundle(s))) ++
- s.readwriters.map(p => Field(p, Default, rwPortToFlattenBundle(s))))
+ s.readers.map(Field(_, Default, rPortToFlattenBundle(s))) ++
+ s.writers.map(Field(_, Default, wPortToFlattenBundle(s))) ++
+ s.readwriters.map(Field(_, Default, rwPortToFlattenBundle(s))))
+
+ // Todo: merge it with memToBundle
+ def memType(mem: DefMemory) = {
+ val rType = rPortToBundle(mem)
+ val wType = BundleType(defaultPortSeq(mem) ++ Seq(
+ Field("data", Default, mem.dataType),
+ Field("mask", Default, createMask(mem.dataType))))
+ val rwType = BundleType(defaultPortSeq(mem) ++ Seq(
+ Field("rdata", Flip, mem.dataType),
+ Field("wmode", Default, UIntType(IntWidth(1))),
+ Field("wdata", Default, mem.dataType),
+ Field("wmask", Default, createMask(mem.dataType))))
+ BundleType(
+ (mem.readers map (Field(_, Flip, rType))) ++
+ (mem.writers map (Field(_, Flip, wType))) ++
+ (mem.readwriters map (Field(_, Flip, rwType))))
+ }
+
+ def kind(s: DefMemory) = MemKind(s.readers ++ s.writers ++ s.readwriters)
+ def memPortField(s: DefMemory, p: String, f: String) = {
+ val mem = WRef(s.name, memType(s), kind(s), UNKNOWNGENDER)
+ val t1 = field_type(mem.tpe, p)
+ val t2 = field_type(t1, f)
+ WSubField(WSubField(mem, p, t1, UNKNOWNGENDER), f, t2, UNKNOWNGENDER)
+ }
}
diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
index 2bae92a7..ca860ab6 100644
--- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
+++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
@@ -101,7 +101,7 @@ object RemoveCHIRRTL extends Pass {
Connect(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), wmode, taddr), zero)
)
def set_write (vec: Seq[MPort], data: String, mask: String) = vec flatMap {r =>
- val tmask = create_mask(s.tpe)
+ val tmask = createMask(s.tpe)
IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), data, tdata)) +:
(create_exps(SubField(SubField(Reference(s.name, ut), r.name, ut), mask, tmask))
map (Connect(s.info, _, zero))
@@ -160,7 +160,7 @@ object RemoveCHIRRTL extends Pass {
e map get_mask(refs) match {
case e: Reference => refs get e.name match {
case None => e
- case Some(p) => SubField(p.exp, p.mask, create_mask(e.tpe))
+ case Some(p) => SubField(p.exp, p.mask, createMask(e.tpe))
}
case e => e
}
diff --git a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/ReplaceMemMacros.scala
index 54c522d7..7bb9c6c4 100644
--- a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala
+++ b/src/main/scala/firrtl/passes/ReplaceMemMacros.scala
@@ -117,7 +117,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass {
)
)
if (containsInfo(wrapperMem.info, "maskGran")) {
- val wrapperMask = create_mask(wrapperMem.dataType)
+ val wrapperMask = createMask(wrapperMem.dataType)
val fillWMask = getFillWMask(wrapperMem)
val bbMask = if (fillWMask) flattenType(wrapperMem.dataType) else flattenType(wrapperMask)
val rhs = {
@@ -150,7 +150,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass {
)
)
if (containsInfo(wrapperMem.info, "maskGran")) {
- val wrapperMask = create_mask(wrapperMem.dataType)
+ val wrapperMask = createMask(wrapperMem.dataType)
val fillWMask = getFillWMask(wrapperMem)
val bbMask = if (fillWMask) flattenType(wrapperMem.dataType) else flattenType(wrapperMask)
val rhs = {