aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAdam Izraelevitz2016-08-02 12:24:13 -0700
committerGitHub2016-08-02 12:24:13 -0700
commitdc7a1470e1a64643c387e328030059735d8d2c4c (patch)
tree29b0699b2fa9e3e18de99b3b39fd1d41ba24775b /src
parent6505168958e44bde9ba6828c0f7c03a04528fdec (diff)
parentc951e7453303f7aaf0c281f88a76ae2ba017ed38 (diff)
Merge pull request #203 from ucb-bar/fix_mem_infer
Fix mem infer
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/Driver.scala7
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala2
-rw-r--r--src/main/scala/firrtl/passes/InferReadWrite.scala202
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala74
-rw-r--r--src/test/scala/firrtlTests/InferReadWriteSpec.scala104
5 files changed, 364 insertions, 25 deletions
diff --git a/src/main/scala/firrtl/Driver.scala b/src/main/scala/firrtl/Driver.scala
index 59a2bb87..bd7210f4 100644
--- a/src/main/scala/firrtl/Driver.scala
+++ b/src/main/scala/firrtl/Driver.scala
@@ -44,6 +44,7 @@ Options:
Currently supported: high low verilog
--info-mode <mode> Specify Info Mode
Supported modes: ignore, use, gen, append
+ --inferRW <circuit> Enable readwrite port inference for the target circuit
"""
// Compiles circuit. First parses a circuit from an input file,
@@ -87,11 +88,15 @@ Options:
case _ => throw new Exception(s"Bad inline instance/module name: $value")
}
+ def handleInferRWOption(value: String) =
+ passes.InferReadWriteAnnotation(value, TransID(-1))
+
run(args: Array[String],
Map( "high" -> new HighFirrtlCompiler(),
"low" -> new LowFirrtlCompiler(),
"verilog" -> new VerilogCompiler()),
- Map("--inline" -> handleInlineOption _),
+ Map("--inline" -> handleInlineOption _,
+ "--inferRW" -> handleInferRWOption _),
usage
)
}
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index 95398356..7c239b10 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -188,6 +188,7 @@ class LowFirrtlCompiler extends Compiler {
new passes.InlineInstances(TransID(0)),
new ResolveAndCheck(),
new HighFirrtlToMiddleFirrtl(),
+ new passes.InferReadWrite(TransID(-1)),
new MiddleFirrtlToLowFirrtl(),
new EmitFirrtl(writer)
)
@@ -200,6 +201,7 @@ class VerilogCompiler extends Compiler {
new IRToWorkingIR(),
new ResolveAndCheck(),
new HighFirrtlToMiddleFirrtl(),
+ new passes.InferReadWrite(TransID(-1)),
new MiddleFirrtlToLowFirrtl(),
new passes.InlineInstances(TransID(0)),
new EmitVerilogFromLowFirrtl(writer)
diff --git a/src/main/scala/firrtl/passes/InferReadWrite.scala b/src/main/scala/firrtl/passes/InferReadWrite.scala
new file mode 100644
index 00000000..2378216d
--- /dev/null
+++ b/src/main/scala/firrtl/passes/InferReadWrite.scala
@@ -0,0 +1,202 @@
+/*
+Copyright (c) 2014 - 2016 The Regents of the University of
+California (Regents). All Rights Reserved. Redistribution and use in
+source and binary forms, with or without modification, are permitted
+provided that the following conditions are met:
+ * Redistributions of source code must retain the above
+ copyright notice, this list of conditions and the following
+ two paragraphs of disclaimer.
+ * Redistributions in binary form must reproduce the above
+ copyright notice, this list of conditions and the following
+ two paragraphs of disclaimer in the documentation and/or other materials
+ provided with the distribution.
+ * Neither the name of the Regents nor the names of its contributors
+ may be used to endorse or promote products derived from this
+ software without specific prior written permission.
+IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT,
+SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS,
+ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF
+REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF
+ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION
+TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR
+MODIFICATIONS.
+*/
+
+package firrtl.passes
+
+import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap}
+import com.typesafe.scalalogging.LazyLogging
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Mappers._
+import firrtl.PrimOps._
+import Annotations._
+
+case class InferReadWriteAnnotation(t: String, tID: TransID)
+ extends Annotation with Loose with Unstable {
+ val target = CircuitName(t)
+ def duplicate(n: Named) = this.copy(t=n.name)
+}
+
+// This pass examine the enable signals of the read & write ports of memories
+// whose readLatency is greater than 1 (usually SeqMem in Chisel).
+// If any product term of the enable signal of the read port is the complement
+// of any product term of the enable signal of the write port, then the readwrite
+// port is inferred.
+object InferReadWritePass extends Pass {
+ def name = "Infer ReadWrite Ports"
+
+ def inferReadWrite(m: Module) = {
+ import WrappedExpression.we
+ val connects = HashMap[String, Expression]()
+ val repl = HashMap[String, Expression]()
+ val stmts = ArrayBuffer[Statement]()
+ val zero = we(UIntLiteral(0, IntWidth(1)))
+ val one = we(UIntLiteral(1, IntWidth(1)))
+
+ // find all wire connections
+ def analyze(s: Statement): Unit = s match {
+ case s: Connect =>
+ connects(s.loc.serialize) = s.expr
+ case s: PartialConnect =>
+ connects(s.loc.serialize) = s.expr
+ case s: DefNode =>
+ connects(s.name) = s.value
+ case s: Block =>
+ s.stmts foreach analyze
+ case _ =>
+ }
+
+ def getProductTermsFromExp(e: Expression): Seq[Expression] =
+ e match {
+ // No ConstProp yet...
+ case Mux(cond, tval, fval, _) if we(tval) == one && we(fval) == zero =>
+ cond +: getProductTerms(cond.serialize)
+ // Visit each term of AND operation
+ case DoPrim(op, args, consts, tpe) if op == And =>
+ e +: (args flatMap getProductTermsFromExp)
+ // Visit connected nodes to references
+ case _: WRef | _: SubField | _: SubIndex | _: SubAccess =>
+ e +: getProductTerms(e.serialize)
+ // Otherwise just return itselt
+ case _ =>
+ List(e)
+ }
+
+ def getProductTerms(node: String): Seq[Expression] =
+ if (connects contains node) getProductTermsFromExp(connects(node)) else Nil
+
+ def checkComplement(a: Expression, b: Expression) = (a, b) match {
+ // b ?= Not(a)
+ case (_, DoPrim(op, args, _, _)) if op == Not =>
+ args.head.serialize == a.serialize
+ // a ?= Not(b)
+ case (DoPrim(op, args, _, _), _) if op == Not =>
+ args.head.serialize == b.serialize
+ // b ?= Eq(a, 0) or b ?= Eq(0, a)
+ case (_, DoPrim(op, args, _, _)) if op == Eq =>
+ args(0).serialize == a.serialize && we(args(1)) == zero ||
+ args(1).serialize == a.serialize && we(args(0)) == zero
+ // a ?= Eq(b, 0) or b ?= Eq(0, a)
+ case (DoPrim(op, args, _, _), _) if op == Eq =>
+ args(0).serialize == b.serialize && we(args(1)) == zero ||
+ args(1).serialize == b.serialize && we(args(0)) == zero
+ case _ => false
+ }
+
+ def inferReadWrite(s: Statement): Statement = s map inferReadWrite match {
+ // infer readwrite ports only for non combinational memories
+ case mem: DefMemory if mem.readLatency > 0 =>
+ val bt = UIntType(IntWidth(1))
+ val ut = UnknownType
+ val ug = UNKNOWNGENDER
+ val readers = HashSet[String]()
+ val writers = HashSet[String]()
+ val readwriters = ArrayBuffer[String]()
+ for (w <- mem.writers ; r <- mem.readers) {
+ val wp = getProductTerms(s"${mem.name}.$w.en")
+ val rp = getProductTerms(s"${mem.name}.$r.en")
+ if (wp exists (a => rp exists (b => checkComplement(a, b)))) {
+ val allPorts = (mem.readers ++ mem.writers ++ mem.readwriters ++ readwriters).toSet
+ // Uniquify names by examining all ports of the memory
+ var rw = (for {
+ idx <- Stream from 0
+ newName = s"rw_$idx"
+ if !allPorts(newName)
+ } yield newName).head
+ val rw_exp = WSubField(WRef(mem.name, ut, NodeKind(), ug), rw, ut, ug)
+ readwriters += rw
+ readers += r
+ writers += w
+ repl(s"${mem.name}.$r.en") = EmptyExpression
+ repl(s"${mem.name}.$r.clk") = EmptyExpression
+ repl(s"${mem.name}.$r.addr") = EmptyExpression
+ repl(s"${mem.name}.$r.data") = WSubField(rw_exp, "rdata", mem.dataType, MALE)
+ repl(s"${mem.name}.$w.en") = WSubField(rw_exp, "wmode", bt, FEMALE)
+ repl(s"${mem.name}.$w.clk") = EmptyExpression
+ repl(s"${mem.name}.$w.addr") = EmptyExpression
+ repl(s"${mem.name}.$w.data") = WSubField(rw_exp, "data", mem.dataType, FEMALE)
+ repl(s"${mem.name}.$w.mask") = WSubField(rw_exp, "mask", ut, FEMALE)
+ stmts += Connect(NoInfo, WSubField(rw_exp, "clk", ClockType, FEMALE),
+ WRef("clk", ClockType, NodeKind(), MALE))
+ stmts += Connect(NoInfo, WSubField(rw_exp, "en", bt, FEMALE),
+ DoPrim(Or, List(connects(s"${mem.name}.$r.en"), connects(s"${mem.name}.$w.en")), Nil, bt))
+ stmts += Connect(NoInfo, WSubField(rw_exp, "addr", ut, FEMALE),
+ Mux(connects(s"${mem.name}.$w.en"), connects(s"${mem.name}.$w.addr"),
+ connects(s"${mem.name}.$r.addr"), ut))
+ }
+ }
+ if (readwriters.isEmpty) mem else DefMemory(mem.info,
+ mem.name, mem.dataType, mem.depth, mem.writeLatency, mem.readLatency,
+ mem.readers filterNot readers, mem.writers filterNot writers,
+ mem.readwriters ++ readwriters)
+ case s => s
+ }
+
+ def replaceExp(e: Expression): Expression =
+ e map replaceExp match {
+ case e: WSubField => repl getOrElse (e.serialize, e)
+ case e => e
+ }
+
+ def replaceStmt(s: Statement): Statement =
+ s map replaceStmt map replaceExp match {
+ case Connect(info, loc, exp) if loc == EmptyExpression => EmptyStmt
+ case s => s
+ }
+
+ analyze(m.body)
+ Module(m.info, m.name, m.ports, Block((m.body map inferReadWrite map replaceStmt) +: stmts.toSeq))
+ }
+
+ def run (c:Circuit) = Circuit(c.info, c.modules map {
+ case m: Module => inferReadWrite(m)
+ case m: ExtModule => m
+ }, c.main)
+}
+
+// Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl"
+// To use this transform, circuit name should be annotated with its TransId.
+class InferReadWrite(transID: TransID) extends Transform with LazyLogging {
+ def execute(circuit:Circuit, map: AnnotationMap) =
+ map get transID match {
+ case Some(p) => p get CircuitName(circuit.main) match {
+ case Some(InferReadWriteAnnotation(_, _)) => TransformResult((Seq(
+ InferReadWritePass,
+ CheckInitialization,
+ ResolveKinds,
+ InferTypes,
+ ResolveGenders) foldLeft circuit){ (c, pass) =>
+ val x = Utils.time(pass.name)(pass run c)
+ logger debug x.serialize
+ x
+ }, None, Some(map))
+ case _ => TransformResult(circuit, None, Some(map))
+ }
+ case _ => TransformResult(circuit, None, Some(map))
+ }
+}
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
index 1a40b7c5..bd9563dc 100644
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ b/src/main/scala/firrtl/passes/Passes.scala
@@ -1085,6 +1085,7 @@ object RemoveCHIRRTL extends Pass {
def remove_chirrtl_m (m:Module) : Module = {
val hash = LinkedHashMap[String,MPorts]()
val repl = LinkedHashMap[String,DataRef]()
+ val raddrs = HashMap[String, Expression]()
val ut = UnknownType
val mport_types = LinkedHashMap[String,Type]()
def EMPs () : MPorts = MPorts(ArrayBuffer[MPort](),ArrayBuffer[MPort](),ArrayBuffer[MPort]())
@@ -1168,11 +1169,15 @@ object RemoveCHIRRTL extends Pass {
ens += "en"
masks += "mask"
}
- case _ => {
+ case MRead => {
repl(s.name) = DataRef(SubField(Reference(s.mem,ut),s.name,ut),"data","data","blah",false)
addrs += "addr"
clks += "clk"
- ens += "en"
+ s.exps(0) match {
+ case e: Reference =>
+ raddrs(e.name) = SubField(SubField(Reference(s.mem,ut),s.name,ut),"en",ut)
+ case _=>
+ }
}
}
val stmts = ArrayBuffer[Statement]()
@@ -1192,23 +1197,25 @@ object RemoveCHIRRTL extends Pass {
}
def remove_chirrtl_s (s:Statement) : Statement = {
var has_write_mport = false
+ var has_read_mport: Option[Expression] = None
var has_readwrite_mport:Option[Expression] = None
def remove_chirrtl_e (g:Gender)(e:Expression) : Expression = {
- (e) match {
- case (e:Reference) => {
- if (repl.contains(e.name)) {
- val vt = repl(e.name)
- g match {
- case MALE => SubField(vt.exp,vt.male,e.tpe)
- case FEMALE => {
- has_write_mport = true
- if (vt.rdwrite == true)
- has_readwrite_mport = Some(SubField(vt.exp,"wmode",UIntType(IntWidth(1))))
- SubField(vt.exp,vt.female,e.tpe)
- }
+ (e) match {
+ case (e:Reference) if repl contains e.name =>
+ val vt = repl(e.name)
+ g match {
+ case MALE => SubField(vt.exp,vt.male,e.tpe)
+ case FEMALE => {
+ has_write_mport = true
+ if (vt.rdwrite)
+ has_readwrite_mport = Some(SubField(vt.exp,"wmode",UIntType(IntWidth(1))))
+ SubField(vt.exp,vt.female,e.tpe)
}
- } else e
- }
+ }
+ case (e:Reference) if g == FEMALE && (raddrs contains e.name) =>
+ has_read_mport = Some(raddrs(e.name))
+ e
+ case (e:Reference) => e
case (e:SubAccess) => SubAccess(remove_chirrtl_e(g)(e.expr),remove_chirrtl_e(MALE)(e.index),e.tpe)
case (e) => e map (remove_chirrtl_e(g))
}
@@ -1225,20 +1232,35 @@ object RemoveCHIRRTL extends Pass {
case (e) => e
}
}
- (s) match {
+ (s) match {
+ case (s:DefNode) => {
+ val stmts = ArrayBuffer[Statement]()
+ val valuex = remove_chirrtl_e(MALE)(s.value)
+ stmts += DefNode(s.info,s.name,valuex)
+ has_read_mport match {
+ case None =>
+ case Some(en) => stmts += Connect(s.info,en,one)
+ }
+ if (stmts.size > 1) Block(stmts)
+ else stmts(0)
+ }
case (s:Connect) => {
val stmts = ArrayBuffer[Statement]()
val rocx = remove_chirrtl_e(MALE)(s.expr)
val locx = remove_chirrtl_e(FEMALE)(s.loc)
stmts += Connect(s.info,locx,rocx)
+ has_read_mport match {
+ case None =>
+ case Some(en) => stmts += Connect(s.info,en,one)
+ }
if (has_write_mport) {
val e = get_mask(s.loc)
for (x <- create_exps(e) ) {
stmts += Connect(s.info,x,one)
}
- if (has_readwrite_mport != None) {
- val wmode = has_readwrite_mport.get
- stmts += Connect(s.info,wmode,one)
+ has_readwrite_mport match {
+ case None =>
+ case Some(wmode) => stmts += Connect(s.info,wmode,one)
}
}
if (stmts.size > 1) Block(stmts)
@@ -1249,16 +1271,20 @@ object RemoveCHIRRTL extends Pass {
val locx = remove_chirrtl_e(FEMALE)(s.loc)
val rocx = remove_chirrtl_e(MALE)(s.expr)
stmts += PartialConnect(s.info,locx,rocx)
- if (has_write_mport != false) {
+ has_read_mport match {
+ case None =>
+ case Some(en) => stmts += Connect(s.info,en,one)
+ }
+ if (has_write_mport) {
val ls = get_valid_points(tpe(s.loc),tpe(s.expr),Default,Default)
val locs = create_exps(get_mask(s.loc))
for (x <- ls ) {
val locx = locs(x._1)
stmts += Connect(s.info,locx,one)
}
- if (has_readwrite_mport != None) {
- val wmode = has_readwrite_mport.get
- stmts += Connect(s.info,wmode,one)
+ has_readwrite_mport match {
+ case None =>
+ case Some(wmode) => stmts += Connect(s.info,wmode,one)
}
}
if (stmts.size > 1) Block(stmts)
diff --git a/src/test/scala/firrtlTests/InferReadWriteSpec.scala b/src/test/scala/firrtlTests/InferReadWriteSpec.scala
new file mode 100644
index 00000000..93f73741
--- /dev/null
+++ b/src/test/scala/firrtlTests/InferReadWriteSpec.scala
@@ -0,0 +1,104 @@
+/*
+Copyright (c) 2014 - 2016 The Regents of the University of
+California (Regents). All Rights Reserved. Redistribution and use in
+source and binary forms, with or without modification, are permitted
+provided that the following conditions are met:
+ * Redistributions of source code must retain the above
+ copyright notice, this list of conditions and the following
+ two paragraphs of disclaimer.
+ * Redistributions in binary form must reproduce the above
+ copyright notice, this list of conditions and the following
+ two paragraphs of disclaimer in the documentation and/or other materials
+ provided with the distribution.
+ * Neither the name of the Regents nor the names of its contributors
+ may be used to endorse or promote products derived from this
+ software without specific prior written permission.
+IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT,
+SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS,
+ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF
+REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF
+ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION
+TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR
+MODIFICATIONS.
+*/
+
+package firrtlTests
+
+import firrtl._
+import firrtl.ir._
+import firrtl.passes._
+import firrtl.Mappers._
+import Annotations._
+
+class InferReadWriteSpec extends SimpleTransformSpec {
+ object InferReadWriteCheckPass extends Pass {
+ val name = "Check Infer ReadWrite Ports"
+ var foundReadWrite = false
+ def findReadWrite(s: Statement): Unit = s match {
+ case s: DefMemory if s.readLatency > 0 =>
+ foundReadWrite = s.name == "mem" && s.readwriters.size == 1
+ case s: Block =>
+ s.stmts foreach findReadWrite
+ case _ =>
+ }
+
+ def run (c: Circuit) = {
+ val errors = new Errors
+ c.modules foreach {
+ case m: Module => findReadWrite(m.body)
+ case m: ExtModule => m
+ }
+ if (!foundReadWrite) {
+ errors append new PassException("Readwrite ports are not found!")
+ }
+ errors.trigger
+ c
+ }
+ }
+
+ object InferReadWriteCheck extends Transform with SimpleRun {
+ def execute (c: Circuit, map: AnnotationMap) =
+ run(c, Seq(InferReadWriteCheckPass))
+ }
+
+ def transforms (writer: java.io.Writer) = Seq(
+ new Chisel3ToHighFirrtl(),
+ new IRToWorkingIR(),
+ new ResolveAndCheck(),
+ new HighFirrtlToMiddleFirrtl(),
+ new InferReadWrite(TransID(-1)),
+ InferReadWriteCheck,
+ new EmitFirrtl(writer)
+ )
+
+ "Infer ReadWrite Ports" should "infer readwrite ports" in {
+ val input = """
+circuit sram6t :
+ module sram6t :
+ input clk : Clock
+ input reset : UInt<1>
+ output io : {flip en : UInt<1>, flip wen : UInt<1>, flip waddr : UInt<8>, flip wdata : UInt<32>, flip raddr : UInt<8>, rdata : UInt<32>}
+
+ io is invalid
+ smem mem : UInt<32>[128]
+ node T_0 = eq(io.wen, UInt<1>("h00"))
+ node T_1 = and(io.en, T_0)
+ wire T_2 : UInt
+ T_2 is invalid
+ when T_1 :
+ T_2 <= io.raddr
+ read mport T_3 = mem[T_2], clk
+ io.rdata <= T_3
+ node T_4 = and(io.en, io.wen)
+ when T_4 :
+ write mport T_5 = mem[io.waddr], clk
+ T_5 <= io.wdata
+""".stripMargin
+
+ val annotaitonMap = AnnotationMap(Seq(InferReadWriteAnnotation("sram6t", TransID(-1))))
+ compile(parse(input), annotaitonMap, new java.io.StringWriter)
+ }
+}