From de45e93c43201c5b757d681fb922a57564462a08 Mon Sep 17 00:00:00 2001 From: Angie Wang Date: Sun, 23 Oct 2016 19:32:18 -0700 Subject: Fix bitmask (#346) * toBitMask cat direction should be consistent with data * minor comment updates * moved remaining mem passes/utils to memlib * changed again so that data, mask are consistent. data element 0, bit 0 = LSB (on RHS) when concatenated --- src/main/scala/firrtl/passes/InferReadWrite.scala | 186 ------------------ src/main/scala/firrtl/passes/MemUtils.scala | 215 --------------------- .../scala/firrtl/passes/VerilogMemDelays.scala | 184 ------------------ .../firrtl/passes/memlib/InferReadWrite.scala | 186 ++++++++++++++++++ src/main/scala/firrtl/passes/memlib/MemUtils.scala | 190 ++++++++++++++++++ .../firrtl/passes/memlib/ReplaceMemMacros.scala | 3 +- .../firrtl/passes/memlib/VerilogMemDelays.scala | 184 ++++++++++++++++++ 7 files changed, 562 insertions(+), 586 deletions(-) delete mode 100644 src/main/scala/firrtl/passes/InferReadWrite.scala delete mode 100644 src/main/scala/firrtl/passes/MemUtils.scala delete mode 100644 src/main/scala/firrtl/passes/VerilogMemDelays.scala create mode 100644 src/main/scala/firrtl/passes/memlib/InferReadWrite.scala create mode 100644 src/main/scala/firrtl/passes/memlib/MemUtils.scala create mode 100644 src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala (limited to 'src') diff --git a/src/main/scala/firrtl/passes/InferReadWrite.scala b/src/main/scala/firrtl/passes/InferReadWrite.scala deleted file mode 100644 index 9adbdd95..00000000 --- a/src/main/scala/firrtl/passes/InferReadWrite.scala +++ /dev/null @@ -1,186 +0,0 @@ -/* -Copyright (c) 2014 - 2016 The Regents of the University of -California (Regents). All Rights Reserved. Redistribution and use in -source and binary forms, with or without modification, are permitted -provided that the following conditions are met: - * Redistributions of source code must retain the above - copyright notice, this list of conditions and the following - two paragraphs of disclaimer. - * Redistributions in binary form must reproduce the above - copyright notice, this list of conditions and the following - two paragraphs of disclaimer in the documentation and/or other materials - provided with the distribution. - * Neither the name of the Regents nor the names of its contributors - may be used to endorse or promote products derived from this - software without specific prior written permission. -IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, -SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, -ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF -REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF -ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION -TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR -MODIFICATIONS. -*/ - -package firrtl.passes - -import firrtl._ -import firrtl.ir._ -import firrtl.Mappers._ -import firrtl.PrimOps._ -import firrtl.Utils.{one, zero, BoolType} -import firrtl.passes.memlib._ -import MemPortUtils.memPortField -import AnalysisUtils.{Connects, getConnects, getOrigin} -import WrappedExpression.weq -import Annotations._ - -case class InferReadWriteAnnotation(t: String, tID: TransID) - extends Annotation with Loose with Unstable { - val target = CircuitName(t) - def duplicate(n: Named) = this.copy(t=n.name) -} - -// This pass examine the enable signals of the read & write ports of memories -// whose readLatency is greater than 1 (usually SeqMem in Chisel). -// If any product term of the enable signal of the read port is the complement -// of any product term of the enable signal of the write port, then the readwrite -// port is inferred. -object InferReadWritePass extends Pass { - def name = "Infer ReadWrite Ports" - - type Netlist = collection.mutable.HashMap[String, Expression] - type Statements = collection.mutable.ArrayBuffer[Statement] - type PortSet = collection.mutable.HashSet[String] - - private implicit def toString(e: Expression): String = e.serialize - - def getProductTerms(connects: Connects)(e: Expression): Seq[Expression] = e match { - // No ConstProp yet... - case Mux(cond, tval, fval, _) if weq(tval, one) && weq(fval, zero) => - getProductTerms(connects)(cond) - // Visit each term of AND operation - case DoPrim(op, args, consts, tpe) if op == And => - e +: (args flatMap getProductTerms(connects)) - // Visit connected nodes to references - case _: WRef | _: WSubField | _: WSubIndex => connects get e match { - case None => Seq(e) - case Some(ex) => e +: getProductTerms(connects)(ex) - } - // Otherwise just return itself - case _ => Seq(e) - } - - def checkComplement(a: Expression, b: Expression) = (a, b) match { - // b ?= Not(a) - case (_, DoPrim(Not, args, _, _)) => weq(args.head, a) - // a ?= Not(b) - case (DoPrim(Not, args, _, _), _) => weq(args.head, b) - // b ?= Eq(a, 0) or b ?= Eq(0, a) - case (_, DoPrim(Eq, args, _, _)) => - weq(args.head, a) && weq(args(1), zero) || - weq(args(1), a) && weq(args.head, zero) - // a ?= Eq(b, 0) or b ?= Eq(0, a) - case (DoPrim(Eq, args, _, _), _) => - weq(args.head, b) && weq(args(1), zero) || - weq(args(1), b) && weq(args.head, zero) - case _ => false - } - - - def replaceExp(repl: Netlist)(e: Expression): Expression = - e map replaceExp(repl) match { - case ex: WSubField => repl getOrElse (ex.serialize, ex) - case ex => ex - } - - def replaceStmt(repl: Netlist)(s: Statement): Statement = - s map replaceStmt(repl) map replaceExp(repl) match { - case Connect(_, EmptyExpression, _) => EmptyStmt - case sx => sx - } - - def inferReadWriteStmt(connects: Connects, - repl: Netlist, - stmts: Statements) - (s: Statement): Statement = s match { - // infer readwrite ports only for non combinational memories - case mem: DefMemory if mem.readLatency > 0 => - val ut = UnknownType - val ug = UNKNOWNGENDER - val readers = new PortSet - val writers = new PortSet - val readwriters = collection.mutable.ArrayBuffer[String]() - val namespace = Namespace(mem.readers ++ mem.writers ++ mem.readwriters) - for (w <- mem.writers ; r <- mem.readers) { - val wp = getProductTerms(connects)(memPortField(mem, w, "en")) - val rp = getProductTerms(connects)(memPortField(mem, r, "en")) - val wclk = getOrigin(connects)(memPortField(mem, w, "clk")) - val rclk = getOrigin(connects)(memPortField(mem, r, "clk")) - if (weq(wclk, rclk) && (wp exists (a => rp exists (b => checkComplement(a, b))))) { - val rw = namespace newName "rw" - val rwExp = createSubField(createRef(mem.name), rw) - readwriters += rw - readers += r - writers += w - repl(memPortField(mem, r, "clk")) = EmptyExpression - repl(memPortField(mem, r, "en")) = EmptyExpression - repl(memPortField(mem, r, "addr")) = EmptyExpression - repl(memPortField(mem, r, "data")) = createSubField(rwExp, "rdata") - repl(memPortField(mem, w, "clk")) = EmptyExpression - repl(memPortField(mem, w, "en")) = createSubField(rwExp, "wmode") - repl(memPortField(mem, w, "addr")) = EmptyExpression - repl(memPortField(mem, w, "data")) = createSubField(rwExp, "wdata") - repl(memPortField(mem, w, "mask")) = createSubField(rwExp, "wmask") - stmts += Connect(NoInfo, createSubField(rwExp, "clk"), wclk) - stmts += Connect(NoInfo, createSubField(rwExp, "en"), - DoPrim(Or, Seq(connects(memPortField(mem, r, "en")), - connects(memPortField(mem, w, "en"))), Nil, BoolType)) - stmts += Connect(NoInfo, createSubField(rwExp, "addr"), - Mux(connects(memPortField(mem, w, "en")), - connects(memPortField(mem, w, "addr")), - connects(memPortField(mem, r, "addr")), UnknownType)) - } - } - if (readwriters.isEmpty) mem else mem copy ( - readers = mem.readers filterNot readers, - writers = mem.writers filterNot writers, - readwriters = mem.readwriters ++ readwriters) - case sx => sx map inferReadWriteStmt(connects, repl, stmts) - } - - def inferReadWrite(m: DefModule) = { - val connects = getConnects(m) - val repl = new Netlist - val stmts = new Statements - (m map inferReadWriteStmt(connects, repl, stmts) - map replaceStmt(repl)) match { - case m: ExtModule => m - case m: Module => m copy (body = Block(m.body +: stmts)) - } - } - - def run(c: Circuit) = c copy (modules = c.modules map inferReadWrite) -} - -// Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl" -// To use this transform, circuit name should be annotated with its TransId. -class InferReadWrite(transID: TransID) extends Transform with SimpleRun { - def passSeq = Seq( - InferReadWritePass, - CheckInitialization, - InferTypes, - ResolveKinds, - ResolveGenders - ) - def execute(c: Circuit, map: AnnotationMap) = map get transID match { - case Some(p) => p get CircuitName(c.main) match { - case Some(InferReadWriteAnnotation(_, _)) => run(c, passSeq) - case _ => sys.error("Unexpected annotation for InferReadWrite") - } - case _ => TransformResult(c) - } -} diff --git a/src/main/scala/firrtl/passes/MemUtils.scala b/src/main/scala/firrtl/passes/MemUtils.scala deleted file mode 100644 index 8cd58afb..00000000 --- a/src/main/scala/firrtl/passes/MemUtils.scala +++ /dev/null @@ -1,215 +0,0 @@ -/* - Copyright (c) 2014 - 2016 The Regents of the University of - California (Regents). All Rights Reserved. Redistribution and use in - source and binary forms, with or without modification, are permitted - provided that the following conditions are met: - * Redistributions of source code must retain the above - copyright notice, this list of conditions and the following - two paragraphs of disclaimer. - * Redistributions in binary form must reproduce the above - copyright notice, this list of conditions and the following - two paragraphs of disclaimer in the documentation and/or other materials - provided with the distribution. - * Neither the name of the Regents nor the names of its contributors - may be used to endorse or promote products derived from this - software without specific prior written permission. - IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, - SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, - ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF - REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF - ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION - TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR - MODIFICATIONS. - */ - -package firrtl.passes - -import firrtl._ -import firrtl.ir._ -import firrtl.Utils._ -import firrtl.PrimOps._ - -object seqCat { - def apply(args: Seq[Expression]): Expression = args.length match { - case 0 => error("Empty Seq passed to seqcat") - case 1 => args.head - case 2 => DoPrim(PrimOps.Cat, args, Nil, UIntType(UnknownWidth)) - case _ => - val (high, low) = args splitAt (args.length / 2) - DoPrim(PrimOps.Cat, Seq(seqCat(high), seqCat(low)), Nil, UIntType(UnknownWidth)) - } -} - -/** Given an expression, return an expression consisting of all sub-expressions - * concatenated (or flattened). - */ -object toBits { - def apply(e: Expression): Expression = e match { - case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiercat(ex) - case t => error("Invalid operand expression for toBits!") - } - private def hiercat(e: Expression): Expression = e.tpe match { - case t: VectorType => seqCat((0 until t.size) map (i => - hiercat(WSubIndex(e, i, t.tpe, UNKNOWNGENDER)))) - case t: BundleType => seqCat(t.fields map (f => - hiercat(WSubField(e, f.name, f.tpe, UNKNOWNGENDER)))) - case t: GroundType => e - case t => error("Unknown type encountered in toBits!") - } -} - -/** Given a mask, return a bitmask corresponding to the desired datatype. - * Requirements: - * - The mask type and datatype must be equivalent, except any ground type in - * datatype must be matched by a 1-bit wide UIntType. - * - The mask must be a reference, subfield, or subindex - * The bitmask is a series of concatenations of the single mask bit over the - * length of the corresponding ground type, e.g.: - *{{{ - * wire mask: {x: UInt<1>, y: UInt<1>} - * wire data: {x: UInt<2>, y: SInt<2>} - * // this would return: - * cat(cat(mask.x, mask.x), cat(mask.y, mask.y)) - * }}} - */ -object toBitMask { - def apply(mask: Expression, dataType: Type): Expression = mask match { - case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiermask(ex, dataType) - case t => error("Invalid operand expression for toBits!") - } - private def hiermask(mask: Expression, dataType: Type): Expression = - (mask.tpe, dataType) match { - case (mt: VectorType, dt: VectorType) => - seqCat((0 until mt.size).reverse map { i => - hiermask(WSubIndex(mask, i, mt.tpe, UNKNOWNGENDER), dt.tpe) - }) - case (mt: BundleType, dt: BundleType) => - seqCat((mt.fields zip dt.fields) map { case (mf, df) => - hiermask(WSubField(mask, mf.name, mf.tpe, UNKNOWNGENDER), df.tpe) - }) - case (UIntType(width), dt: GroundType) if width == IntWidth(BigInt(1)) => - seqCat(List.fill(bitWidth(dt).intValue)(mask)) - case (mt, dt) => error("Invalid type for mask component!") - } -} - -object getWidth { - def apply(t: Type): Width = t match { - case t: GroundType => t.width - case _ => error("No width!") - } - def apply(e: Expression): Width = apply(e.tpe) -} - -object bitWidth { - def apply(dt: Type): BigInt = widthOf(dt) - private def widthOf(dt: Type): BigInt = dt match { - case t: VectorType => t.size * bitWidth(t.tpe) - case t: BundleType => t.fields.map(f => bitWidth(f.tpe)).foldLeft(BigInt(0))(_+_) - case GroundType(IntWidth(width)) => width - case t => error("Unknown type encountered in bitWidth!") - } -} - -object fromBits { - def apply(lhs: Expression, rhs: Expression): Statement = { - val fbits = lhs match { - case ex @ (_: WRef | _: WSubField | _: WSubIndex) => getPart(ex, ex.tpe, rhs, 0) - case _ => error("Invalid LHS expression for fromBits!") - } - Block(fbits._2) - } - private def getPartGround(lhs: Expression, - lhst: Type, - rhs: Expression, - offset: BigInt): (BigInt, Seq[Statement]) = { - val intWidth = bitWidth(lhst) - val sel = DoPrim(PrimOps.Bits, Seq(rhs), Seq(offset + intWidth - 1, offset), UnknownType) - (offset + intWidth, Seq(Connect(NoInfo, lhs, sel))) - } - private def getPart(lhs: Expression, - lhst: Type, - rhs: Expression, - offset: BigInt): (BigInt, Seq[Statement]) = - lhst match { - case t: VectorType => (0 until t.size foldRight (offset, Seq[Statement]())) { - case (i, (curOffset, stmts)) => - val subidx = WSubIndex(lhs, i, t.tpe, UNKNOWNGENDER) - val (tmpOffset, substmts) = getPart(subidx, t.tpe, rhs, curOffset) - (tmpOffset, stmts ++ substmts) - } - case t: BundleType => (t.fields foldRight (offset, Seq[Statement]())) { - case (f, (curOffset, stmts)) => - val subfield = WSubField(lhs, f.name, f.tpe, UNKNOWNGENDER) - val (tmpOffset, substmts) = getPart(subfield, f.tpe, rhs, curOffset) - (tmpOffset, stmts ++ substmts) - } - case t: GroundType => getPartGround(lhs, t, rhs, offset) - case t => error("Unknown type encountered in fromBits!") - } -} - -object createMask { - def apply(dt: Type): Type = dt match { - case t: VectorType => VectorType(apply(t.tpe), t.size) - case t: BundleType => BundleType(t.fields map (f => f copy (tpe=apply(f.tpe)))) - case t: GroundType => BoolType - } -} - -object createRef { - def apply(n: String, t: Type = UnknownType, k: Kind = ExpKind) = WRef(n, t, k, UNKNOWNGENDER) -} - -object createSubField { - def apply(exp: Expression, n: String) = WSubField(exp, n, field_type(exp.tpe, n), UNKNOWNGENDER) -} - -object connectFields { - def apply(lref: Expression, lname: String, rref: Expression, rname: String): Connect = - Connect(NoInfo, createSubField(lref, lname), createSubField(rref, rname)) -} - -object flattenType { - def apply(t: Type) = UIntType(IntWidth(bitWidth(t))) -} - -object MemPortUtils { - type MemPortMap = collection.mutable.HashMap[String, Expression] - type Memories = collection.mutable.ArrayBuffer[DefMemory] - type Modules = collection.mutable.ArrayBuffer[DefModule] - - def defaultPortSeq(mem: DefMemory): Seq[Field] = Seq( - Field("addr", Default, UIntType(IntWidth(ceilLog2(mem.depth) max 1))), - Field("en", Default, BoolType), - Field("clk", Default, ClockType) - ) - - // Todo: merge it with memToBundle - def memType(mem: DefMemory): Type = { - val rType = BundleType(defaultPortSeq(mem) :+ - Field("data", Flip, mem.dataType)) - val wType = BundleType(defaultPortSeq(mem) ++ Seq( - Field("data", Default, mem.dataType), - Field("mask", Default, createMask(mem.dataType)))) - val rwType = BundleType(defaultPortSeq(mem) ++ Seq( - Field("rdata", Flip, mem.dataType), - Field("wmode", Default, BoolType), - Field("wdata", Default, mem.dataType), - Field("wmask", Default, createMask(mem.dataType)))) - BundleType( - (mem.readers map (Field(_, Flip, rType))) ++ - (mem.writers map (Field(_, Flip, wType))) ++ - (mem.readwriters map (Field(_, Flip, rwType)))) - } - - def memPortField(s: DefMemory, p: String, f: String): Expression = { - val mem = WRef(s.name, memType(s), MemKind, UNKNOWNGENDER) - val t1 = field_type(mem.tpe, p) - val t2 = field_type(t1, f) - WSubField(WSubField(mem, p, t1, UNKNOWNGENDER), f, t2, UNKNOWNGENDER) - } -} diff --git a/src/main/scala/firrtl/passes/VerilogMemDelays.scala b/src/main/scala/firrtl/passes/VerilogMemDelays.scala deleted file mode 100644 index 2f7126b4..00000000 --- a/src/main/scala/firrtl/passes/VerilogMemDelays.scala +++ /dev/null @@ -1,184 +0,0 @@ -/* -Copyright (c) 2014 - 2016 The Regents of the University of -California (Regents). All Rights Reserved. Redistribution and use in -source and binary forms, with or without modification, are permitted -provided that the following conditions are met: - * Redistributions of source code must retain the above - copyright notice, this list of conditions and the following - two paragraphs of disclaimer. - * Redistributions in binary form must reproduce the above - copyright notice, this list of conditions and the following - two paragraphs of disclaimer in the documentation and/or other materials - provided with the distribution. - * Neither the name of the Regents nor the names of its contributors - may be used to endorse or promote products derived from this - software without specific prior written permission. -IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, -SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, -ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF -REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF -ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION -TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR -MODIFICATIONS. -*/ - -package firrtl.passes - -import firrtl._ -import firrtl.ir._ -import firrtl.Utils._ -import firrtl.Mappers._ -import firrtl.PrimOps._ -import MemPortUtils._ - -/** This pass generates delay reigsters for memories for verilog */ -object VerilogMemDelays extends Pass { - def name = "Verilog Memory Delays" - val ug = UNKNOWNGENDER - type Netlist = collection.mutable.HashMap[String, Expression] - implicit def expToString(e: Expression): String = e.serialize - private def NOT(e: Expression) = DoPrim(Not, Seq(e), Nil, BoolType) - private def AND(e1: Expression, e2: Expression) = DoPrim(And, Seq(e1, e2), Nil, BoolType) - - def buildNetlist(netlist: Netlist)(s: Statement): Statement = { - s match { - case Connect(_, loc, expr) => kind(loc) match { - case MemKind => netlist(loc) = expr - case _ => - } - case _ => - } - s map buildNetlist(netlist) - } - - def memDelayStmt( - netlist: Netlist, - namespace: Namespace, - repl: Netlist) - (s: Statement): Statement = s map memDelayStmt(netlist, namespace, repl) match { - case sx: DefMemory => - val ports = (sx.readers ++ sx.writers).toSet - def newPortName(rw: String, p: String) = (for { - idx <- Stream from 0 - newName = s"${rw}_${p}_$idx" - if !ports(newName) - } yield newName).head - val rwMap = (sx.readwriters map (rw => - rw -> (newPortName(rw, "r"), newPortName(rw, "w")))).toMap - // 1. readwrite ports are split into read & write ports - // 2. memories are transformed into combinational - // because latency pipes are added for longer latencies - val mem = sx copy ( - readers = sx.readers ++ (sx.readwriters map (rw => rwMap(rw)._1)), - writers = sx.writers ++ (sx.readwriters map (rw => rwMap(rw)._2)), - readwriters = Nil, readLatency = 0, writeLatency = 1) - def pipe(e: Expression, // Expression to be piped - n: Int, // pipe depth - clk: Expression, // clock expression - cond: Expression // condition for pipes - ): (Expression, Seq[Statement]) = { - // returns - // 1) reference to the last pipe register - // 2) pipe registers and connects - val node = DefNode(NoInfo, namespace.newTemp, netlist(e)) - val wref = WRef(node.name, e.tpe, NodeKind, MALE) - ((0 until n) foldLeft (wref, Seq[Statement](node))){case ((ex, stmts), i) => - val name = namespace newName s"${LowerTypes.loweredName(e)}_pipe_$i" - val exx = WRef(name, e.tpe, RegKind, ug) - (exx, stmts ++ Seq(DefRegister(NoInfo, name, e.tpe, clk, zero, exx)) ++ - (if (i < n - 1 && WrappedExpression.weq(cond, one)) Seq(Connect(NoInfo, exx, ex)) else { - val condn = namespace newName s"${LowerTypes.loweredName(e)}_en" - val condx = WRef(condn, BoolType, NodeKind, FEMALE) - Seq(DefNode(NoInfo, condn, cond), - Connect(NoInfo, condx, cond), - Connect(NoInfo, exx, Mux(condx, ex, exx, e.tpe))) - }) - ) - } - } - def readPortConnects(reader: String, - clk: Expression, - en: Expression, - addr: Expression) = Seq( - Connect(NoInfo, memPortField(mem, reader, "clk"), clk), - // connect latency pipes to read ports - Connect(NoInfo, memPortField(mem, reader, "en"), en), - Connect(NoInfo, memPortField(mem, reader, "addr"), addr) - ) - def writePortConnects(writer: String, - clk: Expression, - en: Expression, - mask: Expression, - addr: Expression, - data: Expression) = Seq( - Connect(NoInfo, memPortField(mem, writer, "clk"), clk), - // connect latency pipes to write ports - Connect(NoInfo, memPortField(mem, writer, "en"), en), - Connect(NoInfo, memPortField(mem, writer, "mask"), mask), - Connect(NoInfo, memPortField(mem, writer, "addr"), addr), - Connect(NoInfo, memPortField(mem, writer, "data"), data) - ) - - - Block(mem +: ((sx.readers flatMap {reader => - // generate latency pipes for read ports (enable & addr) - val clk = netlist(memPortField(sx, reader, "clk")) - val (en, ss1) = pipe(memPortField(sx, reader, "en"), sx.readLatency - 1, clk, one) - val (addr, ss2) = pipe(memPortField(sx, reader, "addr"), sx.readLatency, clk, en) - ss1 ++ ss2 ++ readPortConnects(reader, clk, en, addr) - }) ++ (sx.writers flatMap {writer => - // generate latency pipes for write ports (enable, mask, addr, data) - val clk = netlist(memPortField(sx, writer, "clk")) - val (en, ss1) = pipe(memPortField(sx, writer, "en"), sx.writeLatency - 1, clk, one) - val (mask, ss2) = pipe(memPortField(sx, writer, "mask"), sx.writeLatency - 1, clk, one) - val (addr, ss3) = pipe(memPortField(sx, writer, "addr"), sx.writeLatency - 1, clk, one) - val (data, ss4) = pipe(memPortField(sx, writer, "data"), sx.writeLatency - 1, clk, one) - ss1 ++ ss2 ++ ss3 ++ ss4 ++ writePortConnects(writer, clk, en, mask, addr, data) - }) ++ (sx.readwriters flatMap {readwriter => - val (reader, writer) = rwMap(readwriter) - val clk = netlist(memPortField(sx, readwriter, "clk")) - // generate latency pipes for readwrite ports (enable, addr, wmode, wmask, wdata) - val (en, ss1) = pipe(memPortField(sx, readwriter, "en"), sx.readLatency - 1, clk, one) - val (wmode, ss2) = pipe(memPortField(sx, readwriter, "wmode"), sx.writeLatency - 1, clk, one) - val (wmask, ss3) = pipe(memPortField(sx, readwriter, "wmask"), sx.writeLatency - 1, clk, one) - val (wdata, ss4) = pipe(memPortField(sx, readwriter, "wdata"), sx.writeLatency - 1, clk, one) - val (raddr, ss5) = pipe(memPortField(sx, readwriter, "addr"), sx.readLatency, clk, AND(en, NOT(wmode))) - val (waddr, ss6) = pipe(memPortField(sx, readwriter, "addr"), sx.writeLatency - 1, clk, one) - repl(memPortField(sx, readwriter, "rdata")) = memPortField(mem, reader, "data") - ss1 ++ ss2 ++ ss3 ++ ss4 ++ ss5 ++ ss6 ++ - readPortConnects(reader, clk, en, raddr) ++ - writePortConnects(writer, clk, AND(en, wmode), wmask, waddr, wdata) - }))) - case sx: Connect => kind(sx.loc) match { - case MemKind => EmptyStmt - case _ => sx - } - case sx => sx - } - - def replaceExp(repl: Netlist)(e: Expression): Expression = e match { - case ex: WSubField => repl get ex match { - case Some(exx) => exx - case None => ex - } - case ex => ex map replaceExp(repl) - } - - def replaceStmt(repl: Netlist)(s: Statement): Statement = - s map replaceStmt(repl) map replaceExp(repl) - - def memDelayMod(m: DefModule): DefModule = { - val netlist = new Netlist - val namespace = Namespace(m) - val repl = new Netlist - (m map buildNetlist(netlist) - map memDelayStmt(netlist, namespace, repl) - map replaceStmt(repl)) - } - - def run(c: Circuit): Circuit = - c copy (modules = c.modules map memDelayMod) -} diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala new file mode 100644 index 00000000..9adbdd95 --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala @@ -0,0 +1,186 @@ +/* +Copyright (c) 2014 - 2016 The Regents of the University of +California (Regents). All Rights Reserved. Redistribution and use in +source and binary forms, with or without modification, are permitted +provided that the following conditions are met: + * Redistributions of source code must retain the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer in the documentation and/or other materials + provided with the distribution. + * Neither the name of the Regents nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. +IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, +SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, +ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF +REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF +ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION +TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR +MODIFICATIONS. +*/ + +package firrtl.passes + +import firrtl._ +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.PrimOps._ +import firrtl.Utils.{one, zero, BoolType} +import firrtl.passes.memlib._ +import MemPortUtils.memPortField +import AnalysisUtils.{Connects, getConnects, getOrigin} +import WrappedExpression.weq +import Annotations._ + +case class InferReadWriteAnnotation(t: String, tID: TransID) + extends Annotation with Loose with Unstable { + val target = CircuitName(t) + def duplicate(n: Named) = this.copy(t=n.name) +} + +// This pass examine the enable signals of the read & write ports of memories +// whose readLatency is greater than 1 (usually SeqMem in Chisel). +// If any product term of the enable signal of the read port is the complement +// of any product term of the enable signal of the write port, then the readwrite +// port is inferred. +object InferReadWritePass extends Pass { + def name = "Infer ReadWrite Ports" + + type Netlist = collection.mutable.HashMap[String, Expression] + type Statements = collection.mutable.ArrayBuffer[Statement] + type PortSet = collection.mutable.HashSet[String] + + private implicit def toString(e: Expression): String = e.serialize + + def getProductTerms(connects: Connects)(e: Expression): Seq[Expression] = e match { + // No ConstProp yet... + case Mux(cond, tval, fval, _) if weq(tval, one) && weq(fval, zero) => + getProductTerms(connects)(cond) + // Visit each term of AND operation + case DoPrim(op, args, consts, tpe) if op == And => + e +: (args flatMap getProductTerms(connects)) + // Visit connected nodes to references + case _: WRef | _: WSubField | _: WSubIndex => connects get e match { + case None => Seq(e) + case Some(ex) => e +: getProductTerms(connects)(ex) + } + // Otherwise just return itself + case _ => Seq(e) + } + + def checkComplement(a: Expression, b: Expression) = (a, b) match { + // b ?= Not(a) + case (_, DoPrim(Not, args, _, _)) => weq(args.head, a) + // a ?= Not(b) + case (DoPrim(Not, args, _, _), _) => weq(args.head, b) + // b ?= Eq(a, 0) or b ?= Eq(0, a) + case (_, DoPrim(Eq, args, _, _)) => + weq(args.head, a) && weq(args(1), zero) || + weq(args(1), a) && weq(args.head, zero) + // a ?= Eq(b, 0) or b ?= Eq(0, a) + case (DoPrim(Eq, args, _, _), _) => + weq(args.head, b) && weq(args(1), zero) || + weq(args(1), b) && weq(args.head, zero) + case _ => false + } + + + def replaceExp(repl: Netlist)(e: Expression): Expression = + e map replaceExp(repl) match { + case ex: WSubField => repl getOrElse (ex.serialize, ex) + case ex => ex + } + + def replaceStmt(repl: Netlist)(s: Statement): Statement = + s map replaceStmt(repl) map replaceExp(repl) match { + case Connect(_, EmptyExpression, _) => EmptyStmt + case sx => sx + } + + def inferReadWriteStmt(connects: Connects, + repl: Netlist, + stmts: Statements) + (s: Statement): Statement = s match { + // infer readwrite ports only for non combinational memories + case mem: DefMemory if mem.readLatency > 0 => + val ut = UnknownType + val ug = UNKNOWNGENDER + val readers = new PortSet + val writers = new PortSet + val readwriters = collection.mutable.ArrayBuffer[String]() + val namespace = Namespace(mem.readers ++ mem.writers ++ mem.readwriters) + for (w <- mem.writers ; r <- mem.readers) { + val wp = getProductTerms(connects)(memPortField(mem, w, "en")) + val rp = getProductTerms(connects)(memPortField(mem, r, "en")) + val wclk = getOrigin(connects)(memPortField(mem, w, "clk")) + val rclk = getOrigin(connects)(memPortField(mem, r, "clk")) + if (weq(wclk, rclk) && (wp exists (a => rp exists (b => checkComplement(a, b))))) { + val rw = namespace newName "rw" + val rwExp = createSubField(createRef(mem.name), rw) + readwriters += rw + readers += r + writers += w + repl(memPortField(mem, r, "clk")) = EmptyExpression + repl(memPortField(mem, r, "en")) = EmptyExpression + repl(memPortField(mem, r, "addr")) = EmptyExpression + repl(memPortField(mem, r, "data")) = createSubField(rwExp, "rdata") + repl(memPortField(mem, w, "clk")) = EmptyExpression + repl(memPortField(mem, w, "en")) = createSubField(rwExp, "wmode") + repl(memPortField(mem, w, "addr")) = EmptyExpression + repl(memPortField(mem, w, "data")) = createSubField(rwExp, "wdata") + repl(memPortField(mem, w, "mask")) = createSubField(rwExp, "wmask") + stmts += Connect(NoInfo, createSubField(rwExp, "clk"), wclk) + stmts += Connect(NoInfo, createSubField(rwExp, "en"), + DoPrim(Or, Seq(connects(memPortField(mem, r, "en")), + connects(memPortField(mem, w, "en"))), Nil, BoolType)) + stmts += Connect(NoInfo, createSubField(rwExp, "addr"), + Mux(connects(memPortField(mem, w, "en")), + connects(memPortField(mem, w, "addr")), + connects(memPortField(mem, r, "addr")), UnknownType)) + } + } + if (readwriters.isEmpty) mem else mem copy ( + readers = mem.readers filterNot readers, + writers = mem.writers filterNot writers, + readwriters = mem.readwriters ++ readwriters) + case sx => sx map inferReadWriteStmt(connects, repl, stmts) + } + + def inferReadWrite(m: DefModule) = { + val connects = getConnects(m) + val repl = new Netlist + val stmts = new Statements + (m map inferReadWriteStmt(connects, repl, stmts) + map replaceStmt(repl)) match { + case m: ExtModule => m + case m: Module => m copy (body = Block(m.body +: stmts)) + } + } + + def run(c: Circuit) = c copy (modules = c.modules map inferReadWrite) +} + +// Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl" +// To use this transform, circuit name should be annotated with its TransId. +class InferReadWrite(transID: TransID) extends Transform with SimpleRun { + def passSeq = Seq( + InferReadWritePass, + CheckInitialization, + InferTypes, + ResolveKinds, + ResolveGenders + ) + def execute(c: Circuit, map: AnnotationMap) = map get transID match { + case Some(p) => p get CircuitName(c.main) match { + case Some(InferReadWriteAnnotation(_, _)) => run(c, passSeq) + case _ => sys.error("Unexpected annotation for InferReadWrite") + } + case _ => TransformResult(c) + } +} diff --git a/src/main/scala/firrtl/passes/memlib/MemUtils.scala b/src/main/scala/firrtl/passes/memlib/MemUtils.scala new file mode 100644 index 00000000..22650c7a --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/MemUtils.scala @@ -0,0 +1,190 @@ +// See LICENSE for license details. + +package firrtl.passes + +import firrtl._ +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.PrimOps._ + +object seqCat { + def apply(args: Seq[Expression]): Expression = args.length match { + case 0 => error("Empty Seq passed to seqcat") + case 1 => args.head + case 2 => DoPrim(PrimOps.Cat, args, Nil, UIntType(UnknownWidth)) + case _ => + val (high, low) = args splitAt (args.length / 2) + DoPrim(PrimOps.Cat, Seq(seqCat(high), seqCat(low)), Nil, UIntType(UnknownWidth)) + } +} + +/** Given an expression, return an expression consisting of all sub-expressions + * concatenated (or flattened). + */ +object toBits { + def apply(e: Expression): Expression = e match { + case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiercat(ex) + case t => error("Invalid operand expression for toBits!") + } + private def hiercat(e: Expression): Expression = e.tpe match { + case t: VectorType => seqCat((0 until t.size).reverse map (i => + hiercat(WSubIndex(e, i, t.tpe, UNKNOWNGENDER)))) + case t: BundleType => seqCat(t.fields map (f => + hiercat(WSubField(e, f.name, f.tpe, UNKNOWNGENDER)))) + case t: GroundType => e + case t => error("Unknown type encountered in toBits!") + } +} + +/** Given a mask, return a bitmask corresponding to the desired datatype. + * Requirements: + * - The mask type and datatype must be equivalent, except any ground type in + * datatype must be matched by a 1-bit wide UIntType. + * - The mask must be a reference, subfield, or subindex + * The bitmask is a series of concatenations of the single mask bit over the + * length of the corresponding ground type, e.g.: + *{{{ + * wire mask: {x: UInt<1>, y: UInt<1>} + * wire data: {x: UInt<2>, y: SInt<2>} + * // this would return: + * cat(cat(mask.x, mask.x), cat(mask.y, mask.y)) + * }}} + */ +object toBitMask { + def apply(mask: Expression, dataType: Type): Expression = mask match { + case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiermask(ex, dataType) + case t => error("Invalid operand expression for toBits!") + } + private def hiermask(mask: Expression, dataType: Type): Expression = + (mask.tpe, dataType) match { + case (mt: VectorType, dt: VectorType) => + seqCat((0 until mt.size).reverse map { i => + hiermask(WSubIndex(mask, i, mt.tpe, UNKNOWNGENDER), dt.tpe) + }) + case (mt: BundleType, dt: BundleType) => + seqCat((mt.fields zip dt.fields) map { case (mf, df) => + hiermask(WSubField(mask, mf.name, mf.tpe, UNKNOWNGENDER), df.tpe) + }) + case (UIntType(width), dt: GroundType) if width == IntWidth(BigInt(1)) => + seqCat(List.fill(bitWidth(dt).intValue)(mask)) + case (mt, dt) => error("Invalid type for mask component!") + } +} + +object getWidth { + def apply(t: Type): Width = t match { + case t: GroundType => t.width + case _ => error("No width!") + } + def apply(e: Expression): Width = apply(e.tpe) +} + +object bitWidth { + def apply(dt: Type): BigInt = widthOf(dt) + private def widthOf(dt: Type): BigInt = dt match { + case t: VectorType => t.size * bitWidth(t.tpe) + case t: BundleType => t.fields.map(f => bitWidth(f.tpe)).foldLeft(BigInt(0))(_+_) + case GroundType(IntWidth(width)) => width + case t => error("Unknown type encountered in bitWidth!") + } +} + +object fromBits { + def apply(lhs: Expression, rhs: Expression): Statement = { + val fbits = lhs match { + case ex @ (_: WRef | _: WSubField | _: WSubIndex) => getPart(ex, ex.tpe, rhs, 0) + case _ => error("Invalid LHS expression for fromBits!") + } + Block(fbits._2) + } + private def getPartGround(lhs: Expression, + lhst: Type, + rhs: Expression, + offset: BigInt): (BigInt, Seq[Statement]) = { + val intWidth = bitWidth(lhst) + val sel = DoPrim(PrimOps.Bits, Seq(rhs), Seq(offset + intWidth - 1, offset), UnknownType) + (offset + intWidth, Seq(Connect(NoInfo, lhs, sel))) + } + private def getPart(lhs: Expression, + lhst: Type, + rhs: Expression, + offset: BigInt): (BigInt, Seq[Statement]) = + lhst match { + case t: VectorType => (0 until t.size foldRight (offset, Seq[Statement]())) { + case (i, (curOffset, stmts)) => + val subidx = WSubIndex(lhs, i, t.tpe, UNKNOWNGENDER) + val (tmpOffset, substmts) = getPart(subidx, t.tpe, rhs, curOffset) + (tmpOffset, stmts ++ substmts) + } + case t: BundleType => (t.fields foldRight (offset, Seq[Statement]())) { + case (f, (curOffset, stmts)) => + val subfield = WSubField(lhs, f.name, f.tpe, UNKNOWNGENDER) + val (tmpOffset, substmts) = getPart(subfield, f.tpe, rhs, curOffset) + (tmpOffset, stmts ++ substmts) + } + case t: GroundType => getPartGround(lhs, t, rhs, offset) + case t => error("Unknown type encountered in fromBits!") + } +} + +object createMask { + def apply(dt: Type): Type = dt match { + case t: VectorType => VectorType(apply(t.tpe), t.size) + case t: BundleType => BundleType(t.fields map (f => f copy (tpe=apply(f.tpe)))) + case t: GroundType => BoolType + } +} + +object createRef { + def apply(n: String, t: Type = UnknownType, k: Kind = ExpKind) = WRef(n, t, k, UNKNOWNGENDER) +} + +object createSubField { + def apply(exp: Expression, n: String) = WSubField(exp, n, field_type(exp.tpe, n), UNKNOWNGENDER) +} + +object connectFields { + def apply(lref: Expression, lname: String, rref: Expression, rname: String): Connect = + Connect(NoInfo, createSubField(lref, lname), createSubField(rref, rname)) +} + +object flattenType { + def apply(t: Type) = UIntType(IntWidth(bitWidth(t))) +} + +object MemPortUtils { + type MemPortMap = collection.mutable.HashMap[String, Expression] + type Memories = collection.mutable.ArrayBuffer[DefMemory] + type Modules = collection.mutable.ArrayBuffer[DefModule] + + def defaultPortSeq(mem: DefMemory): Seq[Field] = Seq( + Field("addr", Default, UIntType(IntWidth(ceilLog2(mem.depth) max 1))), + Field("en", Default, BoolType), + Field("clk", Default, ClockType) + ) + + // Todo: merge it with memToBundle + def memType(mem: DefMemory): Type = { + val rType = BundleType(defaultPortSeq(mem) :+ + Field("data", Flip, mem.dataType)) + val wType = BundleType(defaultPortSeq(mem) ++ Seq( + Field("data", Default, mem.dataType), + Field("mask", Default, createMask(mem.dataType)))) + val rwType = BundleType(defaultPortSeq(mem) ++ Seq( + Field("rdata", Flip, mem.dataType), + Field("wmode", Default, BoolType), + Field("wdata", Default, mem.dataType), + Field("wmask", Default, createMask(mem.dataType)))) + BundleType( + (mem.readers map (Field(_, Flip, rType))) ++ + (mem.writers map (Field(_, Flip, wType))) ++ + (mem.readwriters map (Field(_, Flip, rwType)))) + } + + def memPortField(s: DefMemory, p: String, f: String): Expression = { + val mem = WRef(s.name, memType(s), MemKind, UNKNOWNGENDER) + val t1 = field_type(mem.tpe, p) + val t2 = field_type(t1, f) + WSubField(WSubField(mem, p, t1, UNKNOWNGENDER), f, t2, UNKNOWNGENDER) + } +} diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala index 3139ef21..5fc99b9c 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala @@ -109,7 +109,8 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass { def defaultConnects(wrapperPort: WRef, bbPort: WSubField): Seq[Connect] = Seq("clk", "en", "addr") map (f => connectFields(bbPort, f, wrapperPort, f)) - // Connects the clk, en, and addr fields from the wrapperPort to the bbPort + // Generates mask bits (concatenates an aggregate to ground type) + // depending on mask granularity (# bits = data width / mask granularity) def maskBits(mask: WSubField, dataType: Type, fillMask: Boolean): Expression = if (fillMask) toBitMask(mask, dataType) else toBits(mask) diff --git a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala new file mode 100644 index 00000000..2f7126b4 --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala @@ -0,0 +1,184 @@ +/* +Copyright (c) 2014 - 2016 The Regents of the University of +California (Regents). All Rights Reserved. Redistribution and use in +source and binary forms, with or without modification, are permitted +provided that the following conditions are met: + * Redistributions of source code must retain the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer in the documentation and/or other materials + provided with the distribution. + * Neither the name of the Regents nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. +IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, +SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, +ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF +REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF +ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION +TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR +MODIFICATIONS. +*/ + +package firrtl.passes + +import firrtl._ +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.Mappers._ +import firrtl.PrimOps._ +import MemPortUtils._ + +/** This pass generates delay reigsters for memories for verilog */ +object VerilogMemDelays extends Pass { + def name = "Verilog Memory Delays" + val ug = UNKNOWNGENDER + type Netlist = collection.mutable.HashMap[String, Expression] + implicit def expToString(e: Expression): String = e.serialize + private def NOT(e: Expression) = DoPrim(Not, Seq(e), Nil, BoolType) + private def AND(e1: Expression, e2: Expression) = DoPrim(And, Seq(e1, e2), Nil, BoolType) + + def buildNetlist(netlist: Netlist)(s: Statement): Statement = { + s match { + case Connect(_, loc, expr) => kind(loc) match { + case MemKind => netlist(loc) = expr + case _ => + } + case _ => + } + s map buildNetlist(netlist) + } + + def memDelayStmt( + netlist: Netlist, + namespace: Namespace, + repl: Netlist) + (s: Statement): Statement = s map memDelayStmt(netlist, namespace, repl) match { + case sx: DefMemory => + val ports = (sx.readers ++ sx.writers).toSet + def newPortName(rw: String, p: String) = (for { + idx <- Stream from 0 + newName = s"${rw}_${p}_$idx" + if !ports(newName) + } yield newName).head + val rwMap = (sx.readwriters map (rw => + rw -> (newPortName(rw, "r"), newPortName(rw, "w")))).toMap + // 1. readwrite ports are split into read & write ports + // 2. memories are transformed into combinational + // because latency pipes are added for longer latencies + val mem = sx copy ( + readers = sx.readers ++ (sx.readwriters map (rw => rwMap(rw)._1)), + writers = sx.writers ++ (sx.readwriters map (rw => rwMap(rw)._2)), + readwriters = Nil, readLatency = 0, writeLatency = 1) + def pipe(e: Expression, // Expression to be piped + n: Int, // pipe depth + clk: Expression, // clock expression + cond: Expression // condition for pipes + ): (Expression, Seq[Statement]) = { + // returns + // 1) reference to the last pipe register + // 2) pipe registers and connects + val node = DefNode(NoInfo, namespace.newTemp, netlist(e)) + val wref = WRef(node.name, e.tpe, NodeKind, MALE) + ((0 until n) foldLeft (wref, Seq[Statement](node))){case ((ex, stmts), i) => + val name = namespace newName s"${LowerTypes.loweredName(e)}_pipe_$i" + val exx = WRef(name, e.tpe, RegKind, ug) + (exx, stmts ++ Seq(DefRegister(NoInfo, name, e.tpe, clk, zero, exx)) ++ + (if (i < n - 1 && WrappedExpression.weq(cond, one)) Seq(Connect(NoInfo, exx, ex)) else { + val condn = namespace newName s"${LowerTypes.loweredName(e)}_en" + val condx = WRef(condn, BoolType, NodeKind, FEMALE) + Seq(DefNode(NoInfo, condn, cond), + Connect(NoInfo, condx, cond), + Connect(NoInfo, exx, Mux(condx, ex, exx, e.tpe))) + }) + ) + } + } + def readPortConnects(reader: String, + clk: Expression, + en: Expression, + addr: Expression) = Seq( + Connect(NoInfo, memPortField(mem, reader, "clk"), clk), + // connect latency pipes to read ports + Connect(NoInfo, memPortField(mem, reader, "en"), en), + Connect(NoInfo, memPortField(mem, reader, "addr"), addr) + ) + def writePortConnects(writer: String, + clk: Expression, + en: Expression, + mask: Expression, + addr: Expression, + data: Expression) = Seq( + Connect(NoInfo, memPortField(mem, writer, "clk"), clk), + // connect latency pipes to write ports + Connect(NoInfo, memPortField(mem, writer, "en"), en), + Connect(NoInfo, memPortField(mem, writer, "mask"), mask), + Connect(NoInfo, memPortField(mem, writer, "addr"), addr), + Connect(NoInfo, memPortField(mem, writer, "data"), data) + ) + + + Block(mem +: ((sx.readers flatMap {reader => + // generate latency pipes for read ports (enable & addr) + val clk = netlist(memPortField(sx, reader, "clk")) + val (en, ss1) = pipe(memPortField(sx, reader, "en"), sx.readLatency - 1, clk, one) + val (addr, ss2) = pipe(memPortField(sx, reader, "addr"), sx.readLatency, clk, en) + ss1 ++ ss2 ++ readPortConnects(reader, clk, en, addr) + }) ++ (sx.writers flatMap {writer => + // generate latency pipes for write ports (enable, mask, addr, data) + val clk = netlist(memPortField(sx, writer, "clk")) + val (en, ss1) = pipe(memPortField(sx, writer, "en"), sx.writeLatency - 1, clk, one) + val (mask, ss2) = pipe(memPortField(sx, writer, "mask"), sx.writeLatency - 1, clk, one) + val (addr, ss3) = pipe(memPortField(sx, writer, "addr"), sx.writeLatency - 1, clk, one) + val (data, ss4) = pipe(memPortField(sx, writer, "data"), sx.writeLatency - 1, clk, one) + ss1 ++ ss2 ++ ss3 ++ ss4 ++ writePortConnects(writer, clk, en, mask, addr, data) + }) ++ (sx.readwriters flatMap {readwriter => + val (reader, writer) = rwMap(readwriter) + val clk = netlist(memPortField(sx, readwriter, "clk")) + // generate latency pipes for readwrite ports (enable, addr, wmode, wmask, wdata) + val (en, ss1) = pipe(memPortField(sx, readwriter, "en"), sx.readLatency - 1, clk, one) + val (wmode, ss2) = pipe(memPortField(sx, readwriter, "wmode"), sx.writeLatency - 1, clk, one) + val (wmask, ss3) = pipe(memPortField(sx, readwriter, "wmask"), sx.writeLatency - 1, clk, one) + val (wdata, ss4) = pipe(memPortField(sx, readwriter, "wdata"), sx.writeLatency - 1, clk, one) + val (raddr, ss5) = pipe(memPortField(sx, readwriter, "addr"), sx.readLatency, clk, AND(en, NOT(wmode))) + val (waddr, ss6) = pipe(memPortField(sx, readwriter, "addr"), sx.writeLatency - 1, clk, one) + repl(memPortField(sx, readwriter, "rdata")) = memPortField(mem, reader, "data") + ss1 ++ ss2 ++ ss3 ++ ss4 ++ ss5 ++ ss6 ++ + readPortConnects(reader, clk, en, raddr) ++ + writePortConnects(writer, clk, AND(en, wmode), wmask, waddr, wdata) + }))) + case sx: Connect => kind(sx.loc) match { + case MemKind => EmptyStmt + case _ => sx + } + case sx => sx + } + + def replaceExp(repl: Netlist)(e: Expression): Expression = e match { + case ex: WSubField => repl get ex match { + case Some(exx) => exx + case None => ex + } + case ex => ex map replaceExp(repl) + } + + def replaceStmt(repl: Netlist)(s: Statement): Statement = + s map replaceStmt(repl) map replaceExp(repl) + + def memDelayMod(m: DefModule): DefModule = { + val netlist = new Netlist + val namespace = Namespace(m) + val repl = new Netlist + (m map buildNetlist(netlist) + map memDelayStmt(netlist, namespace, repl) + map replaceStmt(repl)) + } + + def run(c: Circuit): Circuit = + c copy (modules = c.modules map memDelayMod) +} -- cgit v1.2.3