From 0d5fa689a45693bf6db9bc6d9dc3f150bc3ff4b8 Mon Sep 17 00:00:00 2001 From: Angie Date: Fri, 19 Aug 2016 17:00:11 -0700 Subject: Added starter code for SMem replacement --- src/main/scala/firrtl/Driver.scala | 12 +- src/main/scala/firrtl/LoweringCompilers.scala | 2 + src/main/scala/firrtl/passes/ReplSeqMem.scala | 276 ++++++++++++++++++++++++++ 3 files changed, 289 insertions(+), 1 deletion(-) create mode 100644 src/main/scala/firrtl/passes/ReplSeqMem.scala (limited to 'src/main') diff --git a/src/main/scala/firrtl/Driver.scala b/src/main/scala/firrtl/Driver.scala index 5969562f..79f2fdaf 100644 --- a/src/main/scala/firrtl/Driver.scala +++ b/src/main/scala/firrtl/Driver.scala @@ -57,6 +57,12 @@ Optional Arguments: Supported modes: ignore, use, gen, append --inferRW Enable readwrite port inference for the target circuit --inline | Inline a module (e.g. "MyModule") or instance (e.g. "MyModule.myinstance") + + --replSeqMem -c::-i:-o + *** Replace sequential memories with blackboxes + configuration file + *** Input configuration file optional + *** Note: sub-arguments to --replSeqMem should be delimited by : and not white space! + [--help|-h] Print usage string """ @@ -74,12 +80,16 @@ Optional Arguments: def handleInferRWOption(value: String) = passes.InferReadWriteAnnotation(value, TransID(-1)) + def handleReplSeqMem(value: String) = + passes.ReplSeqMemAnnotation(value, TransID(-2)) + run(args: Array[String], Map( "high" -> new HighFirrtlCompiler(), "low" -> new LowFirrtlCompiler(), "verilog" -> new VerilogCompiler()), Map("--inline" -> handleInlineOption _, - "--inferRW" -> handleInferRWOption _), + "--inferRW" -> handleInferRWOption _, + "--replSeqMem" -> handleReplSeqMem _), usage ) } diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index 7c239b10..f9a5864c 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -189,6 +189,7 @@ class LowFirrtlCompiler extends Compiler { new ResolveAndCheck(), new HighFirrtlToMiddleFirrtl(), new passes.InferReadWrite(TransID(-1)), + new passes.ReplSeqMem(TransID(-2)), new MiddleFirrtlToLowFirrtl(), new EmitFirrtl(writer) ) @@ -202,6 +203,7 @@ class VerilogCompiler extends Compiler { new ResolveAndCheck(), new HighFirrtlToMiddleFirrtl(), new passes.InferReadWrite(TransID(-1)), + new passes.ReplSeqMem(TransID(-2)), new MiddleFirrtlToLowFirrtl(), new passes.InlineInstances(TransID(0)), new EmitVerilogFromLowFirrtl(writer) diff --git a/src/main/scala/firrtl/passes/ReplSeqMem.scala b/src/main/scala/firrtl/passes/ReplSeqMem.scala new file mode 100644 index 00000000..3457febb --- /dev/null +++ b/src/main/scala/firrtl/passes/ReplSeqMem.scala @@ -0,0 +1,276 @@ +package firrtl.passes + +import com.typesafe.scalalogging.LazyLogging +import scala.collection.mutable.{ArrayBuffer,HashMap} + +import firrtl._ +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.Utils._ +import Annotations._ +import firrtl.PrimOps._ +import firrtl.WrappedExpression._ + +import java.io.Writer + +import scala.util.matching.Regex + +case class ReplSeqMemAnnotation(t: String, tID: TransID) + extends Annotation with Loose with Unstable { + + val usage = """ +[Optional] ReplSeqMem + Pass to replace sequential memories with blackboxes + configuration file + +Usage: + --replSeqMem -c::-i:-o + *** Note: sub-arguments to --replSeqMem should be delimited by : and not white space! + +Required Arguments: + -o Specify the output configuration file + -c Specify the target circuit + +Optional Arguments: + -i Specify the input configuration file +""" + + sealed trait PassOption + case object InputConfigFileName extends PassOption + case object OutputConfigFileName extends PassOption + case object PassCircuitName extends PassOption + + type PassOptionMap = Map[PassOption, String] + + // can't use space to delimit sub arguments (otherwise, Driver.scala will throw error) + val passArgList = t.split(":").toList + + def nextPassOption(map: PassOptionMap, list: List[String]): PassOptionMap = { + list match { + case Nil => map + case "-i" :: value :: tail => + nextPassOption(map + (InputConfigFileName -> value), tail) + case "-o" :: value :: tail => + nextPassOption(map + (OutputConfigFileName -> value), tail) + case "-c" :: value :: tail => + nextPassOption(map + (PassCircuitName -> value), tail) + case option :: tail => + throw new Exception("Unknown option " + option + usage) + } + } + + val passOptions = nextPassOption(Map[PassOption, String](), passArgList) + val inputConfig = passOptions.getOrElse(InputConfigFileName, throw new Exception("No input config file provided for ReplSeqMem!" + usage)) + val outputConfig = passOptions.getOrElse(OutputConfigFileName, throw new Exception("No output config file provided for ReplSeqMem!" + usage)) + val passCircuit = passOptions.getOrElse(PassCircuitName, throw new Exception("No circuit name specified for ReplSeqMem!" + usage)) + + val target = CircuitName(passCircuit) + def duplicate(n: Named) = this.copy(t=t.replace("-c:"+passCircuit,"-c:"+n.name)) + +} + +object ReplSeqMem extends Pass { + + def name = "Replace Sequential Memories with Blackboxes + Configuration File" + + trait WritePortChar { + def name: String + def useMask: Boolean + def maskGran: Option[BigInt] + require( (useMask && (maskGran != None)) || (!useMask), "Must specify a mask granularity if write mask is desired" ) + } + + case class PortForWrite( + name: String, + useMask: Boolean = false, + maskGran: Option[BigInt] = None + ) extends WritePortChar + + case class PortForReadWrite( + name: String, + useMask: Boolean = false, + maskGran: Option[BigInt] = None + ) extends WritePortChar + + case class PortForRead( + name: String + ) + + // vendor agnostic configuration + case class SMem( + m: DefMemory, + // names of read ports + readPorts: Seq[PortForRead], + // write ports + writePorts: Seq[PortForWrite], + // read/write ports + readWritePorts: Seq[PortForReadWrite] + ){ + require ( + if (readWritePorts.isEmpty) writePorts.nonEmpty && readPorts.nonEmpty else writePorts.isEmpty && readPorts.isEmpty, + "Need at least one set of read, write ports if no RW port is specified. A RW port must be standalone" + ) + require (readWritePorts.length < 2, "Cannot have more than 1 readwrite port") + def name = m.name + def dataType = m.dataType + def depth = m.depth + def writeLatency = m.writeLatency + def readLatency = m.readLatency + def numReaders = readPorts.length + def numWriters = writePorts.length + def numRWriters = readWritePorts.length + def rPortMap = readPorts.zipWithIndex map { case (p,i) => p -> s"R$i" } + def wPortMap = writePorts.zipWithIndex map { case (p,i) => p.name -> s"W$i" } + def rwPortMap = readWritePorts.zipWithIndex map { case (p,i) => p.name -> s"RW$i" } + def width = bitWidth(dataType) + def serialize = { + // for backwards compatibility with old conf format + val writers = writePorts map (x => if (x.useMask) "mwrite" else "write") + val readers = List.fill(numReaders)("read") + val readwriters = readWritePorts map (x => if(x.useMask) "mrw" else "rw") + val ports = (writers ++ readers ++ readwriters).mkString(",") + // old conf file only supported 1 mask_gran + val maskGran = (writePorts ++ readWritePorts) map (_.maskGran.getOrElse(0)) + val maskGranConf = if (maskGran.head == 0) "" else s"mask_gran ${maskGran.head}" + s"name ${name} depth ${depth} width ${width} ports ${ports} ${maskGranConf} \n" + } + def eq(m: SMem) = { + // TODO: Condition on read under write + val wpIndivEq = writePorts zip m.writePorts map {case(a,b) => a.maskGran == b.maskGran} + val wpEq = wpIndivEq.foldLeft(true)(_ && _) + val rwpIndivEq = readWritePorts zip m.readWritePorts map {case(a,b) => a.maskGran == b.maskGran} + val rwpEq = rwpIndivEq.foldLeft(true)(_ && _) + (dataType == m.dataType) && + (depth == m.depth) && + (writeLatency == m.writeLatency) && + (readLatency == m.readLatency) && + (numReaders == m.numReaders) && + (wpEq && rwpEq) + } + } + + def analyzeMemsInModule(m: Module): Seq[SMem] = { + + val connects = HashMap[String, Expression]() + val mems = ArrayBuffer[SMem]() + + // swiped from InferRW + def findConnects(s: Statement): Unit = s match { + case s: Connect => + connects(s.loc.serialize) = s.expr + case s: PartialConnect => + connects(s.loc.serialize) = s.expr + case s: DefNode => + connects(s.name) = s.value + case s: Block => + s.stmts foreach findConnects + case _ => + } + + def findConnectOriginFromExp(e: Expression): Seq[Expression] = e match { + // matches how wmode, wmask, write_en are assigned (from Chirrtl) + // in case no ConstProp is performed before this pass + case Mux(cond, tv, fv, _) if we(tv) == we(one) && we(fv) == we(zero) => + cond +: findConnectOrigin(cond.serialize) + // visit connected nodes to references + case _: WRef | _: SubField | _: SubIndex | _: SubAccess => + e +: findConnectOrigin(e.serialize) + // backward searches until a PrimOp or Literal appears --> + // Literal: you've reached origin + // PrimOp: you're not simply doing propagation anymore + // NOTE: not a catch-all!!! + case _ => List(e) + } + + // only capable of searching for origin in the same module + def findConnectOrigin(node: String): Seq[Expression] = { + if (connects contains node) findConnectOriginFromExp(connects(node)) + else Nil + } + + // returns None if wen = wmask bits or wmask bits all = 1; otherwise returns # of mask bits + def getMaskBits(wen: String, wmask: String): Option[Int] = { + val wenOrigin = findConnectOrigin(wen) + // find all mask bits + val wmaskOrigin = connects.keys.toSeq filter (_.startsWith(wmask)) map findConnectOrigin + val bitEq = wmaskOrigin map (wenOrigin intersect _) map (_.length > 0) + // when all wmask bits are equal to wmode, wmask is redundant + val eq = bitEq.foldLeft(true)(_ && _) + val wmaskBitOne = wmaskOrigin map(_ contains one) + // if all wmask bits = 1, then wmask is redundant + val wmaskOne = wmaskBitOne.foldLeft(true)(_ && _) + if (eq || wmaskOne) None else Some(wmaskOrigin.length) + } + + def findMemInsts(s: Statement): Unit = s match { + // only find smems + case m: DefMemory if m.readLatency > 0 => + val dataBits = bitWidth(m.dataType) + val rwPorts = m.readwriters map (w => { + val maskBits = getMaskBits(s"${m.name}.$w.wmode",s"${m.name}.$w.wmask") + if (maskBits == None) PortForReadWrite(name = w) + else PortForReadWrite(name = w, useMask = true, maskGran = Some(dataBits/maskBits.get)) + }) + val wPorts = m.writers map (w => { + val maskBits = getMaskBits(s"${m.name}.$w.en",s"${m.name}.$w.mask") + if (maskBits == None) PortForWrite(name = w) + else PortForWrite(name = w, useMask = true, maskGran = Some(dataBits/maskBits.get)) + }) + val smemInfo = SMem( + m = m, + readPorts = m.readers map(r => PortForRead(name = r)), + writePorts = wPorts, + readWritePorts = rwPorts + ) + mems += smemInfo + case b: Block => b.stmts foreach findMemInsts + case _ => + } + findConnects(m.body) + findMemInsts(m.body) + mems.toSeq + } + + def run(c: Circuit) = { + val uniqueMems = ArrayBuffer[SMem]() + def analyzeMemsInCircuit(c: Circuit) = { + val mems = ArrayBuffer[SMem]() + c.modules foreach { _ match { + case m: Module => mems ++= analyzeMemsInModule(m) + case m: ExtModule => + }} + mems map {m => + val memProto = uniqueMems.find(_.eq(m)) + if (memProto == None) { + uniqueMems += m + m.name -> m + } + else m.name -> memProto.get.copy(m=m.m) + } + } + val memMap = analyzeMemsInCircuit(c) + println(memMap) + c + } + +} + +class ReplSeqMem(transID: TransID) extends Transform with LazyLogging { + def execute(circuit:Circuit, map: AnnotationMap) = + map get transID match { + case Some(p) => p get CircuitName(circuit.main) match { + case Some(ReplSeqMemAnnotation(_, _)) => TransformResult((Seq( + Legalize, + ReplSeqMem, + CheckInitialization, + ResolveKinds, + InferTypes, + ResolveGenders) foldLeft circuit){ (c, pass) => + val x = Utils.time(pass.name)(pass run c) + logger debug x.serialize + x + }, None, Some(map)) + case _ => TransformResult(circuit, None, Some(map)) + } + case _ => TransformResult(circuit, None, Some(map)) + } +} \ No newline at end of file -- cgit v1.2.3