aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAngie2016-08-22 12:08:43 -0700
committerjackkoenig2016-09-06 00:17:18 -0700
commita47aa7f29ae191b912645c9d3f78bcb0c0072260 (patch)
tree86f2f507093a50cc329807ed6d23cbc890e55043 /src
parentc160906e9dbeec7bc2463ffed03d689897379514 (diff)
Added back support for conf writing.
* Conf file info is passed in through annotations. * A pass should have its own set of sub-arguments delimited by :
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/ReplSeqMem.scala260
-rw-r--r--src/main/scala/firrtl/passes/ReplaceMemMacros.scala9
2 files changed, 35 insertions, 234 deletions
diff --git a/src/main/scala/firrtl/passes/ReplSeqMem.scala b/src/main/scala/firrtl/passes/ReplSeqMem.scala
index 2bbcb926..72b69f3b 100644
--- a/src/main/scala/firrtl/passes/ReplSeqMem.scala
+++ b/src/main/scala/firrtl/passes/ReplSeqMem.scala
@@ -1,19 +1,11 @@
package firrtl.passes
import com.typesafe.scalalogging.LazyLogging
-import scala.collection.mutable.{ArrayBuffer,HashMap}
-
import firrtl._
import firrtl.ir._
-import firrtl.Mappers._
-import firrtl.Utils._
import Annotations._
-import firrtl.PrimOps._
-import firrtl.WrappedExpression._
-
import java.io.Writer
-
-import scala.util.matching.Regex
+import AnalysisUtils._
sealed trait PassOption
case object InputConfigFileName extends PassOption
@@ -47,9 +39,20 @@ object PassConfigUtil {
}
-class OutputWriter(filename: String) {
+class ConfWriter(filename: String) {
val outputBuffer = new java.io.CharArrayWriter
- def append(s: String) = outputBuffer.append(s)
+ def append(m: DefMemory) = {
+ // legacy
+ val maskGran = getInfo(m.info,"maskGran")
+ val writers = m.writers map (x => if (maskGran == None) "write" else "mwrite")
+ val readers = List.fill(m.readers.length)("read")
+ val readwriters = m.readwriters map (x => if (maskGran == None) "rw" else "mrw")
+ val ports = (writers ++ readers ++ readwriters).mkString(",")
+ val maskGranConf = if (maskGran == None) "" else s"mask_gran ${maskGran.get}"
+ val width = bitWidth(m.dataType)
+ val conf = s"name ${m.name}_ext depth ${m.depth} width ${width} ports ${ports} ${maskGranConf} \n"
+ outputBuffer.append(conf)
+ }
def serialize = {
val outputFile = new java.io.PrintWriter(filename)
outputFile.write(outputBuffer.toString)
@@ -73,239 +76,36 @@ Required Arguments:
-c<compiler> Specify the target circuit
Optional Arguments:
- -i<filename> Specify the input configuration file
+ -i<filename> Specify the input configuration file (for additional optimizations)
"""
val passOptions = PassConfigUtil.getPassOptions(t,usage)
- val outputConfig = passOptions.getOrElse(OutputConfigFileName, throw new Exception("No output config file provided for ReplSeqMem!" + usage))
- val passCircuit = passOptions.getOrElse(PassCircuitName, throw new Exception("No circuit name specified for ReplSeqMem!" + usage))
+ val outputConfig = passOptions.getOrElse(
+ OutputConfigFileName,
+ throw new Exception("No output config file provided for ReplSeqMem!" + usage)
+ )
+ val passCircuit = passOptions.getOrElse(
+ PassCircuitName,
+ throw new Exception("No circuit name specified for ReplSeqMem!" + usage)
+ )
val target = CircuitName(passCircuit)
def duplicate(n: Named) = this.copy(t=t.replace("-c:"+passCircuit,"-c:"+n.name))
}
-class ReplSeqMemPass(out: OutputWriter) extends Pass {
-
- def name = "Replace Sequential Memories with Blackboxes + Configuration File"
-
- trait WritePortChar {
- def name: String
- def useMask: Boolean
- def maskGran: Option[BigInt]
- require( (useMask && (maskGran != None)) || (!useMask), "Must specify a mask granularity if write mask is desired" )
- }
-
- case class PortForWrite(
- name: String,
- useMask: Boolean = false,
- maskGran: Option[BigInt] = None
- ) extends WritePortChar
-
- case class PortForReadWrite(
- name: String,
- useMask: Boolean = false,
- maskGran: Option[BigInt] = None
- ) extends WritePortChar
-
- case class PortForRead(
- name: String
- )
-
- // vendor agnostic configuration
- case class SMem(
- m: DefMemory,
- // names of read ports
- readPorts: Seq[PortForRead],
- // write ports
- writePorts: Seq[PortForWrite],
- // read/write ports
- readWritePorts: Seq[PortForReadWrite]
- ){
- require (
- if (readWritePorts.isEmpty) writePorts.nonEmpty && readPorts.nonEmpty else writePorts.isEmpty && readPorts.isEmpty,
- "Need at least one set of read, write ports if no RW port is specified. A RW port must be standalone"
- )
- require (readWritePorts.length < 2, "Cannot have more than 1 readwrite port")
- def name = m.name
- def dataType = m.dataType
- def depth = m.depth
- def writeLatency = m.writeLatency
- def readLatency = m.readLatency
- def numReaders = readPorts.length
- def numWriters = writePorts.length
- def numRWriters = readWritePorts.length
- def rPortMap = readPorts.zipWithIndex map { case (p,i) => p -> s"R$i" }
- def wPortMap = writePorts.zipWithIndex map { case (p,i) => p.name -> s"W$i" }
- def rwPortMap = readWritePorts.zipWithIndex map { case (p,i) => p.name -> s"RW$i" }
- def width = bitWidth(dataType)
- def serialize = {
- // for backwards compatibility with old conf format
- val writers = writePorts map (x => if (x.useMask) "mwrite" else "write")
- val readers = List.fill(numReaders)("read")
- val readwriters = readWritePorts map (x => if(x.useMask) "mrw" else "rw")
- val ports = (writers ++ readers ++ readwriters).mkString(",")
- // old conf file only supported 1 mask_gran
- val maskGran = (writePorts ++ readWritePorts) map (_.maskGran.getOrElse(0))
- val maskGranConf = if (maskGran.head == 0) "" else s"mask_gran ${maskGran.head}"
- s"name ${name} depth ${depth} width ${width} ports ${ports} ${maskGranConf} \n"
- }
- def eq(m: SMem) = {
- // TODO: Condition on read under write
- val wpIndivEq = writePorts zip m.writePorts map {case(a,b) => a.maskGran == b.maskGran}
- val wpEq = wpIndivEq.foldLeft(true)(_ && _)
- val rwpIndivEq = readWritePorts zip m.readWritePorts map {case(a,b) => a.maskGran == b.maskGran}
- val rwpEq = rwpIndivEq.foldLeft(true)(_ && _)
- (dataType == m.dataType) &&
- (depth == m.depth) &&
- (writeLatency == m.writeLatency) &&
- (readLatency == m.readLatency) &&
- (numReaders == m.numReaders) &&
- (wpEq && rwpEq)
- }
- def getInterfacePorts = MemPortUtils.memToBundle(m).fields.map(f => Port(NoInfo, f.name, Input, f.tpe))
- }
-
- def analyzeMemsInModule(m: Module): Seq[SMem] = {
-
- val connects = HashMap[String, Expression]()
- val mems = ArrayBuffer[SMem]()
-
- // swiped from InferRW
- def findConnects(s: Statement): Unit = s match {
- case s: Connect =>
- connects(s.loc.serialize) = s.expr
- case s: PartialConnect =>
- connects(s.loc.serialize) = s.expr
- case s: DefNode =>
- connects(s.name) = s.value
- case s: Block =>
- s.stmts foreach findConnects
- case _ =>
- }
-
- def findConnectOriginFromExp(e: Expression): Seq[Expression] = e match {
- // matches how wmode, wmask, write_en are assigned (from Chirrtl)
- // in case no ConstProp is performed before this pass
- case Mux(cond, tv, fv, _) if we(tv) == we(one) && we(fv) == we(zero) =>
- cond +: findConnectOrigin(cond.serialize)
- // visit connected nodes to references
- case _: WRef | _: SubField | _: SubIndex | _: SubAccess =>
- e +: findConnectOrigin(e.serialize)
- // backward searches until a PrimOp or Literal appears -->
- // Literal: you've reached origin
- // PrimOp: you're not simply doing propagation anymore
- // NOTE: not a catch-all!!!
- case _ => List(e)
- }
-
- // only capable of searching for origin in the same module
- def findConnectOrigin(node: String): Seq[Expression] = {
- if (connects contains node) findConnectOriginFromExp(connects(node))
- else Nil
- }
-
- // returns None if wen = wmask bits or wmask bits all = 1; otherwise returns # of mask bits
- def getMaskBits(wen: String, wmask: String): Option[Int] = {
- val wenOrigin = findConnectOrigin(wen)
- // find all mask bits
- val wmaskOrigin = connects.keys.toSeq filter (_.startsWith(wmask)) map findConnectOrigin
- val bitEq = wmaskOrigin map (wenOrigin intersect _) map (_.length > 0)
- // when all wmask bits are equal to wmode, wmask is redundant
- val eq = bitEq.foldLeft(true)(_ && _)
- val wmaskBitOne = wmaskOrigin map(_ contains one)
- // if all wmask bits = 1, then wmask is redundant
- val wmaskOne = wmaskBitOne.foldLeft(true)(_ && _)
- if (eq || wmaskOne) None else Some(wmaskOrigin.length)
- }
-
- def findMemInsts(s: Statement): Unit = s match {
- // only find smems
- case m: DefMemory if m.readLatency > 0 =>
- val dataBits = bitWidth(m.dataType)
- val rwPorts = m.readwriters map (w => {
- val maskBits = getMaskBits(s"${m.name}.$w.wmode",s"${m.name}.$w.wmask")
- if (maskBits == None) PortForReadWrite(name = w)
- else PortForReadWrite(name = w, useMask = true, maskGran = Some(dataBits/maskBits.get))
- })
- val wPorts = m.writers map (w => {
- val maskBits = getMaskBits(s"${m.name}.$w.en",s"${m.name}.$w.mask")
- if (maskBits == None) PortForWrite(name = w)
- else PortForWrite(name = w, useMask = true, maskGran = Some(dataBits/maskBits.get))
- })
- val smemInfo = SMem(
- m = m,
- readPorts = m.readers map(r => PortForRead(name = r)),
- writePorts = wPorts,
- readWritePorts = rwPorts
- )
- mems += smemInfo
- case b: Block => b.stmts foreach findMemInsts
- case _ =>
- }
- findConnects(m.body)
- findMemInsts(m.body)
- mems.toSeq
- }
-
- def run(c: Circuit) = {
- lazy val moduleNamespace = Namespace(c)
-
- val uniqueMems = ArrayBuffer[SMem]()
- val mems = ArrayBuffer[SMem]()
- def analyzeMemsInCircuit(c: Circuit) = {
- c.modules foreach {
- case m: Module => mems ++= analyzeMemsInModule(m)
- case m: ExtModule =>
- }
- mems map {m =>
- val memProto = uniqueMems.find(_.eq(m))
- if (memProto == None) {
- uniqueMems += m
- m.name -> m
- }
- else m.name -> memProto.get.copy(m=m.m)
- }
- }
- val memMap = analyzeMemsInCircuit(c)
- val newMods = mems map (m => ExtModule(m.m.info,m.name,m.getInterfacePorts))
-
- def replaceMemInstsInCircuit(c: Circuit) = {
- def replaceMemInstsInModule(m: Module) = {
- def findMemInsts(s: Statement): Statement = s match {
- case m: DefMemory if m.readLatency > 0 => WDefInstance(m.info, m.name, m.name, UnknownType)
- case b: Block => Block(b.stmts map findMemInsts)
- case s => s
- }
- m.copy(body = findMemInsts(m.body))
- }
- c.modules map {
- case m: Module => replaceMemInstsInModule(m)
- case m: ExtModule => m
- }
- }
-
- uniqueMems foreach { m =>
- moduleNamespace.newName(m.name)
- moduleNamespace.newName(m.name + "_ext")
- out.append(m.serialize)
- }
- out.serialize
- c.copy(modules = replaceMemInstsInCircuit(c) ++ newMods)
- }
-
-}
-
class ReplSeqMem(transID: TransID) extends Transform with LazyLogging {
def execute(circuit:Circuit, map: AnnotationMap) =
map get transID match {
case Some(p) => p get CircuitName(circuit.main) match {
case Some(ReplSeqMemAnnotation(t, _)) => {
- val outConfigFile = PassConfigUtil.getPassOptions(t).get(OutputConfigFileName).get
+ val outConfigFile = new ConfWriter(PassConfigUtil.getPassOptions(t).get(OutputConfigFileName).get)
TransformResult(
(
Seq(
Legalize,
- new ReplSeqMemPass(new OutputWriter(outConfigFile)),
+ AnnotateMemMacros,
+ UpdateDuplicateMemMacros,
+ new ReplaceMemMacros(outConfigFile),
RemoveEmpty,
CheckInitialization,
ResolveKinds, // Must be run for the transform to work!
@@ -326,8 +126,4 @@ class ReplSeqMem(transID: TransID) extends Transform with LazyLogging {
}
case _ => TransformResult(circuit, None, Some(map))
}
-}
-
-// Eliminate extra modules
-// Tag modules
-// connect internals \ No newline at end of file
+} \ No newline at end of file
diff --git a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/ReplaceMemMacros.scala
index fedc4c56..cc74a865 100644
--- a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala
+++ b/src/main/scala/firrtl/passes/ReplaceMemMacros.scala
@@ -8,9 +8,9 @@ import firrtl._
import firrtl.Utils._
import MemPortUtils._
-object ReplaceMemMacros extends Pass {
+class ReplaceMemMacros(writer: ConfWriter) extends Pass {
- def name = "Replace memories with black box wrappers (optimizes when write mask isn't needed)"
+ def name = "Replace memories with black box wrappers (optimizes when write mask isn't needed) + configuration file"
def run(c: Circuit) = {
@@ -57,6 +57,8 @@ object ReplaceMemMacros extends Pass {
case m: ExtModule => m
}
+ // print conf
+ writer.serialize
c.copy(modules = updatedMods ++ memMods.toSeq)
}
@@ -86,6 +88,9 @@ object ReplaceMemMacros extends Pass {
//println(wrapper.body.serialize)
val bb = ExtModule(m.info,bbName,bbioPorts)
+
+ // add to conf file
+ writer.append(m)
Seq(bb,wrapper)
}