diff options
| author | Angie | 2017-02-22 20:03:27 -0800 |
|---|---|---|
| committer | Adam Izraelevitz | 2017-02-23 14:55:00 -0800 |
| commit | 1f9fd2f9b9e9a0117b0dd65524c9dcb767c02778 (patch) | |
| tree | b4ef1d8fdadd89e942e321cf4495a127cb9bfe59 /src | |
| parent | 1d652352b752502dd6d130aeb85981df214d7021 (diff) | |
move more general utils out of memutils, mov WIR helpers to WIR.scala and update uses
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/PrimOps.scala | 4 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 106 | ||||
| -rw-r--r-- | src/main/scala/firrtl/WIR.scala | 7 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/memlib/InferReadWrite.scala | 16 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/memlib/MemUtils.scala | 118 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala | 4 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala | 30 |
7 files changed, 140 insertions, 145 deletions
diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala index fa523e12..1ca005d7 100644 --- a/src/main/scala/firrtl/PrimOps.scala +++ b/src/main/scala/firrtl/PrimOps.scala @@ -120,8 +120,8 @@ object PrimOps extends LazyLogging { def t1 = e.args.head.tpe def t2 = e.args(1).tpe def t3 = e.args(2).tpe - def w1 = passes.getWidth(e.args.head.tpe) - def w2 = passes.getWidth(e.args(1).tpe) + def w1 = getWidth(e.args.head.tpe) + def w2 = getWidth(e.args(1).tpe) def p1 = t1 match { case FixedType(w, p) => p } //Intentional def p2 = t2 match { case FixedType(w, p) => p } //Intentional def c1 = IntWidth(e.consts.head) diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index fe6472a8..e3f0fee1 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -12,6 +12,112 @@ import scala.collection.mutable.{StringBuilder, ArrayBuffer, LinkedHashMap, Hash import java.io.PrintWriter import logger.LazyLogging +object seqCat { + def apply(args: Seq[Expression]): Expression = args.length match { + case 0 => error("Empty Seq passed to seqcat") + case 1 => args.head + case 2 => DoPrim(PrimOps.Cat, args, Nil, UIntType(UnknownWidth)) + case _ => + val (high, low) = args splitAt (args.length / 2) + DoPrim(PrimOps.Cat, Seq(seqCat(high), seqCat(low)), Nil, UIntType(UnknownWidth)) + } +} + +/** Given an expression, return an expression consisting of all sub-expressions + * concatenated (or flattened). + */ +object toBits { + def apply(e: Expression): Expression = e match { + case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiercat(ex) + case t => error("Invalid operand expression for toBits!") + } + private def hiercat(e: Expression): Expression = e.tpe match { + case t: VectorType => seqCat((0 until t.size).reverse map (i => + hiercat(WSubIndex(e, i, t.tpe, UNKNOWNGENDER)))) + case t: BundleType => seqCat(t.fields map (f => + hiercat(WSubField(e, f.name, f.tpe, UNKNOWNGENDER)))) + case t: GroundType => DoPrim(AsUInt, Seq(e), Seq.empty, UnknownType) + case t => error("Unknown type encountered in toBits!") + } +} + +object getWidth { + def apply(t: Type): Width = t match { + case t: GroundType => t.width + case _ => error("No width!") + } + def apply(e: Expression): Width = apply(e.tpe) +} + +object bitWidth { + def apply(dt: Type): BigInt = widthOf(dt) + private def widthOf(dt: Type): BigInt = dt match { + case t: VectorType => t.size * bitWidth(t.tpe) + case t: BundleType => t.fields.map(f => bitWidth(f.tpe)).foldLeft(BigInt(0))(_+_) + case GroundType(IntWidth(width)) => width + case t => error("Unknown type encountered in bitWidth!") + } +} + +object castRhs { + def apply(lhst: Type, rhs: Expression) = { + lhst match { + case _: SIntType => DoPrim(AsSInt, Seq(rhs), Seq.empty, lhst) + case FixedType(_, IntWidth(p)) => DoPrim(AsFixedPoint, Seq(rhs), Seq(p), lhst) + case ClockType => DoPrim(AsClock, Seq(rhs), Seq.empty, lhst) + case _: UIntType => rhs + } + } +} + +object fromBits { + def apply(lhs: Expression, rhs: Expression): Statement = { + val fbits = lhs match { + case ex @ (_: WRef | _: WSubField | _: WSubIndex) => getPart(ex, ex.tpe, rhs, 0) + case _ => error("Invalid LHS expression for fromBits!") + } + Block(fbits._2) + } + private def getPartGround(lhs: Expression, + lhst: Type, + rhs: Expression, + offset: BigInt): (BigInt, Seq[Statement]) = { + val intWidth = bitWidth(lhst) + val sel = DoPrim(PrimOps.Bits, Seq(rhs), Seq(offset + intWidth - 1, offset), UnknownType) + val rhsConnect = castRhs(lhst, sel) + (offset + intWidth, Seq(Connect(NoInfo, lhs, rhsConnect))) + } + private def getPart(lhs: Expression, + lhst: Type, + rhs: Expression, + offset: BigInt): (BigInt, Seq[Statement]) = + lhst match { + case t: VectorType => (0 until t.size foldLeft (offset, Seq[Statement]())) { + case ((curOffset, stmts), i) => + val subidx = WSubIndex(lhs, i, t.tpe, UNKNOWNGENDER) + val (tmpOffset, substmts) = getPart(subidx, t.tpe, rhs, curOffset) + (tmpOffset, stmts ++ substmts) + } + case t: BundleType => (t.fields foldRight (offset, Seq[Statement]())) { + case (f, (curOffset, stmts)) => + val subfield = WSubField(lhs, f.name, f.tpe, UNKNOWNGENDER) + val (tmpOffset, substmts) = getPart(subfield, f.tpe, rhs, curOffset) + (tmpOffset, stmts ++ substmts) + } + case t: GroundType => getPartGround(lhs, t, rhs, offset) + case t => error("Unknown type encountered in fromBits!") + } +} + +object connectFields { + def apply(lref: Expression, lname: String, rref: Expression, rname: String): Connect = + Connect(NoInfo, WSubField(lref, lname), WSubField(rref, rname)) +} + +object flattenType { + def apply(t: Type) = UIntType(IntWidth(bitWidth(t))) +} + class FIRRTLException(str: String) extends Exception(str) object Utils extends LazyLogging { diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index bc256b68..946906fa 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -35,6 +35,7 @@ object WRef { def apply(wire: DefWire): WRef = new WRef(wire.name, wire.tpe, WireKind, UNKNOWNGENDER) /** Creates a WRef from a Register */ def apply(reg: DefRegister): WRef = new WRef(reg.name, reg.tpe, RegKind, UNKNOWNGENDER) + def apply(n: String, t: Type = UnknownType, k: Kind = ExpKind): WRef = new WRef(n, t, k, UNKNOWNGENDER) } case class WSubField(exp: Expression, name: String, tpe: Type, gender: Gender) extends Expression { def serialize: String = s"${exp.serialize}.$name" @@ -42,6 +43,9 @@ case class WSubField(exp: Expression, name: String, tpe: Type, gender: Gender) e def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) def mapWidth(f: Width => Width): Expression = this } +object WSubField { + def apply(exp: Expression, n: String): WSubField = new WSubField(exp, n, field_type(exp.tpe, n), UNKNOWNGENDER) +} case class WSubIndex(exp: Expression, value: Int, tpe: Type, gender: Gender) extends Expression { def serialize: String = s"${exp.serialize}[$value]" def mapExpr(f: Expression => Expression): Expression = this.copy(exp = f(exp)) @@ -83,6 +87,9 @@ case class WDefInstance(info: Info, name: String, module: String, tpe: Type) ext def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) def mapString(f: String => String): Statement = this.copy(name = f(name)) } +object WDefInstance { + def apply(name: String, module: String): WDefInstance = new WDefInstance(NoInfo, name, module, UnknownType) +} case class WDefInstanceConnector( info: Info, name: String, diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala index 2501ba04..554c1f0d 100644 --- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala +++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala @@ -101,24 +101,24 @@ object InferReadWritePass extends Pass { val rclk = getOrigin(connects)(memPortField(mem, r, "clk")) if (weq(wclk, rclk) && (wp exists (a => rp exists (b => checkComplement(a, b))))) { val rw = namespace newName "rw" - val rwExp = createSubField(createRef(mem.name), rw) + val rwExp = WSubField(WRef(mem.name), rw) readwriters += rw readers += r writers += w repl(memPortField(mem, r, "clk")) = EmptyExpression repl(memPortField(mem, r, "en")) = EmptyExpression repl(memPortField(mem, r, "addr")) = EmptyExpression - repl(memPortField(mem, r, "data")) = createSubField(rwExp, "rdata") + repl(memPortField(mem, r, "data")) = WSubField(rwExp, "rdata") repl(memPortField(mem, w, "clk")) = EmptyExpression - repl(memPortField(mem, w, "en")) = createSubField(rwExp, "wmode") + repl(memPortField(mem, w, "en")) = WSubField(rwExp, "wmode") repl(memPortField(mem, w, "addr")) = EmptyExpression - repl(memPortField(mem, w, "data")) = createSubField(rwExp, "wdata") - repl(memPortField(mem, w, "mask")) = createSubField(rwExp, "wmask") - stmts += Connect(NoInfo, createSubField(rwExp, "clk"), wclk) - stmts += Connect(NoInfo, createSubField(rwExp, "en"), + repl(memPortField(mem, w, "data")) = WSubField(rwExp, "wdata") + repl(memPortField(mem, w, "mask")) = WSubField(rwExp, "wmask") + stmts += Connect(NoInfo, WSubField(rwExp, "clk"), wclk) + stmts += Connect(NoInfo, WSubField(rwExp, "en"), DoPrim(Or, Seq(connects(memPortField(mem, r, "en")), connects(memPortField(mem, w, "en"))), Nil, BoolType)) - stmts += Connect(NoInfo, createSubField(rwExp, "addr"), + stmts += Connect(NoInfo, WSubField(rwExp, "addr"), Mux(connects(memPortField(mem, w, "en")), connects(memPortField(mem, w, "addr")), connects(memPortField(mem, r, "addr")), UnknownType)) diff --git a/src/main/scala/firrtl/passes/memlib/MemUtils.scala b/src/main/scala/firrtl/passes/memlib/MemUtils.scala index c7eb4539..40f81555 100644 --- a/src/main/scala/firrtl/passes/memlib/MemUtils.scala +++ b/src/main/scala/firrtl/passes/memlib/MemUtils.scala @@ -7,35 +7,6 @@ import firrtl.ir._ import firrtl.Utils._ import firrtl.PrimOps._ -object seqCat { - def apply(args: Seq[Expression]): Expression = args.length match { - case 0 => error("Empty Seq passed to seqcat") - case 1 => args.head - case 2 => DoPrim(PrimOps.Cat, args, Nil, UIntType(UnknownWidth)) - case _ => - val (high, low) = args splitAt (args.length / 2) - DoPrim(PrimOps.Cat, Seq(seqCat(high), seqCat(low)), Nil, UIntType(UnknownWidth)) - } -} - -/** Given an expression, return an expression consisting of all sub-expressions - * concatenated (or flattened). - */ -object toBits { - def apply(e: Expression): Expression = e match { - case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiercat(ex) - case t => error("Invalid operand expression for toBits!") - } - private def hiercat(e: Expression): Expression = e.tpe match { - case t: VectorType => seqCat((0 until t.size).reverse map (i => - hiercat(WSubIndex(e, i, t.tpe, UNKNOWNGENDER)))) - case t: BundleType => seqCat(t.fields map (f => - hiercat(WSubField(e, f.name, f.tpe, UNKNOWNGENDER)))) - case t: GroundType => DoPrim(AsUInt, Seq(e), Seq.empty, UnknownType) - case t => error("Unknown type encountered in toBits!") - } -} - /** Given a mask, return a bitmask corresponding to the desired datatype. * Requirements: * - The mask type and datatype must be equivalent, except any ground type in @@ -71,74 +42,6 @@ object toBitMask { } } -object getWidth { - def apply(t: Type): Width = t match { - case t: GroundType => t.width - case _ => error("No width!") - } - def apply(e: Expression): Width = apply(e.tpe) -} - -object bitWidth { - def apply(dt: Type): BigInt = widthOf(dt) - private def widthOf(dt: Type): BigInt = dt match { - case t: VectorType => t.size * bitWidth(t.tpe) - case t: BundleType => t.fields.map(f => bitWidth(f.tpe)).foldLeft(BigInt(0))(_+_) - case GroundType(IntWidth(width)) => width - case t => error("Unknown type encountered in bitWidth!") - } -} - -object castRhs { - def apply(lhst: Type, rhs: Expression) = { - lhst match { - case _: SIntType => DoPrim(AsSInt, Seq(rhs), Seq.empty, lhst) - case FixedType(_, IntWidth(p)) => DoPrim(AsFixedPoint, Seq(rhs), Seq(p), lhst) - case ClockType => DoPrim(AsClock, Seq(rhs), Seq.empty, lhst) - case _: UIntType => rhs - } - } -} - -object fromBits { - def apply(lhs: Expression, rhs: Expression): Statement = { - val fbits = lhs match { - case ex @ (_: WRef | _: WSubField | _: WSubIndex) => getPart(ex, ex.tpe, rhs, 0) - case _ => error("Invalid LHS expression for fromBits!") - } - Block(fbits._2) - } - private def getPartGround(lhs: Expression, - lhst: Type, - rhs: Expression, - offset: BigInt): (BigInt, Seq[Statement]) = { - val intWidth = bitWidth(lhst) - val sel = DoPrim(PrimOps.Bits, Seq(rhs), Seq(offset + intWidth - 1, offset), UnknownType) - val rhsConnect = castRhs(lhst, sel) - (offset + intWidth, Seq(Connect(NoInfo, lhs, rhsConnect))) - } - private def getPart(lhs: Expression, - lhst: Type, - rhs: Expression, - offset: BigInt): (BigInt, Seq[Statement]) = - lhst match { - case t: VectorType => (0 until t.size foldLeft (offset, Seq[Statement]())) { - case ((curOffset, stmts), i) => - val subidx = WSubIndex(lhs, i, t.tpe, UNKNOWNGENDER) - val (tmpOffset, substmts) = getPart(subidx, t.tpe, rhs, curOffset) - (tmpOffset, stmts ++ substmts) - } - case t: BundleType => (t.fields foldRight (offset, Seq[Statement]())) { - case (f, (curOffset, stmts)) => - val subfield = WSubField(lhs, f.name, f.tpe, UNKNOWNGENDER) - val (tmpOffset, substmts) = getPart(subfield, f.tpe, rhs, curOffset) - (tmpOffset, stmts ++ substmts) - } - case t: GroundType => getPartGround(lhs, t, rhs, offset) - case t => error("Unknown type encountered in fromBits!") - } -} - object createMask { def apply(dt: Type): Type = dt match { case t: VectorType => VectorType(apply(t.tpe), t.size) @@ -147,27 +50,6 @@ object createMask { } } -object createRef { - def apply(n: String, t: Type = UnknownType, k: Kind = ExpKind) = WRef(n, t, k, UNKNOWNGENDER) -} - -object createSubField { - def apply(exp: Expression, n: String) = WSubField(exp, n, field_type(exp.tpe, n), UNKNOWNGENDER) -} - -object createInstance { - def apply(name: String, module: String) = WDefInstance(NoInfo, name, module, UnknownType) -} - -object connectFields { - def apply(lref: Expression, lname: String, rref: Expression, rname: String): Connect = - Connect(NoInfo, createSubField(lref, lname), createSubField(rref, rname)) -} - -object flattenType { - def apply(t: Type) = UIntType(IntWidth(bitWidth(t))) -} - object MemPortUtils { type MemPortMap = collection.mutable.HashMap[String, Expression] type Memories = collection.mutable.ArrayBuffer[DefMemory] diff --git a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala index 81242810..57c301b1 100644 --- a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala +++ b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala @@ -43,8 +43,8 @@ object RenameAnnotatedMemoryPorts extends Pass { def updateMemPortMap(ports: Seq[String], fields: Seq[String], newPortKind: String): Unit = for ((p, i) <- ports.zipWithIndex; f <- fields) { - val newPort = createSubField(createRef(m.name), newPortKind + i) - val field = createSubField(newPort, f) + val newPort = WSubField(WRef(m.name), newPortKind + i) + val field = WSubField(newPort, f) memPortMap(s"${m.name}.$p.$f") = field } updateMemPortMap(m.readers, rFields, "R") diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala index 44dad557..03fd5ffa 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala @@ -101,14 +101,14 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform { // Creates a type with the write/readwrite masks omitted if necessary val bbIoType = memToFlattenBundle(m) val bbIoPorts = bbIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe)) - val bbRef = createRef(m.name, bbIoType) + val bbRef = WRef(m.name, bbIoType) val hasMask = m.maskGran.isDefined val fillMask = getFillWMask(m) - def portRef(p: String) = createRef(p, field_type(wrapperIoType, p)) + def portRef(p: String) = WRef(p, field_type(wrapperIoType, p)) val stmts = Seq(WDefInstance(NoInfo, m.name, m.name, UnknownType)) ++ - (m.readers flatMap (r => adaptReader(portRef(r), createSubField(bbRef, r)))) ++ - (m.writers flatMap (w => adaptWriter(portRef(w), createSubField(bbRef, w), hasMask, fillMask))) ++ - (m.readwriters flatMap (rw => adaptReadWriter(portRef(rw), createSubField(bbRef, rw), hasMask, fillMask))) + (m.readers flatMap (r => adaptReader(portRef(r), WSubField(bbRef, r)))) ++ + (m.writers flatMap (w => adaptWriter(portRef(w), WSubField(bbRef, w), hasMask, fillMask))) ++ + (m.readwriters flatMap (rw => adaptReadWriter(portRef(rw), WSubField(bbRef, rw), hasMask, fillMask))) val wrapper = Module(NoInfo, wrapperName, wrapperIoPorts, Block(stmts)) val bb = ExtModule(NoInfo, m.name, bbIoPorts, m.name, Seq.empty) // TODO: Annotate? -- use actual annotation map @@ -130,34 +130,34 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform { def adaptReader(wrapperPort: WRef, bbPort: WSubField): Seq[Statement] = defaultConnects(wrapperPort, bbPort) :+ - fromBits(createSubField(wrapperPort, "data"), createSubField(bbPort, "data")) + fromBits(WSubField(wrapperPort, "data"), WSubField(bbPort, "data")) def adaptWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean): Seq[Statement] = { - val wrapperData = createSubField(wrapperPort, "data") + val wrapperData = WSubField(wrapperPort, "data") val defaultSeq = defaultConnects(wrapperPort, bbPort) :+ - Connect(NoInfo, createSubField(bbPort, "data"), toBits(wrapperData)) + Connect(NoInfo, WSubField(bbPort, "data"), toBits(wrapperData)) hasMask match { case false => defaultSeq case true => defaultSeq :+ Connect( NoInfo, - createSubField(bbPort, "mask"), - maskBits(createSubField(wrapperPort, "mask"), wrapperData.tpe, fillMask) + WSubField(bbPort, "mask"), + maskBits(WSubField(wrapperPort, "mask"), wrapperData.tpe, fillMask) ) } } def adaptReadWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean): Seq[Statement] = { - val wrapperWData = createSubField(wrapperPort, "wdata") + val wrapperWData = WSubField(wrapperPort, "wdata") val defaultSeq = defaultConnects(wrapperPort, bbPort) ++ Seq( - fromBits(createSubField(wrapperPort, "rdata"), createSubField(bbPort, "rdata")), + fromBits(WSubField(wrapperPort, "rdata"), WSubField(bbPort, "rdata")), connectFields(bbPort, "wmode", wrapperPort, "wmode"), - Connect(NoInfo, createSubField(bbPort, "wdata"), toBits(wrapperWData))) + Connect(NoInfo, WSubField(bbPort, "wdata"), toBits(wrapperWData))) hasMask match { case false => defaultSeq case true => defaultSeq :+ Connect( NoInfo, - createSubField(bbPort, "wmask"), - maskBits(createSubField(wrapperPort, "wmask"), wrapperWData.tpe, fillMask) + WSubField(bbPort, "wmask"), + maskBits(WSubField(wrapperPort, "wmask"), wrapperWData.tpe, fillMask) ) } } |
