diff options
| author | Donggyu Kim | 2016-09-16 01:02:39 -0700 |
|---|---|---|
| committer | Donggyu Kim | 2016-09-21 13:17:02 -0700 |
| commit | 56f1014669638de90fa1c58007aaf4c16b9876ef (patch) | |
| tree | e902a48cb61f44b715a215a02cf9aa6b9a68a235 /src | |
| parent | 350ffd7bbc1b014b9d9b256da4181c59bf0419e3 (diff) | |
refactor InferReadWrite
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Namespace.scala | 15 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/InferReadWrite.scala | 240 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/InferReadWriteSpec.scala | 2 |
3 files changed, 127 insertions, 130 deletions
diff --git a/src/main/scala/firrtl/Namespace.scala b/src/main/scala/firrtl/Namespace.scala index 952670cf..1e922673 100644 --- a/src/main/scala/firrtl/Namespace.scala +++ b/src/main/scala/firrtl/Namespace.scala @@ -57,8 +57,6 @@ class Namespace private { } object Namespace { - def apply(): Namespace = new Namespace - // Initializes a namespace from a Module def apply(m: DefModule): Namespace = { val namespace = new Namespace @@ -69,7 +67,7 @@ object Namespace { case s: Block => s.stmts flatMap buildNamespaceStmt case _ => Nil } - namespace.namespace ++= (m.ports collect { case dec: IsDeclaration => dec.name }) + namespace.namespace ++= m.ports map (_.name) m match { case in: Module => namespace.namespace ++= buildNamespaceStmt(in.body) @@ -82,9 +80,14 @@ object Namespace { /** Initializes a [[Namespace]] for [[ir.Module]] names in a [[ir.Circuit]] */ def apply(c: Circuit): Namespace = { val namespace = new Namespace - c.modules foreach { m => - namespace.namespace += m.name - } + namespace.namespace ++= c.modules map (_.name) + namespace + } + + /** Initializes a [[Namespace]] from arbitrary strings **/ + def apply(names: Seq[String] = Nil): Namespace = { + val namespace = new Namespace + namespace.namespace ++= names namespace } } diff --git a/src/main/scala/firrtl/passes/InferReadWrite.scala b/src/main/scala/firrtl/passes/InferReadWrite.scala index 38933103..ec996fdb 100644 --- a/src/main/scala/firrtl/passes/InferReadWrite.scala +++ b/src/main/scala/firrtl/passes/InferReadWrite.scala @@ -27,13 +27,14 @@ MODIFICATIONS. package firrtl.passes -import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap} -import com.typesafe.scalalogging.LazyLogging - import firrtl._ import firrtl.ir._ import firrtl.Mappers._ import firrtl.PrimOps._ +import firrtl.Utils.{one, zero, BoolType} +import MemPortUtils.memPortField +import AnalysisUtils.{Connects, getConnects} +import WrappedExpression.weq import Annotations._ case class InferReadWriteAnnotation(t: String, tID: TransID) @@ -50,140 +51,133 @@ case class InferReadWriteAnnotation(t: String, tID: TransID) object InferReadWritePass extends Pass { def name = "Infer ReadWrite Ports" - def inferReadWrite(m: Module) = { - import AnalysisUtils._ - import WrappedExpression.we - val connects = getConnects(m) - val repl = HashMap[String, Expression]() - val stmts = ArrayBuffer[Statement]() - val zero = we(UIntLiteral(0, IntWidth(1))) - val one = we(UIntLiteral(1, IntWidth(1))) + type Netlist = collection.mutable.HashMap[String, Expression] + type Statements = collection.mutable.ArrayBuffer[Statement] + type PortSet = collection.mutable.HashSet[String] - def getProductTermsFromExp(e: Expression): Seq[Expression] = - e match { - // No ConstProp yet... - case Mux(cond, tval, fval, _) if we(tval) == one && we(fval) == zero => - cond +: getProductTerms(cond.serialize) - // Visit each term of AND operation - case DoPrim(op, args, consts, tpe) if op == And => - e +: (args flatMap getProductTermsFromExp) - // Visit connected nodes to references - case _: WRef | _: SubField | _: SubIndex | _: SubAccess => - e +: getProductTerms(e.serialize) - // Otherwise just return itselt - case _ => - List(e) - } + private implicit def toString(e: Expression) = e.serialize + + def getProductTerms(connects: Connects)(e: Expression): Seq[Expression] = e match { + // No ConstProp yet... + case Mux(cond, tval, fval, _) if weq(tval, one) && weq(fval, zero) => + getProductTerms(connects)(cond) + // Visit each term of AND operation + case DoPrim(op, args, consts, tpe) if op == And => + e +: (args flatMap getProductTerms(connects)) + // Visit connected nodes to references + case _: WRef | _: WSubField | _: WSubIndex => connects get e match { + case None => Seq(e) + case Some(ex) => e +: getProductTerms(connects)(ex) + } + // Otherwise just return itself + case _ => Seq(e) + } + + def checkComplement(a: Expression, b: Expression) = (a, b) match { + // b ?= Not(a) + case (_, DoPrim(Not, args, _, _)) => weq(args.head, a) + // a ?= Not(b) + case (DoPrim(Not, args, _, _), _) => weq(args.head, b) + // b ?= Eq(a, 0) or b ?= Eq(0, a) + case (_, DoPrim(Eq, args, _, _)) => + weq(args(0), a) && weq(args(1), zero) || + weq(args(1), a) && weq(args(0), zero) + // a ?= Eq(b, 0) or b ?= Eq(0, a) + case (DoPrim(Eq, args, _, _), _) => + weq(args(0), b) && weq(args(1), zero) || + weq(args(1), b) && weq(args(0), zero) + case _ => false + } - def getProductTerms(node: String): Seq[Expression] = - if (connects contains node) getProductTermsFromExp(connects(node)) else Nil - def checkComplement(a: Expression, b: Expression) = (a, b) match { - // b ?= Not(a) - case (_, DoPrim(op, args, _, _)) if op == Not => - args.head.serialize == a.serialize - // a ?= Not(b) - case (DoPrim(op, args, _, _), _) if op == Not => - args.head.serialize == b.serialize - // b ?= Eq(a, 0) or b ?= Eq(0, a) - case (_, DoPrim(op, args, _, _)) if op == Eq => - args(0).serialize == a.serialize && we(args(1)) == zero || - args(1).serialize == a.serialize && we(args(0)) == zero - // a ?= Eq(b, 0) or b ?= Eq(0, a) - case (DoPrim(op, args, _, _), _) if op == Eq => - args(0).serialize == b.serialize && we(args(1)) == zero || - args(1).serialize == b.serialize && we(args(0)) == zero - case _ => false + def replaceExp(repl: Netlist)(e: Expression): Expression = + e map replaceExp(repl) match { + case e: WSubField => repl getOrElse (e.serialize, e) + case e => e } - def inferReadWrite(s: Statement): Statement = s map inferReadWrite match { - // infer readwrite ports only for non combinational memories - case mem: DefMemory if mem.readLatency > 0 => - val bt = UIntType(IntWidth(1)) - val ut = UnknownType - val ug = UNKNOWNGENDER - val readers = HashSet[String]() - val writers = HashSet[String]() - val readwriters = ArrayBuffer[String]() - for (w <- mem.writers ; r <- mem.readers) { - val wp = getProductTerms(s"${mem.name}.$w.en") - val rp = getProductTerms(s"${mem.name}.$r.en") - if (wp exists (a => rp exists (b => checkComplement(a, b)))) { - val allPorts = (mem.readers ++ mem.writers ++ mem.readwriters ++ readwriters).toSet - // Uniquify names by examining all ports of the memory - var rw = (for { - idx <- Stream from 0 - newName = s"rw_$idx" - if !allPorts(newName) - } yield newName).head - val rw_exp = WSubField(WRef(mem.name, ut, MemKind, ug), rw, ut, ug) - readwriters += rw - readers += r - writers += w - repl(s"${mem.name}.$r.en") = EmptyExpression - repl(s"${mem.name}.$r.clk") = EmptyExpression - repl(s"${mem.name}.$r.addr") = EmptyExpression - repl(s"${mem.name}.$r.data") = WSubField(rw_exp, "rdata", mem.dataType, MALE) - repl(s"${mem.name}.$w.en") = WSubField(rw_exp, "wmode", bt, FEMALE) - repl(s"${mem.name}.$w.clk") = EmptyExpression - repl(s"${mem.name}.$w.addr") = EmptyExpression - repl(s"${mem.name}.$w.data") = WSubField(rw_exp, "wdata", mem.dataType, FEMALE) - repl(s"${mem.name}.$w.mask") = WSubField(rw_exp, "wmask", ut, FEMALE) - stmts += Connect(NoInfo, WSubField(rw_exp, "clk", ClockType, FEMALE), - WRef("clk", ClockType, NodeKind, MALE)) - stmts += Connect(NoInfo, WSubField(rw_exp, "en", bt, FEMALE), - DoPrim(Or, List(connects(s"${mem.name}.$r.en"), connects(s"${mem.name}.$w.en")), Nil, bt)) - stmts += Connect(NoInfo, WSubField(rw_exp, "addr", ut, FEMALE), - Mux(connects(s"${mem.name}.$w.en"), connects(s"${mem.name}.$w.addr"), - connects(s"${mem.name}.$r.addr"), ut)) - } - } - if (readwriters.isEmpty) mem else DefMemory(mem.info, - mem.name, mem.dataType, mem.depth, mem.writeLatency, mem.readLatency, - mem.readers filterNot readers, mem.writers filterNot writers, - mem.readwriters ++ readwriters) + def replaceStmt(repl: Netlist)(s: Statement): Statement = + s map replaceStmt(repl) map replaceExp(repl) match { + case Connect(_, EmptyExpression, _) => EmptyStmt case s => s } - - def replaceExp(e: Expression): Expression = - e map replaceExp match { - case e: WSubField => repl getOrElse (e.serialize, e) - case e => e + + def inferReadWriteStmt(connects: Connects, + repl: Netlist, + stmts: Statements) + (s: Statement): Statement = s match { + // infer readwrite ports only for non combinational memories + case mem: DefMemory if mem.readLatency > 0 => + val ut = UnknownType + val ug = UNKNOWNGENDER + val readers = new PortSet + val writers = new PortSet + val readwriters = collection.mutable.ArrayBuffer[String]() + val namespace = Namespace(mem.readers ++ mem.writers ++ mem.readwriters) + for (w <- mem.writers ; r <- mem.readers) { + val wp = getProductTerms(connects)(memPortField(mem, w, "en")) + val rp = getProductTerms(connects)(memPortField(mem, r, "en")) + if (wp exists (a => rp exists (b => checkComplement(a, b)))) { + val rw = namespace newName "rw" + val rwExp = createSubField(createRef(mem.name), rw) + readwriters += rw + readers += r + writers += w + repl(memPortField(mem, r, "clk")) = EmptyExpression + repl(memPortField(mem, r, "en")) = EmptyExpression + repl(memPortField(mem, r, "addr")) = EmptyExpression + repl(memPortField(mem, r, "data")) = createSubField(rwExp, "rdata") + repl(memPortField(mem, w, "clk")) = EmptyExpression + repl(memPortField(mem, w, "en")) = createSubField(rwExp, "wmode") + repl(memPortField(mem, w, "addr")) = EmptyExpression + repl(memPortField(mem, w, "data")) = createSubField(rwExp, "wdata") + repl(memPortField(mem, w, "mask")) = createSubField(rwExp, "wmask") + stmts += Connect(NoInfo, createSubField(rwExp, "clk"), createRef("clk")) // TODO: fix it + stmts += Connect(NoInfo, createSubField(rwExp, "en"), + DoPrim(Or, Seq(connects(memPortField(mem, r, "en")), + connects(memPortField(mem, w, "en"))), Nil, BoolType)) + stmts += Connect(NoInfo, createSubField(rwExp, "addr"), + Mux(connects(memPortField(mem, w, "en")), + connects(memPortField(mem, w, "addr")), + connects(memPortField(mem, r, "addr")), UnknownType)) + } } + if (readwriters.isEmpty) mem else mem copy ( + readers = mem.readers filterNot readers, + writers = mem.writers filterNot writers, + readwriters = mem.readwriters ++ readwriters) + case s => s map inferReadWriteStmt(connects, repl, stmts) + } - def replaceStmt(s: Statement): Statement = - s map replaceStmt map replaceExp match { - case Connect(info, loc, exp) if loc == EmptyExpression => EmptyStmt - case s => s - } - - Module(m.info, m.name, m.ports, Block((m.body map inferReadWrite map replaceStmt) +: stmts.toSeq)) + def inferReadWrite(m: DefModule) = { + val connects = getConnects(m) + val repl = new Netlist + val stmts = new Statements + (m map inferReadWriteStmt(connects, repl, stmts) + map replaceStmt(repl)) match { + case m: ExtModule => m + case m: Module => m copy (body = Block(m.body +: stmts)) + } } - def run (c:Circuit) = Circuit(c.info, c.modules map { - case m: Module => inferReadWrite(m) - case m: ExtModule => m - }, c.main) + def run(c: Circuit) = c copy (modules = c.modules map inferReadWrite) } // Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl" // To use this transform, circuit name should be annotated with its TransId. -class InferReadWrite(transID: TransID) extends Transform with LazyLogging { - def execute(circuit:Circuit, map: AnnotationMap) = - map get transID match { - case Some(p) => p get CircuitName(circuit.main) match { - case Some(InferReadWriteAnnotation(_, _)) => TransformResult((Seq( - InferReadWritePass, - CheckInitialization, - ResolveKinds, - InferTypes, - ResolveGenders) foldLeft circuit){ (c, pass) => - val x = Utils.time(pass.name)(pass run c) - logger debug x.serialize - x - }, None, Some(map)) - case _ => TransformResult(circuit, None, Some(map)) - } - case _ => TransformResult(circuit, None, Some(map)) +class InferReadWrite(transID: TransID) extends Transform with SimpleRun { + def passSeq = Seq( + InferReadWritePass, + CheckInitialization, + InferTypes, + ResolveKinds, + ResolveGenders + ) + def execute(c: Circuit, map: AnnotationMap) = map get transID match { + case Some(p) => p get CircuitName(c.main) match { + case Some(InferReadWriteAnnotation(_, _)) => run(c, passSeq) + case _ => error("Unexpected annotation for InferReadWrite") } + case _ => TransformResult(c) + } } diff --git a/src/test/scala/firrtlTests/InferReadWriteSpec.scala b/src/test/scala/firrtlTests/InferReadWriteSpec.scala index 7e3383b2..3af018bd 100644 --- a/src/test/scala/firrtlTests/InferReadWriteSpec.scala +++ b/src/test/scala/firrtlTests/InferReadWriteSpec.scala @@ -38,7 +38,7 @@ class InferReadWriteSpec extends SimpleTransformSpec { val name = "Check Infer ReadWrite Ports" def findReadWrite(s: Statement): Boolean = s match { case s: DefMemory if s.readLatency > 0 && s.readwriters.size == 1 => - s.name == "mem" && s.readwriters.head == "rw_0" + s.name == "mem" && s.readwriters.head == "rw" case s: Block => s.stmts exists findReadWrite case _ => false |
