aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDonggyu Kim2016-08-20 11:58:43 -0700
committerDonggyu Kim2016-09-08 13:25:37 -0700
commitde32fe8128105413563a5fa746fcebf24c86d0a3 (patch)
tree4e0b5520f3fc205eb8d6adb27d502c3e78e32d48 /src
parent2a513ff47eebe38a81a1312c51972fcecaeb114f (diff)
clean up ExpandWhens
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/ExpandWhens.scala182
-rw-r--r--src/test/scala/firrtlTests/ReplSeqMemTests.scala30
2 files changed, 94 insertions, 118 deletions
diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala
index 3d26298a..c9c4b7d1 100644
--- a/src/main/scala/firrtl/passes/ExpandWhens.scala
+++ b/src/main/scala/firrtl/passes/ExpandWhens.scala
@@ -35,9 +35,8 @@ import firrtl.PrimOps._
import firrtl.WrappedExpression._
// Datastructures
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.LinkedHashMap
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable
+import scala.collection.mutable.{HashMap, LinkedHashMap, ArrayBuffer}
import annotation.tailrec
@@ -52,41 +51,33 @@ object ExpandWhens extends Pass {
def name = "Expand Whens"
// ========== Expand When Utilz ==========
- private def getEntries(
- hash: LinkedHashMap[WrappedExpression, Expression],
- exps: Seq[Expression]): LinkedHashMap[WrappedExpression, Expression] = {
- val hashx = LinkedHashMap[WrappedExpression, Expression]()
- exps foreach (e => if (hash.contains(e)) hashx(e) = hash(e))
- hashx
- }
private def getFemaleRefs(n: String, t: Type, g: Gender): Seq[Expression] = {
def getGender(t: Type, i: Int, g: Gender): Gender = times(g, get_flip(t, i, Default))
val exps = create_exps(WRef(n, t, ExpKind(), g))
- val expsx = ArrayBuffer[Expression]()
- for (j <- 0 until exps.size) {
- getGender(t, j, g) match {
- case (BIGENDER | FEMALE) => expsx += exps(j)
- case _ =>
+ (exps.zipWithIndex foldLeft Seq[Expression]()){
+ case (expsx, (exp, j)) => getGender(t, j, g) match {
+ case (BIGENDER | FEMALE) => expsx :+ exp
+ case _ => expsx
}
}
- expsx
}
private def expandNetlist(netlist: LinkedHashMap[WrappedExpression, Expression]) =
- netlist map { case (k, v) =>
- v match {
- case WInvalid() => IsInvalid(NoInfo, k.e1)
- case _ => Connect(NoInfo, k.e1, v)
- }
+ netlist map {
+ case (k, WInvalid()) => IsInvalid(NoInfo, k.e1)
+ case (k, v) => Connect(NoInfo, k.e1, v)
}
// Searches nested scopes of defaults for lvalue
// defaults uses mutable Map because we are searching LinkedHashMaps and conversion to immutable is VERY slow
@tailrec
- private def getDefault(
- lvalue: WrappedExpression,
- defaults: Seq[collection.mutable.Map[WrappedExpression, Expression]]): Option[Expression] = {
- if (defaults.isEmpty) None
- else if (defaults.head.contains(lvalue)) defaults.head.get(lvalue)
- else getDefault(lvalue, defaults.tail)
+ private def getDefault(lvalue: WrappedExpression,
+ defaults: Seq[mutable.Map[WrappedExpression, Expression]]): Option[Expression] = {
+ defaults match {
+ case Nil => None
+ case head :: tail => head get lvalue match {
+ case Some(p) => Some(p)
+ case None => getDefault(lvalue, tail)
+ }
+ }
}
// ------------ Pass -------------------
@@ -98,90 +89,75 @@ object ExpandWhens extends Pass {
// defaults ideally would be immutable.Map but conversion from mutable.LinkedHashMap to mutable.Map is VERY slow
def expandWhens(
netlist: LinkedHashMap[WrappedExpression, Expression],
- defaults: Seq[collection.mutable.Map[WrappedExpression, Expression]],
+ defaults: Seq[mutable.Map[WrappedExpression, Expression]],
p: Expression)
- (s: Statement): Statement = {
- s match {
- case w: DefWire =>
- getFemaleRefs(w.name, w.tpe, BIGENDER) foreach (ref => netlist(ref) = WVoid())
- w
- case r: DefRegister =>
- getFemaleRefs(r.name, r.tpe, BIGENDER) foreach (ref => netlist(ref) = ref)
- r
- case c: Connect =>
- netlist(c.loc) = c.expr
- EmptyStmt
- case c: IsInvalid =>
- netlist(c.expr) = WInvalid()
- EmptyStmt
- case s: Conditionally =>
- val memos = ArrayBuffer[Statement]()
+ (s: Statement): Statement = s match {
+ case w: DefWire =>
+ netlist ++= (getFemaleRefs(w.name, w.tpe, BIGENDER) map (ref => we(ref) -> WVoid()))
+ w
+ case r: DefRegister =>
+ netlist ++= (getFemaleRefs(r.name, r.tpe, BIGENDER) map (ref => we(ref) -> ref))
+ r
+ case c: Connect =>
+ netlist(c.loc) = c.expr
+ EmptyStmt
+ case c: IsInvalid =>
+ netlist(c.expr) = WInvalid()
+ EmptyStmt
+ case s: Conditionally =>
+ val conseqNetlist = LinkedHashMap[WrappedExpression, Expression]()
+ val altNetlist = LinkedHashMap[WrappedExpression, Expression]()
+ val conseqStmt = expandWhens(conseqNetlist, netlist +: defaults, AND(p, s.pred))(s.conseq)
+ val altStmt = expandWhens(altNetlist, netlist +: defaults, AND(p, NOT(s.pred)))(s.alt)
- val conseqNetlist = LinkedHashMap[WrappedExpression, Expression]()
- val altNetlist = LinkedHashMap[WrappedExpression, Expression]()
- val conseqStmt = expandWhens(conseqNetlist, netlist +: defaults, AND(p, s.pred))(s.conseq)
- val altStmt = expandWhens(altNetlist, netlist +: defaults, AND(p, NOT(s.pred)))(s.alt)
-
- (conseqNetlist.keySet ++ altNetlist.keySet) foreach { lvalue =>
- // Defaults in netlist get priority over those in defaults
- val default = if (netlist.contains(lvalue)) netlist.get(lvalue) else getDefault(lvalue, defaults)
- val res = default match {
- case Some(defaultValue) =>
- val trueValue = conseqNetlist.getOrElse(lvalue, defaultValue)
- val falseValue = altNetlist.getOrElse(lvalue, defaultValue)
- (trueValue, falseValue) match {
- case (WInvalid(), WInvalid()) => WInvalid()
- case (WInvalid(), fv) => ValidIf(NOT(s.pred), fv, fv.tpe)
- case (tv, WInvalid()) => ValidIf(s.pred, tv, tv.tpe)
- case (tv, fv) => Mux(s.pred, tv, fv, mux_type_and_widths(tv, fv))
- }
- case None =>
- // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt
- conseqNetlist.getOrElse(lvalue, altNetlist(lvalue))
- }
-
- val memoNode = DefNode(s.info, namespace.newTemp, res)
- val memoExpr = WRef(memoNode.name, res.tpe, NodeKind(), MALE)
- memos += memoNode
- netlist(lvalue) = memoExpr
+ val memos = (conseqNetlist.keys ++ altNetlist.keys) map { lvalue =>
+ // Defaults in netlist get priority over those in defaults
+ val default = netlist get lvalue match {
+ case Some(v) => Some(v)
+ case None => getDefault(lvalue, defaults)
}
- Block(Seq(conseqStmt, altStmt) ++ memos)
-
- case s: Print =>
- if(weq(p, one)) {
- simlist += s
- } else {
- simlist += Print(s.info, s.string, s.args, s.clk, AND(p, s.en))
+ val res = default match {
+ case Some(defaultValue) =>
+ val trueValue = conseqNetlist getOrElse (lvalue, defaultValue)
+ val falseValue = altNetlist getOrElse (lvalue, defaultValue)
+ (trueValue, falseValue) match {
+ case (WInvalid(), WInvalid()) => WInvalid()
+ case (WInvalid(), fv) => ValidIf(NOT(s.pred), fv, fv.tpe)
+ case (tv, WInvalid()) => ValidIf(s.pred, tv, tv.tpe)
+ case (tv, fv) => Mux(s.pred, tv, fv, mux_type_and_widths(tv, fv))
+ }
+ case None =>
+ // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt
+ conseqNetlist getOrElse (lvalue, altNetlist(lvalue))
}
- EmptyStmt
- case s: Stop =>
- if (weq(p, one)) {
- simlist += s
- } else {
- simlist += Stop(s.info, s.ret, s.clk, AND(p, s.en))
- }
- EmptyStmt
- case s => s map expandWhens(netlist, defaults, p)
- }
+
+ val memoNode = DefNode(s.info, namespace.newTemp, res)
+ val memoExpr = WRef(memoNode.name, res.tpe, NodeKind(), MALE)
+ netlist(lvalue) = memoExpr
+ memoNode
+ }
+ Block(Seq(conseqStmt, altStmt) ++ memos)
+ case s: Print =>
+ simlist += (if (weq(p, one)) s else Print(s.info, s.string, s.args, s.clk, AND(p, s.en)))
+ EmptyStmt
+ case s: Stop =>
+ simlist += (if (weq(p, one)) s else Stop(s.info, s.ret, s.clk, AND(p, s.en)))
+ EmptyStmt
+ case s => s map expandWhens(netlist, defaults, p)
}
val netlist = LinkedHashMap[WrappedExpression, Expression]()
-
// Add ports to netlist
- m.ports foreach { port =>
- getFemaleRefs(port.name, port.tpe, to_gender(port.direction)) foreach (ref => netlist(ref) = WVoid())
- }
- val bodyx = expandWhens(netlist, Seq(netlist), one)(m.body)
-
- (netlist, simlist, bodyx)
+ netlist ++= (m.ports flatMap { case Port(_, name, dir, tpe) =>
+ getFemaleRefs(name, tpe, to_gender(dir)) map (ref => we(ref) -> WVoid())
+ })
+ (netlist, simlist, expandWhens(netlist, Seq(netlist), one)(m.body))
}
- val modulesx = c.modules map { m =>
- m match {
- case m: ExtModule => m
- case m: Module =>
- val (netlist, simlist, bodyx) = expandWhens(m)
- val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist) ++ simlist)
- Module(m.info, m.name, m.ports, newBody)
- }
+ val modulesx = c.modules map {
+ case m: ExtModule => m
+ case m: Module =>
+ val (netlist, simlist, bodyx) = expandWhens(m)
+ val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist) ++ simlist)
+ Module(m.info, m.name, m.ports, newBody)
}
Circuit(c.info, modulesx, c.main)
}
diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
index 54ef6003..118e547c 100644
--- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala
+++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
@@ -5,7 +5,8 @@ import firrtl.passes._
import Annotations._
class ReplSeqMemSpec extends SimpleTransformSpec {
-
+ val passSeq = Seq(
+ ConstProp, CommonSubexpressionElimination, DeadCodeElimination, RemoveEmpty)
def transforms (writer: java.io.Writer) = Seq(
new Chisel3ToHighFirrtl(),
new IRToWorkingIR(),
@@ -14,6 +15,8 @@ class ReplSeqMemSpec extends SimpleTransformSpec {
new passes.InferReadWrite(TransID(-1)),
new passes.ReplSeqMem(TransID(-2)),
new MiddleFirrtlToLowFirrtl(),
+ (new Transform with SimpleRun {
+ def execute(c: ir.Circuit, a: AnnotationMap) = run(c, passSeq) }),
new EmitFirrtl(writer)
)
@@ -97,27 +100,24 @@ circuit sram6t :
input io_wdata : UInt<32>
input io_raddr : UInt<8>
output io_rdata : UInt<32>
-
+
inst mem of mem
node T_0 = eq(io_wen, UInt<1>("h0"))
node T_1 = and(io_en, T_0)
wire T_2 : UInt<8>
node GEN_0 = validif(T_1, io_raddr)
- node GEN_1 = mux(T_1, UInt<1>("h1"), UInt<1>("h0"))
node T_4 = and(io_en, io_wen)
+ node GEN_4 = validif(T_4, io_wdata)
node GEN_2 = validif(T_4, io_waddr)
- node GEN_3 = validif(T_4, clk)
- node GEN_4 = mux(T_4, UInt<1>("h1"), UInt<1>("h0"))
- node GEN_5 = validif(T_4, io_wdata)
- node GEN_6 = mux(T_4, UInt<1>("h1"), UInt<1>("h0"))
+ node GEN_5 = validif(T_4, clk)
io_rdata <= mem.R0_data
mem.R0_addr <= bits(T_2, 6, 0)
mem.R0_clk <= clk
- mem.R0_en <= GEN_1
+ mem.R0_en <= T_1
mem.W0_addr <= bits(GEN_2, 6, 0)
- mem.W0_clk <= GEN_3
- mem.W0_en <= GEN_4
- mem.W0_data <= GEN_5
+ mem.W0_clk <= GEN_5
+ mem.W0_en <= T_4
+ mem.W0_data <= GEN_4
T_2 <= GEN_0
extmodule mem_ext :
@@ -140,16 +140,16 @@ circuit sram6t :
input W0_en : UInt<1>
input W0_clk : Clock
input W0_data : UInt<32>
-
+
inst mem_ext of mem_ext
mem_ext.R0_addr <= R0_addr
mem_ext.R0_en <= R0_en
mem_ext.R0_clk <= R0_clk
- R0_data <= bits(mem_ext.R0_data, 31, 0)
+ R0_data <= mem_ext.R0_data
mem_ext.W0_addr <= W0_addr
mem_ext.W0_en <= W0_en
mem_ext.W0_clk <= W0_clk
- mem_ext.W0_data <= W0_data
+ mem_ext.W0_data <= W0_data
""".stripMargin
val checkConf = """name mem_ext depth 128 width 32 ports write,read """
@@ -170,4 +170,4 @@ circuit sram6t :
// readwrite vs. no readwrite
// redundant memories (multiple instances of the same type of memory)
// mask + no mask
-// conf \ No newline at end of file
+// conf