aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAngie2017-02-22 20:03:27 -0800
committerAdam Izraelevitz2017-02-23 14:55:00 -0800
commit1f9fd2f9b9e9a0117b0dd65524c9dcb767c02778 (patch)
treeb4ef1d8fdadd89e942e321cf4495a127cb9bfe59 /src
parent1d652352b752502dd6d130aeb85981df214d7021 (diff)
move more general utils out of memutils, mov WIR helpers to WIR.scala and update uses
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/PrimOps.scala4
-rw-r--r--src/main/scala/firrtl/Utils.scala106
-rw-r--r--src/main/scala/firrtl/WIR.scala7
-rw-r--r--src/main/scala/firrtl/passes/memlib/InferReadWrite.scala16
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemUtils.scala118
-rw-r--r--src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala4
-rw-r--r--src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala30
7 files changed, 140 insertions, 145 deletions
diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala
index fa523e12..1ca005d7 100644
--- a/src/main/scala/firrtl/PrimOps.scala
+++ b/src/main/scala/firrtl/PrimOps.scala
@@ -120,8 +120,8 @@ object PrimOps extends LazyLogging {
def t1 = e.args.head.tpe
def t2 = e.args(1).tpe
def t3 = e.args(2).tpe
- def w1 = passes.getWidth(e.args.head.tpe)
- def w2 = passes.getWidth(e.args(1).tpe)
+ def w1 = getWidth(e.args.head.tpe)
+ def w2 = getWidth(e.args(1).tpe)
def p1 = t1 match { case FixedType(w, p) => p } //Intentional
def p2 = t2 match { case FixedType(w, p) => p } //Intentional
def c1 = IntWidth(e.consts.head)
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index fe6472a8..e3f0fee1 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -12,6 +12,112 @@ import scala.collection.mutable.{StringBuilder, ArrayBuffer, LinkedHashMap, Hash
import java.io.PrintWriter
import logger.LazyLogging
+object seqCat {
+ def apply(args: Seq[Expression]): Expression = args.length match {
+ case 0 => error("Empty Seq passed to seqcat")
+ case 1 => args.head
+ case 2 => DoPrim(PrimOps.Cat, args, Nil, UIntType(UnknownWidth))
+ case _ =>
+ val (high, low) = args splitAt (args.length / 2)
+ DoPrim(PrimOps.Cat, Seq(seqCat(high), seqCat(low)), Nil, UIntType(UnknownWidth))
+ }
+}
+
+/** Given an expression, return an expression consisting of all sub-expressions
+ * concatenated (or flattened).
+ */
+object toBits {
+ def apply(e: Expression): Expression = e match {
+ case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiercat(ex)
+ case t => error("Invalid operand expression for toBits!")
+ }
+ private def hiercat(e: Expression): Expression = e.tpe match {
+ case t: VectorType => seqCat((0 until t.size).reverse map (i =>
+ hiercat(WSubIndex(e, i, t.tpe, UNKNOWNGENDER))))
+ case t: BundleType => seqCat(t.fields map (f =>
+ hiercat(WSubField(e, f.name, f.tpe, UNKNOWNGENDER))))
+ case t: GroundType => DoPrim(AsUInt, Seq(e), Seq.empty, UnknownType)
+ case t => error("Unknown type encountered in toBits!")
+ }
+}
+
+object getWidth {
+ def apply(t: Type): Width = t match {
+ case t: GroundType => t.width
+ case _ => error("No width!")
+ }
+ def apply(e: Expression): Width = apply(e.tpe)
+}
+
+object bitWidth {
+ def apply(dt: Type): BigInt = widthOf(dt)
+ private def widthOf(dt: Type): BigInt = dt match {
+ case t: VectorType => t.size * bitWidth(t.tpe)
+ case t: BundleType => t.fields.map(f => bitWidth(f.tpe)).foldLeft(BigInt(0))(_+_)
+ case GroundType(IntWidth(width)) => width
+ case t => error("Unknown type encountered in bitWidth!")
+ }
+}
+
+object castRhs {
+ def apply(lhst: Type, rhs: Expression) = {
+ lhst match {
+ case _: SIntType => DoPrim(AsSInt, Seq(rhs), Seq.empty, lhst)
+ case FixedType(_, IntWidth(p)) => DoPrim(AsFixedPoint, Seq(rhs), Seq(p), lhst)
+ case ClockType => DoPrim(AsClock, Seq(rhs), Seq.empty, lhst)
+ case _: UIntType => rhs
+ }
+ }
+}
+
+object fromBits {
+ def apply(lhs: Expression, rhs: Expression): Statement = {
+ val fbits = lhs match {
+ case ex @ (_: WRef | _: WSubField | _: WSubIndex) => getPart(ex, ex.tpe, rhs, 0)
+ case _ => error("Invalid LHS expression for fromBits!")
+ }
+ Block(fbits._2)
+ }
+ private def getPartGround(lhs: Expression,
+ lhst: Type,
+ rhs: Expression,
+ offset: BigInt): (BigInt, Seq[Statement]) = {
+ val intWidth = bitWidth(lhst)
+ val sel = DoPrim(PrimOps.Bits, Seq(rhs), Seq(offset + intWidth - 1, offset), UnknownType)
+ val rhsConnect = castRhs(lhst, sel)
+ (offset + intWidth, Seq(Connect(NoInfo, lhs, rhsConnect)))
+ }
+ private def getPart(lhs: Expression,
+ lhst: Type,
+ rhs: Expression,
+ offset: BigInt): (BigInt, Seq[Statement]) =
+ lhst match {
+ case t: VectorType => (0 until t.size foldLeft (offset, Seq[Statement]())) {
+ case ((curOffset, stmts), i) =>
+ val subidx = WSubIndex(lhs, i, t.tpe, UNKNOWNGENDER)
+ val (tmpOffset, substmts) = getPart(subidx, t.tpe, rhs, curOffset)
+ (tmpOffset, stmts ++ substmts)
+ }
+ case t: BundleType => (t.fields foldRight (offset, Seq[Statement]())) {
+ case (f, (curOffset, stmts)) =>
+ val subfield = WSubField(lhs, f.name, f.tpe, UNKNOWNGENDER)
+ val (tmpOffset, substmts) = getPart(subfield, f.tpe, rhs, curOffset)
+ (tmpOffset, stmts ++ substmts)
+ }
+ case t: GroundType => getPartGround(lhs, t, rhs, offset)
+ case t => error("Unknown type encountered in fromBits!")
+ }
+}
+
+object connectFields {
+ def apply(lref: Expression, lname: String, rref: Expression, rname: String): Connect =
+ Connect(NoInfo, WSubField(lref, lname), WSubField(rref, rname))
+}
+
+object flattenType {
+ def apply(t: Type) = UIntType(IntWidth(bitWidth(t)))
+}
+
class FIRRTLException(str: String) extends Exception(str)
object Utils extends LazyLogging {
diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala
index bc256b68..946906fa 100644
--- a/src/main/scala/firrtl/WIR.scala
+++ b/src/main/scala/firrtl/WIR.scala
@@ -35,6 +35,7 @@ object WRef {
def apply(wire: DefWire): WRef = new WRef(wire.name, wire.tpe, WireKind, UNKNOWNGENDER)
/** Creates a WRef from a Register */
def apply(reg: DefRegister): WRef = new WRef(reg.name, reg.tpe, RegKind, UNKNOWNGENDER)
+ def apply(n: String, t: Type = UnknownType, k: Kind = ExpKind): WRef = new WRef(n, t, k, UNKNOWNGENDER)
}
case class WSubField(exp: Expression, name: String, tpe: Type, gender: Gender) extends Expression {
def serialize: String = s"${exp.serialize}.$name"
@@ -42,6 +43,9 @@ case class WSubField(exp: Expression, name: String, tpe: Type, gender: Gender) e
def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
def mapWidth(f: Width => Width): Expression = this
}
+object WSubField {
+ def apply(exp: Expression, n: String): WSubField = new WSubField(exp, n, field_type(exp.tpe, n), UNKNOWNGENDER)
+}
case class WSubIndex(exp: Expression, value: Int, tpe: Type, gender: Gender) extends Expression {
def serialize: String = s"${exp.serialize}[$value]"
def mapExpr(f: Expression => Expression): Expression = this.copy(exp = f(exp))
@@ -83,6 +87,9 @@ case class WDefInstance(info: Info, name: String, module: String, tpe: Type) ext
def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
def mapString(f: String => String): Statement = this.copy(name = f(name))
}
+object WDefInstance {
+ def apply(name: String, module: String): WDefInstance = new WDefInstance(NoInfo, name, module, UnknownType)
+}
case class WDefInstanceConnector(
info: Info,
name: String,
diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
index 2501ba04..554c1f0d 100644
--- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
+++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
@@ -101,24 +101,24 @@ object InferReadWritePass extends Pass {
val rclk = getOrigin(connects)(memPortField(mem, r, "clk"))
if (weq(wclk, rclk) && (wp exists (a => rp exists (b => checkComplement(a, b))))) {
val rw = namespace newName "rw"
- val rwExp = createSubField(createRef(mem.name), rw)
+ val rwExp = WSubField(WRef(mem.name), rw)
readwriters += rw
readers += r
writers += w
repl(memPortField(mem, r, "clk")) = EmptyExpression
repl(memPortField(mem, r, "en")) = EmptyExpression
repl(memPortField(mem, r, "addr")) = EmptyExpression
- repl(memPortField(mem, r, "data")) = createSubField(rwExp, "rdata")
+ repl(memPortField(mem, r, "data")) = WSubField(rwExp, "rdata")
repl(memPortField(mem, w, "clk")) = EmptyExpression
- repl(memPortField(mem, w, "en")) = createSubField(rwExp, "wmode")
+ repl(memPortField(mem, w, "en")) = WSubField(rwExp, "wmode")
repl(memPortField(mem, w, "addr")) = EmptyExpression
- repl(memPortField(mem, w, "data")) = createSubField(rwExp, "wdata")
- repl(memPortField(mem, w, "mask")) = createSubField(rwExp, "wmask")
- stmts += Connect(NoInfo, createSubField(rwExp, "clk"), wclk)
- stmts += Connect(NoInfo, createSubField(rwExp, "en"),
+ repl(memPortField(mem, w, "data")) = WSubField(rwExp, "wdata")
+ repl(memPortField(mem, w, "mask")) = WSubField(rwExp, "wmask")
+ stmts += Connect(NoInfo, WSubField(rwExp, "clk"), wclk)
+ stmts += Connect(NoInfo, WSubField(rwExp, "en"),
DoPrim(Or, Seq(connects(memPortField(mem, r, "en")),
connects(memPortField(mem, w, "en"))), Nil, BoolType))
- stmts += Connect(NoInfo, createSubField(rwExp, "addr"),
+ stmts += Connect(NoInfo, WSubField(rwExp, "addr"),
Mux(connects(memPortField(mem, w, "en")),
connects(memPortField(mem, w, "addr")),
connects(memPortField(mem, r, "addr")), UnknownType))
diff --git a/src/main/scala/firrtl/passes/memlib/MemUtils.scala b/src/main/scala/firrtl/passes/memlib/MemUtils.scala
index c7eb4539..40f81555 100644
--- a/src/main/scala/firrtl/passes/memlib/MemUtils.scala
+++ b/src/main/scala/firrtl/passes/memlib/MemUtils.scala
@@ -7,35 +7,6 @@ import firrtl.ir._
import firrtl.Utils._
import firrtl.PrimOps._
-object seqCat {
- def apply(args: Seq[Expression]): Expression = args.length match {
- case 0 => error("Empty Seq passed to seqcat")
- case 1 => args.head
- case 2 => DoPrim(PrimOps.Cat, args, Nil, UIntType(UnknownWidth))
- case _ =>
- val (high, low) = args splitAt (args.length / 2)
- DoPrim(PrimOps.Cat, Seq(seqCat(high), seqCat(low)), Nil, UIntType(UnknownWidth))
- }
-}
-
-/** Given an expression, return an expression consisting of all sub-expressions
- * concatenated (or flattened).
- */
-object toBits {
- def apply(e: Expression): Expression = e match {
- case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiercat(ex)
- case t => error("Invalid operand expression for toBits!")
- }
- private def hiercat(e: Expression): Expression = e.tpe match {
- case t: VectorType => seqCat((0 until t.size).reverse map (i =>
- hiercat(WSubIndex(e, i, t.tpe, UNKNOWNGENDER))))
- case t: BundleType => seqCat(t.fields map (f =>
- hiercat(WSubField(e, f.name, f.tpe, UNKNOWNGENDER))))
- case t: GroundType => DoPrim(AsUInt, Seq(e), Seq.empty, UnknownType)
- case t => error("Unknown type encountered in toBits!")
- }
-}
-
/** Given a mask, return a bitmask corresponding to the desired datatype.
* Requirements:
* - The mask type and datatype must be equivalent, except any ground type in
@@ -71,74 +42,6 @@ object toBitMask {
}
}
-object getWidth {
- def apply(t: Type): Width = t match {
- case t: GroundType => t.width
- case _ => error("No width!")
- }
- def apply(e: Expression): Width = apply(e.tpe)
-}
-
-object bitWidth {
- def apply(dt: Type): BigInt = widthOf(dt)
- private def widthOf(dt: Type): BigInt = dt match {
- case t: VectorType => t.size * bitWidth(t.tpe)
- case t: BundleType => t.fields.map(f => bitWidth(f.tpe)).foldLeft(BigInt(0))(_+_)
- case GroundType(IntWidth(width)) => width
- case t => error("Unknown type encountered in bitWidth!")
- }
-}
-
-object castRhs {
- def apply(lhst: Type, rhs: Expression) = {
- lhst match {
- case _: SIntType => DoPrim(AsSInt, Seq(rhs), Seq.empty, lhst)
- case FixedType(_, IntWidth(p)) => DoPrim(AsFixedPoint, Seq(rhs), Seq(p), lhst)
- case ClockType => DoPrim(AsClock, Seq(rhs), Seq.empty, lhst)
- case _: UIntType => rhs
- }
- }
-}
-
-object fromBits {
- def apply(lhs: Expression, rhs: Expression): Statement = {
- val fbits = lhs match {
- case ex @ (_: WRef | _: WSubField | _: WSubIndex) => getPart(ex, ex.tpe, rhs, 0)
- case _ => error("Invalid LHS expression for fromBits!")
- }
- Block(fbits._2)
- }
- private def getPartGround(lhs: Expression,
- lhst: Type,
- rhs: Expression,
- offset: BigInt): (BigInt, Seq[Statement]) = {
- val intWidth = bitWidth(lhst)
- val sel = DoPrim(PrimOps.Bits, Seq(rhs), Seq(offset + intWidth - 1, offset), UnknownType)
- val rhsConnect = castRhs(lhst, sel)
- (offset + intWidth, Seq(Connect(NoInfo, lhs, rhsConnect)))
- }
- private def getPart(lhs: Expression,
- lhst: Type,
- rhs: Expression,
- offset: BigInt): (BigInt, Seq[Statement]) =
- lhst match {
- case t: VectorType => (0 until t.size foldLeft (offset, Seq[Statement]())) {
- case ((curOffset, stmts), i) =>
- val subidx = WSubIndex(lhs, i, t.tpe, UNKNOWNGENDER)
- val (tmpOffset, substmts) = getPart(subidx, t.tpe, rhs, curOffset)
- (tmpOffset, stmts ++ substmts)
- }
- case t: BundleType => (t.fields foldRight (offset, Seq[Statement]())) {
- case (f, (curOffset, stmts)) =>
- val subfield = WSubField(lhs, f.name, f.tpe, UNKNOWNGENDER)
- val (tmpOffset, substmts) = getPart(subfield, f.tpe, rhs, curOffset)
- (tmpOffset, stmts ++ substmts)
- }
- case t: GroundType => getPartGround(lhs, t, rhs, offset)
- case t => error("Unknown type encountered in fromBits!")
- }
-}
-
object createMask {
def apply(dt: Type): Type = dt match {
case t: VectorType => VectorType(apply(t.tpe), t.size)
@@ -147,27 +50,6 @@ object createMask {
}
}
-object createRef {
- def apply(n: String, t: Type = UnknownType, k: Kind = ExpKind) = WRef(n, t, k, UNKNOWNGENDER)
-}
-
-object createSubField {
- def apply(exp: Expression, n: String) = WSubField(exp, n, field_type(exp.tpe, n), UNKNOWNGENDER)
-}
-
-object createInstance {
- def apply(name: String, module: String) = WDefInstance(NoInfo, name, module, UnknownType)
-}
-
-object connectFields {
- def apply(lref: Expression, lname: String, rref: Expression, rname: String): Connect =
- Connect(NoInfo, createSubField(lref, lname), createSubField(rref, rname))
-}
-
-object flattenType {
- def apply(t: Type) = UIntType(IntWidth(bitWidth(t)))
-}
-
object MemPortUtils {
type MemPortMap = collection.mutable.HashMap[String, Expression]
type Memories = collection.mutable.ArrayBuffer[DefMemory]
diff --git a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala
index 81242810..57c301b1 100644
--- a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala
+++ b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala
@@ -43,8 +43,8 @@ object RenameAnnotatedMemoryPorts extends Pass {
def updateMemPortMap(ports: Seq[String], fields: Seq[String], newPortKind: String): Unit =
for ((p, i) <- ports.zipWithIndex; f <- fields) {
- val newPort = createSubField(createRef(m.name), newPortKind + i)
- val field = createSubField(newPort, f)
+ val newPort = WSubField(WRef(m.name), newPortKind + i)
+ val field = WSubField(newPort, f)
memPortMap(s"${m.name}.$p.$f") = field
}
updateMemPortMap(m.readers, rFields, "R")
diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala
index 44dad557..03fd5ffa 100644
--- a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala
+++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala
@@ -101,14 +101,14 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform {
// Creates a type with the write/readwrite masks omitted if necessary
val bbIoType = memToFlattenBundle(m)
val bbIoPorts = bbIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe))
- val bbRef = createRef(m.name, bbIoType)
+ val bbRef = WRef(m.name, bbIoType)
val hasMask = m.maskGran.isDefined
val fillMask = getFillWMask(m)
- def portRef(p: String) = createRef(p, field_type(wrapperIoType, p))
+ def portRef(p: String) = WRef(p, field_type(wrapperIoType, p))
val stmts = Seq(WDefInstance(NoInfo, m.name, m.name, UnknownType)) ++
- (m.readers flatMap (r => adaptReader(portRef(r), createSubField(bbRef, r)))) ++
- (m.writers flatMap (w => adaptWriter(portRef(w), createSubField(bbRef, w), hasMask, fillMask))) ++
- (m.readwriters flatMap (rw => adaptReadWriter(portRef(rw), createSubField(bbRef, rw), hasMask, fillMask)))
+ (m.readers flatMap (r => adaptReader(portRef(r), WSubField(bbRef, r)))) ++
+ (m.writers flatMap (w => adaptWriter(portRef(w), WSubField(bbRef, w), hasMask, fillMask))) ++
+ (m.readwriters flatMap (rw => adaptReadWriter(portRef(rw), WSubField(bbRef, rw), hasMask, fillMask)))
val wrapper = Module(NoInfo, wrapperName, wrapperIoPorts, Block(stmts))
val bb = ExtModule(NoInfo, m.name, bbIoPorts, m.name, Seq.empty)
// TODO: Annotate? -- use actual annotation map
@@ -130,34 +130,34 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform {
def adaptReader(wrapperPort: WRef, bbPort: WSubField): Seq[Statement] =
defaultConnects(wrapperPort, bbPort) :+
- fromBits(createSubField(wrapperPort, "data"), createSubField(bbPort, "data"))
+ fromBits(WSubField(wrapperPort, "data"), WSubField(bbPort, "data"))
def adaptWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean): Seq[Statement] = {
- val wrapperData = createSubField(wrapperPort, "data")
+ val wrapperData = WSubField(wrapperPort, "data")
val defaultSeq = defaultConnects(wrapperPort, bbPort) :+
- Connect(NoInfo, createSubField(bbPort, "data"), toBits(wrapperData))
+ Connect(NoInfo, WSubField(bbPort, "data"), toBits(wrapperData))
hasMask match {
case false => defaultSeq
case true => defaultSeq :+ Connect(
NoInfo,
- createSubField(bbPort, "mask"),
- maskBits(createSubField(wrapperPort, "mask"), wrapperData.tpe, fillMask)
+ WSubField(bbPort, "mask"),
+ maskBits(WSubField(wrapperPort, "mask"), wrapperData.tpe, fillMask)
)
}
}
def adaptReadWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean): Seq[Statement] = {
- val wrapperWData = createSubField(wrapperPort, "wdata")
+ val wrapperWData = WSubField(wrapperPort, "wdata")
val defaultSeq = defaultConnects(wrapperPort, bbPort) ++ Seq(
- fromBits(createSubField(wrapperPort, "rdata"), createSubField(bbPort, "rdata")),
+ fromBits(WSubField(wrapperPort, "rdata"), WSubField(bbPort, "rdata")),
connectFields(bbPort, "wmode", wrapperPort, "wmode"),
- Connect(NoInfo, createSubField(bbPort, "wdata"), toBits(wrapperWData)))
+ Connect(NoInfo, WSubField(bbPort, "wdata"), toBits(wrapperWData)))
hasMask match {
case false => defaultSeq
case true => defaultSeq :+ Connect(
NoInfo,
- createSubField(bbPort, "wmask"),
- maskBits(createSubField(wrapperPort, "wmask"), wrapperWData.tpe, fillMask)
+ WSubField(bbPort, "wmask"),
+ maskBits(WSubField(wrapperPort, "wmask"), wrapperWData.tpe, fillMask)
)
}
}