aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala
diff options
context:
space:
mode:
authorAngie2016-08-19 17:00:11 -0700
committerjackkoenig2016-09-06 00:17:17 -0700
commit0d5fa689a45693bf6db9bc6d9dc3f150bc3ff4b8 (patch)
treefa2f12be17f3d1c2075a3af67e40a2b8aeaa2f55 /src/main/scala
parentc1ca57452af8adc00bef92e2ddf8984c8cde5620 (diff)
Added starter code for SMem replacement
Diffstat (limited to 'src/main/scala')
-rw-r--r--src/main/scala/firrtl/Driver.scala12
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala2
-rw-r--r--src/main/scala/firrtl/passes/ReplSeqMem.scala276
3 files changed, 289 insertions, 1 deletions
diff --git a/src/main/scala/firrtl/Driver.scala b/src/main/scala/firrtl/Driver.scala
index 5969562f..79f2fdaf 100644
--- a/src/main/scala/firrtl/Driver.scala
+++ b/src/main/scala/firrtl/Driver.scala
@@ -57,6 +57,12 @@ Optional Arguments:
Supported modes: ignore, use, gen, append
--inferRW <circuit> Enable readwrite port inference for the target circuit
--inline <module>|<instance> Inline a module (e.g. "MyModule") or instance (e.g. "MyModule.myinstance")
+
+ --replSeqMem -c:<circuit>:-i<filename>:-o<filename>
+ *** Replace sequential memories with blackboxes + configuration file
+ *** Input configuration file optional
+ *** Note: sub-arguments to --replSeqMem should be delimited by : and not white space!
+
[--help|-h] Print usage string
"""
@@ -74,12 +80,16 @@ Optional Arguments:
def handleInferRWOption(value: String) =
passes.InferReadWriteAnnotation(value, TransID(-1))
+ def handleReplSeqMem(value: String) =
+ passes.ReplSeqMemAnnotation(value, TransID(-2))
+
run(args: Array[String],
Map( "high" -> new HighFirrtlCompiler(),
"low" -> new LowFirrtlCompiler(),
"verilog" -> new VerilogCompiler()),
Map("--inline" -> handleInlineOption _,
- "--inferRW" -> handleInferRWOption _),
+ "--inferRW" -> handleInferRWOption _,
+ "--replSeqMem" -> handleReplSeqMem _),
usage
)
}
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index 7c239b10..f9a5864c 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -189,6 +189,7 @@ class LowFirrtlCompiler extends Compiler {
new ResolveAndCheck(),
new HighFirrtlToMiddleFirrtl(),
new passes.InferReadWrite(TransID(-1)),
+ new passes.ReplSeqMem(TransID(-2)),
new MiddleFirrtlToLowFirrtl(),
new EmitFirrtl(writer)
)
@@ -202,6 +203,7 @@ class VerilogCompiler extends Compiler {
new ResolveAndCheck(),
new HighFirrtlToMiddleFirrtl(),
new passes.InferReadWrite(TransID(-1)),
+ new passes.ReplSeqMem(TransID(-2)),
new MiddleFirrtlToLowFirrtl(),
new passes.InlineInstances(TransID(0)),
new EmitVerilogFromLowFirrtl(writer)
diff --git a/src/main/scala/firrtl/passes/ReplSeqMem.scala b/src/main/scala/firrtl/passes/ReplSeqMem.scala
new file mode 100644
index 00000000..3457febb
--- /dev/null
+++ b/src/main/scala/firrtl/passes/ReplSeqMem.scala
@@ -0,0 +1,276 @@
+package firrtl.passes
+
+import com.typesafe.scalalogging.LazyLogging
+import scala.collection.mutable.{ArrayBuffer,HashMap}
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Mappers._
+import firrtl.Utils._
+import Annotations._
+import firrtl.PrimOps._
+import firrtl.WrappedExpression._
+
+import java.io.Writer
+
+import scala.util.matching.Regex
+
+case class ReplSeqMemAnnotation(t: String, tID: TransID)
+ extends Annotation with Loose with Unstable {
+
+ val usage = """
+[Optional] ReplSeqMem
+ Pass to replace sequential memories with blackboxes + configuration file
+
+Usage:
+ --replSeqMem -c:<circuit>:-i<filename>:-o<filename>
+ *** Note: sub-arguments to --replSeqMem should be delimited by : and not white space!
+
+Required Arguments:
+ -o<filename> Specify the output configuration file
+ -c<compiler> Specify the target circuit
+
+Optional Arguments:
+ -i<filename> Specify the input configuration file
+"""
+
+ sealed trait PassOption
+ case object InputConfigFileName extends PassOption
+ case object OutputConfigFileName extends PassOption
+ case object PassCircuitName extends PassOption
+
+ type PassOptionMap = Map[PassOption, String]
+
+ // can't use space to delimit sub arguments (otherwise, Driver.scala will throw error)
+ val passArgList = t.split(":").toList
+
+ def nextPassOption(map: PassOptionMap, list: List[String]): PassOptionMap = {
+ list match {
+ case Nil => map
+ case "-i" :: value :: tail =>
+ nextPassOption(map + (InputConfigFileName -> value), tail)
+ case "-o" :: value :: tail =>
+ nextPassOption(map + (OutputConfigFileName -> value), tail)
+ case "-c" :: value :: tail =>
+ nextPassOption(map + (PassCircuitName -> value), tail)
+ case option :: tail =>
+ throw new Exception("Unknown option " + option + usage)
+ }
+ }
+
+ val passOptions = nextPassOption(Map[PassOption, String](), passArgList)
+ val inputConfig = passOptions.getOrElse(InputConfigFileName, throw new Exception("No input config file provided for ReplSeqMem!" + usage))
+ val outputConfig = passOptions.getOrElse(OutputConfigFileName, throw new Exception("No output config file provided for ReplSeqMem!" + usage))
+ val passCircuit = passOptions.getOrElse(PassCircuitName, throw new Exception("No circuit name specified for ReplSeqMem!" + usage))
+
+ val target = CircuitName(passCircuit)
+ def duplicate(n: Named) = this.copy(t=t.replace("-c:"+passCircuit,"-c:"+n.name))
+
+}
+
+object ReplSeqMem extends Pass {
+
+ def name = "Replace Sequential Memories with Blackboxes + Configuration File"
+
+ trait WritePortChar {
+ def name: String
+ def useMask: Boolean
+ def maskGran: Option[BigInt]
+ require( (useMask && (maskGran != None)) || (!useMask), "Must specify a mask granularity if write mask is desired" )
+ }
+
+ case class PortForWrite(
+ name: String,
+ useMask: Boolean = false,
+ maskGran: Option[BigInt] = None
+ ) extends WritePortChar
+
+ case class PortForReadWrite(
+ name: String,
+ useMask: Boolean = false,
+ maskGran: Option[BigInt] = None
+ ) extends WritePortChar
+
+ case class PortForRead(
+ name: String
+ )
+
+ // vendor agnostic configuration
+ case class SMem(
+ m: DefMemory,
+ // names of read ports
+ readPorts: Seq[PortForRead],
+ // write ports
+ writePorts: Seq[PortForWrite],
+ // read/write ports
+ readWritePorts: Seq[PortForReadWrite]
+ ){
+ require (
+ if (readWritePorts.isEmpty) writePorts.nonEmpty && readPorts.nonEmpty else writePorts.isEmpty && readPorts.isEmpty,
+ "Need at least one set of read, write ports if no RW port is specified. A RW port must be standalone"
+ )
+ require (readWritePorts.length < 2, "Cannot have more than 1 readwrite port")
+ def name = m.name
+ def dataType = m.dataType
+ def depth = m.depth
+ def writeLatency = m.writeLatency
+ def readLatency = m.readLatency
+ def numReaders = readPorts.length
+ def numWriters = writePorts.length
+ def numRWriters = readWritePorts.length
+ def rPortMap = readPorts.zipWithIndex map { case (p,i) => p -> s"R$i" }
+ def wPortMap = writePorts.zipWithIndex map { case (p,i) => p.name -> s"W$i" }
+ def rwPortMap = readWritePorts.zipWithIndex map { case (p,i) => p.name -> s"RW$i" }
+ def width = bitWidth(dataType)
+ def serialize = {
+ // for backwards compatibility with old conf format
+ val writers = writePorts map (x => if (x.useMask) "mwrite" else "write")
+ val readers = List.fill(numReaders)("read")
+ val readwriters = readWritePorts map (x => if(x.useMask) "mrw" else "rw")
+ val ports = (writers ++ readers ++ readwriters).mkString(",")
+ // old conf file only supported 1 mask_gran
+ val maskGran = (writePorts ++ readWritePorts) map (_.maskGran.getOrElse(0))
+ val maskGranConf = if (maskGran.head == 0) "" else s"mask_gran ${maskGran.head}"
+ s"name ${name} depth ${depth} width ${width} ports ${ports} ${maskGranConf} \n"
+ }
+ def eq(m: SMem) = {
+ // TODO: Condition on read under write
+ val wpIndivEq = writePorts zip m.writePorts map {case(a,b) => a.maskGran == b.maskGran}
+ val wpEq = wpIndivEq.foldLeft(true)(_ && _)
+ val rwpIndivEq = readWritePorts zip m.readWritePorts map {case(a,b) => a.maskGran == b.maskGran}
+ val rwpEq = rwpIndivEq.foldLeft(true)(_ && _)
+ (dataType == m.dataType) &&
+ (depth == m.depth) &&
+ (writeLatency == m.writeLatency) &&
+ (readLatency == m.readLatency) &&
+ (numReaders == m.numReaders) &&
+ (wpEq && rwpEq)
+ }
+ }
+
+ def analyzeMemsInModule(m: Module): Seq[SMem] = {
+
+ val connects = HashMap[String, Expression]()
+ val mems = ArrayBuffer[SMem]()
+
+ // swiped from InferRW
+ def findConnects(s: Statement): Unit = s match {
+ case s: Connect =>
+ connects(s.loc.serialize) = s.expr
+ case s: PartialConnect =>
+ connects(s.loc.serialize) = s.expr
+ case s: DefNode =>
+ connects(s.name) = s.value
+ case s: Block =>
+ s.stmts foreach findConnects
+ case _ =>
+ }
+
+ def findConnectOriginFromExp(e: Expression): Seq[Expression] = e match {
+ // matches how wmode, wmask, write_en are assigned (from Chirrtl)
+ // in case no ConstProp is performed before this pass
+ case Mux(cond, tv, fv, _) if we(tv) == we(one) && we(fv) == we(zero) =>
+ cond +: findConnectOrigin(cond.serialize)
+ // visit connected nodes to references
+ case _: WRef | _: SubField | _: SubIndex | _: SubAccess =>
+ e +: findConnectOrigin(e.serialize)
+ // backward searches until a PrimOp or Literal appears -->
+ // Literal: you've reached origin
+ // PrimOp: you're not simply doing propagation anymore
+ // NOTE: not a catch-all!!!
+ case _ => List(e)
+ }
+
+ // only capable of searching for origin in the same module
+ def findConnectOrigin(node: String): Seq[Expression] = {
+ if (connects contains node) findConnectOriginFromExp(connects(node))
+ else Nil
+ }
+
+ // returns None if wen = wmask bits or wmask bits all = 1; otherwise returns # of mask bits
+ def getMaskBits(wen: String, wmask: String): Option[Int] = {
+ val wenOrigin = findConnectOrigin(wen)
+ // find all mask bits
+ val wmaskOrigin = connects.keys.toSeq filter (_.startsWith(wmask)) map findConnectOrigin
+ val bitEq = wmaskOrigin map (wenOrigin intersect _) map (_.length > 0)
+ // when all wmask bits are equal to wmode, wmask is redundant
+ val eq = bitEq.foldLeft(true)(_ && _)
+ val wmaskBitOne = wmaskOrigin map(_ contains one)
+ // if all wmask bits = 1, then wmask is redundant
+ val wmaskOne = wmaskBitOne.foldLeft(true)(_ && _)
+ if (eq || wmaskOne) None else Some(wmaskOrigin.length)
+ }
+
+ def findMemInsts(s: Statement): Unit = s match {
+ // only find smems
+ case m: DefMemory if m.readLatency > 0 =>
+ val dataBits = bitWidth(m.dataType)
+ val rwPorts = m.readwriters map (w => {
+ val maskBits = getMaskBits(s"${m.name}.$w.wmode",s"${m.name}.$w.wmask")
+ if (maskBits == None) PortForReadWrite(name = w)
+ else PortForReadWrite(name = w, useMask = true, maskGran = Some(dataBits/maskBits.get))
+ })
+ val wPorts = m.writers map (w => {
+ val maskBits = getMaskBits(s"${m.name}.$w.en",s"${m.name}.$w.mask")
+ if (maskBits == None) PortForWrite(name = w)
+ else PortForWrite(name = w, useMask = true, maskGran = Some(dataBits/maskBits.get))
+ })
+ val smemInfo = SMem(
+ m = m,
+ readPorts = m.readers map(r => PortForRead(name = r)),
+ writePorts = wPorts,
+ readWritePorts = rwPorts
+ )
+ mems += smemInfo
+ case b: Block => b.stmts foreach findMemInsts
+ case _ =>
+ }
+ findConnects(m.body)
+ findMemInsts(m.body)
+ mems.toSeq
+ }
+
+ def run(c: Circuit) = {
+ val uniqueMems = ArrayBuffer[SMem]()
+ def analyzeMemsInCircuit(c: Circuit) = {
+ val mems = ArrayBuffer[SMem]()
+ c.modules foreach { _ match {
+ case m: Module => mems ++= analyzeMemsInModule(m)
+ case m: ExtModule =>
+ }}
+ mems map {m =>
+ val memProto = uniqueMems.find(_.eq(m))
+ if (memProto == None) {
+ uniqueMems += m
+ m.name -> m
+ }
+ else m.name -> memProto.get.copy(m=m.m)
+ }
+ }
+ val memMap = analyzeMemsInCircuit(c)
+ println(memMap)
+ c
+ }
+
+}
+
+class ReplSeqMem(transID: TransID) extends Transform with LazyLogging {
+ def execute(circuit:Circuit, map: AnnotationMap) =
+ map get transID match {
+ case Some(p) => p get CircuitName(circuit.main) match {
+ case Some(ReplSeqMemAnnotation(_, _)) => TransformResult((Seq(
+ Legalize,
+ ReplSeqMem,
+ CheckInitialization,
+ ResolveKinds,
+ InferTypes,
+ ResolveGenders) foldLeft circuit){ (c, pass) =>
+ val x = Utils.time(pass.name)(pass run c)
+ logger debug x.serialize
+ x
+ }, None, Some(map))
+ case _ => TransformResult(circuit, None, Some(map))
+ }
+ case _ => TransformResult(circuit, None, Some(map))
+ }
+} \ No newline at end of file