diff options
| author | Donggyu Kim | 2016-08-20 11:58:43 -0700 |
|---|---|---|
| committer | Donggyu Kim | 2016-09-08 13:25:37 -0700 |
| commit | de32fe8128105413563a5fa746fcebf24c86d0a3 (patch) | |
| tree | 4e0b5520f3fc205eb8d6adb27d502c3e78e32d48 /src | |
| parent | 2a513ff47eebe38a81a1312c51972fcecaeb114f (diff) | |
clean up ExpandWhens
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/ExpandWhens.scala | 182 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ReplSeqMemTests.scala | 30 |
2 files changed, 94 insertions, 118 deletions
diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index 3d26298a..c9c4b7d1 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -35,9 +35,8 @@ import firrtl.PrimOps._ import firrtl.WrappedExpression._ // Datastructures -import scala.collection.mutable.HashMap -import scala.collection.mutable.LinkedHashMap -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable +import scala.collection.mutable.{HashMap, LinkedHashMap, ArrayBuffer} import annotation.tailrec @@ -52,41 +51,33 @@ object ExpandWhens extends Pass { def name = "Expand Whens" // ========== Expand When Utilz ========== - private def getEntries( - hash: LinkedHashMap[WrappedExpression, Expression], - exps: Seq[Expression]): LinkedHashMap[WrappedExpression, Expression] = { - val hashx = LinkedHashMap[WrappedExpression, Expression]() - exps foreach (e => if (hash.contains(e)) hashx(e) = hash(e)) - hashx - } private def getFemaleRefs(n: String, t: Type, g: Gender): Seq[Expression] = { def getGender(t: Type, i: Int, g: Gender): Gender = times(g, get_flip(t, i, Default)) val exps = create_exps(WRef(n, t, ExpKind(), g)) - val expsx = ArrayBuffer[Expression]() - for (j <- 0 until exps.size) { - getGender(t, j, g) match { - case (BIGENDER | FEMALE) => expsx += exps(j) - case _ => + (exps.zipWithIndex foldLeft Seq[Expression]()){ + case (expsx, (exp, j)) => getGender(t, j, g) match { + case (BIGENDER | FEMALE) => expsx :+ exp + case _ => expsx } } - expsx } private def expandNetlist(netlist: LinkedHashMap[WrappedExpression, Expression]) = - netlist map { case (k, v) => - v match { - case WInvalid() => IsInvalid(NoInfo, k.e1) - case _ => Connect(NoInfo, k.e1, v) - } + netlist map { + case (k, WInvalid()) => IsInvalid(NoInfo, k.e1) + case (k, v) => Connect(NoInfo, k.e1, v) } // Searches nested scopes of defaults for lvalue // defaults uses mutable Map because we are searching LinkedHashMaps and conversion to immutable is VERY slow @tailrec - private def getDefault( - lvalue: WrappedExpression, - defaults: Seq[collection.mutable.Map[WrappedExpression, Expression]]): Option[Expression] = { - if (defaults.isEmpty) None - else if (defaults.head.contains(lvalue)) defaults.head.get(lvalue) - else getDefault(lvalue, defaults.tail) + private def getDefault(lvalue: WrappedExpression, + defaults: Seq[mutable.Map[WrappedExpression, Expression]]): Option[Expression] = { + defaults match { + case Nil => None + case head :: tail => head get lvalue match { + case Some(p) => Some(p) + case None => getDefault(lvalue, tail) + } + } } // ------------ Pass ------------------- @@ -98,90 +89,75 @@ object ExpandWhens extends Pass { // defaults ideally would be immutable.Map but conversion from mutable.LinkedHashMap to mutable.Map is VERY slow def expandWhens( netlist: LinkedHashMap[WrappedExpression, Expression], - defaults: Seq[collection.mutable.Map[WrappedExpression, Expression]], + defaults: Seq[mutable.Map[WrappedExpression, Expression]], p: Expression) - (s: Statement): Statement = { - s match { - case w: DefWire => - getFemaleRefs(w.name, w.tpe, BIGENDER) foreach (ref => netlist(ref) = WVoid()) - w - case r: DefRegister => - getFemaleRefs(r.name, r.tpe, BIGENDER) foreach (ref => netlist(ref) = ref) - r - case c: Connect => - netlist(c.loc) = c.expr - EmptyStmt - case c: IsInvalid => - netlist(c.expr) = WInvalid() - EmptyStmt - case s: Conditionally => - val memos = ArrayBuffer[Statement]() + (s: Statement): Statement = s match { + case w: DefWire => + netlist ++= (getFemaleRefs(w.name, w.tpe, BIGENDER) map (ref => we(ref) -> WVoid())) + w + case r: DefRegister => + netlist ++= (getFemaleRefs(r.name, r.tpe, BIGENDER) map (ref => we(ref) -> ref)) + r + case c: Connect => + netlist(c.loc) = c.expr + EmptyStmt + case c: IsInvalid => + netlist(c.expr) = WInvalid() + EmptyStmt + case s: Conditionally => + val conseqNetlist = LinkedHashMap[WrappedExpression, Expression]() + val altNetlist = LinkedHashMap[WrappedExpression, Expression]() + val conseqStmt = expandWhens(conseqNetlist, netlist +: defaults, AND(p, s.pred))(s.conseq) + val altStmt = expandWhens(altNetlist, netlist +: defaults, AND(p, NOT(s.pred)))(s.alt) - val conseqNetlist = LinkedHashMap[WrappedExpression, Expression]() - val altNetlist = LinkedHashMap[WrappedExpression, Expression]() - val conseqStmt = expandWhens(conseqNetlist, netlist +: defaults, AND(p, s.pred))(s.conseq) - val altStmt = expandWhens(altNetlist, netlist +: defaults, AND(p, NOT(s.pred)))(s.alt) - - (conseqNetlist.keySet ++ altNetlist.keySet) foreach { lvalue => - // Defaults in netlist get priority over those in defaults - val default = if (netlist.contains(lvalue)) netlist.get(lvalue) else getDefault(lvalue, defaults) - val res = default match { - case Some(defaultValue) => - val trueValue = conseqNetlist.getOrElse(lvalue, defaultValue) - val falseValue = altNetlist.getOrElse(lvalue, defaultValue) - (trueValue, falseValue) match { - case (WInvalid(), WInvalid()) => WInvalid() - case (WInvalid(), fv) => ValidIf(NOT(s.pred), fv, fv.tpe) - case (tv, WInvalid()) => ValidIf(s.pred, tv, tv.tpe) - case (tv, fv) => Mux(s.pred, tv, fv, mux_type_and_widths(tv, fv)) - } - case None => - // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt - conseqNetlist.getOrElse(lvalue, altNetlist(lvalue)) - } - - val memoNode = DefNode(s.info, namespace.newTemp, res) - val memoExpr = WRef(memoNode.name, res.tpe, NodeKind(), MALE) - memos += memoNode - netlist(lvalue) = memoExpr + val memos = (conseqNetlist.keys ++ altNetlist.keys) map { lvalue => + // Defaults in netlist get priority over those in defaults + val default = netlist get lvalue match { + case Some(v) => Some(v) + case None => getDefault(lvalue, defaults) } - Block(Seq(conseqStmt, altStmt) ++ memos) - - case s: Print => - if(weq(p, one)) { - simlist += s - } else { - simlist += Print(s.info, s.string, s.args, s.clk, AND(p, s.en)) + val res = default match { + case Some(defaultValue) => + val trueValue = conseqNetlist getOrElse (lvalue, defaultValue) + val falseValue = altNetlist getOrElse (lvalue, defaultValue) + (trueValue, falseValue) match { + case (WInvalid(), WInvalid()) => WInvalid() + case (WInvalid(), fv) => ValidIf(NOT(s.pred), fv, fv.tpe) + case (tv, WInvalid()) => ValidIf(s.pred, tv, tv.tpe) + case (tv, fv) => Mux(s.pred, tv, fv, mux_type_and_widths(tv, fv)) + } + case None => + // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt + conseqNetlist getOrElse (lvalue, altNetlist(lvalue)) } - EmptyStmt - case s: Stop => - if (weq(p, one)) { - simlist += s - } else { - simlist += Stop(s.info, s.ret, s.clk, AND(p, s.en)) - } - EmptyStmt - case s => s map expandWhens(netlist, defaults, p) - } + + val memoNode = DefNode(s.info, namespace.newTemp, res) + val memoExpr = WRef(memoNode.name, res.tpe, NodeKind(), MALE) + netlist(lvalue) = memoExpr + memoNode + } + Block(Seq(conseqStmt, altStmt) ++ memos) + case s: Print => + simlist += (if (weq(p, one)) s else Print(s.info, s.string, s.args, s.clk, AND(p, s.en))) + EmptyStmt + case s: Stop => + simlist += (if (weq(p, one)) s else Stop(s.info, s.ret, s.clk, AND(p, s.en))) + EmptyStmt + case s => s map expandWhens(netlist, defaults, p) } val netlist = LinkedHashMap[WrappedExpression, Expression]() - // Add ports to netlist - m.ports foreach { port => - getFemaleRefs(port.name, port.tpe, to_gender(port.direction)) foreach (ref => netlist(ref) = WVoid()) - } - val bodyx = expandWhens(netlist, Seq(netlist), one)(m.body) - - (netlist, simlist, bodyx) + netlist ++= (m.ports flatMap { case Port(_, name, dir, tpe) => + getFemaleRefs(name, tpe, to_gender(dir)) map (ref => we(ref) -> WVoid()) + }) + (netlist, simlist, expandWhens(netlist, Seq(netlist), one)(m.body)) } - val modulesx = c.modules map { m => - m match { - case m: ExtModule => m - case m: Module => - val (netlist, simlist, bodyx) = expandWhens(m) - val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist) ++ simlist) - Module(m.info, m.name, m.ports, newBody) - } + val modulesx = c.modules map { + case m: ExtModule => m + case m: Module => + val (netlist, simlist, bodyx) = expandWhens(m) + val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist) ++ simlist) + Module(m.info, m.name, m.ports, newBody) } Circuit(c.info, modulesx, c.main) } diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index 54ef6003..118e547c 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -5,7 +5,8 @@ import firrtl.passes._ import Annotations._ class ReplSeqMemSpec extends SimpleTransformSpec { - + val passSeq = Seq( + ConstProp, CommonSubexpressionElimination, DeadCodeElimination, RemoveEmpty) def transforms (writer: java.io.Writer) = Seq( new Chisel3ToHighFirrtl(), new IRToWorkingIR(), @@ -14,6 +15,8 @@ class ReplSeqMemSpec extends SimpleTransformSpec { new passes.InferReadWrite(TransID(-1)), new passes.ReplSeqMem(TransID(-2)), new MiddleFirrtlToLowFirrtl(), + (new Transform with SimpleRun { + def execute(c: ir.Circuit, a: AnnotationMap) = run(c, passSeq) }), new EmitFirrtl(writer) ) @@ -97,27 +100,24 @@ circuit sram6t : input io_wdata : UInt<32> input io_raddr : UInt<8> output io_rdata : UInt<32> - + inst mem of mem node T_0 = eq(io_wen, UInt<1>("h0")) node T_1 = and(io_en, T_0) wire T_2 : UInt<8> node GEN_0 = validif(T_1, io_raddr) - node GEN_1 = mux(T_1, UInt<1>("h1"), UInt<1>("h0")) node T_4 = and(io_en, io_wen) + node GEN_4 = validif(T_4, io_wdata) node GEN_2 = validif(T_4, io_waddr) - node GEN_3 = validif(T_4, clk) - node GEN_4 = mux(T_4, UInt<1>("h1"), UInt<1>("h0")) - node GEN_5 = validif(T_4, io_wdata) - node GEN_6 = mux(T_4, UInt<1>("h1"), UInt<1>("h0")) + node GEN_5 = validif(T_4, clk) io_rdata <= mem.R0_data mem.R0_addr <= bits(T_2, 6, 0) mem.R0_clk <= clk - mem.R0_en <= GEN_1 + mem.R0_en <= T_1 mem.W0_addr <= bits(GEN_2, 6, 0) - mem.W0_clk <= GEN_3 - mem.W0_en <= GEN_4 - mem.W0_data <= GEN_5 + mem.W0_clk <= GEN_5 + mem.W0_en <= T_4 + mem.W0_data <= GEN_4 T_2 <= GEN_0 extmodule mem_ext : @@ -140,16 +140,16 @@ circuit sram6t : input W0_en : UInt<1> input W0_clk : Clock input W0_data : UInt<32> - + inst mem_ext of mem_ext mem_ext.R0_addr <= R0_addr mem_ext.R0_en <= R0_en mem_ext.R0_clk <= R0_clk - R0_data <= bits(mem_ext.R0_data, 31, 0) + R0_data <= mem_ext.R0_data mem_ext.W0_addr <= W0_addr mem_ext.W0_en <= W0_en mem_ext.W0_clk <= W0_clk - mem_ext.W0_data <= W0_data + mem_ext.W0_data <= W0_data """.stripMargin val checkConf = """name mem_ext depth 128 width 32 ports write,read """ @@ -170,4 +170,4 @@ circuit sram6t : // readwrite vs. no readwrite // redundant memories (multiple instances of the same type of memory) // mask + no mask -// conf
\ No newline at end of file +// conf |
