aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
authorDonggyu Kim2016-07-27 14:56:11 -0700
committerDonggyu Kim2016-08-02 11:29:43 -0700
commitc951e7453303f7aaf0c281f88a76ae2ba017ed38 (patch)
tree7374e2570f84d73ca826e6722941828ee6fc5be2 /src/main
parent22350029c9a91c30abd849c17108f8bc24054a78 (diff)
make infer readwrite ports optional
turned on with '--inferRW <circuit name>'
Diffstat (limited to 'src/main')
-rw-r--r--src/main/scala/firrtl/Driver.scala7
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala3
-rw-r--r--src/main/scala/firrtl/passes/InferReadWrite.scala56
3 files changed, 60 insertions, 6 deletions
diff --git a/src/main/scala/firrtl/Driver.scala b/src/main/scala/firrtl/Driver.scala
index 59a2bb87..bd7210f4 100644
--- a/src/main/scala/firrtl/Driver.scala
+++ b/src/main/scala/firrtl/Driver.scala
@@ -44,6 +44,7 @@ Options:
Currently supported: high low verilog
--info-mode <mode> Specify Info Mode
Supported modes: ignore, use, gen, append
+ --inferRW <circuit> Enable readwrite port inference for the target circuit
"""
// Compiles circuit. First parses a circuit from an input file,
@@ -87,11 +88,15 @@ Options:
case _ => throw new Exception(s"Bad inline instance/module name: $value")
}
+ def handleInferRWOption(value: String) =
+ passes.InferReadWriteAnnotation(value, TransID(-1))
+
run(args: Array[String],
Map( "high" -> new HighFirrtlCompiler(),
"low" -> new LowFirrtlCompiler(),
"verilog" -> new VerilogCompiler()),
- Map("--inline" -> handleInlineOption _),
+ Map("--inline" -> handleInlineOption _,
+ "--inferRW" -> handleInferRWOption _),
usage
)
}
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index c27ffce7..4d7ddfe0 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -107,7 +107,6 @@ class HighFirrtlToMiddleFirrtl () extends Transform with SimpleRun {
passes.ExpandConnects,
passes.RemoveAccesses,
passes.ExpandWhens,
- passes.InferReadWrite,
passes.CheckInitialization,
passes.ConstProp,
passes.ResolveKinds,
@@ -190,6 +189,7 @@ class LowFirrtlCompiler extends Compiler {
new passes.InlineInstances(TransID(0)),
new ResolveAndCheck(),
new HighFirrtlToMiddleFirrtl(),
+ new passes.InferReadWrite(TransID(-1)),
new MiddleFirrtlToLowFirrtl(),
new EmitFirrtl(writer)
)
@@ -202,6 +202,7 @@ class VerilogCompiler extends Compiler {
new IRToWorkingIR(),
new ResolveAndCheck(),
new HighFirrtlToMiddleFirrtl(),
+ new passes.InferReadWrite(TransID(-1)),
new MiddleFirrtlToLowFirrtl(),
new passes.InlineInstances(TransID(0)),
new EmitVerilogFromLowFirrtl(writer)
diff --git a/src/main/scala/firrtl/passes/InferReadWrite.scala b/src/main/scala/firrtl/passes/InferReadWrite.scala
index dd62a240..2378216d 100644
--- a/src/main/scala/firrtl/passes/InferReadWrite.scala
+++ b/src/main/scala/firrtl/passes/InferReadWrite.scala
@@ -28,13 +28,26 @@ MODIFICATIONS.
package firrtl.passes
import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap}
+import com.typesafe.scalalogging.LazyLogging
import firrtl._
import firrtl.ir._
import firrtl.Mappers._
import firrtl.PrimOps._
+import Annotations._
-object InferReadWrite extends Pass {
+case class InferReadWriteAnnotation(t: String, tID: TransID)
+ extends Annotation with Loose with Unstable {
+ val target = CircuitName(t)
+ def duplicate(n: Named) = this.copy(t=n.name)
+}
+
+// This pass examine the enable signals of the read & write ports of memories
+// whose readLatency is greater than 1 (usually SeqMem in Chisel).
+// If any product term of the enable signal of the read port is the complement
+// of any product term of the enable signal of the write port, then the readwrite
+// port is inferred.
+object InferReadWritePass extends Pass {
def name = "Infer ReadWrite Ports"
def inferReadWrite(m: Module) = {
@@ -45,6 +58,7 @@ object InferReadWrite extends Pass {
val zero = we(UIntLiteral(0, IntWidth(1)))
val one = we(UIntLiteral(1, IntWidth(1)))
+ // find all wire connections
def analyze(s: Statement): Unit = s match {
case s: Connect =>
connects(s.loc.serialize) = s.expr
@@ -62,10 +76,13 @@ object InferReadWrite extends Pass {
// No ConstProp yet...
case Mux(cond, tval, fval, _) if we(tval) == one && we(fval) == zero =>
cond +: getProductTerms(cond.serialize)
+ // Visit each term of AND operation
case DoPrim(op, args, consts, tpe) if op == And =>
e +: (args flatMap getProductTermsFromExp)
+ // Visit connected nodes to references
case _: WRef | _: SubField | _: SubIndex | _: SubAccess =>
e +: getProductTerms(e.serialize)
+ // Otherwise just return itselt
case _ =>
List(e)
}
@@ -74,13 +91,17 @@ object InferReadWrite extends Pass {
if (connects contains node) getProductTermsFromExp(connects(node)) else Nil
def checkComplement(a: Expression, b: Expression) = (a, b) match {
+ // b ?= Not(a)
case (_, DoPrim(op, args, _, _)) if op == Not =>
args.head.serialize == a.serialize
+ // a ?= Not(b)
case (DoPrim(op, args, _, _), _) if op == Not =>
args.head.serialize == b.serialize
+ // b ?= Eq(a, 0) or b ?= Eq(0, a)
case (_, DoPrim(op, args, _, _)) if op == Eq =>
args(0).serialize == a.serialize && we(args(1)) == zero ||
args(1).serialize == a.serialize && we(args(0)) == zero
+ // a ?= Eq(b, 0) or b ?= Eq(0, a)
case (DoPrim(op, args, _, _), _) if op == Eq =>
args(0).serialize == b.serialize && we(args(1)) == zero ||
args(1).serialize == b.serialize && we(args(0)) == zero
@@ -88,8 +109,8 @@ object InferReadWrite extends Pass {
}
def inferReadWrite(s: Statement): Statement = s map inferReadWrite match {
+ // infer readwrite ports only for non combinational memories
case mem: DefMemory if mem.readLatency > 0 =>
- var idx = 0
val bt = UIntType(IntWidth(1))
val ut = UnknownType
val ug = UNKNOWNGENDER
@@ -100,7 +121,13 @@ object InferReadWrite extends Pass {
val wp = getProductTerms(s"${mem.name}.$w.en")
val rp = getProductTerms(s"${mem.name}.$r.en")
if (wp exists (a => rp exists (b => checkComplement(a, b)))) {
- val rw = s"rw_$idx"
+ val allPorts = (mem.readers ++ mem.writers ++ mem.readwriters ++ readwriters).toSet
+ // Uniquify names by examining all ports of the memory
+ var rw = (for {
+ idx <- Stream from 0
+ newName = s"rw_$idx"
+ if !allPorts(newName)
+ } yield newName).head
val rw_exp = WSubField(WRef(mem.name, ut, NodeKind(), ug), rw, ut, ug)
readwriters += rw
readers += r
@@ -121,7 +148,6 @@ object InferReadWrite extends Pass {
stmts += Connect(NoInfo, WSubField(rw_exp, "addr", ut, FEMALE),
Mux(connects(s"${mem.name}.$w.en"), connects(s"${mem.name}.$w.addr"),
connects(s"${mem.name}.$r.addr"), ut))
- idx += 1
}
}
if (readwriters.isEmpty) mem else DefMemory(mem.info,
@@ -152,3 +178,25 @@ object InferReadWrite extends Pass {
case m: ExtModule => m
}, c.main)
}
+
+// Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl"
+// To use this transform, circuit name should be annotated with its TransId.
+class InferReadWrite(transID: TransID) extends Transform with LazyLogging {
+ def execute(circuit:Circuit, map: AnnotationMap) =
+ map get transID match {
+ case Some(p) => p get CircuitName(circuit.main) match {
+ case Some(InferReadWriteAnnotation(_, _)) => TransformResult((Seq(
+ InferReadWritePass,
+ CheckInitialization,
+ ResolveKinds,
+ InferTypes,
+ ResolveGenders) foldLeft circuit){ (c, pass) =>
+ val x = Utils.time(pass.name)(pass run c)
+ logger debug x.serialize
+ x
+ }, None, Some(map))
+ case _ => TransformResult(circuit, None, Some(map))
+ }
+ case _ => TransformResult(circuit, None, Some(map))
+ }
+}