aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDonggyu Kim2016-09-16 01:02:39 -0700
committerDonggyu Kim2016-09-21 13:17:02 -0700
commit56f1014669638de90fa1c58007aaf4c16b9876ef (patch)
treee902a48cb61f44b715a215a02cf9aa6b9a68a235 /src
parent350ffd7bbc1b014b9d9b256da4181c59bf0419e3 (diff)
refactor InferReadWrite
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/Namespace.scala15
-rw-r--r--src/main/scala/firrtl/passes/InferReadWrite.scala240
-rw-r--r--src/test/scala/firrtlTests/InferReadWriteSpec.scala2
3 files changed, 127 insertions, 130 deletions
diff --git a/src/main/scala/firrtl/Namespace.scala b/src/main/scala/firrtl/Namespace.scala
index 952670cf..1e922673 100644
--- a/src/main/scala/firrtl/Namespace.scala
+++ b/src/main/scala/firrtl/Namespace.scala
@@ -57,8 +57,6 @@ class Namespace private {
}
object Namespace {
- def apply(): Namespace = new Namespace
-
// Initializes a namespace from a Module
def apply(m: DefModule): Namespace = {
val namespace = new Namespace
@@ -69,7 +67,7 @@ object Namespace {
case s: Block => s.stmts flatMap buildNamespaceStmt
case _ => Nil
}
- namespace.namespace ++= (m.ports collect { case dec: IsDeclaration => dec.name })
+ namespace.namespace ++= m.ports map (_.name)
m match {
case in: Module =>
namespace.namespace ++= buildNamespaceStmt(in.body)
@@ -82,9 +80,14 @@ object Namespace {
/** Initializes a [[Namespace]] for [[ir.Module]] names in a [[ir.Circuit]] */
def apply(c: Circuit): Namespace = {
val namespace = new Namespace
- c.modules foreach { m =>
- namespace.namespace += m.name
- }
+ namespace.namespace ++= c.modules map (_.name)
+ namespace
+ }
+
+ /** Initializes a [[Namespace]] from arbitrary strings **/
+ def apply(names: Seq[String] = Nil): Namespace = {
+ val namespace = new Namespace
+ namespace.namespace ++= names
namespace
}
}
diff --git a/src/main/scala/firrtl/passes/InferReadWrite.scala b/src/main/scala/firrtl/passes/InferReadWrite.scala
index 38933103..ec996fdb 100644
--- a/src/main/scala/firrtl/passes/InferReadWrite.scala
+++ b/src/main/scala/firrtl/passes/InferReadWrite.scala
@@ -27,13 +27,14 @@ MODIFICATIONS.
package firrtl.passes
-import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap}
-import com.typesafe.scalalogging.LazyLogging
-
import firrtl._
import firrtl.ir._
import firrtl.Mappers._
import firrtl.PrimOps._
+import firrtl.Utils.{one, zero, BoolType}
+import MemPortUtils.memPortField
+import AnalysisUtils.{Connects, getConnects}
+import WrappedExpression.weq
import Annotations._
case class InferReadWriteAnnotation(t: String, tID: TransID)
@@ -50,140 +51,133 @@ case class InferReadWriteAnnotation(t: String, tID: TransID)
object InferReadWritePass extends Pass {
def name = "Infer ReadWrite Ports"
- def inferReadWrite(m: Module) = {
- import AnalysisUtils._
- import WrappedExpression.we
- val connects = getConnects(m)
- val repl = HashMap[String, Expression]()
- val stmts = ArrayBuffer[Statement]()
- val zero = we(UIntLiteral(0, IntWidth(1)))
- val one = we(UIntLiteral(1, IntWidth(1)))
+ type Netlist = collection.mutable.HashMap[String, Expression]
+ type Statements = collection.mutable.ArrayBuffer[Statement]
+ type PortSet = collection.mutable.HashSet[String]
- def getProductTermsFromExp(e: Expression): Seq[Expression] =
- e match {
- // No ConstProp yet...
- case Mux(cond, tval, fval, _) if we(tval) == one && we(fval) == zero =>
- cond +: getProductTerms(cond.serialize)
- // Visit each term of AND operation
- case DoPrim(op, args, consts, tpe) if op == And =>
- e +: (args flatMap getProductTermsFromExp)
- // Visit connected nodes to references
- case _: WRef | _: SubField | _: SubIndex | _: SubAccess =>
- e +: getProductTerms(e.serialize)
- // Otherwise just return itselt
- case _ =>
- List(e)
- }
+ private implicit def toString(e: Expression) = e.serialize
+
+ def getProductTerms(connects: Connects)(e: Expression): Seq[Expression] = e match {
+ // No ConstProp yet...
+ case Mux(cond, tval, fval, _) if weq(tval, one) && weq(fval, zero) =>
+ getProductTerms(connects)(cond)
+ // Visit each term of AND operation
+ case DoPrim(op, args, consts, tpe) if op == And =>
+ e +: (args flatMap getProductTerms(connects))
+ // Visit connected nodes to references
+ case _: WRef | _: WSubField | _: WSubIndex => connects get e match {
+ case None => Seq(e)
+ case Some(ex) => e +: getProductTerms(connects)(ex)
+ }
+ // Otherwise just return itself
+ case _ => Seq(e)
+ }
+
+ def checkComplement(a: Expression, b: Expression) = (a, b) match {
+ // b ?= Not(a)
+ case (_, DoPrim(Not, args, _, _)) => weq(args.head, a)
+ // a ?= Not(b)
+ case (DoPrim(Not, args, _, _), _) => weq(args.head, b)
+ // b ?= Eq(a, 0) or b ?= Eq(0, a)
+ case (_, DoPrim(Eq, args, _, _)) =>
+ weq(args(0), a) && weq(args(1), zero) ||
+ weq(args(1), a) && weq(args(0), zero)
+ // a ?= Eq(b, 0) or b ?= Eq(0, a)
+ case (DoPrim(Eq, args, _, _), _) =>
+ weq(args(0), b) && weq(args(1), zero) ||
+ weq(args(1), b) && weq(args(0), zero)
+ case _ => false
+ }
- def getProductTerms(node: String): Seq[Expression] =
- if (connects contains node) getProductTermsFromExp(connects(node)) else Nil
- def checkComplement(a: Expression, b: Expression) = (a, b) match {
- // b ?= Not(a)
- case (_, DoPrim(op, args, _, _)) if op == Not =>
- args.head.serialize == a.serialize
- // a ?= Not(b)
- case (DoPrim(op, args, _, _), _) if op == Not =>
- args.head.serialize == b.serialize
- // b ?= Eq(a, 0) or b ?= Eq(0, a)
- case (_, DoPrim(op, args, _, _)) if op == Eq =>
- args(0).serialize == a.serialize && we(args(1)) == zero ||
- args(1).serialize == a.serialize && we(args(0)) == zero
- // a ?= Eq(b, 0) or b ?= Eq(0, a)
- case (DoPrim(op, args, _, _), _) if op == Eq =>
- args(0).serialize == b.serialize && we(args(1)) == zero ||
- args(1).serialize == b.serialize && we(args(0)) == zero
- case _ => false
+ def replaceExp(repl: Netlist)(e: Expression): Expression =
+ e map replaceExp(repl) match {
+ case e: WSubField => repl getOrElse (e.serialize, e)
+ case e => e
}
- def inferReadWrite(s: Statement): Statement = s map inferReadWrite match {
- // infer readwrite ports only for non combinational memories
- case mem: DefMemory if mem.readLatency > 0 =>
- val bt = UIntType(IntWidth(1))
- val ut = UnknownType
- val ug = UNKNOWNGENDER
- val readers = HashSet[String]()
- val writers = HashSet[String]()
- val readwriters = ArrayBuffer[String]()
- for (w <- mem.writers ; r <- mem.readers) {
- val wp = getProductTerms(s"${mem.name}.$w.en")
- val rp = getProductTerms(s"${mem.name}.$r.en")
- if (wp exists (a => rp exists (b => checkComplement(a, b)))) {
- val allPorts = (mem.readers ++ mem.writers ++ mem.readwriters ++ readwriters).toSet
- // Uniquify names by examining all ports of the memory
- var rw = (for {
- idx <- Stream from 0
- newName = s"rw_$idx"
- if !allPorts(newName)
- } yield newName).head
- val rw_exp = WSubField(WRef(mem.name, ut, MemKind, ug), rw, ut, ug)
- readwriters += rw
- readers += r
- writers += w
- repl(s"${mem.name}.$r.en") = EmptyExpression
- repl(s"${mem.name}.$r.clk") = EmptyExpression
- repl(s"${mem.name}.$r.addr") = EmptyExpression
- repl(s"${mem.name}.$r.data") = WSubField(rw_exp, "rdata", mem.dataType, MALE)
- repl(s"${mem.name}.$w.en") = WSubField(rw_exp, "wmode", bt, FEMALE)
- repl(s"${mem.name}.$w.clk") = EmptyExpression
- repl(s"${mem.name}.$w.addr") = EmptyExpression
- repl(s"${mem.name}.$w.data") = WSubField(rw_exp, "wdata", mem.dataType, FEMALE)
- repl(s"${mem.name}.$w.mask") = WSubField(rw_exp, "wmask", ut, FEMALE)
- stmts += Connect(NoInfo, WSubField(rw_exp, "clk", ClockType, FEMALE),
- WRef("clk", ClockType, NodeKind, MALE))
- stmts += Connect(NoInfo, WSubField(rw_exp, "en", bt, FEMALE),
- DoPrim(Or, List(connects(s"${mem.name}.$r.en"), connects(s"${mem.name}.$w.en")), Nil, bt))
- stmts += Connect(NoInfo, WSubField(rw_exp, "addr", ut, FEMALE),
- Mux(connects(s"${mem.name}.$w.en"), connects(s"${mem.name}.$w.addr"),
- connects(s"${mem.name}.$r.addr"), ut))
- }
- }
- if (readwriters.isEmpty) mem else DefMemory(mem.info,
- mem.name, mem.dataType, mem.depth, mem.writeLatency, mem.readLatency,
- mem.readers filterNot readers, mem.writers filterNot writers,
- mem.readwriters ++ readwriters)
+ def replaceStmt(repl: Netlist)(s: Statement): Statement =
+ s map replaceStmt(repl) map replaceExp(repl) match {
+ case Connect(_, EmptyExpression, _) => EmptyStmt
case s => s
}
-
- def replaceExp(e: Expression): Expression =
- e map replaceExp match {
- case e: WSubField => repl getOrElse (e.serialize, e)
- case e => e
+
+ def inferReadWriteStmt(connects: Connects,
+ repl: Netlist,
+ stmts: Statements)
+ (s: Statement): Statement = s match {
+ // infer readwrite ports only for non combinational memories
+ case mem: DefMemory if mem.readLatency > 0 =>
+ val ut = UnknownType
+ val ug = UNKNOWNGENDER
+ val readers = new PortSet
+ val writers = new PortSet
+ val readwriters = collection.mutable.ArrayBuffer[String]()
+ val namespace = Namespace(mem.readers ++ mem.writers ++ mem.readwriters)
+ for (w <- mem.writers ; r <- mem.readers) {
+ val wp = getProductTerms(connects)(memPortField(mem, w, "en"))
+ val rp = getProductTerms(connects)(memPortField(mem, r, "en"))
+ if (wp exists (a => rp exists (b => checkComplement(a, b)))) {
+ val rw = namespace newName "rw"
+ val rwExp = createSubField(createRef(mem.name), rw)
+ readwriters += rw
+ readers += r
+ writers += w
+ repl(memPortField(mem, r, "clk")) = EmptyExpression
+ repl(memPortField(mem, r, "en")) = EmptyExpression
+ repl(memPortField(mem, r, "addr")) = EmptyExpression
+ repl(memPortField(mem, r, "data")) = createSubField(rwExp, "rdata")
+ repl(memPortField(mem, w, "clk")) = EmptyExpression
+ repl(memPortField(mem, w, "en")) = createSubField(rwExp, "wmode")
+ repl(memPortField(mem, w, "addr")) = EmptyExpression
+ repl(memPortField(mem, w, "data")) = createSubField(rwExp, "wdata")
+ repl(memPortField(mem, w, "mask")) = createSubField(rwExp, "wmask")
+ stmts += Connect(NoInfo, createSubField(rwExp, "clk"), createRef("clk")) // TODO: fix it
+ stmts += Connect(NoInfo, createSubField(rwExp, "en"),
+ DoPrim(Or, Seq(connects(memPortField(mem, r, "en")),
+ connects(memPortField(mem, w, "en"))), Nil, BoolType))
+ stmts += Connect(NoInfo, createSubField(rwExp, "addr"),
+ Mux(connects(memPortField(mem, w, "en")),
+ connects(memPortField(mem, w, "addr")),
+ connects(memPortField(mem, r, "addr")), UnknownType))
+ }
}
+ if (readwriters.isEmpty) mem else mem copy (
+ readers = mem.readers filterNot readers,
+ writers = mem.writers filterNot writers,
+ readwriters = mem.readwriters ++ readwriters)
+ case s => s map inferReadWriteStmt(connects, repl, stmts)
+ }
- def replaceStmt(s: Statement): Statement =
- s map replaceStmt map replaceExp match {
- case Connect(info, loc, exp) if loc == EmptyExpression => EmptyStmt
- case s => s
- }
-
- Module(m.info, m.name, m.ports, Block((m.body map inferReadWrite map replaceStmt) +: stmts.toSeq))
+ def inferReadWrite(m: DefModule) = {
+ val connects = getConnects(m)
+ val repl = new Netlist
+ val stmts = new Statements
+ (m map inferReadWriteStmt(connects, repl, stmts)
+ map replaceStmt(repl)) match {
+ case m: ExtModule => m
+ case m: Module => m copy (body = Block(m.body +: stmts))
+ }
}
- def run (c:Circuit) = Circuit(c.info, c.modules map {
- case m: Module => inferReadWrite(m)
- case m: ExtModule => m
- }, c.main)
+ def run(c: Circuit) = c copy (modules = c.modules map inferReadWrite)
}
// Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl"
// To use this transform, circuit name should be annotated with its TransId.
-class InferReadWrite(transID: TransID) extends Transform with LazyLogging {
- def execute(circuit:Circuit, map: AnnotationMap) =
- map get transID match {
- case Some(p) => p get CircuitName(circuit.main) match {
- case Some(InferReadWriteAnnotation(_, _)) => TransformResult((Seq(
- InferReadWritePass,
- CheckInitialization,
- ResolveKinds,
- InferTypes,
- ResolveGenders) foldLeft circuit){ (c, pass) =>
- val x = Utils.time(pass.name)(pass run c)
- logger debug x.serialize
- x
- }, None, Some(map))
- case _ => TransformResult(circuit, None, Some(map))
- }
- case _ => TransformResult(circuit, None, Some(map))
+class InferReadWrite(transID: TransID) extends Transform with SimpleRun {
+ def passSeq = Seq(
+ InferReadWritePass,
+ CheckInitialization,
+ InferTypes,
+ ResolveKinds,
+ ResolveGenders
+ )
+ def execute(c: Circuit, map: AnnotationMap) = map get transID match {
+ case Some(p) => p get CircuitName(c.main) match {
+ case Some(InferReadWriteAnnotation(_, _)) => run(c, passSeq)
+ case _ => error("Unexpected annotation for InferReadWrite")
}
+ case _ => TransformResult(c)
+ }
}
diff --git a/src/test/scala/firrtlTests/InferReadWriteSpec.scala b/src/test/scala/firrtlTests/InferReadWriteSpec.scala
index 7e3383b2..3af018bd 100644
--- a/src/test/scala/firrtlTests/InferReadWriteSpec.scala
+++ b/src/test/scala/firrtlTests/InferReadWriteSpec.scala
@@ -38,7 +38,7 @@ class InferReadWriteSpec extends SimpleTransformSpec {
val name = "Check Infer ReadWrite Ports"
def findReadWrite(s: Statement): Boolean = s match {
case s: DefMemory if s.readLatency > 0 && s.readwriters.size == 1 =>
- s.name == "mem" && s.readwriters.head == "rw_0"
+ s.name == "mem" && s.readwriters.head == "rw"
case s: Block =>
s.stmts exists findReadWrite
case _ => false